-
Notifications
You must be signed in to change notification settings - Fork 5
/
train.py
381 lines (314 loc) · 11.4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
import argparse
import inspect
import os
import random
from pathlib import Path
from typing import Optional
import numpy as np
import paddle
import paddle.nn as nn
import paddle.optimizer as optim
from loguru import logger
from paddle.io import DataLoader
from paddle.nn import functional as F
from paddle_msssim import ms_ssim, ssim
from doc3d_dataset import Doc3dDataset
from GeoTr import GeoTr
from utils import to_image
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]
RANK = int(os.getenv("RANK", -1))
def init_seeds(seed=0, deterministic=False):
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
if deterministic:
os.environ["FLAGS_cudnn_deterministic"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ["PYTHONHASHSEED"] = str(seed)
def colorstr(*input):
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code,
# i.e. colorstr('blue', 'hello world')
*args, string = (
input if len(input) > 1 else ("blue", "bold", input[0])
) # color arguments, string
colors = {
"black": "\033[30m", # basic colors
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"magenta": "\033[35m",
"cyan": "\033[36m",
"white": "\033[37m",
"bright_black": "\033[90m", # bright colors
"bright_red": "\033[91m",
"bright_green": "\033[92m",
"bright_yellow": "\033[93m",
"bright_blue": "\033[94m",
"bright_magenta": "\033[95m",
"bright_cyan": "\033[96m",
"bright_white": "\033[97m",
"end": "\033[0m", # misc
"bold": "\033[1m",
"underline": "\033[4m",
}
return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
# Print function arguments (optional args dict)
x = inspect.currentframe().f_back # previous frame
file, _, func, _, _ = inspect.getframeinfo(x)
if args is None: # get args automatically
args, _, _, frm = inspect.getargvalues(x)
args = {k: v for k, v in frm.items() if k in args}
try:
file = Path(file).resolve().relative_to(ROOT).with_suffix("")
except ValueError:
file = Path(file).stem
s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
logger.info(colorstr(s) + ", ".join(f"{k}={v}" for k, v in args.items()))
def increment_path(path, exist_ok=False, sep="", mkdir=False):
# Increment file or directory path,
# i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
path = Path(path) # os-agnostic
if path.exists() and not exist_ok:
path, suffix = (
(path.with_suffix(""), path.suffix) if path.is_file() else (path, "")
)
for n in range(2, 9999):
p = f"{path}{sep}{n}{suffix}" # increment path
if not os.path.exists(p):
break
path = Path(p)
if mkdir:
path.mkdir(parents=True, exist_ok=True) # make directory
return path
def train(args):
save_dir = Path(args.save_dir)
use_vdl = args.use_vdl
if use_vdl:
from visualdl import LogWriter
log_dir = save_dir / "vdl"
vdl_writer = LogWriter(str(log_dir))
# Directories
weights_dir = save_dir / "weights"
weights_dir.parent.mkdir(parents=True, exist_ok=True)
last = weights_dir / "last.ckpt"
best = weights_dir / "best.ckpt"
# Hyperparameters
# Config
init_seeds(args.seed)
# Train loader
train_dataset = Doc3dDataset(
args.data_root,
split="train",
is_augment=True,
image_size=args.img_size,
)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
)
# Validation loader
val_dataset = Doc3dDataset(
args.data_root,
split="val",
is_augment=False,
image_size=args.img_size,
)
val_loader = DataLoader(
val_dataset, batch_size=args.batch_size, num_workers=args.workers
)
# Model
model = GeoTr()
if use_vdl:
vdl_writer.add_graph(
model,
input_spec=[
paddle.static.InputSpec([1, 3, args.img_size, args.img_size], "float32")
],
)
# Data Parallel Mode
if RANK == -1 and paddle.device.cuda.device_count() > 1:
model = paddle.DataParallel(model)
# Scheduler
scheduler = optim.lr.OneCycleLR(
max_learning_rate=args.lr,
total_steps=args.epochs * len(train_loader),
phase_pct=0.1,
end_learning_rate=args.lr / 2.5e5,
)
# Optimizer
optimizer = optim.AdamW(
learning_rate=scheduler,
parameters=model.parameters(),
)
# loss function
l1_loss_fn = nn.L1Loss()
mse_loss_fn = nn.MSELoss()
# Resume
best_fitness, start_epoch = 0.0, 0
if args.resume:
ckpt = paddle.load(args.resume)
model.set_state_dict(ckpt["model"])
optimizer.set_state_dict(ckpt["optimizer"])
scheduler.set_state_dict(ckpt["scheduler"])
best_fitness = ckpt["best_fitness"]
start_epoch = ckpt["epoch"] + 1
# Train
for epoch in range(start_epoch, args.epochs):
model.train()
for i, (img, target) in enumerate(train_loader):
img = paddle.to_tensor(img) # NCHW
target = paddle.to_tensor(target) # NHWC
pred = model(img) # NCHW
pred_nhwc = pred.transpose([0, 2, 3, 1])
loss = l1_loss_fn(pred_nhwc, target)
mse_loss = mse_loss_fn(pred_nhwc, target)
if use_vdl:
vdl_writer.add_scalar(
"Train/L1 Loss", float(loss), epoch * len(train_loader) + i
)
vdl_writer.add_scalar(
"Train/MSE Loss", float(mse_loss), epoch * len(train_loader) + i
)
vdl_writer.add_scalar(
"Train/Learning Rate",
float(scheduler.get_lr()),
epoch * len(train_loader) + i,
)
loss.backward()
optimizer.step()
scheduler.step()
optimizer.clear_grad()
if i % 10 == 0:
logger.info(
f"[TRAIN MODE] Epoch: {epoch}, Iter: {i}, L1 Loss: {float(loss)}, "
f"MSE Loss: {float(mse_loss)}, LR: {float(scheduler.get_lr())}"
)
# Validation
model.eval()
with paddle.no_grad():
avg_ssim = paddle.zeros([])
avg_ms_ssim = paddle.zeros([])
avg_l1_loss = paddle.zeros([])
avg_mse_loss = paddle.zeros([])
for i, (img, target) in enumerate(val_loader):
img = paddle.to_tensor(img)
target = paddle.to_tensor(target)
pred = model(img)
pred_nhwc = pred.transpose([0, 2, 3, 1])
# predict image
out = F.grid_sample(img, (pred_nhwc / args.img_size - 0.5) * 2)
out_gt = F.grid_sample(img, (target / args.img_size - 0.5) * 2)
# calculate ssim
ssim_val = ssim(out, out_gt, data_range=1.0)
ms_ssim_val = ms_ssim(out, out_gt, data_range=1.0)
loss = l1_loss_fn(pred_nhwc, target)
mse_loss = mse_loss_fn(pred_nhwc, target)
# calculate fitness
avg_ssim += ssim_val
avg_ms_ssim += ms_ssim_val
avg_l1_loss += loss
avg_mse_loss += mse_loss
if i % 10 == 0:
logger.info(
f"[VAL MODE] Epoch: {epoch}, VAL Iter: {i}, "
f"L1 Loss: {float(loss)} MSE Loss: {float(mse_loss)}, "
f"MS-SSIM: {float(ms_ssim_val)}, SSIM: {float(ssim_val)}"
)
if use_vdl and i == 0:
img_0 = to_image(out[0])
img_gt_0 = to_image(out_gt[0])
vdl_writer.add_image("Val/Predicted Image No.0", img_0, epoch)
vdl_writer.add_image("Val/Target Image No.0", img_gt_0, epoch)
img_1 = to_image(out[1])
img_gt_1 = to_image(out_gt[1])
img_gt_1 = img_gt_1.astype("uint8")
vdl_writer.add_image("Val/Predicted Image No.1", img_1, epoch)
vdl_writer.add_image("Val/Target Image No.1", img_gt_1, epoch)
img_2 = to_image(out[2])
img_gt_2 = to_image(out_gt[2])
vdl_writer.add_image("Val/Predicted Image No.2", img_2, epoch)
vdl_writer.add_image("Val/Target Image No.2", img_gt_2, epoch)
avg_ssim /= len(val_loader)
avg_ms_ssim /= len(val_loader)
avg_l1_loss /= len(val_loader)
avg_mse_loss /= len(val_loader)
if use_vdl:
vdl_writer.add_scalar("Val/L1 Loss", float(loss), epoch)
vdl_writer.add_scalar("Val/MSE Loss", float(mse_loss), epoch)
vdl_writer.add_scalar("Val/SSIM", float(ssim_val), epoch)
vdl_writer.add_scalar("Val/MS-SSIM", float(ms_ssim_val), epoch)
# Save
ckpt = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"best_fitness": best_fitness,
"epoch": epoch,
}
paddle.save(ckpt, str(last))
if best_fitness < avg_ssim:
best_fitness = avg_ssim
paddle.save(ckpt, str(best))
if use_vdl:
vdl_writer.close()
def main(args):
print_args(vars(args))
args.save_dir = str(
increment_path(Path(args.project) / args.name, exist_ok=args.exist_ok)
)
train(args)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Hyperparams")
parser.add_argument(
"--data-root",
nargs="?",
type=str,
default="~/datasets/doc3d",
help="The root path of the dataset",
)
parser.add_argument(
"--img-size",
nargs="?",
type=int,
default=288,
help="The size of the input image",
)
parser.add_argument(
"--epochs",
nargs="?",
type=int,
default=65,
help="The number of training epochs",
)
parser.add_argument(
"--batch-size", nargs="?", type=int, default=12, help="Batch Size"
)
parser.add_argument(
"--lr", nargs="?", type=float, default=1e-04, help="Learning Rate"
)
parser.add_argument(
"--resume",
nargs="?",
type=str,
default=None,
help="Path to previous saved model to restart from",
)
parser.add_argument("--workers", type=int, default=8, help="max dataloader workers")
parser.add_argument(
"--project", default=ROOT / "runs/train", help="save to project/name"
)
parser.add_argument("--name", default="exp", help="save to project/name")
parser.add_argument(
"--exist-ok",
action="store_true",
help="existing project/name ok, do not increment",
)
parser.add_argument("--seed", type=int, default=0, help="Global training seed")
parser.add_argument("--use-vdl", action="store_true", help="use VisualDL as logger")
args = parser.parse_args()
main(args)