Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins committed Apr 3, 2024
1 parent e425523 commit 3c9a2e3
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 34 deletions.
50 changes: 36 additions & 14 deletions src/sparseml/transformers/finetune/examples/test_trl_sft_data.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,33 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datasets import load_dataset
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

from sparseml.transformers import (
SFTTrainer,
TrainingArguments,
SparseAutoModelForCausalLM,
SparseAutoTokenizer
SparseAutoModelForCausalLM,
SparseAutoTokenizer,
TrainingArguments,
)
from trl import DataCollatorForCompletionOnlyLM


model_path = "neuralmagic/Llama-2-7b-pruned50-retrained"
output_dir = "./output_trl_sft_test_7b_gsm8k_sft_data"
model = SparseAutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
model = SparseAutoModelForCausalLM.from_pretrained(
model_path, torch_dtype="auto", device_map="auto"
)
tokenizer = SparseAutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

Expand All @@ -19,27 +36,32 @@
test_stage:
pruning_modifiers:
ConstantPruningModifier:
targets: ['re:.*q_proj.weight', 're:.*k_proj.weight', 're:.*v_proj.weight', 're:.*o_proj.weight',
're:.*gate_proj.weight', 're:.*up_proj.weight', 're:.*down_proj.weight']
targets: ['re:.*q_proj.weight', 're:.*k_proj.weight', 're:.*v_proj.weight',
're:.*o_proj.weight','re:.*gate_proj.weight', 're:.*up_proj.weight',
're:.*down_proj.weight']
start: 0
"""

# Load gsm8k using TRL dataset tools
dataset = load_dataset("gsm8k", "main", split="train")


def formatting_prompts_func(example):
output_texts = []
for i in range(len(example['question'])):
for i in range(len(example["question"])):
text = f"Question: {example['question'][i]}\n Answer: {example['answer'][i]}"
output_texts.append(text)
return output_texts


response_template = "Answer:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=0.6,
logging_steps=50,
gradient_checkpointing=True
output_dir=output_dir,
num_train_epochs=0.6,
logging_steps=50,
gradient_checkpointing=True,
)

trainer = SFTTrainer(
Expand All @@ -50,7 +72,7 @@ def formatting_prompts_func(example):
formatting_func=formatting_prompts_func,
data_collator=collator,
args=training_args,
max_seq_length=512
max_seq_length=512,
)
trainer.train()
trainer.save_model(output_dir=trainer.args.output_dir)
trainer.save_model(output_dir=trainer.args.output_dir)
50 changes: 35 additions & 15 deletions src/sparseml/transformers/finetune/examples/test_trl_trainer.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,41 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from transformers import DefaultDataCollator

from sparseml.transformers import (
DataTrainingArguments,
SFTTrainer,
DataTrainingArguments,
TrainingArguments,
TextGenerationDataset,
SparseAutoModelForCausalLM,
SparseAutoTokenizer
SparseAutoModelForCausalLM,
SparseAutoTokenizer,
TextGenerationDataset,
TrainingArguments,
)


model_path = "neuralmagic/Llama-2-7b-pruned50-retrained"
output_dir = "./output_trl_sft_test_7b_gsm8k"

model = SparseAutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
model = SparseAutoModelForCausalLM.from_pretrained(
model_path, torch_dtype="auto", device_map="auto"
)
tokenizer = SparseAutoTokenizer.from_pretrained(model_path)

# Load gsm8k using SparseML dataset tools
data_args = DataTrainingArguments(dataset = "gsm8k", dataset_config_name="main", max_seq_length=512)
data_args = DataTrainingArguments(
dataset="gsm8k", dataset_config_name="main", max_seq_length=512
)
dataset_manager = TextGenerationDataset.load_from_registry(
data_args.dataset,
data_args=data_args,
Expand All @@ -31,17 +50,18 @@
test_stage:
pruning_modifiers:
ConstantPruningModifier:
targets: ['re:.*q_proj.weight', 're:.*k_proj.weight', 're:.*v_proj.weight', 're:.*o_proj.weight',
're:.*gate_proj.weight', 're:.*up_proj.weight', 're:.*down_proj.weight']
targets: ['re:.*q_proj.weight', 're:.*k_proj.weight', 're:.*v_proj.weight',
're:.*o_proj.weight','re:.*gate_proj.weight', 're:.*up_proj.weight',
're:.*down_proj.weight']
start: 0
"""

data_collator = DefaultDataCollator()
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=0.6,
logging_steps=50,
gradient_checkpointing=True
output_dir=output_dir,
num_train_epochs=0.6,
logging_steps=50,
gradient_checkpointing=True,
)
trainer = SFTTrainer(
model=model,
Expand All @@ -51,7 +71,7 @@
data_collator=data_collator,
args=training_args,
max_seq_length=data_args.max_seq_length,
packing=True
packing=True,
)
trainer.train()
trainer.save_model(output_dir=trainer.args.output_dir)
trainer.save_model(output_dir=trainer.args.output_dir)
6 changes: 1 addition & 5 deletions src/sparseml/transformers/finetune/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Dict, Optional, Union

import torch
from torch.nn import Module
from transformers import Trainer as HFTransformersTrainer

from sparseml.transformers.finetune.session_mixin import SessionManagerMixIn
Expand All @@ -25,4 +21,4 @@


class Trainer(SessionManagerMixIn, HFTransformersTrainer):
pass
pass

0 comments on commit 3c9a2e3

Please sign in to comment.