-
Notifications
You must be signed in to change notification settings - Fork 4
/
utils.py
42 lines (32 loc) · 1.02 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
def log(data, name):
print('{} data is:\n'.format(name), data)
print('{} size is:'.format(name), data.size())
print('{} data(>0):'.format(name), data[data > 0])
def toTensor(img):
assert type(img) == np.ndarray, 'the img type is {}, but ndarry expected'.format(type(img))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose((2, 0, 1)))
return img.float().div(255).unsqueeze(0)
def tensor_to_np(tensor):
img = tensor.mul(255).byte()
img = img.cpu().numpy().squeeze(0).transpose((1, 2, 0))
return img
def show_from_cv(img, title=None):
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.figure()
plt.imshow(img)
if title is not None:
plt.title(title)
plt.pause(0.001)
def show_from_tensor(tensor, title=None):
img = tensor.clone()
img = tensor_to_np(img)
plt.figure()
plt.imshow(img)
if title is not None:
plt.title(title)
plt.pause(0.001)