diff --git a/tests/torchtune/modules/peft/test_dora.py b/tests/torchtune/modules/peft/test_dora.py index b52dc9702..3a7724735 100644 --- a/tests/torchtune/modules/peft/test_dora.py +++ b/tests/torchtune/modules/peft/test_dora.py @@ -269,7 +269,7 @@ def embed_dim(self): def test_dora_distributed_init(self): self.run_subtests( { - "load_dora_weights": [True], + "load_dora_weights": [True, False], }, self._test_dora_distributed_init, ) @@ -327,8 +327,7 @@ def _test_dora_distributed_init(self, load_dora_weights): assert dora_linear.magnitude.is_meta # Optionally load adapter weights (as though we are resuming from checkpoint) - # Now lora_a, lora_b, and magnitude should not be on meta device, but base weight should be - # Additionally since the weights are randomly initialized we should have magnitude != ||W+(alpha/rank)BA|| + # Now lora_a, lora_b, and magnitude should not be on meta device, but base weight should be. if load_dora_weights: training.load_from_full_model_state_dict( ffn,