-
Notifications
You must be signed in to change notification settings - Fork 70
/
Copy pathtest.py
92 lines (66 loc) · 2.48 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import sys
import time
import copy
import shutil
import random
import pdb
import torch
import numpy as np
from tqdm import tqdm
import config
import myutils
from torch.utils.data import DataLoader
##### Parse CmdLine Arguments #####
os.environ["CUDA_VISIBLE_DEVICES"]='7'
args, unparsed = config.get_args()
cwd = os.getcwd()
device = torch.device('cuda' if args.cuda else 'cpu')
torch.manual_seed(args.random_seed)
if args.cuda:
torch.cuda.manual_seed(args.random_seed)
if args.dataset == "vimeo90K_septuplet":
from dataset.vimeo90k_septuplet import get_loader
test_loader = get_loader('test', args.data_root, args.test_batch_size, shuffle=False, num_workers=args.num_workers)
elif args.dataset == "ucf101":
from dataset.ucf101_test import get_loader
test_loader = get_loader(args.data_root, args.test_batch_size, shuffle=False, num_workers=args.num_workers)
elif args.dataset == "gopro":
from dataset.GoPro import get_loader
test_loader = get_loader(args.data_root, args.test_batch_size, shuffle=False, num_workers=args.num_workers, test_mode=True, interFrames=args.n_outputs)
else:
raise NotImplementedError
from model.FLAVR_arch import UNet_3D_3D
print("Building model: %s"%args.model.lower())
model = UNet_3D_3D(args.model.lower() , n_inputs=args.nbr_frame, n_outputs=args.n_outputs, joinType=args.joinType)
model = torch.nn.DataParallel(model).to(device)
print("#params" , sum([p.numel() for p in model.parameters()]))
def test(args):
time_taken = []
losses, psnrs, ssims = myutils.init_meters(args.loss)
model.eval()
psnr_list = []
with torch.no_grad():
for i, (images, gt_image ) in enumerate(tqdm(test_loader)):
images = [img_.cuda() for img_ in images]
gt = [g_.cuda() for g_ in gt_image]
torch.cuda.synchronize()
start_time = time.time()
out = model(images)
out = torch.cat(out)
gt = torch.cat(gt)
torch.cuda.synchronize()
time_taken.append(time.time() - start_time)
myutils.eval_metrics(out, gt, psnrs, ssims)
print("PSNR: %f, SSIM: %fn" %
(psnrs.avg, ssims.avg))
print("Time , " , sum(time_taken)/len(time_taken))
return psnrs.avg
""" Entry Point """
def main(args):
assert args.load_from is not None
model_dict = model.state_dict()
model.load_state_dict(torch.load(args.load_from)["state_dict"] , strict=True)
test(args)
if __name__ == "__main__":
main(args)