Skip to content

Commit

Permalink
enable tracker to track extra metadata
Browse files Browse the repository at this point in the history
Signed-off-by: Dushyant Behl <[email protected]>
  • Loading branch information
dushyantbehl committed Feb 21, 2024
1 parent b0c170c commit 4f93a7c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 11 deletions.
26 changes: 16 additions & 10 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tuning.data import tokenizer_data_utils
from tuning.utils.config_utils import get_hf_peft_config
from tuning.utils.data_type_utils import get_torch_dtype
from tuning.tracker.tracker import Tracker
from tuning.tracker.tracker import Tracker, get_tracker
from tuning.tracker.aimstack_tracker import AimStackTracker

logger = logging.get_logger("sft_trainer")
Expand Down Expand Up @@ -92,7 +92,7 @@ def train(
Union[peft_config.LoraConfig, peft_config.PromptTuningConfig]
] = None,
callbacks: Optional[List[TrainerCallback]] = None,
tracker: Optional[Tracker] = None,
tracker: Optional[Tracker] = Tracker() # default tracker is dummy tracker
):
"""Call the SFTTrainer
Expand Down Expand Up @@ -285,6 +285,11 @@ def main(**kwargs):
choices=["pt", "lora", None, "none"],
default="pt",
)
parser.add_argument(
"--extra_metadata",
type=str,
default=None,
)
(
model_args,
data_args,
Expand All @@ -311,14 +316,7 @@ def main(**kwargs):
tracker_config=None

# Initialize the tracker early so we can calculate custom metrics like model_load_time.
tracker_name = training_args.tracker
if tracker_name == 'aim':
if tracker_config is not None:
tracker = AimStackTracker(tracker_config)
else:
logger.error("Tracker name is set to "+tracker_name+" but config is None.")
else:
tracker = Tracker()
tracker = get_tracker(tracker_name, tracker_config)

# Initialize callbacks
file_logger_callback = FileLoggingCallback(logger)
Expand All @@ -329,6 +327,14 @@ def main(**kwargs):
if tracker_callback is not None:
callbacks.append(tracker_callback)

# track extra metadata
if additional.extra_metadata is not None:
try:
metadata = json.loads(additional.extra_metadata)
tracker.track_metadata(metadata)
except:
logger.error("failed while parsing extra metadata. pass a valid json")

train(model_args, data_args, training_args, tune_config, callbacks, tracker)

if __name__ == "__main__":
Expand Down
20 changes: 19 additions & 1 deletion tuning/tracker/tracker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Generic Tracker API

from tuning.tracker.aimstack_tracker import AimStackTracker

class Tracker:
def __init__(self, tracker_config) -> None:
self.config = tracker_config
Expand All @@ -8,4 +10,20 @@ def get_hf_callback():
return None

def track(self, metric, name, stage):
pass
pass

# Metadata passed here is supposed to be a KV object
# Key being the name and value being the metric to track.
def track_metadata(self, metadata=None):
if metadata is None or not isinstance(metadata, dict):
return
for k, v in metadata.items():
self.track(name=k, metric=v)

def get_tracker(tracker_name, tracker_config):
if tracker_name == 'aim':
if tracker_config is not None:
tracker = AimStackTracker(tracker_config)
else:
tracker = Tracker()
return tracker

0 comments on commit 4f93a7c

Please sign in to comment.