-
Notifications
You must be signed in to change notification settings - Fork 2.3k
/
config_utils.py
171 lines (143 loc) · 7.53 KB
/
config_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import argparse
import json
import os
from glob import glob
from mmengine.config import Config
def parse_args(training=False):
parser = argparse.ArgumentParser()
# model config
parser.add_argument("config", help="model config file path")
# ======================================================
# General
# ======================================================
parser.add_argument("--seed", default=None, type=int, help="seed for reproducibility")
parser.add_argument(
"--ckpt-path",
default=None,
type=str,
help="path to model ckpt; will overwrite cfg.model.from_pretrained if specified",
)
parser.add_argument("--batch-size", default=None, type=int, help="batch size")
parser.add_argument("--outputs", default=None, type=str, help="the dir to save model weights")
parser.add_argument("--flash-attn", default=None, type=str2bool, help="enable flash attention")
parser.add_argument("--layernorm-kernel", default=None, type=str2bool, help="enable layernorm kernel")
parser.add_argument("--resolution", default=None, type=str, help="multi resolution")
parser.add_argument("--data-path", default=None, type=str, help="path to data csv")
parser.add_argument("--dtype", default=None, type=str, help="data type")
# ======================================================
# Inference
# ======================================================
if not training:
# output
parser.add_argument("--save-dir", default=None, type=str, help="path to save generated samples")
parser.add_argument("--sample-name", default=None, type=str, help="sample name, default is sample_idx")
parser.add_argument("--start-index", default=None, type=int, help="start index for sample name")
parser.add_argument("--end-index", default=None, type=int, help="end index for sample name")
parser.add_argument("--num-sample", default=None, type=int, help="number of samples to generate for one prompt")
parser.add_argument("--prompt-as-path", action="store_true", help="use prompt as path to save samples")
parser.add_argument("--verbose", default=None, type=int, help="verbose level")
# prompt
parser.add_argument("--prompt-path", default=None, type=str, help="path to prompt txt file")
parser.add_argument("--prompt", default=None, type=str, nargs="+", help="prompt list")
parser.add_argument("--llm-refine", default=None, type=str2bool, help="enable LLM refine")
parser.add_argument("--prompt-generator", default=None, type=str, help="prompt generator")
# image/video
parser.add_argument("--num-frames", default=None, type=str, help="number of frames")
parser.add_argument("--fps", default=None, type=int, help="fps")
parser.add_argument("--save-fps", default=None, type=int, help="save fps")
parser.add_argument("--image-size", default=None, type=int, nargs=2, help="image size")
parser.add_argument("--frame-interval", default=None, type=int, help="frame interval")
parser.add_argument("--aspect-ratio", default=None, type=str, help="aspect ratio (h:w)")
parser.add_argument("--watermark", default=None, type=str2bool, help="watermark video")
# hyperparameters
parser.add_argument("--num-sampling-steps", default=None, type=int, help="sampling steps")
parser.add_argument("--cfg-scale", default=None, type=float, help="balance between cond & uncond")
# reference
parser.add_argument("--loop", default=None, type=int, help="loop")
parser.add_argument("--condition-frame-length", default=None, type=int, help="condition frame length")
parser.add_argument("--reference-path", default=None, type=str, nargs="+", help="reference path")
parser.add_argument("--mask-strategy", default=None, type=str, nargs="+", help="mask strategy")
parser.add_argument("--aes", default=None, type=float, help="aesthetic score")
parser.add_argument("--flow", default=None, type=float, help="flow score")
parser.add_argument("--camera-motion", default=None, type=str, help="camera motion")
# ======================================================
# Training
# ======================================================
else:
parser.add_argument("--lr", default=None, type=float, help="learning rate")
parser.add_argument("--wandb", default=None, type=bool, help="enable wandb")
parser.add_argument("--load", default=None, type=str, help="path to continue training")
parser.add_argument("--start-from-scratch", action="store_true", help="start training from scratch")
parser.add_argument("--warmup-steps", default=None, type=int, help="warmup steps")
parser.add_argument("--record-time", default=False, action="store_true", help="record time of each part")
return parser.parse_args()
def merge_args(cfg, args, training=False):
if args.ckpt_path is not None:
cfg.model["from_pretrained"] = args.ckpt_path
if cfg.get("discriminator") is not None:
cfg.discriminator["from_pretrained"] = args.ckpt_path
args.ckpt_path = None
if args.flash_attn is not None:
cfg.model["enable_flash_attn"] = args.flash_attn
args.enable_flash_attn = None
if args.layernorm_kernel is not None:
cfg.model["enable_layernorm_kernel"] = args.layernorm_kernel
args.enable_layernorm_kernel = None
if args.data_path is not None:
cfg.dataset["data_path"] = args.data_path
args.data_path = None
# NOTE: for vae inference (reconstruction)
if not training and "dataset" in cfg:
if args.image_size is not None:
cfg.dataset["image_size"] = args.image_size
if args.num_frames is not None:
cfg.dataset["num_frames"] = args.num_frames
if not training:
if args.cfg_scale is not None:
cfg.scheduler["cfg_scale"] = args.cfg_scale
args.cfg_scale = None
if args.num_sampling_steps is not None:
cfg.scheduler["num_sampling_steps"] = args.num_sampling_steps
args.num_sampling_steps = None
for k, v in vars(args).items():
if v is not None:
cfg[k] = v
return cfg
def read_config(config_path):
cfg = Config.fromfile(config_path)
return cfg
def parse_configs(training=False):
args = parse_args(training)
cfg = read_config(args.config)
cfg = merge_args(cfg, args, training)
return cfg
def define_experiment_workspace(cfg, get_last_workspace=False):
"""
This function creates a folder for experiment tracking.
Args:
args: The parsed arguments.
Returns:
exp_dir: The path to the experiment folder.
"""
# Make outputs folder (holds all experiment subfolders)
os.makedirs(cfg.outputs, exist_ok=True)
experiment_index = len(glob(f"{cfg.outputs}/*"))
if get_last_workspace:
experiment_index -= 1
# Create an experiment folder
model_name = cfg.model["type"].replace("/", "-")
exp_name = f"{experiment_index:03d}-{model_name}"
exp_dir = f"{cfg.outputs}/{exp_name}"
return exp_name, exp_dir
def save_training_config(cfg, experiment_dir):
with open(f"{experiment_dir}/config.txt", "w") as f:
json.dump(cfg, f, indent=4)
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")