-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathvgg_loss.py
149 lines (118 loc) · 5.7 KB
/
vgg_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""A VGG-based perceptual loss function for PyTorch."""
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import models, transforms
class Lambda(nn.Module):
"""Wraps a callable in an :class:`nn.Module` without registering it."""
def __init__(self, func):
super().__init__()
object.__setattr__(self, 'forward', func)
def extra_repr(self):
return getattr(self.forward, '__name__', type(self.forward).__name__) + '()'
class WeightedLoss(nn.ModuleList):
"""A weighted combination of multiple loss functions."""
def __init__(self, losses, weights, verbose=False):
super().__init__()
for loss in losses:
self.append(loss if isinstance(loss, nn.Module) else Lambda(loss))
self.weights = weights
self.verbose = verbose
def _print_losses(self, losses):
for i, loss in enumerate(losses):
print(f'({i}) {type(self[i]).__name__}: {loss.item()}')
def forward(self, *args, **kwargs):
losses = []
for loss, weight in zip(self, self.weights):
losses.append(loss(*args, **kwargs) * weight)
if self.verbose:
self._print_losses(losses)
return sum(losses)
class TVLoss(nn.Module):
"""Total variation loss (Lp penalty on image gradient magnitude).
The input must be 4D. If a target (second parameter) is passed in, it is
ignored.
``p=1`` yields the vectorial total variation norm. It is a generalization
of the originally proposed (isotropic) 2D total variation norm (see
(see https://en.wikipedia.org/wiki/Total_variation_denoising) for color
images. On images with a single channel it is equal to the 2D TV norm.
``p=2`` yields a variant that is often used for smoothing out noise in
reconstructions of images from neural network feature maps (see Mahendran
and Vevaldi, "Understanding Deep Image Representations by Inverting
Them", https://arxiv.org/abs/1412.0035)
:attr:`reduction` can be set to ``'mean'``, ``'sum'``, or ``'none'``
similarly to the loss functions in :mod:`torch.nn`. The default is
``'mean'``.
"""
def __init__(self, p, reduction='mean', eps=1e-8):
super().__init__()
if p not in {1, 2}:
raise ValueError('p must be 1 or 2')
if reduction not in {'mean', 'sum', 'none'}:
raise ValueError("reduction must be 'mean', 'sum', or 'none'")
self.p = p
self.reduction = reduction
self.eps = eps
def forward(self, input, target=None):
input = F.pad(input, (0, 1, 0, 1), 'replicate')
x_diff = input[..., :-1, :-1] - input[..., :-1, 1:]
y_diff = input[..., :-1, :-1] - input[..., 1:, :-1]
diff = x_diff**2 + y_diff**2
if self.p == 1:
diff = (diff + self.eps).mean(dim=1, keepdims=True).sqrt()
if self.reduction == 'mean':
return diff.mean()
if self.reduction == 'sum':
return diff.sum()
return diff
class VGGLoss(nn.Module):
"""Computes the VGG perceptual loss between two batches of images.
The input and target must be 4D tensors with three channels
``(B, 3, H, W)`` and must have equivalent shapes. Pixel values should be
normalized to the range 0–1.
The VGG perceptual loss is the mean squared difference between the features
computed for the input and target at layer :attr:`layer` (default 8, or
``relu2_2``) of the pretrained model specified by :attr:`model` (either
``'vgg16'`` (default) or ``'vgg19'``).
If :attr:`shift` is nonzero, a random shift of at most :attr:`shift`
pixels in both height and width will be applied to all images in the input
and target. The shift will only be applied when the loss function is in
training mode, and will not be applied if a precomputed feature map is
supplied as the target.
:attr:`reduction` can be set to ``'mean'``, ``'sum'``, or ``'none'``
similarly to the loss functions in :mod:`torch.nn`. The default is
``'mean'``.
:meth:`get_features()` may be used to precompute the features for the
target, to speed up the case where inputs are compared against the same
target over and over. To use the precomputed features, pass them in as
:attr:`target` and set :attr:`target_is_features` to :code:`True`.
Instances of :class:`VGGLoss` must be manually converted to the same
device and dtype as their inputs.
"""
models = {'vgg16': models.vgg16, 'vgg19': models.vgg19}
def __init__(self, model='vgg16', layer=8, shift=0, reduction='mean'):
super().__init__()
self.shift = shift
self.reduction = reduction
self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
self.model = self.models[model](pretrained=True).features[:layer+1]
self.model.eval()
self.model.requires_grad_(False)
def get_features(self, input):
return self.model(self.normalize(input))
def train(self, mode=True):
self.training = mode
def forward(self, input, target, target_is_features=False):
if target_is_features:
input_feats = self.get_features(input)
target_feats = target
else:
sep = input.shape[0]
batch = torch.cat([input, target])
if self.shift and self.training:
padded = F.pad(batch, [self.shift] * 4, mode='replicate')
batch = transforms.RandomCrop(batch.shape[2:])(padded)
feats = self.get_features(batch)
input_feats, target_feats = feats[:sep], feats[sep:]
return F.mse_loss(input_feats, target_feats, reduction=self.reduction)