-
Notifications
You must be signed in to change notification settings - Fork 3
/
train_baseline.py
631 lines (544 loc) · 21.4 KB
/
train_baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
import os
import argparse
from typing import Dict, Any
import copy
import logging
import yaml
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import MultiStepLR
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch.nn import Module, ModuleDict
from torch.utils.tensorboard import SummaryWriter
from dataset import get_dataset
from dataset.cifar import get_cifar100_dataloaders_sample
from dataset.tiny_imagenet import get_tinyimagenet_dataloaders_sample
from models import get_model
from distiller_zoo import get_loss_module, get_loss_forward
from optim import get_optimizer
from models.classifier import LinearClassifier
from models.gs import gumbel_softmax
from models.util import ConvReg, Connector, Paraphraser, Translator, LinearEmbed
from distiller_zoo.FitNet import HintLoss
from distiller_zoo.AT import Attention
from distiller_zoo.crd.criterion import CRDLoss
from distiller_zoo.NST import NSTLoss
from distiller_zoo.SP import Similarity
from distiller_zoo.RKD import RKDLoss
from distiller_zoo.PKT import PKT
from distiller_zoo.KDSVD import KDSVD
from distiller_zoo.CC import Correlation
from distiller_zoo.VID import VIDLoss
from distiller_zoo.AB import ABLoss
from distiller_zoo.FT import FactorTransfer
from distiller_zoo.FSP import FSP
from helper.util import str2bool, get_logger, preserve_memory, adjust_learning_rate_stage, adjust_learning_rate_stage_agent
from helper.util import make_deterministic
from helper.util import AverageMeter, accuracy
from helper.validate import validate_policy, validate
from helper.pretrain import init_pretrain
def get_dataloader(cfg: Dict[str, Any]):
# dataset
dataset_cfg = cfg["dataset"]
train_dataset = get_dataset(split="train", **dataset_cfg)
val_dataset = get_dataset(split="val", **dataset_cfg)
num_classes = len(train_dataset.classes)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=cfg["training"]["batch_size"],
num_workers=cfg["training"]["num_workers"],
shuffle=True,
pin_memory=True
)
val_loader = DataLoader(
dataset=val_dataset,
batch_size=cfg["validation"]["batch_size"],
num_workers=cfg["validation"]["num_workers"],
shuffle=False,
pin_memory=True
)
return train_loader, val_loader, num_classes
def get_teacher(cfg: Dict[str, Any], num_classes: int) -> Module:
teacher_cfg = copy.deepcopy(cfg["kd"]["teacher"])
teacher_name = teacher_cfg["name"]
ckpt_fp = teacher_cfg["checkpoint"]
teacher_cfg.pop("name")
teacher_cfg.pop("checkpoint")
# load state dict
state_dict = torch.load(ckpt_fp, map_location="cpu")["model"]
model_t = get_model(
model_name=teacher_name,
num_classes=num_classes,
state_dict=state_dict,
**teacher_cfg
)
return model_t
def get_student(cfg: Dict[str, Any], num_classes: int) -> Module:
student_cfg = copy.deepcopy(cfg["kd"]["student"])
student_name = student_cfg["name"]
student_cfg.pop("name")
state_dict = None
if "checkpoint" in student_cfg.keys():
state_dict = torch.load(student_cfg["checkpoint"], map_location="cpu")["model"]
student_cfg.pop("checkpoint")
model_s = get_model(
model_name=student_name,
num_classes=num_classes,
state_dict=state_dict,
**student_cfg
)
return model_s
def get_pre_student(cfg: Dict[str, Any], num_classes: int) -> Module:
student_cfg = copy.deepcopy(cfg["kd"]["prestudent"])
student_name = student_cfg["name"]
student_cfg.pop("name")
state_dict = None
if "checkpoint" in student_cfg.keys():
state_dict = torch.load(student_cfg["checkpoint"], map_location="cpu")["model"]
student_cfg.pop("checkpoint")
model_s = get_model(
model_name=student_name,
num_classes=num_classes,
state_dict=state_dict,
**student_cfg
)
return model_s
def feature_loss_function(fea, target_fea):
loss = (fea - target_fea)**2 * ((fea > 0) | (target_fea > 0)).float()
return torch.abs(loss)
def train_epoch(
cfg: Dict[str, Any],
epoch: int,
train_loader: DataLoader,
module_dict: ModuleDict,
criterion_dict: ModuleDict,
optimizer: Optimizer,
tb_writer: SummaryWriter,
device: torch.device
):
logger = logging.getLogger("train_epoch")
# setting parameters
gamma = cfg["kd"]["loss_weights"]["classify_weight"]
alpha = cfg["kd"]["loss_weights"]["kd_weight"]
beta = cfg["kd"]["loss_weights"]["other_kd"]
logger.info(
"Starting train one epoch with [gamma: %.5f, alpha: %.5f, beta: %.5f]...",
gamma, alpha, beta
)
for name, module in module_dict.items():
if name == "teacher":
module.eval()
else:
module.train()
criterion_cls = criterion_dict["cls"]
criterion_div = criterion_dict["div"]
criterion_kd = criterion_dict["kd"]
model_s = module_dict["student"].train()
model_t = module_dict["teacher"].eval()
if cfg["kd_loss"]["name"] in ['FitNet']:
hint = module_dict["loss_hint"].train()
elif cfg["kd_loss"]["name"] in ['CC']:
cc_embed_s = module_dict["cc_embed_s"].train()
cc_embed_t = module_dict["cc_embed_t"].train()
elif cfg["kd_loss"]["name"] in ['FT']:
ft_s = module_dict["translator"].train()
ft_t = module_dict["paraphraser"].train()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
for idx, data in enumerate(train_loader):
__global_values__["it"] += 1
if cfg["kd_loss"]["name"] in ['CRD']:
x, target, index, contrast_idx = data
contrast_idx = contrast_idx.to(device)
index = index.to(device)
else:
x, target = data
x = x.to(device)
target = target.to(device)
# ===================forward=====================
with torch.no_grad():
feat_t, logit_t = model_t(x, begin=0, end=100, is_feat=True)
feat_s, logit_s = model_s(x, begin=0, end=100, is_feat=True)
loss_cls = criterion_cls(logit_s, target)
loss_div =criterion_div(logit_s, logit_t)
loss_kd = torch.tensor(0.0, device=device, dtype=torch.float)
if cfg["kd_loss"]["name"] == "FitNet":
f_s = hint(feat_s[2])
f_t = feat_t[2]
loss_kd = criterion_kd(f_s, f_t)
elif cfg["kd_loss"]["name"] == "CRD":
loss_kd = criterion_kd(feat_s[-1], feat_t[-1], index, contrast_idx)
elif cfg["kd_loss"]["name"] in ["AT", "NST","KDSVD"]:
g_s = feat_s[1:-1]
g_t = feat_t[1:-1]
loss_kd = sum(criterion_kd(g_s, g_t))
elif cfg["kd_loss"]["name"] in ["SP"]:
f_s = [feat_s[-2]]
f_t = [feat_t[-2]]
loss_kd = sum(criterion_kd(f_s, f_t))
elif cfg["kd_loss"]["name"] in ["RKD",'PKT']:
f_s = feat_s[-1]
f_t = feat_t[-1]
loss_kd = criterion_kd(f_s, f_t)
elif cfg["kd_loss"]["name"] in ['CC']:
f_s = cc_embed_s(feat_s[-1])
f_t = cc_embed_t(feat_t[-1])
loss_kd = criterion_kd(f_s, f_t)
elif cfg["kd_loss"]["name"] in ['VID']:
g_s = feat_s[1:-1]
g_t = feat_t[1:-1]
loss_group = [c(f_s, f_t) for f_s, f_t, c in zip(g_s, g_t, criterion_kd)]
loss_kd = sum(loss_group)
elif cfg["kd_loss"]["name"] in ["FT"]:
factor_s = ft_s(feat_s[-2])
factor_t = ft_t(feat_t[-2], is_factor=True)
loss_kd = criterion_kd(factor_s, factor_t)
elif cfg["kd_loss"]["name"] in ["FSP"]:
loss_kd = criterion_kd(feat_s[:-1], feat_t[:-1].detach())
else:
raise NotImplementedError(cfg["kd_loss"]["name"])
loss = loss_cls + alpha * loss_div + beta * loss_kd
acc1, acc5 = accuracy(logit_s, target, topk=(1, 5))
losses.update(loss.item(), x.shape[0])
top1.update(acc1[0], x.shape[0])
top5.update(acc5[0], x.shape[0])
# ===================backward=====================
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print info
tb_writer.add_scalars(
main_tag="train/acc",
tag_scalar_dict={
"@1": acc1,
"@5": acc5,
},
global_step=__global_values__["it"]
)
tb_writer.add_scalars(
main_tag="train/loss",
tag_scalar_dict={
"cls": loss_cls.item(),
"div": loss_div.item(),
"kd": loss_kd.item()
},
global_step=__global_values__["it"]
)
if idx % cfg["training"]["print_iter_freq"] == 0:
logger.info(
"Epoch: [%3d|%3d], idx: %d, total iter: %d, loss: %.5f, acc@1: %.4f, acc@5: %.4f",
epoch, cfg["training"]["epochs"],
idx, __global_values__["it"],
losses.val, top1.val, top5.val
)
return top1.avg, losses.avg
def train_kd(
cfg: Dict[str, Any],
train_loader: DataLoader,
val_loader: DataLoader,
module_dict: ModuleDict,
criterion_dict: ModuleDict,
optimizer: Optimizer,
lr_scheduler: MultiStepLR,
tb_writer: SummaryWriter,
device: torch.device,
ckpt_dir: str
):
logger = logging.getLogger("train")
logger.info("Start training...")
best_acc = 0
for epoch in range(1, cfg["training"]["epochs"] + 1):
# logger.info("Start training epoch: %d, current lr: %.6f", epoch, lr_scheduler.get_last_lr()[0])
adjust_learning_rate_stage(
optimizer=optimizer,
cfg=cfg,
epoch=epoch
)
print(cfg["kd"]["teacher"]["name"], cfg["kd"]["student"]["name"])
logger.info("Start training epoch: %d, current lr: %.6f",
epoch, lr_scheduler.get_last_lr()[0])
train_acc, train_loss = train_epoch(
cfg=cfg,
epoch=epoch,
train_loader=train_loader,
module_dict=module_dict,
criterion_dict=criterion_dict,
optimizer=optimizer,
tb_writer=tb_writer,
device=device
)
tb_writer.add_scalar("epoch/train_acc", train_acc, epoch)
tb_writer.add_scalar("epoch/train_loss", train_loss, epoch)
val_acc, val_acc_top5, val_loss = validate(
val_loader=val_loader,
model=module_dict["student"],
criterion=criterion_dict["cls"],
device=device
)
tb_writer.add_scalar("epoch/val_acc", val_acc, epoch)
tb_writer.add_scalar("epoch/val_loss", val_loss, epoch)
tb_writer.add_scalar("epoch/val_acc_top5", val_acc_top5, epoch)
logger.info(
"Epoch: %04d | %04d, acc: %.4f, loss: %.5f, val_acc: %.4f, val_acc_top5: %.4f, val_loss: %.5f",
epoch, cfg["training"]["epochs"],
train_acc, train_loss,
val_acc, val_acc_top5, val_loss
)
lr_scheduler.step()
state = {
"epoch": epoch,
"model": module_dict["student"].state_dict(),
"acc": val_acc,
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict()
}
if cfg["kd_loss"]["name"] == "FitNet":
state["loss_hint"] = module_dict["loss_hint"].state_dict()
elif cfg["kd_loss"]["name"] == "CRD":
state["embed_s"] = module_dict["crd_embed_s"].state_dict()
state["embed_t"] = module_dict["crd_embed_t"].state_dict()
elif cfg["kd_loss"]["name"] == "CC":
state["embed_s"] = module_dict["cc_embed_s"].state_dict()
state["embed_t"] = module_dict["cc_embed_t"].state_dict()
# regular saving
# if epoch % cfg["training"]["save_ep_freq"] == 0:
# logger.info("Saving epoch %d checkpoint...", epoch)
# save_file = os.path.join(ckpt_dir, "epoch_{}.pth".format(epoch))
# torch.save(state, save_file)
# save the best model
if val_acc > best_acc:
best_acc = val_acc
best_ep = epoch
save_file = os.path.join(ckpt_dir, "best.pth")
logger.info("Saving the best model with acc: %.4f", best_acc)
torch.save(state, save_file)
logger.info("Epoch: %04d | %04d, best acc: %.4f,", epoch, cfg["training"]["epochs"], best_acc)
logger.info("Final best accuracy: %.5f, at epoch: %d", best_acc, best_ep)
def main(
cfg_filepath: str,
file_name_cfg: str,
logdir: str,
gpu_preserve: bool = False,
debug: bool = False
):
with open(cfg_filepath) as fp:
cfg = yaml.load(fp, Loader=yaml.SafeLoader)
seed = cfg["training"]["seed"]
ckpt_dir = os.path.join(logdir, "ckpt")
os.makedirs(logdir, exist_ok=True)
os.makedirs(ckpt_dir, exist_ok=True)
formatter = (
cfg["kd"]["teacher"]["name"],
cfg["kd"]["student"]["name"],
cfg["kd_loss"]["T"],
cfg["dataset"]["name"],
)
writer = SummaryWriter(
log_dir=os.path.join(
logdir,
"tf-logs",
file_name_cfg.format(*formatter)
),
flush_secs=1
)
train_log_dir = os.path.join(logdir, "train-logs")
os.makedirs(train_log_dir, exist_ok=True)
logger = get_logger(
level=logging.INFO,
mode="w",
name=None,
logger_fp=os.path.join(
train_log_dir,
"training-" + file_name_cfg.format(*formatter) + ".log"
)
)
logger.info("Start running with config: \n{}".format(yaml.dump(cfg)))
# set seed
make_deterministic(seed)
logger.info("Set seed : {}".format(seed))
if gpu_preserve:
logger.info("Preserving memory...")
preserve_memory(args.preserve_percent)
logger.info("Preserving memory done")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# get dataloaders
logger.info("Loading datasets...")
if cfg["kd_loss"]["name"] in ['CRD'] and cfg["dataset"]["name"] in ["cifar100","CIFAR100","cifar-100"]:
train_loader, val_loader, n_data = get_cifar100_dataloaders_sample(
batch_size=cfg["training"]["batch_size"],
num_workers=cfg["training"]["num_workers"],
k=16384,
mode='exact'
)
num_classes=100
elif cfg["kd_loss"]["name"] in ['CRD'] and cfg["dataset"]["name"] in ["tiny-imagenet"]:
num_classes = 200
train_loader, val_loader, n_data = get_tinyimagenet_dataloaders_sample(
batch_size=cfg["training"]["batch_size"],
num_workers=cfg["training"]["num_workers"],
k=16384,
mode='exact'
)
else:
train_loader, val_loader, num_classes = get_dataloader(cfg)
logger.info("num_classes: {}".format(num_classes))
# get models
logger.info("Loading teacher and student...")
model_t = get_teacher(cfg, num_classes).to(device)
model_s = get_student(cfg, num_classes).to(device)
model_t.eval()
model_s.eval()
if cfg["dataset"]["name"] in ["cifar100","CIFAR100","cifar-100"]:
data = torch.randn(2, 3, 32, 32).to(device)
else:
data = torch.randn(2, 3, 64, 64).to(device)
feat_t, _ = model_t(data, begin=0, end=100, is_feat=True)
feat_s, _ = model_s(data, begin=0, end=100, is_feat=True)
logger.info(model_s)
module_dict = nn.ModuleDict(dict(
student=model_s,
teacher=model_t,
))
trainable_dict = nn.ModuleDict(dict(
student=model_s,
))
# get loss modules
criterion_dict, loss_trainable_dict = get_loss_module(
cfg=cfg,
module_dict=module_dict,
train_loader=train_loader,
tb_writer=writer,
device=device
)
if cfg["kd_loss"]["name"] == "FitNet":
criterion_dict["kd"] = HintLoss()
regress_s = ConvReg(feat_s[2].shape, feat_t[2].shape).to(device)
module_dict["loss_hint"] = regress_s
trainable_dict["loss_hint"] = regress_s
elif cfg["kd_loss"]["name"] == "CRD":
criterion_dict["kd"] = CRDLoss(
s_dim=feat_s[-1].shape[1],
t_dim=feat_t[-1].shape[1],
feat_dim=cfg["kd_loss"]["feat_dim"],
n_data=n_data
).to(device)
trainable_dict["crd_embed_s"] = criterion_dict["kd"].embed_s
trainable_dict["crd_embed_t"] = criterion_dict["kd"].embed_t
module_dict["crd_embed_s"] = criterion_dict["kd"].embed_s
module_dict["crd_embed_t"] = criterion_dict["kd"].embed_t
elif cfg["kd_loss"]["name"] == "AT":
criterion_dict["kd"] = Attention().to(device)
elif cfg["kd_loss"]["name"] == "NST":
criterion_dict["kd"] = NSTLoss().to(device)
elif cfg["kd_loss"]["name"] == "SP":
criterion_dict["kd"] = Similarity().to(device)
elif cfg["kd_loss"]["name"] == "RKD":
criterion_dict["kd"] = RKDLoss().to(device)
elif cfg["kd_loss"]["name"] == "PKT":
criterion_dict["kd"] = PKT().to(device)
elif cfg["kd_loss"]["name"] == "KDSVD":
criterion_dict["kd"] = KDSVD().to(device)
elif cfg["kd_loss"]["name"] == "CC":
criterion_dict["kd"] = Correlation().to(device)
embed_s = LinearEmbed(feat_s[-1].shape[1], cfg["kd_loss"]["feat_dim"]).to(device)
embed_t = LinearEmbed(feat_t[-1].shape[1], cfg["kd_loss"]["feat_dim"]).to(device)
module_dict["cc_embed_s"] = embed_s
module_dict["cc_embed_t"] = embed_t
trainable_dict["cc_embed_s"] = embed_s
trainable_dict["cc_embed_t"] = embed_t
elif cfg["kd_loss"]["name"]=='VID':
s_n = [f.shape[1] for f in feat_s[1:-1]]
t_n = [f.shape[1] for f in feat_t[1:-1]]
criterion_kd = nn.ModuleList(
[VIDLoss(s, t, t) for s, t in zip(s_n, t_n)]
)
criterion_dict["kd"] = criterion_kd.to(device)
trainable_dict["vid"] = criterion_kd.to(device)
elif cfg["kd_loss"]["name"] == "FT":
s_shape = feat_s[-2].shape
t_shape = feat_t[-2].shape
paraphraser = Paraphraser(t_shape).to(device)
translator = Translator(s_shape, t_shape).to(device)
# init stage training
init_trainable_dict =nn.ModuleDict(dict(
paraphraser=paraphraser,
))
criterion_init = nn.MSELoss().to(device)
init_pretrain(cfg, module_dict, init_trainable_dict, criterion_init, train_loader, logger, device)
# classification
criterion_dict["kd"] = FactorTransfer().to(device)
trainable_dict["translator"] = translator
module_dict["translator"] = translator
module_dict["paraphraser"] = paraphraser
elif cfg["kd_loss"]["name"] == "FSP":
s_shapes = [s.shape for s in feat_s[:-1]]
t_shapes = [t.shape for t in feat_t[:-1]]
criterion_kd = FSP(s_shapes, t_shapes).to(device)
# init stage training
init_trainable_dict = nn.ModuleDict(dict(
model_s=model_s.get_feat_modules(),
))
init_pretrain(cfg, module_dict, init_trainable_dict, criterion_kd, train_loader, logger, device)
criterion_dict["kd"] = criterion_kd
elif cfg["kd_loss"]["name"] == "AB":
s_shapes = [f.shape for f in feat_s[1:-1]]
t_shapes = [f.shape for f in feat_t[1:-1]]
connector = Connector(s_shapes, t_shapes).to(device)
# init stage training
init_trainable_dict = nn.ModuleDict(dict(
connector=connector,
))
criterion_kd = ABLoss(len(feat_s[1:-1])).to(device)
init_pretrain(cfg, module_dict, init_trainable_dict, criterion_kd, train_loader, logger, device)
# classification
module_dict["connector"] = connector
criterion_dict["kd"] = criterion_kd
else:
raise NotImplementedError(cfg["kd_loss"]["name"])
trainable_dict.update(loss_trainable_dict)
assert "teacher" not in trainable_dict.keys(), "teacher is not trainable"
# optimizer
optimizer = torch.optim.SGD(
params=trainable_dict.parameters(),
lr=cfg["training"]["lr"],
weight_decay=cfg["training"]["optimizer"]["weight_decay_stage2"],
momentum=cfg["training"]["optimizer"]["momentum"])
lr_scheduler = MultiStepLR(
optimizer=optimizer,
milestones=cfg["training"]["lr_decay_epochs"],
gamma=cfg["training"]["lr_decay_rate"]
)
# append teacher after optimizer to avoid weight_decay
module_dict["teacher"] = model_t.to(device)
logger.info(optimizer)
train_kd(
cfg=cfg,
train_loader=train_loader,
val_loader=val_loader,
module_dict=module_dict,
criterion_dict=criterion_dict,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
tb_writer=writer,
device=device,
ckpt_dir=ckpt_dir
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str)
parser.add_argument("--logdir", type=str)
parser.add_argument("--file_name_cfg", type=str)
parser.add_argument("--gpu_preserve", type=str2bool, default=False)
parser.add_argument("--debug", type=str2bool, default=False)
parser.add_argument("--preserve_percent", type=float, default=0.95)
args = parser.parse_args()
__global_values__ = dict(it=0)
main(
cfg_filepath=args.config,
file_name_cfg=args.file_name_cfg,
logdir=args.logdir,
gpu_preserve=args.gpu_preserve,
debug=args.debug
)