Skip to content

Commit 92232b7

Browse files
[Feature] Joint rotations (#113)
* amend * amend * amend
1 parent f4f8a89 commit 92232b7

4 files changed

Lines changed: 51 additions & 12 deletions

File tree

vmas/scenarios/debug/waterfall.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
6868
anchor_a=(1, 0),
6969
anchor_b=(-1, 0),
7070
dist=self.agent_dist,
71-
rotate_a=True,
72-
rotate_b=True,
71+
rotate_a=False,
72+
rotate_b=False,
7373
collidable=True,
7474
width=0,
7575
mass=1,

vmas/simulator/core.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
Observable,
3636
override,
3737
TorchUtils,
38+
TORQUE_CONSTRAINT_FORCE,
3839
X,
3940
Y,
4041
)
@@ -1079,6 +1080,7 @@ def __init__(
10791080
dim_c: int = 0,
10801081
collision_force: float = COLLISION_FORCE,
10811082
joint_force: float = JOINT_FORCE,
1083+
torque_constraint_force: float = TORQUE_CONSTRAINT_FORCE,
10821084
contact_margin: float = 1e-3,
10831085
gravity: Tuple[float, float] = (0.0, 0.0),
10841086
):
@@ -1110,6 +1112,7 @@ def __init__(
11101112
self._collision_force = collision_force
11111113
self._joint_force = joint_force
11121114
self._contact_margin = contact_margin
1115+
self._torque_constraint_force = torque_constraint_force
11131116
# joints
11141117
self._joints = {}
11151118
# Pairs of collidable shapes
@@ -1597,8 +1600,6 @@ def step(self):
15971600
# apply gravity
15981601
self._apply_gravity(entity)
15991602

1600-
# self._apply_environment_force(entity, i)
1601-
16021603
self._apply_vectorized_enviornment_force()
16031604

16041605
for entity in self.entities:
@@ -1802,17 +1803,24 @@ def _vectorized_joint_constraints(self, joints):
18021803
pos_joint_b = []
18031804
dist = []
18041805
rotate = []
1806+
rot_a = []
1807+
rot_b = []
1808+
18051809
for entity_a, entity_b, joint in joints:
18061810
pos_joint_a.append(joint.pos_point(entity_a))
18071811
pos_joint_b.append(joint.pos_point(entity_b))
18081812
pos_a.append(entity_a.state.pos)
18091813
pos_b.append(entity_b.state.pos)
18101814
dist.append(torch.tensor(joint.dist, device=self.device))
18111815
rotate.append(torch.tensor(joint.rotate, device=self.device))
1816+
rot_a.append(entity_a.state.rot)
1817+
rot_b.append(entity_b.state.rot)
18121818
pos_a = torch.stack(pos_a, dim=-2)
18131819
pos_b = torch.stack(pos_b, dim=-2)
18141820
pos_joint_a = torch.stack(pos_joint_a, dim=-2)
18151821
pos_joint_b = torch.stack(pos_joint_b, dim=-2)
1822+
rot_a = torch.stack(rot_a, dim=-2)
1823+
rot_b = torch.stack(rot_b, dim=-2)
18161824
dist = (
18171825
torch.stack(
18181826
dist,
@@ -1846,13 +1854,19 @@ def _vectorized_joint_constraints(self, joints):
18461854
r_a = pos_joint_a - pos_a
18471855
r_b = pos_joint_b - pos_b
18481856

1849-
torque_a = torch.zeros_like(rotate, device=self.device, dtype=torch.float)
1850-
torque_b = torch.zeros_like(rotate, device=self.device, dtype=torch.float)
1851-
if rotate_prior.any():
1852-
torque_a_rotate = TorchUtils.compute_torque(force_a, r_a)
1853-
torque_b_rotate = TorchUtils.compute_torque(force_b, r_b)
1854-
torque_a = torch.where(rotate, torque_a_rotate, 0)
1855-
torque_b = torch.where(rotate, torque_b_rotate, 0)
1857+
torque_a_rotate = TorchUtils.compute_torque(force_a, r_a)
1858+
torque_b_rotate = TorchUtils.compute_torque(force_b, r_b)
1859+
1860+
torque_a_fixed, torque_b_fixed = self._get_constraint_torques(
1861+
rot_a, rot_b, force_multiplier=self._torque_constraint_force
1862+
)
1863+
1864+
torque_a = torch.where(
1865+
rotate, torque_a_rotate, torque_a_rotate + torque_a_fixed
1866+
)
1867+
torque_b = torch.where(
1868+
rotate, torque_b_rotate, torque_b_rotate + torque_b_fixed
1869+
)
18561870

18571871
for i, (entity_a, entity_b, _) in enumerate(joints):
18581872
self.update_env_forces(
@@ -2411,6 +2425,30 @@ def _get_constraint_forces(
24112425
force = torch.where((dist < dist_min).unsqueeze(-1), 0.0, force)
24122426
return force, -force
24132427

2428+
def _get_constraint_torques(
2429+
self,
2430+
rot_a: Tensor,
2431+
rot_b: Tensor,
2432+
force_multiplier: float = TORQUE_CONSTRAINT_FORCE,
2433+
) -> Tensor:
2434+
min_delta_rot = 1e-9
2435+
delta_rot = rot_a - rot_b
2436+
abs_delta_rot = torch.linalg.vector_norm(delta_rot, dim=-1).unsqueeze(-1)
2437+
2438+
# softmax penetration
2439+
k = 1
2440+
penetration = k * (torch.exp(abs_delta_rot / k) - 1)
2441+
2442+
torque = (
2443+
force_multiplier
2444+
* delta_rot
2445+
/ torch.where(abs_delta_rot > 0, abs_delta_rot, 1e-8)
2446+
* penetration
2447+
)
2448+
torque = torch.where((abs_delta_rot < min_delta_rot), 0.0, torque)
2449+
2450+
return -torque, torque
2451+
24142452
# integrate physical state
24152453
# uses semi-implicit euler with sub-stepping
24162454
def _integrate_state(self, entity: Entity, substep: int):

vmas/simulator/joints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(
6767
name=f"joint {entity_a.name} {entity_b.name}",
6868
collide=collidable,
6969
movable=True,
70-
rotatable=rotate_a and rotate_b,
70+
rotatable=True,
7171
mass=mass,
7272
shape=(
7373
vmas.simulator.core.Box(length=dist, width=width)

vmas/simulator/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
LINE_MIN_DIST = 4 / 6e2
2323
COLLISION_FORCE = 100
2424
JOINT_FORCE = 130
25+
TORQUE_CONSTRAINT_FORCE = 1
2526

2627
DRAG = 0.25
2728
LINEAR_FRICTION = 0.0

0 commit comments

Comments
 (0)