|
| 1 | +""" |
| 2 | +A script to collect a batch of human demonstrations. |
| 3 | +
|
| 4 | +The demonstrations can be played back using the `playback_demonstrations_from_hdf5.py` script. |
| 5 | +""" |
| 6 | + |
| 7 | +import argparse |
| 8 | +import datetime |
| 9 | +import json |
| 10 | +import os |
| 11 | +import time |
| 12 | +from glob import glob |
| 13 | + |
| 14 | +import h5py |
| 15 | +import numpy as np |
| 16 | + |
| 17 | +import robosuite as suite |
| 18 | +from robosuite.controllers import load_composite_controller_config |
| 19 | +from robosuite.controllers.composite.composite_controller import WholeBody |
| 20 | +from robosuite.wrappers import DataCollectionWrapper, VisualizationWrapper |
| 21 | + |
| 22 | + |
| 23 | +def collect_human_trajectory(env, device, arm, max_fr): |
| 24 | + """ |
| 25 | + Use the device (keyboard or SpaceNav 3D mouse) to collect a demonstration. |
| 26 | + The rollout trajectory is saved to files in npz format. |
| 27 | + Modify the DataCollectionWrapper wrapper to add new fields or change data formats. |
| 28 | +
|
| 29 | + Args: |
| 30 | + env (MujocoEnv): environment to control |
| 31 | + device (Device): to receive controls from the device |
| 32 | + arms (str): which arm to control (eg bimanual) 'right' or 'left' |
| 33 | + max_fr (int): if specified, pause the simulation whenever simulation runs faster than max_fr |
| 34 | + """ |
| 35 | + |
| 36 | + env.reset() |
| 37 | + env.render() |
| 38 | + DAMPING_RATIO = 1 |
| 39 | + DEFAULT_KD = 150 |
| 40 | + task_completion_hold_count = -1 # counter to collect 10 timesteps after reaching goal |
| 41 | + device.start_control() |
| 42 | + |
| 43 | + for robot in env.robots: |
| 44 | + robot.print_action_info_dict() |
| 45 | + |
| 46 | + # Keep track of prev gripper actions when using since they are position-based and must be maintained when arms switched |
| 47 | + all_prev_gripper_actions = [ |
| 48 | + { |
| 49 | + f"{robot_arm}_gripper": np.repeat([0], robot.gripper[robot_arm].dof) |
| 50 | + for robot_arm in robot.arms |
| 51 | + if robot.gripper[robot_arm].dof > 0 |
| 52 | + } |
| 53 | + for robot in env.robots |
| 54 | + ] |
| 55 | + |
| 56 | + # Loop until we get a reset from the input or the task completes |
| 57 | + while True: |
| 58 | + start = time.time() |
| 59 | + |
| 60 | + # Set active robot |
| 61 | + active_robot = env.robots[device.active_robot] |
| 62 | + |
| 63 | + # Get the newest action |
| 64 | + input_ac_dict = device.input2action() |
| 65 | + |
| 66 | + # If action is none, then this a reset so we should break |
| 67 | + if input_ac_dict is None: |
| 68 | + break |
| 69 | + |
| 70 | + from copy import deepcopy |
| 71 | + |
| 72 | + action_dict = deepcopy(input_ac_dict) # {} |
| 73 | + # set arm actions |
| 74 | + for arm in active_robot.arms: |
| 75 | + if isinstance(active_robot.composite_controller, WholeBody): # input type passed to joint_action_policy |
| 76 | + controller_input_type = active_robot.composite_controller.joint_action_policy.input_type |
| 77 | + else: |
| 78 | + controller_input_type = active_robot.part_controllers[arm].input_type |
| 79 | + |
| 80 | + if controller_input_type == "delta": |
| 81 | + action_dict[arm] = input_ac_dict[f"{arm}_delta"] |
| 82 | + elif controller_input_type == "absolute": |
| 83 | + # sample gaussian noise with stdev at 15% of kp/kd |
| 84 | + |
| 85 | + kd = np.random.normal( |
| 86 | + loc=DEFAULT_KD, scale=0.15 * DEFAULT_KD |
| 87 | + ) |
| 88 | + damping_ratio = np.random.normal( |
| 89 | + loc=DAMPING_RATIO, scale=0.05 * DAMPING_RATIO |
| 90 | + ) |
| 91 | + damping_ratio = np.clip(damping_ratio, 0.1, 1.0) |
| 92 | + # first six actions are kp then next six are kd |
| 93 | + action_dict[arm] = np.concatenate( |
| 94 | + [np.repeat([damping_ratio], 6), np.repeat([kd], 6), input_ac_dict[f"{arm}_abs"]] |
| 95 | + ) |
| 96 | + else: |
| 97 | + raise ValueError |
| 98 | + |
| 99 | + # Maintain gripper state for each robot but only update the active robot with action |
| 100 | + env_action = [robot.create_action_vector(all_prev_gripper_actions[i]) for i, robot in enumerate(env.robots)] |
| 101 | + env_action[device.active_robot] = active_robot.create_action_vector(action_dict) |
| 102 | + env_action = np.concatenate(env_action) |
| 103 | + for gripper_ac in all_prev_gripper_actions[device.active_robot]: |
| 104 | + all_prev_gripper_actions[device.active_robot][gripper_ac] = action_dict[gripper_ac] |
| 105 | + |
| 106 | + env.step(env_action) |
| 107 | + env.render() |
| 108 | + |
| 109 | + # Also break if we complete the task |
| 110 | + if task_completion_hold_count == 0: |
| 111 | + break |
| 112 | + |
| 113 | + # state machine to check for having a success for 10 consecutive timesteps |
| 114 | + if env._check_success(): |
| 115 | + if task_completion_hold_count > 0: |
| 116 | + task_completion_hold_count -= 1 # latched state, decrement count |
| 117 | + else: |
| 118 | + task_completion_hold_count = 10 # reset count on first success timestep |
| 119 | + else: |
| 120 | + task_completion_hold_count = -1 # null the counter if there's no success |
| 121 | + |
| 122 | + # limit frame rate if necessary |
| 123 | + if max_fr is not None: |
| 124 | + elapsed = time.time() - start |
| 125 | + diff = 1 / max_fr - elapsed |
| 126 | + if diff > 0: |
| 127 | + time.sleep(diff) |
| 128 | + |
| 129 | + # cleanup for end of data collection episodes |
| 130 | + env.close() |
| 131 | + |
| 132 | + |
| 133 | +def gather_demonstrations_as_hdf5(directory, out_dir, env_info): |
| 134 | + """ |
| 135 | + Gathers the demonstrations saved in @directory into a |
| 136 | + single hdf5 file. |
| 137 | +
|
| 138 | + The strucure of the hdf5 file is as follows. |
| 139 | +
|
| 140 | + data (group) |
| 141 | + date (attribute) - date of collection |
| 142 | + time (attribute) - time of collection |
| 143 | + repository_version (attribute) - repository version used during collection |
| 144 | + env (attribute) - environment name on which demos were collected |
| 145 | +
|
| 146 | + demo1 (group) - every demonstration has a group |
| 147 | + model_file (attribute) - model xml string for demonstration |
| 148 | + states (dataset) - flattened mujoco states |
| 149 | + actions (dataset) - actions applied during demonstration |
| 150 | +
|
| 151 | + demo2 (group) |
| 152 | + ... |
| 153 | +
|
| 154 | + Args: |
| 155 | + directory (str): Path to the directory containing raw demonstrations. |
| 156 | + out_dir (str): Path to where to store the hdf5 file. |
| 157 | + env_info (str): JSON-encoded string containing environment information, |
| 158 | + including controller and robot info |
| 159 | + """ |
| 160 | + |
| 161 | + hdf5_path = os.path.join(out_dir, "demo.hdf5") |
| 162 | + f = h5py.File(hdf5_path, "w") |
| 163 | + |
| 164 | + # store some metadata in the attributes of one group |
| 165 | + grp = f.create_group("data") |
| 166 | + |
| 167 | + num_eps = 0 |
| 168 | + env_name = None # will get populated at some point |
| 169 | + |
| 170 | + for ep_directory in os.listdir(directory): |
| 171 | + |
| 172 | + state_paths = os.path.join(directory, ep_directory, "state_*.npz") |
| 173 | + states = [] |
| 174 | + actions = [] |
| 175 | + success = False |
| 176 | + |
| 177 | + for state_file in sorted(glob(state_paths)): |
| 178 | + dic = np.load(state_file, allow_pickle=True) |
| 179 | + env_name = str(dic["env"]) |
| 180 | + |
| 181 | + states.extend(dic["states"]) |
| 182 | + for ai in dic["action_infos"]: |
| 183 | + actions.append(ai["actions"]) |
| 184 | + success = success or dic["successful"] |
| 185 | + |
| 186 | + if len(states) == 0: |
| 187 | + continue |
| 188 | + |
| 189 | + # Add only the successful demonstration to dataset |
| 190 | + if success: |
| 191 | + print("Demonstration is successful and has been saved") |
| 192 | + # Delete the last state. This is because when the DataCollector wrapper |
| 193 | + # recorded the states and actions, the states were recorded AFTER playing that action, |
| 194 | + # so we end up with an extra state at the end. |
| 195 | + del states[-1] |
| 196 | + assert len(states) == len(actions) |
| 197 | + |
| 198 | + num_eps += 1 |
| 199 | + ep_data_grp = grp.create_group("demo_{}".format(num_eps)) |
| 200 | + |
| 201 | + # store model xml as an attribute |
| 202 | + xml_path = os.path.join(directory, ep_directory, "model.xml") |
| 203 | + with open(xml_path, "r") as f: |
| 204 | + xml_str = f.read() |
| 205 | + ep_data_grp.attrs["model_file"] = xml_str |
| 206 | + |
| 207 | + # write datasets for states and actions |
| 208 | + ep_data_grp.create_dataset("states", data=np.array(states)) |
| 209 | + ep_data_grp.create_dataset("actions", data=np.array(actions)) |
| 210 | + else: |
| 211 | + print("Demonstration is unsuccessful and has NOT been saved") |
| 212 | + |
| 213 | + # write dataset attributes (metadata) |
| 214 | + now = datetime.datetime.now() |
| 215 | + grp.attrs["date"] = "{}-{}-{}".format(now.month, now.day, now.year) |
| 216 | + grp.attrs["time"] = "{}:{}:{}".format(now.hour, now.minute, now.second) |
| 217 | + grp.attrs["repository_version"] = suite.__version__ |
| 218 | + grp.attrs["env"] = env_name |
| 219 | + grp.attrs["env_info"] = env_info |
| 220 | + |
| 221 | + f.close() |
| 222 | + |
| 223 | + |
| 224 | +if __name__ == "__main__": |
| 225 | + # Arguments |
| 226 | + parser = argparse.ArgumentParser() |
| 227 | + parser.add_argument( |
| 228 | + "--directory", |
| 229 | + type=str, |
| 230 | + default=os.path.join(suite.models.assets_root, "demonstrations_private"), |
| 231 | + ) |
| 232 | + parser.add_argument("--environment", type=str, default="Lift") |
| 233 | + parser.add_argument("--robots", nargs="+", type=str, default="Panda", help="Which robot(s) to use in the env") |
| 234 | + parser.add_argument( |
| 235 | + "--config", type=str, default="default", help="Specified environment configuration if necessary" |
| 236 | + ) |
| 237 | + parser.add_argument("--arm", type=str, default="right", help="Which arm to control (eg bimanual) 'right' or 'left'") |
| 238 | + parser.add_argument("--camera", type=str, default="agentview", help="Which camera to use for collecting demos") |
| 239 | + parser.add_argument( |
| 240 | + "--controller", |
| 241 | + type=str, |
| 242 | + default=None, |
| 243 | + help="Choice of controller. Can be generic (eg. 'BASIC' or 'WHOLE_BODY_MINK_IK') or json file (see robosuite/controllers/config for examples)", |
| 244 | + ) |
| 245 | + parser.add_argument("--device", type=str, default="keyboard") |
| 246 | + parser.add_argument("--pos-sensitivity", type=float, default=1.0, help="How much to scale position user inputs") |
| 247 | + parser.add_argument("--rot-sensitivity", type=float, default=1.0, help="How much to scale rotation user inputs") |
| 248 | + parser.add_argument( |
| 249 | + "--renderer", |
| 250 | + type=str, |
| 251 | + default="mjviewer", |
| 252 | + help="Use Mujoco's builtin interactive viewer (mjviewer) or OpenCV viewer (mujoco)", |
| 253 | + ) |
| 254 | + parser.add_argument( |
| 255 | + "--max_fr", |
| 256 | + default=20, |
| 257 | + type=int, |
| 258 | + help="Sleep when simluation runs faster than specified frame rate; 20 fps is real time.", |
| 259 | + ) |
| 260 | + args = parser.parse_args() |
| 261 | + |
| 262 | + # Get controller config |
| 263 | + controller_config = load_composite_controller_config( |
| 264 | + controller=args.controller, |
| 265 | + robot=args.robots[0], |
| 266 | + ) |
| 267 | + |
| 268 | + if controller_config["type"] == "WHOLE_BODY_MINK_IK": |
| 269 | + # mink-speicific import. requires installing mink |
| 270 | + from robosuite.examples.third_party_controller.mink_controller import WholeBodyMinkIK |
| 271 | + |
| 272 | + # Create argument configuration |
| 273 | + config = { |
| 274 | + "env_name": args.environment, |
| 275 | + "robots": args.robots, |
| 276 | + "controller_configs": controller_config, |
| 277 | + } |
| 278 | + |
| 279 | + # Check if we're using a multi-armed environment and use env_configuration argument if so |
| 280 | + if "TwoArm" in args.environment: |
| 281 | + config["env_configuration"] = args.config |
| 282 | + |
| 283 | + # Create environment |
| 284 | + env = suite.make( |
| 285 | + **config, |
| 286 | + has_renderer=True, |
| 287 | + renderer=args.renderer, |
| 288 | + has_offscreen_renderer=False, |
| 289 | + render_camera=args.camera, |
| 290 | + ignore_done=True, |
| 291 | + use_camera_obs=False, |
| 292 | + reward_shaping=True, |
| 293 | + control_freq=20, |
| 294 | + ) |
| 295 | + |
| 296 | + # Wrap this with visualization wrapper |
| 297 | + env = VisualizationWrapper(env) |
| 298 | + |
| 299 | + # Grab reference to controller config and convert it to json-encoded string |
| 300 | + env_info = json.dumps(config) |
| 301 | + |
| 302 | + # wrap the environment with data collection wrapper |
| 303 | + tmp_directory = "/tmp/{}".format(str(time.time()).replace(".", "_")) |
| 304 | + env = DataCollectionWrapper(env, tmp_directory) |
| 305 | + |
| 306 | + # initialize device |
| 307 | + if args.device == "keyboard": |
| 308 | + from robosuite.devices import Keyboard |
| 309 | + |
| 310 | + device = Keyboard(env=env, pos_sensitivity=args.pos_sensitivity, rot_sensitivity=args.rot_sensitivity) |
| 311 | + elif args.device == "spacemouse": |
| 312 | + from robosuite.devices import SpaceMouse |
| 313 | + |
| 314 | + device = SpaceMouse(env=env, pos_sensitivity=args.pos_sensitivity, rot_sensitivity=args.rot_sensitivity) |
| 315 | + elif args.device == "mjgui": |
| 316 | + assert args.renderer == "mjviewer", "Mocap is only supported with the mjviewer renderer" |
| 317 | + from robosuite.devices.mjgui import MJGUI |
| 318 | + |
| 319 | + device = MJGUI(env=env) |
| 320 | + else: |
| 321 | + raise Exception("Invalid device choice: choose either 'keyboard' or 'spacemouse'.") |
| 322 | + |
| 323 | + # make a new timestamped directory |
| 324 | + t1, t2 = str(time.time()).split(".") |
| 325 | + new_dir = os.path.join(args.directory, "{}_{}".format(t1, t2)) |
| 326 | + os.makedirs(new_dir) |
| 327 | + |
| 328 | + # collect demonstrations |
| 329 | + while True: |
| 330 | + collect_human_trajectory(env, device, args.arm, args.max_fr) |
| 331 | + gather_demonstrations_as_hdf5(tmp_directory, new_dir, env_info) |
0 commit comments