diff --git a/scripts/run_inference.py b/scripts/run_inference.py index 989aaa8ca..b67dd1cc0 100644 --- a/scripts/run_inference.py +++ b/scripts/run_inference.py @@ -8,6 +8,7 @@ If these things change in the future, we should consider breaking it up. """ + # Standard import argparse import json diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 3223bfa7c..279b006f4 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -59,5 +59,9 @@ class TrainingArguments(transformers.TrainingArguments): ) tracker: str.lower = field( default=None, - metadata={"help": "Experiment tracker to use. Requires additional configs, see tuning.configs/tracker_configs.py"} + metadata={ + "help": "Experiment tracker to use.\n" + \ + "Available trackers are - aim, none\n" + \ + "Requires additional configs, see tuning.configs/tracker_configs.py" + }, ) diff --git a/tuning/config/tracker_configs.py b/tuning/config/tracker_configs.py index bd97253b0..49135adbf 100644 --- a/tuning/config/tracker_configs.py +++ b/tuning/config/tracker_configs.py @@ -1,15 +1,18 @@ +# Standard from dataclasses import dataclass + @dataclass class AimConfig: # Name of the experiment experiment: str = None - # 'repo' can point to a locally accessible directory (e.g., '~/.aim') or a remote repository hosted on a server. - # When 'remote_server_ip' or 'remote_server_port' is set, it designates a remote aim repo. + # 'aim_repo' can point to a locally accessible directory (e.g., '~/.aim') or a remote repository hosted on a server. + # When 'aim_remote_server_ip' or 'aim_remote_server_port' is set, it designates a remote aim repo. # Otherwise, 'repo' specifies the directory, with a default of None representing '.aim'. - aim_repo: str = None + # See https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html for documentation on Aim remote server tracking. + aim_repo: str = ".aim" aim_remote_server_ip: str = None aim_remote_server_port: int = None - # Location of where run_hash is exported, if unspecified this is output to + # Location of where run_hash is exported, if unspecified this is output to # training_args.output_dir/.aim_run_hash if the output_dir is set else not exported. aim_run_hash_export_path: str = None diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 4f48ec881..cba25a67a 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -1,11 +1,12 @@ # Standard from datetime import datetime -from typing import Optional, Union, List, Dict +from typing import Dict, List, Optional, Union import json -import os, time +import os +import time # Third Party -import transformers +from peft.utils.other import fsdp_auto_wrap_policy from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -16,21 +17,22 @@ TrainerCallback, ) from transformers.utils import logging -from peft.utils.other import fsdp_auto_wrap_policy from trl import DataCollatorForCompletionOnlyLM, SFTTrainer import datasets import fire +import transformers # Local from tuning.config import configs, peft_config, tracker_configs from tuning.data import tokenizer_data_utils -from tuning.utils.config_utils import get_hf_peft_config -from tuning.utils.data_type_utils import get_torch_dtype from tuning.trackers.tracker import Tracker from tuning.trackers.tracker_factory import get_tracker +from tuning.utils.config_utils import get_hf_peft_config +from tuning.utils.data_type_utils import get_torch_dtype logger = logging.get_logger("sft_trainer") + class PeftSavingCallback(TrainerCallback): def on_save(self, args, state, control, **kwargs): checkpoint_path = os.path.join( @@ -41,6 +43,7 @@ def on_save(self, args, state, control, **kwargs): if "pytorch_model.bin" in os.listdir(checkpoint_path): os.remove(os.path.join(checkpoint_path, "pytorch_model.bin")) + class FileLoggingCallback(TrainerCallback): """Exports metrics, e.g., training loss to a file in the checkpoint directory.""" @@ -84,6 +87,7 @@ def _track_loss(self, loss_key, log_file, logs, state): with open(log_file, "a") as f: f.write(f"{json.dumps(log_obj, sort_keys=True)}\n") + def train( model_args: configs.ModelArguments, data_args: configs.DataArguments, @@ -93,7 +97,7 @@ def train( ] = None, callbacks: Optional[List[TrainerCallback]] = None, tracker: Optional[Tracker] = None, - exp_metadata: Optional[Dict] = None + exp_metadata: Optional[Dict] = None, ): """Call the SFTTrainer @@ -105,6 +109,11 @@ def train( peft_config.PromptTuningConfig for prompt tuning | \ None for fine tuning The peft configuration to pass to trainer + callbacks: List of callbacks to attach with SFTtrainer. + tracker: One of the available trackers in tuning.trackers.tracker_factory.REGISTERED_TRACKERS + Initialized using tuning.trackers.tracker_factory.get_tracker + Using configs in tuning.config.tracker_configs + exp_metadata: Dict of key value pairs passed to train to be recoreded by the tracker. """ run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1 @@ -133,7 +142,7 @@ def train( torch_dtype=get_torch_dtype(model_args.torch_dtype), use_flash_attention_2=model_args.use_flash_attn, ) - additional_metrics['model_load_time'] = time.time() - model_load_time + additional_metrics["model_load_time"] = time.time() - model_load_time peft_config = get_hf_peft_config(task_type, peft_config) @@ -269,9 +278,9 @@ def train( if tracker is not None: # Currently tracked only on process zero. if trainer.is_world_process_zero(): - for k,v in additional_metrics.items(): - tracker.track(metric=v, name=k, stage='additional_metrics') - tracker.set_params(params=exp_metadata, name='experiment_metadata') + for k, v in additional_metrics.items(): + tracker.track(metric=v, name=k, stage="additional_metrics") + tracker.set_params(params=exp_metadata, name="experiment_metadata") if run_distributed and peft_config is not None: trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy( @@ -279,6 +288,7 @@ def train( ) trainer.train() + def main(**kwargs): parser = transformers.HfArgumentParser( dataclass_types=( @@ -300,6 +310,7 @@ def main(**kwargs): "--exp_metadata", type=str, default=None, + help='Pass a json string representing K:V pairs to be associated to the tuning run in the tracker. e.g. \'{"gpu":"A100-80G"}\'', ) ( model_args, @@ -313,18 +324,18 @@ def main(**kwargs): ) = parser.parse_args_into_dataclasses(return_remaining_strings=True) peft_method = additional.peft_method - if peft_method =="lora": - tune_config=lora_config - elif peft_method =="pt": - tune_config=prompt_tuning_config + if peft_method == "lora": + tune_config = lora_config + elif peft_method == "pt": + tune_config = prompt_tuning_config else: - tune_config=None + tune_config = None tracker_name = training_args.tracker if tracker_name == "aim": - tracker_config=aim_config + tracker_config = aim_config else: - tracker_config=None + tracker_config = None # Initialize callbacks file_logger_callback = FileLoggingCallback(logger) @@ -343,7 +354,9 @@ def main(**kwargs): try: metadata = json.loads(additional.exp_metadata) if metadata is None or not isinstance(metadata, Dict): - logger.warning('metadata cannot be converted to simple k:v dict ignoring') + logger.warning( + "metadata cannot be converted to simple k:v dict ignoring" + ) metadata = None except: logger.error("failed while parsing extra metadata. pass a valid json") @@ -355,8 +368,9 @@ def main(**kwargs): peft_config=tune_config, callbacks=callbacks, tracker=tracker, - exp_metadata=metadata + exp_metadata=metadata, ) + if __name__ == "__main__": fire.Fire(main) diff --git a/tuning/trackers/aimstack_tracker.py b/tuning/trackers/aimstack_tracker.py index caf8fc4a1..30c000cc5 100644 --- a/tuning/trackers/aimstack_tracker.py +++ b/tuning/trackers/aimstack_tracker.py @@ -1,56 +1,60 @@ # Standard import os +# Third Party +from aim.hugging_face import AimCallback + +# Local from .tracker import Tracker from tuning.config.tracker_configs import AimConfig -# Third Party -from aim.hugging_face import AimCallback class CustomAimCallback(AimCallback): # A path to export run hash generated by Aim # This is used to link back to the expriments from outside aimstack - run_hash_export_path = None + hash_export_path = None def on_init_end(self, args, state, control, **kwargs): if state and not state.is_world_process_zero: return - self.setup() # initializes the run_hash + self.setup() # initializes the run_hash # Store the run hash # Change default run hash path to output directory - if self.run_hash_export_path is None: + if self.hash_export_path is None: if args and args.output_dir: # args.output_dir/.aim_run_hash - self.run_hash_export_path = os.path.join( - args.output_dir, - '.aim_run_hash' - ) + self.hash_export_path = os.path.join( + args.output_dir, ".aim_run_hash" + ) - if self.run_hash_export_path: - with open(self.run_hash_export_path, 'w') as f: - f.write('{\"run_hash\":\"'+str(self._run.hash)+'\"}\n') + if self.hash_export_path: + with open(self.hash_export_path, "w") as f: + hash = self.experiment.hash + f.write('{"run_hash":"' + str(hash) + '"}\n') def on_train_begin(self, args, state, control, model=None, **kwargs): # call directly to make sure hyper parameters and model info is recorded. self.setup(args=args, state=state, model=model) def track_metrics(self, metric, name, context): - if self._run is not None: - self._run.track(metric, name=name, context=context) + run = self.experiment + if run is not None: + run.track(metric, name=name, context=context) + def set_params(self, params, name): - if self._run is not None: - for key, value in params.items(): - self._run.set((name, key), value, strict=False) + run = self.experiment + if run is not None: + [run.set((name, key), value, strict=False) for key, value in params.items()] -class AimStackTracker(Tracker): +class AimStackTracker(Tracker): def __init__(self, tracker_config: AimConfig): - super().__init__(name='aim', tracker_config=tracker_config) + super().__init__(name="aim", tracker_config=tracker_config) def get_hf_callback(self): c = self.config @@ -60,25 +64,25 @@ def get_hf_callback(self): repo = c.aim_repo hash_export_path = c.aim_run_hash_export_path - if (ip is not None and port is not None): + if ip is not None and port is not None: aim_callback = CustomAimCallback( - repo="aim://" + ip +":"+ port + "/", - experiment=exp) + repo="aim://" + ip + ":" + port + "/", experiment=exp + ) if repo: aim_callback = CustomAimCallback(repo=repo, experiment=exp) else: aim_callback = CustomAimCallback(experiment=exp) - aim_callback.run_hash_export_path = hash_export_path + aim_callback.hash_export_path = hash_export_path self.hf_callback = aim_callback return self.hf_callback - def track(self, metric, name, stage='additional_metrics'): - context={'subset' : stage} + def track(self, metric, name, stage="additional_metrics"): + context = {"subset": stage} self.hf_callback.track_metrics(metric, name=name, context=context) - def set_params(self, params, name='extra_params'): + def set_params(self, params, name="extra_params"): try: self.hf_callback.set_params(params, name) except: - pass \ No newline at end of file + pass diff --git a/tuning/trackers/tracker.py b/tuning/trackers/tracker.py index 220109a66..71ad60183 100644 --- a/tuning/trackers/tracker.py +++ b/tuning/trackers/tracker.py @@ -1,5 +1,6 @@ # Generic Tracker API + class Tracker: def __init__(self, name=None, tracker_config=None) -> None: if tracker_config is not None: @@ -19,4 +20,4 @@ def track(self, metric, name, stage): # Object passed here is supposed to be a KV object # for the parameters to be associated with a run def set_params(self, params, name): - pass \ No newline at end of file + pass diff --git a/tuning/trackers/tracker_factory.py b/tuning/trackers/tracker_factory.py index 3f06104d8..a9c32d641 100644 --- a/tuning/trackers/tracker_factory.py +++ b/tuning/trackers/tracker_factory.py @@ -1,13 +1,13 @@ -from .tracker import Tracker +# Local from .aimstack_tracker import AimStackTracker +from .tracker import Tracker + +REGISTERED_TRACKERS = {"aim": AimStackTracker} -REGISTERED_TRACKERS = { - "aim" : AimStackTracker -} def get_tracker(tracker_name, tracker_config): if tracker_name in REGISTERED_TRACKERS: T = REGISTERED_TRACKERS[tracker_name] return T(tracker_config) else: - return Tracker() \ No newline at end of file + return Tracker()