|
35 | 35 | Observable, |
36 | 36 | override, |
37 | 37 | TorchUtils, |
| 38 | + TORQUE_CONSTRAINT_FORCE, |
38 | 39 | X, |
39 | 40 | Y, |
40 | 41 | ) |
@@ -1079,6 +1080,7 @@ def __init__( |
1079 | 1080 | dim_c: int = 0, |
1080 | 1081 | collision_force: float = COLLISION_FORCE, |
1081 | 1082 | joint_force: float = JOINT_FORCE, |
| 1083 | + torque_constraint_force: float = TORQUE_CONSTRAINT_FORCE, |
1082 | 1084 | contact_margin: float = 1e-3, |
1083 | 1085 | gravity: Tuple[float, float] = (0.0, 0.0), |
1084 | 1086 | ): |
@@ -1110,6 +1112,7 @@ def __init__( |
1110 | 1112 | self._collision_force = collision_force |
1111 | 1113 | self._joint_force = joint_force |
1112 | 1114 | self._contact_margin = contact_margin |
| 1115 | + self._torque_constraint_force = torque_constraint_force |
1113 | 1116 | # joints |
1114 | 1117 | self._joints = {} |
1115 | 1118 | # Pairs of collidable shapes |
@@ -1597,8 +1600,6 @@ def step(self): |
1597 | 1600 | # apply gravity |
1598 | 1601 | self._apply_gravity(entity) |
1599 | 1602 |
|
1600 | | - # self._apply_environment_force(entity, i) |
1601 | | - |
1602 | 1603 | self._apply_vectorized_enviornment_force() |
1603 | 1604 |
|
1604 | 1605 | for entity in self.entities: |
@@ -1802,17 +1803,24 @@ def _vectorized_joint_constraints(self, joints): |
1802 | 1803 | pos_joint_b = [] |
1803 | 1804 | dist = [] |
1804 | 1805 | rotate = [] |
| 1806 | + rot_a = [] |
| 1807 | + rot_b = [] |
| 1808 | + |
1805 | 1809 | for entity_a, entity_b, joint in joints: |
1806 | 1810 | pos_joint_a.append(joint.pos_point(entity_a)) |
1807 | 1811 | pos_joint_b.append(joint.pos_point(entity_b)) |
1808 | 1812 | pos_a.append(entity_a.state.pos) |
1809 | 1813 | pos_b.append(entity_b.state.pos) |
1810 | 1814 | dist.append(torch.tensor(joint.dist, device=self.device)) |
1811 | 1815 | rotate.append(torch.tensor(joint.rotate, device=self.device)) |
| 1816 | + rot_a.append(entity_a.state.rot) |
| 1817 | + rot_b.append(entity_b.state.rot) |
1812 | 1818 | pos_a = torch.stack(pos_a, dim=-2) |
1813 | 1819 | pos_b = torch.stack(pos_b, dim=-2) |
1814 | 1820 | pos_joint_a = torch.stack(pos_joint_a, dim=-2) |
1815 | 1821 | pos_joint_b = torch.stack(pos_joint_b, dim=-2) |
| 1822 | + rot_a = torch.stack(rot_a, dim=-2) |
| 1823 | + rot_b = torch.stack(rot_b, dim=-2) |
1816 | 1824 | dist = ( |
1817 | 1825 | torch.stack( |
1818 | 1826 | dist, |
@@ -1846,13 +1854,19 @@ def _vectorized_joint_constraints(self, joints): |
1846 | 1854 | r_a = pos_joint_a - pos_a |
1847 | 1855 | r_b = pos_joint_b - pos_b |
1848 | 1856 |
|
1849 | | - torque_a = torch.zeros_like(rotate, device=self.device, dtype=torch.float) |
1850 | | - torque_b = torch.zeros_like(rotate, device=self.device, dtype=torch.float) |
1851 | | - if rotate_prior.any(): |
1852 | | - torque_a_rotate = TorchUtils.compute_torque(force_a, r_a) |
1853 | | - torque_b_rotate = TorchUtils.compute_torque(force_b, r_b) |
1854 | | - torque_a = torch.where(rotate, torque_a_rotate, 0) |
1855 | | - torque_b = torch.where(rotate, torque_b_rotate, 0) |
| 1857 | + torque_a_rotate = TorchUtils.compute_torque(force_a, r_a) |
| 1858 | + torque_b_rotate = TorchUtils.compute_torque(force_b, r_b) |
| 1859 | + |
| 1860 | + torque_a_fixed, torque_b_fixed = self._get_constraint_torques( |
| 1861 | + rot_a, rot_b, force_multiplier=self._torque_constraint_force |
| 1862 | + ) |
| 1863 | + |
| 1864 | + torque_a = torch.where( |
| 1865 | + rotate, torque_a_rotate, torque_a_rotate + torque_a_fixed |
| 1866 | + ) |
| 1867 | + torque_b = torch.where( |
| 1868 | + rotate, torque_b_rotate, torque_b_rotate + torque_b_fixed |
| 1869 | + ) |
1856 | 1870 |
|
1857 | 1871 | for i, (entity_a, entity_b, _) in enumerate(joints): |
1858 | 1872 | self.update_env_forces( |
@@ -2411,6 +2425,30 @@ def _get_constraint_forces( |
2411 | 2425 | force = torch.where((dist < dist_min).unsqueeze(-1), 0.0, force) |
2412 | 2426 | return force, -force |
2413 | 2427 |
|
| 2428 | + def _get_constraint_torques( |
| 2429 | + self, |
| 2430 | + rot_a: Tensor, |
| 2431 | + rot_b: Tensor, |
| 2432 | + force_multiplier: float = TORQUE_CONSTRAINT_FORCE, |
| 2433 | + ) -> Tensor: |
| 2434 | + min_delta_rot = 1e-9 |
| 2435 | + delta_rot = rot_a - rot_b |
| 2436 | + abs_delta_rot = torch.linalg.vector_norm(delta_rot, dim=-1).unsqueeze(-1) |
| 2437 | + |
| 2438 | + # softmax penetration |
| 2439 | + k = 1 |
| 2440 | + penetration = k * (torch.exp(abs_delta_rot / k) - 1) |
| 2441 | + |
| 2442 | + torque = ( |
| 2443 | + force_multiplier |
| 2444 | + * delta_rot |
| 2445 | + / torch.where(abs_delta_rot > 0, abs_delta_rot, 1e-8) |
| 2446 | + * penetration |
| 2447 | + ) |
| 2448 | + torque = torch.where((abs_delta_rot < min_delta_rot), 0.0, torque) |
| 2449 | + |
| 2450 | + return -torque, torque |
| 2451 | + |
2414 | 2452 | # integrate physical state |
2415 | 2453 | # uses semi-implicit euler with sub-stepping |
2416 | 2454 | def _integrate_state(self, entity: Entity, substep: int): |
|
0 commit comments