Skip to content

Commit

Permalink
finetuning the risk model as we train the risk conditioned policy
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaustubh Mani authored and Kaustubh Mani committed Aug 7, 2023
1 parent 1b868f4 commit 720b671
Showing 1 changed file with 73 additions and 16 deletions.
89 changes: 73 additions & 16 deletions cleanrl/ppo_continuous_action_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.distributions.normal import Normal
from torch.utils.tensorboard import SummaryWriter

Expand Down Expand Up @@ -45,7 +47,7 @@ def parse_args():
help="whether to capture videos of the agent performances (check out `videos` folder)")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="HalfCheetah-v4",
parser.add_argument("--env-id", type=str, default="SafetyCarGoal1Gymnasium-v0",
help="the id of the environment")
parser.add_argument("--early-termination", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="whether to terminate early i.e. when the catastrophe has happened")
Expand Down Expand Up @@ -108,11 +110,16 @@ def parse_args():
help="fear radius for training the risk model")
parser.add_argument("--num-risk-datapoints", type=int, default=10000,
help="fear radius for training the risk model")
parser.add_argument("--update-risk-model", type=int, default=100,
parser.add_argument("--update-risk-model", type=int, default=10000,
help="number of epochs to update the risk model")
parser.add_argument("--risk-sgd-steps", type=int, default=100,
parser.add_argument("--risk-epochs", type=int, default=10,
help="number of epochs to update the risk model")

parser.add_argument("--risk-lr", type=float, default=1e-5,
help="the learning rate of the optimizer")
parser.add_argument("--risk-batch-size", type=int, default=10,
help="number of epochs to update the risk model")
parser.add_argument("--fine-tune-risk", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")

args = parser.parse_args()
args.batch_size = int(args.num_envs * args.num_steps)
Expand Down Expand Up @@ -229,6 +236,53 @@ def get_action_and_value(self, x, action=None):
action = probs.sample()
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)

class RiskDataset(nn.Module):
def __init__(self, inputs, targets):
self.inputs = inputs
self.targets = targets

def __len__(self):
return self.inputs.size()[0]

def __getitem__(self, idx):
y = torch.zeros(2)
y[int(self.targets[idx][0])] = 1.0
return self.inputs[idx], y



def fine_tune_risk(cfg, model, inputs, targets, opt, device):
model.train()
dataset = RiskDataset(inputs, targets)
weight = torch.sum(targets==0) / torch.sum(targets==1)
if cfg.model_type == "bayesian":
criterion = nn.NLLLoss(weight=torch.Tensor([1, weight]).to(device))
else:
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([1, weight]).to(device))

dataloader = DataLoader(dataset, batch_size=cfg.risk_batch_size, shuffle=True, num_workers=10, generator=torch.Generator(device=device))
for epoch in range(cfg.risk_epochs):
net_loss = 0
for batch in dataloader:
pred = model(batch[0].to(device))
if cfg.model_type == "mlp":
loss = criterion(pred, batch[1].to(device))
else:
loss = criterion(pred, torch.argmax(batch[1].squeeze(), axis=1).to(device))
opt.zero_grad()
loss.backward()
opt.step()
# scheduler.step()
net_loss += loss.item()
print("Average Risk training loss: %.4f"%(net_loss / len(dataloader)))

model.eval()
return model






def train(cfg):
# fmt: on
Expand All @@ -252,11 +306,6 @@ def train(cfg):
# "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(cfg.).items()])),
#)

#experiment.add_tag(cfg.tag)
#experiment.log_parameters(cfg.ppo)
#experiment.log_parameters(cfg.risk)
#experiment.log_parameters(cfg.env)
# TRY NOT TO MODIFY: seeding
random.seed(cfg.seed)
np.random.seed(cfg.seed)
torch.manual_seed(cfg.seed)
Expand All @@ -274,14 +323,16 @@ def train(cfg):
risk_model_class = BayesRiskEst
else:
risk_model_class = RiskEst
print(envs.single_observation_space.shape)
# print(envs.single_observation_space.shape)

if cfg.use_risk:
agent = RiskAgent(envs=envs).to(device)
if os.path.exists(cfg.risk_model_path):
risk_model = risk_model_class(obs_size=np.array(envs.single_observation_space.shape).prod())
risk_model = risk_model_class(obs_size=np.array(envs.single_observation_space.shape).prod(), batch_norm=True)
risk_model.load_state_dict(torch.load(cfg.risk_model_path, map_location=device))
risk_model.to(device)
if cfg.fine_tune_risk:
opt_risk = optim.Adam(risk_model.parameters(), lr=cfg.risk_lr, eps=1e-5)
risk_model.eval()
else:
raise("No model in the path specified!!")
Expand All @@ -290,7 +341,7 @@ def train(cfg):

optimizer = optim.Adam(agent.parameters(), lr=cfg.learning_rate, eps=1e-5)

print(envs.single_observation_space.shape)
# print(envs.single_observation_space.shape)
# ALGO Logic: Storage setup
obs = torch.zeros((cfg.num_steps, cfg.num_envs) + envs.single_observation_space.shape).to(device)
actions = torch.zeros((cfg.num_steps, cfg.num_envs) + envs.single_action_space.shape).to(device)
Expand Down Expand Up @@ -323,7 +374,7 @@ def train(cfg):
f_obs = next_obs
f_risks = torch.Tensor([[0.]]).to(device)

print(f_obs.size(), f_risks.size())
# print(f_obs.size(), f_risks.size())


if cfg.collect_data:
Expand Down Expand Up @@ -351,7 +402,7 @@ def train(cfg):
id_risk = torch.argmax(next_risk, axis=1)
next_risk = torch.zeros_like(next_risk)
next_risk[:, id_risk] = 1

# print(next_risk)
risks[step] = next_risk
all_risks[global_step] = torch.argmax(next_risk, axis=-1)

Expand Down Expand Up @@ -379,7 +430,7 @@ def train(cfg):
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)
f_obs = torch.concat([f_obs, next_obs], axis=0)
f_risks = torch.concat([f_risks, risk], axis=0)
print(f_risks.size(), f_obs.size())
# print(f_risks.size(), f_obs.size())

if cost > 0:
f_risks[global_step-cfg.fear_radius:, 0] = 1.
Expand All @@ -391,10 +442,15 @@ def train(cfg):
cost = torch.Tensor(np.array([infos["final_info"][0]["cost"]])).to(device).view(-1)
ep_cost += np.array([infos["final_info"][0]["cost"]]); cum_cost += np.array([infos["final_info"][0]["cost"]])

if global_step % cfg.update_risk_model == 0 and cfg.fine_tune_risk:
fine_tune_risk(cfg, risk_model, f_obs[-cfg.num_risk_datapoints:], f_risks[-cfg.num_risk_datapoints:], opt_risk, device)


# Only print when at least 1 env is done
if "final_info" not in infos:
continue


for info in infos["final_info"]:
# Skip the envs that are not done
if info is None:
Expand All @@ -409,7 +465,7 @@ def train(cfg):
ep_risk = torch.sum(all_risks[last_step:global_step]).item()
cum_risk += ep_risk

risk_cost_int = torch.logical_and(all_costs[last_step:global_step], all_risks[last_step:global_step])
risk_cost_int = torch.logical_and(f_risks[last_step:global_step], all_risks[last_step:global_step])
ep_risk_cost_int = torch.sum(risk_cost_int).item()
cum_risk_cost_int += ep_risk_cost_int

Expand Down Expand Up @@ -475,6 +531,7 @@ def train(cfg):
b_values = values.reshape(-1)
b_risks = risks.reshape((-1, ) + (2, ))


# Optimizing the policy and value network
b_inds = np.arange(cfg.batch_size)
clipfracs = []
Expand Down

0 comments on commit 720b671

Please sign in to comment.