diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 2eb53cc9c..5abd2e294 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -26,7 +26,7 @@ from tuning.data import tokenizer_data_utils from tuning.utils.config_utils import get_hf_peft_config from tuning.utils.data_type_utils import get_torch_dtype -from tuning.tracker.tracker import Tracker +from tuning.tracker.tracker import Tracker, get_tracker from tuning.tracker.aimstack_tracker import AimStackTracker logger = logging.get_logger("sft_trainer") @@ -92,7 +92,7 @@ def train( Union[peft_config.LoraConfig, peft_config.PromptTuningConfig] ] = None, callbacks: Optional[List[TrainerCallback]] = None, - tracker: Optional[Tracker] = None, + tracker: Optional[Tracker] = Tracker() # default tracker is dummy tracker ): """Call the SFTTrainer @@ -285,6 +285,11 @@ def main(**kwargs): choices=["pt", "lora", None, "none"], default="pt", ) + parser.add_argument( + "--extra_metadata", + type=str, + default=None, + ) ( model_args, data_args, @@ -311,14 +316,7 @@ def main(**kwargs): tracker_config=None # Initialize the tracker early so we can calculate custom metrics like model_load_time. - tracker_name = training_args.tracker - if tracker_name == 'aim': - if tracker_config is not None: - tracker = AimStackTracker(tracker_config) - else: - logger.error("Tracker name is set to "+tracker_name+" but config is None.") - else: - tracker = Tracker() + tracker = get_tracker(tracker_name, tracker_config) # Initialize callbacks file_logger_callback = FileLoggingCallback(logger) @@ -329,6 +327,14 @@ def main(**kwargs): if tracker_callback is not None: callbacks.append(tracker_callback) + # track extra metadata + if additional.extra_metadata is not None: + try: + metadata = json.loads(additional.extra_metadata) + tracker.track_metadata(metadata) + except: + logger.error("failed while parsing extra metadata. pass a valid json") + train(model_args, data_args, training_args, tune_config, callbacks, tracker) if __name__ == "__main__": diff --git a/tuning/tracker/tracker.py b/tuning/tracker/tracker.py index d0f26bdac..013b082dd 100644 --- a/tuning/tracker/tracker.py +++ b/tuning/tracker/tracker.py @@ -1,5 +1,7 @@ # Generic Tracker API +from tuning.tracker.aimstack_tracker import AimStackTracker + class Tracker: def __init__(self, tracker_config) -> None: self.config = tracker_config @@ -8,4 +10,20 @@ def get_hf_callback(): return None def track(self, metric, name, stage): - pass \ No newline at end of file + pass + + # Metadata passed here is supposed to be a KV object + # Key being the name and value being the metric to track. + def track_metadata(self, metadata=None): + if metadata is None or not isinstance(metadata, dict): + return + for k, v in metadata.items(): + self.track(name=k, metric=v) + +def get_tracker(tracker_name, tracker_config): + if tracker_name == 'aim': + if tracker_config is not None: + tracker = AimStackTracker(tracker_config) + else: + tracker = Tracker() + return tracker \ No newline at end of file