-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsaver.py
73 lines (64 loc) · 2.62 KB
/
saver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import torchvision
from tensorboardX import SummaryWriter
import numpy as np
from PIL import Image
# tensor to PIL Image
def tensor2img(img):
img = img[0].cpu().float().numpy()
if img.shape[0] == 1:
img = np.tile(img, (3, 1, 1))
img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0
return img.astype(np.uint8)
# save a set of images
def save_imgs(imgs, names, path):
if not os.path.exists(path):
os.mkdir(path)
for img, name in zip(imgs, names):
img = tensor2img(img)
img = Image.fromarray(img)
img.save(os.path.join(path, name + '.png'))
class Saver():
def __init__(self, opts):
self.display_dir = os.path.join(opts.display_dir, opts.name)
self.model_dir = os.path.join(opts.result_dir, opts.name)
self.image_dir = os.path.join(self.model_dir, 'images')
self.display_freq = opts.display_freq
self.img_save_freq = opts.img_save_freq
self.model_save_freq = opts.model_save_freq
# make directory
if not os.path.exists(self.display_dir):
os.makedirs(self.display_dir)
if not os.path.exists(self.model_dir):
os.makedirs(self.model_dir)
if not os.path.exists(self.image_dir):
os.makedirs(self.image_dir)
# create tensorboard writer
self.writer = SummaryWriter(logdir=self.display_dir)
# write losses and images to tensorboard
def write_display(self, total_it, model):
if (total_it + 1) % self.display_freq == 0:
# write loss
members = [attr for attr in dir(model) if not callable(getattr(model, attr)) and not attr.startswith("__") and 'loss' in attr]
for m in members:
self.writer.add_scalar(m, getattr(model, m), total_it)
# write img
image_dis = torchvision.utils.make_grid(model.image_display, nrow=model.image_display.size(0)//2)/2 + 0.5
self.writer.add_image('Image', image_dis, total_it)
# save result images
def write_img(self, ep, model):
if (ep + 1) % self.img_save_freq == 0:
assembled_images = model.assemble_outputs()
img_filename = '%s/gen_%05d.jpg' % (self.image_dir, ep)
torchvision.utils.save_image(assembled_images / 2 + 0.5, img_filename, nrow=1)
elif ep == -1:
assembled_images = model.assemble_outputs()
img_filename = '%s/gen_last.jpg' % (self.image_dir, ep)
torchvision.utils.save_image(assembled_images / 2 + 0.5, img_filename, nrow=1)
# save model
def write_model(self, ep, total_it, model):
if (ep + 1) % self.model_save_freq == 0:
print('--- save the model @ ep %d ---' % (ep))
model.save('%s/%05d.pth' % (self.model_dir, ep), ep, total_it)
elif ep == -1:
model.save('%s/last.pth' % self.model_dir, ep, total_it)