-
Notifications
You must be signed in to change notification settings - Fork 309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Missing architectures jsonnet files #4
Comments
My bad, that specific comment was leftover from the internal version of the file - removed now.
I believe all of the policies we provide have the same architecture we describe in the paper. |
Just to make it easier for others (since the code depends on an old gym version and the paper is somewhat underspecified), here's the json for [
{
"activation": "relu",
"filters": 9,
"kernel_size": 3,
"layer_type": "circ_conv1d",
"nodes_in": ["lidar"],
"nodes_out": ["lidar"]
},
{
"layer_type": "flatten_outer",
"nodes_in": ["lidar"],
"nodes_out": ["lidar"]
},
{
"layer_type": "concat",
"nodes_in": ["main", "lidar"],
"nodes_out": ["main"]
},
{
"layer_type": "concat",
"nodes_in": ["main", "agent_qpos_qvel"],
"nodes_out": ["agent_qpos_qvel"]
},
{
"layer_type": "concat",
"nodes_in": ["main", "box_obs"],
"nodes_out": ["box_obs"]
},
{
"layer_type": "concat",
"nodes_in": ["main", "ramp_obs"],
"nodes_out": ["ramp_obs"]
},
{
"activation": "relu",
"layer_type": "dense",
"nodes_in": ["agent_qpos_qvel", "box_obs", "ramp_obs", "main"],
"nodes_out": ["agent_qpos_qvel", "box_obs", "ramp_obs", "main"],
"units": 128
},
{
"layer_type": "entity_concat",
"mask_out": "objects_mask",
"masks_in": ["mask_aa_obs", "mask_ab_obs", "mask_ar_obs", None],
"nodes_in": ["agent_qpos_qvel", "box_obs", "ramp_obs", "main"],
"nodes_out": ["objects"]
},
{
"heads": 4,
"layer_norm": True,
"layer_type": "residual_sa_block",
"mask": "objects_mask",
"n_embd": 128,
"n_mlp": 1,
"nodes_in": ["objects"],
"nodes_out": ["objects"],
"post_sa_layer_norm": True
},
{
"layer_type": "entity_pooling",
"mask": "objects_mask",
"nodes_in": ["objects"],
"nodes_out": ["objects_pooled"]
},
{
"layer_type": "concat",
"nodes_in": ["main", "objects_pooled"],
"nodes_out": ["main"]
},
{ "layer_type": "layernorm" },
{ "activation": "relu", "layer_type": "dense", "units": 256 },
{ "layer_type": "layernorm" },
{ "layer_type": "lstm", "units": 256 },
{ "layer_type": "layernorm" }
] script: import numpy as np
import pickle
from pprint import pprint
policy_path = "examples/hide_and_seek_full.npz"
policy_dict = dict(np.load(policy_path))
policy_fn_and_args_raw = pickle.loads(policy_dict['policy_fn_and_args'])
policy_args = policy_fn_and_args_raw['args'] # this contains the arguments fed in to the policy class
network_spec = policy_args['network_spec'] # this contains the network architecture
pprint(network_spec) |
ma_policy/graph_construct.py specifies that file mas/ppo/base-architectures.jsonnet contains example architectures, to the best of my ability I can't find that file in the repository.
The text was updated successfully, but these errors were encountered: