From e0b2d2e653078b0518cd0303fb3b1ce7e1a533eb Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Fri, 23 Feb 2024 14:08:57 +0530 Subject: [PATCH] change default output path for aim run export --- tuning/sft_trainer.py | 3 ++- tuning/tracker/aimstack_tracker.py | 16 +++++++++++----- tuning/tracker/tracker.py | 3 ++- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 0938aa650..be09d4030 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -20,6 +20,7 @@ from trl import DataCollatorForCompletionOnlyLM, SFTTrainer import datasets import fire +import dataclasses # Local from tuning.config import configs, peft_config, tracker_configs @@ -332,7 +333,7 @@ def main(**kwargs): # Initialize the tracker tracker = TrackerFactory.get_tracker(tracker_name, tracker_config) - tracker_callback = tracker.get_hf_callback() + tracker_callback = tracker.get_hf_callback(dataclasses.asdict(training_args)) if tracker_callback is not None: callbacks.append(tracker_callback) diff --git a/tuning/tracker/aimstack_tracker.py b/tuning/tracker/aimstack_tracker.py index 7640b1db7..0cfaeec90 100644 --- a/tuning/tracker/aimstack_tracker.py +++ b/tuning/tracker/aimstack_tracker.py @@ -11,7 +11,7 @@ class CustomAimCallback(AimCallback): # A path to export run hash generated by Aim # This is used to link back to the expriments from outside aimstack - aim_run_hash_export_path = None + run_hash_export_path = None def on_init_end(self, args, state, control, **kwargs): @@ -21,8 +21,8 @@ def on_init_end(self, args, state, control, **kwargs): self.setup() # initializes the run_hash # store the run hash - if self.aim_run_hash_export_path: - with open(self.aim_run_hash_export_path, 'w') as f: + if self.run_hash_export_path: + with open(self.run_hash_export_path, 'w') as f: f.write('{\"run_hash\":\"'+str(self._run.hash)+'\"}\n') def on_train_begin(self, args, state, control, model=None, **kwargs): @@ -43,7 +43,7 @@ class AimStackTracker(Tracker): def __init__(self, tracker_config: AimConfig): super().__init__(name='aim', tracker_config=tracker_config) - def get_hf_callback(self): + def get_hf_callback(self, **kwargs): c = self.config exp = c.experiment ip = c.aim_remote_server_ip @@ -51,6 +51,12 @@ def get_hf_callback(self): repo = c.aim_repo hash_export_path = c.aim_run_hash_export_path + # Change default run hash path to output directory + if hash_export_path is None: + if kwargs is not None and 'output_dir' in kwargs: + # training_args.output_dir/.aim_run_hash + p = os.path.join(kwargs['output_dir'], '.aim_run_hash') + if (ip is not None and port is not None): aim_callback = CustomAimCallback( repo="aim://" + ip +":"+ port + "/", @@ -60,7 +66,7 @@ def get_hf_callback(self): else: aim_callback = CustomAimCallback(experiment=exp) - aim_callback.aim_run_hash_export_path = hash_export_path + aim_callback.run_hash_export_path = hash_export_path self.hf_callback = aim_callback return self.hf_callback diff --git a/tuning/tracker/tracker.py b/tuning/tracker/tracker.py index a1d2d2ba0..a7300f042 100644 --- a/tuning/tracker/tracker.py +++ b/tuning/tracker/tracker.py @@ -9,7 +9,8 @@ def __init__(self, name=None, tracker_config=None) -> None: else: self._name = name - def get_hf_callback(): + # we use args here to denote any argument. + def get_hf_callback(self, **kwargs): return None def track(self, metric, name, stage):