diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 5abd2e294..0938aa650 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -1,6 +1,6 @@ # Standard from datetime import datetime -from typing import Optional, Union, List +from typing import Optional, Union, List, Dict import json import os, time @@ -26,8 +26,7 @@ from tuning.data import tokenizer_data_utils from tuning.utils.config_utils import get_hf_peft_config from tuning.utils.data_type_utils import get_torch_dtype -from tuning.tracker.tracker import Tracker, get_tracker -from tuning.tracker.aimstack_tracker import AimStackTracker +from tuning.tracker.tracker import Tracker, TrackerFactory logger = logging.get_logger("sft_trainer") @@ -92,7 +91,8 @@ def train( Union[peft_config.LoraConfig, peft_config.PromptTuningConfig] ] = None, callbacks: Optional[List[TrainerCallback]] = None, - tracker: Optional[Tracker] = Tracker() # default tracker is dummy tracker + tracker: Optional[Tracker] = None, + exp_metadata: Optional[Dict] = None ): """Call the SFTTrainer @@ -123,6 +123,7 @@ def train( train_args.fsdp_config = {"xla": False} task_type = "CAUSAL_LM" + additional_metrics = {} model_load_time = time.time() model = AutoModelForCausalLM.from_pretrained( @@ -131,8 +132,7 @@ def train( torch_dtype=get_torch_dtype(model_args.torch_dtype), use_flash_attention_2=model_args.use_flash_attn, ) - model_load_time = time.time() - model_load_time - tracker.track(metric=model_load_time, name='model_load_time') + additional_metrics['model_load_time'] = time.time() - model_load_time peft_config = get_hf_peft_config(task_type, peft_config) @@ -262,6 +262,16 @@ def train( peft_config=peft_config, ) + # We track additional metrics and experiment metadata after + # Trainer object creation to ensure that this is not repeated + # multiple times for FSDP runs. + if tracker is not None: + # Currently tracked only on process zero. + if trainer.is_world_process_zero(): + for k,v in additional_metrics.items(): + tracker.track(metric=v, name=k, stage='additional_metrics') + tracker.set_params(params=exp_metadata, name='experiment_metadata') + if run_distributed and peft_config is not None: trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy( model @@ -286,7 +296,7 @@ def main(**kwargs): default="pt", ) parser.add_argument( - "--extra_metadata", + "--exp_metadata", type=str, default=None, ) @@ -315,27 +325,37 @@ def main(**kwargs): else: tracker_config=None - # Initialize the tracker early so we can calculate custom metrics like model_load_time. - tracker = get_tracker(tracker_name, tracker_config) - # Initialize callbacks file_logger_callback = FileLoggingCallback(logger) peft_saving_callback = PeftSavingCallback() callbacks = [peft_saving_callback, file_logger_callback] + # Initialize the tracker + tracker = TrackerFactory.get_tracker(tracker_name, tracker_config) tracker_callback = tracker.get_hf_callback() if tracker_callback is not None: callbacks.append(tracker_callback) - # track extra metadata - if additional.extra_metadata is not None: + # extra metadata passed via client + metadata = None + if additional.exp_metadata is not None: try: - metadata = json.loads(additional.extra_metadata) - tracker.track_metadata(metadata) + metadata = json.loads(additional.exp_metadata) + if metadata is None or not isinstance(metadata, Dict): + logger.warning('metadata cannot be converted to simple k:v dict ignoring') + metadata = None except: logger.error("failed while parsing extra metadata. pass a valid json") - train(model_args, data_args, training_args, tune_config, callbacks, tracker) + train( + model_args=model_args, + data_args=data_args, + train_args=training_args, + peft_config=tune_config, + callbacks=callbacks, + tracker=tracker, + exp_metadata=metadata + ) if __name__ == "__main__": fire.Fire(main) diff --git a/tuning/tracker/aimstack_tracker.py b/tuning/tracker/aimstack_tracker.py index 138f6c673..7640b1db7 100644 --- a/tuning/tracker/aimstack_tracker.py +++ b/tuning/tracker/aimstack_tracker.py @@ -1,15 +1,49 @@ # Standard import os -from tuning.tracker.tracker import Tracker +from .tracker import Tracker +from tuning.config.tracker_configs import AimConfig # Third Party from aim.hugging_face import AimCallback +class CustomAimCallback(AimCallback): + + # A path to export run hash generated by Aim + # This is used to link back to the expriments from outside aimstack + aim_run_hash_export_path = None + + def on_init_end(self, args, state, control, **kwargs): + + if state and not state.is_world_process_zero: + return + + self.setup() # initializes the run_hash + + # store the run hash + if self.aim_run_hash_export_path: + with open(self.aim_run_hash_export_path, 'w') as f: + f.write('{\"run_hash\":\"'+str(self._run.hash)+'\"}\n') + + def on_train_begin(self, args, state, control, model=None, **kwargs): + # call directly to make sure hyper parameters and model info is recorded. + self.setup(args=args, state=state, model=model) + + def track_metrics(self, metric, name, context): + if self._run is not None: + self._run.track(metric, name=name, context=context) + + def set_params(self, params, name): + if self._run is not None: + for key, value in params.items(): + self._run.set((name, key), value, strict=False) + class AimStackTracker(Tracker): - def __init__(self, tracker_config): - super().__init__(tracker_config) + def __init__(self, tracker_config: AimConfig): + super().__init__(name='aim', tracker_config=tracker_config) + + def get_hf_callback(self): c = self.config exp = c.experiment ip = c.aim_remote_server_ip @@ -18,30 +52,24 @@ def __init__(self, tracker_config): hash_export_path = c.aim_run_hash_export_path if (ip is not None and port is not None): - aim_callback = AimCallback( + aim_callback = CustomAimCallback( repo="aim://" + ip +":"+ port + "/", - experiment=exp - ) + experiment=exp) if repo: - aim_callback = AimCallback(repo=repo, experiment=exp) + aim_callback = CustomAimCallback(repo=repo, experiment=exp) else: - aim_callback = AimCallback(experiment=exp) + aim_callback = CustomAimCallback(experiment=exp) - run = aim_callback.experiment # Initialize Aim run - run_hash = run.hash # Extract the hash - - # store the run hash - if hash_export_path: - with open(hash_export_path, 'w') as f: - f.write(str(run_hash)+'\n') - - # Save Internal State + aim_callback.aim_run_hash_export_path = hash_export_path self.hf_callback = aim_callback - self.run = run - - def get_hf_callback(self): return self.hf_callback def track(self, metric, name, stage='additional_metrics'): context={'subset' : stage} - self.run.track(metric, name=name, context=context) \ No newline at end of file + self.hf_callback.track_metrics(metric, name=name, context=context) + + def set_params(self, params, name='extra_params'): + try: + self.hf_callback.set_params(params, name) + except: + pass \ No newline at end of file diff --git a/tuning/tracker/tracker.py b/tuning/tracker/tracker.py index 013b082dd..a1d2d2ba0 100644 --- a/tuning/tracker/tracker.py +++ b/tuning/tracker/tracker.py @@ -1,10 +1,13 @@ # Generic Tracker API -from tuning.tracker.aimstack_tracker import AimStackTracker - class Tracker: - def __init__(self, tracker_config) -> None: - self.config = tracker_config + def __init__(self, name=None, tracker_config=None) -> None: + if tracker_config is not None: + self.config = tracker_config + if name is None: + self._name = "None" + else: + self._name = name def get_hf_callback(): return None @@ -12,18 +15,15 @@ def get_hf_callback(): def track(self, metric, name, stage): pass - # Metadata passed here is supposed to be a KV object - # Key being the name and value being the metric to track. - def track_metadata(self, metadata=None): - if metadata is None or not isinstance(metadata, dict): - return - for k, v in metadata.items(): - self.track(name=k, metric=v) + # Object passed here is supposed to be a KV object + # for the parameters to be associated with a run + def set_params(self, params, name): + pass -def get_tracker(tracker_name, tracker_config): - if tracker_name == 'aim': - if tracker_config is not None: - tracker = AimStackTracker(tracker_config) - else: - tracker = Tracker() - return tracker \ No newline at end of file +class TrackerFactory: + def get_tracker(tracker_name, tracker_config): + for T in Tracker.__subclasses__(): + if T._name == tracker_name: + return T(tracker_config) + else: + return Tracker() \ No newline at end of file