-
Notifications
You must be signed in to change notification settings - Fork 0
/
simple-vision-dnn-cifar10.py
93 lines (72 loc) · 2.98 KB
/
simple-vision-dnn-cifar10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
from torch import optim, nn
from dataloader_cifar10_animal_bird import prepare_data
def calculate_loss(model, loss_fn, x, y, is_train):
with torch.set_grad_enabled(is_train):
y_hat = model(x)
loss = loss_fn(y_hat, y)
assert loss.requires_grad == is_train
return loss
def main():
learning_rate = 1e-4
batch_size = 64
epochs = 100
train_loader, val_loader, class_names = prepare_data(batch_size)
n_in = 3 * 32 * 32
n_hidden = 512
n_out = len(class_names)
model = nn.Sequential(
nn.Linear(in_features=n_in, out_features=n_hidden, bias=True),
nn.ReLU(),
nn.Linear(in_features=n_hidden, out_features=n_hidden, bias=True),
nn.ReLU(),
nn.Linear(in_features=n_hidden, out_features=n_hidden, bias=True),
nn.ReLU(),
nn.Linear(in_features=n_hidden, out_features=n_hidden, bias=True),
nn.ReLU(),
nn.Linear(in_features=n_hidden, out_features=n_out, bias=True),
)
total_trainable_params = 0
for name, params in model.named_parameters():
num_params = params.numel()
trainable_params = 0
if params.requires_grad:
total_trainable_params += num_params
trainable_params = num_params
print(f'{name}: {params.shape} params={num_params} trainable={trainable_params})')
print(f'total trainable params : {total_trainable_params}')
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=model.parameters(), lr=learning_rate)
for epoch in range(epochs):
# train one epoch
train_total = 0
train_correct = 0
for imgs, label_indices in train_loader:
num_imgs = imgs.shape[0]
imgs_1d = imgs.view(num_imgs, -1)
output = model(imgs_1d)
train_loss = loss_fn(output, label_indices)
out_scores, out_indices = torch.max(output, dim=-1)
train_total += num_imgs
train_correct += int((out_indices == label_indices).sum())
# update parameters
optimizer.zero_grad()
train_loss.backward()
optimizer.step()
# occasionally check validation set performance
if epoch % 5 == 0:
val_total = 0
val_correct = 0
for imgs, label_indices in val_loader:
num_imgs = imgs.shape[0]
imgs_1d = imgs.view(num_imgs, -1)
with torch.set_grad_enabled(False):
output = model(imgs_1d)
val_loss = loss_fn(output, label_indices)
out_scores, out_indices = torch.max(output, dim=-1)
val_total += num_imgs
val_correct += int((out_indices == label_indices).sum())
print(f'epoch = {epoch} train loss = {train_loss:0.6f} train accuracy = {train_correct / train_total} '
f'val loss = {val_loss:0.6f} val accuracy = {val_correct / val_total}')
if __name__ == '__main__':
main()