-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtrain.py
31 lines (25 loc) · 998 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import os
import sys
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "src"))
from network.trainer import *
from argparse import ArgumentParser
from utils import read_config
import random
def train(settings):
random.seed(42)
trainer = Trainer(settings)
trainer.train()
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('-c', '--config', type=str, default='', help='Which config to read')
parser.add_argument('--check_point', type=str, default='', help='check point')
args = parser.parse_args()
settings = read_config(args.config)
if args.check_point != '':
settings['checkpoint'] = args.check_point
try:
train(settings)
except Exception as e:
print("Training task:%s crashed when error occur:\n%s"%(settings["job_name"], e), "Training task:%s crashed"%(settings["job_name"]))
raise e
print("Training task finish", "Training task:%s finish"%(settings["job_name"]))