Skip to content

Commit

Permalink
Clean up API and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
danijar committed Apr 29, 2024
1 parent 2411f7d commit 2aba861
Show file tree
Hide file tree
Showing 22 changed files with 285 additions and 253 deletions.
74 changes: 39 additions & 35 deletions dreamerv3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ def __init__(self, obs_space, act_space, config):
not k.startswith('log_') and re.match(config.dec.spaces, k)}
embodied.print('Encoder:', {k: v.shape for k, v in enc_space.items()})
embodied.print('Decoder:', {k: v.shape for k, v in dec_space.items()})
# nets.Initializer.VARIANCE_FACTOR = self.config.init_scale
nets.Initializer.FORCE_STDDEV = self.config.winit_scale

# World Model
self.enc = {
Expand Down Expand Up @@ -85,16 +83,6 @@ def __init__(self, obs_space, act_space, config):
# Optimizer
kw = dict(config.opt)
lr = kw.pop('lr')
if config.compute_lr:
assert not config.separate_lrs
width = float(config.actor.units)
replay = float(config.run.train_ratio)
a = config.compute_lr_params.lnwidth
b = config.compute_lr_params.lnreplay
c = config.compute_lr_params.bias
lr = np.exp(a * np.log(width) + b * np.log(replay) + c)
message = f'Computed LR (width={width}, replay={replay}): {lr:.1e}'
embodied.print(message)
if config.separate_lrs:
lr = {f'agent/{k}': v for k, v in config.lrs.items()}
self.opt = jaxutils.Optimizer(lr, **kw, name='opt')
Expand All @@ -114,7 +102,6 @@ def policy_keys(self):

@property
def aux_spaces(self):
import numpy as np
spaces = {}
spaces['stepid'] = embodied.Space(np.uint8, 20)
if self.config.replay_context:
Expand Down Expand Up @@ -142,6 +129,9 @@ def init_train(self, batch_size):
for k, v in self.act_space.items()}
return (self.dyn.initial(batch_size), prevact)

def init_report(self, batch_size):
return self.init_train(batch_size)

def policy(self, obs, carry, mode='train'):
self.config.jax.jit and embodied.print(
'Tracing policy function', color='yellow')
Expand Down Expand Up @@ -332,13 +322,8 @@ def imgstep(carry, _):
adv_normed = (adv - aoffset) / ascale
logpi = sum([v.log_prob(sg(acts[k]))[:, :-1] for k, v in actor.items()])
ents = {k: v.entropy()[:, :-1] for k, v in actor.items()}
if self.config.scale_by_actent:
actor_loss = sg(weight[:, :-1]) * -(
logpi * sg(adv_normed) * (1 / self.config.actent) +
sum(ents.values()))
else:
actor_loss = sg(weight[:, :-1]) * -(
logpi * sg(adv_normed) + self.config.actent * sum(ents.values()))
actor_loss = sg(weight[:, :-1]) * -(
logpi * sg(adv_normed) + self.config.actent * sum(ents.values()))
losses['actor'] = actor_loss

# Critic
Expand Down Expand Up @@ -414,34 +399,55 @@ def imgstep(carry, _):
losses = {k: v * self.scales[k] for k, v in losses.items()}
loss = jnp.stack([v.mean() for k, v in losses.items()]).sum()
newact = {k: data[k][:, -1] for k in self.act_space}
outs = {'replay_outs': replay_outs, 'prevacts': prevacts}
outs = {'replay_outs': replay_outs, 'prevacts': prevacts, 'embed': embed}
outs.update({f'{k}_loss': v for k, v in losses.items()})
carry = (newlat, newact)
return loss, (outs, carry, metrics)

def report(self, data):
def report(self, data, carry):
self.config.jax.jit and embodied.print(
'Tracing report function', color='yellow')
if not self.config.report:
return {}
return {}, carry
metrics = {}
data = self.preprocess(data)

# Train metrics
carry = self.init_train(len(data['is_first']))
_, (outs, _, mets) = self.loss(data, carry, update=False)
_, (outs, carry_out, mets) = self.loss(data, carry, update=False)
metrics.update(mets)

# Open loop predictions
B, T = data['is_first'].shape
num_obs = min(self.config.report_openl_context, T // 2)
# Rerun observe to get the correct intermediate state, because
# outs_to_carry doesn't work with num_obs<context.
img_start, rec_outs = self.dyn.observe(
carry[0],
{k: v[:, :num_obs] for k, v in outs['prevacts'].items()},
outs['embed'][:, :num_obs],
data['is_first'][:, :num_obs])
img_acts = {k: v[:, num_obs:] for k, v in outs['prevacts'].items()}
img_outs = self.dyn.imagine(img_start, img_acts)[1]
rec = dict(
**self.dec(rec_outs), reward=self.rew(rec_outs),
cont=self.con(rec_outs))
img = dict(
**self.dec(img_outs), reward=self.rew(img_outs),
cont=self.con(img_outs))

# Prediction losses
data_img = {k: v[:, num_obs:] for k, v in data.items()}
losses = {k: -v.log_prob(data_img[k].astype(f32)) for k, v in img.items()}
metrics.update({f'openl_{k}_loss': v.mean() for k, v in losses.items()})
stats = jaxutils.balance_stats(img['reward'], data_img['reward'], 0.1)
metrics.update({f'openl_reward_{k}': v for k, v in stats.items()})
stats = jaxutils.balance_stats(img['cont'], data_img['cont'], 0.5)
metrics.update({f'openl_cont_{k}': v for k, v in stats.items()})

# Video predictions
future_acts = {k: v[:6, 8:] for k, v in outs['prevacts'].items()}
context_outs = {k: v[:6, :8] for k, v in outs['replay_outs'].items()}
start = self.dyn.outs_to_carry(context_outs)
recon = self.dec(context_outs)
openl = self.dec(self.dyn.imagine(start, future_acts)[1])
for key in self.dec.imgkeys:
true = f32(data[key][:6])
pred = jnp.concatenate([
recon[key].mode()[:, :8], openl[key].mode()], 1)
pred = jnp.concatenate([rec[key].mode()[:6], img[key].mode()[:6]], 1)
error = (pred - true + 1) / 2
video = jnp.concatenate([true, pred, error], 2)
metrics[f'openloop/{key}'] = jaxutils.video_grid(video)
Expand All @@ -457,7 +463,7 @@ def report(self, data):
except KeyError:
print(f'Skipping gradnorm summary for missing loss: {key}')

return metrics
return metrics, carry_out

def preprocess(self, obs):
spaces = {**self.obs_space, **self.act_space, **self.aux_spaces}
Expand All @@ -471,5 +477,3 @@ def preprocess(self, obs):
result[key] = value
result['cont'] = 1.0 - f32(result['is_terminal'])
return result


13 changes: 6 additions & 7 deletions dreamerv3/configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ defaults:
priosignal: model
recexp: 1.0
chunksize: 1024
debug_save_wait: False
save_wait: False

jax:
platform: gpu
Expand Down Expand Up @@ -91,12 +91,12 @@ defaults:
batch_size: 16
batch_length: 65
batch_length_eval: 33
replay_length: 0
replay_length_eval: 0
replay_context: 1
random_agent: False
loss_scales: {dec_cnn: 1.0, dec_mlp: 1.0, reward: 1.0, cont: 1.0, dyn: 1.0, rep: 0.1, actor: 1.0, critic: 1.0, replay_critic: 0.3}
opt: {scaler: rms, lr: 4e-5, eps: 1e-20, momentum: True, wd: 0.0, warmup: 1000, globclip: 0.0, agc: 0.3, beta1: 0.9, beta2: 0.999, details: False, pmin: 1e-3, anneal: 0, schedule: constant}
compute_lr: False
compute_lr_params: {lnwidth: -1.195, lnreplay: -0.529, bias: 0}
separate_lrs: False
lrs: {dec: 1e-4, enc: 1e-4, dyn: 1e-4, rew: 1e-4, con: 1e-4, actor: 3e-5, critic: 3e-5}
ac_grads: none
Expand All @@ -105,7 +105,7 @@ defaults:
replay_critic_grad: True
replay_critic_bootstrap: imag
reward_grad: True
winit_scale: 0.0
report_openl_context: 8

# World Model
dyn:
Expand All @@ -114,11 +114,11 @@ defaults:
enc:
spaces: '.*'
typ: simple
simple: {depth: 64, mults: [1, 2, 3, 4, 4], layers: 3, units: 1024, act: silu, norm: rms, winit: normal, symlog: True, debug_outer: True, kernel: 5, minres: 4}
simple: {depth: 64, mults: [1, 2, 3, 4, 4], layers: 3, units: 1024, act: silu, norm: rms, winit: normal, symlog: True, outer: True, kernel: 5, minres: 4}
dec:
spaces: '.*'
typ: simple
simple: {inputs: [deter, stoch], vecdist: symlog_mse, depth: 64, mults: [1, 2, 3, 4, 4], layers: 3, units: 1024, act: silu, norm: rms, outscale: 1.0, winit: normal, debug_outer: True, kernel: 5, minres: 4, block_space: 8, block_fans: False, block_norm: False, hidden_stoch: True, space_hidden: 0}
simple: {inputs: [deter, stoch], vecdist: symlog_mse, depth: 64, mults: [1, 2, 3, 4, 4], layers: 3, units: 1024, act: silu, norm: rms, outscale: 1.0, winit: normal, outer: True, kernel: 5, minres: 4, block_space: 8, block_fans: False, block_norm: False, hidden_stoch: True, space_hidden: 0}
rewhead: {layers: 1, units: 1024, act: silu, norm: rms, dist: symexp_twohot, outscale: 0.0, inputs: [deter, stoch], winit: normal, bins: 255, block_fans: False, block_norm: False}
conhead: {layers: 1, units: 1024, act: silu, norm: rms, dist: binary, outscale: 1.0, inputs: [deter, stoch], winit: normal, block_fans: False, block_norm: False}
contdisc: True
Expand All @@ -144,7 +144,6 @@ defaults:
actent: 3e-4
slowreg: 1.0
slowtar: False
scale_by_actent: False

size12m: &size12m
dyn.rssm: {deter: 2048, hidden: 256, classes: 16}
Expand Down
37 changes: 27 additions & 10 deletions dreamerv3/jaxagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(self, agent_cls, obs_space, act_space, config):
self._train = self._train.compile()
self._report = self._report.compile()
self._stack = jax.jit(lambda xs: jax.tree.map(
jnp.stack, xs, is_leaf=lambda x: isinstance(x, list)))
jnp.stack, xs, is_leaf=lambda x: isinstance(x, list)))
self._split = jax.jit(lambda xs: jax.tree.map(
lambda x: [y[0] for y in jnp.split(x, len(x))], xs))
print('Done compiling train and report!')
Expand All @@ -108,6 +108,12 @@ def init_train(self, batch_size):
carry = self._init_train(self.params, seed, batch_size)
return carry

def init_report(self, batch_size):
seed = self._next_seeds(self.train_sharded)
batch_size //= len(self.train_mesh.devices)
carry = self._init_report(self.params, seed, batch_size)
return carry

@embodied.timer.section('jaxagent_policy')
def policy(self, obs, carry, mode='train'):
obs = self._filter_data(obs)
Expand Down Expand Up @@ -215,14 +221,14 @@ def train(self, data, carry):
return return_outs, carry, return_mets

@embodied.timer.section('jaxagent_report')
def report(self, data):
def report(self, data, carry):
seed = data['seed']
data = self._filter_data(data)
with embodied.timer.section('jit_report'):
with self.train_lock:
mets = self._report(self.params, data, seed)
mets, carry = self._report(self.params, data, carry, seed)
mets = self._take_mets(fetch_async(mets))
return mets
return mets, carry

def dataset(self, generator):
def transform(data):
Expand Down Expand Up @@ -307,10 +313,15 @@ def train(alloc, donated, data, carry, seed):
mets = {k: v[None] for k, v in mets.items()}
return params, outs, carry, mets

def report(params, data, seed):
def init_report(params, seed, batch_size):
pure = nj.pure(self.agent.init_report)
return pure(params, batch_size, seed=seed)[1]

def report(params, data, carry, seed):
pure = nj.pure(self.agent.report)
_, mets = pure(params, data, seed=seed)
return {k: v[None] for k, v in mets.items()}
_, (mets, carry) = pure(params, data, carry, seed=seed)
mets = {k: v[None] for k, v in mets.items()}
return mets, carry

from jax.experimental.shard_map import shard_map
s = jax.sharding.PartitionSpec('i') # sharded
Expand All @@ -330,9 +341,12 @@ def report(params, data, seed):
train = shard_map(
train, self.train_mesh,
(m, m, s, s, s), (m, s, s, m), check_rep=False)
init_report = lambda params, seed, batch_size, fn=init_report: shard_map(
lambda params, seed: fn(params, seed, batch_size),
self.train_mesh, (m, s), s, check_rep=False)(params, seed)
report = shard_map(
report, self.train_mesh,
(m, s, s), m, check_rep=False)
(m, s, s, s), (m, s), check_rep=False)

ps, pm = self.policy_sharded, self.policy_mirrored
self._init_policy = jax.jit(
Expand All @@ -345,8 +359,10 @@ def report(params, data, seed):
init_train, (tm, ts), ts, static_argnames=['batch_size'])
self._train = jax.jit(
train, (tm, tm, ts, ts, ts), (tm, ts, ts, tm), donate_argnums=[1])
self._init_report = jax.jit(
init_report, (tm, ts), ts, static_argnames=['batch_size'])
self._report = jax.jit(
report, (tm, ts, ts), tm)
report, (tm, ts, ts, ts), (tm, ts))

def _take_mets(self, mets):
mets = jax.tree.map(lambda x: x.__array__(), mets)
Expand Down Expand Up @@ -405,7 +421,8 @@ def _lower_report(self):
data = self._dummy_batch(self.spaces, (B, T))
data = jax.device_put(data, self.train_sharded)
seed = self._next_seeds(self.train_sharded)
self._report = self._report.lower(self.params, data, seed)
carry = self.init_report(self.config.batch_size)
self._report = self._report.lower(self.params, data, carry, seed)


def fetch_async(value):
Expand Down
25 changes: 13 additions & 12 deletions dreamerv3/jaxutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def wrapped(*args, **kwargs):
self.put('state', optstate)

if self.details:
metrics.update(self._detailed_stats(optstate, params, updates))
metrics.update(self._detailed_stats(optstate, params, updates, grads))

scale = 1
step = self.step.read().astype(f32)
Expand Down Expand Up @@ -562,18 +562,16 @@ def _update_scale(self, grads, finite):
1e-4, 1e5))
return finite

def _detailed_stats(self, optstate, params, updates):
def _detailed_stats(self, optstate, params, updates, grads):
groups = {
'all': r'.*',
'enc': r'/enc/.*/kernel$',
'dec': r'/dec/.*/kernel$',
'rssm': r'/rssm/.*/kernel$',
'cont': r'/cont/.*/kernel$',
'rew': r'/rew/.*/kernel$',
'actor': r'/actor/.*/kernel$',
'critic': r'/critic/.*/kernel$',
'gru': r'/gru/kernel$',
'bias': r'/bias$',
'enc': r'/enc/.*',
'dec': r'/dec/.*',
'dyn': r'/dyn/.*',
'con': r'/con/.*',
'rew': r'/rew/.*',
'actor': r'/actor/.*',
'critic': r'/critic/.*',
'out': r'/out/kernel$',
'repr': r'/repr_logit/kernel$',
'prior': r'/prior_logit/kernel$',
Expand All @@ -590,15 +588,18 @@ def _detailed_stats(self, optstate, params, updates):
keys = [k for k in params if re.search(pattern, k)]
ps = [params[k] for k in keys]
us = [updates[k] for k in keys]
gs = [grads[k] for k in keys]
if not ps:
continue
metrics.update({f'{k}/{name}': v for k, v in dict(
param_count=jnp.array(np.sum([np.prod(x.shape) for x in ps])),
param_abs_max=jnp.stack([jnp.abs(x).max() for x in ps]).max(),
param_abs_mean=jnp.stack([jnp.abs(x).mean() for x in ps]).mean(),
param_norm=optax.global_norm(ps),
update_abs_max=jnp.stack([jnp.abs(x).max() for x in us]).max(),
update_abs_mean=jnp.stack([jnp.abs(x).mean() for x in us]).mean(),
update_norm=optax.global_norm(us),
grad_norm=optax.global_norm(gs),
).items()})
if stddev is not None:
sc = [stddev[k] for k in keys]
Expand All @@ -611,7 +612,7 @@ def _detailed_stats(self, optstate, params, updates):
prop_max=jnp.stack([x.max() for x in pr]).max(),
prop_min=jnp.stack([x.min() for x in pr]).min(),
prop_mean=jnp.stack([x.mean() for x in pr]).mean(),
).items()})
).items()})
return metrics


Expand Down
Loading

0 comments on commit 2aba861

Please sign in to comment.