-
Notifications
You must be signed in to change notification settings - Fork 3
/
predict.py
111 lines (88 loc) · 4.32 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import argparse
import os
import torch
import yaml
import numpy as np
import torch.nn.functional as F
import cv2
import config_folder as cf
from data_loaders.Chairs import Chairs
from data_loaders.kitti import KITTI
from data_loaders.sintel import Sintel
from data_loaders.denso import DENSO
from model import MaskFlownet, MaskFlownet_S, Upsample, EpeLossWithMask
import matplotlib.pyplot as plt
from utils.flow_utils import flow_to_image
def centralize(img1, img2):
rgb_mean = torch.cat((img1, img2), 2)
rgb_mean = rgb_mean.view(rgb_mean.shape[0], 3, -1).mean(2)
rgb_mean = rgb_mean.view(rgb_mean.shape[0], 3, 1, 1)
return img1 - rgb_mean, img2-rgb_mean, rgb_mean
parser = argparse.ArgumentParser()
parser.add_argument('config', type=str, nargs='?', default=None)
parser.add_argument('--dataset_cfg', type=str, default='chairs.yaml')
parser.add_argument('-c', '--checkpoint', type=str, default=None,
help='model checkpoint to load')
parser.add_argument('-b', '--batch', type=int, default=1,
help='Batch Size')
parser.add_argument('-f', '--root_folder', type=str, default=None,
help='Root folder of KITTI')
parser.add_argument('--split_file', type=str, default='')
parser.add_argument('--resize', type=str, default='')
parser.add_argument('--output_dir', type=str, default='')
args = parser.parse_args()
resize = (int(args.resize.split(',')[0]), int(args.resize.split(',')[1])) if args.resize else None
num_workers = 2
with open(os.path.join('config_folder', args.dataset_cfg)) as f:
config = cf.Reader(yaml.load(f))
with open(os.path.join('config_folder', args.config)) as f:
config_model = cf.Reader(yaml.load(f))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = eval(config_model.value['network']['class'])(config)
checkpoint = torch.load(os.path.join('weights', args.checkpoint))
net.load_state_dict(checkpoint)
net = net.to(device)
if config.value['dataset'] == 'kitti':
dataset = KITTI(args.root_folder, split='train', editions='mixed', resize=resize, parts='valid')
elif config.value['dataset'] == 'chairs':
dataset = Chairs(args.root_folder, split='valid')
elif config.value['dataset'] == 'sintel':
dataset = Sintel(args.root_folder, split='valid', subset='final')
elif config.value['dataset'] == 'denso':
dataset = DENSO(args.root_folder, args.split_file, resize = resize)
data_loader = torch.utils.data.DataLoader(dataset=dataset,
shuffle=False,
batch_size=args.batch,
num_workers=num_workers,
drop_last=False,
pin_memory=True)
epe = []
for idx, sample in enumerate(data_loader):
with torch.no_grad():
im0, im1, label, mask, path = sample
# if isinstance(mask, list):
# mask = torch.ones((label.shape[0], label.shape[1], label.shape[2], 1), dtype=label.dtype, device=device)
im0 = im0.permute(0, 3, 1, 2)
im1 = im1.permute(0, 3, 1, 2)
im0, im1, _ = centralize(im0, im1)
# label = label.permute(0, 3, 1, 2).to(device).flip(1)
# mask = mask.permute(0, 3, 1, 2).to(device)
shape = im0.shape
pad_h = (64 - shape[2] % 64) % 64
pad_w = (64 - shape[3] % 64) % 64
if pad_h != 0 or pad_w != 0:
im0 = F.interpolate(im0, size=[shape[2] + pad_h, shape[3] + pad_w], mode='bilinear')
im1 = F.interpolate(im1, size=[shape[2] + pad_h, shape[3] + pad_w], mode='bilinear')
im0 = im0.to(device)
im1 = im1.to(device)
pred, flows, warpeds = net(im0, im1)
up_flow = Upsample(pred[-1], 4)
up_occ_mask = Upsample(flows[0], 4)
if pad_h != 0 or pad_w != 0:
up_flow = F.interpolate(up_flow, size=[shape[2], shape[3]], mode='bilinear') * \
torch.tensor([shape[d] / up_flow.shape[d] for d in (2, 3)], device=device).view(1, 2, 1, 1)
up_occ_mask = F.interpolate(up_occ_mask, size=[shape[2], shape[3]], mode='bilinear')
np_flow = up_flow.detach().cpu().numpy().squeeze(0).transpose([1, 2, 0])
rgb_flow = flow_to_image(np_flow)
filename = os.path.join(args.output_dir, str(idx+1))
plt.imsave(filename, rgb_flow, format='png')