Skip to content

Commit

Permalink
Merge branch 'main' into sa/update_ex
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins authored May 28, 2024
2 parents ec77568 + 53e98b6 commit ea07f56
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions tests/sparseml/transformers/finetune/test_oneshot_then_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,13 @@ def setUp(self):
self.output = Path("./finetune_output")

def test_oneshot_then_finetune(self):
import torch

import sparseml
from sparseml.transformers import oneshot, train
from sparseml.transformers import SparseAutoModelForCausalLM, oneshot, train

recipe_str = "tests/sparseml/transformers/obcq/recipes/test_tiny2.yaml"
model = "Xenova/llama2.c-stories15M"
device = "cuda:0"
if not torch.cuda.is_available():
device = "cpu"
model = SparseAutoModelForCausalLM.from_pretrained(
"Xenova/llama2.c-stories15M", device_map="auto"
)
dataset = "open_platypus"
concatenate_data = False
num_calibration_samples = 64
Expand All @@ -59,11 +56,15 @@ def test_oneshot_then_finetune(self):
recipe=recipe_str,
concatenate_data=concatenate_data,
splits=splits,
oneshot_device=device,
)

recipe_str = "tests/sparseml/transformers/finetune/test_finetune_recipe.yaml"
model = self.output / "oneshot_out"
model = SparseAutoModelForCausalLM.from_pretrained(
self.output / "oneshot_out", device_map="auto"
)
distill_teacher = SparseAutoModelForCausalLM.from_pretrained(
"Xenova/llama2.c-stories15M", device_map="auto"
)
dataset = "open_platypus"
concatenate_data = False
output_dir = self.output / "finetune_out"
Expand All @@ -73,15 +74,14 @@ def test_oneshot_then_finetune(self):
with sparseml.create_session():
train(
model=model,
distill_teacher="Xenova/llama2.c-stories15M",
distill_teacher=distill_teacher,
dataset=dataset,
output_dir=output_dir,
num_calibration_samples=num_calibration_samples,
recipe=recipe_str,
concatenate_data=concatenate_data,
splits=splits,
max_steps=max_steps,
oneshot_device=device,
)

def tearDown(self):
Expand Down

0 comments on commit ea07f56

Please sign in to comment.