diff --git a/configure.py b/configure.py index 447ab5fb..26620662 100644 --- a/configure.py +++ b/configure.py @@ -544,11 +544,11 @@ def configure_env(): ) env_contents["--gradient_checkpointing"] = "true" gradient_checkpointing_interval = prompt_user( - "Would you like to configure a gradient checkpointing interval? A value larger than 1 will increase VRAM usage but speed up training by skipping checkpoint creation every Nth layer", + "Would you like to configure a gradient checkpointing interval? A value larger than 1 will increase VRAM usage but speed up training by skipping checkpoint creation every Nth layer, and a zero will disable this feature.", 0, ) try: - if int(gradient_checkpointing_interval) > 0: + if int(gradient_checkpointing_interval) > 1: env_contents["--gradient_checkpointing_interval"] = int( gradient_checkpointing_interval )