-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathbaseline_train_cd.py
158 lines (144 loc) · 8.67 KB
/
baseline_train_cd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import json
import os
from lib.utils import get_logger
from lib import preprocessing as prep
from model.gdbt_supervisor import GBRSupervisor
from model.gp_supervisor import GPRSupervisor
from model.arima_supervisor2 import ARIMASupervisor
from model.svr_supervisor import SVRSupervisor
from model.var_supervisor import VARSupervisor
from model.VARMAX_supervisor import VARMAXSupervisor
# flags
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer('batch_size', 20, 'Batch size')
flags.DEFINE_integer('cl_decay_steps', -1,
'Parameter to control the decay speed of probability of feeding groundth instead of model output.')
flags.DEFINE_integer('epochs', 200, 'Maximum number of epochs to train.')
flags.DEFINE_string('filter_type', 'laplacian', 'laplacian/random_walk/dual_random_walk.')
flags.DEFINE_integer('horizon', 1, 'Maximum number of timestamps to prediction.')
flags.DEFINE_float('l1_decay', -1.0, 'L1 Regularization')
flags.DEFINE_float('lr_decay', -1.0, 'Learning rate decay.')
flags.DEFINE_integer('lr_decay_epoch', -1, 'The epoch that starting decaying the parameter.')
flags.DEFINE_integer('lr_decay_interval', -1, 'Interval beteween each deacy.')
flags.DEFINE_float('learning_rate', 0.001, 'Learning rate. -1: select by hyperopt tuning.')
flags.DEFINE_string('log_dir', None, 'Log directory for restoring the model from a checkpoint.')
flags.DEFINE_string('loss_func', 'L2', 'MSE/MAPE/RMSE: loss function; KL/L2/EMD: loss function.')
flags.DEFINE_float('min_learning_rate', -1, 'Minimum learning rate')
flags.DEFINE_integer('nb_weeks', 17, 'How many week\'s data should be used for train/test.')
flags.DEFINE_integer('patience', -1,
'Maximum number of epochs allowed for non-improving validation error before early stopping.')
flags.DEFINE_integer('seq_len', 3, 'Sequence length.')
flags.DEFINE_integer('test_every_n_epochs', 10, 'Run model on the testing dataset every n epochs.')
flags.DEFINE_bool('coarsen', False, 'Apply coarsen on input data.')
flags.DEFINE_integer('coarsening_levels', 4, 'Number of coarsened graph.')
flags.DEFINE_bool('shuffle_training', False, 'shuffle_training False: select by hyperopt tuning.')
flags.DEFINE_integer('num_gpus', 2, 'How many GPUs to use.')
flags.DEFINE_bool('use_cpu_only', False, 'Set to true to only use cpu.')
flags.DEFINE_bool('use_curriculum_learning', None, 'Set to true to use Curriculum learning in decoding stage.')
flags.DEFINE_integer('verbose', -1, '1: to log individual sensor information.')
# flags for data related
flags.DEFINE_string('server_name', 'chengdu', 'The name of dataset to be processed')
flags.DEFINE_string('borough', 'SecRing', 'Selected area')
flags.DEFINE_string('zone', 'polygon', 'map partition method')
flags.DEFINE_integer('sample_rate', 15, 'Sample rate to condense the data')
flags.DEFINE_string('mode', 'hist', 'avg: for single value, hist: for multi values')
flags.DEFINE_string('data_format', 'speed', 'speed or duration')
flags.DEFINE_bool('duration_log', False, 'Apply log10 to the data, True when data_form is duration')
flags.DEFINE_bool('fill_mean', False, 'Fill HA to the data, True when need to fill in mean values')
flags.DEFINE_bool('sparse_removal', False, 'Apply sparse removal to the data, True when need to remove sparse regions')
flags.DEFINE_string('scaler', 'maxmin', 'maxmin: MaxMinScaler, std: Standard scaler')
flags.DEFINE_string('config_filename', './conf/base_fc_config_chengdu.json',
'Configuration filename for restoring the model.')
# flags for the graph construction
flags.DEFINE_integer('hopk', 4, 'Hopk to construct the adjacent matrix')
flags.DEFINE_integer('sigma', 9, 'sigma used to construct the adj matrix')
flags.DEFINE_string('fc_method', 'od', 'od: od_construction, direct: as a whole')
flags.DEFINE_string('base_line', 'svr', 'Baseline methods.')
def main(_):
# Reads graph data.
with open(FLAGS.config_filename) as f:
# load configuration
data_model_config = json.load(f)
data_config = data_model_config['data']
# load data: include graph and data array
for name in ['server_name', 'hopk', 'sigma', 'mode', 'zone',
'borough', 'data_format', 'duration_log', 'sample_rate']:
data_config[name] = getattr(FLAGS, name)
data_config['window_size'] = getattr(FLAGS, 'seq_len')
data_config['predict_size'] = getattr(FLAGS, 'horizon')
data_config['base_dir'] = os.path.join(data_config['base_dir'],
data_config['server_name'],
data_config['borough'],
data_config['zone'])
logger = get_logger('./logs/', 'info.log')
logger.info('Loading graph...')
dataset = prep.CDData(**data_config)
logger.info('Loading graph tensor data with {} Mean-Fill...'.format(FLAGS.fill_mean))
print("sparse_removal: ", FLAGS.sparse_removal)
dataset_f = dataset.gcnn_lstm_data_construction(sparse_removal=FLAGS.sparse_removal)
adj_mx = dataset.adj_matrix
logger.info('Construct model and train...')
nodes = dataset.nodes
print("Number of edges is ", len(nodes))
print("Shape of adjacency matrix is ", adj_mx.shape)
supervisor_config = data_model_config['model']
supervisor_config['start_date'] = data_config['start_date']
# setting for training
supervisor_config['use_cpu_only'] = FLAGS.use_cpu_only
if FLAGS.log_dir:
supervisor_config['log_dir'] = FLAGS.log_dir
if FLAGS.use_curriculum_learning is not None:
supervisor_config['use_curriculum_learning'] = FLAGS.use_curriculum_learning
if FLAGS.loss_func:
supervisor_config['loss_func'] = FLAGS.loss_func
if FLAGS.filter_type:
supervisor_config['filter_type'] = FLAGS.filter_type
# Overwrites space with specified parameters.
for name in ['batch_size', 'cl_decay_steps', 'epochs', 'horizon', 'learning_rate', 'l1_decay',
'lr_decay', 'lr_decay_epoch', 'lr_decay_interval', 'sample_rate', 'min_learning_rate',
'patience', 'seq_len', 'test_every_n_epochs', 'verbose', 'coarsen', 'coarsening_levels',
'zone', 'scaler', 'data_format', 'num_gpus', 'mode', 'shuffle_training', 'fc_method', 'hopk']:
if type(getattr(FLAGS, name)) == str or getattr(FLAGS, name) >= 0:
supervisor_config[name] = getattr(FLAGS, name)
print('In ', FLAGS.base_line)
if FLAGS.base_line == 'arima':
supervisor = ARIMASupervisor(traffic_reading_df=dataset_f, adj_mx=adj_mx,
config=supervisor_config,
origin_df_file=dataset.origin_df_file,
nodes=nodes)
elif FLAGS.base_line == 'svr':
supervisor = SVRSupervisor(traffic_reading_df=dataset_f, adj_mx=adj_mx,
config=supervisor_config,
origin_df_file=dataset.origin_df_file,
nodes=nodes)
elif FLAGS.base_line == 'var':
supervisor = VARSupervisor(traffic_reading_df=dataset_f, adj_mx=adj_mx,
config=supervisor_config,
origin_df_file=dataset.origin_df_file,
nodes=nodes)
elif FLAGS.base_line == 'varmax':
supervisor = VARMAXSupervisor(traffic_reading_df=dataset_f, adj_mx=adj_mx,
config=supervisor_config,
origin_df_file=dataset.origin_df_file,
nodes=nodes)
elif FLAGS.base_line == 'gp':
supervisor = GPRSupervisor(traffic_reading_df=dataset_f, adj_mx=adj_mx,
config=supervisor_config,
origin_df_file=dataset.origin_df_file,
nodes=nodes)
elif FLAGS.base_line == 'gb':
supervisor = GBRSupervisor(traffic_reading_df=dataset_f, adj_mx=adj_mx,
config=supervisor_config,
origin_df_file=dataset.origin_df_file,
nodes=nodes)
else:
supervisor = None
raise ('Please provide a supervisor...')
supervisor.train(sess=None)
if __name__ == '__main__':
tf.app.run()