Skip to content

Commit 66793b6

Browse files
Add options to get/set all robot joints, not just gripper/arm joints
1 parent 7871abb commit 66793b6

1 file changed

Lines changed: 57 additions & 1 deletion

File tree

robosuite/robots/robot.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import os
44
from collections import OrderedDict
5-
from typing import Optional
5+
from typing import Optional, Dict
66

77
import numpy as np
88

@@ -328,6 +328,10 @@ def setup_references(self):
328328
self._ref_torso_joint_pos_indexes = [
329329
self.sim.model.get_joint_qpos_addr(x) for x in self.robot_model._torso_joints
330330
]
331+
# indices for all joints (includes arms, base, torso, but not gripper joints, as those are handled separately)
332+
self._ref_all_joint_pos_indexes = [
333+
self.sim.model.get_joint_qpos_addr(x) for x in self.robot_model.all_joints
334+
]
331335

332336
def setup_observables(self):
333337
"""
@@ -638,6 +642,58 @@ def set_gripper_joint_positions(self, jpos: np.ndarray, gripper_arm: Optional[st
638642
else:
639643
raise ValueError(f"No gripper found for arm {gripper_arm}")
640644

645+
def set_all_robot_joint_positions(self, joint_positions: Dict[str, float]):
646+
"""
647+
Helper method to force robot joint positions to the passed values.
648+
Assumes valid joint names are passed.
649+
650+
Args:
651+
joint_positions (dict): joint name -> joint position (in angles / radians)
652+
"""
653+
for joint_name, position in joint_positions.items():
654+
self.sim.data.qpos[self.sim.model.joint_name2id(joint_name)] = position
655+
self.sim.forward()
656+
657+
def get_robot_joint_positions(self):
658+
"""
659+
Returns:
660+
np.array: joint positions (in angles / radians)
661+
"""
662+
return self.sim.data.qpos[self._ref_joint_pos_indexes]
663+
664+
def get_all_robot_joint_positions(self) -> Dict[str, float]:
665+
"""
666+
Returns all joint positions including robot joints (arms, base, torso) and gripper joints.
667+
668+
Returns:
669+
dict: joint name -> joint position (in angles / radians)
670+
"""
671+
joint_positions = {}
672+
673+
# Get all robot joint positions (includes arms, base, torso, etc.)
674+
all_joint_names = self.robot_model.all_joints
675+
for i, joint_name in enumerate(all_joint_names):
676+
joint_positions[joint_name] = self.sim.data.qpos[self._ref_all_joint_pos_indexes[i]]
677+
678+
# Add gripper joint positions for all arms
679+
for arm in self.arms:
680+
if self.has_gripper[arm]:
681+
gripper_joint_names = self.gripper_joints[arm]
682+
gripper_positions = self.sim.data.qpos[self._ref_gripper_joint_pos_indexes[arm]]
683+
for joint_name, position in zip(gripper_joint_names, gripper_positions):
684+
joint_positions[joint_name] = position
685+
686+
return joint_positions
687+
688+
def get_gripper_joint_positions(self, gripper_arm: Optional[str] = None):
689+
"""
690+
Returns:
691+
np.array: gripper joint positions (in angles / radians)
692+
"""
693+
if gripper_arm is None:
694+
gripper_arm = self.arms[0]
695+
return self.sim.data.qpos[self._ref_gripper_joint_pos_indexes[gripper_arm]]
696+
641697
@property
642698
def js_energy(self):
643699
"""

0 commit comments

Comments
 (0)