Skip to content

Commit aa1e02a

Browse files
[Feature] Joint rotation offset and more dynamics (#125)
* amend * amend * amend * amend * amend * amend
1 parent cd7c14a commit aa1e02a

6 files changed

Lines changed: 112 additions & 8 deletions

File tree

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,8 @@ To create a fake screen you need to have `Xvfb` installed.
345345
| **<p align="center">joint_passage_size</p>** <br/> <img src="https://github.com/matteobettini/vmas-media/blob/main/media/scenarios/joint_passage_size.gif?raw=true"/> | **<p align="center">flocking</p>** <br/> <img src="https://github.com/matteobettini/vmas-media/blob/main/media/scenarios/flocking.gif?raw=true"/> | **<p align="center">discovery</p>** <br/> <img src="https://github.com/matteobettini/vmas-media/blob/main/media/scenarios/discovery.gif?raw=true"/> |
346346
| **<p align="center">joint_passage</p>** <br/> <img src="https://github.com/matteobettini/vmas-media/blob/main/media/scenarios/joint_passage.gif?raw=true"/> | **<p align="center">ball_passage</p>** <br/> <img src="https://github.com/matteobettini/vmas-media/blob/main/media/scenarios/ball_passage.gif?raw=true"/> | **<p align="center">ball_trajectory</p>** <br/> <img src="https://github.com/matteobettini/vmas-media/blob/main/media/scenarios/ball_trajectory.gif?raw=true"/> |
347347
| **<p align="center">buzz_wire</p>** <br/> <img src="https://github.com/matteobettini/vmas-media/blob/main/media/scenarios/buzz_wire.gif?raw=true"/> | **<p align="center">multi_give_way</p>** <br/> <img src="https://github.com/matteobettini/vmas-media/blob/main/media/scenarios/multi_give_way.gif?raw=true"/> | **<p align="center">navigation</p>** <br/> <img src="https://github.com/matteobettini/vmas-media/blob/main/media/scenarios/navigation.gif?raw=true"/> |
348-
| **<p align="center">sampling</p>** <br/> <img src="https://github.com/matteobettini/vmas-media/blob/main/media/scenarios/sampling.gif?raw=true"/> | **<p align="center">wind_flocking</p>** <br/> <img src="https://github.com/matteobettini/vmas-media/blob/main/media/scenarios/wind_flocking.gif?raw=true"/> | **<p align="center">road_traffic</p>** <br/> <img src="https://github.com/matteobettini/vmas-media/blob/main/media/scenarios/road_traffic_cpm_lab.gif?raw=true"/> |
348+
| **<p align="center">sampling</p>** <br/> <img src="https://github.com/matteobettini/vmas-media/blob/main/media/scenarios/sampling.gif?raw=true"/> | **<p align="center">wind_flocking</p>** <br/> <img src="https://github.com/matteobettini/vmas-media/blob/main/media/scenarios/wind_flocking.gif?raw=true"/> | **<p align="center">road_traffic</p>** <br/> <img src="https://github.com/matteobettini/vmas-media/blob/main/media/scenarios/road_traffic_cpm_lab.gif?raw=true"/> |
349+
349350
#### Main scenarios
350351

351352
| Env name | Description | GIF |

vmas/simulator/core.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2186,7 +2186,7 @@ def _vectorized_joint_constraints(self, joints):
21862186
rotate = []
21872187
rot_a = []
21882188
rot_b = []
2189-
2189+
joint_rot = []
21902190
for entity_a, entity_b, joint in joints:
21912191
pos_joint_a.append(joint.pos_point(entity_a))
21922192
pos_joint_b.append(joint.pos_point(entity_b))
@@ -2196,6 +2196,13 @@ def _vectorized_joint_constraints(self, joints):
21962196
rotate.append(torch.tensor(joint.rotate, device=self.device))
21972197
rot_a.append(entity_a.state.rot)
21982198
rot_b.append(entity_b.state.rot)
2199+
joint_rot.append(
2200+
torch.tensor(joint.fixed_rotation, device=self.device)
2201+
.unsqueeze(-1)
2202+
.expand(self.batch_dim, 1)
2203+
if isinstance(joint.fixed_rotation, float)
2204+
else joint.fixed_rotation
2205+
)
21992206
pos_a = torch.stack(pos_a, dim=-2)
22002207
pos_b = torch.stack(pos_b, dim=-2)
22012208
pos_joint_a = torch.stack(pos_joint_a, dim=-2)
@@ -2215,6 +2222,7 @@ def _vectorized_joint_constraints(self, joints):
22152222
dim=-1,
22162223
)
22172224
rotate = rotate_prior.unsqueeze(0).expand(self.batch_dim, -1).unsqueeze(-1)
2225+
joint_rot = torch.stack(joint_rot, dim=-2)
22182226

22192227
(force_a_attractive, force_b_attractive,) = self._get_constraint_forces(
22202228
pos_joint_a,
@@ -2239,7 +2247,7 @@ def _vectorized_joint_constraints(self, joints):
22392247
torque_b_rotate = TorchUtils.compute_torque(force_b, r_b)
22402248

22412249
torque_a_fixed, torque_b_fixed = self._get_constraint_torques(
2242-
rot_a, rot_b, force_multiplier=self._torque_constraint_force
2250+
rot_a, rot_b + joint_rot, force_multiplier=self._torque_constraint_force
22432251
)
22442252

22452253
torque_a = torch.where(

vmas/simulator/dynamics/forward.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) 2024.
2+
# ProrokLab (https://www.proroklab.org/)
3+
# All rights reserved.
4+
import torch
5+
6+
from vmas.simulator.dynamics.common import Dynamics
7+
from vmas.simulator.utils import TorchUtils, X
8+
9+
10+
class Forward(Dynamics):
11+
@property
12+
def needed_action_size(self) -> int:
13+
return 1
14+
15+
def process_action(self):
16+
force = torch.zeros(
17+
self.agent.batch_dim, 2, device=self.agent.device, dtype=torch.float
18+
)
19+
force[:, X] = self.agent.action.u[:, 0]
20+
self.agent.state.force = TorchUtils.rotate_vector(force, self.agent.state.rot)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) 2024.
2+
# ProrokLab (https://www.proroklab.org/)
3+
# All rights reserved.
4+
5+
from vmas.simulator.dynamics.common import Dynamics
6+
7+
8+
class Rotation(Dynamics):
9+
@property
10+
def needed_action_size(self) -> int:
11+
return 1
12+
13+
def process_action(self):
14+
self.agent.state.torque = self.agent.action.u[:, 0]

vmas/simulator/dynamics/static.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) 2024.
2+
# ProrokLab (https://www.proroklab.org/)
3+
# All rights reserved.
4+
5+
from vmas.simulator.dynamics.common import Dynamics
6+
7+
8+
class Static(Dynamics):
9+
@property
10+
def needed_action_size(self) -> int:
11+
return 0
12+
13+
def process_action(self):
14+
pass

vmas/simulator/joints.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from __future__ import annotations
66

7-
from typing import List, Tuple, TYPE_CHECKING
7+
from typing import List, Optional, Tuple, TYPE_CHECKING
88

99
import torch
1010

@@ -30,6 +30,8 @@ def __init__(
3030
collidable: bool = False,
3131
width: float = 0.0,
3232
mass: float = 1.0,
33+
fixed_rotation_a: Optional[float] = None,
34+
fixed_rotation_b: Optional[float] = None,
3335
):
3436
assert entity_a != entity_b, "Cannot join same entity"
3537
for anchor in (anchor_a, anchor_b):
@@ -40,11 +42,27 @@ def __init__(
4042
if dist == 0:
4143
assert not collidable, "Cannot have collidable joint with dist 0"
4244
assert width == 0, "Cannot have width for joint with dist 0"
45+
assert (
46+
fixed_rotation_a == fixed_rotation_b
47+
), "If dist is 0, fixed_rotation_a and fixed_rotation_b should be the same"
48+
if fixed_rotation_a is not None:
49+
assert (
50+
not rotate_a
51+
), "If you provide a fixed rotation for a, rotate_a should be False"
52+
if fixed_rotation_b is not None:
53+
assert (
54+
not rotate_b
55+
), "If you provide a fixed rotation for b, rotate_b should be False"
56+
4357
if width > 0:
4458
assert collidable
4559

4660
self.entity_a = entity_a
4761
self.entity_b = entity_b
62+
self.rotate_a = rotate_a
63+
self.rotate_b = rotate_b
64+
self.fixed_rotation_a = fixed_rotation_a
65+
self.fixed_rotation_b = fixed_rotation_b
4866
self.landmark = None
4967
self.joint_constraints = []
5068

@@ -57,6 +75,7 @@ def __init__(
5775
anchor_b=anchor_b,
5876
dist=dist,
5977
rotate=rotate_a and rotate_b,
78+
fixed_rotation=fixed_rotation_a, # or b, it is the same
6079
),
6180
)
6281
else:
@@ -85,6 +104,7 @@ def __init__(
85104
anchor_b=anchor_a,
86105
dist=0.0,
87106
rotate=rotate_a,
107+
fixed_rotation=fixed_rotation_a,
88108
),
89109
JointConstraint(
90110
self.landmark,
@@ -93,6 +113,7 @@ def __init__(
93113
anchor_b=anchor_b,
94114
dist=0.0,
95115
rotate=rotate_b,
116+
fixed_rotation=fixed_rotation_b,
96117
),
97118
]
98119

@@ -104,14 +125,31 @@ def notify(self, observable, *args, **kwargs):
104125
(pos_a + pos_b) / 2,
105126
batch_index=None,
106127
)
128+
129+
angle = torch.atan2(
130+
pos_b[:, vmas.simulator.utils.Y] - pos_a[:, vmas.simulator.utils.Y],
131+
pos_b[:, vmas.simulator.utils.X] - pos_a[:, vmas.simulator.utils.X],
132+
).unsqueeze(-1)
133+
107134
self.landmark.set_rot(
108-
torch.atan2(
109-
pos_b[:, vmas.simulator.utils.Y] - pos_a[:, vmas.simulator.utils.Y],
110-
pos_b[:, vmas.simulator.utils.X] - pos_a[:, vmas.simulator.utils.X],
111-
).unsqueeze(-1),
135+
angle,
112136
batch_index=None,
113137
)
114138

139+
# If we do not allow rotation, and we did not provide a fixed rotation value, we infer it
140+
if not self.rotate_a and self.fixed_rotation_a is None:
141+
self.joint_constraints[0].fixed_rotation = torch.where(
142+
angle >= 0,
143+
angle - self.entity_a.state.rot,
144+
-angle + self.entity_a.state.rot,
145+
)
146+
if not self.rotate_b and self.fixed_rotation_b is None:
147+
self.joint_constraints[1].fixed_rotation = torch.where(
148+
angle >= 0,
149+
angle - self.entity_b.state.rot,
150+
-angle + self.entity_b.state.rot,
151+
)
152+
115153

116154
# Private class: do not instantiate directly
117155
class JointConstraint:
@@ -127,19 +165,28 @@ def __init__(
127165
anchor_b: Tuple[float, float] = (0.0, 0.0),
128166
dist: float = 0.0,
129167
rotate: bool = True,
168+
fixed_rotation: Optional[float] = None,
130169
):
131170
assert entity_a != entity_b, "Cannot join same entity"
132171
for anchor in (anchor_a, anchor_b):
133172
assert (
134173
max(anchor) <= 1 and min(anchor) >= -1
135174
), f"Joint anchor points should be between -1 and 1, got {anchor}"
136175
assert dist >= 0, f"Joint dist must be >= 0, got {dist}"
176+
if fixed_rotation is not None:
177+
assert not rotate, "If fixed rotation is provided, rotate should be False"
178+
if rotate:
179+
assert (
180+
fixed_rotation is None
181+
), "If you provide a fixed rotation, rotate should be False"
182+
fixed_rotation = 0.0
137183

138184
self.entity_a = entity_a
139185
self.entity_b = entity_b
140186
self.anchor_a = anchor_a
141187
self.anchor_b = anchor_b
142188
self.dist = dist
189+
self.fixed_rotation = fixed_rotation
143190
self.rotate = rotate
144191
self._delta_anchor_tensor_map = {}
145192

0 commit comments

Comments
 (0)