Skip to content

Commit cd7c14a

Browse files
[Feature] Vectorized lidar (#124)
* init * added raycasting with multiple angles * out commenting the old * aemnd * aemnd * aemnd * aemnd * aemnd * aemnd * aemnd --------- Co-authored-by: Zartris <Jonas.le.fevre@gmail.com>
1 parent a9d545f commit cd7c14a

5 files changed

Lines changed: 439 additions & 12 deletions

File tree

tests/test_lidar.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (c) 2024.
2+
# ProrokLab (https://www.proroklab.org/)
3+
# All rights reserved.
4+
5+
import torch
6+
7+
from vmas import make_env
8+
9+
10+
def test_vectorized_lidar(n_envs=12, n_steps=15):
11+
def get_obs(env):
12+
rollout_obs = []
13+
for _ in range(n_steps):
14+
obs, _, _, _ = env.step(env.get_random_actions())
15+
obs = torch.stack(obs, dim=-1)
16+
rollout_obs.append(obs)
17+
return torch.stack(rollout_obs, dim=-1)
18+
19+
env_vec_lidar = make_env(
20+
scenario="pollock", num_envs=n_envs, seed=0, lidar=True, vectorized_lidar=True
21+
)
22+
obs_vec_lidar = get_obs(env_vec_lidar)
23+
env_non_vec_lidar = make_env(
24+
scenario="pollock", num_envs=n_envs, seed=0, lidar=True, vectorized_lidar=False
25+
)
26+
obs_non_vec_lidar = get_obs(env_non_vec_lidar)
27+
28+
assert torch.allclose(obs_vec_lidar, obs_non_vec_lidar)

vmas/scenarios/debug/pollock.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from vmas import render_interactively
88
from vmas.simulator.core import Agent, Box, Landmark, Line, Sphere, World
99
from vmas.simulator.scenario import BaseScenario
10+
11+
from vmas.simulator.sensors import Lidar
1012
from vmas.simulator.utils import Color, ScenarioUtils
1113

1214

@@ -15,6 +17,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
1517
self.n_agents = kwargs.pop("n_agents", 15)
1618
self.n_lines = kwargs.pop("n_lines", 15)
1719
self.n_boxes = kwargs.pop("n_boxes", 15)
20+
self.lidar = kwargs.pop("lidar", False)
21+
self.vectorized_lidar = kwargs.pop("vectorized_lidar", True)
1822
ScenarioUtils.check_kwargs_consumed(kwargs)
1923

2024
self.agent_radius = 0.05
@@ -43,6 +47,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
4347
shape=Sphere(radius=self.agent_radius),
4448
u_multiplier=0.7,
4549
rotatable=True,
50+
sensors=[Lidar(world, n_rays=16, max_range=0.5)] if self.lidar else [],
4651
)
4752
world.add_agent(agent)
4853

@@ -85,7 +90,11 @@ def reward(self, agent: Agent):
8590
return torch.zeros(self.world.batch_dim, device=self.world.device)
8691

8792
def observation(self, agent: Agent):
88-
return torch.zeros(self.world.batch_dim, 1, device=self.world.device)
93+
return (
94+
torch.zeros(self.world.batch_dim, 1, device=self.world.device)
95+
if not self.lidar
96+
else agent.sensors[0].measure(vectorized=self.vectorized_lidar)
97+
)
8998

9099

91100
if __name__ == "__main__":

0 commit comments

Comments
 (0)