Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing architectures jsonnet files #4

Open
jeremiahvpratt opened this issue Sep 20, 2019 · 2 comments
Open

Missing architectures jsonnet files #4

jeremiahvpratt opened this issue Sep 20, 2019 · 2 comments

Comments

@jeremiahvpratt
Copy link

ma_policy/graph_construct.py specifies that file mas/ppo/base-architectures.jsonnet contains example architectures, to the best of my ability I can't find that file in the repository.

@todor-markov
Copy link
Contributor

My bad, that specific comment was leftover from the internal version of the file - removed now.
Currently, if you want to see the architecture for a given policy, you should load the .npz file for that policy and extract it from there:

policy_dict = dict(np.load(policy_path))
policy_fn_and_args_raw = pickle.loads(policy_dict['policy_fn_and_args'])
policy_args = policy_fn_and_args_raw['args']  # this contains the arguments fed in to the policy class
network_spec = policy_args['network_spec']  # this contains the network architecture
print(network_spec)

I believe all of the policies we provide have the same architecture we describe in the paper.

@phiresky
Copy link

Just to make it easier for others (since the code depends on an old gym version and the paper is somewhat underspecified), here's the json for hide_and_seek_full :

[
  {
    "activation": "relu",
    "filters": 9,
    "kernel_size": 3,
    "layer_type": "circ_conv1d",
    "nodes_in": ["lidar"],
    "nodes_out": ["lidar"]
  },
  {
    "layer_type": "flatten_outer",
    "nodes_in": ["lidar"],
    "nodes_out": ["lidar"]
  },
  {
    "layer_type": "concat",
    "nodes_in": ["main", "lidar"],
    "nodes_out": ["main"]
  },
  {
    "layer_type": "concat",
    "nodes_in": ["main", "agent_qpos_qvel"],
    "nodes_out": ["agent_qpos_qvel"]
  },
  {
    "layer_type": "concat",
    "nodes_in": ["main", "box_obs"],
    "nodes_out": ["box_obs"]
  },
  {
    "layer_type": "concat",
    "nodes_in": ["main", "ramp_obs"],
    "nodes_out": ["ramp_obs"]
  },
  {
    "activation": "relu",
    "layer_type": "dense",
    "nodes_in": ["agent_qpos_qvel", "box_obs", "ramp_obs", "main"],
    "nodes_out": ["agent_qpos_qvel", "box_obs", "ramp_obs", "main"],
    "units": 128
  },
  {
    "layer_type": "entity_concat",
    "mask_out": "objects_mask",
    "masks_in": ["mask_aa_obs", "mask_ab_obs", "mask_ar_obs", None],
    "nodes_in": ["agent_qpos_qvel", "box_obs", "ramp_obs", "main"],
    "nodes_out": ["objects"]
  },
  {
    "heads": 4,
    "layer_norm": True,
    "layer_type": "residual_sa_block",
    "mask": "objects_mask",
    "n_embd": 128,
    "n_mlp": 1,
    "nodes_in": ["objects"],
    "nodes_out": ["objects"],
    "post_sa_layer_norm": True
  },
  {
    "layer_type": "entity_pooling",
    "mask": "objects_mask",
    "nodes_in": ["objects"],
    "nodes_out": ["objects_pooled"]
  },
  {
    "layer_type": "concat",
    "nodes_in": ["main", "objects_pooled"],
    "nodes_out": ["main"]
  },
  { "layer_type": "layernorm" },
  { "activation": "relu", "layer_type": "dense", "units": 256 },
  { "layer_type": "layernorm" },
  { "layer_type": "lstm", "units": 256 },
  { "layer_type": "layernorm" }
]

script:

import numpy as np
import pickle
from pprint import pprint
policy_path = "examples/hide_and_seek_full.npz"
policy_dict = dict(np.load(policy_path))
policy_fn_and_args_raw = pickle.loads(policy_dict['policy_fn_and_args'])
policy_args = policy_fn_and_args_raw['args']  # this contains the arguments fed in to the policy class
network_spec = policy_args['network_spec']  # this contains the network architecture
pprint(network_spec)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants