-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.py
103 lines (84 loc) · 2.91 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
# jax/flax
import jax
from flax.core.frozen_dict import unfreeze
from flax.training import train_state
from architecture import setup_model
# internal code
from args import parse_args
from optimizer import setup_optimizer
from training_loop import training_loop
from monitoring import wandb_close, wandb_init
def main():
args = parse_args()
# number of splits/partitions/devices/shards
num_devices = jax.local_device_count()
output_dir = args.output_dir
load_pretrained = os.path.exists(output_dir) and os.path.isdir(output_dir)
# Setup WandB for logging & tracking
log_wandb = args.log_wandb
if log_wandb:
wandb_init(args, num_devices)
# init random number generator
seed = args.seed
seed_rng = jax.random.PRNGKey(seed)
rng, training_from_scratch_rng_params = jax.random.split(seed_rng)
print("random generator setup...")
# Pretrained/freezed and training model setup
text_encoder, text_encoder_params, vae, vae_params, unet, unet_params = setup_model(
seed,
load_pretrained,
output_dir,
training_from_scratch_rng_params,
)
print("models setup...")
# Optimization & scheduling setup
optimizer = setup_optimizer(
args.learning_rate,
args.adam_beta1,
args.adam_beta2,
args.adam_epsilon,
args.adam_weight_decay,
args.max_grad_norm,
)
print("optimizer setup...")
# Training state setup
unet_training_state = train_state.TrainState.create(
apply_fn=unet,
params=unfreeze(unet_params),
tx=optimizer,
)
print("training state initialized...")
if log_wandb:
get_validation_predictions = None # TODO: put validation here
else:
get_validation_predictions = None
# JAX device data replication
# replicated_state = replicate(unet_training_state)
# NOTE: # These can't be replicated here, otherwise, you get this whenever they are used: flax.errors.ScopeParamShapeError: Initializer expected to generate shape (4, 384, 1536) but got shape (384, 1536) instead for parameter "embedding" in "/shared". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)
# replicated_text_encoder_params = jax_utils.replicate(text_encoder_params)
# replicated_vae_params = jax_utils.replicate(vae_params)
# print("states & params replicated to TPUs...")
# Train!
print("Training loop init...")
training_loop(
text_encoder,
text_encoder_params,
vae,
vae_params,
unet,
unet_training_state,
rng,
args.max_train_steps,
args.num_train_epochs,
args.train_batch_size,
output_dir,
log_wandb,
get_validation_predictions,
num_devices,
)
print("Training loop done...")
if log_wandb:
wandb_close()
if __name__ == "__main__":
main()