Skip to content

Commit

Permalink
Add mlflow tracker and unit testing for the same.
Browse files Browse the repository at this point in the history
Also add mlflow docs and add mlflow to docker file and as optional requirement

Signed-off-by: Dushyant Behl <[email protected]>
  • Loading branch information
dushyantbehl committed Dec 19, 2024
1 parent 42e3077 commit 609b98e
Show file tree
Hide file tree
Showing 13 changed files with 481 additions and 37 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ venv/
# Aim
.aim

# Mlflow
mlruns/

# Backup files and folders
*.bkp
*.bkp.*
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -823,12 +823,13 @@ For details about how you can use set a custom stopping criteria and perform cus

## Experiment Tracking

Experiment tracking in fms-hf-tuning allows users to track their experiments with known trackers like [Aimstack](https://aimstack.io/) or custom trackers built into the code like
Experiment tracking in fms-hf-tuning allows users to track their experiments with known trackers like [Aimstack](https://aimstack.io/), [MLflow Tracking](https://mlflow.org/docs/latest/tracking.html) or custom trackers built into the code like
[FileLoggingTracker](./tuning/trackers/filelogging_tracker.py)

The code supports currently two trackers out of the box,
* `FileLoggingTracker` : A built in tracker which supports logging training loss to a file.
* `Aimstack` : A popular opensource tracker which can be used to track any metrics or metadata from the experiments.
* `MLflow Tracking` : Another popular opensource tracker which stores metrics, metadata or even artifacts from experiments.

Further details on enabling and using the trackers mentioned above can be found [here](docs/experiment-tracking.md).

Expand Down
7 changes: 6 additions & 1 deletion build/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ ARG USER=tuning
ARG USER_UID=1000
ARG PYTHON_VERSION=3.11
ARG WHEEL_VERSION=""
## Enable Aimstack if requested via ENABLE_AIM set to "true"
## Enable Aimstack or MLflow if requested via ENABLE_AIM/MLFLOW set to "true"
ARG ENABLE_AIM=false
ARG ENABLE_MLFLOW=false
ARG ENABLE_FMS_ACCELERATION=true

## Base Layer ##################################################################
Expand Down Expand Up @@ -151,6 +152,10 @@ RUN if [[ "${ENABLE_AIM}" == "true" ]]; then \
python -m pip install --user "$(head bdist_name)[aim]"; \
fi

RUN if [[ "${ENABLE_MLFLOW}" == "true" ]]; then \
python -m pip install --user "$(head bdist_name)[mlflow]"; \
fi

# Clean up the wheel module. It's only needed by flash-attn install
RUN python -m pip uninstall wheel build -y && \
# Cleanup the bdist whl file
Expand Down
32 changes: 30 additions & 2 deletions docs/experiment-tracking.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,34 @@ sft_trainer.train(train_args=training_args, tracker_configs=tracker_configs,....
The code expects either the `local` or `remote` repo to be specified and will result in a `ValueError` otherwise.
See [AimConfig](https://github.com/foundation-model-stack/fms-hf-tuning/blob/a9b8ec8d1d50211873e63fa4641054f704be8712/tuning/config/tracker_configs.py#L25) for more details.

## MLflow Tracker

To enable [MLflow Tracking](https://mlflow.org/docs/latest/tracking.html) users need to pass `"mlflow"` as the requested tracker as part of the [training argument](https://github.com/foundation-model-stack/fms-hf-tuning/blob/a9b8ec8d1d50211873e63fa4641054f704be8712/tuning/config/configs.py#L131).


When using MLflow, users need to specify additional arguments which specify [mlflow tracking uri](https://mlflow.org/docs/latest/tracking.html#common-setups) location where either a [mlflow supported database](https://mlflow.org/docs/latest/tracking/backend-stores.html#supported-store-types) or [mlflow remote tracking server](https://mlflow.org/docs/latest/tracking/server.html) is running.

Example
```
from tuning import sft_trainer
from tuning.config.tracker_configs import MLflowConfig, TrackerConfigFactory
training_args = TrainingArguments(
...,
trackers = ["mlflow"],
)
tracker_configs = TrackerConfigFactory(
mlflow_config=MLflowConfig(
mlflow_experiment="experiment-name",
mlflow_tracking_uri=<tracking uri>
)
)
sft_trainer.train(train_args=training_args, tracker_configs=tracker_configs,....)
```

The code expects a valid uri to be specified and will result in a `ValueError` otherwise.

## Running the code via command line `tuning/sft_trainer::main` function

Expand All @@ -123,10 +151,10 @@ If running the code via main function of [sft_trainer.py](../tuning/sft_trainer.
To enable tracking please pass

```
--tracker <aim/file_logger>
--tracker <aim/file_logger/mlflow>
```

To further customise tracking you can specify additional arguments needed by the tracker like
To further customise tracking you can specify additional arguments needed by the tracker like (example shows aim follow similarly for mlflow)

```
--tracker aim --aim_repo <path-to-aimrepo> --experiment <experiment-name>
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ dependencies = [
dev = ["wheel>=0.42.0,<1.0", "packaging>=23.2,<25", "ninja>=1.11.1.1,<2.0", "scikit-learn>=1.0, <2.0", "boto3>=1.34, <2.0"]
flash-attn = ["flash-attn>=2.5.3,<3.0"]
aim = ["aim>=3.19.0,<4.0"]
mlflow = ["mlflow"]
fms-accel = ["fms-acceleration>=0.1"]
gptq-dev = ["auto_gptq>0.4.2", "optimum>=1.15.0"]

Expand Down
6 changes: 4 additions & 2 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def test_parse_arguments(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_copy)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert data_args.dataset_text_field == "output"
Expand All @@ -384,6 +385,7 @@ def test_parse_arguments_defaults(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_defaults)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert model_args.use_flash_attn is False
Expand All @@ -394,14 +396,14 @@ def test_parse_arguments_peft_method(job_config):
parser = sft_trainer.get_parser()
job_config_pt = copy.deepcopy(job_config)
job_config_pt["peft_method"] = "pt"
_, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments(
_, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_pt
)
assert isinstance(tune_config, peft_config.PromptTuningConfig)

job_config_lora = copy.deepcopy(job_config)
job_config_lora["peft_method"] = "lora"
_, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments(
_, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_lora
)
assert isinstance(tune_config, peft_config.LoraConfig)
Expand Down
2 changes: 1 addition & 1 deletion tests/trackers/test_aim_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def fixture_aimrepo():


@pytest.mark.skipif(aim_not_available, reason="Requires aimstack to be installed")
def test_run_with_good_tracker_name_but_no_args():
def test_run_with_aim_tracker_name_but_no_args():
"""Ensure that train() raises error with aim tracker name but no args"""

with tempfile.TemporaryDirectory() as tempdir:
Expand Down
138 changes: 138 additions & 0 deletions tests/trackers/test_mlflow_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/

# Standard
import copy
import json
import os
import tempfile

# Third Party
from transformers.utils.import_utils import _is_package_available
import pytest

# First Party
from tests.test_sft_trainer import (
DATA_ARGS,
MODEL_ARGS,
TRAIN_ARGS,
_get_checkpoint_path,
_test_run_inference,
_validate_training,
)

# Local
from tuning import sft_trainer
from tuning.config.tracker_configs import MLflowConfig, TrackerConfigFactory

mlflow_not_available = not _is_package_available("mlflow")


@pytest.mark.skipif(mlflow_not_available, reason="Requires mlflow to be installed")
def test_run_with_mlflow_tracker_name_but_no_args():
"""Ensure that train() raises error with mlflow tracker name but no args"""

with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

train_args.trackers = ["mlflow"]

with pytest.raises(
ValueError,
match="mlflow tracker requested but mlflow_uri is not specified.",
):
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args)


@pytest.mark.skipif(mlflow_not_available, reason="Requires mlflow to be installed")
def test_e2e_run_with_mlflow_tracker():
"""Ensure that training succeeds with mlflow tracker"""

# mlflow performs a cleanup at callback close time which happens post the
# delete of this directory so we run into two issues
# 1. the temp directory cannot be cleared as it has open pointer by mlflow
# 2. mlflow complaints that it cannot find a run which it just created.
# this is a race condition which is fixed with mkdtemp() which doesn't delete
tempdir = tempfile.mkdtemp()

train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

# This should not mean file logger is not present.
# code will add it by default
# The below validate_training check will test for that too.
train_args.trackers = ["mlflow"]

mlflow_path = os.path.join(tempdir, "mlflow")

tracker_configs = TrackerConfigFactory(
mlflow_config=MLflowConfig(
mlflow_experiment="unit_test",
mlflow_tracking_uri=f"file://{mlflow_path}",
)
)

sft_trainer.train(
MODEL_ARGS, DATA_ARGS, train_args, tracker_configs=tracker_configs
)

# validate ft tuning configs
_validate_training(tempdir)

assert os.path.exists(mlflow_path) and os.path.isdir(mlflow_path)

# validate inference
_test_run_inference(checkpoint_path=_get_checkpoint_path(tempdir))


@pytest.mark.skipif(mlflow_not_available, reason="Requires mlflow to be installed")
def test_e2e_run_with_mlflow_runuri_export_default_path():
"""Ensure that mlflow outputs run uri in the output dir by default"""

tempdir = tempfile.mkdtemp()
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

train_args.trackers = ["mlflow"]

mlflow_path = os.path.join(tempdir, "mlflow")

tracker_configs = TrackerConfigFactory(
mlflow_config=MLflowConfig(
mlflow_experiment="unit_test",
mlflow_tracking_uri=f"file://{mlflow_path}",
)
)

sft_trainer.train(
MODEL_ARGS, DATA_ARGS, train_args, tracker_configs=tracker_configs
)

# validate ft tuning configs
_validate_training(tempdir)

assert os.path.exists(mlflow_path) and os.path.isdir(mlflow_path)

run_uri_file = os.path.join(tempdir, "mlflow_tracker.json")

assert os.path.exists(run_uri_file) is True
assert os.path.getsize(run_uri_file) > 0

with open(run_uri_file, "r", encoding="utf-8") as f:
content = json.loads(f.read())
assert "run_uri" in content
21 changes: 17 additions & 4 deletions tuning/config/tracker_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class FileLoggingTrackerConfig:
@dataclass
class AimConfig:
# Name of the experiment
experiment: str = None
experiment: str = "fms-hf-tuning"
# aim_repo can point to a locally accessible directory
# or a remote repository hosted on a server.
# When 'aim_remote_server_ip' or 'aim_remote_server_port' is set,
Expand All @@ -47,9 +47,6 @@ class AimConfig:
aim_run_id_export_path: str = None

def __post_init__(self):
if self.experiment is None:
self.experiment = "fms-hf-tuning"

if (
self.aim_remote_server_ip is not None
and self.aim_remote_server_port is not None
Expand All @@ -63,7 +60,23 @@ def __post_init__(self):
)


@dataclass
class MLflowConfig:
# Name of the experiment
mlflow_experiment: str = "fms-hf-tuning"
mlflow_tracking_uri: str = None
# Location of where mlflow's run uri is to be exported.
# If mlflow_run_uri_export_path is set the run uri will be output in a json format
# to the location pointed to by `mlflow_run_uri_export_path/mlflow_tracker.json`
# If this is not set then the default location where run uri will be exported
# is training_args.output_dir/mlflow_tracker.json
# Run uri is not exported if mlflow_run_uri_export_path variable is not set
# and output_dir is not specified.
mlflow_run_uri_export_path: str = None


@dataclass
class TrackerConfigFactory:
file_logger_config: FileLoggingTrackerConfig = None
aim_config: AimConfig = None
mlflow_config: MLflowConfig = None
Loading

0 comments on commit 609b98e

Please sign in to comment.