Skip to content

Commit

Permalink
change default output path for aim run export
Browse files Browse the repository at this point in the history
Signed-off-by: Dushyant Behl <[email protected]>
  • Loading branch information
dushyantbehl committed Feb 23, 2024
1 parent a28aa7b commit 8661cc0
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 17 deletions.
5 changes: 3 additions & 2 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from tuning.data import tokenizer_data_utils
from tuning.utils.config_utils import get_hf_peft_config
from tuning.utils.data_type_utils import get_torch_dtype
from tuning.tracker.tracker import Tracker, TrackerFactory
from tuning.trackers.tracker import Tracker
from tuning.trackers.tracker_factory import get_tracker

logger = logging.get_logger("sft_trainer")

Expand Down Expand Up @@ -331,7 +332,7 @@ def main(**kwargs):
callbacks = [peft_saving_callback, file_logger_callback]

# Initialize the tracker
tracker = TrackerFactory.get_tracker(tracker_name, tracker_config)
tracker = get_tracker(tracker_name, tracker_config)
tracker_callback = tracker.get_hf_callback()
if tracker_callback is not None:
callbacks.append(tracker_callback)
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class CustomAimCallback(AimCallback):

# A path to export run hash generated by Aim
# This is used to link back to the expriments from outside aimstack
aim_run_hash_export_path = None
run_hash_export_path = None

def on_init_end(self, args, state, control, **kwargs):

Expand All @@ -20,9 +20,18 @@ def on_init_end(self, args, state, control, **kwargs):

self.setup() # initializes the run_hash

# store the run hash
if self.aim_run_hash_export_path:
with open(self.aim_run_hash_export_path, 'w') as f:
# Store the run hash
# Change default run hash path to output directory
if self.run_hash_export_path is None:
if args and args.output_dir:
# args.output_dir/.aim_run_hash
self.run_hash_export_path = os.path.join(
args.output_dir,
'.aim_run_hash'
)

if self.run_hash_export_path:
with open(self.run_hash_export_path, 'w') as f:
f.write('{\"run_hash\":\"'+str(self._run.hash)+'\"}\n')

def on_train_begin(self, args, state, control, model=None, **kwargs):
Expand Down Expand Up @@ -60,7 +69,7 @@ def get_hf_callback(self):
else:
aim_callback = CustomAimCallback(experiment=exp)

aim_callback.aim_run_hash_export_path = hash_export_path
aim_callback.run_hash_export_path = hash_export_path
self.hf_callback = aim_callback
return self.hf_callback

Expand Down
13 changes: 3 additions & 10 deletions tuning/tracker/tracker.py → tuning/trackers/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ def __init__(self, name=None, tracker_config=None) -> None:
else:
self._name = name

def get_hf_callback():
# we use args here to denote any argument.
def get_hf_callback(self):
return None

def track(self, metric, name, stage):
Expand All @@ -18,12 +19,4 @@ def track(self, metric, name, stage):
# Object passed here is supposed to be a KV object
# for the parameters to be associated with a run
def set_params(self, params, name):
pass

class TrackerFactory:
def get_tracker(tracker_name, tracker_config):
for T in Tracker.__subclasses__():
if T._name == tracker_name:
return T(tracker_config)
else:
return Tracker()
pass
13 changes: 13 additions & 0 deletions tuning/trackers/tracker_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .tracker import Tracker
from .aimstack_tracker import AimStackTracker

REGISTERED_TRACKERS = {
"aim" : AimStackTracker
}

def get_tracker(tracker_name, tracker_config):
if tracker_name in REGISTERED_TRACKERS:
T = REGISTERED_TRACKERS[tracker_name]
return T(tracker_config)
else:
return Tracker()

0 comments on commit 8661cc0

Please sign in to comment.