Skip to content

Commit

Permalink
code format using black and minor nit
Browse files Browse the repository at this point in the history
Signed-off-by: Dushyant Behl <[email protected]>
  • Loading branch information
dushyantbehl committed Feb 28, 2024
1 parent c93329f commit 803784b
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 38 deletions.
1 change: 1 addition & 0 deletions scripts/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
If these things change in the future, we should consider breaking it up.
"""

# Standard
import argparse
import json
Expand Down
5 changes: 4 additions & 1 deletion tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,8 @@ class TrainingArguments(transformers.TrainingArguments):
)
tracker: str.lower = field(
default=None,
metadata={"help": "Experiment tracker to use. Requires additional configs, see tuning.configs/tracker_configs.py"}
choices=["aim", None, "none"],
metadata={
"help": "Experiment tracker to use. Requires additional configs, see tuning.configs/tracker_configs.py"
},
)
10 changes: 6 additions & 4 deletions tuning/config/tracker_configs.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from dataclasses import dataclass


@dataclass
class AimConfig:
# Name of the experiment
experiment: str = None
# 'repo' can point to a locally accessible directory (e.g., '~/.aim') or a remote repository hosted on a server.
# When 'remote_server_ip' or 'remote_server_port' is set, it designates a remote aim repo.
# 'aim_repo' can point to a locally accessible directory (e.g., '~/.aim') or a remote repository hosted on a server.
# When 'aim_remote_server_ip' or 'aim_remote_server_port' is set, it designates a remote aim repo.
# Otherwise, 'repo' specifies the directory, with a default of None representing '.aim'.
aim_repo: str = None
# See https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html for documentation on Aim remote server tracking.
aim_repo: str = ".aim"
aim_remote_server_ip: str = None
aim_remote_server_port: int = None
# Location of where run_hash is exported, if unspecified this is output to
# Location of where run_hash is exported, if unspecified this is output to
# training_args.output_dir/.aim_run_hash if the output_dir is set else not exported.
aim_run_hash_export_path: str = None
41 changes: 27 additions & 14 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

logger = logging.get_logger("sft_trainer")


class PeftSavingCallback(TrainerCallback):
def on_save(self, args, state, control, **kwargs):
checkpoint_path = os.path.join(
Expand All @@ -41,6 +42,7 @@ def on_save(self, args, state, control, **kwargs):
if "pytorch_model.bin" in os.listdir(checkpoint_path):
os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))


class FileLoggingCallback(TrainerCallback):
"""Exports metrics, e.g., training loss to a file in the checkpoint directory."""

Expand Down Expand Up @@ -84,6 +86,7 @@ def _track_loss(self, loss_key, log_file, logs, state):
with open(log_file, "a") as f:
f.write(f"{json.dumps(log_obj, sort_keys=True)}\n")


def train(
model_args: configs.ModelArguments,
data_args: configs.DataArguments,
Expand All @@ -93,7 +96,7 @@ def train(
] = None,
callbacks: Optional[List[TrainerCallback]] = None,
tracker: Optional[Tracker] = None,
exp_metadata: Optional[Dict] = None
exp_metadata: Optional[Dict] = None,
):
"""Call the SFTTrainer
Expand All @@ -105,6 +108,11 @@ def train(
peft_config.PromptTuningConfig for prompt tuning | \
None for fine tuning
The peft configuration to pass to trainer
callbacks: List of callbacks to attach with SFTtrainer.
tracker: One of the available trackers in tuning.trackers.tracker_factory.REGISTERED_TRACKERS
Initialized using tuning.trackers.tracker_factory.get_tracker
Using configs in tuning.config.tracker_configs
exp_metadata: Dict of key value pairs passed to train to be recoreded by the tracker.
"""
run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1

Expand Down Expand Up @@ -133,7 +141,7 @@ def train(
torch_dtype=get_torch_dtype(model_args.torch_dtype),
use_flash_attention_2=model_args.use_flash_attn,
)
additional_metrics['model_load_time'] = time.time() - model_load_time
additional_metrics["model_load_time"] = time.time() - model_load_time

peft_config = get_hf_peft_config(task_type, peft_config)

Expand Down Expand Up @@ -269,16 +277,17 @@ def train(
if tracker is not None:
# Currently tracked only on process zero.
if trainer.is_world_process_zero():
for k,v in additional_metrics.items():
tracker.track(metric=v, name=k, stage='additional_metrics')
tracker.set_params(params=exp_metadata, name='experiment_metadata')
for k, v in additional_metrics.items():
tracker.track(metric=v, name=k, stage="additional_metrics")
tracker.set_params(params=exp_metadata, name="experiment_metadata")

if run_distributed and peft_config is not None:
trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(
model
)
trainer.train()


def main(**kwargs):
parser = transformers.HfArgumentParser(
dataclass_types=(
Expand All @@ -300,6 +309,7 @@ def main(**kwargs):
"--exp_metadata",
type=str,
default=None,
help='Pass a json string representing K:V pairs to be associated to the tuning run in the tracker. e.g. \'{"gpu":"A100-80G"}\'',
)
(
model_args,
Expand All @@ -313,18 +323,18 @@ def main(**kwargs):
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)

peft_method = additional.peft_method
if peft_method =="lora":
tune_config=lora_config
elif peft_method =="pt":
tune_config=prompt_tuning_config
if peft_method == "lora":
tune_config = lora_config
elif peft_method == "pt":
tune_config = prompt_tuning_config
else:
tune_config=None
tune_config = None

tracker_name = training_args.tracker
if tracker_name == "aim":
tracker_config=aim_config
tracker_config = aim_config
else:
tracker_config=None
tracker_config = None

# Initialize callbacks
file_logger_callback = FileLoggingCallback(logger)
Expand All @@ -343,7 +353,9 @@ def main(**kwargs):
try:
metadata = json.loads(additional.exp_metadata)
if metadata is None or not isinstance(metadata, Dict):
logger.warning('metadata cannot be converted to simple k:v dict ignoring')
logger.warning(
"metadata cannot be converted to simple k:v dict ignoring"
)
metadata = None
except:
logger.error("failed while parsing extra metadata. pass a valid json")
Expand All @@ -355,8 +367,9 @@ def main(**kwargs):
peft_config=tune_config,
callbacks=callbacks,
tracker=tracker,
exp_metadata=metadata
exp_metadata=metadata,
)


if __name__ == "__main__":
fire.Fire(main)
29 changes: 15 additions & 14 deletions tuning/trackers/aimstack_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# Third Party
from aim.hugging_face import AimCallback


class CustomAimCallback(AimCallback):

# A path to export run hash generated by Aim
Expand All @@ -18,21 +19,20 @@ def on_init_end(self, args, state, control, **kwargs):
if state and not state.is_world_process_zero:
return

self.setup() # initializes the run_hash
self.setup() # initializes the run_hash

# Store the run hash
# Change default run hash path to output directory
if self.run_hash_export_path is None:
if args and args.output_dir:
# args.output_dir/.aim_run_hash
self.run_hash_export_path = os.path.join(
args.output_dir,
'.aim_run_hash'
)
args.output_dir, ".aim_run_hash"
)

if self.run_hash_export_path:
with open(self.run_hash_export_path, 'w') as f:
f.write('{\"run_hash\":\"'+str(self._run.hash)+'\"}\n')
with open(self.run_hash_export_path, "w") as f:
f.write('{"run_hash":"' + str(self._run.hash) + '"}\n')

def on_train_begin(self, args, state, control, model=None, **kwargs):
# call directly to make sure hyper parameters and model info is recorded.
Expand All @@ -47,10 +47,11 @@ def set_params(self, params, name):
for key, value in params.items():
self._run.set((name, key), value, strict=False)


class AimStackTracker(Tracker):

def __init__(self, tracker_config: AimConfig):
super().__init__(name='aim', tracker_config=tracker_config)
super().__init__(name="aim", tracker_config=tracker_config)

def get_hf_callback(self):
c = self.config
Expand All @@ -60,10 +61,10 @@ def get_hf_callback(self):
repo = c.aim_repo
hash_export_path = c.aim_run_hash_export_path

if (ip is not None and port is not None):
if ip is not None and port is not None:
aim_callback = CustomAimCallback(
repo="aim://" + ip +":"+ port + "/",
experiment=exp)
repo="aim://" + ip + ":" + port + "/", experiment=exp
)
if repo:
aim_callback = CustomAimCallback(repo=repo, experiment=exp)
else:
Expand All @@ -73,12 +74,12 @@ def get_hf_callback(self):
self.hf_callback = aim_callback
return self.hf_callback

def track(self, metric, name, stage='additional_metrics'):
context={'subset' : stage}
def track(self, metric, name, stage="additional_metrics"):
context = {"subset": stage}
self.hf_callback.track_metrics(metric, name=name, context=context)

def set_params(self, params, name='extra_params'):
def set_params(self, params, name="extra_params"):
try:
self.hf_callback.set_params(params, name)
except:
pass
pass
3 changes: 2 additions & 1 deletion tuning/trackers/tracker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generic Tracker API


class Tracker:
def __init__(self, name=None, tracker_config=None) -> None:
if tracker_config is not None:
Expand All @@ -19,4 +20,4 @@ def track(self, metric, name, stage):
# Object passed here is supposed to be a KV object
# for the parameters to be associated with a run
def set_params(self, params, name):
pass
pass
7 changes: 3 additions & 4 deletions tuning/trackers/tracker_factory.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from .tracker import Tracker
from .aimstack_tracker import AimStackTracker

REGISTERED_TRACKERS = {
"aim" : AimStackTracker
}
REGISTERED_TRACKERS = {"aim": AimStackTracker}


def get_tracker(tracker_name, tracker_config):
if tracker_name in REGISTERED_TRACKERS:
T = REGISTERED_TRACKERS[tracker_name]
return T(tracker_config)
else:
return Tracker()
return Tracker()

0 comments on commit 803784b

Please sign in to comment.