From 26761592457409be29a2f63b205e2c2134e352a4 Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Thu, 19 Dec 2024 16:55:30 +0530 Subject: [PATCH] Add mlflow tracker and unit testing for the same. Signed-off-by: Dushyant Behl --- tests/test_sft_trainer.py | 6 +- tests/trackers/test_aim_tracker.py | 2 +- tests/trackers/test_mlflow_tracker.py | 124 ++++++++++++++++ tuning/config/tracker_configs.py | 20 +++ tuning/sft_trainer.py | 25 +++- tuning/trackers/mlflow_tracker.py | 204 ++++++++++++++++++++++++++ tuning/trackers/tracker_factory.py | 54 +++++-- 7 files changed, 412 insertions(+), 23 deletions(-) create mode 100644 tests/trackers/test_mlflow_tracker.py create mode 100644 tuning/trackers/mlflow_tracker.py diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 8dcdf3087..f74bc09e7 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -356,6 +356,7 @@ def test_parse_arguments(job_config): _, _, _, + _, ) = sft_trainer.parse_arguments(parser, job_config_copy) assert str(model_args.torch_dtype) == "torch.bfloat16" assert data_args.dataset_text_field == "output" @@ -381,6 +382,7 @@ def test_parse_arguments_defaults(job_config): _, _, _, + _, ) = sft_trainer.parse_arguments(parser, job_config_defaults) assert str(model_args.torch_dtype) == "torch.bfloat16" assert model_args.use_flash_attn is False @@ -391,14 +393,14 @@ def test_parse_arguments_peft_method(job_config): parser = sft_trainer.get_parser() job_config_pt = copy.deepcopy(job_config) job_config_pt["peft_method"] = "pt" - _, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_pt ) assert isinstance(tune_config, peft_config.PromptTuningConfig) job_config_lora = copy.deepcopy(job_config) job_config_lora["peft_method"] = "lora" - _, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_lora ) assert isinstance(tune_config, peft_config.LoraConfig) diff --git a/tests/trackers/test_aim_tracker.py b/tests/trackers/test_aim_tracker.py index d2aa301b7..cad19a78b 100644 --- a/tests/trackers/test_aim_tracker.py +++ b/tests/trackers/test_aim_tracker.py @@ -58,7 +58,7 @@ def fixture_aimrepo(): @pytest.mark.skipif(aim_not_available, reason="Requires aimstack to be installed") -def test_run_with_good_tracker_name_but_no_args(): +def test_run_with_aim_tracker_name_but_no_args(): """Ensure that train() raises error with aim tracker name but no args""" with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/trackers/test_mlflow_tracker.py b/tests/trackers/test_mlflow_tracker.py new file mode 100644 index 000000000..6501d6280 --- /dev/null +++ b/tests/trackers/test_mlflow_tracker.py @@ -0,0 +1,124 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Standard +import copy +import json +import os +import tempfile + +# Third Party +from transformers.utils.import_utils import _is_package_available +import pytest + +# First Party +from tests.test_sft_trainer import ( + DATA_ARGS, + MODEL_ARGS, + TRAIN_ARGS, + _get_checkpoint_path, + _test_run_inference, + _validate_training, +) + +# Local +from tuning import sft_trainer +from tuning.config.tracker_configs import MLflowConfig, TrackerConfigFactory + +mlflow_not_available = not _is_package_available("mlflow") + + +@pytest.mark.skipif(mlflow_not_available, reason="Requires mlflow to be installed") +def test_run_with_mlflow_tracker_name_but_no_args(): + """Ensure that train() raises error with mlflow tracker name but no args""" + + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + train_args.trackers = ["mlflow"] + + with pytest.raises( + ValueError, + match="mlflow tracker requested but mlflow_uri is not specified.", + ): + sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args) + + +@pytest.mark.skipif(mlflow_not_available, reason="Requires mlflow to be installed") +def test_e2e_run_with_mlflow_tracker(): + """Ensure that training succeeds with mlflow tracker""" + + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + # This should not mean file logger is not present. + # code will add it by default + # The below validate_training check will test for that too. + train_args.trackers = ["mlflow"] + + tracker_configs = TrackerConfigFactory( + mlflow_config=MLflowConfig( + mlflow_experiment="unit_test", + mlflow_tracking_uri=os.path.join(tempdir, "mlflowdb.sqlite"), + ) + ) + + sft_trainer.train( + MODEL_ARGS, DATA_ARGS, train_args, tracker_configs=tracker_configs + ) + + # validate ft tuning configs + _validate_training(tempdir) + + # validate inference + _test_run_inference(checkpoint_path=_get_checkpoint_path(tempdir)) + + +@pytest.mark.skipif(mlflow_not_available, reason="Requires mlflow to be installed") +def test_e2e_run_with_mlflow_runuri_export_default_path(): + """Ensure that mlflow outputs run uri in the output dir by default""" + + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + train_args.trackers = ["mlflow"] + + tracker_configs = TrackerConfigFactory( + mlflow_config=MLflowConfig( + mlflow_experiment="unit_test", + mlflow_tracking_uri=os.path.join(tempdir, "mlflowdb.sqlite"), + ) + ) + + sft_trainer.train( + MODEL_ARGS, DATA_ARGS, train_args, tracker_configs=tracker_configs + ) + + # validate ft tuning configs + _validate_training(tempdir) + + run_uri_file = os.path.join(tempdir, "mlflow_tracker.json") + + assert os.path.exists(run_uri_file) is True + assert os.path.getsize(run_uri_file) > 0 + + with open(run_uri_file, "r", encoding="utf-8") as f: + content = json.loads(f.read()) + assert "run_uri" in content diff --git a/tuning/config/tracker_configs.py b/tuning/config/tracker_configs.py index 5a8781375..a80407efe 100644 --- a/tuning/config/tracker_configs.py +++ b/tuning/config/tracker_configs.py @@ -63,7 +63,27 @@ def __post_init__(self): ) +@dataclass +class MLflowConfig: + # Name of the experiment + mlflow_experiment: str = None + mlflow_tracking_uri: str = None + # Location of where mlflow's run uri is to be exported. + # If mlflow_run_uri_export_path is set the run uri will be output in a json format + # to the location pointed to by `mlflow_run_uri_export_path/mlflow_tracker.json` + # If this is not set then the default location where run uri will be exported + # is training_args.output_dir/mlflow_tracker.json + # Run uri is not exported if mlflow_run_uri_export_path variable is not set + # and output_dir is not specified. + mlflow_run_uri_export_path: str = None + + def __post_init__(self): + if self.mlflow_experiment is None: + self.mlflow_experiment = "fms-hf-tuning" + + @dataclass class TrackerConfigFactory: file_logger_config: FileLoggingTrackerConfig = None aim_config: AimConfig = None + mlflow_config: MLflowConfig = None diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index dc3e3733e..c4798484d 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -51,6 +51,7 @@ from tuning.config.tracker_configs import ( AimConfig, FileLoggingTrackerConfig, + MLflowConfig, TrackerConfigFactory, ) from tuning.data.setup_dataprocessor import process_dataargs @@ -434,6 +435,7 @@ def get_parser(): QuantizedLoraConfig, FusedOpsAndKernelsConfig, AttentionAndDistributedPackingConfig, + MLflowConfig, ) ) parser.add_argument( @@ -483,8 +485,10 @@ def parse_arguments(parser, json_config=None): Configuration for fused operations and kernels. AttentionAndDistributedPackingConfig Configuration for padding free and packing. + MLflowConfig + Configuration for mlflow tracker. dict[str, str] - Extra AIM metadata. + Extra tracker metadata. """ if json_config: ( @@ -499,6 +503,7 @@ def parse_arguments(parser, json_config=None): quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, + mlflow_config, ) = parser.parse_dict(json_config, allow_extra_keys=True) peft_method = json_config.get("peft_method") exp_metadata = json_config.get("exp_metadata") @@ -515,6 +520,7 @@ def parse_arguments(parser, json_config=None): quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, + mlflow_config, additional, _, ) = parser.parse_args_into_dataclasses(return_remaining_strings=True) @@ -540,6 +546,7 @@ def parse_arguments(parser, json_config=None): quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, + mlflow_config, exp_metadata, ) @@ -561,6 +568,7 @@ def main(): quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, + mlflow_config, exp_metadata, ) = parse_arguments(parser, job_config) @@ -572,7 +580,8 @@ def main(): model_args %s, data_args %s, training_args %s, trainer_controller_args %s, \ tune_config %s, file_logger_config, %s aim_config %s, \ quantized_lora_config %s, fusedops_kernels_config %s, \ - attention_and_distributed_packing_config %s exp_metadata %s", + attention_and_distributed_packing_config %s,\ + mlflow_config %s, exp_metadata %s", model_args, data_args, training_args, @@ -583,6 +592,7 @@ def main(): quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, + mlflow_config, exp_metadata, ) except Exception as e: # pylint: disable=broad-except @@ -607,10 +617,11 @@ def main(): "failed while parsing extra metadata. pass a valid json %s", repr(e) ) - combined_tracker_configs = TrackerConfigFactory() - - combined_tracker_configs.file_logger_config = file_logger_config - combined_tracker_configs.aim_config = aim_config + tracker_configs = TrackerConfigFactory( + file_logger_config=file_logger_config, + aim_config=aim_config, + mlflow_config=mlflow_config, + ) if training_args.output_dir: os.makedirs(training_args.output_dir, exist_ok=True) @@ -622,7 +633,7 @@ def main(): train_args=training_args, peft_config=tune_config, trainer_controller_args=trainer_controller_args, - tracker_configs=combined_tracker_configs, + tracker_configs=tracker_configs, additional_callbacks=None, exp_metadata=metadata, quantized_lora_config=quantized_lora_config, diff --git a/tuning/trackers/mlflow_tracker.py b/tuning/trackers/mlflow_tracker.py new file mode 100644 index 000000000..1368f2325 --- /dev/null +++ b/tuning/trackers/mlflow_tracker.py @@ -0,0 +1,204 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import json +import logging +import os + +# Third Party +from transformers.integrations import MLflowCallback # pylint: disable=import-error + +# Local +from .tracker import Tracker +from tuning.config.tracker_configs import MLflowConfig + +MLFLOW_RUN_URI_EXPORT_FILENAME = "mlflow_tracker.json" + + +class RunIDExporterMlFlowCallback(MLflowCallback): + """ + Custom MlflowCallBack callback is used to export run id from Mlflow + as soon as it is created, which is on setup(). + """ + + run_uri_export_path: str = None + + # Override ml flow callback setup function + # Initialise mlflow callback and export Mlflow's run url. + # Export location is MlflowConfig.mlflow_run_uri_export_path if it is passed + # or, training_args.output_dir/mlflow_tracker.json if output_dir is present + # Exported url looks like '{"run_uri":""}' in the file. + # url is not exported if both paths are invalid + def setup(self, args, state, model): + """Override the `setup` function in the `MLflowCallback` callback. + + This function performs the following steps: + 1. Calls `MLFlowCallBack.setup` to + initialize internal `mlflow` structures. + 2. Exports the `Mlflow` run uri: + - If `MLflowConfig.mlflow_run_uri_export_path` is provided, the uri + is exported to `mlflow_run_uri_export_path/mlflow_tracker.json` + - If `MLflowConfig.mlflow_run_uri_export_path` is not provided but + `args.output_dir` is specified, the uri is exported to + `args.output_dir/mlflow_tracker.json` + - If neither path is valid, the uri is not exported. + + The exported uri is formatted as '{"run_uri":""}'. + + Args: + For the arguments see reference to transformers.TrainingCallback + """ + super().setup(args, state, model) + + active_run = self._ml_flow.active_run() + if not active_run: + return + + active_run_info = active_run.info + if active_run_info: + experiment_id = active_run_info.experiment_id + experiment_url = f"/#/experiments/{experiment_id}" + run_id = active_run_info.run_id + run_name = active_run_info.run_name + run_uri = f"{experiment_url}/runs/{run_id}" + + if run_uri is None: + return + + # Change default uri path to output directory if not specified + if self.run_uri_export_path is None: + if args is None or args.output_dir is None: + logging.warning( + "To export mlflow uri either output_dir \ + or mlflow_run_id_export_path should be set" + ) + return + + self.run_uri_export_path = args.output_dir + + if not os.path.exists(self.run_uri_export_path): + os.makedirs(self.run_uri_export_path, exist_ok=True) + + export_path = os.path.join( + self.run_id_export_path, MLFLOW_RUN_URI_EXPORT_FILENAME + ) + with open(export_path, "w", encoding="utf-8") as f: + f.write(json.dumps({"run_name": run_name, "run_uri": run_uri})) + self.logger.info("Mlflow tracker run uri id dumped to " + export_path) + + +class MLflowTracker(Tracker): + def __init__(self, tracker_config: MLflowConfig): + """Tracker which uses mlflow to collect and store metrics. + + Args: + tracker_config (MLflowConfig): A valid MLflowConfig which contains + information on where the mlflow tracking uri is present. + """ + super().__init__(name="mlflow", tracker_config=tracker_config) + + def get_hf_callback(self): + """Returns the MLFlowCallBack object associated with this tracker. + + Raises: + ValueError: If the config passed at initialise does not contain + the uri where the mlflow tracking server is present + + Returns: + MLFlowCallBack: The MLFlowCallBack initialsed with the config + provided at init time. + """ + c = self.config + exp = c.mlflow_experiment + uri = c.mlflow_tracking_uri + run_uri_path = c.mlflow_run_uri_export_path + + if uri is None: + logging.error( + "mlflow tracker requested but mlflow_uri is not specified. " + + "Please specify mlflow uri for using mlflow." + ) + raise ValueError( + "mlflow tracker requested but mlflow_uri is not specified." + ) + + # Modify the environment expected by mlflow + os.environ["MLFLOW_TRACKING_URI"] = uri + os.environ["MLFLOW_EXPERIMENT_NAME"] = exp + + mlflow_callback = RunIDExporterMlFlowCallback() + + if mlflow_callback is not None: + mlflow_callback.run_uri_export_path = run_uri_path + + self.hf_callback = mlflow_callback + return self.hf_callback + + def track(self, metric, name, stage=None): + """Track any additional metric with name under mlflow tracker. + + Args: + metric (int/float): Expected metrics to be tracked by mlflow. + name (str): Name of the metric being tracked. + stage (str, optional): Can be used to pass the namespace/metadata to + associate with metric, e.g. at the stage the metric was generated + like train, eval. Defaults to None. + If not None the metric is saved as { state.name : metric } + Raises: + ValueError: If the metric or name are passed as None. + """ + if metric is None or name is None: + raise ValueError( + "mlflow track function should not be called with None metric value or name" + ) + + if stage is not None: + name = f"{stage}.{name}" + + callback = self.hf_callback + mlflow = callback._mlflow + if mlflow is not None: + mlflow.log_metric(key=name, value=metric) + + def set_params(self, params, name=None): + """Attach any extra params with the run information stored in mlflow tracker. + + Args: + params (dict): A dict of k:v pairs of parameters to be storeed in tracker. + name (str, optional): represents the namespace under which parameters + will be associated in mlflow. e.g. {name: params} + Defaults to None. + + Raises: + ValueError: the params passed is None or not of type dict + """ + if params is None: + return + if not isinstance(params, dict): + raise ValueError( + "set_params passed to mlflow should be called with a dict of params" + ) + if name and not isinstance(name, str): + raise ValueError("name passed to mlflow set_params should be a string") + + if name: + tolog = {name: params} + else: + tolog = params + + callback = self.hf_callback + mlflow = callback._mlflow + if mlflow is not None: + mlflow.log_params(tolog) diff --git a/tuning/trackers/tracker_factory.py b/tuning/trackers/tracker_factory.py index 096099306..f86cbb1a1 100644 --- a/tuning/trackers/tracker_factory.py +++ b/tuning/trackers/tracker_factory.py @@ -6,7 +6,7 @@ # # http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software +# Unless required by applicable law or agreed to in writing, aim_reposoftware # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and @@ -26,8 +26,9 @@ # Information about all registered trackers AIMSTACK_TRACKER = "aim" FILE_LOGGING_TRACKER = "file_logger" +MLFLOW_TRACKER = "mlflow" -AVAILABLE_TRACKERS = [AIMSTACK_TRACKER, FILE_LOGGING_TRACKER] +AVAILABLE_TRACKERS = [AIMSTACK_TRACKER, FILE_LOGGING_TRACKER, MLFLOW_TRACKER] # Trackers which can be used @@ -35,12 +36,19 @@ # One time package check for list of external trackers. _is_aim_available = _is_package_available("aim") +_is_mlflow_available = _is_package_available("mlflow") def _get_tracker_class(T, C): return {"tracker": T, "config": C} +def _is_tracker_installed(name): + if name == "aim": + return _is_aim_available + return False + + def _register_aim_tracker(): # pylint: disable=import-outside-toplevel if _is_aim_available: @@ -60,10 +68,23 @@ def _register_aim_tracker(): ) -def _is_tracker_installed(name): - if name == "aim": - return _is_aim_available - return False +def _register_mlflow_tracker(): + # pylint: disable=import-outside-toplevel + if _is_mlflow_available: + # Local + from .mlflow_tracker import MLflowTracker + from tuning.config.tracker_configs import MLflowConfig + + mlflow_tracker = _get_tracker_class(MLflowTracker, MLflowConfig) + + REGISTERED_TRACKERS[MLFLOW_TRACKER] = mlflow_tracker + logging.info("Registered mlflow tracker") + else: + logging.info( + "Not registering mlflow tracker due to unavailablity of package.\n" + "Please install mlflow if you intend to use it.\n" + "\t pip install mlflow" + ) def _register_file_logging_tracker(): @@ -81,6 +102,8 @@ def _register_trackers(): _register_aim_tracker() if FILE_LOGGING_TRACKER not in REGISTERED_TRACKERS: _register_file_logging_tracker() + if MLFLOW_TRACKER not in REGISTERED_TRACKERS: + _register_mlflow_tracker() def _get_tracker_config_by_name(name: str, tracker_configs: TrackerConfigFactory): @@ -111,6 +134,7 @@ def get_tracker(name: str, tracker_configs: TrackerConfigFactory): Valid classes available are, tuning.trackers.tracker.aimstack_tracker.AimStackTracker, tuning.trackers.tracker.filelogging_tracker.FileLoggingTracker + tuning.trackers.tracker.mlflow_tracker.MLflowTracker Examples: file_logging_tracker = get_tracker("file_logger", TrackerConfigFactory( @@ -120,10 +144,16 @@ def get_tracker(name: str, tracker_configs: TrackerConfigFactory): )) aim_tracker = get_tracker("aim", TrackerConfigFactory( aim_config=AimConfig( - experiment="unit_test", + experiment="test", aim_repo=tempdir + "/" ) )) + mlflow_tracker = get_tracker("mlflow", TrackerConfigFactory( + mlflow_config=MLflowConfig( + experiment="test", + mlflow_tracking_uri="./mlflow.sqlite" + ) + )) """ if not REGISTERED_TRACKERS: # a one time step. @@ -131,14 +161,12 @@ def get_tracker(name: str, tracker_configs: TrackerConfigFactory): if name not in REGISTERED_TRACKERS: if name in AVAILABLE_TRACKERS and (not _is_tracker_installed(name)): - e = "Requested tracker {} is not installed. Please install before proceeding".format( - name - ) + e = f"Requested tracker {name} is not installed.\ + Please install before proceeding" else: available = ", ".join(str(t) for t in AVAILABLE_TRACKERS) - e = "Requested Tracker {} not found. List trackers available for use is - {} ".format( - name, available - ) + e = f"Requested Tracker {name} not found.\ + List trackers available for use is - {available} " logging.error(e) raise ValueError(e)