Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(zjow): add Implicit Q-Learning #821

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
361 changes: 361 additions & 0 deletions ding/model/template/qvac.py

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions ding/policy/command_mode_policy_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

from .d4pg import D4PGPolicy
from .cql import CQLPolicy, DiscreteCQLPolicy
from .iql import IQLPolicy
from .dt import DTPolicy
from .pdqn import PDQNPolicy
from .madqn import MADQNPolicy
Expand Down Expand Up @@ -322,6 +323,11 @@ class CQLCommandModePolicy(CQLPolicy, DummyCommandModePolicy):
pass


@POLICY_REGISTRY.register('iql_command')
class IQLCommandModePolicy(IQLPolicy, DummyCommandModePolicy):
pass


@POLICY_REGISTRY.register('discrete_cql_command')
class DiscreteCQLCommandModePolicy(DiscreteCQLPolicy, EpsCommandModePolicy):
pass
Expand Down
646 changes: 646 additions & 0 deletions ding/policy/iql.py

Large diffs are not rendered by default.

32 changes: 32 additions & 0 deletions ding/utils/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,38 @@ def __init__(self, cfg: dict) -> None:
except (KeyError, AttributeError):
# do not normalize
pass
if hasattr(cfg.env, "reward_norm"):
if cfg.env.reward_norm == "normalize":
dataset['rewards'] = (dataset['rewards'] - dataset['rewards'].mean()) / dataset['rewards'].std()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a eps

elif cfg.env.reward_norm == "iql_antmaze":
dataset['rewards'] = dataset['rewards'] - 1.0
elif cfg.env.reward_norm == "iql_locomotion":

def return_range(dataset, max_episode_steps):
returns, lengths = [], []
ep_ret, ep_len = 0.0, 0
for r, d in zip(dataset["rewards"], dataset["terminals"]):
ep_ret += float(r)
ep_len += 1
if d or ep_len == max_episode_steps:
returns.append(ep_ret)
lengths.append(ep_len)
ep_ret, ep_len = 0.0, 0
# returns.append(ep_ret) # incomplete trajectory
lengths.append(ep_len) # but still keep track of number of steps
assert sum(lengths) == len(dataset["rewards"])
return min(returns), max(returns)

min_ret, max_ret = return_range(dataset, 1000)
dataset['rewards'] /= max_ret - min_ret
dataset['rewards'] *= 1000
elif cfg.env.reward_norm == "cql_antmaze":
dataset['rewards'] = (dataset['rewards'] - 0.5) * 4.0
elif cfg.env.reward_norm == "antmaze":
dataset['rewards'] = (dataset['rewards'] - 0.25) * 2.0
else:
raise NotImplementedError

self._data = []
self._load_d4rl(dataset)

Expand Down
54 changes: 54 additions & 0 deletions dizoo/d4rl/config/halfcheetah_medium_expert_iql_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# You can conduct Experiments on D4RL with this config file through the following command:
# cd ../entry && python d4rl_iql_main.py
from easydict import EasyDict

main_config = dict(
exp_name="halfcheetah_medium_expert_iql_seed0",
env=dict(
env_id='halfcheetah-medium-expert-v2',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
reward_norm="iql_locomotion",
),
policy=dict(
cuda=True,
model=dict(
obs_shape=17,
action_shape=6,

),
learn=dict(
data_path=None,
train_epoch=30000,
batch_size=4096,
learning_rate_q=3e-4,
learning_rate_policy=1e-4,
beta=0.05,
tau=0.7,
),
collect=dict(data_type='d4rl', ),
eval=dict(evaluator=dict(eval_freq=5000, )),
other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
),
)

main_config = EasyDict(main_config)
main_config = main_config

create_config = dict(
env=dict(
type='d4rl',
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='base'),
policy=dict(
type='iql',
import_names=['ding.policy.iql'],
),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
54 changes: 54 additions & 0 deletions dizoo/d4rl/config/halfcheetah_medium_iql_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# You can conduct Experiments on D4RL with this config file through the following command:
# cd ../entry && python d4rl_iql_main.py
from easydict import EasyDict

main_config = dict(
exp_name="halfcheetah_medium_iql_seed0",
env=dict(
env_id='halfcheetah-medium-v2',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
reward_norm="iql_locomotion",
),
policy=dict(
cuda=True,
model=dict(
obs_shape=17,
action_shape=6,

),
learn=dict(
data_path=None,
train_epoch=30000,
batch_size=4096,
learning_rate_q=3e-4,
learning_rate_policy=1e-4,
beta=0.05,
tau=0.7,
),
collect=dict(data_type='d4rl', ),
eval=dict(evaluator=dict(eval_freq=5000, )),
other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why replay buffer here

),
)

main_config = EasyDict(main_config)
main_config = main_config

create_config = dict(
env=dict(
type='d4rl',
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='base'),
policy=dict(
type='iql',
import_names=['ding.policy.iql'],
),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
54 changes: 54 additions & 0 deletions dizoo/d4rl/config/halfcheetah_medium_replay_iql_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# You can conduct Experiments on D4RL with this config file through the following command:
# cd ../entry && python d4rl_iql_main.py
from easydict import EasyDict

main_config = dict(
exp_name="halfcheetah_medium_replay_iql_seed0",
env=dict(
env_id='halfcheetah-medium-replay-v2',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
reward_norm="iql_locomotion",
),
policy=dict(
cuda=True,
model=dict(
obs_shape=17,
action_shape=6,

),
learn=dict(
data_path=None,
train_epoch=30000,
batch_size=4096,
learning_rate_q=3e-4,
learning_rate_policy=1e-4,
beta=0.05,
tau=0.7,
),
collect=dict(data_type='d4rl', ),
eval=dict(evaluator=dict(eval_freq=5000, )),
other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
),
)

main_config = EasyDict(main_config)
main_config = main_config

create_config = dict(
env=dict(
type='d4rl',
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='base'),
policy=dict(
type='iql',
import_names=['ding.policy.iql'],
),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
54 changes: 54 additions & 0 deletions dizoo/d4rl/config/hopper_medium_expert_iql_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# You can conduct Experiments on D4RL with this config file through the following command:
# cd ../entry && python d4rl_iql_main.py
from easydict import EasyDict

main_config = dict(
exp_name="hopper_medium_expert_iql_seed0",
env=dict(
env_id='hopper-medium-expert-v2',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
reward_norm="iql_locomotion",
),
policy=dict(
cuda=True,
model=dict(
obs_shape=11,
action_shape=3,

),
learn=dict(
data_path=None,
train_epoch=30000,
batch_size=4096,
learning_rate_q=3e-4,
learning_rate_policy=1e-4,
beta=0.05,
tau=0.7,
),
collect=dict(data_type='d4rl', ),
eval=dict(evaluator=dict(eval_freq=5000, )),
other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
),
)

main_config = EasyDict(main_config)
main_config = main_config

create_config = dict(
env=dict(
type='d4rl',
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='base'),
policy=dict(
type='iql',
import_names=['ding.policy.iql'],
),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
54 changes: 54 additions & 0 deletions dizoo/d4rl/config/hopper_medium_iql_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# You can conduct Experiments on D4RL with this config file through the following command:
# cd ../entry && python d4rl_iql_main.py
from easydict import EasyDict

main_config = dict(
exp_name="hopper_medium_iql_seed0",
env=dict(
env_id='hopper-medium-v2',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
reward_norm="iql_locomotion",
),
policy=dict(
cuda=True,
model=dict(
obs_shape=11,
action_shape=3,

),
learn=dict(
data_path=None,
train_epoch=30000,
batch_size=4096,
learning_rate_q=3e-4,
learning_rate_policy=1e-4,
beta=0.05,
tau=0.7,
),
collect=dict(data_type='d4rl', ),
eval=dict(evaluator=dict(eval_freq=5000, )),
other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
),
)

main_config = EasyDict(main_config)
main_config = main_config

create_config = dict(
env=dict(
type='d4rl',
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='base'),
policy=dict(
type='iql',
import_names=['ding.policy.iql'],
),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
54 changes: 54 additions & 0 deletions dizoo/d4rl/config/hopper_medium_replay_iql_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# You can conduct Experiments on D4RL with this config file through the following command:
# cd ../entry && python d4rl_iql_main.py
from easydict import EasyDict

main_config = dict(
exp_name="hopper_medium_replay_iql_seed0",
env=dict(
env_id='hopper-medium-replay-v2',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
reward_norm="iql_locomotion",
),
policy=dict(
cuda=True,
model=dict(
obs_shape=11,
action_shape=3,

),
learn=dict(
data_path=None,
train_epoch=30000,
batch_size=4096,
learning_rate_q=3e-4,
learning_rate_policy=1e-4,
beta=0.05,
tau=0.7,
),
collect=dict(data_type='d4rl', ),
eval=dict(evaluator=dict(eval_freq=5000, )),
other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
),
)

main_config = EasyDict(main_config)
main_config = main_config

create_config = dict(
env=dict(
type='d4rl',
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='base'),
policy=dict(
type='iql',
import_names=['ding.policy.iql'],
),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
Loading
Loading