diff --git a/fixtures/accelerate_fsdp_defaults.yaml b/fixtures/accelerate_fsdp_defaults.yaml index f70d74faa..28cc0faec 100644 --- a/fixtures/accelerate_fsdp_defaults.yaml +++ b/fixtures/accelerate_fsdp_defaults.yaml @@ -14,7 +14,7 @@ fsdp_config: fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP # this controls the FSDP pipelining - fsdp_backward_prefetch_policy: BACKWARD_PRE # set to BACKWARD_PRE for the most time-efficient pipeline + fsdp_backward_prefetch: BACKWARD_PRE # set to BACKWARD_PRE for the most time-efficient pipeline # but requires the most memory. BACKWARD_POST is the less # memory intensive option diff --git a/tests/build/dummy_job_config.json b/tests/build/dummy_job_config.json index 315a5b527..ed5abfa85 100644 --- a/tests/build/dummy_job_config.json +++ b/tests/build/dummy_job_config.json @@ -5,7 +5,7 @@ "dynamo_use_dynamic": true, "num_machines": 1, "main_process_port": 1234, - "fsdp_backward_prefetch_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_backward_prefetch": "TRANSFORMER_BASED_WRAP", "fsdp_sharding_strategy": 1, "fsdp_state_dict_type": "FULL_STATE_DICT", "fsdp_cpu_ram_efficient_loading": true, diff --git a/tests/build/test_utils.py b/tests/build/test_utils.py index fde0ffb2c..4ad228879 100644 --- a/tests/build/test_utils.py +++ b/tests/build/test_utils.py @@ -44,7 +44,7 @@ def test_process_accelerate_launch_args(job_config): args = process_accelerate_launch_args(job_config) # json config values used assert args.use_fsdp is True - assert args.fsdp_backward_prefetch_policy == "TRANSFORMER_BASED_WRAP" + assert args.fsdp_backward_prefetch == "TRANSFORMER_BASED_WRAP" assert args.env == ["env1", "env2"] assert args.training_script == "tuning.sft_trainer" assert args.config_file == "fixtures/accelerate_fsdp_defaults.yaml"