diff --git a/tests/sparseml/transformers/finetune/test_oneshot_then_finetune.py b/tests/sparseml/transformers/finetune/test_oneshot_then_finetune.py index eb010083343..ef7a1b30a5b 100644 --- a/tests/sparseml/transformers/finetune/test_oneshot_then_finetune.py +++ b/tests/sparseml/transformers/finetune/test_oneshot_then_finetune.py @@ -34,16 +34,13 @@ def setUp(self): self.output = Path("./finetune_output") def test_oneshot_then_finetune(self): - import torch - import sparseml - from sparseml.transformers import oneshot, train + from sparseml.transformers import SparseAutoModelForCausalLM, oneshot, train recipe_str = "tests/sparseml/transformers/obcq/recipes/test_tiny2.yaml" - model = "Xenova/llama2.c-stories15M" - device = "cuda:0" - if not torch.cuda.is_available(): - device = "cpu" + model = SparseAutoModelForCausalLM.from_pretrained( + "Xenova/llama2.c-stories15M", device_map="auto" + ) dataset = "open_platypus" concatenate_data = False num_calibration_samples = 64 @@ -59,11 +56,15 @@ def test_oneshot_then_finetune(self): recipe=recipe_str, concatenate_data=concatenate_data, splits=splits, - oneshot_device=device, ) recipe_str = "tests/sparseml/transformers/finetune/test_finetune_recipe.yaml" - model = self.output / "oneshot_out" + model = SparseAutoModelForCausalLM.from_pretrained( + self.output / "oneshot_out", device_map="auto" + ) + distill_teacher = SparseAutoModelForCausalLM.from_pretrained( + "Xenova/llama2.c-stories15M", device_map="auto" + ) dataset = "open_platypus" concatenate_data = False output_dir = self.output / "finetune_out" @@ -73,7 +74,7 @@ def test_oneshot_then_finetune(self): with sparseml.create_session(): train( model=model, - distill_teacher="Xenova/llama2.c-stories15M", + distill_teacher=distill_teacher, dataset=dataset, output_dir=output_dir, num_calibration_samples=num_calibration_samples, @@ -81,7 +82,6 @@ def test_oneshot_then_finetune(self): concatenate_data=concatenate_data, splits=splits, max_steps=max_steps, - oneshot_device=device, ) def tearDown(self):