-
Notifications
You must be signed in to change notification settings - Fork 3
/
eval.py
executable file
·60 lines (44 loc) · 1.77 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import logging
import torch
import torch.nn as nn
from lib.datasets import image_caption
from lib.scanpp import SCANpp
from lib import evaluation
from lib.vocab import Vocabulary, deserialize_vocab
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
def main(model_path, split, gpuid='0', fold5=False):
print("use GPU:", gpuid)
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpuid)
# load model and options
checkpoint = torch.load(model_path)
opt = checkpoint['opt']
# load vocabulary used by the model
if 'coco' in opt.data_name:
vocab_file = 'coco_precomp_vocab.json'
else:
vocab_file = 'f30k_precomp_vocab.json'
vocab = deserialize_vocab(os.path.join(opt.vocab_path, vocab_file))
vocab.add_word('<mask>')
opt.vocab_size = len(vocab)
# construct model
model = SCANpp(opt)
model.cuda()
model = nn.DataParallel(model)
# load model state
model.load_state_dict(checkpoint['model'])
data_loader = image_caption.get_test_loader(split, opt.data_name, vocab,
opt.batch_size, opt.workers, opt)
logger.info(opt)
logger.info('Computing results with checkpoint_{}'.format(checkpoint['epoch']))
evaluation.evalrank(model.module, data_loader, opt, split, fold5)
if __name__ == '__main__':
main('runs/f30k_t2i_rcar2/model_best.pth', 'test', '0', False)
main('runs/f30k_i2t_rcar2/model_best.pth', 'test', '0', False)
main('runs/coco_t2i_rcar2/model_best.pth', 'testall', '0', True)
main('runs/coco_i2t_rcar2/model_best.pth', 'testall', '0', True)
main('runs/coco_t2i_rcar2/model_best.pth', 'testall', '0', False)
main('runs/coco_i2t_rcar2/model_best.pth', 'testall', '0', False)
logger.info('finished')