Skip to content

Commit

Permalink
Change to custom aim callback to disable multiple instantiation for
Browse files Browse the repository at this point in the history
FSDP.

Add tracker factory.

Signed-off-by: Dushyant Behl <[email protected]>
  • Loading branch information
dushyantbehl committed Feb 22, 2024
1 parent 4f93a7c commit a28aa7b
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 54 deletions.
50 changes: 35 additions & 15 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Standard
from datetime import datetime
from typing import Optional, Union, List
from typing import Optional, Union, List, Dict
import json
import os, time

Expand All @@ -26,8 +26,7 @@
from tuning.data import tokenizer_data_utils
from tuning.utils.config_utils import get_hf_peft_config
from tuning.utils.data_type_utils import get_torch_dtype
from tuning.tracker.tracker import Tracker, get_tracker
from tuning.tracker.aimstack_tracker import AimStackTracker
from tuning.tracker.tracker import Tracker, TrackerFactory

logger = logging.get_logger("sft_trainer")

Expand Down Expand Up @@ -92,7 +91,8 @@ def train(
Union[peft_config.LoraConfig, peft_config.PromptTuningConfig]
] = None,
callbacks: Optional[List[TrainerCallback]] = None,
tracker: Optional[Tracker] = Tracker() # default tracker is dummy tracker
tracker: Optional[Tracker] = None,
exp_metadata: Optional[Dict] = None
):
"""Call the SFTTrainer
Expand Down Expand Up @@ -123,6 +123,7 @@ def train(
train_args.fsdp_config = {"xla": False}

task_type = "CAUSAL_LM"
additional_metrics = {}

model_load_time = time.time()
model = AutoModelForCausalLM.from_pretrained(
Expand All @@ -131,8 +132,7 @@ def train(
torch_dtype=get_torch_dtype(model_args.torch_dtype),
use_flash_attention_2=model_args.use_flash_attn,
)
model_load_time = time.time() - model_load_time
tracker.track(metric=model_load_time, name='model_load_time')
additional_metrics['model_load_time'] = time.time() - model_load_time

peft_config = get_hf_peft_config(task_type, peft_config)

Expand Down Expand Up @@ -262,6 +262,16 @@ def train(
peft_config=peft_config,
)

# We track additional metrics and experiment metadata after
# Trainer object creation to ensure that this is not repeated
# multiple times for FSDP runs.
if tracker is not None:
# Currently tracked only on process zero.
if trainer.is_world_process_zero():
for k,v in additional_metrics.items():
tracker.track(metric=v, name=k, stage='additional_metrics')
tracker.set_params(params=exp_metadata, name='experiment_metadata')

if run_distributed and peft_config is not None:
trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(
model
Expand All @@ -286,7 +296,7 @@ def main(**kwargs):
default="pt",
)
parser.add_argument(
"--extra_metadata",
"--exp_metadata",
type=str,
default=None,
)
Expand Down Expand Up @@ -315,27 +325,37 @@ def main(**kwargs):
else:
tracker_config=None

# Initialize the tracker early so we can calculate custom metrics like model_load_time.
tracker = get_tracker(tracker_name, tracker_config)

# Initialize callbacks
file_logger_callback = FileLoggingCallback(logger)
peft_saving_callback = PeftSavingCallback()
callbacks = [peft_saving_callback, file_logger_callback]

# Initialize the tracker
tracker = TrackerFactory.get_tracker(tracker_name, tracker_config)
tracker_callback = tracker.get_hf_callback()
if tracker_callback is not None:
callbacks.append(tracker_callback)

# track extra metadata
if additional.extra_metadata is not None:
# extra metadata passed via client
metadata = None
if additional.exp_metadata is not None:
try:
metadata = json.loads(additional.extra_metadata)
tracker.track_metadata(metadata)
metadata = json.loads(additional.exp_metadata)
if metadata is None or not isinstance(metadata, Dict):
logger.warning('metadata cannot be converted to simple k:v dict ignoring')
metadata = None
except:
logger.error("failed while parsing extra metadata. pass a valid json")

train(model_args, data_args, training_args, tune_config, callbacks, tracker)
train(
model_args=model_args,
data_args=data_args,
train_args=training_args,
peft_config=tune_config,
callbacks=callbacks,
tracker=tracker,
exp_metadata=metadata
)

if __name__ == "__main__":
fire.Fire(main)
70 changes: 49 additions & 21 deletions tuning/tracker/aimstack_tracker.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,49 @@
# Standard
import os

from tuning.tracker.tracker import Tracker
from .tracker import Tracker
from tuning.config.tracker_configs import AimConfig

# Third Party
from aim.hugging_face import AimCallback

class CustomAimCallback(AimCallback):

# A path to export run hash generated by Aim
# This is used to link back to the expriments from outside aimstack
aim_run_hash_export_path = None

def on_init_end(self, args, state, control, **kwargs):

if state and not state.is_world_process_zero:
return

self.setup() # initializes the run_hash

# store the run hash
if self.aim_run_hash_export_path:
with open(self.aim_run_hash_export_path, 'w') as f:
f.write('{\"run_hash\":\"'+str(self._run.hash)+'\"}\n')

def on_train_begin(self, args, state, control, model=None, **kwargs):
# call directly to make sure hyper parameters and model info is recorded.
self.setup(args=args, state=state, model=model)

def track_metrics(self, metric, name, context):
if self._run is not None:
self._run.track(metric, name=name, context=context)

def set_params(self, params, name):
if self._run is not None:
for key, value in params.items():
self._run.set((name, key), value, strict=False)

class AimStackTracker(Tracker):

def __init__(self, tracker_config):
super().__init__(tracker_config)
def __init__(self, tracker_config: AimConfig):
super().__init__(name='aim', tracker_config=tracker_config)

def get_hf_callback(self):
c = self.config
exp = c.experiment
ip = c.aim_remote_server_ip
Expand All @@ -18,30 +52,24 @@ def __init__(self, tracker_config):
hash_export_path = c.aim_run_hash_export_path

if (ip is not None and port is not None):
aim_callback = AimCallback(
aim_callback = CustomAimCallback(
repo="aim://" + ip +":"+ port + "/",
experiment=exp
)
experiment=exp)
if repo:
aim_callback = AimCallback(repo=repo, experiment=exp)
aim_callback = CustomAimCallback(repo=repo, experiment=exp)
else:
aim_callback = AimCallback(experiment=exp)
aim_callback = CustomAimCallback(experiment=exp)

run = aim_callback.experiment # Initialize Aim run
run_hash = run.hash # Extract the hash

# store the run hash
if hash_export_path:
with open(hash_export_path, 'w') as f:
f.write(str(run_hash)+'\n')

# Save Internal State
aim_callback.aim_run_hash_export_path = hash_export_path
self.hf_callback = aim_callback
self.run = run

def get_hf_callback(self):
return self.hf_callback

def track(self, metric, name, stage='additional_metrics'):
context={'subset' : stage}
self.run.track(metric, name=name, context=context)
self.hf_callback.track_metrics(metric, name=name, context=context)

def set_params(self, params, name='extra_params'):
try:
self.hf_callback.set_params(params, name)
except:
pass
36 changes: 18 additions & 18 deletions tuning/tracker/tracker.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
# Generic Tracker API

from tuning.tracker.aimstack_tracker import AimStackTracker

class Tracker:
def __init__(self, tracker_config) -> None:
self.config = tracker_config
def __init__(self, name=None, tracker_config=None) -> None:
if tracker_config is not None:
self.config = tracker_config
if name is None:
self._name = "None"
else:
self._name = name

def get_hf_callback():
return None

def track(self, metric, name, stage):
pass

# Metadata passed here is supposed to be a KV object
# Key being the name and value being the metric to track.
def track_metadata(self, metadata=None):
if metadata is None or not isinstance(metadata, dict):
return
for k, v in metadata.items():
self.track(name=k, metric=v)
# Object passed here is supposed to be a KV object
# for the parameters to be associated with a run
def set_params(self, params, name):
pass

def get_tracker(tracker_name, tracker_config):
if tracker_name == 'aim':
if tracker_config is not None:
tracker = AimStackTracker(tracker_config)
else:
tracker = Tracker()
return tracker
class TrackerFactory:
def get_tracker(tracker_name, tracker_config):
for T in Tracker.__subclasses__():
if T._name == tracker_name:
return T(tracker_config)
else:
return Tracker()

0 comments on commit a28aa7b

Please sign in to comment.