Skip to content

Commit

Permalink
Ibrl (#2)
Browse files Browse the repository at this point in the history
* remove dataset consistency check

* add pretrain configs

* rename

* transport pretrain cfg

* add ibrl

* fix base policy

* set `deterministic=True` when sampling in diffusion evaluation

* minors

* Revert "add rlpd framework"

* Revert "Revert "add rlpd framework"" (#4)

* match rlpd param names

* rename to `StitchedSequenceQLearningDataset`

* add configs

* add `tanh_output` and dropout to gaussians

* fix ibrl

* minors

---------

Co-authored-by: Justin M. Lidard <[email protected]>
Co-authored-by: allenzren <[email protected]>
  • Loading branch information
3 people committed Oct 6, 2024
1 parent 5d361a8 commit c2bc18f
Show file tree
Hide file tree
Showing 33 changed files with 1,674 additions and 353 deletions.
90 changes: 46 additions & 44 deletions agent/dataset/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def __init__(
self.img_cond_steps = img_cond_steps
self.device = device
self.use_img = use_img
self.max_n_episodes = max_n_episodes
self.dataset_path = dataset_path

# Load dataset to device specified
if dataset_path.endswith(".npz"):
Expand Down Expand Up @@ -88,7 +90,7 @@ def __getitem__(self, idx):
"""
start, num_before_start = self.indices[idx]
end = start + self.horizon_steps
states = self.states[(start - num_before_start) : start + 1]
states = self.states[(start - num_before_start) : (start + 1)]
actions = self.actions[start:end]
states = torch.stack(
[
Expand Down Expand Up @@ -117,9 +119,9 @@ def make_indices(self, traj_lengths, horizon_steps):
indices = []
cur_traj_index = 0
for traj_length in traj_lengths:
max_start = cur_traj_index + traj_length - horizon_steps + 1
max_start = cur_traj_index + traj_length - horizon_steps
indices += [
(i, i - cur_traj_index) for i in range(cur_traj_index, max_start)
(i, i - cur_traj_index) for i in range(cur_traj_index, max_start + 1)
]
cur_traj_index += traj_length
return indices
Expand All @@ -141,94 +143,94 @@ def __len__(self):
class StitchedSequenceQLearningDataset(StitchedSequenceDataset):
"""
Extends StitchedSequenceDataset to include rewards and dones for Q learning
Do not load the last step of **truncated** episodes since we do not have the correct next state for the final step of each episode. Truncation can be determined by terminal=False but end of episode.
"""

def __init__(
self,
dataset_path,
horizon_steps=64,
cond_steps=1,
img_cond_steps=1,
max_n_episodes=10000,
use_img=False,
device="cuda:0",
clip_to_eps=True,
eps=1e-5,
**kwargs,
):
super().__init__(
dataset_path,
horizon_steps,
cond_steps,
img_cond_steps,
max_n_episodes,
use_img,
device,
)
if clip_to_eps:
lim = 1 - eps
self.actions = torch.clip(self.actions, -lim, lim)

# Load dataset to device specified (additional processing for rewards and dones)
if dataset_path.endswith(".npz"):
dataset = np.load(dataset_path, allow_pickle=False) # only np arrays
dataset = np.load(dataset_path, allow_pickle=False)
elif dataset_path.endswith(".pkl"):
with open(dataset_path, "rb") as f:
dataset = pickle.load(f)
else:
raise ValueError(f"Unsupported file format: {dataset_path}")
traj_lengths = dataset["traj_lengths"][:max_n_episodes] # 1-D array
traj_lengths = dataset["traj_lengths"][:max_n_episodes]
total_num_steps = np.sum(traj_lengths)

# rewards and dones(terminals)
self.rewards = (
torch.from_numpy(dataset["rewards"][:total_num_steps]).float().to(device)
) # (total_num_steps, action_dim)
)
log.info(f"Rewards shape/type: {self.rewards.shape, self.rewards.dtype}")

# set the last done of each trajectory to 1
self.dones = torch.zeros_like(self.rewards)
cumulative_traj_length = np.cumsum(traj_lengths)
for i, traj_length in enumerate(cumulative_traj_length):
self.dones[traj_length - 1] = 1
self.dones = torch.from_numpy(dataset["terminals"][:total_num_steps]).to(device)
log.info(f"Dones shape/type: {self.dones.shape, self.dones.dtype}")

super().__init__(
dataset_path=dataset_path,
max_n_episodes=max_n_episodes,
device=device,
**kwargs,
)
log.info(f"Total number of transitions using: {len(self)}")

def make_indices(self, traj_lengths, horizon_steps):
"""
skip last step of truncated episodes
"""
num_skip = 0
indices = []
cur_traj_index = 0
for traj_length in traj_lengths:
max_start = cur_traj_index + traj_length - horizon_steps
if not self.dones[cur_traj_index + traj_length - 1]: # truncation
max_start -= 1
num_skip += 1
indices += [
(i, i - cur_traj_index) for i in range(cur_traj_index, max_start + 1)
]
cur_traj_index += traj_length
log.info(f"Number of transitions skipped due to truncation: {num_skip}")
return indices

def __getitem__(self, idx):
# Sample a transition that includes rewards and dones.
# We take the last reward and done for the action chunk as the reward and done for the transition.
start, num_before_start = self.indices[idx]
end = start + self.horizon_steps
states = self.states[(start - num_before_start) : start + 1]
states = self.states[(start - num_before_start) : (start + 1)]
actions = self.actions[start:end]
rewards = self.rewards[start:end][-1:]
dones = self.dones[start:end][-1:]
rewards = self.rewards[start : (start + 1)]
dones = self.dones[start : (start + 1)]

# Note: for self.horizon_steps > 1, we need to include the action chunk in the environment dynamics.
# The next state is the state at the end of the action chunk. Therefore, when we index,
# we need to check whether idx is within self.horizon_steps of the end of the dataset.
# Account for action horizon
if idx < len(self.indices) - self.horizon_steps:
# the states after we apply the action chunk
next_states = self.states[
(start - num_before_start + self.horizon_steps) : start
+ 1
+ self.horizon_steps
]
] # even if this uses the first state(s) of the next episode, done=True will prevent bootstrapping. We have already filtered out cases where done=False but end of episode (truncation).
else:
# prevents indexing error, but ignored since done=True
next_states = torch.zeros_like(states)

# stack obs history
states = torch.stack(
[
states[max(num_before_start - t, 0)]
for t in reversed(range(self.cond_steps))
]
) # more recent is at the end

next_states = torch.stack(
[
next_states[max(num_before_start - t, 0)]
for t in reversed(range(self.cond_steps))
]
) # more recent is at the end

conditions = {"state": states, "next_state": next_states}
if self.use_img:
images = self.images[(start - num_before_start) : end]
Expand Down
Loading

0 comments on commit c2bc18f

Please sign in to comment.