Skip to content

Commit 176d31a

Browse files
committed
testing script
1 parent dea5ba9 commit 176d31a

2 files changed

Lines changed: 333 additions & 2 deletions

File tree

robosuite/controllers/config/robots/default_panda.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
"output_min": [-0.05, -0.05, -0.05, -0.5, -0.5, -0.5],
1111
"kp": 150,
1212
"damping_ratio": 1,
13-
"impedance_mode": "fixed",
13+
"impedance_mode": "variable",
1414
"kp_limits": [0, 300],
1515
"damping_ratio_limits": [0, 10],
1616
"position_limits": null,
1717
"orientation_limits": null,
1818
"uncouple_pos_ori": true,
19-
"input_type": "delta",
19+
"input_type": "absolute",
2020
"input_ref_frame": "base",
2121
"interpolation": null,
2222
"ramp_ratio": 0.2,
Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
"""
2+
A script to collect a batch of human demonstrations.
3+
4+
The demonstrations can be played back using the `playback_demonstrations_from_hdf5.py` script.
5+
"""
6+
7+
import argparse
8+
import datetime
9+
import json
10+
import os
11+
import time
12+
from glob import glob
13+
14+
import h5py
15+
import numpy as np
16+
17+
import robosuite as suite
18+
from robosuite.controllers import load_composite_controller_config
19+
from robosuite.controllers.composite.composite_controller import WholeBody
20+
from robosuite.wrappers import DataCollectionWrapper, VisualizationWrapper
21+
22+
23+
def collect_human_trajectory(env, device, arm, max_fr):
24+
"""
25+
Use the device (keyboard or SpaceNav 3D mouse) to collect a demonstration.
26+
The rollout trajectory is saved to files in npz format.
27+
Modify the DataCollectionWrapper wrapper to add new fields or change data formats.
28+
29+
Args:
30+
env (MujocoEnv): environment to control
31+
device (Device): to receive controls from the device
32+
arms (str): which arm to control (eg bimanual) 'right' or 'left'
33+
max_fr (int): if specified, pause the simulation whenever simulation runs faster than max_fr
34+
"""
35+
36+
env.reset()
37+
env.render()
38+
DAMPING_RATIO = 1
39+
DEFAULT_KD = 150
40+
task_completion_hold_count = -1 # counter to collect 10 timesteps after reaching goal
41+
device.start_control()
42+
43+
for robot in env.robots:
44+
robot.print_action_info_dict()
45+
46+
# Keep track of prev gripper actions when using since they are position-based and must be maintained when arms switched
47+
all_prev_gripper_actions = [
48+
{
49+
f"{robot_arm}_gripper": np.repeat([0], robot.gripper[robot_arm].dof)
50+
for robot_arm in robot.arms
51+
if robot.gripper[robot_arm].dof > 0
52+
}
53+
for robot in env.robots
54+
]
55+
56+
# Loop until we get a reset from the input or the task completes
57+
while True:
58+
start = time.time()
59+
60+
# Set active robot
61+
active_robot = env.robots[device.active_robot]
62+
63+
# Get the newest action
64+
input_ac_dict = device.input2action()
65+
66+
# If action is none, then this a reset so we should break
67+
if input_ac_dict is None:
68+
break
69+
70+
from copy import deepcopy
71+
72+
action_dict = deepcopy(input_ac_dict) # {}
73+
# set arm actions
74+
for arm in active_robot.arms:
75+
if isinstance(active_robot.composite_controller, WholeBody): # input type passed to joint_action_policy
76+
controller_input_type = active_robot.composite_controller.joint_action_policy.input_type
77+
else:
78+
controller_input_type = active_robot.part_controllers[arm].input_type
79+
80+
if controller_input_type == "delta":
81+
action_dict[arm] = input_ac_dict[f"{arm}_delta"]
82+
elif controller_input_type == "absolute":
83+
# sample gaussian noise with stdev at 15% of kp/kd
84+
85+
kd = np.random.normal(
86+
loc=DEFAULT_KD, scale=0.15 * DEFAULT_KD
87+
)
88+
damping_ratio = np.random.normal(
89+
loc=DAMPING_RATIO, scale=0.05 * DAMPING_RATIO
90+
)
91+
damping_ratio = np.clip(damping_ratio, 0.1, 1.0)
92+
# first six actions are kp then next six are kd
93+
action_dict[arm] = np.concatenate(
94+
[np.repeat([damping_ratio], 6), np.repeat([kd], 6), input_ac_dict[f"{arm}_abs"]]
95+
)
96+
else:
97+
raise ValueError
98+
99+
# Maintain gripper state for each robot but only update the active robot with action
100+
env_action = [robot.create_action_vector(all_prev_gripper_actions[i]) for i, robot in enumerate(env.robots)]
101+
env_action[device.active_robot] = active_robot.create_action_vector(action_dict)
102+
env_action = np.concatenate(env_action)
103+
for gripper_ac in all_prev_gripper_actions[device.active_robot]:
104+
all_prev_gripper_actions[device.active_robot][gripper_ac] = action_dict[gripper_ac]
105+
106+
env.step(env_action)
107+
env.render()
108+
109+
# Also break if we complete the task
110+
if task_completion_hold_count == 0:
111+
break
112+
113+
# state machine to check for having a success for 10 consecutive timesteps
114+
if env._check_success():
115+
if task_completion_hold_count > 0:
116+
task_completion_hold_count -= 1 # latched state, decrement count
117+
else:
118+
task_completion_hold_count = 10 # reset count on first success timestep
119+
else:
120+
task_completion_hold_count = -1 # null the counter if there's no success
121+
122+
# limit frame rate if necessary
123+
if max_fr is not None:
124+
elapsed = time.time() - start
125+
diff = 1 / max_fr - elapsed
126+
if diff > 0:
127+
time.sleep(diff)
128+
129+
# cleanup for end of data collection episodes
130+
env.close()
131+
132+
133+
def gather_demonstrations_as_hdf5(directory, out_dir, env_info):
134+
"""
135+
Gathers the demonstrations saved in @directory into a
136+
single hdf5 file.
137+
138+
The strucure of the hdf5 file is as follows.
139+
140+
data (group)
141+
date (attribute) - date of collection
142+
time (attribute) - time of collection
143+
repository_version (attribute) - repository version used during collection
144+
env (attribute) - environment name on which demos were collected
145+
146+
demo1 (group) - every demonstration has a group
147+
model_file (attribute) - model xml string for demonstration
148+
states (dataset) - flattened mujoco states
149+
actions (dataset) - actions applied during demonstration
150+
151+
demo2 (group)
152+
...
153+
154+
Args:
155+
directory (str): Path to the directory containing raw demonstrations.
156+
out_dir (str): Path to where to store the hdf5 file.
157+
env_info (str): JSON-encoded string containing environment information,
158+
including controller and robot info
159+
"""
160+
161+
hdf5_path = os.path.join(out_dir, "demo.hdf5")
162+
f = h5py.File(hdf5_path, "w")
163+
164+
# store some metadata in the attributes of one group
165+
grp = f.create_group("data")
166+
167+
num_eps = 0
168+
env_name = None # will get populated at some point
169+
170+
for ep_directory in os.listdir(directory):
171+
172+
state_paths = os.path.join(directory, ep_directory, "state_*.npz")
173+
states = []
174+
actions = []
175+
success = False
176+
177+
for state_file in sorted(glob(state_paths)):
178+
dic = np.load(state_file, allow_pickle=True)
179+
env_name = str(dic["env"])
180+
181+
states.extend(dic["states"])
182+
for ai in dic["action_infos"]:
183+
actions.append(ai["actions"])
184+
success = success or dic["successful"]
185+
186+
if len(states) == 0:
187+
continue
188+
189+
# Add only the successful demonstration to dataset
190+
if success:
191+
print("Demonstration is successful and has been saved")
192+
# Delete the last state. This is because when the DataCollector wrapper
193+
# recorded the states and actions, the states were recorded AFTER playing that action,
194+
# so we end up with an extra state at the end.
195+
del states[-1]
196+
assert len(states) == len(actions)
197+
198+
num_eps += 1
199+
ep_data_grp = grp.create_group("demo_{}".format(num_eps))
200+
201+
# store model xml as an attribute
202+
xml_path = os.path.join(directory, ep_directory, "model.xml")
203+
with open(xml_path, "r") as f:
204+
xml_str = f.read()
205+
ep_data_grp.attrs["model_file"] = xml_str
206+
207+
# write datasets for states and actions
208+
ep_data_grp.create_dataset("states", data=np.array(states))
209+
ep_data_grp.create_dataset("actions", data=np.array(actions))
210+
else:
211+
print("Demonstration is unsuccessful and has NOT been saved")
212+
213+
# write dataset attributes (metadata)
214+
now = datetime.datetime.now()
215+
grp.attrs["date"] = "{}-{}-{}".format(now.month, now.day, now.year)
216+
grp.attrs["time"] = "{}:{}:{}".format(now.hour, now.minute, now.second)
217+
grp.attrs["repository_version"] = suite.__version__
218+
grp.attrs["env"] = env_name
219+
grp.attrs["env_info"] = env_info
220+
221+
f.close()
222+
223+
224+
if __name__ == "__main__":
225+
# Arguments
226+
parser = argparse.ArgumentParser()
227+
parser.add_argument(
228+
"--directory",
229+
type=str,
230+
default=os.path.join(suite.models.assets_root, "demonstrations_private"),
231+
)
232+
parser.add_argument("--environment", type=str, default="Lift")
233+
parser.add_argument("--robots", nargs="+", type=str, default="Panda", help="Which robot(s) to use in the env")
234+
parser.add_argument(
235+
"--config", type=str, default="default", help="Specified environment configuration if necessary"
236+
)
237+
parser.add_argument("--arm", type=str, default="right", help="Which arm to control (eg bimanual) 'right' or 'left'")
238+
parser.add_argument("--camera", type=str, default="agentview", help="Which camera to use for collecting demos")
239+
parser.add_argument(
240+
"--controller",
241+
type=str,
242+
default=None,
243+
help="Choice of controller. Can be generic (eg. 'BASIC' or 'WHOLE_BODY_MINK_IK') or json file (see robosuite/controllers/config for examples)",
244+
)
245+
parser.add_argument("--device", type=str, default="keyboard")
246+
parser.add_argument("--pos-sensitivity", type=float, default=1.0, help="How much to scale position user inputs")
247+
parser.add_argument("--rot-sensitivity", type=float, default=1.0, help="How much to scale rotation user inputs")
248+
parser.add_argument(
249+
"--renderer",
250+
type=str,
251+
default="mjviewer",
252+
help="Use Mujoco's builtin interactive viewer (mjviewer) or OpenCV viewer (mujoco)",
253+
)
254+
parser.add_argument(
255+
"--max_fr",
256+
default=20,
257+
type=int,
258+
help="Sleep when simluation runs faster than specified frame rate; 20 fps is real time.",
259+
)
260+
args = parser.parse_args()
261+
262+
# Get controller config
263+
controller_config = load_composite_controller_config(
264+
controller=args.controller,
265+
robot=args.robots[0],
266+
)
267+
268+
if controller_config["type"] == "WHOLE_BODY_MINK_IK":
269+
# mink-speicific import. requires installing mink
270+
from robosuite.examples.third_party_controller.mink_controller import WholeBodyMinkIK
271+
272+
# Create argument configuration
273+
config = {
274+
"env_name": args.environment,
275+
"robots": args.robots,
276+
"controller_configs": controller_config,
277+
}
278+
279+
# Check if we're using a multi-armed environment and use env_configuration argument if so
280+
if "TwoArm" in args.environment:
281+
config["env_configuration"] = args.config
282+
283+
# Create environment
284+
env = suite.make(
285+
**config,
286+
has_renderer=True,
287+
renderer=args.renderer,
288+
has_offscreen_renderer=False,
289+
render_camera=args.camera,
290+
ignore_done=True,
291+
use_camera_obs=False,
292+
reward_shaping=True,
293+
control_freq=20,
294+
)
295+
296+
# Wrap this with visualization wrapper
297+
env = VisualizationWrapper(env)
298+
299+
# Grab reference to controller config and convert it to json-encoded string
300+
env_info = json.dumps(config)
301+
302+
# wrap the environment with data collection wrapper
303+
tmp_directory = "/tmp/{}".format(str(time.time()).replace(".", "_"))
304+
env = DataCollectionWrapper(env, tmp_directory)
305+
306+
# initialize device
307+
if args.device == "keyboard":
308+
from robosuite.devices import Keyboard
309+
310+
device = Keyboard(env=env, pos_sensitivity=args.pos_sensitivity, rot_sensitivity=args.rot_sensitivity)
311+
elif args.device == "spacemouse":
312+
from robosuite.devices import SpaceMouse
313+
314+
device = SpaceMouse(env=env, pos_sensitivity=args.pos_sensitivity, rot_sensitivity=args.rot_sensitivity)
315+
elif args.device == "mjgui":
316+
assert args.renderer == "mjviewer", "Mocap is only supported with the mjviewer renderer"
317+
from robosuite.devices.mjgui import MJGUI
318+
319+
device = MJGUI(env=env)
320+
else:
321+
raise Exception("Invalid device choice: choose either 'keyboard' or 'spacemouse'.")
322+
323+
# make a new timestamped directory
324+
t1, t2 = str(time.time()).split(".")
325+
new_dir = os.path.join(args.directory, "{}_{}".format(t1, t2))
326+
os.makedirs(new_dir)
327+
328+
# collect demonstrations
329+
while True:
330+
collect_human_trajectory(env, device, args.arm, args.max_fr)
331+
gather_demonstrations_as_hdf5(tmp_directory, new_dir, env_info)

0 commit comments

Comments
 (0)