-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
112 lines (93 loc) · 2.64 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import pprint
import argparse
import torch
import numpy as np
from dataset import ShapeNet15k
from model import Generator
from trainer import Trainer
def parse_args():
root_dir = os.path.abspath(os.path.dirname(__file__))
parser = argparse.ArgumentParser()
# Environment settings
parser.add_argument(
"--data_dir",
type=str,
default=os.path.join(root_dir, "data"),
help="Path to dataset directory.",
)
parser.add_argument(
"--ckpt_path", type=str, required=True, help="Path to checkpoint file. "
)
# Testing settings
parser.add_argument(
"--submit",
default=False,
action="store_true",
help="Generate submission for GradeScope.",
)
parser.add_argument(
"--seed", type=int, default=0, help="Manual seed for reproducibility."
)
parser.add_argument(
"--cate", type=str, default="airplane", help="ShapeNet15k category."
)
parser.add_argument("--split", type=str, default="val", help="ShapeNet15k split.")
parser.add_argument(
"--batch_size",
type=int,
default=64,
help="Minibatch size used during training and testing.",
)
parser.add_argument(
"--sample_size",
type=int,
default=2048,
help="Number of points sampled from each training sample.",
)
parser.add_argument(
"--device",
type=str,
default=("cuda:0" if torch.cuda.is_available() else "cpu"),
help="Accelerator to use.",
)
return parser.parse_args()
def main(args):
"""
Testing entry point.
"""
# Print args
pprint.pprint(vars(args))
# Fix seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# Setup dataloaders
test_loader = torch.utils.data.DataLoader(
dataset=ShapeNet15k(
root=args.data_dir,
cate=args.cate,
split=args.split,
random_sample=False,
sample_size=args.sample_size,
),
batch_size=args.batch_size,
shuffle=False,
num_workers=2,
pin_memory=True,
drop_last=False,
)
# Setup model
net_g = Generator()
# Setup trainer
trainer = Trainer(net_g=net_g, batch_size=args.batch_size, device=args.device,)
# Load checkpoint
trainer.load_checkpoint(args.ckpt_path)
# Start testing
(metrics, submission), _ = trainer.test(test_loader)
torch.set_printoptions(precision=6)
pprint.pprint(metrics)
# Generate submission
if args.submit:
torch.save(submission, "submission.pth")
if __name__ == "__main__":
main(parse_args())