@@ -28,6 +28,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
2828 self ._min_dist_between_entities = kwargs .pop ("min_dist_between_entities" , 0.2 )
2929 self ._lidar_range = kwargs .pop ("lidar_range" , 0.35 )
3030 self ._covering_range = kwargs .pop ("covering_range" , 0.25 )
31+ self .use_agent_lidar = kwargs .pop ("use_agent_lidar" , False )
3132 self ._agents_per_target = kwargs .pop ("agents_per_target" , 2 )
3233 self .targets_respawn = kwargs .pop ("targets_respawn" , True )
3334 self .shared_reward = kwargs .pop ("shared_reward" , False )
@@ -57,9 +58,9 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
5758 )
5859
5960 # Add agents
60- # entity_filter_agents: Callable[[Entity], bool] = lambda e: e.name.startswith(
61- # "agent"
62- # )
61+ entity_filter_agents : Callable [[Entity ], bool ] = lambda e : e .name .startswith (
62+ "agent"
63+ )
6364 entity_filter_targets : Callable [[Entity ], bool ] = lambda e : e .name .startswith (
6465 "target"
6566 )
@@ -69,24 +70,32 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
6970 name = f"agent_{ i } " ,
7071 collide = True ,
7172 shape = Sphere (radius = self .agent_radius ),
72- sensors = [
73- # Lidar(
74- # world,
75- # angle_start=0.05,
76- # angle_end=2 * torch.pi + 0.05,
77- # n_rays=12,
78- # max_range=self._lidar_range,
79- # entity_filter=entity_filter_agents,
80- # render_color=Color.BLUE,
81- # ),
82- Lidar (
83- world ,
84- n_rays = 15 ,
85- max_range = self ._lidar_range ,
86- entity_filter = entity_filter_targets ,
87- render_color = Color .GREEN ,
88- ),
89- ],
73+ sensors = (
74+ [
75+ Lidar (
76+ world ,
77+ n_rays = 15 ,
78+ max_range = self ._lidar_range ,
79+ entity_filter = entity_filter_targets ,
80+ render_color = Color .GREEN ,
81+ )
82+ ]
83+ + (
84+ [
85+ Lidar (
86+ world ,
87+ angle_start = 0.05 ,
88+ angle_end = 2 * torch .pi + 0.05 ,
89+ n_rays = 12 ,
90+ max_range = self ._lidar_range ,
91+ entity_filter = entity_filter_agents ,
92+ render_color = Color .BLUE ,
93+ )
94+ ]
95+ if self .use_agent_lidar
96+ else []
97+ )
98+ ),
9099 )
91100 agent .collision_rew = torch .zeros (batch_dim , device = device )
92101 agent .covering_reward = agent .collision_rew .clone ()
@@ -230,15 +239,9 @@ def agent_reward(self, agent):
230239
231240 def observation (self , agent : Agent ):
232241 lidar_1_measures = agent .sensors [0 ].measure ()
233- # lidar_2_measures = agent.sensors[1].measure()
234242 return torch .cat (
235- [
236- agent .state .pos ,
237- agent .state .vel ,
238- agent .state .pos ,
239- lidar_1_measures ,
240- # lidar_2_measures,
241- ],
243+ [agent .state .pos , agent .state .vel , lidar_1_measures ]
244+ + ([agent .sensors [1 ].measure ()] if self .use_agent_lidar else []),
242245 dim = - 1 ,
243246 )
244247
@@ -317,24 +320,25 @@ def compute_action(self, observation: torch.Tensor, u_range: float) -> torch.Ten
317320 closest_point_on_circ_normal *= 0.1
318321 des_pos = closest_point_on_circ + closest_point_on_circ_normal
319322
320- # Move away from other agents within visibility range
321- lidar_agents = observation [:, 4 :16 ]
322- agent_visible = torch .any (lidar_agents < 0.15 , dim = 1 )
323- _ , agent_dir_index = torch .min (lidar_agents , dim = 1 )
324- agent_dir = agent_dir_index / lidar_agents .shape [1 ] * 2 * torch .pi
325- agent_vec = torch .stack ([torch .cos (agent_dir ), torch .sin (agent_dir )], dim = 1 )
326- des_pos_agent = current_pos - agent_vec * 0.1
327- des_pos [agent_visible ] = des_pos_agent [agent_visible ]
328-
329323 # Move towards targets within visibility range
330- lidar_targets = observation [:, 16 : 28 ]
324+ lidar_targets = observation [:, 4 : 19 ]
331325 target_visible = torch .any (lidar_targets < 0.3 , dim = 1 )
332326 _ , target_dir_index = torch .min (lidar_targets , dim = 1 )
333327 target_dir = target_dir_index / lidar_targets .shape [1 ] * 2 * torch .pi
334328 target_vec = torch .stack ([torch .cos (target_dir ), torch .sin (target_dir )], dim = 1 )
335329 des_pos_target = current_pos + target_vec * 0.1
336330 des_pos [target_visible ] = des_pos_target [target_visible ]
337331
332+ if observation .shape [- 1 ] > 19 :
333+ # Move away from other agents within visibility range
334+ lidar_agents = observation [:, 19 :31 ]
335+ agent_visible = torch .any (lidar_agents < 0.15 , dim = 1 )
336+ _ , agent_dir_index = torch .min (lidar_agents , dim = 1 )
337+ agent_dir = agent_dir_index / lidar_agents .shape [1 ] * 2 * torch .pi
338+ agent_vec = torch .stack ([torch .cos (agent_dir ), torch .sin (agent_dir )], dim = 1 )
339+ des_pos_agent = current_pos - agent_vec * 0.1
340+ des_pos [agent_visible ] = des_pos_agent [agent_visible ]
341+
338342 action = torch .clamp (
339343 (des_pos - current_pos ) * 10 ,
340344 min = - u_range ,
0 commit comments