Skip to content

Commit

Permalink
change default output path for aim run export
Browse files Browse the repository at this point in the history
  • Loading branch information
dushyantbehl committed Feb 23, 2024
1 parent a28aa7b commit e0b2d2e
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
3 changes: 2 additions & 1 deletion tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from trl import DataCollatorForCompletionOnlyLM, SFTTrainer
import datasets
import fire
import dataclasses

# Local
from tuning.config import configs, peft_config, tracker_configs
Expand Down Expand Up @@ -332,7 +333,7 @@ def main(**kwargs):

# Initialize the tracker
tracker = TrackerFactory.get_tracker(tracker_name, tracker_config)
tracker_callback = tracker.get_hf_callback()
tracker_callback = tracker.get_hf_callback(dataclasses.asdict(training_args))
if tracker_callback is not None:
callbacks.append(tracker_callback)

Expand Down
16 changes: 11 additions & 5 deletions tuning/tracker/aimstack_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class CustomAimCallback(AimCallback):

# A path to export run hash generated by Aim
# This is used to link back to the expriments from outside aimstack
aim_run_hash_export_path = None
run_hash_export_path = None

def on_init_end(self, args, state, control, **kwargs):

Expand All @@ -21,8 +21,8 @@ def on_init_end(self, args, state, control, **kwargs):
self.setup() # initializes the run_hash

# store the run hash
if self.aim_run_hash_export_path:
with open(self.aim_run_hash_export_path, 'w') as f:
if self.run_hash_export_path:
with open(self.run_hash_export_path, 'w') as f:
f.write('{\"run_hash\":\"'+str(self._run.hash)+'\"}\n')

def on_train_begin(self, args, state, control, model=None, **kwargs):
Expand All @@ -43,14 +43,20 @@ class AimStackTracker(Tracker):
def __init__(self, tracker_config: AimConfig):
super().__init__(name='aim', tracker_config=tracker_config)

def get_hf_callback(self):
def get_hf_callback(self, **kwargs):
c = self.config
exp = c.experiment
ip = c.aim_remote_server_ip
port = c.aim_remote_server_port
repo = c.aim_repo
hash_export_path = c.aim_run_hash_export_path

# Change default run hash path to output directory
if hash_export_path is None:
if kwargs is not None and 'output_dir' in kwargs:
# training_args.output_dir/.aim_run_hash
p = os.path.join(kwargs['output_dir'], '.aim_run_hash')

if (ip is not None and port is not None):
aim_callback = CustomAimCallback(
repo="aim://" + ip +":"+ port + "/",
Expand All @@ -60,7 +66,7 @@ def get_hf_callback(self):
else:
aim_callback = CustomAimCallback(experiment=exp)

aim_callback.aim_run_hash_export_path = hash_export_path
aim_callback.run_hash_export_path = hash_export_path
self.hf_callback = aim_callback
return self.hf_callback

Expand Down
3 changes: 2 additions & 1 deletion tuning/tracker/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ def __init__(self, name=None, tracker_config=None) -> None:
else:
self._name = name

def get_hf_callback():
# we use args here to denote any argument.
def get_hf_callback(self, **kwargs):
return None

def track(self, metric, name, stage):
Expand Down

0 comments on commit e0b2d2e

Please sign in to comment.