-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
106 lines (86 loc) · 3.09 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#%%
import torch;
torch.__version__
print("Hi")
#%%
#*******************************************************************************
# Imports and Setup
#*******************************************************************************
# packages
import argparse
import os
import pandas as pd
import pickle
import shutil
import tqdm
import yaml
# torch imports
import torch
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter
# nflows imports
from nflows.flows.base import Flow
from nflows.distributions.normal import StandardNormal
# file imports
from construct_transform import create_transform
# set seed for reproducibility
torch.manual_seed(0)
# setup
parser = argparse.ArgumentParser()
parser.add_argument('--simulator',
choices=['robot', 'racecar', 'f16'],
default='robot',
help='Choose an autonomous systems simulator.')
simulator = parser.parse_args()
with open('configs/{}.yaml'.format(simulator.simulator), 'r') as file:
args = yaml.safe_load(file)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#*******************************************************************************
# File IO
#*******************************************************************************
train_df = pd.read_csv("data/{}-flow.csv".format(args['key']), header=None)
train_data = torch.tensor(train_df.values, dtype=torch.float32)
train_dataloader = DataLoader(TensorDataset(train_data),
batch_size=args['batch_size'], shuffle=True, pin_memory=True,
drop_last=True)
#*******************************************************************************
# Flow Construction
#*******************************************************************************
base_dist = StandardNormal(shape=[args['features']])
transform = create_transform(args)
flow = Flow(transform, base_dist)
flow = flow.to(device)
optimizer = torch.optim.Adam(flow.parameters(), lr=args['learning_rate'])
total_steps = len(train_dataloader) * args['epochs']
#*******************************************************************************
# Flow Training
#*******************************************************************************
tb_key = args['base'] + '-' + args['linear'] + '-' + args['key'] + "test"
if os.path.isdir("runs/" + tb_key):
shutil.rmtree("runs/" + tb_key)
writer = SummaryWriter("runs/" + tb_key)
i = 0
print("training flow...")
pbar = tqdm.tqdm(total=total_steps)
for epoch in range(args['epochs']):
for x in train_dataloader:
x = x[0].to(device)
flow.train()
optimizer.zero_grad()
loss = -flow.log_prob(inputs = x).mean()
i += 1
writer.add_scalar("Loss/train", loss, i)
loss.backward()
if args['grad_norm_clip_value'] is not None:
clip_grad_norm_(flow.parameters(), args['grad_norm_clip_value'])
optimizer.step()
if (i+1) % 50 == 0:
writer.flush()
pbar.update(1)
# save flow model
flow.to('cpu')
name = str(tb_key)
f = open("flows/" + name, "wb")
pickle.dump(flow, f)
f.close()