Skip to content

Commit 21a3199

Browse files
[Feature] allow different number of discrete actions for each action dimension (#119)
* support custom nvec for discrete actions Changes: - add action_nvec property to the Dynamics ABC (defaults to 3s as before when not overridden) - add Agent.action_nvec - In Environment: update get_agent_action_space, get_random_action, _set_action to support Agent.action_nvec * add composite dynamics Changes: - add simulator.dynamics.composite with Composite class - add Rotation to simulator.dynamics.holonomic_with_rot * revert changes to dynamics, improve logic for bc-compatibility * Apply suggestions from code review Co-authored-by: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> * fix discrete to multi-discrete mapping, add tests * Apply suggestions from code review Co-authored-by: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> * improve tests --------- Co-authored-by: Matteo Bettini <55539777+matteobettini@users.noreply.github.com>
1 parent fe9c3b9 commit 21a3199

3 files changed

Lines changed: 230 additions & 33 deletions

File tree

tests/test_vmas.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# Copyright (c) 2022-2024.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
4+
import math
45
import os
6+
import random
57
import sys
68
from pathlib import Path
79

@@ -26,6 +28,14 @@ def scenario_names():
2628
return scenarios
2729

2830

31+
def random_nvecs(count, l_min=2, l_max=6, n_min=2, n_max=6, seed=0):
32+
random.seed(seed)
33+
return [
34+
[random.randint(n_min, n_max) for _ in range(random.randint(l_min, l_max))]
35+
for _ in range(count)
36+
]
37+
38+
2939
def test_all_scenarios_included():
3040
from vmas import debug_scenarios, mpe_scenarios, scenarios
3141

@@ -70,6 +80,163 @@ def test_multi_discrete_actions(scenario, num_envs=10, n_steps=10):
7080
env.step(env.get_random_actions())
7181

7282

83+
@pytest.mark.parametrize("scenario", scenario_names())
84+
@pytest.mark.parametrize("multidiscrete_actions", [True, False])
85+
def test_discrete_action_nvec(scenario, multidiscrete_actions, num_envs=10, n_steps=5):
86+
env = make_env(
87+
scenario=scenario,
88+
num_envs=num_envs,
89+
seed=0,
90+
multidiscrete_actions=multidiscrete_actions,
91+
continuous_actions=False,
92+
)
93+
if (
94+
type(env.scenario).process_action
95+
is not vmas.simulator.scenario.BaseScenario.process_action
96+
):
97+
pytest.skip("Scenario uses a custom process_action method.")
98+
99+
random.seed(0)
100+
for agent in env.world.agents:
101+
agent.discrete_action_nvec = [
102+
random.randint(2, 6) for _ in range(agent.action_size)
103+
]
104+
env.action_space = env.get_action_space()
105+
106+
def to_multidiscrete(action, nvec):
107+
action_multi = []
108+
for i in range(len(nvec)):
109+
n = math.prod(nvec[i + 1 :])
110+
action_multi.append(action // n)
111+
action = action % n
112+
return torch.stack(action_multi, dim=-1)
113+
114+
def full_nvec(agent, world):
115+
return list(agent.discrete_action_nvec) + (
116+
[world.dim_c] if not agent.silent and world.dim_c != 0 else []
117+
)
118+
119+
for _ in range(n_steps):
120+
actions = env.get_random_actions()
121+
122+
# Check that generated actions are in the action space
123+
for a_batch, s in zip(actions, env.action_space.spaces):
124+
for a in a_batch:
125+
assert a.numpy() in s
126+
127+
env.step(actions)
128+
129+
if not multidiscrete_actions:
130+
actions = [
131+
to_multidiscrete(a.squeeze(-1), full_nvec(agent, env.world))
132+
for a, agent in zip(actions, env.world.policy_agents)
133+
]
134+
135+
# Check that discrete action to continuous control mapping is correct.
136+
for i_a, agent in enumerate(env.world.policy_agents):
137+
for i, n in enumerate(agent.discrete_action_nvec):
138+
a = actions[i_a][:, i]
139+
u = agent.action.u[:, i]
140+
U = agent.action.u_range_tensor[i]
141+
k = agent.action.u_multiplier_tensor[i]
142+
for aj, uj in zip(a, u):
143+
assert aj in range(
144+
n
145+
), f"discrete action {aj} not in [0,{n-1}] (n={n}, U={U}, k={k})"
146+
if n % 2 != 0:
147+
assert (
148+
aj != 0 or uj == 0
149+
), f"discrete action {aj} maps to control {uj} (n={n}), U={U}, k={k})"
150+
assert (aj < 1 or aj > n // 2) or torch.isclose(
151+
uj / k, (2 * U * (aj - 1)) / (n - 1) - U
152+
), f"discrete action {aj} maps to control {uj} (n={n}, U={U}, k={k})"
153+
assert (aj <= n // 2) or torch.isclose(
154+
uj / k, 2 * U * (aj / (n - 1)) - U
155+
), f"discrete action {aj} maps to control {uj} (n={n}), U={U}, k={k})"
156+
else:
157+
assert torch.isclose(
158+
uj / k, 2 * U * (aj / (n - 1)) - U
159+
), f"discrete action {aj} maps to control {uj} (n={n}), U={U}, k={k})"
160+
161+
162+
@pytest.mark.parametrize(
163+
"nvecs", list(zip(random_nvecs(10, seed=0), random_nvecs(10, seed=42)))
164+
)
165+
def test_discrete_action_nvec_discrete_to_multi(
166+
nvecs, scenario="transport", num_envs=10, n_steps=5
167+
):
168+
kwargs = {
169+
"scenario": scenario,
170+
"num_envs": num_envs,
171+
"seed": 0,
172+
"continuous_actions": False,
173+
}
174+
env = make_env(**kwargs, multidiscrete_actions=False)
175+
env_multi = make_env(**kwargs, multidiscrete_actions=True)
176+
if (
177+
type(env.scenario).process_action
178+
is not vmas.simulator.scenario.BaseScenario.process_action
179+
):
180+
pytest.skip("Scenario uses a custom process_action method.")
181+
182+
def set_nvec(agent, nvec):
183+
agent.action_size = len(nvec)
184+
agent.discrete_action_nvec = nvec
185+
agent.action.action_size = agent.action_size
186+
187+
random.seed(0)
188+
for agent, agent_multi, nvec in zip(
189+
env.world.policy_agents, env_multi.world.policy_agents, nvecs
190+
):
191+
set_nvec(agent, nvec)
192+
set_nvec(agent_multi, nvec)
193+
env.action_space = env.get_action_space()
194+
env_multi.action_space = env.get_action_space()
195+
196+
def full_nvec(agent, world):
197+
return list(agent.discrete_action_nvec) + (
198+
[world.dim_c] if not agent.silent and world.dim_c != 0 else []
199+
)
200+
201+
def full_action_size(agent, world):
202+
return len(full_nvec(agent, world))
203+
204+
for _ in range(n_steps):
205+
actions_multi = env_multi.get_random_actions()
206+
prodss = [
207+
[
208+
math.prod(full_nvec(agent, env.world)[i + 1 :])
209+
for i in range(full_action_size(agent, env.world))
210+
]
211+
for agent in env.world.policy_agents
212+
]
213+
# Compute the expected mapping from multi-discrete to discrete
214+
actions = [
215+
(a_multi * torch.tensor(prods)).sum(dim=1)
216+
for a_multi, prods in zip(actions_multi, prodss)
217+
]
218+
219+
env_multi.step(actions_multi)
220+
env.step(actions)
221+
222+
# Check that both discrete and multi-discrete actions result in the
223+
# same control value
224+
for agent, agent_multi, action, action_multi in zip(
225+
env.world.policy_agents,
226+
env_multi.world.policy_agents,
227+
actions,
228+
actions_multi,
229+
):
230+
U = agent.action.u_range_tensor
231+
k = agent.action.u_multiplier_tensor
232+
for u, u_multi, a, a_multi in zip(
233+
agent.action.u, agent_multi.action.u, action, action_multi
234+
):
235+
assert torch.allclose(
236+
u, u_multi
237+
), f"{u} != {u_multi} (nvec={agent.discrete_action_nvec}, a={a}, a_multi={a_multi}, U={U}, k={k})"
238+
239+
73240
@pytest.mark.parametrize("scenario", scenario_names())
74241
def test_non_dict_spaces_actions(scenario, num_envs=10, n_steps=10):
75242
env = make_env(

vmas/simulator/core.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,9 @@ def __init__(
862862
render_action: bool = False,
863863
dynamics: Dynamics = None, # Defaults to holonomic
864864
action_size: int = None, # Defaults to what required by the dynamics
865+
discrete_action_nvec: List[
866+
int
867+
] = None, # Defaults to 3-way discretization if discrete actions are chosen (stay, decrement, increment)
865868
):
866869
super().__init__(
867870
name,
@@ -884,6 +887,17 @@ def __init__(
884887
if obs_range == 0.0:
885888
assert sensors is None, f"Blind agent cannot have sensors, got {sensors}"
886889

890+
if action_size is not None and discrete_action_nvec is not None:
891+
if action_size != len(discrete_action_nvec):
892+
raise ValueError(
893+
f"action_size {action_size} is inconsistent with discrete_action_nvec {discrete_action_nvec}"
894+
)
895+
if discrete_action_nvec is not None:
896+
if not all(n > 1 for n in discrete_action_nvec):
897+
raise ValueError(
898+
f"All values in discrete_action_nvec must be greater than 1, got {discrete_action_nvec}"
899+
)
900+
887901
# cannot observe the world
888902
self._obs_range = obs_range
889903
# observation noise
@@ -914,9 +928,16 @@ def __init__(
914928
# Dynamics
915929
self.dynamics = dynamics if dynamics is not None else Holonomic()
916930
# Action
917-
self.action_size = (
918-
action_size if action_size is not None else self.dynamics.needed_action_size
919-
)
931+
if action_size is not None:
932+
self.action_size = action_size
933+
elif discrete_action_nvec is not None:
934+
self.action_size = len(discrete_action_nvec)
935+
else:
936+
self.action_size = self.dynamics.needed_action_size
937+
if discrete_action_nvec is None:
938+
self.discrete_action_nvec = [3] * self.action_size
939+
else:
940+
self.discrete_action_nvec = discrete_action_nvec
920941
self.dynamics.agent = self
921942
self._action = Action(
922943
u_range=u_range,

vmas/simulator/environment/environment.py

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) 2022-2024.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
4+
import math
45
import random
56
from ctypes import byref
67
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
@@ -334,13 +335,13 @@ def get_agent_action_space(self, agent: Agent):
334335
dtype=np.float32,
335336
)
336337
elif self.multidiscrete_actions:
337-
actions = [3] * agent.action_size + (
338+
actions = agent.discrete_action_nvec + (
338339
[self.world.dim_c] if not agent.silent and self.world.dim_c != 0 else []
339340
)
340341
return spaces.MultiDiscrete(actions)
341342
else:
342343
return spaces.Discrete(
343-
3**agent.action_size
344+
math.prod(agent.discrete_action_nvec)
344345
* (
345346
self.world.dim_c
346347
if not agent.silent and self.world.dim_c != 0
@@ -503,41 +504,49 @@ def _set_action(self, action, agent):
503504
if not self.multidiscrete_actions:
504505
# This bit of code translates the discrete action (taken from a space that
505506
# is the cartesian product of all action spaces) into a multi discrete action.
506-
# For example, if agent.action_size=4, it will mean that the agent will have
507-
# 4 actions each with 3 possibilities (stay, decrement, increment).
508-
# The env will have a space Discrete(3**4).
509-
# This code will translate the action (with shape [n_envs,1] and range [0,3**4)) to an
510-
# action with shape [n_envs,4] and range [0,3).
511-
n_actions = self.get_agent_action_space(agent).n
512-
action_range = torch.arange(n_actions, device=self.device).expand(
513-
self.world.batch_dim, n_actions
514-
)
515-
physical_action = action
516-
action_range = torch.where(action_range == physical_action, 1.0, 0.0)
517-
action_range = action_range.view(
518-
(self.world.batch_dim,)
519-
+ (3,) * agent.action_size
520-
+ (self.world.dim_c,)
521-
* (1 if not agent.silent and self.world.dim_c != 0 else 0)
507+
# This is done by iteratively taking the modulo of the action and dividing by the
508+
# number of actions in the current action space, which treats the action as if
509+
# it was the "flat index" of the multi-discrete actions. E.g. if we have
510+
# nvec = [3,2], action 0 corresponds to the actions [0,0],
511+
# action 1 corresponds to the action [0,1], action 2 corresponds
512+
# to the action [1,0], action 3 corresponds to the action [1,1], etc.
513+
flat_action = action.squeeze(-1)
514+
actions = []
515+
nvec = list(agent.discrete_action_nvec) + (
516+
[self.world.dim_c]
517+
if not agent.silent and self.world.dim_c != 0
518+
else []
522519
)
523-
action = action_range.nonzero()[:, 1:]
520+
for i in range(len(nvec)):
521+
n = math.prod(nvec[i + 1 :])
522+
actions.append(flat_action // n)
523+
flat_action = flat_action % n
524+
action = torch.stack(actions, dim=-1)
524525

525526
# Now we have an action with shape [n_envs, action_size+comms_actions]
526-
for _ in range(agent.action_size):
527-
physical_action = action[:, action_index].unsqueeze(-1)
527+
for n in agent.discrete_action_nvec:
528+
physical_action = action[:, action_index]
528529
self._check_discrete_action(
529-
physical_action,
530+
physical_action.unsqueeze(-1),
530531
low=0,
531-
high=3,
532+
high=n,
532533
type="physical",
533534
)
534-
535-
arr1 = physical_action == 1
536-
arr2 = physical_action == 2
537-
538-
disc_action_value = agent.action.u_range_tensor[action_index]
539-
agent.action.u[:, action_index] -= disc_action_value * arr1.squeeze(-1)
540-
agent.action.u[:, action_index] += disc_action_value * arr2.squeeze(-1)
535+
u_max = agent.action.u_range_tensor[action_index]
536+
# For odd n we want the first action to always map to u=0, so
537+
# we swap 0 values with the middle value, and shift the first
538+
# half of the remaining values by -1.
539+
if n % 2 != 0:
540+
stay = physical_action == 0
541+
decrement = (physical_action > 0) & (physical_action <= n // 2)
542+
physical_action[stay] = n // 2
543+
physical_action[decrement] -= 1
544+
# We know u must be in [-u_max, u_max], and we know action is
545+
# in [0, n-1]. Conversion steps: [0, n-1] -> [0, 1] -> [0, 2*u_max] -> [-u_max, u_max]
546+
# E.g. action 0 -> -u_max, action n-1 -> u_max, action 1 -> -u_max + 2*u_max/(n-1)
547+
agent.action.u[:, action_index] = (physical_action / (n - 1)) * (
548+
2 * u_max
549+
) - u_max
541550

542551
action_index += 1
543552

0 commit comments

Comments
 (0)