-
Notifications
You must be signed in to change notification settings - Fork 42
/
Copy pathpytorch_example.py
108 lines (83 loc) · 2.88 KB
/
pytorch_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torchvision import datasets
from torchvision import transforms
from hyperactive import Hyperactive
"""
derived from optuna example:
https://github.com/optuna/optuna/blob/master/examples/pytorch_simple.py
"""
DEVICE = torch.device("cpu")
BATCHSIZE = 256
CLASSES = 10
DIR = os.getcwd()
EPOCHS = 10
LOG_INTERVAL = 10
N_TRAIN_EXAMPLES = BATCHSIZE * 30
N_VALID_EXAMPLES = BATCHSIZE * 10
# Get the MNIST dataset.
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(DIR, train=True, download=True, transform=transforms.ToTensor()),
batch_size=BATCHSIZE,
shuffle=True,
)
valid_loader = torch.utils.data.DataLoader(
datasets.MNIST(DIR, train=False, transform=transforms.ToTensor()),
batch_size=BATCHSIZE,
shuffle=True,
)
def pytorch_cnn(params):
linear0 = params["linear.0"]
linear1 = params["linear.1"]
layers = []
in_features = 28 * 28
layers.append(nn.Linear(in_features, linear0))
layers.append(nn.ReLU())
layers.append(nn.Dropout(0.2))
layers.append(nn.Linear(linear0, linear1))
layers.append(nn.ReLU())
layers.append(nn.Dropout(0.2))
layers.append(nn.Linear(linear1, CLASSES))
layers.append(nn.LogSoftmax(dim=1))
model = nn.Sequential(*layers)
# model = create_model(params).to(DEVICE)
optimizer = getattr(optim, "Adam")(model.parameters(), lr=0.01)
# Training of the model.
for epoch in range(EPOCHS):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
# Limiting training data for faster epochs.
if batch_idx * BATCHSIZE >= N_TRAIN_EXAMPLES:
break
data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
# Validation of the model.
model.eval()
correct = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(valid_loader):
# Limiting validation data.
if batch_idx * BATCHSIZE >= N_VALID_EXAMPLES:
break
data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)
output = model(data)
# Get the index of the max log-probability.
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
accuracy = correct / min(len(valid_loader.dataset), N_VALID_EXAMPLES)
return accuracy
search_space = {
"linear.0": list(range(10, 200, 10)),
"linear.1": list(range(10, 200, 10)),
}
hyper = Hyperactive()
hyper.add_search(pytorch_cnn, search_space, n_iter=5)
hyper.run()