forked from jvanvugt/pytorch-domain-adaptation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
54 lines (39 loc) · 1.34 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from PIL import Image
import numpy as np
import torch
from torch.autograd import Function
def set_requires_grad(model, requires_grad=True):
for param in model.parameters():
param.requires_grad = requires_grad
def loop_iterable(iterable):
while True:
yield from iterable
class GrayscaleToRgb:
"""Convert a grayscale image to rgb"""
def __call__(self, image):
image = np.array(image)
image = np.dstack([image, image, image])
return Image.fromarray(image)
class GradientReversalFunction(Function):
"""
Gradient Reversal Layer from:
Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015)
Forward pass is the identity function. In the backward pass,
the upstream gradients are multiplied by -lambda (i.e. gradient is reversed)
"""
@staticmethod
def forward(ctx, x, lambda_):
ctx.lambda_ = lambda_
return x.clone()
@staticmethod
def backward(ctx, grads):
lambda_ = ctx.lambda_
lambda_ = grads.new_tensor(lambda_)
dx = -lambda_ * grads
return dx, None
class GradientReversal(torch.nn.Module):
def __init__(self, lambda_=1):
super(GradientReversal, self).__init__()
self.lambda_ = lambda_
def forward(self, x):
return GradientReversalFunction.apply(x, self.lambda_)