diff --git a/scripts/run_inference.py b/scripts/run_inference.py index 989aaa8ca..b67dd1cc0 100644 --- a/scripts/run_inference.py +++ b/scripts/run_inference.py @@ -8,6 +8,7 @@ If these things change in the future, we should consider breaking it up. """ + # Standard import argparse import json diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 3223bfa7c..a1c4c61bc 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -59,5 +59,8 @@ class TrainingArguments(transformers.TrainingArguments): ) tracker: str.lower = field( default=None, - metadata={"help": "Experiment tracker to use. Requires additional configs, see tuning.configs/tracker_configs.py"} + choices=["aim", None, "none"], + metadata={ + "help": "Experiment tracker to use. Requires additional configs, see tuning.configs/tracker_configs.py" + }, ) diff --git a/tuning/config/tracker_configs.py b/tuning/config/tracker_configs.py index bd97253b0..057714f31 100644 --- a/tuning/config/tracker_configs.py +++ b/tuning/config/tracker_configs.py @@ -1,15 +1,17 @@ from dataclasses import dataclass + @dataclass class AimConfig: # Name of the experiment experiment: str = None - # 'repo' can point to a locally accessible directory (e.g., '~/.aim') or a remote repository hosted on a server. - # When 'remote_server_ip' or 'remote_server_port' is set, it designates a remote aim repo. + # 'aim_repo' can point to a locally accessible directory (e.g., '~/.aim') or a remote repository hosted on a server. + # When 'aim_remote_server_ip' or 'aim_remote_server_port' is set, it designates a remote aim repo. # Otherwise, 'repo' specifies the directory, with a default of None representing '.aim'. - aim_repo: str = None + # See https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html for documentation on Aim remote server tracking. + aim_repo: str = ".aim" aim_remote_server_ip: str = None aim_remote_server_port: int = None - # Location of where run_hash is exported, if unspecified this is output to + # Location of where run_hash is exported, if unspecified this is output to # training_args.output_dir/.aim_run_hash if the output_dir is set else not exported. aim_run_hash_export_path: str = None diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 4f48ec881..4dad940db 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -31,6 +31,7 @@ logger = logging.get_logger("sft_trainer") + class PeftSavingCallback(TrainerCallback): def on_save(self, args, state, control, **kwargs): checkpoint_path = os.path.join( @@ -41,6 +42,7 @@ def on_save(self, args, state, control, **kwargs): if "pytorch_model.bin" in os.listdir(checkpoint_path): os.remove(os.path.join(checkpoint_path, "pytorch_model.bin")) + class FileLoggingCallback(TrainerCallback): """Exports metrics, e.g., training loss to a file in the checkpoint directory.""" @@ -84,6 +86,7 @@ def _track_loss(self, loss_key, log_file, logs, state): with open(log_file, "a") as f: f.write(f"{json.dumps(log_obj, sort_keys=True)}\n") + def train( model_args: configs.ModelArguments, data_args: configs.DataArguments, @@ -93,7 +96,7 @@ def train( ] = None, callbacks: Optional[List[TrainerCallback]] = None, tracker: Optional[Tracker] = None, - exp_metadata: Optional[Dict] = None + exp_metadata: Optional[Dict] = None, ): """Call the SFTTrainer @@ -105,6 +108,11 @@ def train( peft_config.PromptTuningConfig for prompt tuning | \ None for fine tuning The peft configuration to pass to trainer + callbacks: List of callbacks to attach with SFTtrainer. + tracker: One of the available trackers in tuning.trackers.tracker_factory.REGISTERED_TRACKERS + Initialized using tuning.trackers.tracker_factory.get_tracker + Using configs in tuning.config.tracker_configs + exp_metadata: Dict of key value pairs passed to train to be recoreded by the tracker. """ run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1 @@ -133,7 +141,7 @@ def train( torch_dtype=get_torch_dtype(model_args.torch_dtype), use_flash_attention_2=model_args.use_flash_attn, ) - additional_metrics['model_load_time'] = time.time() - model_load_time + additional_metrics["model_load_time"] = time.time() - model_load_time peft_config = get_hf_peft_config(task_type, peft_config) @@ -269,9 +277,9 @@ def train( if tracker is not None: # Currently tracked only on process zero. if trainer.is_world_process_zero(): - for k,v in additional_metrics.items(): - tracker.track(metric=v, name=k, stage='additional_metrics') - tracker.set_params(params=exp_metadata, name='experiment_metadata') + for k, v in additional_metrics.items(): + tracker.track(metric=v, name=k, stage="additional_metrics") + tracker.set_params(params=exp_metadata, name="experiment_metadata") if run_distributed and peft_config is not None: trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy( @@ -279,6 +287,7 @@ def train( ) trainer.train() + def main(**kwargs): parser = transformers.HfArgumentParser( dataclass_types=( @@ -300,6 +309,7 @@ def main(**kwargs): "--exp_metadata", type=str, default=None, + help='Pass a json string representing K:V pairs to be associated to the tuning run in the tracker. e.g. \'{"gpu":"A100-80G"}\'', ) ( model_args, @@ -313,18 +323,18 @@ def main(**kwargs): ) = parser.parse_args_into_dataclasses(return_remaining_strings=True) peft_method = additional.peft_method - if peft_method =="lora": - tune_config=lora_config - elif peft_method =="pt": - tune_config=prompt_tuning_config + if peft_method == "lora": + tune_config = lora_config + elif peft_method == "pt": + tune_config = prompt_tuning_config else: - tune_config=None + tune_config = None tracker_name = training_args.tracker if tracker_name == "aim": - tracker_config=aim_config + tracker_config = aim_config else: - tracker_config=None + tracker_config = None # Initialize callbacks file_logger_callback = FileLoggingCallback(logger) @@ -343,7 +353,9 @@ def main(**kwargs): try: metadata = json.loads(additional.exp_metadata) if metadata is None or not isinstance(metadata, Dict): - logger.warning('metadata cannot be converted to simple k:v dict ignoring') + logger.warning( + "metadata cannot be converted to simple k:v dict ignoring" + ) metadata = None except: logger.error("failed while parsing extra metadata. pass a valid json") @@ -355,8 +367,9 @@ def main(**kwargs): peft_config=tune_config, callbacks=callbacks, tracker=tracker, - exp_metadata=metadata + exp_metadata=metadata, ) + if __name__ == "__main__": fire.Fire(main) diff --git a/tuning/trackers/aimstack_tracker.py b/tuning/trackers/aimstack_tracker.py index caf8fc4a1..c82b89f35 100644 --- a/tuning/trackers/aimstack_tracker.py +++ b/tuning/trackers/aimstack_tracker.py @@ -7,6 +7,7 @@ # Third Party from aim.hugging_face import AimCallback + class CustomAimCallback(AimCallback): # A path to export run hash generated by Aim @@ -18,7 +19,7 @@ def on_init_end(self, args, state, control, **kwargs): if state and not state.is_world_process_zero: return - self.setup() # initializes the run_hash + self.setup() # initializes the run_hash # Store the run hash # Change default run hash path to output directory @@ -26,13 +27,12 @@ def on_init_end(self, args, state, control, **kwargs): if args and args.output_dir: # args.output_dir/.aim_run_hash self.run_hash_export_path = os.path.join( - args.output_dir, - '.aim_run_hash' - ) + args.output_dir, ".aim_run_hash" + ) if self.run_hash_export_path: - with open(self.run_hash_export_path, 'w') as f: - f.write('{\"run_hash\":\"'+str(self._run.hash)+'\"}\n') + with open(self.run_hash_export_path, "w") as f: + f.write('{"run_hash":"' + str(self._run.hash) + '"}\n') def on_train_begin(self, args, state, control, model=None, **kwargs): # call directly to make sure hyper parameters and model info is recorded. @@ -47,10 +47,11 @@ def set_params(self, params, name): for key, value in params.items(): self._run.set((name, key), value, strict=False) + class AimStackTracker(Tracker): def __init__(self, tracker_config: AimConfig): - super().__init__(name='aim', tracker_config=tracker_config) + super().__init__(name="aim", tracker_config=tracker_config) def get_hf_callback(self): c = self.config @@ -60,10 +61,10 @@ def get_hf_callback(self): repo = c.aim_repo hash_export_path = c.aim_run_hash_export_path - if (ip is not None and port is not None): + if ip is not None and port is not None: aim_callback = CustomAimCallback( - repo="aim://" + ip +":"+ port + "/", - experiment=exp) + repo="aim://" + ip + ":" + port + "/", experiment=exp + ) if repo: aim_callback = CustomAimCallback(repo=repo, experiment=exp) else: @@ -73,12 +74,12 @@ def get_hf_callback(self): self.hf_callback = aim_callback return self.hf_callback - def track(self, metric, name, stage='additional_metrics'): - context={'subset' : stage} + def track(self, metric, name, stage="additional_metrics"): + context = {"subset": stage} self.hf_callback.track_metrics(metric, name=name, context=context) - def set_params(self, params, name='extra_params'): + def set_params(self, params, name="extra_params"): try: self.hf_callback.set_params(params, name) except: - pass \ No newline at end of file + pass diff --git a/tuning/trackers/tracker.py b/tuning/trackers/tracker.py index 220109a66..71ad60183 100644 --- a/tuning/trackers/tracker.py +++ b/tuning/trackers/tracker.py @@ -1,5 +1,6 @@ # Generic Tracker API + class Tracker: def __init__(self, name=None, tracker_config=None) -> None: if tracker_config is not None: @@ -19,4 +20,4 @@ def track(self, metric, name, stage): # Object passed here is supposed to be a KV object # for the parameters to be associated with a run def set_params(self, params, name): - pass \ No newline at end of file + pass diff --git a/tuning/trackers/tracker_factory.py b/tuning/trackers/tracker_factory.py index 3f06104d8..bd35c8b46 100644 --- a/tuning/trackers/tracker_factory.py +++ b/tuning/trackers/tracker_factory.py @@ -1,13 +1,12 @@ from .tracker import Tracker from .aimstack_tracker import AimStackTracker -REGISTERED_TRACKERS = { - "aim" : AimStackTracker -} +REGISTERED_TRACKERS = {"aim": AimStackTracker} + def get_tracker(tracker_name, tracker_config): if tracker_name in REGISTERED_TRACKERS: T = REGISTERED_TRACKERS[tracker_name] return T(tracker_config) else: - return Tracker() \ No newline at end of file + return Tracker()