|
2 | 2 | import json |
3 | 3 | import os |
4 | 4 | from collections import OrderedDict |
5 | | -from typing import Optional |
| 5 | +from typing import Optional, Dict |
6 | 6 |
|
7 | 7 | import numpy as np |
8 | 8 |
|
@@ -328,6 +328,10 @@ def setup_references(self): |
328 | 328 | self._ref_torso_joint_pos_indexes = [ |
329 | 329 | self.sim.model.get_joint_qpos_addr(x) for x in self.robot_model._torso_joints |
330 | 330 | ] |
| 331 | + # indices for all joints (includes arms, base, torso, but not gripper joints, as those are handled separately) |
| 332 | + self._ref_all_joint_pos_indexes = [ |
| 333 | + self.sim.model.get_joint_qpos_addr(x) for x in self.robot_model.all_joints |
| 334 | + ] |
331 | 335 |
|
332 | 336 | def setup_observables(self): |
333 | 337 | """ |
@@ -638,6 +642,58 @@ def set_gripper_joint_positions(self, jpos: np.ndarray, gripper_arm: Optional[st |
638 | 642 | else: |
639 | 643 | raise ValueError(f"No gripper found for arm {gripper_arm}") |
640 | 644 |
|
| 645 | + def set_all_robot_joint_positions(self, joint_positions: Dict[str, float]): |
| 646 | + """ |
| 647 | + Helper method to force robot joint positions to the passed values. |
| 648 | + Assumes valid joint names are passed. |
| 649 | +
|
| 650 | + Args: |
| 651 | + joint_positions (dict): joint name -> joint position (in angles / radians) |
| 652 | + """ |
| 653 | + for joint_name, position in joint_positions.items(): |
| 654 | + self.sim.data.qpos[self.sim.model.joint_name2id(joint_name)] = position |
| 655 | + self.sim.forward() |
| 656 | + |
| 657 | + def get_robot_joint_positions(self): |
| 658 | + """ |
| 659 | + Returns: |
| 660 | + np.array: joint positions (in angles / radians) |
| 661 | + """ |
| 662 | + return self.sim.data.qpos[self._ref_joint_pos_indexes] |
| 663 | + |
| 664 | + def get_all_robot_joint_positions(self) -> Dict[str, float]: |
| 665 | + """ |
| 666 | + Returns all joint positions including robot joints (arms, base, torso) and gripper joints. |
| 667 | + |
| 668 | + Returns: |
| 669 | + dict: joint name -> joint position (in angles / radians) |
| 670 | + """ |
| 671 | + joint_positions = {} |
| 672 | + |
| 673 | + # Get all robot joint positions (includes arms, base, torso, etc.) |
| 674 | + all_joint_names = self.robot_model.all_joints |
| 675 | + for i, joint_name in enumerate(all_joint_names): |
| 676 | + joint_positions[joint_name] = self.sim.data.qpos[self._ref_all_joint_pos_indexes[i]] |
| 677 | + |
| 678 | + # Add gripper joint positions for all arms |
| 679 | + for arm in self.arms: |
| 680 | + if self.has_gripper[arm]: |
| 681 | + gripper_joint_names = self.gripper_joints[arm] |
| 682 | + gripper_positions = self.sim.data.qpos[self._ref_gripper_joint_pos_indexes[arm]] |
| 683 | + for joint_name, position in zip(gripper_joint_names, gripper_positions): |
| 684 | + joint_positions[joint_name] = position |
| 685 | + |
| 686 | + return joint_positions |
| 687 | + |
| 688 | + def get_gripper_joint_positions(self, gripper_arm: Optional[str] = None): |
| 689 | + """ |
| 690 | + Returns: |
| 691 | + np.array: gripper joint positions (in angles / radians) |
| 692 | + """ |
| 693 | + if gripper_arm is None: |
| 694 | + gripper_arm = self.arms[0] |
| 695 | + return self.sim.data.qpos[self._ref_gripper_joint_pos_indexes[gripper_arm]] |
| 696 | + |
641 | 697 | @property |
642 | 698 | def js_energy(self): |
643 | 699 | """ |
|
0 commit comments