From b0c170c6f625185098eee3a9e388efad525e3d56 Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Wed, 21 Feb 2024 11:44:44 +0530 Subject: [PATCH] separate callbacks from train Signed-off-by: Dushyant Behl --- tuning/aim_loader.py | 25 -------------------- tuning/sft_trainer.py | 53 ++++++++++++++++++++----------------------- 2 files changed, 25 insertions(+), 53 deletions(-) delete mode 100644 tuning/aim_loader.py diff --git a/tuning/aim_loader.py b/tuning/aim_loader.py deleted file mode 100644 index 6ee617a42..000000000 --- a/tuning/aim_loader.py +++ /dev/null @@ -1,25 +0,0 @@ -# Standard -import os - -# Third Party -from aim.hugging_face import AimCallback - - -def get_aimstack_callback(): - # Initialize a new run - aim_server = os.environ.get("AIMSTACK_SERVER") - aim_db = os.environ.get("AIMSTACK_DB") - aim_experiment = os.environ.get("AIMSTACK_EXPERIMENT") - if aim_experiment is None: - aim_experiment = "" - - if aim_server: - aim_callback = AimCallback( - repo="aim://" + aim_server + "/", experiment=aim_experiment - ) - if aim_db: - aim_callback = AimCallback(repo=aim_db, experiment=aim_experiment) - else: - aim_callback = AimCallback(experiment=aim_experiment) - - return aim_callback diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index a4840f015..2eb53cc9c 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -1,11 +1,11 @@ # Standard from datetime import datetime -from typing import Optional, Union +from typing import Optional, Union, List import json import os, time # Third Party -from peft.utils.other import fsdp_auto_wrap_policy +import transformers from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -16,10 +16,10 @@ TrainerCallback, ) from transformers.utils import logging +from peft.utils.other import fsdp_auto_wrap_policy from trl import DataCollatorForCompletionOnlyLM, SFTTrainer import datasets import fire -import transformers # Local from tuning.config import configs, peft_config, tracker_configs @@ -29,6 +29,7 @@ from tuning.tracker.tracker import Tracker from tuning.tracker.aimstack_tracker import AimStackTracker +logger = logging.get_logger("sft_trainer") class PeftSavingCallback(TrainerCallback): def on_save(self, args, state, control, **kwargs): @@ -83,7 +84,6 @@ def _track_loss(self, loss_key, log_file, logs, state): with open(log_file, "a") as f: f.write(f"{json.dumps(log_obj, sort_keys=True)}\n") - def train( model_args: configs.ModelArguments, data_args: configs.DataArguments, @@ -91,7 +91,8 @@ def train( peft_config: Optional[ Union[peft_config.LoraConfig, peft_config.PromptTuningConfig] ] = None, - tracker_config: Optional[Union[tracker_configs.AimConfig]] = None + callbacks: Optional[List[TrainerCallback]] = None, + tracker: Optional[Tracker] = None, ): """Call the SFTTrainer @@ -105,7 +106,6 @@ def train( The peft configuration to pass to trainer """ run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1 - logger = logging.get_logger("sft_trainer") # Validate parameters if (not isinstance(train_args.num_train_epochs, float)) or ( @@ -122,17 +122,6 @@ def train( train_args.fsdp = "" train_args.fsdp_config = {"xla": False} - # Initialize the tracker early so we can calculate custom metrics like model_load_time. - tracker_name = train_args.tracker - if tracker_name == 'aim': - if tracker_config is not None: - tracker = AimStackTracker(tracker_config) - else: - logger.error("Tracker name is set to "+tracker_name+" but config is None.") - else: - logger.info('No tracker set so just set a dummy API which does nothing') - tracker = Tracker() - task_type = "CAUSAL_LM" model_load_time = time.time() @@ -259,15 +248,6 @@ def train( ) packing = False - # club and register callbacks - file_logger_callback = FileLoggingCallback(logger) - peft_saving_callback = PeftSavingCallback() - callbacks = [peft_saving_callback, file_logger_callback] - - tracker_callback = tracker.get_hf_callback() - if tracker_callback is not None: - callbacks.append(tracker_callback) - trainer = SFTTrainer( model=model, tokenizer=tokenizer, @@ -288,7 +268,6 @@ def train( ) trainer.train() - def main(**kwargs): parser = transformers.HfArgumentParser( dataclass_types=( @@ -331,8 +310,26 @@ def main(**kwargs): else: tracker_config=None - train(model_args, data_args, training_args, tune_config, tracker_config) + # Initialize the tracker early so we can calculate custom metrics like model_load_time. + tracker_name = training_args.tracker + if tracker_name == 'aim': + if tracker_config is not None: + tracker = AimStackTracker(tracker_config) + else: + logger.error("Tracker name is set to "+tracker_name+" but config is None.") + else: + tracker = Tracker() + + # Initialize callbacks + file_logger_callback = FileLoggingCallback(logger) + peft_saving_callback = PeftSavingCallback() + callbacks = [peft_saving_callback, file_logger_callback] + + tracker_callback = tracker.get_hf_callback() + if tracker_callback is not None: + callbacks.append(tracker_callback) + train(model_args, data_args, training_args, tune_config, callbacks, tracker) if __name__ == "__main__": fire.Fire(main)