-
Notifications
You must be signed in to change notification settings - Fork 0
/
StyleTransfer.py
144 lines (118 loc) · 4.99 KB
/
StyleTransfer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
from torch import optim, nn
from torchvision import models, transforms, datasets
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import time
class NeuralStyleTransfer:
def __init__(self, content_image_path, style_image_path):
self.model = models.vgg19(pretrained=True).features
for pram in self.model.parameters():
pram.requires_grad_(False)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.content = self.load_image(content_image_path).to(self.device)
self.target = self.content.clone().requires_grad_(True).to(self.device)
self.style = self.load_image(style_image_path).to(self.device)
def load_image(self, path, max_size=400, shape=None, gray=False):
image = Image.open(path).convert('RGB')
if shape is not None:
size = shape
elif max(image.size) > max_size:
size = max_size
else:
size = max(image.size)
if gray:
image = image.convert('L')
image = image.convert('RGB')
in_transform = transforms.Compose([transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize(
(0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
return in_transform(image).unsqueeze(0)
def get_features(self, model, image):
layers = {
'0': 'conv1_1',
'5': 'conv2_1',
'10': 'conv3_1',
'19': 'conv4_1',
'21': 'conv4_2',
'28': 'conv5_1'
}
features = {}
x = image
for name, layer in self.model._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x
return features
def gramian(self, tensor):
# Compute the gramian matrix of a single channel from a single conv layer.
t = tensor.view(tensor.shape[1], -1)
return t @ t.T
def content_loss(self, c_features, t_features):
# Compute mean squared content loss of all feature maps.
loss = 0.5 * (t_features['conv4_2'] - c_features['conv4_2']) ** 2
return torch.mean(loss)
def style_loss(self, s_grams, t_features, s_features, weights):
# Compute style loss, i.e. the weighted sum of MSE of all layers.
# for each style feature, get target and style gramians, compare
loss = 0
for layer in weights:
_, d, h, w = s_features[layer].shape
t_gram = self.gramian(t_features[layer])
layer_loss = torch.mean((t_gram - s_grams[layer]) ** 2) / (d * h * w)
loss += layer_loss * weights[layer]
return loss
def im_convert(self, tensor):
# Display a tensor as an image.
image = tensor.to("cpu").clone().detach()
image = image.numpy().squeeze(0)
image = image.transpose(1, 2, 0)
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
image = image.clip(0, 1)
return image
def save_image(self):
picture = self.im_convert(self.target)
plt.imshow(picture)
import matplotlib.image as im
ll = im.imsave('Result/target.jpg', picture)
def forward(self):
start = time.time()
style_weights = {'conv1_1': .2,
'conv2_1': .2,
'conv3_1': .2,
'conv4_1': .2,
'conv5_1': .2}
show = 500
steps = 10000
c_weight = 2
s_weight = 50
s_features = self.get_features(self.model, self.style)
c_features = self.get_features(self.model, self.content)
s_grams = {layer: self.gramian(features) for layer, features in s_features.items()}
opt = optim.Adam([self.target], lr=0.009)
print('Creating Style...')
for step in range(1, steps + 1):
opt.zero_grad()
t_features = self.get_features(self.model, self.target)
c_loss = self.content_loss(c_features, t_features)
s_loss = self.style_loss(s_grams, t_features, s_features, style_weights)
total_loss = c_weight * c_loss + s_weight * s_loss
total_loss.backward()
opt.step()
if step % show == 0:
print('=========Total loss: ', total_loss.item(), 'after ', step, ' steps =========')
plt.imshow(self.im_convert(self.target))
plt.show()
end = time.time()
print('time required: ', end - start)
self.save_image()
return None
if __name__ == '__main__':
content_path = 'content_images/content_image.jpg'
style_path = 'style_images/style.jpg'
transfer = NeuralStyleTransfer(content_image_path=content_path, style_image_path=style_path)
transfer.forward()