Skip to content

Commit 3964475

Browse files
[Tests] Improve tests (#118)
1 parent 195f78f commit 3964475

2 files changed

Lines changed: 62 additions & 7 deletions

File tree

tests/test_vmas.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ def test_all_scenarios_included():
3636

3737
@pytest.mark.parametrize("scenario", scenario_names())
3838
@pytest.mark.parametrize("continuous_actions", [True, False])
39-
def test_use_vmas_env(scenario, continuous_actions, num_envs=10, n_steps=10):
39+
def test_use_vmas_env(
40+
scenario, continuous_actions, dict_spaces=True, num_envs=10, n_steps=10
41+
):
4042
render = True
4143
if sys.platform.startswith("win32"):
4244
# Windows on github servers has issues with pyglet
@@ -51,9 +53,36 @@ def test_use_vmas_env(scenario, continuous_actions, num_envs=10, n_steps=10):
5153
continuous_actions=continuous_actions,
5254
num_envs=num_envs,
5355
n_steps=n_steps,
56+
dict_spaces=dict_spaces,
5457
)
5558

5659

60+
@pytest.mark.parametrize("scenario", scenario_names())
61+
def test_multi_discrete_actions(scenario, num_envs=10, n_steps=10):
62+
env = make_env(
63+
scenario=scenario,
64+
num_envs=num_envs,
65+
seed=0,
66+
multidiscrete_actions=True,
67+
continuous_actions=False,
68+
)
69+
for _ in range(n_steps):
70+
env.step(env.get_random_actions())
71+
72+
73+
@pytest.mark.parametrize("scenario", scenario_names())
74+
def test_non_dict_spaces_actions(scenario, num_envs=10, n_steps=10):
75+
env = make_env(
76+
scenario=scenario,
77+
num_envs=num_envs,
78+
seed=0,
79+
continuous_actions=True,
80+
dict_spaces=False,
81+
)
82+
for _ in range(n_steps):
83+
env.step(env.get_random_actions())
84+
85+
5786
@pytest.mark.parametrize("scenario", scenario_names())
5887
def test_partial_reset(scenario, num_envs=10, n_steps=10):
5988
env = make_env(
@@ -70,6 +99,19 @@ def test_partial_reset(scenario, num_envs=10, n_steps=10):
7099
env_index = 0
71100

72101

102+
@pytest.mark.parametrize("scenario", scenario_names())
103+
def test_global_reset(scenario, num_envs=10, n_steps=10):
104+
env = make_env(
105+
scenario=scenario,
106+
num_envs=num_envs,
107+
seed=0,
108+
)
109+
for step in range(n_steps):
110+
env.step(env.get_random_actions())
111+
if step == n_steps // 2:
112+
env.reset()
113+
114+
73115
@pytest.mark.parametrize("scenario", vmas.scenarios + vmas.mpe_scenarios)
74116
def test_vmas_differentiable(scenario, n_steps=10, n_envs=10):
75117
if scenario == "football" or scenario == "simple_crypto":

vmas/simulator/environment/environment.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -406,12 +406,25 @@ def get_random_action(self, agent: Agent) -> torch.Tensor:
406406
)
407407
action = torch.stack(actions, dim=-1)
408408
else:
409-
action = torch.randint(
410-
low=0,
411-
high=self.get_agent_action_space(agent).n,
412-
size=(agent.batch_dim,),
413-
device=agent.device,
414-
)
409+
action_space = self.get_agent_action_space(agent)
410+
if self.multidiscrete_actions:
411+
actions = [
412+
torch.randint(
413+
low=0,
414+
high=action_space.nvec[action_index],
415+
size=(agent.batch_dim,),
416+
device=agent.device,
417+
)
418+
for action_index in range(action_space.shape[0])
419+
]
420+
action = torch.stack(actions, dim=-1)
421+
else:
422+
action = torch.randint(
423+
low=0,
424+
high=action_space.n,
425+
size=(agent.batch_dim,),
426+
device=agent.device,
427+
)
415428
return action
416429

417430
def get_random_actions(self) -> Sequence[torch.Tensor]:

0 commit comments

Comments
 (0)