Skip to content

Commit f50421f

Browse files
stmioFinlaySanders
andcommitted
Refactor drone environments and merge back into one env
Co-authored-by: Finlay Sanders <finlay_sanders@icloud.com>
1 parent 7a99b3b commit f50421f

16 files changed

Lines changed: 949 additions & 2182 deletions

File tree

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[base]
22
package = ocean
3-
env_name = puffer_drone_swarm
3+
env_name = puffer_drone
44
policy_name = Policy
55
rnn_name = Recurrent
66

pufferlib/config/ocean/drone_race.ini

Lines changed: 0 additions & 37 deletions
This file was deleted.
Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
#include "drone_swarm.h"
1+
#include "drone.h"
2+
#include "render.h"
23

3-
#define Env DroneSwarm
4+
#define Env DroneEnv
45
#include "../env_binding.h"
56

67
static int my_init(Env *env, PyObject *args, PyObject *kwargs) {
@@ -14,8 +15,10 @@ static int my_log(PyObject *dict, Log *log) {
1415
assign_to_dict(dict, "perf", log->perf);
1516
assign_to_dict(dict, "score", log->score);
1617
assign_to_dict(dict, "rings_passed", log->rings_passed);
18+
assign_to_dict(dict, "ring_collisions", log->ring_collision);
1719
assign_to_dict(dict, "collision_rate", log->collision_rate);
1820
assign_to_dict(dict, "oob", log->oob);
21+
assign_to_dict(dict, "timeout", log->timeout);
1922
assign_to_dict(dict, "episode_return", log->episode_return);
2023
assign_to_dict(dict, "episode_length", log->episode_length);
2124
assign_to_dict(dict, "n", log->n);
Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
// Standalone C demo for DroneSwarm environment
1+
// Standalone C demo for drone environment
22
// Compile using: ./scripts/build_ocean.sh drone [local|fast]
33
// Run with: ./drone
44

5-
#include "drone_swarm.h"
5+
#include "drone.h"
66
#include "puffernet.h"
77
#include <time.h>
88

@@ -88,7 +88,7 @@ void forward_linearcontlstm(LinearContLSTM *net, float *observations, float *act
8888
}
8989
}
9090

91-
void generate_dummy_actions(DroneSwarm *env) {
91+
void generate_dummy_actions(DroneEnv *env) {
9292
// Generate random floats in [-1, 1] range
9393
env->actions[0] = ((float)rand() / (float)RAND_MAX) * 2.0f - 1.0f;
9494
env->actions[1] = ((float)rand() / (float)RAND_MAX) * 2.0f - 1.0f;
@@ -98,14 +98,14 @@ void generate_dummy_actions(DroneSwarm *env) {
9898

9999
#ifdef __EMSCRIPTEN__
100100
typedef struct {
101-
DroneSwarm *env;
101+
DroneEnv *env;
102102
LinearContLSTM *net;
103103
Weights *weights;
104104
} WebRenderArgs;
105105

106106
void emscriptenStep(void *e) {
107107
WebRenderArgs *args = (WebRenderArgs *)e;
108-
DroneSwarm *env = args->env;
108+
DroneEnv *env = args->env;
109109
LinearContLSTM *net = args->net;
110110

111111
forward_linearcontlstm(net, env->observations, env->actions);
@@ -120,13 +120,13 @@ WebRenderArgs *web_args = NULL;
120120
int main() {
121121
srand(time(NULL)); // Seed random number generator
122122

123-
DroneSwarm *env = calloc(1, sizeof(DroneSwarm));
123+
DroneEnv *env = calloc(1, sizeof(DroneEnv));
124124
env->num_agents = 64;
125125
env->max_rings = 10;
126-
env->task = TASK_ORBIT;
126+
env->task = ORBIT;
127127
init(env);
128128

129-
size_t obs_size = 41;
129+
size_t obs_size = 26;
130130
size_t act_size = 4;
131131
env->observations = (float *)calloc(env->num_agents * obs_size, sizeof(float));
132132
env->actions = (float *)calloc(env->num_agents * act_size, sizeof(float));

pufferlib/ocean/drone/drone.h

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
// Originally made by Sam Turner and Finlay Sanders, 2025.
2+
// Included in pufferlib under the original project's MIT license.
3+
// https://github.com/tensaur/drone
4+
5+
#pragma once
6+
7+
#include <math.h>
8+
#include <limits.h>
9+
#include <stdbool.h>
10+
#include <stdlib.h>
11+
12+
#include "dronelib.h"
13+
#include "tasks.h"
14+
15+
typedef struct Client Client;
16+
typedef struct DroneEnv DroneEnv;
17+
18+
struct DroneEnv {
19+
float *observations;
20+
float *actions;
21+
float *rewards;
22+
unsigned char *terminals;
23+
24+
Log log;
25+
int tick;
26+
int report_interval;
27+
28+
DroneTask task;
29+
int num_agents;
30+
Drone *agents;
31+
32+
int max_rings;
33+
Target* ring_buffer;
34+
35+
Client *client;
36+
};
37+
38+
void init(DroneEnv *env) {
39+
env->agents = (Drone*) calloc(env->num_agents, sizeof(Drone));
40+
env->ring_buffer = (Target*) calloc(env->max_rings, sizeof(Target));
41+
42+
for (int i = 0; i < env->num_agents; i++) {
43+
env->agents[i].target = (Target*) calloc(1, sizeof(Target));
44+
}
45+
46+
env->log = (Log){0};
47+
env->tick = 0;
48+
}
49+
50+
void add_log(DroneEnv *env, int idx, bool oob, bool ring_collision,
51+
bool timeout) {
52+
Drone *agent = &env->agents[idx];
53+
env->log.score += agent->score;
54+
env->log.episode_return += agent->episode_return;
55+
env->log.episode_length += agent->episode_length;
56+
env->log.collision_rate += agent->collisions / (float)agent->episode_length;
57+
env->log.perf += agent->score / (float)agent->episode_length;
58+
if (oob) {
59+
env->log.oob += 1.0f;
60+
}
61+
if (ring_collision) {
62+
env->log.ring_collision += 1.0f;
63+
}
64+
if (timeout) {
65+
env->log.timeout += 1.0f;
66+
}
67+
env->log.n += 1.0f;
68+
69+
agent->episode_length = 0;
70+
agent->episode_return = 0.0f;
71+
}
72+
73+
void compute_observations(DroneEnv *env) {
74+
int idx = 0;
75+
76+
for (int i = 0; i < env->num_agents; i++) {
77+
Drone *agent = &env->agents[i];
78+
79+
Quat q_inv = quat_inverse(agent->state.quat);
80+
Vec3 linear_vel_body = quat_rotate(q_inv, agent->state.vel);
81+
Vec3 to_target = sub3(agent->target->pos, agent->state.pos);
82+
83+
env->observations[idx++] = linear_vel_body.x / agent->params.max_vel;
84+
env->observations[idx++] = linear_vel_body.y / agent->params.max_vel;
85+
env->observations[idx++] = linear_vel_body.z / agent->params.max_vel;
86+
87+
env->observations[idx++] = agent->state.omega.x / agent->params.max_omega;
88+
env->observations[idx++] = agent->state.omega.y / agent->params.max_omega;
89+
env->observations[idx++] = agent->state.omega.z / agent->params.max_omega;
90+
91+
env->observations[idx++] = agent->state.quat.w;
92+
env->observations[idx++] = agent->state.quat.x;
93+
env->observations[idx++] = agent->state.quat.y;
94+
env->observations[idx++] = agent->state.quat.z;
95+
96+
env->observations[idx++] = agent->state.rpms[0] / agent->params.max_rpm;
97+
env->observations[idx++] = agent->state.rpms[1] / agent->params.max_rpm;
98+
env->observations[idx++] = agent->state.rpms[2] / agent->params.max_rpm;
99+
env->observations[idx++] = agent->state.rpms[3] / agent->params.max_rpm;
100+
101+
env->observations[idx++] = to_target.x / GRID_X;
102+
env->observations[idx++] = to_target.y / GRID_Y;
103+
env->observations[idx++] = to_target.z / GRID_Z;
104+
105+
env->observations[idx++] = clampf(to_target.x, -1.0f, 1.0f);
106+
env->observations[idx++] = clampf(to_target.y, -1.0f, 1.0f);
107+
env->observations[idx++] = clampf(to_target.z, -1.0f, 1.0f);
108+
109+
env->observations[idx++] = agent->target->normal.x;
110+
env->observations[idx++] = agent->target->normal.y;
111+
env->observations[idx++] = agent->target->normal.z;
112+
113+
// Multiagent obs
114+
Drone *nearest = nearest_drone(agent, env->agents, env->num_agents);
115+
if (env->num_agents > 1) {
116+
env->observations[idx++] =
117+
clampf(nearest->state.pos.x - agent->state.pos.x, -1.0f, 1.0f);
118+
env->observations[idx++] =
119+
clampf(nearest->state.pos.y - agent->state.pos.y, -1.0f, 1.0f);
120+
env->observations[idx++] =
121+
clampf(nearest->state.pos.z - agent->state.pos.z, -1.0f, 1.0f);
122+
} else {
123+
env->observations[idx++] = MAX_DIST;
124+
env->observations[idx++] = MAX_DIST;
125+
env->observations[idx++] = MAX_DIST;
126+
}
127+
}
128+
}
129+
130+
void reset_agent(DroneEnv *env, Drone *agent, int idx) {
131+
agent->last_dist_reward = 0.0f;
132+
agent->episode_return = 0.0f;
133+
agent->episode_length = 0;
134+
agent->collisions = 0.0f;
135+
agent->score = 0.0f;
136+
137+
agent->buffer = env->ring_buffer;
138+
agent->buffer_size = env->max_rings;
139+
agent->buffer_idx = -1;
140+
141+
float size = rndf(0.1f, 0.4f);
142+
init_drone(agent, size, 0.1f);
143+
144+
agent->state.pos =
145+
(Vec3){rndf(-MARGIN_X, MARGIN_X), rndf(-MARGIN_Y, MARGIN_Y),
146+
rndf(-MARGIN_Z, MARGIN_Z)};
147+
148+
if (env->task == RACE) {
149+
while (norm3(sub3(agent->state.pos, env->ring_buffer[0].pos)) <
150+
2.0f * RING_RADIUS) {
151+
agent->state.pos =
152+
(Vec3){rndf(-MARGIN_X, MARGIN_X), rndf(-MARGIN_Y, MARGIN_Y),
153+
rndf(-MARGIN_Z, MARGIN_Z)};
154+
}
155+
}
156+
157+
agent->prev_pos = agent->state.pos;
158+
}
159+
160+
void c_reset(DroneEnv *env) {
161+
env->tick = 0;
162+
int rng = rand();
163+
164+
if (rng > INT_MAX / 2) {
165+
env->task = RACE;
166+
} else {
167+
env->task = (DroneTask)(rng % (TASK_N - 1));
168+
}
169+
170+
if (env->task == RACE) {
171+
reset_rings(env->ring_buffer, env->max_rings);
172+
}
173+
174+
for (int i = 0; i < env->num_agents; i++) {
175+
Drone *agent = &env->agents[i];
176+
reset_agent(env, agent, i);
177+
set_target(env->task, env->agents, i, env->num_agents);
178+
}
179+
180+
compute_observations(env);
181+
}
182+
183+
void c_step(DroneEnv *env) {
184+
env->tick = (env->tick + 1) % HORIZON;
185+
186+
for (int i = 0; i < env->num_agents; i++) {
187+
Drone *agent = &env->agents[i];
188+
env->rewards[i] = 0;
189+
env->terminals[i] = 0;
190+
191+
float *atn = &env->actions[4 * i];
192+
move_drone(agent, atn);
193+
194+
bool out_of_bounds =
195+
agent->state.pos.x < -GRID_X || agent->state.pos.x > GRID_X ||
196+
agent->state.pos.y < -GRID_Y || agent->state.pos.y > GRID_Y ||
197+
agent->state.pos.z < -GRID_Z || agent->state.pos.z > GRID_Z;
198+
199+
move_target(agent);
200+
201+
bool collision = check_collision(agent, env->agents, env->num_agents);
202+
float reward = 0.0f;
203+
204+
if (env->task == RACE) {
205+
// Check ring passage
206+
Target *ring = &env->ring_buffer[agent->buffer_idx];
207+
int ring_passage = check_ring(agent, ring);
208+
209+
// Ring collision
210+
if (ring_passage < 0) {
211+
env->rewards[i] = (float)ring_passage;
212+
agent->episode_return += (float)ring_passage;
213+
env->terminals[i] = 1;
214+
add_log(env, i, false, true, false);
215+
reset_agent(env, agent, i);
216+
set_target(env->task, env->agents, i, env->num_agents);
217+
continue;
218+
}
219+
220+
// Successfully passed through ring - advance to next
221+
if (ring_passage > 0) {
222+
set_target(env->task, env->agents, i, env->num_agents);
223+
env->log.rings_passed += 1.0f;
224+
}
225+
226+
reward = dynamic_task_reward(agent, collision, ring_passage);
227+
} else {
228+
reward = static_task_reward(agent, collision);
229+
}
230+
231+
// Update agent state
232+
agent->episode_length++;
233+
agent->score += reward;
234+
if (collision) {
235+
agent->collisions += 1.0f;
236+
}
237+
238+
env->rewards[i] = reward;
239+
agent->episode_return += reward;
240+
241+
// Check termination conditions
242+
if (out_of_bounds) {
243+
env->rewards[i] -= 1.0f;
244+
agent->episode_return -= 1.0f;
245+
env->terminals[i] = 1;
246+
add_log(env, i, true, false, false);
247+
248+
reset_agent(env, agent, i);
249+
set_target(env->task, env->agents, i, env->num_agents);
250+
static_task_reward(agent, false);
251+
} else if (env->tick >= HORIZON - 1) {
252+
env->terminals[i] = 1;
253+
add_log(env, i, false, false, true);
254+
}
255+
}
256+
257+
if (env->tick >= HORIZON - 1) {
258+
c_reset(env);
259+
}
260+
261+
compute_observations(env);
262+
}
263+
264+
void c_close_client(Client* client);
265+
266+
void c_close(DroneEnv *env) {
267+
for (int i = 0; i < env->num_agents; i++) {
268+
free(env->agents[i].target);
269+
}
270+
271+
free(env->agents);
272+
free(env->ring_buffer);
273+
274+
if (env->client != NULL) {
275+
c_close_client(env->client);
276+
}
277+
}
278+

0 commit comments

Comments
 (0)