Skip to content

Commit

Permalink
Add support for integer observations
Browse files Browse the repository at this point in the history
  • Loading branch information
danijar committed May 9, 2024
1 parent 6952abb commit af4f5b7
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 14 deletions.
8 changes: 6 additions & 2 deletions dreamerv3/jaxutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,11 +720,15 @@ def concat_dict(mapping, batch_shape=None):
return jnp.concatenate(tensors, -1)


def onehot_dict(mapping, spaces):
def onehot_dict(mapping, spaces, filter=False, limit=256):
result = {}
for key, value in mapping.items():
if key not in spaces and filter:
continue
space = spaces[key]
if space.discrete:
if space.discrete and space.dtype != jnp.uint8:
if limit:
assert space.classes <= limit, (key, space, limit)
value = jax.nn.one_hot(value, space.classes)
result[key] = value
return result
Expand Down
7 changes: 6 additions & 1 deletion dreamerv3/nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ class SimpleEncoder(nj.Module):

def __init__(self, spaces, **kw):
assert all(len(s.shape) <= 3 for s in spaces.values()), spaces
self.spaces = spaces
self.veckeys = [k for k, s in spaces.items() if len(s.shape) <= 2]
self.imgkeys = [k for k, s in spaces.items() if len(s.shape) == 3]
self.vecinp = Input(self.veckeys, featdims=1)
Expand All @@ -246,6 +247,10 @@ def __call__(self, data, bdims=2):
kw = dict(**self.kw, norm=self.norm, act=self.act)
outs = []

shape = data['is_first'].shape[:bdims]
data = {k: data[k] for k in self.spaces}
data = jaxutils.onehot_dict(data, self.spaces)

if self.veckeys:
x = self.vecinp(data, bdims, f32)
x = x.reshape((-1, *x.shape[bdims:]))
Expand All @@ -268,7 +273,7 @@ def __call__(self, data, bdims=2):
outs.append(x)

x = jnp.concatenate(outs, -1)
x = x.reshape((*data['is_first'].shape, *x.shape[1:]))
x = x.reshape((*shape, *x.shape[1:]))
return x


Expand Down
15 changes: 4 additions & 11 deletions embodied/envs/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def obs_space(self):
return {
'image': embodied.Space(np.uint8, self._size + (3,)),
'vector': embodied.Space(np.float32, (7,)),
# 'token': embodied.Space(np.int32, (), 0, 256),
'step': embodied.Space(np.int32, (), 0, self._length),
'token': embodied.Space(np.int32, (), 0, 256),
'step': embodied.Space(np.float32, (), 0, self._length),
'reward': embodied.Space(np.float32),
'is_first': embodied.Space(bool),
'is_last': embodied.Space(bool),
Expand All @@ -27,13 +27,6 @@ def obs_space(self):

@property
def act_space(self):

# if self._task == 'cont':
# space = embodied.Space(np.float32, (6,))
# else:
# space = embodied.Space(np.int32, (), 0, 5)
# return {'action': space, 'reset': embodied.Space(bool)}

return {
'action': embodied.Space(np.int32, (), 0, 5),
'other': embodied.Space(np.float32, (6,)),
Expand All @@ -54,8 +47,8 @@ def _obs(self, reward, is_first=False, is_last=False, is_terminal=False):
return dict(
image=np.zeros(self._size + (3,), np.uint8),
vector=np.zeros(7, np.float32),
# token=np.zeros((), np.int32),
step=np.int32(self._step),
token=np.zeros((), np.int32),
step=np.float32(self._step),
reward=np.float32(reward),
is_first=is_first,
is_last=is_last,
Expand Down

0 comments on commit af4f5b7

Please sign in to comment.