Skip to content

Commit fe9c3b9

Browse files
[Feature] Warn if spwaning entities takes too long (#128)
* amend * amend
1 parent b5eeeca commit fe9c3b9

1 file changed

Lines changed: 11 additions & 0 deletions

File tree

vmas/simulator/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def spawn_entities_randomly(
237237
x_bounds: Tuple[int, int],
238238
y_bounds: Tuple[int, int],
239239
occupied_positions: Tensor = None,
240+
disable_warn: bool = False,
240241
):
241242
batch_size = world.batch_dim if env_index is None else 1
242243

@@ -253,6 +254,7 @@ def spawn_entities_randomly(
253254
min_dist_between_entities,
254255
x_bounds,
255256
y_bounds,
257+
disable_warn,
256258
)
257259
occupied_positions = torch.cat([occupied_positions, pos], dim=1)
258260
entity.set_pos(pos.squeeze(1), batch_index=env_index)
@@ -265,10 +267,12 @@ def find_random_pos_for_entity(
265267
min_dist_between_entities: float,
266268
x_bounds: Tuple[int, int],
267269
y_bounds: Tuple[int, int],
270+
disable_warn: bool = False,
268271
):
269272
batch_size = world.batch_dim if env_index is None else 1
270273

271274
pos = None
275+
tries = 0
272276
while True:
273277
proposed_pos = torch.cat(
274278
[
@@ -296,6 +300,13 @@ def find_random_pos_for_entity(
296300
pos[overlaps] = proposed_pos[overlaps]
297301
else:
298302
break
303+
tries += 1
304+
if tries > 50_000 and not disable_warn:
305+
warnings.warn(
306+
"It is taking many iterations to spawn the entity, make sure the bounds or "
307+
"the min_dist_between_entities are not too tight to fit all entities."
308+
"You can disable this warning by setting disable_warn=True"
309+
)
299310
return pos
300311

301312
@staticmethod

0 commit comments

Comments
 (0)