This repository has been archived by the owner on Oct 19, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 245
/
gan.py
145 lines (111 loc) · 4.89 KB
/
gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#!/usr/bin/env python
# coding: utf-8
# # GAN으로 새로운 패션아이템 생성하기
# *GAN을 이용하여 새로운 패션 아이템을 만들어봅니다*
# 이 프로젝트는 최윤제님의 파이토치 튜토리얼 사용 허락을 받아 참고했습니다.
# * [yunjey/pytorch-tutorial](https://github.com/yunjey/pytorch-tutorial) - MIT License
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np
# EPOCHS 과 BATCH_SIZE 등 학습에 필요한 하이퍼 파라미터 들을 설정해 줍니다.
# 하이퍼파라미터
EPOCHS = 500
BATCH_SIZE = 100
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
print("Using Device:", DEVICE)
# 학습에 필요한 데이터셋을 로딩합니다.
# Fashion MNIST 데이터셋
trainset = datasets.FashionMNIST(
'./.data',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
)
train_loader = torch.utils.data.DataLoader(
dataset = trainset,
batch_size = BATCH_SIZE,
shuffle = True
)
# 생성자는 64차원의 랜덤한 텐서를 입력받아 이에 행렬곱(Linear)과 활성화 함수(ReLU, Tanh) 연산을 실행합니다. 생성자의 결과값은 784차원, 즉 Fashion MNIST 속의 이미지와 같은 차원의 텐서입니다.
# 생성자 (Generator)
G = nn.Sequential(
nn.Linear(64, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 784),
nn.Tanh())
# 판별자는 784차원의 텐서를 입력받습니다. 판별자 역시 입력된 데이터에 행렬곱과 활성화 함수를 실행시키지만, 생성자와 달리 판별자의 결과값은 입력받은 텐서가 진짜인지 구분하는 예측값입니다.
# 판별자 (Discriminator)
D = nn.Sequential(
nn.Linear(784, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid())
# 생성자와 판별자 학습에 쓰일 오차 함수와 최적화 알고리즘도 정의해 줍니다.
# 모델의 가중치를 지정한 장치로 보내기
D = D.to(DEVICE)
G = G.to(DEVICE)
# 이진 크로스 엔트로피 (Binary cross entropy) 오차 함수와
# 생성자와 판별자를 최적화할 Adam 모듈
criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = optim.Adam(G.parameters(), lr=0.0002)
# 모델 학습에 필요한 준비는 끝났습니다. 그럼 본격적으로 GAN을 학습시키는 loop을 만들어 보겠습니다.
total_step = len(train_loader)
for epoch in range(EPOCHS):
for i, (images, _) in enumerate(train_loader):
images = images.reshape(BATCH_SIZE, -1).to(DEVICE)
# '진짜'와 '가짜' 레이블 생성
real_labels = torch.ones(BATCH_SIZE, 1).to(DEVICE)
fake_labels = torch.zeros(BATCH_SIZE, 1).to(DEVICE)
# 판별자가 진짜 이미지를 진짜로 인식하는 오차를 예산
outputs = D(images)
d_loss_real = criterion(outputs, real_labels)
real_score = outputs
# 무작위 텐서로 가짜 이미지 생성
z = torch.randn(BATCH_SIZE, 64).to(DEVICE)
fake_images = G(z)
# 판별자가 가짜 이미지를 가짜로 인식하는 오차를 계산
outputs = D(fake_images)
d_loss_fake = criterion(outputs, fake_labels)
fake_score = outputs
# 진짜와 가짜 이미지를 갖고 낸 오차를 더해서 판별자의 오차 계산
d_loss = d_loss_real + d_loss_fake
# 역전파 알고리즘으로 판별자 모델의 학습을 진행
d_optimizer.zero_grad()
g_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# 생성자가 판별자를 속였는지에 대한 오차를 계산
fake_images = G(z)
outputs = D(fake_images)
g_loss = criterion(outputs, real_labels)
# 역전파 알고리즘으로 생성자 모델의 학습을 진행
d_optimizer.zero_grad()
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
# 학습 진행 알아보기
print('Epoch [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
.format(epoch, EPOCHS, d_loss.item(), g_loss.item(),
real_score.mean().item(), fake_score.mean().item()))
# 학습이 끝난 생성자의 결과물을 한번 확인해 보겠습니다.
z = torch.randn(BATCH_SIZE, 64).to(DEVICE)
fake_images = G(z)
for i in range(10):
fake_images_img = np.reshape(fake_images.data.cpu().numpy()[i],(28, 28))
plt.imshow(fake_images_img, cmap = 'gray')
plt.show()