Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: change tracker API to initialize tracker early and track additional metrics. #50

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ numpy
accelerate>=0.20.3
transformers>=4.34.1
torch
aim==3.17.5
aim==3.18.1
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since there's a desire to implement multiple trackers, did we want to make the dependency (and imports) optional, just used when available and configured?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be done..but then do we still list these in the requirements.txt?

or do we throw an error and ask user to install the required tracker before importing it in the code.

sentencepiece
tokenizers>=0.13.3
tqdm
Expand Down
4 changes: 4 additions & 0 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,7 @@ class TrainingArguments(transformers.TrainingArguments):
default=False,
metadata={"help": "Packing to be enabled in SFT Trainer, default is False"},
)
tracker: str.lower = field(
default=None,
metadata={"help": "Experiment tracker to use. Requires additional configs, see tuning.configs/tracker_configs.py"}
dushyantbehl marked this conversation as resolved.
Show resolved Hide resolved
)
14 changes: 14 additions & 0 deletions tuning/config/tracker_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from dataclasses import dataclass

@dataclass
class AimConfig:
# Name of the experiment
experiment: str = None
# 'repo' can point to a locally accessible directory (e.g., '~/.aim') or a remote repository hosted on a server.
dushyantbehl marked this conversation as resolved.
Show resolved Hide resolved
# When 'remote_server_ip' or 'remote_server_port' is set, it designates a remote aim repo.
dushyantbehl marked this conversation as resolved.
Show resolved Hide resolved
# Otherwise, 'repo' specifies the directory, with a default of None representing '.aim'.
dushyantbehl marked this conversation as resolved.
Show resolved Hide resolved
aim_repo: str = None
aim_remote_server_ip: str = None
aim_remote_server_port: int = None
# Location of where run_hash is exported
aim_run_hash_export_path: str = None
64 changes: 47 additions & 17 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import datetime
from typing import Optional, Union
import json
import os
import os, time

# Third Party
from peft.utils.other import fsdp_auto_wrap_policy
Expand All @@ -22,11 +22,12 @@
import transformers

# Local
from tuning.aim_loader import get_aimstack_callback
from tuning.config import configs, peft_config
from tuning.config import configs, peft_config, tracker_configs
from tuning.data import tokenizer_data_utils
from tuning.utils.config_utils import get_hf_peft_config
from tuning.utils.data_type_utils import get_torch_dtype
from tuning.tracker.tracker import Tracker
from tuning.tracker.aimstack_tracker import AimStackTracker


class PeftSavingCallback(TrainerCallback):
Expand All @@ -39,7 +40,6 @@ def on_save(self, args, state, control, **kwargs):
if "pytorch_model.bin" in os.listdir(checkpoint_path):
os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))


class FileLoggingCallback(TrainerCallback):
"""Exports metrics, e.g., training loss to a file in the checkpoint directory."""

Expand Down Expand Up @@ -84,6 +84,7 @@ def train(
peft_config: Optional[
Union[peft_config.LoraConfig, peft_config.PromptTuningConfig]
] = None,
tracker_config: Optional[Union[tracker_configs.AimConfig]] = None
):
"""Call the SFTTrainer

Expand All @@ -97,7 +98,6 @@ def train(
The peft configuration to pass to trainer
"""
run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1
dushyantbehl marked this conversation as resolved.
Show resolved Hide resolved

logger = logging.get_logger("sft_trainer")

# Validate parameters
Expand All @@ -115,13 +115,28 @@ def train(
train_args.fsdp = ""
train_args.fsdp_config = {"xla": False}

# Initialize the tracker early so we can calculate custom metrics like model_load_time.
tracker_name = train_args.tracker
if tracker_name == 'aim':
if tracker_config is not None:
tracker = AimStackTracker(tracker_config)
else:
logger.error("Tracker name is set to "+tracker_name+" but config is None.")
else:
logger.info('No tracker set so just set a dummy API which does nothing')
tracker = Tracker()

task_type = "CAUSAL_LM"

model_load_time = time.time()
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=train_args.cache_dir,
torch_dtype=get_torch_dtype(model_args.torch_dtype),
use_flash_attention_2=model_args.use_flash_attn,
)
model_load_time = time.time() - model_load_time
tracker.track(metric=model_load_time, name='model_load_time')
dushyantbehl marked this conversation as resolved.
Show resolved Hide resolved

peft_config = get_hf_peft_config(task_type, peft_config)

Expand Down Expand Up @@ -212,11 +227,6 @@ def train(
formatted_validation_dataset = json_dataset["validation"].map(format_dataset)
logger.info(f"Validation dataset length is {len(formatted_validation_dataset)}")

aim_callback = get_aimstack_callback()
file_logger_callback = FileLoggingCallback(logger)
peft_saving_callback = PeftSavingCallback()
callbacks = [aim_callback, peft_saving_callback, file_logger_callback]

if train_args.packing:
logger.info("Packing is set to True")
data_collator = None
Expand All @@ -242,6 +252,15 @@ def train(
)
packing = False

# club and register callbacks
file_logger_callback = FileLoggingCallback(logger)
peft_saving_callback = PeftSavingCallback()
callbacks = [peft_saving_callback, file_logger_callback]

tracker_callback = tracker.get_hf_callback()
if tracker_callback is not None:
callbacks.append(tracker_callback)

trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
Expand Down Expand Up @@ -271,6 +290,7 @@ def main(**kwargs):
configs.TrainingArguments,
peft_config.LoraConfig,
peft_config.PromptTuningConfig,
tracker_configs.AimConfig,
)
)
parser.add_argument(
Expand All @@ -285,16 +305,26 @@ def main(**kwargs):
training_args,
lora_config,
prompt_tuning_config,
peft_method,
aim_config,
additional,
_,
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
if peft_method.peft_method == "lora":
tune_config = lora_config
elif peft_method.peft_method == "pt":
tune_config = prompt_tuning_config

peft_method = additional.peft_method
if peft_method =="lora":
dushyantbehl marked this conversation as resolved.
Show resolved Hide resolved
tune_config=lora_config
elif peft_method =="pt":
tune_config=prompt_tuning_config
else:
tune_config=None

tracker_name = training_args.tracker
if tracker_name == "aim":
tracker_config=aim_config
else:
tune_config = None
train(model_args, data_args, training_args, tune_config)
tracker_config=None

train(model_args, data_args, training_args, tune_config, tracker_config)


if __name__ == "__main__":
Expand Down
Empty file added tuning/tracker/__init__.py
Empty file.
47 changes: 47 additions & 0 deletions tuning/tracker/aimstack_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Standard
import os

from tuning.tracker.tracker import Tracker

# Third Party
from aim.hugging_face import AimCallback

class AimStackTracker(Tracker):

def __init__(self, tracker_config):
super().__init__(tracker_config)
c = self.config
exp = c.experiment
ip = c.aim_remote_server_ip
port = c.aim_remote_server_port
repo = c.aim_repo
hash_export_path = c.aim_run_hash_export_path

if (ip is not None and port is not None):
aim_callback = AimCallback(
repo="aim://" + ip +":"+ port + "/",
experiment=exp
)
if repo:
aim_callback = AimCallback(repo=repo, experiment=exp)
else:
aim_callback = AimCallback(experiment=exp)

run = aim_callback.experiment # Initialize Aim run
run_hash = run.hash # Extract the hash

# store the run hash
if hash_export_path:
with open(hash_export_path, 'w') as f:
f.write(str(run_hash)+'\n')

# Save Internal State
self.hf_callback = aim_callback
self.run = run

def get_hf_callback(self):
return self.hf_callback

def track(self, metric, name, stage='additional_metrics'):
context={'subset' : stage}
self.run.track(metric, name=name, context=context)
11 changes: 11 additions & 0 deletions tuning/tracker/tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Generic Tracker API

class Tracker:
def __init__(self, tracker_config) -> None:
self.config = tracker_config

def get_hf_callback():
return None

def track(self, metric, name, stage):
pass