44
55from __future__ import annotations
66
7- from typing import List , Tuple , TYPE_CHECKING
7+ from typing import List , Optional , Tuple , TYPE_CHECKING
88
99import torch
1010
@@ -30,6 +30,8 @@ def __init__(
3030 collidable : bool = False ,
3131 width : float = 0.0 ,
3232 mass : float = 1.0 ,
33+ fixed_rotation_a : Optional [float ] = None ,
34+ fixed_rotation_b : Optional [float ] = None ,
3335 ):
3436 assert entity_a != entity_b , "Cannot join same entity"
3537 for anchor in (anchor_a , anchor_b ):
@@ -40,11 +42,27 @@ def __init__(
4042 if dist == 0 :
4143 assert not collidable , "Cannot have collidable joint with dist 0"
4244 assert width == 0 , "Cannot have width for joint with dist 0"
45+ assert (
46+ fixed_rotation_a == fixed_rotation_b
47+ ), "If dist is 0, fixed_rotation_a and fixed_rotation_b should be the same"
48+ if fixed_rotation_a is not None :
49+ assert (
50+ not rotate_a
51+ ), "If you provide a fixed rotation for a, rotate_a should be False"
52+ if fixed_rotation_b is not None :
53+ assert (
54+ not rotate_b
55+ ), "If you provide a fixed rotation for b, rotate_b should be False"
56+
4357 if width > 0 :
4458 assert collidable
4559
4660 self .entity_a = entity_a
4761 self .entity_b = entity_b
62+ self .rotate_a = rotate_a
63+ self .rotate_b = rotate_b
64+ self .fixed_rotation_a = fixed_rotation_a
65+ self .fixed_rotation_b = fixed_rotation_b
4866 self .landmark = None
4967 self .joint_constraints = []
5068
@@ -57,6 +75,7 @@ def __init__(
5775 anchor_b = anchor_b ,
5876 dist = dist ,
5977 rotate = rotate_a and rotate_b ,
78+ fixed_rotation = fixed_rotation_a , # or b, it is the same
6079 ),
6180 )
6281 else :
@@ -85,6 +104,7 @@ def __init__(
85104 anchor_b = anchor_a ,
86105 dist = 0.0 ,
87106 rotate = rotate_a ,
107+ fixed_rotation = fixed_rotation_a ,
88108 ),
89109 JointConstraint (
90110 self .landmark ,
@@ -93,6 +113,7 @@ def __init__(
93113 anchor_b = anchor_b ,
94114 dist = 0.0 ,
95115 rotate = rotate_b ,
116+ fixed_rotation = fixed_rotation_b ,
96117 ),
97118 ]
98119
@@ -104,14 +125,31 @@ def notify(self, observable, *args, **kwargs):
104125 (pos_a + pos_b ) / 2 ,
105126 batch_index = None ,
106127 )
128+
129+ angle = torch .atan2 (
130+ pos_b [:, vmas .simulator .utils .Y ] - pos_a [:, vmas .simulator .utils .Y ],
131+ pos_b [:, vmas .simulator .utils .X ] - pos_a [:, vmas .simulator .utils .X ],
132+ ).unsqueeze (- 1 )
133+
107134 self .landmark .set_rot (
108- torch .atan2 (
109- pos_b [:, vmas .simulator .utils .Y ] - pos_a [:, vmas .simulator .utils .Y ],
110- pos_b [:, vmas .simulator .utils .X ] - pos_a [:, vmas .simulator .utils .X ],
111- ).unsqueeze (- 1 ),
135+ angle ,
112136 batch_index = None ,
113137 )
114138
139+ # If we do not allow rotation, and we did not provide a fixed rotation value, we infer it
140+ if not self .rotate_a and self .fixed_rotation_a is None :
141+ self .joint_constraints [0 ].fixed_rotation = torch .where (
142+ angle >= 0 ,
143+ angle - self .entity_a .state .rot ,
144+ - angle + self .entity_a .state .rot ,
145+ )
146+ if not self .rotate_b and self .fixed_rotation_b is None :
147+ self .joint_constraints [1 ].fixed_rotation = torch .where (
148+ angle >= 0 ,
149+ angle - self .entity_b .state .rot ,
150+ - angle + self .entity_b .state .rot ,
151+ )
152+
115153
116154# Private class: do not instantiate directly
117155class JointConstraint :
@@ -127,19 +165,28 @@ def __init__(
127165 anchor_b : Tuple [float , float ] = (0.0 , 0.0 ),
128166 dist : float = 0.0 ,
129167 rotate : bool = True ,
168+ fixed_rotation : Optional [float ] = None ,
130169 ):
131170 assert entity_a != entity_b , "Cannot join same entity"
132171 for anchor in (anchor_a , anchor_b ):
133172 assert (
134173 max (anchor ) <= 1 and min (anchor ) >= - 1
135174 ), f"Joint anchor points should be between -1 and 1, got { anchor } "
136175 assert dist >= 0 , f"Joint dist must be >= 0, got { dist } "
176+ if fixed_rotation is not None :
177+ assert not rotate , "If fixed rotation is provided, rotate should be False"
178+ if rotate :
179+ assert (
180+ fixed_rotation is None
181+ ), "If you provide a fixed rotation, rotate should be False"
182+ fixed_rotation = 0.0
137183
138184 self .entity_a = entity_a
139185 self .entity_b = entity_b
140186 self .anchor_a = anchor_a
141187 self .anchor_b = anchor_b
142188 self .dist = dist
189+ self .fixed_rotation = fixed_rotation
143190 self .rotate = rotate
144191 self ._delta_anchor_tensor_map = {}
145192
0 commit comments