diff --git a/configure.py b/configure.py index 7aced716..447ab5fb 100644 --- a/configure.py +++ b/configure.py @@ -543,6 +543,18 @@ def configure_env(): ) ) env_contents["--gradient_checkpointing"] = "true" + gradient_checkpointing_interval = prompt_user( + "Would you like to configure a gradient checkpointing interval? A value larger than 1 will increase VRAM usage but speed up training by skipping checkpoint creation every Nth layer", + 0, + ) + try: + if int(gradient_checkpointing_interval) > 0: + env_contents["--gradient_checkpointing_interval"] = int( + gradient_checkpointing_interval + ) + except: + print("Could not parse gradient checkpointing interval. Not enabling.") + pass env_contents["--caption_dropout_probability"] = float( prompt_user(