|
| 1 | +// Originally made by Sam Turner and Finlay Sanders, 2025. |
| 2 | +// Included in pufferlib under the original project's MIT license. |
| 3 | +// https://github.com/tensaur/drone |
| 4 | + |
| 5 | +#pragma once |
| 6 | + |
| 7 | +#include <math.h> |
| 8 | +#include <limits.h> |
| 9 | +#include <stdbool.h> |
| 10 | +#include <stdlib.h> |
| 11 | + |
| 12 | +#include "dronelib.h" |
| 13 | +#include "tasks.h" |
| 14 | + |
| 15 | +typedef struct Client Client; |
| 16 | +typedef struct DroneEnv DroneEnv; |
| 17 | + |
| 18 | +struct DroneEnv { |
| 19 | + float *observations; |
| 20 | + float *actions; |
| 21 | + float *rewards; |
| 22 | + unsigned char *terminals; |
| 23 | + |
| 24 | + Log log; |
| 25 | + int tick; |
| 26 | + int report_interval; |
| 27 | + |
| 28 | + DroneTask task; |
| 29 | + int num_agents; |
| 30 | + Drone *agents; |
| 31 | + |
| 32 | + int max_rings; |
| 33 | + Target* ring_buffer; |
| 34 | + |
| 35 | + Client *client; |
| 36 | +}; |
| 37 | + |
| 38 | +void init(DroneEnv *env) { |
| 39 | + env->agents = (Drone*) calloc(env->num_agents, sizeof(Drone)); |
| 40 | + env->ring_buffer = (Target*) calloc(env->max_rings, sizeof(Target)); |
| 41 | + |
| 42 | + for (int i = 0; i < env->num_agents; i++) { |
| 43 | + env->agents[i].target = (Target*) calloc(1, sizeof(Target)); |
| 44 | + } |
| 45 | + |
| 46 | + env->log = (Log){0}; |
| 47 | + env->tick = 0; |
| 48 | +} |
| 49 | + |
| 50 | +void add_log(DroneEnv *env, int idx, bool oob, bool ring_collision, |
| 51 | + bool timeout) { |
| 52 | + Drone *agent = &env->agents[idx]; |
| 53 | + env->log.score += agent->score; |
| 54 | + env->log.episode_return += agent->episode_return; |
| 55 | + env->log.episode_length += agent->episode_length; |
| 56 | + env->log.collision_rate += agent->collisions / (float)agent->episode_length; |
| 57 | + env->log.perf += agent->score / (float)agent->episode_length; |
| 58 | + if (oob) { |
| 59 | + env->log.oob += 1.0f; |
| 60 | + } |
| 61 | + if (ring_collision) { |
| 62 | + env->log.ring_collision += 1.0f; |
| 63 | + } |
| 64 | + if (timeout) { |
| 65 | + env->log.timeout += 1.0f; |
| 66 | + } |
| 67 | + env->log.n += 1.0f; |
| 68 | + |
| 69 | + agent->episode_length = 0; |
| 70 | + agent->episode_return = 0.0f; |
| 71 | +} |
| 72 | + |
| 73 | +void compute_observations(DroneEnv *env) { |
| 74 | + int idx = 0; |
| 75 | + |
| 76 | + for (int i = 0; i < env->num_agents; i++) { |
| 77 | + Drone *agent = &env->agents[i]; |
| 78 | + |
| 79 | + Quat q_inv = quat_inverse(agent->state.quat); |
| 80 | + Vec3 linear_vel_body = quat_rotate(q_inv, agent->state.vel); |
| 81 | + Vec3 to_target = sub3(agent->target->pos, agent->state.pos); |
| 82 | + |
| 83 | + env->observations[idx++] = linear_vel_body.x / agent->params.max_vel; |
| 84 | + env->observations[idx++] = linear_vel_body.y / agent->params.max_vel; |
| 85 | + env->observations[idx++] = linear_vel_body.z / agent->params.max_vel; |
| 86 | + |
| 87 | + env->observations[idx++] = agent->state.omega.x / agent->params.max_omega; |
| 88 | + env->observations[idx++] = agent->state.omega.y / agent->params.max_omega; |
| 89 | + env->observations[idx++] = agent->state.omega.z / agent->params.max_omega; |
| 90 | + |
| 91 | + env->observations[idx++] = agent->state.quat.w; |
| 92 | + env->observations[idx++] = agent->state.quat.x; |
| 93 | + env->observations[idx++] = agent->state.quat.y; |
| 94 | + env->observations[idx++] = agent->state.quat.z; |
| 95 | + |
| 96 | + env->observations[idx++] = agent->state.rpms[0] / agent->params.max_rpm; |
| 97 | + env->observations[idx++] = agent->state.rpms[1] / agent->params.max_rpm; |
| 98 | + env->observations[idx++] = agent->state.rpms[2] / agent->params.max_rpm; |
| 99 | + env->observations[idx++] = agent->state.rpms[3] / agent->params.max_rpm; |
| 100 | + |
| 101 | + env->observations[idx++] = to_target.x / GRID_X; |
| 102 | + env->observations[idx++] = to_target.y / GRID_Y; |
| 103 | + env->observations[idx++] = to_target.z / GRID_Z; |
| 104 | + |
| 105 | + env->observations[idx++] = clampf(to_target.x, -1.0f, 1.0f); |
| 106 | + env->observations[idx++] = clampf(to_target.y, -1.0f, 1.0f); |
| 107 | + env->observations[idx++] = clampf(to_target.z, -1.0f, 1.0f); |
| 108 | + |
| 109 | + env->observations[idx++] = agent->target->normal.x; |
| 110 | + env->observations[idx++] = agent->target->normal.y; |
| 111 | + env->observations[idx++] = agent->target->normal.z; |
| 112 | + |
| 113 | + // Multiagent obs |
| 114 | + Drone *nearest = nearest_drone(agent, env->agents, env->num_agents); |
| 115 | + if (env->num_agents > 1) { |
| 116 | + env->observations[idx++] = |
| 117 | + clampf(nearest->state.pos.x - agent->state.pos.x, -1.0f, 1.0f); |
| 118 | + env->observations[idx++] = |
| 119 | + clampf(nearest->state.pos.y - agent->state.pos.y, -1.0f, 1.0f); |
| 120 | + env->observations[idx++] = |
| 121 | + clampf(nearest->state.pos.z - agent->state.pos.z, -1.0f, 1.0f); |
| 122 | + } else { |
| 123 | + env->observations[idx++] = MAX_DIST; |
| 124 | + env->observations[idx++] = MAX_DIST; |
| 125 | + env->observations[idx++] = MAX_DIST; |
| 126 | + } |
| 127 | + } |
| 128 | +} |
| 129 | + |
| 130 | +void reset_agent(DroneEnv *env, Drone *agent, int idx) { |
| 131 | + agent->last_dist_reward = 0.0f; |
| 132 | + agent->episode_return = 0.0f; |
| 133 | + agent->episode_length = 0; |
| 134 | + agent->collisions = 0.0f; |
| 135 | + agent->score = 0.0f; |
| 136 | + |
| 137 | + agent->buffer = env->ring_buffer; |
| 138 | + agent->buffer_size = env->max_rings; |
| 139 | + agent->buffer_idx = -1; |
| 140 | + |
| 141 | + float size = rndf(0.1f, 0.4f); |
| 142 | + init_drone(agent, size, 0.1f); |
| 143 | + |
| 144 | + agent->state.pos = |
| 145 | + (Vec3){rndf(-MARGIN_X, MARGIN_X), rndf(-MARGIN_Y, MARGIN_Y), |
| 146 | + rndf(-MARGIN_Z, MARGIN_Z)}; |
| 147 | + |
| 148 | + if (env->task == RACE) { |
| 149 | + while (norm3(sub3(agent->state.pos, env->ring_buffer[0].pos)) < |
| 150 | + 2.0f * RING_RADIUS) { |
| 151 | + agent->state.pos = |
| 152 | + (Vec3){rndf(-MARGIN_X, MARGIN_X), rndf(-MARGIN_Y, MARGIN_Y), |
| 153 | + rndf(-MARGIN_Z, MARGIN_Z)}; |
| 154 | + } |
| 155 | + } |
| 156 | + |
| 157 | + agent->prev_pos = agent->state.pos; |
| 158 | +} |
| 159 | + |
| 160 | +void c_reset(DroneEnv *env) { |
| 161 | + env->tick = 0; |
| 162 | + int rng = rand(); |
| 163 | + |
| 164 | + if (rng > INT_MAX / 2) { |
| 165 | + env->task = RACE; |
| 166 | + } else { |
| 167 | + env->task = (DroneTask)(rng % (TASK_N - 1)); |
| 168 | + } |
| 169 | + |
| 170 | + if (env->task == RACE) { |
| 171 | + reset_rings(env->ring_buffer, env->max_rings); |
| 172 | + } |
| 173 | + |
| 174 | + for (int i = 0; i < env->num_agents; i++) { |
| 175 | + Drone *agent = &env->agents[i]; |
| 176 | + reset_agent(env, agent, i); |
| 177 | + set_target(env->task, env->agents, i, env->num_agents); |
| 178 | + } |
| 179 | + |
| 180 | + compute_observations(env); |
| 181 | +} |
| 182 | + |
| 183 | +void c_step(DroneEnv *env) { |
| 184 | + env->tick = (env->tick + 1) % HORIZON; |
| 185 | + |
| 186 | + for (int i = 0; i < env->num_agents; i++) { |
| 187 | + Drone *agent = &env->agents[i]; |
| 188 | + env->rewards[i] = 0; |
| 189 | + env->terminals[i] = 0; |
| 190 | + |
| 191 | + float *atn = &env->actions[4 * i]; |
| 192 | + move_drone(agent, atn); |
| 193 | + |
| 194 | + bool out_of_bounds = |
| 195 | + agent->state.pos.x < -GRID_X || agent->state.pos.x > GRID_X || |
| 196 | + agent->state.pos.y < -GRID_Y || agent->state.pos.y > GRID_Y || |
| 197 | + agent->state.pos.z < -GRID_Z || agent->state.pos.z > GRID_Z; |
| 198 | + |
| 199 | + move_target(agent); |
| 200 | + |
| 201 | + bool collision = check_collision(agent, env->agents, env->num_agents); |
| 202 | + float reward = 0.0f; |
| 203 | + |
| 204 | + if (env->task == RACE) { |
| 205 | + // Check ring passage |
| 206 | + Target *ring = &env->ring_buffer[agent->buffer_idx]; |
| 207 | + int ring_passage = check_ring(agent, ring); |
| 208 | + |
| 209 | + // Ring collision |
| 210 | + if (ring_passage < 0) { |
| 211 | + env->rewards[i] = (float)ring_passage; |
| 212 | + agent->episode_return += (float)ring_passage; |
| 213 | + env->terminals[i] = 1; |
| 214 | + add_log(env, i, false, true, false); |
| 215 | + reset_agent(env, agent, i); |
| 216 | + set_target(env->task, env->agents, i, env->num_agents); |
| 217 | + continue; |
| 218 | + } |
| 219 | + |
| 220 | + // Successfully passed through ring - advance to next |
| 221 | + if (ring_passage > 0) { |
| 222 | + set_target(env->task, env->agents, i, env->num_agents); |
| 223 | + env->log.rings_passed += 1.0f; |
| 224 | + } |
| 225 | + |
| 226 | + reward = dynamic_task_reward(agent, collision, ring_passage); |
| 227 | + } else { |
| 228 | + reward = static_task_reward(agent, collision); |
| 229 | + } |
| 230 | + |
| 231 | + // Update agent state |
| 232 | + agent->episode_length++; |
| 233 | + agent->score += reward; |
| 234 | + if (collision) { |
| 235 | + agent->collisions += 1.0f; |
| 236 | + } |
| 237 | + |
| 238 | + env->rewards[i] = reward; |
| 239 | + agent->episode_return += reward; |
| 240 | + |
| 241 | + // Check termination conditions |
| 242 | + if (out_of_bounds) { |
| 243 | + env->rewards[i] -= 1.0f; |
| 244 | + agent->episode_return -= 1.0f; |
| 245 | + env->terminals[i] = 1; |
| 246 | + add_log(env, i, true, false, false); |
| 247 | + |
| 248 | + reset_agent(env, agent, i); |
| 249 | + set_target(env->task, env->agents, i, env->num_agents); |
| 250 | + static_task_reward(agent, false); |
| 251 | + } else if (env->tick >= HORIZON - 1) { |
| 252 | + env->terminals[i] = 1; |
| 253 | + add_log(env, i, false, false, true); |
| 254 | + } |
| 255 | + } |
| 256 | + |
| 257 | + if (env->tick >= HORIZON - 1) { |
| 258 | + c_reset(env); |
| 259 | + } |
| 260 | + |
| 261 | + compute_observations(env); |
| 262 | +} |
| 263 | + |
| 264 | +void c_close_client(Client* client); |
| 265 | + |
| 266 | +void c_close(DroneEnv *env) { |
| 267 | + for (int i = 0; i < env->num_agents; i++) { |
| 268 | + free(env->agents[i].target); |
| 269 | + } |
| 270 | + |
| 271 | + free(env->agents); |
| 272 | + free(env->ring_buffer); |
| 273 | + |
| 274 | + if (env->client != NULL) { |
| 275 | + c_close_client(env->client); |
| 276 | + } |
| 277 | +} |
| 278 | + |
0 commit comments