-
Notifications
You must be signed in to change notification settings - Fork 167
/
06_pytorch_oxe_dataloader.py
119 lines (105 loc) · 3.57 KB
/
06_pytorch_oxe_dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
This example shows how to use the `octo.data` dataloader with PyTorch by wrapping it in a simple PyTorch
dataloader. The config below also happens to be our exact pretraining config (except for the batch size and
shuffle buffer size, which are reduced for demonstration purposes).
"""
import numpy as np
import tensorflow as tf
import torch
from torch.utils.data import DataLoader
import tqdm
from octo.data.dataset import make_interleaved_dataset
from octo.data.oxe import make_oxe_dataset_kwargs_and_weights
DATA_PATH = "gs://rail-orca-central2/resize_256_256"
tf.config.set_visible_devices([], "GPU")
class TorchRLDSDataset(torch.utils.data.IterableDataset):
"""Thin wrapper around RLDS dataset for use with PyTorch dataloaders."""
def __init__(
self,
rlds_dataset,
train=True,
):
self._rlds_dataset = rlds_dataset
self._is_train = train
def __iter__(self):
for sample in self._rlds_dataset.as_numpy_iterator():
yield sample
def __len__(self):
lengths = np.array(
[
stats["num_transitions"]
for stats in self._rlds_dataset.dataset_statistics
]
)
if hasattr(self._rlds_dataset, "sample_weights"):
lengths *= np.array(self._rlds_dataset.sample_weights)
total_len = lengths.sum()
if self._is_train:
return int(0.95 * total_len)
else:
return int(0.05 * total_len)
dataset_kwargs_list, sample_weights = make_oxe_dataset_kwargs_and_weights(
"oxe_magic_soup",
DATA_PATH,
load_camera_views=("primary", "wrist"),
)
dataset = make_interleaved_dataset(
dataset_kwargs_list,
sample_weights,
train=True,
shuffle_buffer_size=1000, # change to 500k for training, large shuffle buffers are important, but adjust to your RAM
batch_size=None, # batching will be handles in PyTorch Dataloader object
balance_weights=True,
traj_transform_kwargs=dict(
goal_relabeling_strategy="uniform",
window_size=2,
action_horizon=4,
subsample_length=100,
),
frame_transform_kwargs=dict(
image_augment_kwargs={
"primary": dict(
random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]),
random_brightness=[0.1],
random_contrast=[0.9, 1.1],
random_saturation=[0.9, 1.1],
random_hue=[0.05],
augment_order=[
"random_resized_crop",
"random_brightness",
"random_contrast",
"random_saturation",
"random_hue",
],
),
"wrist": dict(
random_brightness=[0.1],
random_contrast=[0.9, 1.1],
random_saturation=[0.9, 1.1],
random_hue=[0.05],
augment_order=[
"random_brightness",
"random_contrast",
"random_saturation",
"random_hue",
],
),
},
resize_size=dict(
primary=(256, 256),
wrist=(128, 128),
),
num_parallel_calls=200,
),
traj_transform_threads=48,
traj_read_threads=48,
)
pytorch_dataset = TorchRLDSDataset(dataset)
dataloader = DataLoader(
pytorch_dataset,
batch_size=16,
num_workers=0, # important to keep this to 0 so PyTorch does not mess with the parallelism
)
for i, sample in tqdm.tqdm(enumerate(dataloader)):
if i == 5000:
break