This repository has been archived by the owner on Oct 19, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 245
/
Copy pathdenoising_autoencoder.py
159 lines (123 loc) · 4.91 KB
/
denoising_autoencoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#!/usr/bin/env python
# coding: utf-8
# # 오토인코더로 망가진 이미지 복원하기
# 잡음제거 오토인코더(Denoising Autoencoder)는 2008년 몬트리올 대학에서 발표한 논문
# ["Extracting and Composing Robust Features with Denoising AutoEncoder"](http://www.cs.toronto.edu/~larocheh/publications/icml-2008-denoising-autoencoders.pdf)
# 에서 처음 제안되었습니다.
# 앞서 오토인코더는 일종의 "압축"을 한다고 했습니다.
# 그리고 압축은 데이터의 특성에 중요도로 우선순위를 매기고
# 낮은 우선순위의 데이터를 버린다는 뜻이기도 합니다.
# 잡음제거 오토인코더의 아이디어는
# 중요한 특징을 추출하는 오토인코더의 특성을 이용하여 비교적
# "덜 중요한 데이터"인 잡음을 버려 원래의 데이터를 복원한다는 것 입니다.
# 원래 배웠던 오토인코더와 큰 차이점은 없으며,
# 학습을 할때 입력에 잡음을 더하는 방식으로 복원 능력을 강화한 것이 핵심입니다.
# 앞서 다룬 코드와 동일하며 `add_noise()` 함수로 학습시 이미지에 노이즈를 더해주는 부분만 추가됐습니다.
import torch
import torchvision
import torch.nn.functional as F
from torch import nn, optim
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import numpy as np
# 하이퍼파라미터
EPOCH = 10
BATCH_SIZE = 64
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
print("다음 기기로 학습합니다:", DEVICE)
# Fashion MNIST 학습 데이터셋
trainset = datasets.FashionMNIST(
root = './.data/',
train = True,
download = True,
transform = transforms.ToTensor()
)
train_loader = torch.utils.data.DataLoader(
dataset = trainset,
batch_size = BATCH_SIZE,
shuffle = True,
num_workers = 2
)
class Autoencoder(nn.Module):
def __init__(self):
super(Autoencoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(28*28, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 12),
nn.ReLU(),
nn.Linear(12, 3), # 입력의 특징을 3차원으로 압축합니다
)
self.decoder = nn.Sequential(
nn.Linear(3, 12),
nn.ReLU(),
nn.Linear(12, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, 28*28),
nn.Sigmoid(), # 픽셀당 0과 1 사이로 값을 출력합니다
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return encoded, decoded
autoencoder = Autoencoder().to(DEVICE)
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.005)
criterion = nn.MSELoss()
def add_noise(img):
noise = torch.randn(img.size()) * 0.2
noisy_img = img + noise
return noisy_img
def train(autoencoder, train_loader):
autoencoder.train()
avg_loss = 0
for step, (x, label) in enumerate(train_loader):
noisy_x = add_noise(x) # 입력에 노이즈 더하기
noisy_x = noisy_x.view(-1, 28*28).to(DEVICE)
y = x.view(-1, 28*28).to(DEVICE)
label = label.to(DEVICE)
encoded, decoded = autoencoder(noisy_x)
loss = criterion(decoded, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
avg_loss += loss.item()
return avg_loss / len(train_loader)
for epoch in range(1, EPOCH+1):
loss = train(autoencoder, train_loader)
print("[Epoch {}] loss:{}".format(epoch, loss))
# 이번 예제에선 학습시 시각화를 건너 뜁니다
# # 이미지 복원 시각화 하기
# 모델이 학습시 본적이 없는 데이터로 검증하기 위해 테스트 데이터셋을 가져옵니다.
testset = datasets.FashionMNIST(
root = './.data/',
train = False,
download = True,
transform = transforms.ToTensor()
)
# 테스트셋에서 이미지 한장을 가져옵니다.
sample_data = testset.data[0].view(-1, 28*28)
sample_data = sample_data.type(torch.FloatTensor)/255.
# 이미지를 add_noise로 오염시킨 후, 모델에 통과시킵니다.
original_x = sample_data[0]
noisy_x = add_noise(original_x).to(DEVICE)
_, recovered_x = autoencoder(noisy_x)
f, a = plt.subplots(1, 3, figsize=(15, 15))
# 시각화를 위해 넘파이 행렬로 바꿔줍니다.
original_img = np.reshape(original_x.to("cpu").data.numpy(), (28, 28))
noisy_img = np.reshape(noisy_x.to("cpu").data.numpy(), (28, 28))
recovered_img = np.reshape(recovered_x.to("cpu").data.numpy(), (28, 28))
# 원본 사진
a[0].set_title('Original')
a[0].imshow(original_img, cmap='gray')
# 오염된 원본 사진
a[1].set_title('Noisy')
a[1].imshow(noisy_img, cmap='gray')
# 복원된 사진
a[2].set_title('Recovered')
a[2].imshow(recovered_img, cmap='gray')
plt.show()