Skip to content

Commit

Permalink
updated demo for longer sequences
Browse files Browse the repository at this point in the history
  • Loading branch information
Zach Teed committed Oct 5, 2020
1 parent 25eb2ac commit d3f3840
Showing 1 changed file with 12 additions and 17 deletions.
29 changes: 12 additions & 17 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,9 @@
def load_image(imfile):
img = np.array(Image.open(imfile)).astype(np.uint8)
img = torch.from_numpy(img).permute(2, 0, 1).float()
return img
return img[None].to(DEVICE)


def load_image_list(image_files):
images = []
for imfile in sorted(image_files):
images.append(load_image(imfile))

images = torch.stack(images, dim=0)
images = images.to(DEVICE)

padder = InputPadder(images.shape)
return padder.pad(images)[0]


def viz(img, flo):
img = img[0].permute(1,2,0).cpu().numpy()
flo = flo[0].permute(1,2,0).cpu().numpy()
Expand All @@ -43,6 +31,10 @@ def viz(img, flo):
flo = flow_viz.flow_to_image(flo)
img_flo = np.concatenate([img, flo], axis=0)

# import matplotlib.pyplot as plt
# plt.imshow(img_flo / 255.0)
# plt.show()

cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
cv2.waitKey()

Expand All @@ -58,11 +50,14 @@ def demo(args):
with torch.no_grad():
images = glob.glob(os.path.join(args.path, '*.png')) + \
glob.glob(os.path.join(args.path, '*.jpg'))

images = sorted(images)
for imfile1, imfile2 in zip(images[:-1], images[1:]):
image1 = load_image(imfile1)
image2 = load_image(imfile2)

images = load_image_list(images)
for i in range(images.shape[0]-1):
image1 = images[i,None]
image2 = images[i+1,None]
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)

flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
viz(image1, flow_up)
Expand Down

0 comments on commit d3f3840

Please sign in to comment.