From 1836a67d912842c06561d15263b2d6a8a63b06fe Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Mon, 19 Feb 2024 18:51:28 +0530 Subject: [PATCH 1/9] Generic Tracker API with command line arguments. Tracker now takes command line arguments as config. Aim stack is the default tracker and code contains example to measure additional metrics seamlessly into aimstack like 'model_load_time' Signed-off-by: Dushyant Behl --- tuning/aim_loader.py | 19 ---------- tuning/config/tracker_configs.py | 14 +++++++ tuning/sft_trainer.py | 60 +++++++++++++++++++++++------- tuning/tracker/__init__.py | 0 tuning/tracker/aimstack_tracker.py | 40 ++++++++++++++++++++ tuning/tracker/tracker.py | 11 ++++++ 6 files changed, 112 insertions(+), 32 deletions(-) delete mode 100644 tuning/aim_loader.py create mode 100644 tuning/config/tracker_configs.py create mode 100644 tuning/tracker/__init__.py create mode 100644 tuning/tracker/aimstack_tracker.py create mode 100644 tuning/tracker/tracker.py diff --git a/tuning/aim_loader.py b/tuning/aim_loader.py deleted file mode 100644 index 44aa46748..000000000 --- a/tuning/aim_loader.py +++ /dev/null @@ -1,19 +0,0 @@ -import os -from aim.hugging_face import AimCallback - -def get_aimstack_callback(): - # Initialize a new run - aim_server = os.environ.get('AIMSTACK_SERVER') - aim_db = os.environ.get('AIMSTACK_DB') - aim_experiment = os.environ.get('AIMSTACK_EXPERIMENT') - if aim_experiment is None: - aim_experiment = "" - - if aim_server: - aim_callback = AimCallback(repo='aim://'+aim_server+'/', experiment=aim_experiment) - if aim_db: - aim_callback = AimCallback(repo=aim_db, experiment=aim_experiment) - else: - aim_callback = AimCallback(experiment=aim_experiment) - - return aim_callback diff --git a/tuning/config/tracker_configs.py b/tuning/config/tracker_configs.py new file mode 100644 index 000000000..4d9d193b0 --- /dev/null +++ b/tuning/config/tracker_configs.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass + +@dataclass +class AimConfig: + # 'repo' can point to a locally accessible directory (e.g., '~/.aim') or a remote repository hosted on a server. + # When 'remote_server_ip' or 'remote_server_port' is set, it designates a remote aim repo. + # Otherwise, 'repo' specifies the directory, with a default of None representing '.aim'. + repo: str = None + remote_server_ip: str = None + remote_server_port: int = None + # Name of the experiment + experiment: str = None + # Location of where run_hash is exported + run_hash_export_location: str = None diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index dcb83245d..c7c133bbf 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -1,4 +1,5 @@ import os +import time from typing import Optional, Union import datasets @@ -10,11 +11,12 @@ from transformers.utils import logging from transformers import TrainerCallback from trl import SFTTrainer, DataCollatorForCompletionOnlyLM -from tuning.aim_loader import get_aimstack_callback -from tuning.config import configs, peft_config +from tuning.config import configs, peft_config, tracker_configs from tuning.data import tokenizer_data_utils from tuning.utils.config_utils import get_hf_peft_config from tuning.utils.data_type_utils import get_torch_dtype +from tuning.tracker.tracker import Tracker +from tuning.tracker.aimstack_tracker import AimStackTracker class PeftSavingCallback(TrainerCallback): def on_save(self, args, state, control, **kwargs): @@ -24,13 +26,13 @@ def on_save(self, args, state, control, **kwargs): if "pytorch_model.bin" in os.listdir(checkpoint_path): os.remove(os.path.join(checkpoint_path, "pytorch_model.bin")) - - def train( model_args: configs.ModelArguments, data_args: configs.DataArguments, train_args: configs.TrainingArguments, peft_config: Optional[Union[peft_config.LoraConfig, peft_config.PromptTuningConfig]] = None, + tracker_name: Optional[str] = None, + tracker_config: Optional[Union[tracker_configs.AimConfig]] = None ): """Call the SFTTrainer @@ -44,7 +46,6 @@ def train( The peft configuration to pass to trainer """ run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1 - logger = logging.get_logger("sft_trainer") # Validate parameters @@ -58,14 +59,29 @@ def train( train_args.fsdp = "" train_args.fsdp_config = {'xla':False} + # Initialize the tracker early so we can calculate custom metrics like model_load_time. + + if tracker_name == 'aim': + if tracker_config is not None: + tracker = AimStackTracker(tracker_config) + else: + logger.error("Tracker name is set to "+tracker_name+" but config is None.") + else: + logger.info('No tracker set so just set a dummy API which does nothing') + tracker = Tracker() + task_type = "CAUSAL_LM" + + model_load_time = time.time() model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=train_args.cache_dir, torch_dtype=get_torch_dtype(model_args.torch_dtype), use_flash_attention_2=model_args.use_flash_attn, ) - + model_load_time = time.time() - model_load_time + tracker.track(metric=model_load_time, name='model_load_time') + peft_config = get_hf_peft_config(task_type, peft_config) model.gradient_checkpointing_enable() @@ -130,8 +146,12 @@ def train( formatted_dataset = json_dataset['train'].map(lambda example : {f"{data_args.dataset_text_field}" : example[f"{data_args.dataset_text_field}"] + tokenizer.eos_token}) logger.info(f"Dataset length is {len(formatted_dataset)}") - aim_callback = get_aimstack_callback() - callbacks=[aim_callback,PeftSavingCallback()] + # club and register callbacks + callbacks = [PeftSavingCallback()] + + tracker_callback = tracker.get_hf_callback() + if tracker_callback is not None: + callbacks.append(tracker_callback) if train_args.packing: logger.info("Packing is set to True") @@ -173,16 +193,30 @@ def main(**kwargs): configs.DataArguments, configs.TrainingArguments, peft_config.LoraConfig, - peft_config.PromptTuningConfig)) + peft_config.PromptTuningConfig, + tracker_configs.AimConfig)) parser.add_argument('--peft_method', type=str.lower, choices=['pt', 'lora', None, 'none'], default="pt") - model_args, data_args, training_args, lora_config, prompt_tuning_config, peft_method, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) - if peft_method.peft_method =="lora": + parser.add_argument('--tracker', type=str.lower, choices=['aim', None, 'none'], default="aim") + (model_args, data_args, training_args, + lora_config, prompt_tuning_config, aim_config, + additional, _) = parser.parse_args_into_dataclasses(return_remaining_strings=True) + + peft_method = additional.peft_method + tracker_name = additional.tracker + + if peft_method =="lora": tune_config=lora_config - elif peft_method.peft_method =="pt": + elif peft_method =="pt": tune_config=prompt_tuning_config else: tune_config=None - train(model_args, data_args, training_args, tune_config) + + if tracker_name == "aim": + tracker_config=aim_config + else: + tracker_config=None + + train(model_args, data_args, training_args, tune_config, tracker_name, tracker_config) if __name__ == "__main__": fire.Fire(main) diff --git a/tuning/tracker/__init__.py b/tuning/tracker/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tuning/tracker/aimstack_tracker.py b/tuning/tracker/aimstack_tracker.py new file mode 100644 index 000000000..8efce867d --- /dev/null +++ b/tuning/tracker/aimstack_tracker.py @@ -0,0 +1,40 @@ +# Standard +import os + +from tuning.tracker.tracker import Tracker + +# Third Party +from aim.hugging_face import AimCallback + +class AimStackTracker(Tracker): + + def __init__(self, tracker_config): + super().__init__(tracker_config) + c = self.config + if (c.remote_server_ip is not None and + c.remote_server_port is not None): + aim_callback = AimCallback(repo="aim://" + c.remote_server_ip+":"+ c.remote_server_port+ "/", + experiment=c.experiment) + if c.repo: + aim_callback = AimCallback(repo=c.repo, experiment=c.experiment) + else: + aim_callback = AimCallback(experiment=c.experiment) + + run = aim_callback.experiment # Initialize Aim run + run_hash = run.hash # Extract the hash + + # store the run hash + if c.run_hash_export_location: + with open(c.run_hash_export_location, 'w') as f: + f.write(str(run_hash)+'\n') + + # Save Internal State + self.hf_callback = aim_callback + self.run = run + + def get_hf_callback(self): + return self.hf_callback + + def track(self, metric, name, stage='additional_metrics'): + context={'subset' : stage} + self.run.track(metric, name=name, context=context) \ No newline at end of file diff --git a/tuning/tracker/tracker.py b/tuning/tracker/tracker.py new file mode 100644 index 000000000..d0f26bdac --- /dev/null +++ b/tuning/tracker/tracker.py @@ -0,0 +1,11 @@ +# Generic Tracker API + +class Tracker: + def __init__(self, tracker_config) -> None: + self.config = tracker_config + + def get_hf_callback(): + return None + + def track(self, metric, name, stage): + pass \ No newline at end of file From 2d5ffa5a1029b07942fda72aee622ddf18e3297f Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Tue, 20 Feb 2024 10:57:11 +0530 Subject: [PATCH 2/9] Bump aim version. Fix duplicate argument Signed-off-by: Dushyant Behl --- requirements.txt | 2 +- tuning/config/configs.py | 4 ++-- tuning/sft_trainer.py | 6 ------ 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/requirements.txt b/requirements.txt index c4e8fa1ef..8f658ca25 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ numpy accelerate>=0.20.3 transformers>=4.34.1 torch -aim==3.17.5 +aim==3.18.1 sentencepiece tokenizers>=0.13.3 tqdm diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 7919841c0..10a024e6d 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -57,7 +57,7 @@ class TrainingArguments(transformers.TrainingArguments): default=False, metadata={"help": "Packing to be enabled in SFT Trainer, default is False"}, ) - tracker: str = field( + tracker: str.lower = field( default="aim", - metadata={"help", "Default experiment tracker to integrate with. requires additional configs. see tuning.configs/tracker_configs.py"} + metadata={"help": "Default experiment tracker to integrate with. requires additional configs. see tuning.configs/tracker_configs.py"} ) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 75d335bd9..c9c82303b 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -299,12 +299,6 @@ def main(**kwargs): choices=["pt", "lora", None, "none"], default="pt", ) - parser.add_argument( - "--tracker", - type=str.lower, - choices=['aim', None, 'none'], - default="aim" - ) ( model_args, data_args, From 80453c49772c488c32e5ce259cf7c1d107f833c3 Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Tue, 20 Feb 2024 18:00:28 +0530 Subject: [PATCH 3/9] default tracker none Signed-off-by: Dushyant Behl --- tuning/config/configs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 10a024e6d..3223bfa7c 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -58,6 +58,6 @@ class TrainingArguments(transformers.TrainingArguments): metadata={"help": "Packing to be enabled in SFT Trainer, default is False"}, ) tracker: str.lower = field( - default="aim", - metadata={"help": "Default experiment tracker to integrate with. requires additional configs. see tuning.configs/tracker_configs.py"} + default=None, + metadata={"help": "Experiment tracker to use. Requires additional configs, see tuning.configs/tracker_configs.py"} ) From b0c170c6f625185098eee3a9e388efad525e3d56 Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Wed, 21 Feb 2024 11:44:44 +0530 Subject: [PATCH 4/9] separate callbacks from train Signed-off-by: Dushyant Behl --- tuning/aim_loader.py | 25 -------------------- tuning/sft_trainer.py | 53 ++++++++++++++++++++----------------------- 2 files changed, 25 insertions(+), 53 deletions(-) delete mode 100644 tuning/aim_loader.py diff --git a/tuning/aim_loader.py b/tuning/aim_loader.py deleted file mode 100644 index 6ee617a42..000000000 --- a/tuning/aim_loader.py +++ /dev/null @@ -1,25 +0,0 @@ -# Standard -import os - -# Third Party -from aim.hugging_face import AimCallback - - -def get_aimstack_callback(): - # Initialize a new run - aim_server = os.environ.get("AIMSTACK_SERVER") - aim_db = os.environ.get("AIMSTACK_DB") - aim_experiment = os.environ.get("AIMSTACK_EXPERIMENT") - if aim_experiment is None: - aim_experiment = "" - - if aim_server: - aim_callback = AimCallback( - repo="aim://" + aim_server + "/", experiment=aim_experiment - ) - if aim_db: - aim_callback = AimCallback(repo=aim_db, experiment=aim_experiment) - else: - aim_callback = AimCallback(experiment=aim_experiment) - - return aim_callback diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index a4840f015..2eb53cc9c 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -1,11 +1,11 @@ # Standard from datetime import datetime -from typing import Optional, Union +from typing import Optional, Union, List import json import os, time # Third Party -from peft.utils.other import fsdp_auto_wrap_policy +import transformers from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -16,10 +16,10 @@ TrainerCallback, ) from transformers.utils import logging +from peft.utils.other import fsdp_auto_wrap_policy from trl import DataCollatorForCompletionOnlyLM, SFTTrainer import datasets import fire -import transformers # Local from tuning.config import configs, peft_config, tracker_configs @@ -29,6 +29,7 @@ from tuning.tracker.tracker import Tracker from tuning.tracker.aimstack_tracker import AimStackTracker +logger = logging.get_logger("sft_trainer") class PeftSavingCallback(TrainerCallback): def on_save(self, args, state, control, **kwargs): @@ -83,7 +84,6 @@ def _track_loss(self, loss_key, log_file, logs, state): with open(log_file, "a") as f: f.write(f"{json.dumps(log_obj, sort_keys=True)}\n") - def train( model_args: configs.ModelArguments, data_args: configs.DataArguments, @@ -91,7 +91,8 @@ def train( peft_config: Optional[ Union[peft_config.LoraConfig, peft_config.PromptTuningConfig] ] = None, - tracker_config: Optional[Union[tracker_configs.AimConfig]] = None + callbacks: Optional[List[TrainerCallback]] = None, + tracker: Optional[Tracker] = None, ): """Call the SFTTrainer @@ -105,7 +106,6 @@ def train( The peft configuration to pass to trainer """ run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1 - logger = logging.get_logger("sft_trainer") # Validate parameters if (not isinstance(train_args.num_train_epochs, float)) or ( @@ -122,17 +122,6 @@ def train( train_args.fsdp = "" train_args.fsdp_config = {"xla": False} - # Initialize the tracker early so we can calculate custom metrics like model_load_time. - tracker_name = train_args.tracker - if tracker_name == 'aim': - if tracker_config is not None: - tracker = AimStackTracker(tracker_config) - else: - logger.error("Tracker name is set to "+tracker_name+" but config is None.") - else: - logger.info('No tracker set so just set a dummy API which does nothing') - tracker = Tracker() - task_type = "CAUSAL_LM" model_load_time = time.time() @@ -259,15 +248,6 @@ def train( ) packing = False - # club and register callbacks - file_logger_callback = FileLoggingCallback(logger) - peft_saving_callback = PeftSavingCallback() - callbacks = [peft_saving_callback, file_logger_callback] - - tracker_callback = tracker.get_hf_callback() - if tracker_callback is not None: - callbacks.append(tracker_callback) - trainer = SFTTrainer( model=model, tokenizer=tokenizer, @@ -288,7 +268,6 @@ def train( ) trainer.train() - def main(**kwargs): parser = transformers.HfArgumentParser( dataclass_types=( @@ -331,8 +310,26 @@ def main(**kwargs): else: tracker_config=None - train(model_args, data_args, training_args, tune_config, tracker_config) + # Initialize the tracker early so we can calculate custom metrics like model_load_time. + tracker_name = training_args.tracker + if tracker_name == 'aim': + if tracker_config is not None: + tracker = AimStackTracker(tracker_config) + else: + logger.error("Tracker name is set to "+tracker_name+" but config is None.") + else: + tracker = Tracker() + + # Initialize callbacks + file_logger_callback = FileLoggingCallback(logger) + peft_saving_callback = PeftSavingCallback() + callbacks = [peft_saving_callback, file_logger_callback] + + tracker_callback = tracker.get_hf_callback() + if tracker_callback is not None: + callbacks.append(tracker_callback) + train(model_args, data_args, training_args, tune_config, callbacks, tracker) if __name__ == "__main__": fire.Fire(main) From 4f93a7cadb23bc23288760f3db723ba2b8e4c889 Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Wed, 21 Feb 2024 16:27:16 +0530 Subject: [PATCH 5/9] enable tracker to track extra metadata Signed-off-by: Dushyant Behl --- tuning/sft_trainer.py | 26 ++++++++++++++++---------- tuning/tracker/tracker.py | 20 +++++++++++++++++++- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 2eb53cc9c..5abd2e294 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -26,7 +26,7 @@ from tuning.data import tokenizer_data_utils from tuning.utils.config_utils import get_hf_peft_config from tuning.utils.data_type_utils import get_torch_dtype -from tuning.tracker.tracker import Tracker +from tuning.tracker.tracker import Tracker, get_tracker from tuning.tracker.aimstack_tracker import AimStackTracker logger = logging.get_logger("sft_trainer") @@ -92,7 +92,7 @@ def train( Union[peft_config.LoraConfig, peft_config.PromptTuningConfig] ] = None, callbacks: Optional[List[TrainerCallback]] = None, - tracker: Optional[Tracker] = None, + tracker: Optional[Tracker] = Tracker() # default tracker is dummy tracker ): """Call the SFTTrainer @@ -285,6 +285,11 @@ def main(**kwargs): choices=["pt", "lora", None, "none"], default="pt", ) + parser.add_argument( + "--extra_metadata", + type=str, + default=None, + ) ( model_args, data_args, @@ -311,14 +316,7 @@ def main(**kwargs): tracker_config=None # Initialize the tracker early so we can calculate custom metrics like model_load_time. - tracker_name = training_args.tracker - if tracker_name == 'aim': - if tracker_config is not None: - tracker = AimStackTracker(tracker_config) - else: - logger.error("Tracker name is set to "+tracker_name+" but config is None.") - else: - tracker = Tracker() + tracker = get_tracker(tracker_name, tracker_config) # Initialize callbacks file_logger_callback = FileLoggingCallback(logger) @@ -329,6 +327,14 @@ def main(**kwargs): if tracker_callback is not None: callbacks.append(tracker_callback) + # track extra metadata + if additional.extra_metadata is not None: + try: + metadata = json.loads(additional.extra_metadata) + tracker.track_metadata(metadata) + except: + logger.error("failed while parsing extra metadata. pass a valid json") + train(model_args, data_args, training_args, tune_config, callbacks, tracker) if __name__ == "__main__": diff --git a/tuning/tracker/tracker.py b/tuning/tracker/tracker.py index d0f26bdac..013b082dd 100644 --- a/tuning/tracker/tracker.py +++ b/tuning/tracker/tracker.py @@ -1,5 +1,7 @@ # Generic Tracker API +from tuning.tracker.aimstack_tracker import AimStackTracker + class Tracker: def __init__(self, tracker_config) -> None: self.config = tracker_config @@ -8,4 +10,20 @@ def get_hf_callback(): return None def track(self, metric, name, stage): - pass \ No newline at end of file + pass + + # Metadata passed here is supposed to be a KV object + # Key being the name and value being the metric to track. + def track_metadata(self, metadata=None): + if metadata is None or not isinstance(metadata, dict): + return + for k, v in metadata.items(): + self.track(name=k, metric=v) + +def get_tracker(tracker_name, tracker_config): + if tracker_name == 'aim': + if tracker_config is not None: + tracker = AimStackTracker(tracker_config) + else: + tracker = Tracker() + return tracker \ No newline at end of file From a28aa7b24a66f6b74d8349387f159ee28b4b84a0 Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Thu, 22 Feb 2024 20:53:02 +0530 Subject: [PATCH 6/9] Change to custom aim callback to disable multiple instantiation for FSDP. Add tracker factory. Signed-off-by: Dushyant Behl --- tuning/sft_trainer.py | 50 ++++++++++++++------- tuning/tracker/aimstack_tracker.py | 70 +++++++++++++++++++++--------- tuning/tracker/tracker.py | 36 +++++++-------- 3 files changed, 102 insertions(+), 54 deletions(-) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 5abd2e294..0938aa650 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -1,6 +1,6 @@ # Standard from datetime import datetime -from typing import Optional, Union, List +from typing import Optional, Union, List, Dict import json import os, time @@ -26,8 +26,7 @@ from tuning.data import tokenizer_data_utils from tuning.utils.config_utils import get_hf_peft_config from tuning.utils.data_type_utils import get_torch_dtype -from tuning.tracker.tracker import Tracker, get_tracker -from tuning.tracker.aimstack_tracker import AimStackTracker +from tuning.tracker.tracker import Tracker, TrackerFactory logger = logging.get_logger("sft_trainer") @@ -92,7 +91,8 @@ def train( Union[peft_config.LoraConfig, peft_config.PromptTuningConfig] ] = None, callbacks: Optional[List[TrainerCallback]] = None, - tracker: Optional[Tracker] = Tracker() # default tracker is dummy tracker + tracker: Optional[Tracker] = None, + exp_metadata: Optional[Dict] = None ): """Call the SFTTrainer @@ -123,6 +123,7 @@ def train( train_args.fsdp_config = {"xla": False} task_type = "CAUSAL_LM" + additional_metrics = {} model_load_time = time.time() model = AutoModelForCausalLM.from_pretrained( @@ -131,8 +132,7 @@ def train( torch_dtype=get_torch_dtype(model_args.torch_dtype), use_flash_attention_2=model_args.use_flash_attn, ) - model_load_time = time.time() - model_load_time - tracker.track(metric=model_load_time, name='model_load_time') + additional_metrics['model_load_time'] = time.time() - model_load_time peft_config = get_hf_peft_config(task_type, peft_config) @@ -262,6 +262,16 @@ def train( peft_config=peft_config, ) + # We track additional metrics and experiment metadata after + # Trainer object creation to ensure that this is not repeated + # multiple times for FSDP runs. + if tracker is not None: + # Currently tracked only on process zero. + if trainer.is_world_process_zero(): + for k,v in additional_metrics.items(): + tracker.track(metric=v, name=k, stage='additional_metrics') + tracker.set_params(params=exp_metadata, name='experiment_metadata') + if run_distributed and peft_config is not None: trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy( model @@ -286,7 +296,7 @@ def main(**kwargs): default="pt", ) parser.add_argument( - "--extra_metadata", + "--exp_metadata", type=str, default=None, ) @@ -315,27 +325,37 @@ def main(**kwargs): else: tracker_config=None - # Initialize the tracker early so we can calculate custom metrics like model_load_time. - tracker = get_tracker(tracker_name, tracker_config) - # Initialize callbacks file_logger_callback = FileLoggingCallback(logger) peft_saving_callback = PeftSavingCallback() callbacks = [peft_saving_callback, file_logger_callback] + # Initialize the tracker + tracker = TrackerFactory.get_tracker(tracker_name, tracker_config) tracker_callback = tracker.get_hf_callback() if tracker_callback is not None: callbacks.append(tracker_callback) - # track extra metadata - if additional.extra_metadata is not None: + # extra metadata passed via client + metadata = None + if additional.exp_metadata is not None: try: - metadata = json.loads(additional.extra_metadata) - tracker.track_metadata(metadata) + metadata = json.loads(additional.exp_metadata) + if metadata is None or not isinstance(metadata, Dict): + logger.warning('metadata cannot be converted to simple k:v dict ignoring') + metadata = None except: logger.error("failed while parsing extra metadata. pass a valid json") - train(model_args, data_args, training_args, tune_config, callbacks, tracker) + train( + model_args=model_args, + data_args=data_args, + train_args=training_args, + peft_config=tune_config, + callbacks=callbacks, + tracker=tracker, + exp_metadata=metadata + ) if __name__ == "__main__": fire.Fire(main) diff --git a/tuning/tracker/aimstack_tracker.py b/tuning/tracker/aimstack_tracker.py index 138f6c673..7640b1db7 100644 --- a/tuning/tracker/aimstack_tracker.py +++ b/tuning/tracker/aimstack_tracker.py @@ -1,15 +1,49 @@ # Standard import os -from tuning.tracker.tracker import Tracker +from .tracker import Tracker +from tuning.config.tracker_configs import AimConfig # Third Party from aim.hugging_face import AimCallback +class CustomAimCallback(AimCallback): + + # A path to export run hash generated by Aim + # This is used to link back to the expriments from outside aimstack + aim_run_hash_export_path = None + + def on_init_end(self, args, state, control, **kwargs): + + if state and not state.is_world_process_zero: + return + + self.setup() # initializes the run_hash + + # store the run hash + if self.aim_run_hash_export_path: + with open(self.aim_run_hash_export_path, 'w') as f: + f.write('{\"run_hash\":\"'+str(self._run.hash)+'\"}\n') + + def on_train_begin(self, args, state, control, model=None, **kwargs): + # call directly to make sure hyper parameters and model info is recorded. + self.setup(args=args, state=state, model=model) + + def track_metrics(self, metric, name, context): + if self._run is not None: + self._run.track(metric, name=name, context=context) + + def set_params(self, params, name): + if self._run is not None: + for key, value in params.items(): + self._run.set((name, key), value, strict=False) + class AimStackTracker(Tracker): - def __init__(self, tracker_config): - super().__init__(tracker_config) + def __init__(self, tracker_config: AimConfig): + super().__init__(name='aim', tracker_config=tracker_config) + + def get_hf_callback(self): c = self.config exp = c.experiment ip = c.aim_remote_server_ip @@ -18,30 +52,24 @@ def __init__(self, tracker_config): hash_export_path = c.aim_run_hash_export_path if (ip is not None and port is not None): - aim_callback = AimCallback( + aim_callback = CustomAimCallback( repo="aim://" + ip +":"+ port + "/", - experiment=exp - ) + experiment=exp) if repo: - aim_callback = AimCallback(repo=repo, experiment=exp) + aim_callback = CustomAimCallback(repo=repo, experiment=exp) else: - aim_callback = AimCallback(experiment=exp) + aim_callback = CustomAimCallback(experiment=exp) - run = aim_callback.experiment # Initialize Aim run - run_hash = run.hash # Extract the hash - - # store the run hash - if hash_export_path: - with open(hash_export_path, 'w') as f: - f.write(str(run_hash)+'\n') - - # Save Internal State + aim_callback.aim_run_hash_export_path = hash_export_path self.hf_callback = aim_callback - self.run = run - - def get_hf_callback(self): return self.hf_callback def track(self, metric, name, stage='additional_metrics'): context={'subset' : stage} - self.run.track(metric, name=name, context=context) \ No newline at end of file + self.hf_callback.track_metrics(metric, name=name, context=context) + + def set_params(self, params, name='extra_params'): + try: + self.hf_callback.set_params(params, name) + except: + pass \ No newline at end of file diff --git a/tuning/tracker/tracker.py b/tuning/tracker/tracker.py index 013b082dd..a1d2d2ba0 100644 --- a/tuning/tracker/tracker.py +++ b/tuning/tracker/tracker.py @@ -1,10 +1,13 @@ # Generic Tracker API -from tuning.tracker.aimstack_tracker import AimStackTracker - class Tracker: - def __init__(self, tracker_config) -> None: - self.config = tracker_config + def __init__(self, name=None, tracker_config=None) -> None: + if tracker_config is not None: + self.config = tracker_config + if name is None: + self._name = "None" + else: + self._name = name def get_hf_callback(): return None @@ -12,18 +15,15 @@ def get_hf_callback(): def track(self, metric, name, stage): pass - # Metadata passed here is supposed to be a KV object - # Key being the name and value being the metric to track. - def track_metadata(self, metadata=None): - if metadata is None or not isinstance(metadata, dict): - return - for k, v in metadata.items(): - self.track(name=k, metric=v) + # Object passed here is supposed to be a KV object + # for the parameters to be associated with a run + def set_params(self, params, name): + pass -def get_tracker(tracker_name, tracker_config): - if tracker_name == 'aim': - if tracker_config is not None: - tracker = AimStackTracker(tracker_config) - else: - tracker = Tracker() - return tracker \ No newline at end of file +class TrackerFactory: + def get_tracker(tracker_name, tracker_config): + for T in Tracker.__subclasses__(): + if T._name == tracker_name: + return T(tracker_config) + else: + return Tracker() \ No newline at end of file From 550604d0354ef9cf797e9f6c82ac6d499c804f3f Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Fri, 23 Feb 2024 14:08:57 +0530 Subject: [PATCH 7/9] change default output path for aim run export Signed-off-by: Dushyant Behl --- tuning/config/tracker_configs.py | 3 ++- tuning/sft_trainer.py | 5 +++-- tuning/{tracker => trackers}/__init__.py | 0 .../{tracker => trackers}/aimstack_tracker.py | 19 ++++++++++++++----- tuning/{tracker => trackers}/tracker.py | 13 +++---------- tuning/trackers/tracker_factory.py | 13 +++++++++++++ 6 files changed, 35 insertions(+), 18 deletions(-) rename tuning/{tracker => trackers}/__init__.py (100%) rename tuning/{tracker => trackers}/aimstack_tracker.py (78%) rename tuning/{tracker => trackers}/tracker.py (64%) create mode 100644 tuning/trackers/tracker_factory.py diff --git a/tuning/config/tracker_configs.py b/tuning/config/tracker_configs.py index 904ad3889..bd97253b0 100644 --- a/tuning/config/tracker_configs.py +++ b/tuning/config/tracker_configs.py @@ -10,5 +10,6 @@ class AimConfig: aim_repo: str = None aim_remote_server_ip: str = None aim_remote_server_port: int = None - # Location of where run_hash is exported + # Location of where run_hash is exported, if unspecified this is output to + # training_args.output_dir/.aim_run_hash if the output_dir is set else not exported. aim_run_hash_export_path: str = None diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 0938aa650..4f48ec881 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -26,7 +26,8 @@ from tuning.data import tokenizer_data_utils from tuning.utils.config_utils import get_hf_peft_config from tuning.utils.data_type_utils import get_torch_dtype -from tuning.tracker.tracker import Tracker, TrackerFactory +from tuning.trackers.tracker import Tracker +from tuning.trackers.tracker_factory import get_tracker logger = logging.get_logger("sft_trainer") @@ -331,7 +332,7 @@ def main(**kwargs): callbacks = [peft_saving_callback, file_logger_callback] # Initialize the tracker - tracker = TrackerFactory.get_tracker(tracker_name, tracker_config) + tracker = get_tracker(tracker_name, tracker_config) tracker_callback = tracker.get_hf_callback() if tracker_callback is not None: callbacks.append(tracker_callback) diff --git a/tuning/tracker/__init__.py b/tuning/trackers/__init__.py similarity index 100% rename from tuning/tracker/__init__.py rename to tuning/trackers/__init__.py diff --git a/tuning/tracker/aimstack_tracker.py b/tuning/trackers/aimstack_tracker.py similarity index 78% rename from tuning/tracker/aimstack_tracker.py rename to tuning/trackers/aimstack_tracker.py index 7640b1db7..caf8fc4a1 100644 --- a/tuning/tracker/aimstack_tracker.py +++ b/tuning/trackers/aimstack_tracker.py @@ -11,7 +11,7 @@ class CustomAimCallback(AimCallback): # A path to export run hash generated by Aim # This is used to link back to the expriments from outside aimstack - aim_run_hash_export_path = None + run_hash_export_path = None def on_init_end(self, args, state, control, **kwargs): @@ -20,9 +20,18 @@ def on_init_end(self, args, state, control, **kwargs): self.setup() # initializes the run_hash - # store the run hash - if self.aim_run_hash_export_path: - with open(self.aim_run_hash_export_path, 'w') as f: + # Store the run hash + # Change default run hash path to output directory + if self.run_hash_export_path is None: + if args and args.output_dir: + # args.output_dir/.aim_run_hash + self.run_hash_export_path = os.path.join( + args.output_dir, + '.aim_run_hash' + ) + + if self.run_hash_export_path: + with open(self.run_hash_export_path, 'w') as f: f.write('{\"run_hash\":\"'+str(self._run.hash)+'\"}\n') def on_train_begin(self, args, state, control, model=None, **kwargs): @@ -60,7 +69,7 @@ def get_hf_callback(self): else: aim_callback = CustomAimCallback(experiment=exp) - aim_callback.aim_run_hash_export_path = hash_export_path + aim_callback.run_hash_export_path = hash_export_path self.hf_callback = aim_callback return self.hf_callback diff --git a/tuning/tracker/tracker.py b/tuning/trackers/tracker.py similarity index 64% rename from tuning/tracker/tracker.py rename to tuning/trackers/tracker.py index a1d2d2ba0..220109a66 100644 --- a/tuning/tracker/tracker.py +++ b/tuning/trackers/tracker.py @@ -9,7 +9,8 @@ def __init__(self, name=None, tracker_config=None) -> None: else: self._name = name - def get_hf_callback(): + # we use args here to denote any argument. + def get_hf_callback(self): return None def track(self, metric, name, stage): @@ -18,12 +19,4 @@ def track(self, metric, name, stage): # Object passed here is supposed to be a KV object # for the parameters to be associated with a run def set_params(self, params, name): - pass - -class TrackerFactory: - def get_tracker(tracker_name, tracker_config): - for T in Tracker.__subclasses__(): - if T._name == tracker_name: - return T(tracker_config) - else: - return Tracker() \ No newline at end of file + pass \ No newline at end of file diff --git a/tuning/trackers/tracker_factory.py b/tuning/trackers/tracker_factory.py new file mode 100644 index 000000000..3f06104d8 --- /dev/null +++ b/tuning/trackers/tracker_factory.py @@ -0,0 +1,13 @@ +from .tracker import Tracker +from .aimstack_tracker import AimStackTracker + +REGISTERED_TRACKERS = { + "aim" : AimStackTracker +} + +def get_tracker(tracker_name, tracker_config): + if tracker_name in REGISTERED_TRACKERS: + T = REGISTERED_TRACKERS[tracker_name] + return T(tracker_config) + else: + return Tracker() \ No newline at end of file From f178c21ed3679b6781e05ac4b2f507d97e9694e8 Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Tue, 27 Feb 2024 11:48:10 +0530 Subject: [PATCH 8/9] code format using black and minor nit to use public interface Signed-off-by: Dushyant Behl --- scripts/run_inference.py | 1 + tuning/config/configs.py | 6 ++- tuning/config/tracker_configs.py | 11 ++++-- tuning/sft_trainer.py | 54 +++++++++++++++++---------- tuning/trackers/aimstack_tracker.py | 58 +++++++++++++++-------------- tuning/trackers/tracker.py | 3 +- tuning/trackers/tracker_factory.py | 10 ++--- 7 files changed, 85 insertions(+), 58 deletions(-) diff --git a/scripts/run_inference.py b/scripts/run_inference.py index 989aaa8ca..b67dd1cc0 100644 --- a/scripts/run_inference.py +++ b/scripts/run_inference.py @@ -8,6 +8,7 @@ If these things change in the future, we should consider breaking it up. """ + # Standard import argparse import json diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 3223bfa7c..279b006f4 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -59,5 +59,9 @@ class TrainingArguments(transformers.TrainingArguments): ) tracker: str.lower = field( default=None, - metadata={"help": "Experiment tracker to use. Requires additional configs, see tuning.configs/tracker_configs.py"} + metadata={ + "help": "Experiment tracker to use.\n" + \ + "Available trackers are - aim, none\n" + \ + "Requires additional configs, see tuning.configs/tracker_configs.py" + }, ) diff --git a/tuning/config/tracker_configs.py b/tuning/config/tracker_configs.py index bd97253b0..49135adbf 100644 --- a/tuning/config/tracker_configs.py +++ b/tuning/config/tracker_configs.py @@ -1,15 +1,18 @@ +# Standard from dataclasses import dataclass + @dataclass class AimConfig: # Name of the experiment experiment: str = None - # 'repo' can point to a locally accessible directory (e.g., '~/.aim') or a remote repository hosted on a server. - # When 'remote_server_ip' or 'remote_server_port' is set, it designates a remote aim repo. + # 'aim_repo' can point to a locally accessible directory (e.g., '~/.aim') or a remote repository hosted on a server. + # When 'aim_remote_server_ip' or 'aim_remote_server_port' is set, it designates a remote aim repo. # Otherwise, 'repo' specifies the directory, with a default of None representing '.aim'. - aim_repo: str = None + # See https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html for documentation on Aim remote server tracking. + aim_repo: str = ".aim" aim_remote_server_ip: str = None aim_remote_server_port: int = None - # Location of where run_hash is exported, if unspecified this is output to + # Location of where run_hash is exported, if unspecified this is output to # training_args.output_dir/.aim_run_hash if the output_dir is set else not exported. aim_run_hash_export_path: str = None diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 4f48ec881..cba25a67a 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -1,11 +1,12 @@ # Standard from datetime import datetime -from typing import Optional, Union, List, Dict +from typing import Dict, List, Optional, Union import json -import os, time +import os +import time # Third Party -import transformers +from peft.utils.other import fsdp_auto_wrap_policy from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -16,21 +17,22 @@ TrainerCallback, ) from transformers.utils import logging -from peft.utils.other import fsdp_auto_wrap_policy from trl import DataCollatorForCompletionOnlyLM, SFTTrainer import datasets import fire +import transformers # Local from tuning.config import configs, peft_config, tracker_configs from tuning.data import tokenizer_data_utils -from tuning.utils.config_utils import get_hf_peft_config -from tuning.utils.data_type_utils import get_torch_dtype from tuning.trackers.tracker import Tracker from tuning.trackers.tracker_factory import get_tracker +from tuning.utils.config_utils import get_hf_peft_config +from tuning.utils.data_type_utils import get_torch_dtype logger = logging.get_logger("sft_trainer") + class PeftSavingCallback(TrainerCallback): def on_save(self, args, state, control, **kwargs): checkpoint_path = os.path.join( @@ -41,6 +43,7 @@ def on_save(self, args, state, control, **kwargs): if "pytorch_model.bin" in os.listdir(checkpoint_path): os.remove(os.path.join(checkpoint_path, "pytorch_model.bin")) + class FileLoggingCallback(TrainerCallback): """Exports metrics, e.g., training loss to a file in the checkpoint directory.""" @@ -84,6 +87,7 @@ def _track_loss(self, loss_key, log_file, logs, state): with open(log_file, "a") as f: f.write(f"{json.dumps(log_obj, sort_keys=True)}\n") + def train( model_args: configs.ModelArguments, data_args: configs.DataArguments, @@ -93,7 +97,7 @@ def train( ] = None, callbacks: Optional[List[TrainerCallback]] = None, tracker: Optional[Tracker] = None, - exp_metadata: Optional[Dict] = None + exp_metadata: Optional[Dict] = None, ): """Call the SFTTrainer @@ -105,6 +109,11 @@ def train( peft_config.PromptTuningConfig for prompt tuning | \ None for fine tuning The peft configuration to pass to trainer + callbacks: List of callbacks to attach with SFTtrainer. + tracker: One of the available trackers in tuning.trackers.tracker_factory.REGISTERED_TRACKERS + Initialized using tuning.trackers.tracker_factory.get_tracker + Using configs in tuning.config.tracker_configs + exp_metadata: Dict of key value pairs passed to train to be recoreded by the tracker. """ run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1 @@ -133,7 +142,7 @@ def train( torch_dtype=get_torch_dtype(model_args.torch_dtype), use_flash_attention_2=model_args.use_flash_attn, ) - additional_metrics['model_load_time'] = time.time() - model_load_time + additional_metrics["model_load_time"] = time.time() - model_load_time peft_config = get_hf_peft_config(task_type, peft_config) @@ -269,9 +278,9 @@ def train( if tracker is not None: # Currently tracked only on process zero. if trainer.is_world_process_zero(): - for k,v in additional_metrics.items(): - tracker.track(metric=v, name=k, stage='additional_metrics') - tracker.set_params(params=exp_metadata, name='experiment_metadata') + for k, v in additional_metrics.items(): + tracker.track(metric=v, name=k, stage="additional_metrics") + tracker.set_params(params=exp_metadata, name="experiment_metadata") if run_distributed and peft_config is not None: trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy( @@ -279,6 +288,7 @@ def train( ) trainer.train() + def main(**kwargs): parser = transformers.HfArgumentParser( dataclass_types=( @@ -300,6 +310,7 @@ def main(**kwargs): "--exp_metadata", type=str, default=None, + help='Pass a json string representing K:V pairs to be associated to the tuning run in the tracker. e.g. \'{"gpu":"A100-80G"}\'', ) ( model_args, @@ -313,18 +324,18 @@ def main(**kwargs): ) = parser.parse_args_into_dataclasses(return_remaining_strings=True) peft_method = additional.peft_method - if peft_method =="lora": - tune_config=lora_config - elif peft_method =="pt": - tune_config=prompt_tuning_config + if peft_method == "lora": + tune_config = lora_config + elif peft_method == "pt": + tune_config = prompt_tuning_config else: - tune_config=None + tune_config = None tracker_name = training_args.tracker if tracker_name == "aim": - tracker_config=aim_config + tracker_config = aim_config else: - tracker_config=None + tracker_config = None # Initialize callbacks file_logger_callback = FileLoggingCallback(logger) @@ -343,7 +354,9 @@ def main(**kwargs): try: metadata = json.loads(additional.exp_metadata) if metadata is None or not isinstance(metadata, Dict): - logger.warning('metadata cannot be converted to simple k:v dict ignoring') + logger.warning( + "metadata cannot be converted to simple k:v dict ignoring" + ) metadata = None except: logger.error("failed while parsing extra metadata. pass a valid json") @@ -355,8 +368,9 @@ def main(**kwargs): peft_config=tune_config, callbacks=callbacks, tracker=tracker, - exp_metadata=metadata + exp_metadata=metadata, ) + if __name__ == "__main__": fire.Fire(main) diff --git a/tuning/trackers/aimstack_tracker.py b/tuning/trackers/aimstack_tracker.py index caf8fc4a1..30c000cc5 100644 --- a/tuning/trackers/aimstack_tracker.py +++ b/tuning/trackers/aimstack_tracker.py @@ -1,56 +1,60 @@ # Standard import os +# Third Party +from aim.hugging_face import AimCallback + +# Local from .tracker import Tracker from tuning.config.tracker_configs import AimConfig -# Third Party -from aim.hugging_face import AimCallback class CustomAimCallback(AimCallback): # A path to export run hash generated by Aim # This is used to link back to the expriments from outside aimstack - run_hash_export_path = None + hash_export_path = None def on_init_end(self, args, state, control, **kwargs): if state and not state.is_world_process_zero: return - self.setup() # initializes the run_hash + self.setup() # initializes the run_hash # Store the run hash # Change default run hash path to output directory - if self.run_hash_export_path is None: + if self.hash_export_path is None: if args and args.output_dir: # args.output_dir/.aim_run_hash - self.run_hash_export_path = os.path.join( - args.output_dir, - '.aim_run_hash' - ) + self.hash_export_path = os.path.join( + args.output_dir, ".aim_run_hash" + ) - if self.run_hash_export_path: - with open(self.run_hash_export_path, 'w') as f: - f.write('{\"run_hash\":\"'+str(self._run.hash)+'\"}\n') + if self.hash_export_path: + with open(self.hash_export_path, "w") as f: + hash = self.experiment.hash + f.write('{"run_hash":"' + str(hash) + '"}\n') def on_train_begin(self, args, state, control, model=None, **kwargs): # call directly to make sure hyper parameters and model info is recorded. self.setup(args=args, state=state, model=model) def track_metrics(self, metric, name, context): - if self._run is not None: - self._run.track(metric, name=name, context=context) + run = self.experiment + if run is not None: + run.track(metric, name=name, context=context) + def set_params(self, params, name): - if self._run is not None: - for key, value in params.items(): - self._run.set((name, key), value, strict=False) + run = self.experiment + if run is not None: + [run.set((name, key), value, strict=False) for key, value in params.items()] -class AimStackTracker(Tracker): +class AimStackTracker(Tracker): def __init__(self, tracker_config: AimConfig): - super().__init__(name='aim', tracker_config=tracker_config) + super().__init__(name="aim", tracker_config=tracker_config) def get_hf_callback(self): c = self.config @@ -60,25 +64,25 @@ def get_hf_callback(self): repo = c.aim_repo hash_export_path = c.aim_run_hash_export_path - if (ip is not None and port is not None): + if ip is not None and port is not None: aim_callback = CustomAimCallback( - repo="aim://" + ip +":"+ port + "/", - experiment=exp) + repo="aim://" + ip + ":" + port + "/", experiment=exp + ) if repo: aim_callback = CustomAimCallback(repo=repo, experiment=exp) else: aim_callback = CustomAimCallback(experiment=exp) - aim_callback.run_hash_export_path = hash_export_path + aim_callback.hash_export_path = hash_export_path self.hf_callback = aim_callback return self.hf_callback - def track(self, metric, name, stage='additional_metrics'): - context={'subset' : stage} + def track(self, metric, name, stage="additional_metrics"): + context = {"subset": stage} self.hf_callback.track_metrics(metric, name=name, context=context) - def set_params(self, params, name='extra_params'): + def set_params(self, params, name="extra_params"): try: self.hf_callback.set_params(params, name) except: - pass \ No newline at end of file + pass diff --git a/tuning/trackers/tracker.py b/tuning/trackers/tracker.py index 220109a66..71ad60183 100644 --- a/tuning/trackers/tracker.py +++ b/tuning/trackers/tracker.py @@ -1,5 +1,6 @@ # Generic Tracker API + class Tracker: def __init__(self, name=None, tracker_config=None) -> None: if tracker_config is not None: @@ -19,4 +20,4 @@ def track(self, metric, name, stage): # Object passed here is supposed to be a KV object # for the parameters to be associated with a run def set_params(self, params, name): - pass \ No newline at end of file + pass diff --git a/tuning/trackers/tracker_factory.py b/tuning/trackers/tracker_factory.py index 3f06104d8..a9c32d641 100644 --- a/tuning/trackers/tracker_factory.py +++ b/tuning/trackers/tracker_factory.py @@ -1,13 +1,13 @@ -from .tracker import Tracker +# Local from .aimstack_tracker import AimStackTracker +from .tracker import Tracker + +REGISTERED_TRACKERS = {"aim": AimStackTracker} -REGISTERED_TRACKERS = { - "aim" : AimStackTracker -} def get_tracker(tracker_name, tracker_config): if tracker_name in REGISTERED_TRACKERS: T = REGISTERED_TRACKERS[tracker_name] return T(tracker_config) else: - return Tracker() \ No newline at end of file + return Tracker() From 2ceb54fbc464384fd63b429ade75357182aa479d Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Tue, 12 Mar 2024 14:41:29 +0530 Subject: [PATCH 9/9] custom aim callback should be minimal Signed-off-by: Dushyant Behl --- tuning/trackers/aimstack_tracker.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/tuning/trackers/aimstack_tracker.py b/tuning/trackers/aimstack_tracker.py index 30c000cc5..82f424f36 100644 --- a/tuning/trackers/aimstack_tracker.py +++ b/tuning/trackers/aimstack_tracker.py @@ -40,18 +40,6 @@ def on_train_begin(self, args, state, control, model=None, **kwargs): # call directly to make sure hyper parameters and model info is recorded. self.setup(args=args, state=state, model=model) - def track_metrics(self, metric, name, context): - run = self.experiment - if run is not None: - run.track(metric, name=name, context=context) - - - def set_params(self, params, name): - run = self.experiment - if run is not None: - [run.set((name, key), value, strict=False) for key, value in params.items()] - - class AimStackTracker(Tracker): def __init__(self, tracker_config: AimConfig): super().__init__(name="aim", tracker_config=tracker_config) @@ -79,10 +67,16 @@ def get_hf_callback(self): def track(self, metric, name, stage="additional_metrics"): context = {"subset": stage} - self.hf_callback.track_metrics(metric, name=name, context=context) + callback = self.hf_callback + run = callback.experiment + if run is not None: + run.track(metric, name=name, context=context) def set_params(self, params, name="extra_params"): try: - self.hf_callback.set_params(params, name) - except: - pass + callback = self.hf_callback + run = callback.experiment + if run is not None: + [run.set((name, key), value, strict=False) for key, value in params.items()] + except Exception as e: + raise e