-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_gp.py
60 lines (51 loc) · 2.15 KB
/
train_gp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import wandb
import numpy as np
import os
from warhmm_gp import TWARHMM_GP, LinearRegressionObservations_GP
from data_util import load_dataset, standardize_pcs, precompute_ar_covariates, log_wandb_model
import datetime
from kernels import RBF
import matplotlib.pyplot as plt
data_dim = 10
num_lags = 1
hyperparameter_defaults = dict(
num_discrete_states=20,
data_dim=data_dim,
covariates_dim=11,
tau_scale=1,
num_taus=31,
kappa=10000,
alpha=5,
covariance_reg=1e-4,
lengthscale=1
)
train_dataset, test_dataset = load_dataset(num_pcs=data_dim)
train_dataset, mean, std = standardize_pcs(train_dataset)
test_dataset, _, _ = standardize_pcs(test_dataset, mean, std)
print("data loaded")
# First compute the autoregression covariates
precompute_ar_covariates(train_dataset, num_lags=num_lags, fit_intercept=True)
precompute_ar_covariates(test_dataset, num_lags=num_lags, fit_intercept=True)
# Then precompute the sufficient statistics
LinearRegressionObservations_GP.precompute_suff_stats(train_dataset)
LinearRegressionObservations_GP.precompute_suff_stats(test_dataset)
covariates_dim = train_dataset[0]['covariates'].shape[1]
projectname = "twarhmm_gp"
wandb.init(config=hyperparameter_defaults, entity="twss", project=projectname)
config = wandb.config
taus = np.linspace(-config['tau_scale'], config['tau_scale'], config['num_taus'])
twarhmm_gp = TWARHMM_GP(config, taus, kernel=RBF(config['num_discrete_states'], config['lengthscale']))
train_lls, test_lls, train_posteriors, test_posteriors, = \
twarhmm_gp.fit_stoch(train_dataset,
test_dataset,
num_epochs=50, fit_transitions=True, fit_tau=False, fit_kernel_params=False, wandb_log=True)
#plt.plot(test_lls)
# e = datetime.datetime.now()
#
log_wandb_model(twarhmm_gp, "twarhmm_gp_K{}_T{}".format(twarhmm_gp.num_discrete_states,len(twarhmm_gp.taus)),type="model")
# if test_posteriors is not None:
# wnb_histogram_plot(test_posteriors, tau_duration=True, duration_plot=True, state_usage_plot=True, ordered_state_usage=True, state_switch=True)
# centroid_velocity_plot(test_posteriors)
# #save_videos_wandb(test_posteriors)
#
wandb.finish()