Skip to content

Commit

Permalink
Add scores
Browse files Browse the repository at this point in the history
  • Loading branch information
danijar committed May 12, 2023
1 parent c64680b commit 423291a
Show file tree
Hide file tree
Showing 21 changed files with 241 additions and 0 deletions.
51 changes: 51 additions & 0 deletions scores/clean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import gzip
import json
import pathlib
import re
import sys

filenames = list((pathlib.Path(__file__).parent / 'data').glob('*.json.gz'))

if len(sys.argv) > 1:
filenames = [x for x in filenames if re.search(sys.argv[1], x.name)]

for filename in sorted(filenames):
print(filename.name)

with gzip.open(filename, 'rb') as f:
runs = original = json.load(f)
edited = False

runs = [r for r in runs if not r['task'].startswith('stats_')]

tasks = sorted(set(run['task'] for run in runs))
methods = sorted(set(run['method'] for run in runs))
seeds = sorted(set(run['seed'] for run in runs))

new = sorted([str(x) for x in range(len(seeds))])
renames = {k: v for k, v in zip(seeds, new) if k != v}
for run in runs:
if run['seed'] in renames:
run['seed'] = renames[run['seed']]
edited = True

if filename.name.startswith('atari200m'):
for run in runs:
if run['task'] == 'atari_james_bond':
run['task'] = 'atari_jamesbond'
edited = True

# if filename.name.startswith('...'):
# for run in runs:
# keep = len([x for x in run['xs'] if x <= 1e6])
# if keep < len(run['xs']):
# run['xs'] = run['xs'][:keep]
# run['ys'] = run['ys'][:keep]
# edited = True

runs = sorted(runs, key=lambda x: ((x['task'], x['method'], x['seed'])))

if (runs != original) or edited:
print(f'Writing changes')
with gzip.open(filename, 'wb') as f:
f.write(json.dumps(runs).encode('utf-8'))
Binary file added scores/data/atari100k_dreamerv3.json.gz
Binary file not shown.
Binary file added scores/data/atari200m_c51.json.gz
Binary file not shown.
Binary file added scores/data/atari200m_dqn.json.gz
Binary file not shown.
Binary file added scores/data/atari200m_dreamerv2.json.gz
Binary file not shown.
Binary file added scores/data/atari200m_dreamerv3.json.gz
Binary file not shown.
Binary file added scores/data/atari200m_iqn.json.gz
Binary file not shown.
Binary file added scores/data/atari200m_rainbow.json.gz
Binary file not shown.
Binary file added scores/data/dmcproprio_d4pg.json.gz
Binary file not shown.
Binary file added scores/data/dmcproprio_dmpo.json.gz
Binary file not shown.
Binary file added scores/data/dmcproprio_dreamerv3.json.gz
Binary file not shown.
Binary file added scores/data/dmcproprio_mpo.json.gz
Binary file not shown.
Binary file added scores/data/dmcvision_curl.json.gz
Binary file not shown.
Binary file added scores/data/dmcvision_dreamerv3.json.gz
Binary file not shown.
Binary file added scores/data/dmcvision_drq.json.gz
Binary file not shown.
Binary file added scores/data/dmcvision_drqv2.json.gz
Binary file not shown.
Binary file added scores/data/dmcvision_sac.json.gz
Binary file not shown.
Binary file added scores/data/dmlab_dreamerv3.json.gz
Binary file not shown.
Binary file added scores/data/dmlab_impala.json.gz
Binary file not shown.
164 changes: 164 additions & 0 deletions scores/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import gzip
import json
import pathlib
import re
import sys
import warnings

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import ticker


COLORS = (
'#377eb8', '#4daf4a', '#984ea3', '#e41a1c', '#ff7f00', '#a65628',
'#f781bf', '#888888', '#a6cee3', '#b2df8a', '#cab2d6', '#fb9a99',
)


def plots(
amount, cols=4, size=(2, 2.3), xticks=4, yticks=5, grid=(1, 1), **kwargs):
rows = int(np.ceil(amount / cols))
size = (cols * size[0], rows * size[1])
fig, axes = plt.subplots(rows, cols, figsize=size, squeeze=False, **kwargs)
axes = axes.flatten()
for ax in axes:
ax.xaxis.set_major_locator(ticker.MaxNLocator(xticks))
ax.yaxis.set_major_locator(ticker.MaxNLocator(yticks))
if grid:
grid = (grid, grid) if not hasattr(grid, '__len__') else grid
ax.grid(which='both', color='#eeeeee')
ax.xaxis.set_minor_locator(ticker.AutoMinorLocator(int(grid[0])))
ax.yaxis.set_minor_locator(ticker.AutoMinorLocator(int(grid[1])))
ax.tick_params(which='minor', length=0)
for ax in axes[amount:]:
ax.axis('off')
return fig, axes


def curve(
ax, domain, values, low=None, high=None, label=None, order=0, **kwargs):
finite = np.isfinite(values)
ax.plot(
domain[finite], values[finite],
label=label, zorder=1000 - order, **kwargs)
if low is not None:
ax.fill_between(
domain[finite], low[finite], high[finite],
zorder=100 - order, alpha=0.2, lw=0, **kwargs)


def legend(fig, mapping=None, adjust=False, **kwargs):
options = dict(
fontsize='medium', numpoints=1, labelspacing=0, columnspacing=1.2,
handlelength=1.5, handletextpad=0.5, ncol=4, loc='lower center')
options.update(kwargs)
# Find all labels and remove duplicates.
entries = {}
for ax in fig.axes:
for handle, label in zip(*ax.get_legend_handles_labels()):
if mapping and label in mapping:
label = mapping[label]
entries[label] = handle
leg = fig.legend(entries.values(), entries.keys(), **options)
leg.get_frame().set_edgecolor('white')
if adjust is not False:
pad = adjust if isinstance(adjust, (int, float)) else 0.5
extent = leg.get_window_extent(fig.canvas.get_renderer())
extent = extent.transformed(fig.transFigure.inverted())
yloc, xloc = options['loc'].split()
y0 = dict(lower=extent.y1, center=0, upper=0)[yloc]
y1 = dict(lower=1, center=1, upper=extent.y0)[yloc]
x0 = dict(left=extent.x1, center=0, right=0)[xloc]
x1 = dict(left=1, center=1, right=extent.x0)[xloc]
fig.tight_layout(rect=[x0, y0, x1, y1], h_pad=pad, w_pad=pad)


def binning(xs, ys, borders, reducer=np.nanmean, fill='nan'):
assert fill in ('nan', 'last', 'zeros')
xs = xs if isinstance(xs, np.ndarray) else np.asarray(xs)
ys = ys if isinstance(ys, np.ndarray) else np.asarray(ys)
order = np.argsort(xs)
xs, ys = xs[order], ys[order]
binned = []
for start, stop in zip(borders[:-1], borders[1:]):
left = (xs <= start).sum()
right = (xs <= stop).sum()
value = np.nan
if left < right:
value = reduce(ys[left:right], reducer)
if np.isnan(value):
if fill == 'zeros':
value = 0
if fill == 'last' and binned:
value = binned[-1]
binned.append(value)
return borders[1:], np.array(binned)


def reduce(values, reducer=np.nanmean, *args, **kwargs):
with warnings.catch_warnings(): # Buckets can be empty.
warnings.simplefilter('ignore', category=RuntimeWarning)
return reducer(values, *args, **kwargs)


datadir = pathlib.Path(__file__).parent / 'data'
outdir = pathlib.Path(__file__).parent / 'figs'
outdir.mkdir(exist_ok=True)

suites = sorted(set(x.name.split('_')[0] for x in datadir.glob('*.json.gz')))

if len(sys.argv) > 1:
suites = [x for x in suites if re.search(sys.argv[1], x)]
print(f'Pattern matches {len(suites)} suites: {", ".join(suites)}')

for suite in suites:
print('-' * 79)
print(suite)
print('-' * 79)

runs = []
for filename in datadir.glob(f'{suite}_*.json.gz'):
with gzip.open(filename, 'rb') as f:
runs += json.load(f)

tasks = sorted(set(run['task'] for run in runs))
methods = sorted(set(run['method'] for run in runs))
seeds = sorted(set(run['seed'] for run in runs))

fig, axes = plots(len(tasks), cols=6, size=(2, 2))
for i, task in enumerate(tasks):
ax = axes[i]

title = task.split('_', 1)[-1]
title = title.replace('_', ' ').title()
ax.set_title(title)

for j, method in enumerate(methods):
relevant = [run for run in runs if (
run['task'] == task and run['method'] == method)]
if not relevant:
print(f'No runs for {method} on {task}')
continue
lo = min([min(run['xs']) for run in relevant])
hi = max([max(run['xs']) for run in relevant])
borders = np.linspace(lo, hi, 30)
scores = []
for run in relevant:
scores.append(binning(run['xs'], run['ys'], borders, fill='last')[1])
mean = np.nanmean(scores, 0)
std = np.nanstd(scores, 0)
curve(
ax, borders[1:], mean, mean - std, mean + std,
label=method, order=j, color=COLORS[j])

ax.tick_params(
axis='both', which='major', labelsize='small', pad=1, length=1)
ax.ticklabel_format(
axis='x', style='sci', scilimits=(-2, 2))
legend(fig, adjust=1)

filename = outdir / (suite + '.png')
fig.savefig(filename, dpi=300)
print('Saved', filename)
print('')
26 changes: 26 additions & 0 deletions scores/view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import gzip
import json
import pathlib
import re
import sys

filenames = list((pathlib.Path(__file__).parent / 'data').glob('*.json.gz'))

if len(sys.argv) > 1:
filenames = [x for x in filenames if re.search(sys.argv[1], x.name)]

for filename in sorted(filenames):
print('-' * 79)
print(filename.name)
print('-' * 79)
with gzip.open(filename) as f:
runs = json.load(f)
tasks = sorted(set(run['task'] for run in runs))
methods = sorted(set(run['method'] for run in runs))
seeds = sorted(set(run['seed'] for run in runs))
print(f'Methods ({len(methods)}):', ', '.join(methods))
print(f'Seeds ({len(seeds)}):', ', '.join(seeds))
print(f'Tasks ({len(tasks)}):', ', '.join(tasks))
print('Possible combinations:', len(tasks) * len(methods) * len(seeds))
print('Runs:', len(runs))
print('')

0 comments on commit 423291a

Please sign in to comment.