Skip to content

Commit

Permalink
risk model with simhash
Browse files Browse the repository at this point in the history
  • Loading branch information
manila95 committed Mar 18, 2024
1 parent ae6ebdc commit 6125e9a
Showing 1 changed file with 115 additions and 16 deletions.
131 changes: 115 additions & 16 deletions cleanrl/ppo_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from distutils.util import strtobool

import gymnasium as gym
import safety_gymnasium
import numpy as np
import torch
import torch.nn as nn
Expand All @@ -27,16 +28,34 @@ def parse_args():
help="if toggled, cuda will be enabled by default")
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="if toggled, this experiment will be tracked with Weights and Biases")
parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
parser.add_argument("--wandb-project-name", type=str, default="risk_aware_exploration",
help="the wandb's project name")
parser.add_argument("--wandb-entity", type=str, default=None,
parser.add_argument("--wandb-entity", type=str, default="kaustubh_umontreal",
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to capture videos of the agent performances (check out `videos` folder)")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="HalfCheetah-v4",
help="the id of the environment")
parser.add_argument("--reward-goal", type=float, default=10,
help="reward to give when the goal is achieved")
parser.add_argument("--reward-distance", type=float, default=0.0,
help="reward to give when the goal is achieved")
parser.add_argument("--early-termination", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="whether to terminate early i.e. when the catastrophe has happened")
parser.add_argument("--unifying-lidar", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="what kind of sensor is used (same for every environment?)")
parser.add_argument("--term-cost", type=int, default=1,
help="how many violations before you terminate")
parser.add_argument("--failure-penalty", type=float, default=0.0,
help="Reward Penalty when you fail")
parser.add_argument("--collect-data", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="store data while trianing")
parser.add_argument("--storage-path", type=str, default="./data/ppo/term_1",
help="the storage path for the data collected")


parser.add_argument("--total-timesteps", type=int, default=1000000,
help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=3e-4,
Expand Down Expand Up @@ -69,19 +88,29 @@ def parse_args():
help="the maximum norm for the gradient clipping")
parser.add_argument("--target-kl", type=float, default=None,
help="the target KL divergence threshold")

parser.add_argument("--use-risk", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--quantile-num", type=int, default=10,
help="num of quantiles")
parser.add_argument("--quantile-size", type=int, default=4,
help="quantile size")
args = parser.parse_args()
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
# fmt: on
return args


def make_env(env_id, idx, capture_video, run_name, gamma):
def make_env(args, idx, capture_video, run_name, gamma):
def thunk():
if capture_video:
env = gym.make(env_id, render_mode="rgb_array")
if "safety" in args.env_id.lower():
env = gym.make(args.env_id, early_termination=args.early_termination, reward_goal=args.reward_goal, reward_distance=args.reward_distance)
else:
env = gym.make(env_id)
if capture_video:
env = gym.make(args.env_id, render_mode="rgb_array")
else:
env = gym.make(args.env_id)
env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
Expand All @@ -104,17 +133,17 @@ def layer_init(layer, std=np.sqrt(2), bias_const=0.0):


class Agent(nn.Module):
def __init__(self, envs):
def __init__(self, envs, risk_size=0):
super().__init__()
self.critic = nn.Sequential(
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod()+risk_size, 64)),
nn.Tanh(),
layer_init(nn.Linear(64, 64)),
nn.Tanh(),
layer_init(nn.Linear(64, 1), std=1.0),
)
self.actor_mean = nn.Sequential(
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod()+risk_size, 64)),
nn.Tanh(),
layer_init(nn.Linear(64, 64)),
nn.Tanh(),
Expand All @@ -125,7 +154,9 @@ def __init__(self, envs):
def get_value(self, x):
return self.critic(x)

def get_action_and_value(self, x, action=None):
def get_action_and_value(self, x, risk=None, action=None):
if risk is not None:
x = torch.cat([x, risk], axis=-1)
action_mean = self.actor_mean(x)
action_logstd = self.actor_logstd.expand_as(action_mean)
action_std = torch.exp(action_logstd)
Expand All @@ -135,6 +166,38 @@ def get_action_and_value(self, x, action=None):
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)


class SimHash(object) :
def __init__(self, state_emb, k, quantile_num, device, count_thresh=0,) :
''' Hashing between continuous state space and discrete state space '''
self.hash = {}
self.A = np.random.normal(0,1, (k , state_emb))
self.device = device
self.uniform_risk = np.array([1 / float(quantile_num)]*quantile_num)
self.count_thresh = count_thresh

def update_risk(self, states, risks) :
''' Increase the count for the states and retourn the counts '''
counts = []
for i, state in enumerate(states):
# print(risks[i])
key = str(np.sign(self.A @ state.detach().cpu().numpy()).tolist())
if key in self.hash:
self.hash[key] += risks[i]
else:
self.hash[key] = risks[i]

def get_risk(self, states):
risks = []
for state in states:
key = str(np.sign(self.A @ state.detach().cpu().numpy()).tolist())
if key in self.hash and np.sum(self.hash[key]) > self.count_thresh:
risks.append(self.hash[key])
else:
risks.append(self.uniform_risk)
return torch.from_numpy(np.array(risks)).to(self.device)



if __name__ == "__main__":
args = parse_args()
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
Expand Down Expand Up @@ -166,13 +229,19 @@ def get_action_and_value(self, x, action=None):

# env setup
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, i, args.capture_video, run_name, args.gamma) for i in range(args.num_envs)]
[make_env(args, i, args.capture_video, run_name, args.gamma) for i in range(args.num_envs)]
)
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"

agent = Agent(envs).to(device)

agent = Agent(envs, args.quantile_num if args.use_risk else 0).to(device)

optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

if args.use_risk:
risk_bins = np.array([i*args.quantile_size for i in range(args.quantile_num+1)])
risk_hash = SimHash(np.array(envs.single_observation_space.shape).prod(), 10, args.quantile_num, device, 0)

# ALGO Logic: Storage setup
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
Expand All @@ -189,6 +258,8 @@ def get_action_and_value(self, x, action=None):
next_done = torch.zeros(args.num_envs).to(device)
num_updates = args.total_timesteps // args.batch_size

ep_states = [None]*args.num_envs

for update in range(1, num_updates + 1):
# Annealing the rate if instructed to do so.
if args.anneal_lr:
Expand All @@ -203,32 +274,54 @@ def get_action_and_value(self, x, action=None):

# ALGO LOGIC: action logic
with torch.no_grad():
action, logprob, _, value = agent.get_action_and_value(next_obs)
risks = risk_hash.get_risk(next_obs).float() if args.use_risk else None
# print(risks.size())
action, logprob, _, value = agent.get_action_and_value(next_obs, risks)
values[step] = value.flatten()
actions[step] = action
logprobs[step] = logprob

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, reward, terminated, truncated, infos = envs.step(action.cpu().numpy())

done = np.logical_or(terminated, truncated)
rewards[step] = torch.tensor(reward).to(device).view(-1)
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)

if args.use_risk:
for i in range(args.num_envs):
ep_states[i] = next_obs[i].unsqueeze(0) if ep_states[i] is None else torch.concat([ep_states[i], next_obs[i].unsqueeze(0)], axis=0)

# Only print when at least 1 env is done
if "final_info" not in infos:
continue

for info in infos["final_info"]:
for i, info in enumerate(infos["final_info"]):
# Skip the envs that are not done
if info is None:
continue

ep_len = info["episode"]["l"]
cost = info["cost_sum"]
if args.use_risk:
ep_risks = np.array(list(reversed(range(int(ep_len))))) if cost > 0 else np.array([int(ep_len)]*int(ep_len))
ep_risks_quant = np.apply_along_axis(lambda x: np.histogram(x, bins=risk_bins)[0], 1, np.expand_dims(ep_risks, 1))
print(ep_risks_quant.shape)
risk_hash.update_risk(ep_states[i], ep_risks_quant)
ep_states[i] = None

print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)

# bootstrap value if not done
with torch.no_grad():
next_value = agent.get_value(next_obs).reshape(1, -1)
if args.use_risk:
next_risk = risk_hash.get_risk(next_obs) if args.use_risk else None
next_value = agent.get_value(torch.cat([next_obs, next_risk.float()], -1)).reshape(1, -1)
else:
next_value = agent.get_value(next_obs).reshape(1, -1)

advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
for t in reversed(range(args.num_steps)):
Expand All @@ -250,6 +343,8 @@ def get_action_and_value(self, x, action=None):
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)

b_risks = risk_hash.get_risk(b_obs).float() if args.use_risk else None

# Optimizing the policy and value network
b_inds = np.arange(args.batch_size)
clipfracs = []
Expand All @@ -259,7 +354,11 @@ def get_action_and_value(self, x, action=None):
end = start + args.minibatch_size
mb_inds = b_inds[start:end]

_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])
if args.use_risk:
_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_risks[mb_inds], b_actions[mb_inds])
else:
_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], None, b_actions[mb_inds])

logratio = newlogprob - b_logprobs[mb_inds]
ratio = logratio.exp()

Expand Down

0 comments on commit 6125e9a

Please sign in to comment.