Skip to content

Commit

Permalink
clarity comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins committed Apr 2, 2024
1 parent 3821606 commit 4f619bd
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
tokenizer = SparseAutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

# recipe for maintaining model sparsity during finetuning
recipe = """
test_stage:
pruning_modifiers:
Expand All @@ -23,15 +24,14 @@
start: 0
"""


# Load gsm8k using TRL dataset tools
dataset = load_dataset("gsm8k", "main", split="train")
def formatting_prompts_func(example):
output_texts = []
for i in range(len(example['question'])):
text = f"Question: {example['question'][i]}\n Answer: {example['answer'][i]}"
output_texts.append(text)
return output_texts

response_template = "Answer:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
model = SparseAutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
tokenizer = SparseAutoTokenizer.from_pretrained(model_path)

# Load gsm8k using SparseML dataset tools
data_args = DataTrainingArguments(dataset = "gsm8k", dataset_config_name="main", max_seq_length=512)
dataset_manager = TextGenerationDataset.load_from_registry(
data_args.dataset,
Expand All @@ -25,6 +26,7 @@
train_dataset = dataset_manager.tokenize_and_process()
print(f"--> Training Set Length = {len(train_dataset)}")

# recipe for maintaining model sparsity during finetuning
recipe = """
test_stage:
pruning_modifiers:
Expand Down

0 comments on commit 4f619bd

Please sign in to comment.