-
Notifications
You must be signed in to change notification settings - Fork 0
/
single_inference.py
56 lines (44 loc) · 2 KB
/
single_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, Subset
from tqdm import tqdm
from tsdf_dataset import ShapeNet
from model.pvqvae.vqvae import VQVAE
import argparse
import pickle
from utils import shape2patch, patch2shape, display_tsdf
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Inference PVQVAE')
parser.add_argument('--mesh_path', type=str, default='./dataset/table/table_126.pkl', help='path to input mesh')
parser.add_argument('--model_path', type=str, default='./best_model.pth', help='Path to model')
parser.add_argument('--num_embed', type=int, default=512, help='Number of embeddings')
parser.add_argument('--embed_dim', type=int, default=256, help='Embedding dimension')
args = parser.parse_args()
with open(args.mesh_path, 'rb') as f:
tsdf = pickle.load(f)
tsdf, _ = tsdf['tsdf'], tsdf['model_path']
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: ", device)
num_embed = args.num_embed
embed_dim = args.embed_dim
model = VQVAE(num_embeddings=num_embed, embed_dim=embed_dim).to(device)
# Load model
model.load_state_dict(torch.load(args.model_path))
model.to(device)
model.eval()
tsdf = torch.from_numpy(tsdf)
tsdf = tsdf.to(device)
input_tsdf = torch.reshape(tsdf, (1, 1, *tsdf.shape))
patched_tsdf = shape2patch(input_tsdf)
with torch.no_grad():
# reconstructed_data, test_vq_loss, test_com_loss = model(patched_tsdf, is_training=False)
reconstructed_data = model(patched_tsdf, is_training=False)
test_recon_loss = torch.mean((reconstructed_data - tsdf) ** 2)
print(f'{test_recon_loss=}')
print(input_tsdf.shape)
print(reconstructed_data.shape)
display_tsdf(tsdf.cpu(), 0)
reconstructed_data = torch.squeeze(reconstructed_data)
display_tsdf(reconstructed_data.cpu(), (reconstructed_data.max() + reconstructed_data.min())/ 2)