Skip to content

Commit fc52c9b

Browse files
[Feature] Warn when kwargs passed to a Scenario are not used (#117)
* amend * replace get with pop * add check * amend * amend * amend
1 parent a9c710c commit fc52c9b

45 files changed

Lines changed: 348 additions & 295 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

tests/test_scenarios/test_discovery.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@ def setup_env(
2525

2626
@pytest.mark.parametrize("n_agents", [1, 4])
2727
def test_heuristic(self, n_agents, n_steps=50, n_envs=4):
28-
self.setup_env(
29-
n_agents=n_agents, random_package_pos_on_line=False, n_envs=n_envs
30-
)
28+
self.setup_env(n_agents=n_agents, n_envs=n_envs)
3129
policy = discovery.HeuristicPolicy(True)
3230

3331
obs = self.env.reset()

tests/test_scenarios/test_flocking.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@ def setup_env(
2525

2626
@pytest.mark.parametrize("n_agents", [1, 5])
2727
def test_heuristic(self, n_agents, n_steps=50, n_envs=4):
28-
self.setup_env(
29-
n_agents=n_agents, random_package_pos_on_line=False, n_envs=n_envs
30-
)
28+
self.setup_env(n_agents=n_agents, n_envs=n_envs)
3129
policy = flocking.HeuristicPolicy(True)
3230

3331
obs = self.env.reset()

tests/test_scenarios/test_transport.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@ def test_not_passing_through_packages(self, n_agents=1, n_envs=4):
5353

5454
@pytest.mark.parametrize("n_agents", [6])
5555
def test_heuristic(self, n_agents, n_envs=4):
56-
self.setup_env(
57-
n_agents=n_agents, random_package_pos_on_line=False, n_envs=n_envs
58-
)
56+
self.setup_env(n_agents=n_agents, n_envs=n_envs)
5957
policy = transport.HeuristicPolicy(self.continuous_actions)
6058

6159
obs = self.env.reset()

vmas/examples/use_vmas_env.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ def use_vmas_env(
3131
random_action: bool = False,
3232
device: str = "cpu",
3333
scenario_name: str = "waterfall",
34-
n_agents: int = 4,
3534
continuous_actions: bool = True,
3635
visualize_render: bool = True,
36+
dict_spaces: bool = True,
37+
**kwargs,
3738
):
3839
"""Example function to use a vmas environment
3940
4041
Args:
4142
continuous_actions (bool): Whether the agents have continuous or discrete actions
42-
n_agents (int): Number of agents
4343
scenario_name (str): Name of scenario
4444
device (str): Torch device to use
4545
render (bool): Whether to render the scenario
@@ -48,15 +48,15 @@ def use_vmas_env(
4848
n_steps (int): Number of steps before returning done
4949
random_action (bool): Use random actions or have all agents perform the down action
5050
visualize_render (bool, optional): Whether to visualize the render. Defaults to ``True``.
51+
dict_spaces (bool, optional): Weather to return obs, rewards, and infos as dictionaries with agent names.
52+
By default, they are lists of len # of agents
53+
kwargs (dict, optional): Keyword arguments to pass to the scenario
5154
5255
Returns:
5356
5457
"""
5558
assert not (save_render and not render), "To save the video you have to render it"
5659

57-
dict_spaces = True # Weather to return obs, rewards, and infos as dictionaries with agent names
58-
# (by default they are lists of len # of agents)
59-
6060
env = make_env(
6161
scenario=scenario_name,
6262
num_envs=num_envs,
@@ -66,7 +66,7 @@ def use_vmas_env(
6666
wrapper=None,
6767
seed=None,
6868
# Environment specific variables
69-
n_agents=n_agents,
69+
**kwargs,
7070
)
7171

7272
frame_list = [] # For creating a gif
@@ -121,4 +121,6 @@ def use_vmas_env(
121121
save_render=False,
122122
random_action=False,
123123
continuous_actions=False,
124+
# Environment specific
125+
n_agents=4,
124126
)

vmas/scenarios/balance.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88
from vmas.simulator.core import Agent, Box, Landmark, Line, Sphere, World
99
from vmas.simulator.heuristic_policy import BaseHeuristicPolicy
1010
from vmas.simulator.scenario import BaseScenario
11-
from vmas.simulator.utils import Color, Y
11+
from vmas.simulator.utils import Color, ScenarioUtils, Y
1212

1313

1414
class Scenario(BaseScenario):
1515
def make_world(self, batch_dim: int, device: torch.device, **kwargs):
16-
self.n_agents = kwargs.get("n_agents", 3)
17-
self.package_mass = kwargs.get("package_mass", 5)
18-
self.random_package_pos_on_line = kwargs.get("random_package_pos_on_line", True)
16+
self.n_agents = kwargs.pop("n_agents", 3)
17+
self.package_mass = kwargs.pop("package_mass", 5)
18+
self.random_package_pos_on_line = kwargs.pop("random_package_pos_on_line", True)
19+
ScenarioUtils.check_kwargs_consumed(kwargs)
1920

2021
assert self.n_agents > 1
2122

vmas/scenarios/ball_passage.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@
1010
from vmas import render_interactively
1111
from vmas.simulator.core import Agent, Box, Landmark, Line, Sphere, World
1212
from vmas.simulator.scenario import BaseScenario
13-
from vmas.simulator.utils import Color, X, Y
13+
from vmas.simulator.utils import Color, ScenarioUtils, X, Y
1414

1515

1616
class Scenario(BaseScenario):
1717
def make_world(self, batch_dim: int, device: torch.device, **kwargs):
18-
self.n_passages = kwargs.get("n_passages", 1)
19-
self.fixed_passage = kwargs.get("fixed_passage", False)
20-
self.random_start_angle = kwargs.get("random_start_angle", True)
18+
self.n_passages = kwargs.pop("n_passages", 1)
19+
self.fixed_passage = kwargs.pop("fixed_passage", False)
20+
self.random_start_angle = kwargs.pop("random_start_angle", True)
21+
ScenarioUtils.check_kwargs_consumed(kwargs)
2122

2223
assert 1 <= self.n_passages <= 20
2324

vmas/scenarios/ball_trajectory.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@
1010
from vmas.simulator.core import Agent, Landmark, Sphere, World
1111
from vmas.simulator.joints import Joint
1212
from vmas.simulator.scenario import BaseScenario
13-
from vmas.simulator.utils import Color, JOINT_FORCE, X
13+
from vmas.simulator.utils import Color, JOINT_FORCE, ScenarioUtils, X
1414

1515

1616
class Scenario(BaseScenario):
1717
def make_world(self, batch_dim: int, device: torch.device, **kwargs):
18-
self.pos_shaping_factor = kwargs.get("pos_shaping_factor", 0)
19-
self.speed_shaping_factor = kwargs.get("speed_shaping_factor", 1)
20-
self.dist_shaping_factor = kwargs.get("dist_shaping_factor", 0)
21-
self.joints = kwargs.get("joints", True)
18+
self.pos_shaping_factor = kwargs.pop("pos_shaping_factor", 0)
19+
self.speed_shaping_factor = kwargs.pop("speed_shaping_factor", 1)
20+
self.dist_shaping_factor = kwargs.pop("dist_shaping_factor", 0)
21+
self.joints = kwargs.pop("joints", True)
22+
ScenarioUtils.check_kwargs_consumed(kwargs)
2223

2324
self.n_agents = 2
2425

vmas/scenarios/buzz_wire.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@
1010
from vmas.simulator.core import Agent, Landmark, Line, Sphere, World
1111
from vmas.simulator.joints import Joint
1212
from vmas.simulator.scenario import BaseScenario
13-
from vmas.simulator.utils import Color
13+
from vmas.simulator.utils import Color, ScenarioUtils
1414

1515

1616
class Scenario(BaseScenario):
1717
def make_world(self, batch_dim: int, device: torch.device, **kwargs):
18-
self.random_start_angle = kwargs.get("random_start_angle", True)
19-
self.pos_shaping_factor = kwargs.get("pos_shaping_factor", 1)
20-
self.collision_reward = kwargs.get("collision_reward", -10)
21-
self.max_speed_1 = kwargs.get("max_speed_1", None) # 0.05
18+
self.random_start_angle = kwargs.pop("random_start_angle", True)
19+
self.pos_shaping_factor = kwargs.pop("pos_shaping_factor", 1)
20+
self.collision_reward = kwargs.pop("collision_reward", -10)
21+
self.max_speed_1 = kwargs.pop("max_speed_1", None) # 0.05
22+
ScenarioUtils.check_kwargs_consumed(kwargs)
2223

2324
self.pos_shaping_factor = 1
2425

vmas/scenarios/debug/asym_joint.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vmas.simulator.core import Agent, Landmark, Sphere, World
1313
from vmas.simulator.joints import Joint
1414
from vmas.simulator.scenario import BaseScenario
15-
from vmas.simulator.utils import Color
15+
from vmas.simulator.utils import Color, ScenarioUtils
1616

1717
if typing.TYPE_CHECKING:
1818
from vmas.simulator.rendering import Geom
@@ -47,19 +47,21 @@ def angle_to_vector(angle):
4747

4848
class Scenario(BaseScenario):
4949
def make_world(self, batch_dim: int, device: torch.device, **kwargs):
50-
self.joint_length = kwargs.get("joint_length", 0.5)
51-
self.random_start_angle = kwargs.get("random_start_angle", False)
52-
self.observe_joint_angle = kwargs.get("observe_joint_angle", False)
53-
self.joint_angle_obs_noise = kwargs.get("joint_angle_obs_noise", 0.0)
54-
self.asym_package = kwargs.get("asym_package", True)
55-
self.mass_ratio = kwargs.get("mass_ratio", 5)
56-
self.mass_position = kwargs.get("mass_position", 0.75)
57-
self.max_speed_1 = kwargs.get("max_speed_1", None) # 0.1
58-
self.obs_noise = kwargs.get("obs_noise", 0.2)
50+
self.joint_length = kwargs.pop("joint_length", 0.5)
51+
self.random_start_angle = kwargs.pop("random_start_angle", False)
52+
self.observe_joint_angle = kwargs.pop("observe_joint_angle", False)
53+
self.joint_angle_obs_noise = kwargs.pop("joint_angle_obs_noise", 0.0)
54+
self.asym_package = kwargs.pop("asym_package", True)
55+
self.mass_ratio = kwargs.pop("mass_ratio", 5)
56+
self.mass_position = kwargs.pop("mass_position", 0.75)
57+
self.max_speed_1 = kwargs.pop("max_speed_1", None) # 0.1
58+
self.obs_noise = kwargs.pop("obs_noise", 0.2)
5959

6060
# Reward
61-
self.rot_shaping_factor = kwargs.get("rot_shaping_factor", 1)
62-
self.energy_reward_coeff = kwargs.get("energy_reward_coeff", 0.08)
61+
self.rot_shaping_factor = kwargs.pop("rot_shaping_factor", 1)
62+
self.energy_reward_coeff = kwargs.pop("energy_reward_coeff", 0.08)
63+
64+
ScenarioUtils.check_kwargs_consumed(kwargs)
6365

6466
# Make world
6567
world = World(

vmas/scenarios/debug/circle_trajectory.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,18 @@
1010
from vmas.simulator.controllers.velocity_controller import VelocityController
1111
from vmas.simulator.core import Agent, Sphere, World
1212
from vmas.simulator.scenario import BaseScenario
13-
from vmas.simulator.utils import Color, TorchUtils, X, Y
13+
from vmas.simulator.utils import Color, ScenarioUtils, TorchUtils, X, Y
1414

1515

1616
class Scenario(BaseScenario):
1717
def make_world(self, batch_dim: int, device: torch.device, **kwargs):
18-
self.u_range = kwargs.get("u_range", 1)
19-
self.a_range = kwargs.get("a_range", 1)
20-
self.obs_noise = kwargs.get("obs_noise", 0.0)
21-
self.dt_delay = kwargs.get("dt_delay", 0)
22-
self.min_input_norm = kwargs.get("min_input_norm", 0.08)
23-
self.linear_friction = kwargs.get("linear_friction", 0.1)
18+
self.u_range = kwargs.pop("u_range", 1)
19+
self.a_range = kwargs.pop("a_range", 1)
20+
self.obs_noise = kwargs.pop("obs_noise", 0.0)
21+
self.dt_delay = kwargs.pop("dt_delay", 0)
22+
self.min_input_norm = kwargs.pop("min_input_norm", 0.08)
23+
self.linear_friction = kwargs.pop("linear_friction", 0.1)
24+
ScenarioUtils.check_kwargs_consumed(kwargs)
2425

2526
self.agent_radius = 0.16
2627
self.desired_radius = 1.5

0 commit comments

Comments
 (0)