forked from real-stanford/universal_manipulation_interface
-
Notifications
You must be signed in to change notification settings - Fork 2
/
bimanual_umi_env.py
631 lines (564 loc) · 24.8 KB
/
bimanual_umi_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
from typing import Optional, List
import pathlib
import numpy as np
import time
import shutil
import math
from multiprocessing.managers import SharedMemoryManager
from umi.real_world.rtde_interpolation_controller import RTDEInterpolationController
from umi.real_world.wsg_controller import WSGController
from umi.real_world.franka_interpolation_controller import FrankaInterpolationController
from umi.real_world.multi_uvc_camera import MultiUvcCamera, VideoRecorder
from diffusion_policy.common.timestamp_accumulator import (
TimestampActionAccumulator,
ObsAccumulator
)
from umi.common.cv_util import draw_predefined_mask
from umi.real_world.multi_camera_visualizer import MultiCameraVisualizer
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.common.cv2_util import (
get_image_transform, optimal_row_cols)
from umi.common.usb_util import reset_all_elgato_devices, get_sorted_v4l_paths
from umi.common.pose_util import pose_to_pos_rot
from umi.common.interpolation_util import get_interp1d, PoseInterpolator
class BimanualUmiEnv:
def __init__(self,
# required params
output_dir,
robots_config, # list of dict[{robot_type: 'ur5', robot_ip: XXX, obs_latency: 0.0001, action_latency: 0.1, tcp_offset: 0.21}]
grippers_config, # list of dict[{gripper_ip: XXX, gripper_port: 1000, obs_latency: 0.01, , action_latency: 0.1}]
# env params
frequency=20,
# obs
obs_image_resolution=(224,224),
max_obs_buffer_size=60,
obs_float32=False,
camera_reorder=None,
no_mirror=False,
fisheye_converter=None,
mirror_swap=False,
# this latency compensates receive_timestamp
# all in seconds
camera_obs_latency=0.125,
# all in steps (relative to frequency)
camera_down_sample_steps=1,
robot_down_sample_steps=1,
gripper_down_sample_steps=1,
# all in steps (relative to frequency)
camera_obs_horizon=2,
robot_obs_horizon=2,
gripper_obs_horizon=2,
# action
max_pos_speed=0.25,
max_rot_speed=0.6,
init_joints=False,
# vis params
enable_multi_cam_vis=True,
multi_cam_vis_resolution=(960, 960),
# shared memory
shm_manager=None
):
output_dir = pathlib.Path(output_dir)
assert output_dir.parent.is_dir()
video_dir = output_dir.joinpath('videos')
video_dir.mkdir(parents=True, exist_ok=True)
zarr_path = str(output_dir.joinpath('replay_buffer.zarr').absolute())
replay_buffer = ReplayBuffer.create_from_path(
zarr_path=zarr_path, mode='a')
if shm_manager is None:
shm_manager = SharedMemoryManager()
shm_manager.start()
# Find and reset all Elgato capture cards.
# Required to workaround a firmware bug.
reset_all_elgato_devices()
# Wait for all v4l cameras to be back online
time.sleep(0.1)
v4l_paths = get_sorted_v4l_paths()
if camera_reorder is not None:
paths = [v4l_paths[i] for i in camera_reorder]
v4l_paths = paths
# compute resolution for vis
rw, rh, col, row = optimal_row_cols(
n_cameras=len(v4l_paths),
in_wh_ratio=4/3,
max_resolution=multi_cam_vis_resolution
)
# HACK: Separate video setting for each camera
# Elagto Cam Link 4k records at 4k 30fps
# Other capture card records at 720p 60fps
resolution = list()
capture_fps = list()
cap_buffer_size = list()
video_recorder = list()
transform = list()
vis_transform = list()
for path in v4l_paths:
if 'Cam_Link_4K' in path:
res = (3840, 2160)
fps = 30
buf = 3
bit_rate = 6000*1000
def tf4k(data, input_res=res):
img = data['color']
f = get_image_transform(
input_res=input_res,
output_res=obs_image_resolution,
# obs output rgb
bgr_to_rgb=True)
img = f(img)
if obs_float32:
img = img.astype(np.float32) / 255
data['color'] = img
return data
transform.append(tf4k)
else:
res = (1920, 1080)
fps = 60
buf = 1
bit_rate = 3000*1000
is_mirror = None
if mirror_swap:
mirror_mask = np.ones((224,224,3),dtype=np.uint8)
mirror_mask = draw_predefined_mask(
mirror_mask, color=(0,0,0), mirror=True, gripper=False, finger=False)
is_mirror = (mirror_mask[...,0] == 0)
def tf(data, input_res=res):
img = data['color']
if fisheye_converter is None:
f = get_image_transform(
input_res=input_res,
output_res=obs_image_resolution,
# obs output rgb
bgr_to_rgb=True)
img = np.ascontiguousarray(f(img))
if is_mirror is not None:
img[is_mirror] = img[:,::-1,:][is_mirror]
img = draw_predefined_mask(img, color=(0,0,0),
mirror=no_mirror, gripper=True, finger=False, use_aa=True)
else:
img = fisheye_converter.forward(img)
img = img[...,::-1]
if obs_float32:
img = img.astype(np.float32) / 255
data['color'] = img
return data
transform.append(tf)
resolution.append(res)
capture_fps.append(fps)
cap_buffer_size.append(buf)
video_recorder.append(VideoRecorder.create_hevc_nvenc(
fps=fps,
input_pix_fmt='bgr24',
bit_rate=bit_rate
))
def vis_tf(data, input_res=res):
img = data['color']
f = get_image_transform(
input_res=input_res,
output_res=(rw,rh),
bgr_to_rgb=False
)
img = f(img)
data['color'] = img
return data
vis_transform.append(vis_tf)
camera = MultiUvcCamera(
dev_video_paths=v4l_paths,
shm_manager=shm_manager,
resolution=resolution,
capture_fps=capture_fps,
# send every frame immediately after arrival
# ignores put_fps
put_downsample=False,
get_max_k=max_obs_buffer_size,
receive_latency=camera_obs_latency,
cap_buffer_size=cap_buffer_size,
transform=transform,
vis_transform=vis_transform,
video_recorder=video_recorder,
verbose=False
)
multi_cam_vis = None
if enable_multi_cam_vis:
multi_cam_vis = MultiCameraVisualizer(
camera=camera,
row=row,
col=col,
rgb_to_bgr=False
)
cube_diag = np.linalg.norm([1,1,1])
j_init = np.array([0,-90,-90,-90,90,0]) / 180 * np.pi
if not init_joints:
j_init = None
assert len(robots_config) == len(grippers_config)
robots: List[RTDEInterpolationController] = list()
grippers: List[WSGController] = list()
for rc in robots_config:
if rc['robot_type'].startswith('ur5'):
assert rc['robot_type'] in ['ur5', 'ur5e']
this_robot = RTDEInterpolationController(
shm_manager=shm_manager,
robot_ip=rc['robot_ip'],
frequency=500 if rc['robot_type'] == 'ur5e' else 125,
lookahead_time=0.1,
gain=300,
max_pos_speed=max_pos_speed*cube_diag,
max_rot_speed=max_rot_speed*cube_diag,
launch_timeout=3,
tcp_offset_pose=[0, 0, rc['tcp_offset'], 0, 0, 0],
payload_mass=None,
payload_cog=None,
joints_init=j_init,
joints_init_speed=1.05,
soft_real_time=False,
verbose=False,
receive_keys=None,
receive_latency=rc['robot_obs_latency']
)
elif rc['robot_type'].startswith('franka'):
this_robot = FrankaInterpolationController(
shm_manager=shm_manager,
robot_ip=rc['robot_ip'],
frequency=200,
Kx_scale=1.0,
Kxd_scale=np.array([2.0,1.5,2.0,1.0,1.0,1.0]),
verbose=False,
receive_latency=rc['robot_obs_latency']
)
else:
raise NotImplementedError()
robots.append(this_robot)
for gc in grippers_config:
this_gripper = WSGController(
shm_manager=shm_manager,
hostname=gc['gripper_ip'],
port=gc['gripper_port'],
receive_latency=gc['gripper_obs_latency'],
use_meters=True
)
grippers.append(this_gripper)
self.camera = camera
self.robots = robots
self.robots_config = robots_config
self.grippers = grippers
self.grippers_config = grippers_config
self.multi_cam_vis = multi_cam_vis
self.frequency = frequency
self.max_obs_buffer_size = max_obs_buffer_size
self.max_pos_speed = max_pos_speed
self.max_rot_speed = max_rot_speed
# timing
self.camera_obs_latency = camera_obs_latency
self.camera_down_sample_steps = camera_down_sample_steps
self.robot_down_sample_steps = robot_down_sample_steps
self.gripper_down_sample_steps = gripper_down_sample_steps
self.camera_obs_horizon = camera_obs_horizon
self.robot_obs_horizon = robot_obs_horizon
self.gripper_obs_horizon = gripper_obs_horizon
# recording
self.output_dir = output_dir
self.video_dir = video_dir
self.replay_buffer = replay_buffer
# temp memory buffers
self.last_camera_data = None
# recording buffers
self.obs_accumulator = None
self.action_accumulator = None
self.start_time = None
self.last_time_step = 0
# ======== start-stop API =============
@property
def is_ready(self):
ready_flag = self.camera.is_ready
for robot in self.robots:
ready_flag = ready_flag and robot.is_ready
for gripper in self.grippers:
ready_flag = ready_flag and gripper.is_ready
return ready_flag
def start(self, wait=True):
self.camera.start(wait=False)
for robot in self.robots:
robot.start(wait=False)
for gripper in self.grippers:
gripper.start(wait=False)
if self.multi_cam_vis is not None:
self.multi_cam_vis.start(wait=False)
if wait:
self.start_wait()
def stop(self, wait=True):
self.end_episode()
if self.multi_cam_vis is not None:
self.multi_cam_vis.stop(wait=False)
for robot in self.robots:
robot.stop(wait=False)
for gripper in self.grippers:
gripper.stop(wait=False)
self.camera.stop(wait=False)
if wait:
self.stop_wait()
def start_wait(self):
self.camera.start_wait()
for robot in self.robots:
robot.start_wait()
for gripper in self.grippers:
gripper.start_wait()
if self.multi_cam_vis is not None:
self.multi_cam_vis.start_wait()
def stop_wait(self):
for robot in self.robots:
robot.stop_wait()
for gripper in self.grippers:
gripper.stop_wait()
self.camera.stop_wait()
if self.multi_cam_vis is not None:
self.multi_cam_vis.stop_wait()
# ========= context manager ===========
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
# ========= async env API ===========
def get_obs(self) -> dict:
"""
Timestamp alignment policy
We assume the cameras used for obs are always [0, k - 1], where k is the number of robots
All other cameras, find corresponding frame with the nearest timestamp
All low-dim observations, interpolate with respect to 'current' time
"""
"observation dict"
assert self.is_ready
# get data
# 60 Hz, camera_calibrated_timestamp
k = math.ceil(
self.camera_obs_horizon * self.camera_down_sample_steps \
* (60 / self.frequency)) + 2 # here 2 is adjustable, typically 1 should be enough
# print('==>k ', k, self.camera_obs_horizon, self.camera_down_sample_steps, self.frequency)
self.last_camera_data = self.camera.get(
k=k,
out=self.last_camera_data)
# both have more than n_obs_steps data
last_robots_data = list()
last_grippers_data = list()
# 125/500 hz, robot_receive_timestamp
for robot in self.robots:
last_robots_data.append(robot.get_all_state())
# 30 hz, gripper_receive_timestamp
for gripper in self.grippers:
last_grippers_data.append(gripper.get_all_state())
# select align_camera_idx
num_obs_cameras = len(self.robots)
align_camera_idx = None
running_best_error = np.inf
for camera_idx in range(num_obs_cameras):
this_error = 0
this_timestamp = self.last_camera_data[camera_idx]['timestamp'][-1]
for other_camera_idx in range(num_obs_cameras):
if other_camera_idx == camera_idx:
continue
other_timestep_idx = -1
while True:
if self.last_camera_data[other_camera_idx]['timestamp'][other_timestep_idx] < this_timestamp:
this_error += this_timestamp - self.last_camera_data[other_camera_idx]['timestamp'][other_timestep_idx]
break
other_timestep_idx -= 1
if align_camera_idx is None or this_error < running_best_error:
running_best_error = this_error
align_camera_idx = camera_idx
last_timestamp = self.last_camera_data[align_camera_idx]['timestamp'][-1]
dt = 1 / self.frequency
# align camera obs timestamps
camera_obs_timestamps = last_timestamp - (
np.arange(self.camera_obs_horizon)[::-1] * self.camera_down_sample_steps * dt)
camera_obs = dict()
for camera_idx, value in self.last_camera_data.items():
this_timestamps = value['timestamp']
this_idxs = list()
for t in camera_obs_timestamps:
nn_idx = np.argmin(np.abs(this_timestamps - t))
# if np.abs(this_timestamps - t)[nn_idx] > 1.0 / 120 and camera_idx != 3:
# print('ERROR!!! ', camera_idx, len(this_timestamps), nn_idx, (this_timestamps - t)[nn_idx-1: nn_idx+2])
this_idxs.append(nn_idx)
# remap key
camera_obs[f'camera{camera_idx}_rgb'] = value['color'][this_idxs]
# obs_data to return (it only includes camera data at this stage)
obs_data = dict(camera_obs)
# include camera timesteps
obs_data['timestamp'] = camera_obs_timestamps
# align robot obs
robot_obs_timestamps = last_timestamp - (
np.arange(self.robot_obs_horizon)[::-1] * self.robot_down_sample_steps * dt)
for robot_idx, last_robot_data in enumerate(last_robots_data):
robot_pose_interpolator = PoseInterpolator(
t=last_robot_data['robot_timestamp'],
x=last_robot_data['ActualTCPPose'])
robot_pose = robot_pose_interpolator(robot_obs_timestamps)
robot_obs = {
f'robot{robot_idx}_eef_pos': robot_pose[...,:3],
f'robot{robot_idx}_eef_rot_axis_angle': robot_pose[...,3:]
}
# update obs_data
obs_data.update(robot_obs)
# align gripper obs
gripper_obs_timestamps = last_timestamp - (
np.arange(self.gripper_obs_horizon)[::-1] * self.gripper_down_sample_steps * dt)
for robot_idx, last_gripper_data in enumerate(last_grippers_data):
# align gripper obs
gripper_interpolator = get_interp1d(
t=last_gripper_data['gripper_timestamp'],
x=last_gripper_data['gripper_position'][...,None]
)
gripper_obs = {
f'robot{robot_idx}_gripper_width': gripper_interpolator(gripper_obs_timestamps)
}
# update obs_data
obs_data.update(gripper_obs)
# accumulate obs
if self.obs_accumulator is not None:
for robot_idx, last_robot_data in enumerate(last_robots_data):
self.obs_accumulator.put(
data={
f'robot{robot_idx}_eef_pose': last_robot_data['ActualTCPPose'],
f'robot{robot_idx}_joint_pos': last_robot_data['ActualQ'],
f'robot{robot_idx}_joint_vel': last_robot_data['ActualQd'],
},
timestamps=last_robot_data['robot_timestamp']
)
for robot_idx, last_gripper_data in enumerate(last_grippers_data):
self.obs_accumulator.put(
data={
f'robot{robot_idx}_gripper_width': last_gripper_data['gripper_position'][...,None]
},
timestamps=last_gripper_data['gripper_timestamp']
)
return obs_data
def exec_actions(self,
actions: np.ndarray,
timestamps: np.ndarray,
compensate_latency=False):
assert self.is_ready
if not isinstance(actions, np.ndarray):
actions = np.array(actions)
if not isinstance(timestamps, np.ndarray):
timestamps = np.array(timestamps)
# convert action to pose
receive_time = time.time()
is_new = timestamps > receive_time
new_actions = actions[is_new]
new_timestamps = timestamps[is_new]
assert new_actions.shape[1] // len(self.robots) == 7
assert new_actions.shape[1] % len(self.robots) == 0
# schedule waypoints
for i in range(len(new_actions)):
for robot_idx, (robot, gripper, rc, gc) in enumerate(zip(self.robots, self.grippers, self.robots_config, self.grippers_config)):
r_latency = rc['robot_action_latency'] if compensate_latency else 0.0
g_latency = gc['gripper_action_latency'] if compensate_latency else 0.0
r_actions = new_actions[i, 7 * robot_idx + 0: 7 * robot_idx + 6]
g_actions = new_actions[i, 7 * robot_idx + 6]
robot.schedule_waypoint(
pose=r_actions,
target_time=new_timestamps[i] - r_latency
)
gripper.schedule_waypoint(
pos=g_actions,
target_time=new_timestamps[i] - g_latency
)
# record actions
if self.action_accumulator is not None:
self.action_accumulator.put(
new_actions,
new_timestamps
)
def get_robot_state(self):
return [robot.get_state() for robot in self.robots]
def get_gripper_state(self):
return [gripper.get_state() for gripper in self.grippers]
# recording API
def start_episode(self, start_time=None):
"Start recording and return first obs"
if start_time is None:
start_time = time.time()
self.start_time = start_time
assert self.is_ready
# prepare recording stuff
episode_id = self.replay_buffer.n_episodes
this_video_dir = self.video_dir.joinpath(str(episode_id))
this_video_dir.mkdir(parents=True, exist_ok=True)
n_cameras = self.camera.n_cameras
video_paths = list()
for i in range(n_cameras):
video_paths.append(
str(this_video_dir.joinpath(f'{i}.mp4').absolute()))
# start recording on camera
self.camera.restart_put(start_time=start_time)
self.camera.start_recording(video_path=video_paths, start_time=start_time)
# create accumulators
self.obs_accumulator = ObsAccumulator()
self.action_accumulator = TimestampActionAccumulator(
start_time=start_time,
dt=1/self.frequency
)
print(f'Episode {episode_id} started!')
def end_episode(self):
"Stop recording"
assert self.is_ready
# stop video recorder
self.camera.stop_recording()
# TODO
if self.obs_accumulator is not None:
# recording
assert self.action_accumulator is not None
# Since the only way to accumulate obs and action is by calling
# get_obs and exec_actions, which will be in the same thread.
# We don't need to worry new data come in here.
end_time = float('inf')
for key, value in self.obs_accumulator.timestamps.items():
end_time = min(end_time, value[-1])
end_time = min(end_time, self.action_accumulator.timestamps[-1])
actions = self.action_accumulator.actions
action_timestamps = self.action_accumulator.timestamps
n_steps = 0
if np.sum(self.action_accumulator.timestamps <= end_time) > 0:
n_steps = np.nonzero(self.action_accumulator.timestamps <= end_time)[0][-1]+1
if n_steps > 0:
timestamps = action_timestamps[:n_steps]
episode = {
'timestamp': timestamps,
'action': actions[:n_steps],
}
for robot_idx in range(len(self.robots)):
robot_pose_interpolator = PoseInterpolator(
t=np.array(self.obs_accumulator.timestamps[f'robot{robot_idx}_eef_pose']),
x=np.array(self.obs_accumulator.data[f'robot{robot_idx}_eef_pose'])
)
robot_pose = robot_pose_interpolator(timestamps)
episode[f'robot{robot_idx}_eef_pos'] = robot_pose[:,:3]
episode[f'robot{robot_idx}_eef_rot_axis_angle'] = robot_pose[:,3:]
joint_pos_interpolator = get_interp1d(
np.array(self.obs_accumulator.timestamps[f'robot{robot_idx}_joint_pos']),
np.array(self.obs_accumulator.data[f'robot{robot_idx}_joint_pos'])
)
joint_vel_interpolator = get_interp1d(
np.array(self.obs_accumulator.timestamps[f'robot{robot_idx}_joint_vel']),
np.array(self.obs_accumulator.data[f'robot{robot_idx}_joint_vel'])
)
episode[f'robot{robot_idx}_joint_pos'] = joint_pos_interpolator(timestamps)
episode[f'robot{robot_idx}_joint_vel'] = joint_vel_interpolator(timestamps)
gripper_interpolator = get_interp1d(
t=np.array(self.obs_accumulator.timestamps[f'robot{robot_idx}_gripper_width']),
x=np.array(self.obs_accumulator.data[f'robot{robot_idx}_gripper_width'])
)
episode[f'robot{robot_idx}_gripper_width'] = gripper_interpolator(timestamps)
self.replay_buffer.add_episode(episode, compressors='disk')
episode_id = self.replay_buffer.n_episodes - 1
print(f'Episode {episode_id} saved!')
self.obs_accumulator = None
self.action_accumulator = None
def drop_episode(self):
self.end_episode()
self.replay_buffer.drop_episode()
episode_id = self.replay_buffer.n_episodes
this_video_dir = self.video_dir.joinpath(str(episode_id))
if this_video_dir.exists():
shutil.rmtree(str(this_video_dir))
print(f'Episode {episode_id} dropped!')