Skip to content

Commit

Permalink
separate callbacks from train
Browse files Browse the repository at this point in the history
Signed-off-by: Dushyant Behl <[email protected]>
  • Loading branch information
dushyantbehl committed Feb 21, 2024
1 parent 9357243 commit b0c170c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 53 deletions.
25 changes: 0 additions & 25 deletions tuning/aim_loader.py

This file was deleted.

53 changes: 25 additions & 28 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Standard
from datetime import datetime
from typing import Optional, Union
from typing import Optional, Union, List
import json
import os, time

# Third Party
from peft.utils.other import fsdp_auto_wrap_policy
import transformers
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Expand All @@ -16,10 +16,10 @@
TrainerCallback,
)
from transformers.utils import logging
from peft.utils.other import fsdp_auto_wrap_policy
from trl import DataCollatorForCompletionOnlyLM, SFTTrainer
import datasets
import fire
import transformers

# Local
from tuning.config import configs, peft_config, tracker_configs
Expand All @@ -29,6 +29,7 @@
from tuning.tracker.tracker import Tracker
from tuning.tracker.aimstack_tracker import AimStackTracker

logger = logging.get_logger("sft_trainer")

class PeftSavingCallback(TrainerCallback):
def on_save(self, args, state, control, **kwargs):
Expand Down Expand Up @@ -83,15 +84,15 @@ def _track_loss(self, loss_key, log_file, logs, state):
with open(log_file, "a") as f:
f.write(f"{json.dumps(log_obj, sort_keys=True)}\n")


def train(
model_args: configs.ModelArguments,
data_args: configs.DataArguments,
train_args: configs.TrainingArguments,
peft_config: Optional[
Union[peft_config.LoraConfig, peft_config.PromptTuningConfig]
] = None,
tracker_config: Optional[Union[tracker_configs.AimConfig]] = None
callbacks: Optional[List[TrainerCallback]] = None,
tracker: Optional[Tracker] = None,
):
"""Call the SFTTrainer
Expand All @@ -105,7 +106,6 @@ def train(
The peft configuration to pass to trainer
"""
run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1
logger = logging.get_logger("sft_trainer")

# Validate parameters
if (not isinstance(train_args.num_train_epochs, float)) or (
Expand All @@ -122,17 +122,6 @@ def train(
train_args.fsdp = ""
train_args.fsdp_config = {"xla": False}

# Initialize the tracker early so we can calculate custom metrics like model_load_time.
tracker_name = train_args.tracker
if tracker_name == 'aim':
if tracker_config is not None:
tracker = AimStackTracker(tracker_config)
else:
logger.error("Tracker name is set to "+tracker_name+" but config is None.")
else:
logger.info('No tracker set so just set a dummy API which does nothing')
tracker = Tracker()

task_type = "CAUSAL_LM"

model_load_time = time.time()
Expand Down Expand Up @@ -259,15 +248,6 @@ def train(
)
packing = False

# club and register callbacks
file_logger_callback = FileLoggingCallback(logger)
peft_saving_callback = PeftSavingCallback()
callbacks = [peft_saving_callback, file_logger_callback]

tracker_callback = tracker.get_hf_callback()
if tracker_callback is not None:
callbacks.append(tracker_callback)

trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
Expand All @@ -288,7 +268,6 @@ def train(
)
trainer.train()


def main(**kwargs):
parser = transformers.HfArgumentParser(
dataclass_types=(
Expand Down Expand Up @@ -331,8 +310,26 @@ def main(**kwargs):
else:
tracker_config=None

train(model_args, data_args, training_args, tune_config, tracker_config)
# Initialize the tracker early so we can calculate custom metrics like model_load_time.
tracker_name = training_args.tracker
if tracker_name == 'aim':
if tracker_config is not None:
tracker = AimStackTracker(tracker_config)
else:
logger.error("Tracker name is set to "+tracker_name+" but config is None.")
else:
tracker = Tracker()

# Initialize callbacks
file_logger_callback = FileLoggingCallback(logger)
peft_saving_callback = PeftSavingCallback()
callbacks = [peft_saving_callback, file_logger_callback]

tracker_callback = tracker.get_hf_callback()
if tracker_callback is not None:
callbacks.append(tracker_callback)

train(model_args, data_args, training_args, tune_config, callbacks, tracker)

if __name__ == "__main__":
fire.Fire(main)

0 comments on commit b0c170c

Please sign in to comment.