@@ -122,23 +122,6 @@ def get_control_dim(self, part_name):
122122 return self .part_controllers [part_name ].control_dim
123123
124124 def get_controller_base_pose (self , controller_name ):
125- """
126- Get the base position and orientation of a specified controller's part. Note: this pose may likely differ from
127- the robot base's pose.
128-
129- Args:
130- controller_name (str): The name of the controller, used to look up part-specific information.
131-
132- Returns:
133- tuple[np.ndarray, np.ndarray]: A tuple containing:
134- - base_pos (np.ndarray): The 3D position of the part's center in world coordinates (shape: (3,)).
135- - base_ori (np.ndarray): The 3x3 rotation matrix representing the part's orientation in world coordinates.
136-
137- Details:
138- - Uses the controller's `naming_prefix` and `part_name` to construct the corresponding site name.
139- - Queries the simulation (`self.sim`) for the site's position (`site_xpos`) and orientation (`site_xmat`).
140- - The site orientation matrix is reshaped from a flat array of size 9 to a 3x3 rotation matrix.
141- """
142125 naming_prefix = self .part_controllers [controller_name ].naming_prefix
143126 part_name = self .part_controllers [controller_name ].part_name
144127 base_pos = np .array (self .sim .data .site_xpos [self .sim .model .site_name2id (f"{ naming_prefix } { part_name } _center" )])
@@ -161,7 +144,10 @@ def action_limits(self):
161144 for part_name , controller in self .part_controllers .items ():
162145 if part_name not in self .arms :
163146 if part_name in self .grippers .keys ():
164- low_g , high_g = ([- 1 ] * self .grippers [part_name ].dof , [1 ] * self .grippers [part_name ].dof )
147+ low_g , high_g = (
148+ [- 1 ] * self .grippers [part_name ].dof ,
149+ [1 ] * self .grippers [part_name ].dof ,
150+ )
165151 low , high = np .concatenate ([low , low_g ]), np .concatenate ([high , high_g ])
166152 else :
167153 control_dim = controller .control_dim
@@ -172,6 +158,55 @@ def action_limits(self):
172158 low , high = np .concatenate ([low , low_c ]), np .concatenate ([high , high_c ])
173159 return low , high
174160
161+ def create_action_vector (self , action_dict ):
162+ """
163+ A helper function that creates the action vector given a dictionary
164+ """
165+ full_action_vector = np .zeros (self .action_limits [0 ].shape )
166+ for part_name , action_vector in action_dict .items ():
167+ if part_name not in self ._action_split_indexes :
168+ ROBOSUITE_DEFAULT_LOGGER .debug (f"{ part_name } is not specified in the action space" )
169+ continue
170+ start_idx , end_idx = self ._action_split_indexes [part_name ]
171+ if end_idx - start_idx == 0 :
172+ # skipping not controlling actions
173+ continue
174+ assert len (action_vector ) == (end_idx - start_idx ), ROBOSUITE_DEFAULT_LOGGER .error (
175+ f"Action vector for { part_name } is not the correct size. Expected { end_idx - start_idx } for { part_name } , got { len (action_vector )} "
176+ )
177+ full_action_vector [start_idx :end_idx ] = action_vector
178+ return full_action_vector
179+
180+ def get_action_info (self ):
181+ action_index_info = []
182+ action_dim_info = []
183+ for part_name , (
184+ start_idx ,
185+ end_idx ,
186+ ) in self ._action_split_indexes .items ():
187+ action_dim_info .append (f"{ part_name } : { (end_idx - start_idx )} dim" )
188+ action_index_info .append (f"{ part_name } : { start_idx } :{ end_idx } " )
189+
190+ return action_index_info , action_dim_info
191+
192+ def print_action_info (self ):
193+ action_index_info , action_dim_info = self .get_action_info ()
194+ action_index_info_str = ", " .join (action_index_info )
195+ action_dim_info_str = ", " .join (action_dim_info )
196+ ROBOSUITE_DEFAULT_LOGGER .info (f"Action Dimensions: [{ action_dim_info_str } ]" )
197+ ROBOSUITE_DEFAULT_LOGGER .info (f"Action Indices: [{ action_index_info_str } ]" )
198+
199+ def get_action_info_dict (self ):
200+ info_dict = {}
201+ info_dict ["Action Dimension" ] = self .action_limits [0 ].shape
202+ info_dict .update (dict (self ._action_split_indexes ))
203+ return info_dict
204+
205+ def print_action_info_dict (self , name : str = "" ):
206+ info_dict = self .get_action_info_dict ()
207+ info_dict_str = f"\n Action Info for { name } :\n \n { json .dumps (dict (info_dict ), indent = 4 )} "
208+ ROBOSUITE_DEFAULT_LOGGER .info (info_dict_str )
209+
175210
176211@register_composite_controller
177212class HybridMobileBase (CompositeController ):
@@ -217,7 +252,7 @@ def create_action_vector(self, action_dict):
217252 A helper function that creates the action vector given a dictionary
218253 """
219254 full_action_vector = np .zeros (self .action_limits [0 ].shape )
220- for ( part_name , action_vector ) in action_dict .items ():
255+ for part_name , action_vector in action_dict .items ():
221256 if part_name not in self ._action_split_indexes :
222257 ROBOSUITE_DEFAULT_LOGGER .debug (f"{ part_name } is not specified in the action space" )
223258 continue
@@ -323,6 +358,10 @@ def setup_whole_body_controller_action_split_idx(self):
323358 previous_idx = last_idx
324359
325360 def set_goal (self , all_action ):
361+ if self .composite_controller_specific_config .get ("skip_wbc_action" , False ):
362+ super ().set_goal (all_action )
363+ return
364+
326365 target_qpos = self .joint_action_policy .solve (all_action [: self .joint_action_policy .control_dim ])
327366 # create new all_action vector with the IK solver's actions first
328367 all_action = np .concatenate ([target_qpos , all_action [self .joint_action_policy .control_dim :]])
@@ -343,6 +382,9 @@ def action_limits(self):
343382 Returns the action limits for the whole body controller.
344383 Corresponds to each term in the action vector passed to env.step().
345384 """
385+ if self .composite_controller_specific_config .get ("skip_wbc_action" , False ):
386+ return super ().action_limits
387+
346388 low , high = [], []
347389 # assumption: IK solver's actions come first
348390 low_c , high_c = self .joint_action_policy .control_limits
@@ -353,7 +395,10 @@ def action_limits(self):
353395 continue
354396 if part_name not in self .arms :
355397 if part_name in self .grippers .keys ():
356- low_g , high_g = ([- 1 ] * self .grippers [part_name ].dof , [1 ] * self .grippers [part_name ].dof )
398+ low_g , high_g = (
399+ [- 1 ] * self .grippers [part_name ].dof ,
400+ [1 ] * self .grippers [part_name ].dof ,
401+ )
357402 low , high = np .concatenate ([low , low_g ]), np .concatenate ([high , high_g ])
358403 else :
359404 control_dim = controller .control_dim
@@ -365,8 +410,11 @@ def action_limits(self):
365410 return low , high
366411
367412 def create_action_vector (self , action_dict : Dict [str , np .ndarray ]) -> np .ndarray :
413+ if self .composite_controller_specific_config .get ("skip_wbc_action" , False ):
414+ return super ().create_action_vector (action_dict )
415+
368416 full_action_vector = np .zeros (self .action_limits [0 ].shape )
369- for ( part_name , action_vector ) in action_dict .items ():
417+ for part_name , action_vector in action_dict .items ():
370418 if part_name not in self ._whole_body_controller_action_split_indexes :
371419 ROBOSUITE_DEFAULT_LOGGER .debug (f"{ part_name } is not specified in the action space" )
372420 continue
@@ -380,24 +428,53 @@ def create_action_vector(self, action_dict: Dict[str, np.ndarray]) -> np.ndarray
380428 full_action_vector [start_idx :end_idx ] = action_vector
381429 return full_action_vector
382430
383- def print_action_info (self ):
431+ def get_action_info (self ):
432+ if self .composite_controller_specific_config .get ("skip_wbc_action" , False ):
433+ return super ().get_action_info ()
434+
384435 action_index_info = []
385436 action_dim_info = []
386- for part_name , (start_idx , end_idx ) in self ._whole_body_controller_action_split_indexes .items ():
437+ for part_name , (
438+ start_idx ,
439+ end_idx ,
440+ ) in self ._whole_body_controller_action_split_indexes .items ():
387441 action_dim_info .append (f"{ part_name } : { (end_idx - start_idx )} dim" )
388442 action_index_info .append (f"{ part_name } : { start_idx } :{ end_idx } " )
443+ return action_index_info , action_dim_info
389444
390- action_dim_info_str = ", " .join (action_dim_info )
391- ROBOSUITE_DEFAULT_LOGGER .info (f"Action Dimensions: [{ action_dim_info_str } ]" )
445+ def print_action_info (self ):
446+ if self .composite_controller_specific_config .get ("skip_wbc_action" , False ):
447+ return super ().print_action_info ()
392448
449+ action_index_info , action_dim_info = self .get_action_info ()
393450 action_index_info_str = ", " .join (action_index_info )
451+ action_dim_info_str = ", " .join (action_dim_info )
452+ ROBOSUITE_DEFAULT_LOGGER .info (f"Action Dimensions: [{ action_dim_info_str } ]" )
394453 ROBOSUITE_DEFAULT_LOGGER .info (f"Action Indices: [{ action_index_info_str } ]" )
395454
396- def print_action_info_dict (self , name : str = "" ):
455+ def get_action_info_dict (self ):
456+ if self .composite_controller_specific_config .get ("skip_wbc_action" , False ):
457+ return super ().get_action_info_dict ()
458+
397459 info_dict = {}
398460 info_dict ["Action Dimension" ] = self .action_limits [0 ].shape
399461 info_dict .update (dict (self ._whole_body_controller_action_split_indexes ))
462+ return info_dict
463+
464+ def get_action_info_dict (self ):
465+ if self .composite_controller_specific_config .get ("skip_wbc_action" , False ):
466+ return super ().get_action_info_dict ()
467+
468+ info_dict = {}
469+ info_dict ["Action Dimension" ] = self .action_limits [0 ].shape
470+ info_dict .update (dict (self ._whole_body_controller_action_split_indexes ))
471+ return info_dict
472+
473+ def print_action_info_dict (self , name : str = "" ):
474+ if self .composite_controller_specific_config .get ("skip_wbc_action" , False ):
475+ return super ().print_action_info_dict (name )
400476
477+ info_dict = self .get_action_info_dict ()
401478 info_dict_str = f"\n Action Info for { name } :\n \n { json .dumps (dict (info_dict ), indent = 4 )} "
402479 ROBOSUITE_DEFAULT_LOGGER .info (info_dict_str )
403480
@@ -474,6 +551,5 @@ def _init_joint_action_policy(self):
474551 max_dq_torso = self .composite_controller_specific_config .get ("ik_max_dq_torso" , 0.2 ),
475552 input_rotation_repr = self .composite_controller_specific_config .get ("ik_input_rotation_repr" , "axis_angle" ),
476553 input_type = self .composite_controller_specific_config .get ("ik_input_type" , "axis_angle" ),
477- input_ref_frame = self .composite_controller_specific_config .get ("ik_input_ref_frame" , "world" ),
478554 debug = self .composite_controller_specific_config .get ("verbose" , False ),
479555 )
0 commit comments