From 3c9a2e311132ea966be822861bdd33b4f34afd8e Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 3 Apr 2024 12:46:34 +0000 Subject: [PATCH] style --- .../finetune/examples/test_trl_sft_data.py | 50 +++++++++++++------ .../finetune/examples/test_trl_trainer.py | 50 +++++++++++++------ src/sparseml/transformers/finetune/trainer.py | 6 +-- 3 files changed, 72 insertions(+), 34 deletions(-) diff --git a/src/sparseml/transformers/finetune/examples/test_trl_sft_data.py b/src/sparseml/transformers/finetune/examples/test_trl_sft_data.py index 68292fc85cf..ddbbd7ae623 100644 --- a/src/sparseml/transformers/finetune/examples/test_trl_sft_data.py +++ b/src/sparseml/transformers/finetune/examples/test_trl_sft_data.py @@ -1,16 +1,33 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from datasets import load_dataset -from trl import SFTTrainer, DataCollatorForCompletionOnlyLM from sparseml.transformers import ( SFTTrainer, - TrainingArguments, - SparseAutoModelForCausalLM, - SparseAutoTokenizer + SparseAutoModelForCausalLM, + SparseAutoTokenizer, + TrainingArguments, ) +from trl import DataCollatorForCompletionOnlyLM + model_path = "neuralmagic/Llama-2-7b-pruned50-retrained" output_dir = "./output_trl_sft_test_7b_gsm8k_sft_data" -model = SparseAutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto") +model = SparseAutoModelForCausalLM.from_pretrained( + model_path, torch_dtype="auto", device_map="auto" +) tokenizer = SparseAutoTokenizer.from_pretrained(model_path) tokenizer.pad_token = tokenizer.eos_token @@ -19,27 +36,32 @@ test_stage: pruning_modifiers: ConstantPruningModifier: - targets: ['re:.*q_proj.weight', 're:.*k_proj.weight', 're:.*v_proj.weight', 're:.*o_proj.weight', - 're:.*gate_proj.weight', 're:.*up_proj.weight', 're:.*down_proj.weight'] + targets: ['re:.*q_proj.weight', 're:.*k_proj.weight', 're:.*v_proj.weight', + 're:.*o_proj.weight','re:.*gate_proj.weight', 're:.*up_proj.weight', + 're:.*down_proj.weight'] start: 0 """ # Load gsm8k using TRL dataset tools dataset = load_dataset("gsm8k", "main", split="train") + + def formatting_prompts_func(example): output_texts = [] - for i in range(len(example['question'])): + for i in range(len(example["question"])): text = f"Question: {example['question'][i]}\n Answer: {example['answer'][i]}" output_texts.append(text) return output_texts + + response_template = "Answer:" collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer) training_args = TrainingArguments( - output_dir=output_dir, - num_train_epochs=0.6, - logging_steps=50, - gradient_checkpointing=True + output_dir=output_dir, + num_train_epochs=0.6, + logging_steps=50, + gradient_checkpointing=True, ) trainer = SFTTrainer( @@ -50,7 +72,7 @@ def formatting_prompts_func(example): formatting_func=formatting_prompts_func, data_collator=collator, args=training_args, - max_seq_length=512 + max_seq_length=512, ) trainer.train() -trainer.save_model(output_dir=trainer.args.output_dir) \ No newline at end of file +trainer.save_model(output_dir=trainer.args.output_dir) diff --git a/src/sparseml/transformers/finetune/examples/test_trl_trainer.py b/src/sparseml/transformers/finetune/examples/test_trl_trainer.py index c1af0b67500..7b4ecda49b5 100644 --- a/src/sparseml/transformers/finetune/examples/test_trl_trainer.py +++ b/src/sparseml/transformers/finetune/examples/test_trl_trainer.py @@ -1,22 +1,41 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from transformers import DefaultDataCollator from sparseml.transformers import ( + DataTrainingArguments, SFTTrainer, - DataTrainingArguments, - TrainingArguments, - TextGenerationDataset, - SparseAutoModelForCausalLM, - SparseAutoTokenizer + SparseAutoModelForCausalLM, + SparseAutoTokenizer, + TextGenerationDataset, + TrainingArguments, ) + model_path = "neuralmagic/Llama-2-7b-pruned50-retrained" output_dir = "./output_trl_sft_test_7b_gsm8k" -model = SparseAutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto") +model = SparseAutoModelForCausalLM.from_pretrained( + model_path, torch_dtype="auto", device_map="auto" +) tokenizer = SparseAutoTokenizer.from_pretrained(model_path) # Load gsm8k using SparseML dataset tools -data_args = DataTrainingArguments(dataset = "gsm8k", dataset_config_name="main", max_seq_length=512) +data_args = DataTrainingArguments( + dataset="gsm8k", dataset_config_name="main", max_seq_length=512 +) dataset_manager = TextGenerationDataset.load_from_registry( data_args.dataset, data_args=data_args, @@ -31,17 +50,18 @@ test_stage: pruning_modifiers: ConstantPruningModifier: - targets: ['re:.*q_proj.weight', 're:.*k_proj.weight', 're:.*v_proj.weight', 're:.*o_proj.weight', - 're:.*gate_proj.weight', 're:.*up_proj.weight', 're:.*down_proj.weight'] + targets: ['re:.*q_proj.weight', 're:.*k_proj.weight', 're:.*v_proj.weight', + 're:.*o_proj.weight','re:.*gate_proj.weight', 're:.*up_proj.weight', + 're:.*down_proj.weight'] start: 0 """ data_collator = DefaultDataCollator() training_args = TrainingArguments( - output_dir=output_dir, - num_train_epochs=0.6, - logging_steps=50, - gradient_checkpointing=True + output_dir=output_dir, + num_train_epochs=0.6, + logging_steps=50, + gradient_checkpointing=True, ) trainer = SFTTrainer( model=model, @@ -51,7 +71,7 @@ data_collator=data_collator, args=training_args, max_seq_length=data_args.max_seq_length, - packing=True + packing=True, ) trainer.train() -trainer.save_model(output_dir=trainer.args.output_dir) \ No newline at end of file +trainer.save_model(output_dir=trainer.args.output_dir) diff --git a/src/sparseml/transformers/finetune/trainer.py b/src/sparseml/transformers/finetune/trainer.py index 9fed69d82be..d918f50880b 100644 --- a/src/sparseml/transformers/finetune/trainer.py +++ b/src/sparseml/transformers/finetune/trainer.py @@ -12,10 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Optional, Union - -import torch -from torch.nn import Module from transformers import Trainer as HFTransformersTrainer from sparseml.transformers.finetune.session_mixin import SessionManagerMixIn @@ -25,4 +21,4 @@ class Trainer(SessionManagerMixIn, HFTransformersTrainer): - pass \ No newline at end of file + pass