diff --git a/tests/torchtune/training/test_distributed.py b/tests/torchtune/training/test_distributed.py index e4abc7e5e..b4f0d798f 100644 --- a/tests/torchtune/training/test_distributed.py +++ b/tests/torchtune/training/test_distributed.py @@ -243,6 +243,7 @@ def test_lora_state_dict(self): fsdp_model_to_load.parameters(), weight_decay=0.01, lr=0.01 ) training.load_from_full_optimizer_state_dict( + fsdp_model_to_load, fsdp_optim_to_load, # mimic mmap=True where every rank see full SD copy.deepcopy(self._broadcast_full_state_dict(optim_full_sd)),