-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmnist_main.py
89 lines (74 loc) · 3.2 KB
/
mnist_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#!/usr/bin/python3
# -*-coding:utf-8 -*-
# Reference:**********************************************
# @Time : 4/13/2020 11:01 PM
# @Author : Gaopeng.Bai
# @File : mnist_main.py
# @User : gaope
# @Software: PyCharm
# @Description:
# Reference:**********************************************
import numpy as np
import argparse
import torch
from torchvision.datasets import mnist
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import torchvision.transforms as transforms
from utils.model import model_select
parser = argparse.ArgumentParser(description='PyTorch mnist Training')
parser.add_argument('--model', default="simply_cnn", type=str,
metavar='N', help=' (lenet5, simply_cnn, -alexnet-) ')
parser.add_argument('--epochs', default=15, type=int,
metavar='N', help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=128, type=int, metavar='N',
help='mini-batch size (default: 128),only used for train')
parser.add_argument('--lr', '--learning-rate', default=1e-2,
type=float, metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float,
metavar='M', help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4,
type=float, metavar='W', help='weight decay (default: 1e-4)')
args = parser.parse_args()
def main(args):
normalize = transforms.Normalize(
mean=[0.131], std=[0.308])
train_dataset = mnist.MNIST(root='../data', train=True, download=True, transform=transforms.Compose([
transforms.ToTensor(),
normalize,
]))
test_dataset = mnist.MNIST(root='../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
normalize,
]))
train_loader = DataLoader(train_dataset, batch_size=args.batch_size)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size)
model = model_select(args.model)
sgd = SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
cross_error = CrossEntropyLoss()
for _epoch in range(args.epochs):
print("Epoch {}/{}".format(_epoch, args.epochs))
print("-" * 10)
for idx, (train_x, train_label) in enumerate(train_loader):
label_np = np.zeros((train_label.shape[0], 10))
sgd.zero_grad()
predict_y = model(train_x.float())
_error = cross_error(predict_y, train_label.long())
if idx % 100 == 0:
print('idx: {}, _error: {}'.format(idx, _error))
_error.backward()
sgd.step()
correct = 0
_sum = 0
for idx, (test_x, test_label) in enumerate(test_loader):
predict_y = model(test_x.float()).detach()
predict_ys = np.argmax(predict_y, axis=-1)
label_np = test_label.numpy()
_ = predict_ys == test_label
correct += np.sum(_.numpy(), axis=-1)
_sum += _.shape[0]
print('test accuracy: {:.2f}'.format(correct / _sum))
if __name__ == '__main__':
main(args)