-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathbenchmarking.py
121 lines (107 loc) · 4.72 KB
/
benchmarking.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import sys
sys.path.append("src")
import argparse
import os
from omegaconf import OmegaConf
from tqdm import tqdm
from functools import partialmethod
from imagen_hub.benchmark import benchmark_infer, \
infer_control_guided_ig_bench, \
infer_text_guided_ig_bench, \
infer_text_guided_ie_bench, \
infer_mask_guided_ie_bench, \
infer_subject_driven_ig_bench, \
infer_subject_driven_ie_bench, \
infer_multi_concept_ic_bench
try:
import inquirer
except:
print("Please install inquirer package to use the interactive mode.")
print("pip install inquirer")
def parser():
parser = argparse.ArgumentParser(
description="benchmarking.py: Running Benchmark scripts for experiment.")
parser.add_argument("-cfg", "--cfg", type=str,
help="Path to the YAML configuration file")
parser.add_argument("-quiet", "--quiet", action='store_true',
help="Disable tqdm progress bar.")
return parser.parse_args()
def check_arguments_errors(args):
if args.cfg and not os.path.isfile(args.cfg):
raise (ValueError("Invalid path {}".format(os.path.abspath(args.cfg))))
def list_config_files(folder):
return [f for f in os.listdir(folder) if f.endswith('.yaml') or f.endswith('.yml')]
def select_config_file():
config_folder = "benchmark_cfg"
config_files = list_config_files(config_folder)
questions = [
inquirer.List('cfg', message="Select the configuration file", choices=config_files)
]
answers = inquirer.prompt(questions)
return os.path.join(config_folder, answers['cfg'])
def main():
args = parser()
check_arguments_errors(args)
if not args.cfg:
args.cfg = select_config_file()
config = OmegaConf.load(args.cfg)
print("=====> Config content:")
print(OmegaConf.to_yaml(config))
print("======================")
# Make tqdm disable
if args.quiet:
print("Disabled tqdm.")
tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
# Access specific values
result_folder = config.params.save_to_folder
limit_images_amount=config.params.limit_images_amount
experiment_basename = config.params.experiment_basename
model_list = config.info.running_models
task_id = config.info.task_id
if task_id == 0:
benchmark_infer(experiment_basename,
model_list = model_list,
limit_images_amount = limit_images_amount,
result_folder = result_folder,
infer_dataset_fn=infer_text_guided_ie_bench)
elif task_id == 1:
benchmark_infer(experiment_basename,
model_list = model_list,
limit_images_amount = limit_images_amount,
result_folder = result_folder,
infer_dataset_fn=infer_mask_guided_ie_bench)
elif task_id == 2:
benchmark_infer(experiment_basename,
model_list = model_list,
limit_images_amount = limit_images_amount,
result_folder = result_folder,
infer_dataset_fn=infer_control_guided_ig_bench)
elif task_id == 3:
benchmark_infer(experiment_basename,
model_list = model_list,
limit_images_amount = limit_images_amount,
result_folder = result_folder,
infer_dataset_fn=infer_subject_driven_ie_bench)
elif task_id == 4:
benchmark_infer(experiment_basename,
model_list = model_list,
limit_images_amount = limit_images_amount,
result_folder = result_folder,
infer_dataset_fn=infer_multi_concept_ic_bench)
elif task_id == 5:
benchmark_infer(experiment_basename,
model_list = model_list,
limit_images_amount = limit_images_amount,
result_folder = result_folder,
infer_dataset_fn=infer_subject_driven_ig_bench)
elif task_id == 6:
benchmark_infer(experiment_basename,
model_list = model_list,
limit_images_amount = limit_images_amount,
result_folder = result_folder,
infer_dataset_fn=infer_text_guided_ig_bench)
else:
# Implement your new task here
raise NotImplementedError()
if __name__ == "__main__":
main()