-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.py
67 lines (54 loc) · 1.78 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn as nn
import torch.utils.data as utils
import sys,os,json
n_in, n_h, n_out = 304, 200, 2
_alldata = torch.load('traintest_normalised.pt')
alldata = _alldata[torch.randperm(_alldata.size()[0])]
x,y = torch.split(alldata, [304,1], dim=1)
y = y.long().view(-1)
#x = x.t()
#y = y.t()
print x.size(),y.size()
train_split_size = int(0.8 * x.size()[0])
test_split_size = x.size()[0] - train_split_size
x_train,x_test = torch.split(x, [train_split_size, test_split_size], dim=0)
y_train,y_test = torch.split(y, [train_split_size, test_split_size], dim=0)
x_train = x_train.cuda()
x_test = x_test.cuda()
y_train = y_train.cuda()
y_test = y_test.cuda()
model = nn.Sequential(nn.Linear(n_in, n_h),
nn.ReLU(),
nn.Linear(n_h, n_h),
nn.ReLU(),
nn.Linear(n_h,n_out),
nn.Softmax()).cuda()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
epoch = 1
while True:
# Forward Propagation
y_pred = model(x_train)
# Compute and print loss
loss = criterion(y_pred, y_train)
# Zero the gradients
optimizer.zero_grad()
# perform a backward pass (backpropagation)
loss.backward()
# Update the parameters
optimizer.step()
#validate
epoch += 1
if epoch%1000 == 0:
print('epoch: ', epoch,' loss: ', loss.item())
model.eval()
y_test_pred = model(x_test)
correct = 0
for i in range(test_split_size):
predres = 1.0 - (y_test_pred[i][0]>0.5)
if torch.eq(predres.long(),y_test[i]):
correct += 1
print("Acc = %f"%(correct/float(test_split_size)))
model.train()
torch.save(model.state_dict(), 'er.model')