diff --git a/README.md b/README.md index aebf68900..6d18d72ab 100644 --- a/README.md +++ b/README.md @@ -823,12 +823,13 @@ For details about how you can use set a custom stopping criteria and perform cus ## Experiment Tracking -Experiment tracking in fms-hf-tuning allows users to track their experiments with known trackers like [Aimstack](https://aimstack.io/) or custom trackers built into the code like +Experiment tracking in fms-hf-tuning allows users to track their experiments with known trackers like [Aimstack](https://aimstack.io/), [MLflow Tracking](https://mlflow.org/docs/latest/tracking.html) or custom trackers built into the code like [FileLoggingTracker](./tuning/trackers/filelogging_tracker.py) The code supports currently two trackers out of the box, * `FileLoggingTracker` : A built in tracker which supports logging training loss to a file. * `Aimstack` : A popular opensource tracker which can be used to track any metrics or metadata from the experiments. +* `MLflow Tracking` : Another popular opensource tracker which stores metrics, metadata or even artifacts from experiments. Further details on enabling and using the trackers mentioned above can be found [here](docs/experiment-tracking.md). diff --git a/build/Dockerfile b/build/Dockerfile index d8cc74877..9a6a5583f 100644 --- a/build/Dockerfile +++ b/build/Dockerfile @@ -19,8 +19,9 @@ ARG USER=tuning ARG USER_UID=1000 ARG PYTHON_VERSION=3.11 ARG WHEEL_VERSION="" -## Enable Aimstack if requested via ENABLE_AIM set to "true" +## Enable Aimstack or MLflow if requested via ENABLE_AIM/MLFLOW set to "true" ARG ENABLE_AIM=false +ARG ENABLE_MLFLOW=false ARG ENABLE_FMS_ACCELERATION=true ## Base Layer ################################################################## @@ -151,6 +152,10 @@ RUN if [[ "${ENABLE_AIM}" == "true" ]]; then \ python -m pip install --user "$(head bdist_name)[aim]"; \ fi +RUN if [[ "${ENABLE_MLFLOW}" == "true" ]]; then \ + python -m pip install --user "$(head bdist_name)[mlflow]"; \ +fi + # Clean up the wheel module. It's only needed by flash-attn install RUN python -m pip uninstall wheel build -y && \ # Cleanup the bdist whl file diff --git a/docs/experiment-tracking.md b/docs/experiment-tracking.md index edc4e5978..598a1941b 100644 --- a/docs/experiment-tracking.md +++ b/docs/experiment-tracking.md @@ -115,6 +115,34 @@ sft_trainer.train(train_args=training_args, tracker_configs=tracker_configs,.... The code expects either the `local` or `remote` repo to be specified and will result in a `ValueError` otherwise. See [AimConfig](https://github.com/foundation-model-stack/fms-hf-tuning/blob/a9b8ec8d1d50211873e63fa4641054f704be8712/tuning/config/tracker_configs.py#L25) for more details. +## MLflow Tracker + +To enable [MLflow Tracking](https://mlflow.org/docs/latest/tracking.html) users need to pass `"mlflow"` as the requested tracker as part of the [training argument](https://github.com/foundation-model-stack/fms-hf-tuning/blob/a9b8ec8d1d50211873e63fa4641054f704be8712/tuning/config/configs.py#L131). + + +When using MLflow, users need to specify additional arguments which specify [mlflow tracking uri](https://mlflow.org/docs/latest/tracking.html#common-setups) location where either a [mlflow supported database](https://mlflow.org/docs/latest/tracking/backend-stores.html#supported-store-types) or [mlflow remote tracking server](https://mlflow.org/docs/latest/tracking/server.html) is running. + +Example +``` +from tuning import sft_trainer +from tuning.config.tracker_configs import MLflowConfig, TrackerConfigFactory + +training_args = TrainingArguments( + ..., + trackers = ["mlflow"], +) + +tracker_configs = TrackerConfigFactory( + mlflow_config=MLflowConfig( + mlflow_experiment="experiment-name", + mlflow_tracking_uri= + ) + ) + +sft_trainer.train(train_args=training_args, tracker_configs=tracker_configs,....) +``` + +The code expects a valid uri to be specified and will result in a `ValueError` otherwise. ## Running the code via command line `tuning/sft_trainer::main` function @@ -123,10 +151,10 @@ If running the code via main function of [sft_trainer.py](../tuning/sft_trainer. To enable tracking please pass ``` ---tracker +--tracker ``` -To further customise tracking you can specify additional arguments needed by the tracker like +To further customise tracking you can specify additional arguments needed by the tracker like (example shows aim follow similarly for mlflow) ``` --tracker aim --aim_repo --experiment diff --git a/pyproject.toml b/pyproject.toml index 9cfeecdc4..8301ce253 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ dev = ["wheel>=0.42.0,<1.0", "packaging>=23.2,<25", "ninja>=1.11.1.1,<2.0", "scikit-learn>=1.0, <2.0", "boto3>=1.34, <2.0"] flash-attn = ["flash-attn>=2.5.3,<3.0"] aim = ["aim>=3.19.0,<4.0"] +mlflow = ["mlflow"] fms-accel = ["fms-acceleration>=0.1"] gptq-dev = ["auto_gptq>0.4.2", "optimum>=1.15.0"] diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 2efbabb1c..1cee6a3c9 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -359,6 +359,7 @@ def test_parse_arguments(job_config): _, _, _, + _, ) = sft_trainer.parse_arguments(parser, job_config_copy) assert str(model_args.torch_dtype) == "torch.bfloat16" assert data_args.dataset_text_field == "output" @@ -384,6 +385,7 @@ def test_parse_arguments_defaults(job_config): _, _, _, + _, ) = sft_trainer.parse_arguments(parser, job_config_defaults) assert str(model_args.torch_dtype) == "torch.bfloat16" assert model_args.use_flash_attn is False @@ -394,14 +396,14 @@ def test_parse_arguments_peft_method(job_config): parser = sft_trainer.get_parser() job_config_pt = copy.deepcopy(job_config) job_config_pt["peft_method"] = "pt" - _, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_pt ) assert isinstance(tune_config, peft_config.PromptTuningConfig) job_config_lora = copy.deepcopy(job_config) job_config_lora["peft_method"] = "lora" - _, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_lora ) assert isinstance(tune_config, peft_config.LoraConfig) diff --git a/tests/trackers/test_aim_tracker.py b/tests/trackers/test_aim_tracker.py index d2aa301b7..cad19a78b 100644 --- a/tests/trackers/test_aim_tracker.py +++ b/tests/trackers/test_aim_tracker.py @@ -58,7 +58,7 @@ def fixture_aimrepo(): @pytest.mark.skipif(aim_not_available, reason="Requires aimstack to be installed") -def test_run_with_good_tracker_name_but_no_args(): +def test_run_with_aim_tracker_name_but_no_args(): """Ensure that train() raises error with aim tracker name but no args""" with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/trackers/test_mlflow_tracker.py b/tests/trackers/test_mlflow_tracker.py new file mode 100644 index 000000000..f83af715e --- /dev/null +++ b/tests/trackers/test_mlflow_tracker.py @@ -0,0 +1,130 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Standard +import copy +import json +import os +import tempfile +import contextlib + +# Third Party +from transformers.utils.import_utils import _is_package_available +import pytest + +# First Party +from tests.test_sft_trainer import ( + DATA_ARGS, + MODEL_ARGS, + TRAIN_ARGS, + _get_checkpoint_path, + _test_run_inference, + _validate_training, +) + +# Local +from tuning import sft_trainer +from tuning.config.tracker_configs import MLflowConfig, TrackerConfigFactory + +mlflow_not_available = not _is_package_available("mlflow") + +@pytest.mark.skipif(mlflow_not_available, reason="Requires mlflow to be installed") +def test_run_with_mlflow_tracker_name_but_no_args(): + """Ensure that train() raises error with mlflow tracker name but no args""" + + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + train_args.trackers = ["mlflow"] + + with pytest.raises( + ValueError, + match="mlflow tracker requested but mlflow_uri is not specified.", + ): + sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args) + + +@pytest.mark.skipif(mlflow_not_available, reason="Requires mlflow to be installed") +def test_e2e_run_with_mlflow_tracker(): + """Ensure that training succeeds with mlflow tracker""" + + # mlflow performs a cleanup at callback close time which happens post the + # delete of this directory so we run into two issues + # 1. the temp directory cannot be cleared as it has open pointer by mlflow + # 2. mlflow complaints that it cannot find a run which it just created. + # this is a race condition which is fixed with mkdtemp() which doesn't delete + tempdir = tempfile.mkdtemp() + + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + # This should not mean file logger is not present. + # code will add it by default + # The below validate_training check will test for that too. + train_args.trackers = ["mlflow"] + + tracker_configs = TrackerConfigFactory( + mlflow_config = MLflowConfig( + mlflow_experiment="unit_test", + mlflow_tracking_uri=os.path.join(tempdir, "mlflowdb.sqlite"), + ) + ) + + sft_trainer.train( + MODEL_ARGS, DATA_ARGS, train_args, tracker_configs=tracker_configs + ) + + # validate ft tuning configs + _validate_training(tempdir) + + # validate inference + _test_run_inference(checkpoint_path=_get_checkpoint_path(tempdir)) + + +@pytest.mark.skipif(mlflow_not_available, reason="Requires mlflow to be installed") +def test_e2e_run_with_mlflow_runuri_export_default_path(): + """Ensure that mlflow outputs run uri in the output dir by default""" + + tempdir = tempfile.mkdtemp() + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + train_args.trackers = ["mlflow"] + + tracker_configs = TrackerConfigFactory( + mlflow_config = MLflowConfig( + mlflow_experiment="unit_test", + mlflow_tracking_uri=os.path.join(tempdir, "mlflowdb.sqlite"), + ) + ) + + sft_trainer.train( + MODEL_ARGS, DATA_ARGS, train_args, tracker_configs=tracker_configs + ) + + # validate ft tuning configs + _validate_training(tempdir) + + run_uri_file = os.path.join(tempdir, "mlflow_tracker.json") + + assert os.path.exists(run_uri_file) is True + assert os.path.getsize(run_uri_file) > 0 + + with open(run_uri_file, "r", encoding="utf-8") as f: + content = json.loads(f.read()) + assert "run_uri" in content diff --git a/tuning/config/tracker_configs.py b/tuning/config/tracker_configs.py index 5a8781375..a80407efe 100644 --- a/tuning/config/tracker_configs.py +++ b/tuning/config/tracker_configs.py @@ -63,7 +63,27 @@ def __post_init__(self): ) +@dataclass +class MLflowConfig: + # Name of the experiment + mlflow_experiment: str = None + mlflow_tracking_uri: str = None + # Location of where mlflow's run uri is to be exported. + # If mlflow_run_uri_export_path is set the run uri will be output in a json format + # to the location pointed to by `mlflow_run_uri_export_path/mlflow_tracker.json` + # If this is not set then the default location where run uri will be exported + # is training_args.output_dir/mlflow_tracker.json + # Run uri is not exported if mlflow_run_uri_export_path variable is not set + # and output_dir is not specified. + mlflow_run_uri_export_path: str = None + + def __post_init__(self): + if self.mlflow_experiment is None: + self.mlflow_experiment = "fms-hf-tuning" + + @dataclass class TrackerConfigFactory: file_logger_config: FileLoggingTrackerConfig = None aim_config: AimConfig = None + mlflow_config: MLflowConfig = None diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 4e1ae1171..2afdd2dac 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -51,6 +51,7 @@ from tuning.config.tracker_configs import ( AimConfig, FileLoggingTrackerConfig, + MLflowConfig, TrackerConfigFactory, ) from tuning.data.setup_dataprocessor import process_dataargs @@ -444,6 +445,7 @@ def get_parser(): QuantizedLoraConfig, FusedOpsAndKernelsConfig, AttentionAndDistributedPackingConfig, + MLflowConfig, ) ) parser.add_argument( @@ -493,8 +495,10 @@ def parse_arguments(parser, json_config=None): Configuration for fused operations and kernels. AttentionAndDistributedPackingConfig Configuration for padding free and packing. + MLflowConfig + Configuration for mlflow tracker. dict[str, str] - Extra AIM metadata. + Extra tracker metadata. """ if json_config: ( @@ -509,6 +513,7 @@ def parse_arguments(parser, json_config=None): quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, + mlflow_config, ) = parser.parse_dict(json_config, allow_extra_keys=True) peft_method = json_config.get("peft_method") exp_metadata = json_config.get("exp_metadata") @@ -525,6 +530,7 @@ def parse_arguments(parser, json_config=None): quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, + mlflow_config, additional, _, ) = parser.parse_args_into_dataclasses(return_remaining_strings=True) @@ -550,6 +556,7 @@ def parse_arguments(parser, json_config=None): quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, + mlflow_config, exp_metadata, ) @@ -571,6 +578,7 @@ def main(): quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, + mlflow_config, exp_metadata, ) = parse_arguments(parser, job_config) @@ -582,7 +590,8 @@ def main(): model_args %s, data_args %s, training_args %s, trainer_controller_args %s, \ tune_config %s, file_logger_config, %s aim_config %s, \ quantized_lora_config %s, fusedops_kernels_config %s, \ - attention_and_distributed_packing_config %s exp_metadata %s", + attention_and_distributed_packing_config %s,\ + mlflow_config %s, exp_metadata %s", model_args, data_args, training_args, @@ -593,6 +602,7 @@ def main(): quantized_lora_config, fusedops_kernels_config, attention_and_distributed_packing_config, + mlflow_config, exp_metadata, ) except Exception as e: # pylint: disable=broad-except @@ -617,10 +627,11 @@ def main(): "failed while parsing extra metadata. pass a valid json %s", repr(e) ) - combined_tracker_configs = TrackerConfigFactory() - - combined_tracker_configs.file_logger_config = file_logger_config - combined_tracker_configs.aim_config = aim_config + tracker_configs = TrackerConfigFactory( + file_logger_config=file_logger_config, + aim_config=aim_config, + mlflow_config=mlflow_config, + ) if training_args.output_dir: os.makedirs(training_args.output_dir, exist_ok=True) @@ -632,7 +643,7 @@ def main(): train_args=training_args, peft_config=tune_config, trainer_controller_args=trainer_controller_args, - tracker_configs=combined_tracker_configs, + tracker_configs=tracker_configs, additional_callbacks=None, exp_metadata=metadata, quantized_lora_config=quantized_lora_config, diff --git a/tuning/trackers/mlflow_tracker.py b/tuning/trackers/mlflow_tracker.py new file mode 100644 index 000000000..687cfd656 --- /dev/null +++ b/tuning/trackers/mlflow_tracker.py @@ -0,0 +1,205 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import json +import logging +import os + +# Third Party +from transformers.integrations import MLflowCallback # pylint: disable=import-error + +# Local +from .tracker import Tracker +from tuning.config.tracker_configs import MLflowConfig + +MLFLOW_RUN_URI_EXPORT_FILENAME = "mlflow_tracker.json" + + +class RunURIExporterMlflowCallback(MLflowCallback): + """ + Custom MlflowCallBack callback is used to export run uri from Mlflow + as soon as it is created, which is on setup(). + """ + + run_uri_export_path: str = None + client = None + + # Override ml flow callback setup function + # Initialise mlflow callback and export Mlflow's run url. + # Export location is MlflowConfig.mlflow_run_uri_export_path if it is passed + # or, training_args.output_dir/mlflow_tracker.json if output_dir is present + # Exported url looks like '{"run_uri":""}' in the file. + # url is not exported if both paths are invalid + def setup(self, args, state, model): + """Override the `setup` function in the `MLflowCallback` callback. + + This function performs the following steps: + 1. Calls `MLFlowCallBack.setup` to + initialize internal `mlflow` structures. + 2. Exports the `Mlflow` run uri: + - If `MLflowConfig.mlflow_run_uri_export_path` is provided, the uri + is exported to `mlflow_run_uri_export_path/mlflow_tracker.json` + - If `MLflowConfig.mlflow_run_uri_export_path` is not provided but + `args.output_dir` is specified, the uri is exported to + `args.output_dir/mlflow_tracker.json` + - If neither path is valid, the uri is not exported. + + The exported uri is formatted as '{"run_uri":""}'. + + Args: + For the arguments see reference to transformers.TrainingCallback + """ + super().setup(args, state, model) + + self.client = self._ml_flow + + active_run = self.client.active_run() + if not active_run: + return + + active_run_info = active_run.info + if active_run_info: + experiment_id = active_run_info.experiment_id + experiment_url = f"/#/experiments/{experiment_id}" + run_id = active_run_info.run_id + run_name = active_run_info.run_name + run_uri = f"{experiment_url}/runs/{run_id}" + + if run_uri is None: + return + + # Change default uri path to output directory if not specified + if self.run_uri_export_path is None: + if args is None or args.output_dir is None: + logging.warning( + "To export mlflow uri either output_dir \ + or mlflow_run_id_export_path should be set" + ) + return + + self.run_uri_export_path = args.output_dir + + if not os.path.exists(self.run_uri_export_path): + os.makedirs(self.run_uri_export_path, exist_ok=True) + + export_path = os.path.join( + self.run_uri_export_path, MLFLOW_RUN_URI_EXPORT_FILENAME + ) + with open(export_path, "w", encoding="utf-8") as f: + f.write(json.dumps({"run_name": run_name, "run_uri": run_uri})) + logging.info("Mlflow tracker run uri id dumped to " + export_path) + + +class MLflowTracker(Tracker): + def __init__(self, tracker_config: MLflowConfig): + """Tracker which uses mlflow to collect and store metrics. + + Args: + tracker_config (MLflowConfig): A valid MLflowConfig which contains + information on where the mlflow tracking uri is present. + """ + super().__init__(name="mlflow", tracker_config=tracker_config) + + def get_hf_callback(self): + """Returns the MLFlowCallBack object associated with this tracker. + + Raises: + ValueError: If the config passed at initialise does not contain + the uri where the mlflow tracking server is present + + Returns: + MLFlowCallBack: The MLFlowCallBack initialsed with the config + provided at init time. + """ + c = self.config + exp = c.mlflow_experiment + uri = c.mlflow_tracking_uri + run_uri_path = c.mlflow_run_uri_export_path + + if uri is None: + logging.error( + "mlflow tracker requested but mlflow_uri is not specified. " + + "Please specify mlflow uri for using mlflow." + ) + raise ValueError( + "mlflow tracker requested but mlflow_uri is not specified." + ) + + # Modify the environment expected by mlflow + os.environ["MLFLOW_TRACKING_URI"] = uri + os.environ["MLFLOW_EXPERIMENT_NAME"] = exp + + mlflow_callback = RunURIExporterMlflowCallback() + + if mlflow_callback is not None: + mlflow_callback.run_uri_export_path = run_uri_path + + self.hf_callback = mlflow_callback + return self.hf_callback + + def track(self, metric, name, stage=None): + """Track any additional metric with name under mlflow tracker. + + Args: + metric (int/float): Expected metrics to be tracked by mlflow. + name (str): Name of the metric being tracked. + stage (str, optional): Can be used to pass the namespace/metadata to + associate with metric, e.g. at the stage the metric was generated + like train, eval. Defaults to None. + If not None the metric is saved as { state.name : metric } + Raises: + ValueError: If the metric or name are passed as None. + """ + if metric is None or name is None: + raise ValueError( + "mlflow track function should not be called with None metric value or name" + ) + + if stage is not None: + name = f"{stage}.{name}" + + mlflow = self.hf_callback.client + if mlflow is not None: + mlflow.log_metric(key=name, value=metric) + + def set_params(self, params, name=None): + """Attach any extra params with the run information stored in mlflow tracker. + + Args: + params (dict): A dict of k:v pairs of parameters to be storeed in tracker. + name (str, optional): represents the namespace under which parameters + will be associated in mlflow. e.g. {name: params} + Defaults to None. + + Raises: + ValueError: the params passed is None or not of type dict + """ + if params is None: + return + if not isinstance(params, dict): + raise ValueError( + "set_params passed to mlflow should be called with a dict of params" + ) + if name and not isinstance(name, str): + raise ValueError("name passed to mlflow set_params should be a string") + + if name: + tolog = {name: params} + else: + tolog = params + + mlflow = self.hf_callback.client + if mlflow is not None: + mlflow.log_params(tolog) diff --git a/tuning/trackers/tracker_factory.py b/tuning/trackers/tracker_factory.py index 096099306..3339347d9 100644 --- a/tuning/trackers/tracker_factory.py +++ b/tuning/trackers/tracker_factory.py @@ -6,7 +6,7 @@ # # http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software +# Unless required by applicable law or agreed to in writing, aim_reposoftware # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and @@ -26,8 +26,9 @@ # Information about all registered trackers AIMSTACK_TRACKER = "aim" FILE_LOGGING_TRACKER = "file_logger" +MLFLOW_TRACKER = "mlflow" -AVAILABLE_TRACKERS = [AIMSTACK_TRACKER, FILE_LOGGING_TRACKER] +AVAILABLE_TRACKERS = [AIMSTACK_TRACKER, FILE_LOGGING_TRACKER, MLFLOW_TRACKER] # Trackers which can be used @@ -35,12 +36,21 @@ # One time package check for list of external trackers. _is_aim_available = _is_package_available("aim") +_is_mlflow_available = _is_package_available("mlflow") def _get_tracker_class(T, C): return {"tracker": T, "config": C} +def _is_tracker_installed(name): + if name == "aim": + return _is_aim_available + if name == "mlflow": + return _is_mlflow_available + return False + + def _register_aim_tracker(): # pylint: disable=import-outside-toplevel if _is_aim_available: @@ -60,10 +70,23 @@ def _register_aim_tracker(): ) -def _is_tracker_installed(name): - if name == "aim": - return _is_aim_available - return False +def _register_mlflow_tracker(): + # pylint: disable=import-outside-toplevel + if _is_mlflow_available: + # Local + from .mlflow_tracker import MLflowTracker + from tuning.config.tracker_configs import MLflowConfig + + mlflow_tracker = _get_tracker_class(MLflowTracker, MLflowConfig) + + REGISTERED_TRACKERS[MLFLOW_TRACKER] = mlflow_tracker + logging.info("Registered mlflow tracker") + else: + logging.info( + "Not registering mlflow tracker due to unavailablity of package.\n" + "Please install mlflow if you intend to use it.\n" + "\t pip install mlflow" + ) def _register_file_logging_tracker(): @@ -81,6 +104,8 @@ def _register_trackers(): _register_aim_tracker() if FILE_LOGGING_TRACKER not in REGISTERED_TRACKERS: _register_file_logging_tracker() + if MLFLOW_TRACKER not in REGISTERED_TRACKERS: + _register_mlflow_tracker() def _get_tracker_config_by_name(name: str, tracker_configs: TrackerConfigFactory): @@ -111,6 +136,7 @@ def get_tracker(name: str, tracker_configs: TrackerConfigFactory): Valid classes available are, tuning.trackers.tracker.aimstack_tracker.AimStackTracker, tuning.trackers.tracker.filelogging_tracker.FileLoggingTracker + tuning.trackers.tracker.mlflow_tracker.MLflowTracker Examples: file_logging_tracker = get_tracker("file_logger", TrackerConfigFactory( @@ -120,10 +146,16 @@ def get_tracker(name: str, tracker_configs: TrackerConfigFactory): )) aim_tracker = get_tracker("aim", TrackerConfigFactory( aim_config=AimConfig( - experiment="unit_test", + experiment="test", aim_repo=tempdir + "/" ) )) + mlflow_tracker = get_tracker("mlflow", TrackerConfigFactory( + mlflow_config=MLflowConfig( + experiment="test", + mlflow_tracking_uri="./mlflow.sqlite" + ) + )) """ if not REGISTERED_TRACKERS: # a one time step. @@ -131,14 +163,12 @@ def get_tracker(name: str, tracker_configs: TrackerConfigFactory): if name not in REGISTERED_TRACKERS: if name in AVAILABLE_TRACKERS and (not _is_tracker_installed(name)): - e = "Requested tracker {} is not installed. Please install before proceeding".format( - name - ) + e = f"Requested tracker {name} is not installed.\ + Please install before proceeding" else: available = ", ".join(str(t) for t in AVAILABLE_TRACKERS) - e = "Requested Tracker {} not found. List trackers available for use is - {} ".format( - name, available - ) + e = f"Requested Tracker {name} not found.\ + List trackers available for use is - {available} " logging.error(e) raise ValueError(e)