-
Notifications
You must be signed in to change notification settings - Fork 3
/
args.py
83 lines (79 loc) · 2.46 KB
/
args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--output_dir",
type=str,
default="/data/output",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--cache_dir",
type=str,
default="/data/dataset/cache",
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument(
"--seed", type=int, default=0, help="A seed for reproducible training."
)
parser.add_argument(
"--resolution",
type=int,
default=1024,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--train_batch_size",
type=int,
default=4,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument("--num_train_epochs", type=int, default=1_000_000)
parser.add_argument(
"--max_train_steps",
type=int,
default=1_048_577, # 33_554_432, 67_108_864, 134_217_728, 1_073_741_824, 2_147_483_648, 17_179_869_184
help="Total number of training steps per epoch to perform.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--adam_beta1",
type=float,
default=0.9,
help="The beta1 parameter for the Adam optimizer.",
)
parser.add_argument(
"--adam_beta2",
type=float,
default=0.999,
help="The beta2 parameter for the Adam optimizer.",
)
parser.add_argument(
"--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
)
parser.add_argument(
"--adam_epsilon",
type=float,
default=1e-08,
help="Epsilon value for the Adam optimizer",
)
parser.add_argument(
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
)
parser.add_argument(
"--log_wandb",
type=bool,
default=True,
choices=[True, False],
help=("Whether to use WandB to log the metrics or not"),
)
args = parser.parse_args()
return args