Skip to content

Commit 1a0fe90

Browse files
Allow skip_wbc_action in wbik controller
1 parent cb173eb commit 1a0fe90

1 file changed

Lines changed: 103 additions & 27 deletions

File tree

robosuite/controllers/composite/composite_controller.py

Lines changed: 103 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -122,23 +122,6 @@ def get_control_dim(self, part_name):
122122
return self.part_controllers[part_name].control_dim
123123

124124
def get_controller_base_pose(self, controller_name):
125-
"""
126-
Get the base position and orientation of a specified controller's part. Note: this pose may likely differ from
127-
the robot base's pose.
128-
129-
Args:
130-
controller_name (str): The name of the controller, used to look up part-specific information.
131-
132-
Returns:
133-
tuple[np.ndarray, np.ndarray]: A tuple containing:
134-
- base_pos (np.ndarray): The 3D position of the part's center in world coordinates (shape: (3,)).
135-
- base_ori (np.ndarray): The 3x3 rotation matrix representing the part's orientation in world coordinates.
136-
137-
Details:
138-
- Uses the controller's `naming_prefix` and `part_name` to construct the corresponding site name.
139-
- Queries the simulation (`self.sim`) for the site's position (`site_xpos`) and orientation (`site_xmat`).
140-
- The site orientation matrix is reshaped from a flat array of size 9 to a 3x3 rotation matrix.
141-
"""
142125
naming_prefix = self.part_controllers[controller_name].naming_prefix
143126
part_name = self.part_controllers[controller_name].part_name
144127
base_pos = np.array(self.sim.data.site_xpos[self.sim.model.site_name2id(f"{naming_prefix}{part_name}_center")])
@@ -161,7 +144,10 @@ def action_limits(self):
161144
for part_name, controller in self.part_controllers.items():
162145
if part_name not in self.arms:
163146
if part_name in self.grippers.keys():
164-
low_g, high_g = ([-1] * self.grippers[part_name].dof, [1] * self.grippers[part_name].dof)
147+
low_g, high_g = (
148+
[-1] * self.grippers[part_name].dof,
149+
[1] * self.grippers[part_name].dof,
150+
)
165151
low, high = np.concatenate([low, low_g]), np.concatenate([high, high_g])
166152
else:
167153
control_dim = controller.control_dim
@@ -172,6 +158,55 @@ def action_limits(self):
172158
low, high = np.concatenate([low, low_c]), np.concatenate([high, high_c])
173159
return low, high
174160

161+
def create_action_vector(self, action_dict):
162+
"""
163+
A helper function that creates the action vector given a dictionary
164+
"""
165+
full_action_vector = np.zeros(self.action_limits[0].shape)
166+
for part_name, action_vector in action_dict.items():
167+
if part_name not in self._action_split_indexes:
168+
ROBOSUITE_DEFAULT_LOGGER.debug(f"{part_name} is not specified in the action space")
169+
continue
170+
start_idx, end_idx = self._action_split_indexes[part_name]
171+
if end_idx - start_idx == 0:
172+
# skipping not controlling actions
173+
continue
174+
assert len(action_vector) == (end_idx - start_idx), ROBOSUITE_DEFAULT_LOGGER.error(
175+
f"Action vector for {part_name} is not the correct size. Expected {end_idx - start_idx} for {part_name}, got {len(action_vector)}"
176+
)
177+
full_action_vector[start_idx:end_idx] = action_vector
178+
return full_action_vector
179+
180+
def get_action_info(self):
181+
action_index_info = []
182+
action_dim_info = []
183+
for part_name, (
184+
start_idx,
185+
end_idx,
186+
) in self._action_split_indexes.items():
187+
action_dim_info.append(f"{part_name}: {(end_idx - start_idx)} dim")
188+
action_index_info.append(f"{part_name}: {start_idx}:{end_idx}")
189+
190+
return action_index_info, action_dim_info
191+
192+
def print_action_info(self):
193+
action_index_info, action_dim_info = self.get_action_info()
194+
action_index_info_str = ", ".join(action_index_info)
195+
action_dim_info_str = ", ".join(action_dim_info)
196+
ROBOSUITE_DEFAULT_LOGGER.info(f"Action Dimensions: [{action_dim_info_str}]")
197+
ROBOSUITE_DEFAULT_LOGGER.info(f"Action Indices: [{action_index_info_str}]")
198+
199+
def get_action_info_dict(self):
200+
info_dict = {}
201+
info_dict["Action Dimension"] = self.action_limits[0].shape
202+
info_dict.update(dict(self._action_split_indexes))
203+
return info_dict
204+
205+
def print_action_info_dict(self, name: str = ""):
206+
info_dict = self.get_action_info_dict()
207+
info_dict_str = f"\nAction Info for {name}:\n\n{json.dumps(dict(info_dict), indent=4)}"
208+
ROBOSUITE_DEFAULT_LOGGER.info(info_dict_str)
209+
175210

176211
@register_composite_controller
177212
class HybridMobileBase(CompositeController):
@@ -217,7 +252,7 @@ def create_action_vector(self, action_dict):
217252
A helper function that creates the action vector given a dictionary
218253
"""
219254
full_action_vector = np.zeros(self.action_limits[0].shape)
220-
for (part_name, action_vector) in action_dict.items():
255+
for part_name, action_vector in action_dict.items():
221256
if part_name not in self._action_split_indexes:
222257
ROBOSUITE_DEFAULT_LOGGER.debug(f"{part_name} is not specified in the action space")
223258
continue
@@ -323,6 +358,10 @@ def setup_whole_body_controller_action_split_idx(self):
323358
previous_idx = last_idx
324359

325360
def set_goal(self, all_action):
361+
if self.composite_controller_specific_config.get("skip_wbc_action", False):
362+
super().set_goal(all_action)
363+
return
364+
326365
target_qpos = self.joint_action_policy.solve(all_action[: self.joint_action_policy.control_dim])
327366
# create new all_action vector with the IK solver's actions first
328367
all_action = np.concatenate([target_qpos, all_action[self.joint_action_policy.control_dim :]])
@@ -343,6 +382,9 @@ def action_limits(self):
343382
Returns the action limits for the whole body controller.
344383
Corresponds to each term in the action vector passed to env.step().
345384
"""
385+
if self.composite_controller_specific_config.get("skip_wbc_action", False):
386+
return super().action_limits
387+
346388
low, high = [], []
347389
# assumption: IK solver's actions come first
348390
low_c, high_c = self.joint_action_policy.control_limits
@@ -353,7 +395,10 @@ def action_limits(self):
353395
continue
354396
if part_name not in self.arms:
355397
if part_name in self.grippers.keys():
356-
low_g, high_g = ([-1] * self.grippers[part_name].dof, [1] * self.grippers[part_name].dof)
398+
low_g, high_g = (
399+
[-1] * self.grippers[part_name].dof,
400+
[1] * self.grippers[part_name].dof,
401+
)
357402
low, high = np.concatenate([low, low_g]), np.concatenate([high, high_g])
358403
else:
359404
control_dim = controller.control_dim
@@ -365,8 +410,11 @@ def action_limits(self):
365410
return low, high
366411

367412
def create_action_vector(self, action_dict: Dict[str, np.ndarray]) -> np.ndarray:
413+
if self.composite_controller_specific_config.get("skip_wbc_action", False):
414+
return super().create_action_vector(action_dict)
415+
368416
full_action_vector = np.zeros(self.action_limits[0].shape)
369-
for (part_name, action_vector) in action_dict.items():
417+
for part_name, action_vector in action_dict.items():
370418
if part_name not in self._whole_body_controller_action_split_indexes:
371419
ROBOSUITE_DEFAULT_LOGGER.debug(f"{part_name} is not specified in the action space")
372420
continue
@@ -380,24 +428,53 @@ def create_action_vector(self, action_dict: Dict[str, np.ndarray]) -> np.ndarray
380428
full_action_vector[start_idx:end_idx] = action_vector
381429
return full_action_vector
382430

383-
def print_action_info(self):
431+
def get_action_info(self):
432+
if self.composite_controller_specific_config.get("skip_wbc_action", False):
433+
return super().get_action_info()
434+
384435
action_index_info = []
385436
action_dim_info = []
386-
for part_name, (start_idx, end_idx) in self._whole_body_controller_action_split_indexes.items():
437+
for part_name, (
438+
start_idx,
439+
end_idx,
440+
) in self._whole_body_controller_action_split_indexes.items():
387441
action_dim_info.append(f"{part_name}: {(end_idx - start_idx)} dim")
388442
action_index_info.append(f"{part_name}: {start_idx}:{end_idx}")
443+
return action_index_info, action_dim_info
389444

390-
action_dim_info_str = ", ".join(action_dim_info)
391-
ROBOSUITE_DEFAULT_LOGGER.info(f"Action Dimensions: [{action_dim_info_str}]")
445+
def print_action_info(self):
446+
if self.composite_controller_specific_config.get("skip_wbc_action", False):
447+
return super().print_action_info()
392448

449+
action_index_info, action_dim_info = self.get_action_info()
393450
action_index_info_str = ", ".join(action_index_info)
451+
action_dim_info_str = ", ".join(action_dim_info)
452+
ROBOSUITE_DEFAULT_LOGGER.info(f"Action Dimensions: [{action_dim_info_str}]")
394453
ROBOSUITE_DEFAULT_LOGGER.info(f"Action Indices: [{action_index_info_str}]")
395454

396-
def print_action_info_dict(self, name: str = ""):
455+
def get_action_info_dict(self):
456+
if self.composite_controller_specific_config.get("skip_wbc_action", False):
457+
return super().get_action_info_dict()
458+
397459
info_dict = {}
398460
info_dict["Action Dimension"] = self.action_limits[0].shape
399461
info_dict.update(dict(self._whole_body_controller_action_split_indexes))
462+
return info_dict
463+
464+
def get_action_info_dict(self):
465+
if self.composite_controller_specific_config.get("skip_wbc_action", False):
466+
return super().get_action_info_dict()
467+
468+
info_dict = {}
469+
info_dict["Action Dimension"] = self.action_limits[0].shape
470+
info_dict.update(dict(self._whole_body_controller_action_split_indexes))
471+
return info_dict
472+
473+
def print_action_info_dict(self, name: str = ""):
474+
if self.composite_controller_specific_config.get("skip_wbc_action", False):
475+
return super().print_action_info_dict(name)
400476

477+
info_dict = self.get_action_info_dict()
401478
info_dict_str = f"\nAction Info for {name}:\n\n{json.dumps(dict(info_dict), indent=4)}"
402479
ROBOSUITE_DEFAULT_LOGGER.info(info_dict_str)
403480

@@ -474,6 +551,5 @@ def _init_joint_action_policy(self):
474551
max_dq_torso=self.composite_controller_specific_config.get("ik_max_dq_torso", 0.2),
475552
input_rotation_repr=self.composite_controller_specific_config.get("ik_input_rotation_repr", "axis_angle"),
476553
input_type=self.composite_controller_specific_config.get("ik_input_type", "axis_angle"),
477-
input_ref_frame=self.composite_controller_specific_config.get("ik_input_ref_frame", "world"),
478554
debug=self.composite_controller_specific_config.get("verbose", False),
479555
)

0 commit comments

Comments
 (0)