Skip to content

Commit 255ff00

Browse files
Add create_action_dict_from_action_vector function in composite controller
1 parent 9a1bba8 commit 255ff00

1 file changed

Lines changed: 15 additions & 0 deletions

File tree

robosuite/controllers/composite/composite_controller.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,12 @@ def create_action_vector(self, action_dict):
194194
full_action_vector[start_idx:end_idx] = action_vector
195195
return full_action_vector
196196

197+
def create_action_dict_from_action_vector(self, action_vector):
198+
action_dict = {}
199+
for part_name, (start_idx, end_idx) in self._action_split_indexes.items():
200+
action_dict[part_name] = action_vector[start_idx:end_idx]
201+
return action_dict
202+
197203
def get_action_info(self):
198204
action_index_info = []
199205
action_dim_info = []
@@ -445,6 +451,15 @@ def create_action_vector(self, action_dict: Dict[str, np.ndarray]) -> np.ndarray
445451
full_action_vector[start_idx:end_idx] = action_vector
446452
return full_action_vector
447453

454+
def create_action_dict_from_action_vector(self, action_vector):
455+
if self.composite_controller_specific_config.get("skip_wbc_action", False):
456+
return super().create_action_dict_from_action_vector(action_vector)
457+
458+
action_dict = {}
459+
for part_name, (start_idx, end_idx) in self._whole_body_controller_action_split_indexes.items():
460+
action_dict[part_name] = action_vector[start_idx:end_idx]
461+
return action_dict
462+
448463
def get_action_info(self):
449464
if self.composite_controller_specific_config.get("skip_wbc_action", False):
450465
return super().get_action_info()

0 commit comments

Comments
 (0)