Skip to content

Commit

Permalink
code format using black and minor nit to use public interface
Browse files Browse the repository at this point in the history
Signed-off-by: Dushyant Behl <[email protected]>
  • Loading branch information
dushyantbehl committed Mar 4, 2024
1 parent c93329f commit f178c21
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 58 deletions.
1 change: 1 addition & 0 deletions scripts/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
If these things change in the future, we should consider breaking it up.
"""

# Standard
import argparse
import json
Expand Down
6 changes: 5 additions & 1 deletion tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,9 @@ class TrainingArguments(transformers.TrainingArguments):
)
tracker: str.lower = field(
default=None,
metadata={"help": "Experiment tracker to use. Requires additional configs, see tuning.configs/tracker_configs.py"}
metadata={
"help": "Experiment tracker to use.\n" + \
"Available trackers are - aim, none\n" + \
"Requires additional configs, see tuning.configs/tracker_configs.py"
},
)
11 changes: 7 additions & 4 deletions tuning/config/tracker_configs.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
# Standard
from dataclasses import dataclass


@dataclass
class AimConfig:
# Name of the experiment
experiment: str = None
# 'repo' can point to a locally accessible directory (e.g., '~/.aim') or a remote repository hosted on a server.
# When 'remote_server_ip' or 'remote_server_port' is set, it designates a remote aim repo.
# 'aim_repo' can point to a locally accessible directory (e.g., '~/.aim') or a remote repository hosted on a server.
# When 'aim_remote_server_ip' or 'aim_remote_server_port' is set, it designates a remote aim repo.
# Otherwise, 'repo' specifies the directory, with a default of None representing '.aim'.
aim_repo: str = None
# See https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html for documentation on Aim remote server tracking.
aim_repo: str = ".aim"
aim_remote_server_ip: str = None
aim_remote_server_port: int = None
# Location of where run_hash is exported, if unspecified this is output to
# Location of where run_hash is exported, if unspecified this is output to
# training_args.output_dir/.aim_run_hash if the output_dir is set else not exported.
aim_run_hash_export_path: str = None
54 changes: 34 additions & 20 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Standard
from datetime import datetime
from typing import Optional, Union, List, Dict
from typing import Dict, List, Optional, Union
import json
import os, time
import os
import time

# Third Party
import transformers
from peft.utils.other import fsdp_auto_wrap_policy
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Expand All @@ -16,21 +17,22 @@
TrainerCallback,
)
from transformers.utils import logging
from peft.utils.other import fsdp_auto_wrap_policy
from trl import DataCollatorForCompletionOnlyLM, SFTTrainer
import datasets
import fire
import transformers

# Local
from tuning.config import configs, peft_config, tracker_configs
from tuning.data import tokenizer_data_utils
from tuning.utils.config_utils import get_hf_peft_config
from tuning.utils.data_type_utils import get_torch_dtype
from tuning.trackers.tracker import Tracker
from tuning.trackers.tracker_factory import get_tracker
from tuning.utils.config_utils import get_hf_peft_config
from tuning.utils.data_type_utils import get_torch_dtype

logger = logging.get_logger("sft_trainer")


class PeftSavingCallback(TrainerCallback):
def on_save(self, args, state, control, **kwargs):
checkpoint_path = os.path.join(
Expand All @@ -41,6 +43,7 @@ def on_save(self, args, state, control, **kwargs):
if "pytorch_model.bin" in os.listdir(checkpoint_path):
os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))


class FileLoggingCallback(TrainerCallback):
"""Exports metrics, e.g., training loss to a file in the checkpoint directory."""

Expand Down Expand Up @@ -84,6 +87,7 @@ def _track_loss(self, loss_key, log_file, logs, state):
with open(log_file, "a") as f:
f.write(f"{json.dumps(log_obj, sort_keys=True)}\n")


def train(
model_args: configs.ModelArguments,
data_args: configs.DataArguments,
Expand All @@ -93,7 +97,7 @@ def train(
] = None,
callbacks: Optional[List[TrainerCallback]] = None,
tracker: Optional[Tracker] = None,
exp_metadata: Optional[Dict] = None
exp_metadata: Optional[Dict] = None,
):
"""Call the SFTTrainer
Expand All @@ -105,6 +109,11 @@ def train(
peft_config.PromptTuningConfig for prompt tuning | \
None for fine tuning
The peft configuration to pass to trainer
callbacks: List of callbacks to attach with SFTtrainer.
tracker: One of the available trackers in tuning.trackers.tracker_factory.REGISTERED_TRACKERS
Initialized using tuning.trackers.tracker_factory.get_tracker
Using configs in tuning.config.tracker_configs
exp_metadata: Dict of key value pairs passed to train to be recoreded by the tracker.
"""
run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1

Expand Down Expand Up @@ -133,7 +142,7 @@ def train(
torch_dtype=get_torch_dtype(model_args.torch_dtype),
use_flash_attention_2=model_args.use_flash_attn,
)
additional_metrics['model_load_time'] = time.time() - model_load_time
additional_metrics["model_load_time"] = time.time() - model_load_time

peft_config = get_hf_peft_config(task_type, peft_config)

Expand Down Expand Up @@ -269,16 +278,17 @@ def train(
if tracker is not None:
# Currently tracked only on process zero.
if trainer.is_world_process_zero():
for k,v in additional_metrics.items():
tracker.track(metric=v, name=k, stage='additional_metrics')
tracker.set_params(params=exp_metadata, name='experiment_metadata')
for k, v in additional_metrics.items():
tracker.track(metric=v, name=k, stage="additional_metrics")
tracker.set_params(params=exp_metadata, name="experiment_metadata")

if run_distributed and peft_config is not None:
trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(
model
)
trainer.train()


def main(**kwargs):
parser = transformers.HfArgumentParser(
dataclass_types=(
Expand All @@ -300,6 +310,7 @@ def main(**kwargs):
"--exp_metadata",
type=str,
default=None,
help='Pass a json string representing K:V pairs to be associated to the tuning run in the tracker. e.g. \'{"gpu":"A100-80G"}\'',
)
(
model_args,
Expand All @@ -313,18 +324,18 @@ def main(**kwargs):
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)

peft_method = additional.peft_method
if peft_method =="lora":
tune_config=lora_config
elif peft_method =="pt":
tune_config=prompt_tuning_config
if peft_method == "lora":
tune_config = lora_config
elif peft_method == "pt":
tune_config = prompt_tuning_config
else:
tune_config=None
tune_config = None

tracker_name = training_args.tracker
if tracker_name == "aim":
tracker_config=aim_config
tracker_config = aim_config
else:
tracker_config=None
tracker_config = None

# Initialize callbacks
file_logger_callback = FileLoggingCallback(logger)
Expand All @@ -343,7 +354,9 @@ def main(**kwargs):
try:
metadata = json.loads(additional.exp_metadata)
if metadata is None or not isinstance(metadata, Dict):
logger.warning('metadata cannot be converted to simple k:v dict ignoring')
logger.warning(
"metadata cannot be converted to simple k:v dict ignoring"
)
metadata = None
except:
logger.error("failed while parsing extra metadata. pass a valid json")
Expand All @@ -355,8 +368,9 @@ def main(**kwargs):
peft_config=tune_config,
callbacks=callbacks,
tracker=tracker,
exp_metadata=metadata
exp_metadata=metadata,
)


if __name__ == "__main__":
fire.Fire(main)
58 changes: 31 additions & 27 deletions tuning/trackers/aimstack_tracker.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,60 @@
# Standard
import os

# Third Party
from aim.hugging_face import AimCallback

# Local
from .tracker import Tracker
from tuning.config.tracker_configs import AimConfig

# Third Party
from aim.hugging_face import AimCallback

class CustomAimCallback(AimCallback):

# A path to export run hash generated by Aim
# This is used to link back to the expriments from outside aimstack
run_hash_export_path = None
hash_export_path = None

def on_init_end(self, args, state, control, **kwargs):

if state and not state.is_world_process_zero:
return

self.setup() # initializes the run_hash
self.setup() # initializes the run_hash

# Store the run hash
# Change default run hash path to output directory
if self.run_hash_export_path is None:
if self.hash_export_path is None:
if args and args.output_dir:
# args.output_dir/.aim_run_hash
self.run_hash_export_path = os.path.join(
args.output_dir,
'.aim_run_hash'
)
self.hash_export_path = os.path.join(
args.output_dir, ".aim_run_hash"
)

if self.run_hash_export_path:
with open(self.run_hash_export_path, 'w') as f:
f.write('{\"run_hash\":\"'+str(self._run.hash)+'\"}\n')
if self.hash_export_path:
with open(self.hash_export_path, "w") as f:
hash = self.experiment.hash
f.write('{"run_hash":"' + str(hash) + '"}\n')

def on_train_begin(self, args, state, control, model=None, **kwargs):
# call directly to make sure hyper parameters and model info is recorded.
self.setup(args=args, state=state, model=model)

def track_metrics(self, metric, name, context):
if self._run is not None:
self._run.track(metric, name=name, context=context)
run = self.experiment
if run is not None:
run.track(metric, name=name, context=context)


def set_params(self, params, name):
if self._run is not None:
for key, value in params.items():
self._run.set((name, key), value, strict=False)
run = self.experiment
if run is not None:
[run.set((name, key), value, strict=False) for key, value in params.items()]

class AimStackTracker(Tracker):

class AimStackTracker(Tracker):
def __init__(self, tracker_config: AimConfig):
super().__init__(name='aim', tracker_config=tracker_config)
super().__init__(name="aim", tracker_config=tracker_config)

def get_hf_callback(self):
c = self.config
Expand All @@ -60,25 +64,25 @@ def get_hf_callback(self):
repo = c.aim_repo
hash_export_path = c.aim_run_hash_export_path

if (ip is not None and port is not None):
if ip is not None and port is not None:
aim_callback = CustomAimCallback(
repo="aim://" + ip +":"+ port + "/",
experiment=exp)
repo="aim://" + ip + ":" + port + "/", experiment=exp
)
if repo:
aim_callback = CustomAimCallback(repo=repo, experiment=exp)
else:
aim_callback = CustomAimCallback(experiment=exp)

aim_callback.run_hash_export_path = hash_export_path
aim_callback.hash_export_path = hash_export_path
self.hf_callback = aim_callback
return self.hf_callback

def track(self, metric, name, stage='additional_metrics'):
context={'subset' : stage}
def track(self, metric, name, stage="additional_metrics"):
context = {"subset": stage}
self.hf_callback.track_metrics(metric, name=name, context=context)

def set_params(self, params, name='extra_params'):
def set_params(self, params, name="extra_params"):
try:
self.hf_callback.set_params(params, name)
except:
pass
pass
3 changes: 2 additions & 1 deletion tuning/trackers/tracker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generic Tracker API


class Tracker:
def __init__(self, name=None, tracker_config=None) -> None:
if tracker_config is not None:
Expand All @@ -19,4 +20,4 @@ def track(self, metric, name, stage):
# Object passed here is supposed to be a KV object
# for the parameters to be associated with a run
def set_params(self, params, name):
pass
pass
10 changes: 5 additions & 5 deletions tuning/trackers/tracker_factory.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from .tracker import Tracker
# Local
from .aimstack_tracker import AimStackTracker
from .tracker import Tracker

REGISTERED_TRACKERS = {"aim": AimStackTracker}

REGISTERED_TRACKERS = {
"aim" : AimStackTracker
}

def get_tracker(tracker_name, tracker_config):
if tracker_name in REGISTERED_TRACKERS:
T = REGISTERED_TRACKERS[tracker_name]
return T(tracker_config)
else:
return Tracker()
return Tracker()

0 comments on commit f178c21

Please sign in to comment.