forked from Zheng222/IMDN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
78 lines (57 loc) · 2.04 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from skimage.measure import compare_psnr as psnr
from skimage.measure import compare_ssim as ssim
import numpy as np
import os
import torch
from collections import OrderedDict
def compute_psnr(im1, im2):
p = psnr(im1, im2)
return p
def compute_ssim(im1, im2):
isRGB = len(im1.shape) == 3 and im1.shape[-1] == 3
s = ssim(im1, im2, K1=0.01, K2=0.03, gaussian_weights=True, sigma=1.5, use_sample_covariance=False,
multichannel=isRGB)
return s
def shave(im, border):
border = [border, border]
im = im[border[0]:-border[0], border[1]:-border[1], ...]
return im
def modcrop(im, modulo):
sz = im.shape
h = np.int32(sz[0] / modulo) * modulo
w = np.int32(sz[1] / modulo) * modulo
ims = im[0:h, 0:w, ...]
return ims
def get_list(path, ext):
return [os.path.join(path, f) for f in os.listdir(path) if f.endswith(ext)]
def convert_shape(img):
img = np.transpose((img * 255.0).round(), (1, 2, 0))
img = np.uint8(np.clip(img, 0, 255))
return img
def quantize(img):
return img.clip(0, 255).round().astype(np.uint8)
def tensor2np(tensor, out_type=np.uint8, min_max=(0, 1)):
tensor = tensor.float().cpu().clamp_(*min_max)
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0, 1]
img_np = tensor.numpy()
img_np = np.transpose(img_np, (1, 2, 0))
if out_type == np.uint8:
img_np = (img_np * 255.0).round()
return img_np.astype(out_type)
def convert2np(tensor):
return tensor.cpu().mul(255).clamp(0, 255).byte().squeeze().permute(1, 2, 0).numpy()
def adjust_learning_rate(optimizer, epoch, step_size, lr_init, gamma):
factor = epoch // step_size
lr = lr_init * (gamma ** factor)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def load_state_dict(path):
state_dict = torch.load(path)
new_state_dcit = OrderedDict()
for k, v in state_dict.items():
if 'module' in k:
name = k[7:]
else:
name = k
new_state_dcit[name] = v
return new_state_dcit