Skip to content

Commit

Permalink
Add mlflow tracker and unit testing for the same.
Browse files Browse the repository at this point in the history
Signed-off-by: Dushyant Behl <[email protected]>
  • Loading branch information
dushyantbehl committed Dec 19, 2024
1 parent 4441948 commit 2676159
Show file tree
Hide file tree
Showing 7 changed files with 412 additions and 23 deletions.
6 changes: 4 additions & 2 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def test_parse_arguments(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_copy)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert data_args.dataset_text_field == "output"
Expand All @@ -381,6 +382,7 @@ def test_parse_arguments_defaults(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_defaults)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert model_args.use_flash_attn is False
Expand All @@ -391,14 +393,14 @@ def test_parse_arguments_peft_method(job_config):
parser = sft_trainer.get_parser()
job_config_pt = copy.deepcopy(job_config)
job_config_pt["peft_method"] = "pt"
_, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments(
_, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_pt
)
assert isinstance(tune_config, peft_config.PromptTuningConfig)

job_config_lora = copy.deepcopy(job_config)
job_config_lora["peft_method"] = "lora"
_, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments(
_, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_lora
)
assert isinstance(tune_config, peft_config.LoraConfig)
Expand Down
2 changes: 1 addition & 1 deletion tests/trackers/test_aim_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def fixture_aimrepo():


@pytest.mark.skipif(aim_not_available, reason="Requires aimstack to be installed")
def test_run_with_good_tracker_name_but_no_args():
def test_run_with_aim_tracker_name_but_no_args():
"""Ensure that train() raises error with aim tracker name but no args"""

with tempfile.TemporaryDirectory() as tempdir:
Expand Down
124 changes: 124 additions & 0 deletions tests/trackers/test_mlflow_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/

# Standard
import copy
import json
import os
import tempfile

# Third Party
from transformers.utils.import_utils import _is_package_available
import pytest

# First Party
from tests.test_sft_trainer import (
DATA_ARGS,
MODEL_ARGS,
TRAIN_ARGS,
_get_checkpoint_path,
_test_run_inference,
_validate_training,
)

# Local
from tuning import sft_trainer
from tuning.config.tracker_configs import MLflowConfig, TrackerConfigFactory

mlflow_not_available = not _is_package_available("mlflow")


@pytest.mark.skipif(mlflow_not_available, reason="Requires mlflow to be installed")
def test_run_with_mlflow_tracker_name_but_no_args():
"""Ensure that train() raises error with mlflow tracker name but no args"""

with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

train_args.trackers = ["mlflow"]

with pytest.raises(
ValueError,
match="mlflow tracker requested but mlflow_uri is not specified.",
):
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args)


@pytest.mark.skipif(mlflow_not_available, reason="Requires mlflow to be installed")
def test_e2e_run_with_mlflow_tracker():
"""Ensure that training succeeds with mlflow tracker"""

with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

# This should not mean file logger is not present.
# code will add it by default
# The below validate_training check will test for that too.
train_args.trackers = ["mlflow"]

tracker_configs = TrackerConfigFactory(
mlflow_config=MLflowConfig(
mlflow_experiment="unit_test",
mlflow_tracking_uri=os.path.join(tempdir, "mlflowdb.sqlite"),
)
)

sft_trainer.train(
MODEL_ARGS, DATA_ARGS, train_args, tracker_configs=tracker_configs
)

# validate ft tuning configs
_validate_training(tempdir)

# validate inference
_test_run_inference(checkpoint_path=_get_checkpoint_path(tempdir))


@pytest.mark.skipif(mlflow_not_available, reason="Requires mlflow to be installed")
def test_e2e_run_with_mlflow_runuri_export_default_path():
"""Ensure that mlflow outputs run uri in the output dir by default"""

with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

train_args.trackers = ["mlflow"]

tracker_configs = TrackerConfigFactory(
mlflow_config=MLflowConfig(
mlflow_experiment="unit_test",
mlflow_tracking_uri=os.path.join(tempdir, "mlflowdb.sqlite"),
)
)

sft_trainer.train(
MODEL_ARGS, DATA_ARGS, train_args, tracker_configs=tracker_configs
)

# validate ft tuning configs
_validate_training(tempdir)

run_uri_file = os.path.join(tempdir, "mlflow_tracker.json")

assert os.path.exists(run_uri_file) is True
assert os.path.getsize(run_uri_file) > 0

with open(run_uri_file, "r", encoding="utf-8") as f:
content = json.loads(f.read())
assert "run_uri" in content
20 changes: 20 additions & 0 deletions tuning/config/tracker_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,27 @@ def __post_init__(self):
)


@dataclass
class MLflowConfig:
# Name of the experiment
mlflow_experiment: str = None
mlflow_tracking_uri: str = None
# Location of where mlflow's run uri is to be exported.
# If mlflow_run_uri_export_path is set the run uri will be output in a json format
# to the location pointed to by `mlflow_run_uri_export_path/mlflow_tracker.json`
# If this is not set then the default location where run uri will be exported
# is training_args.output_dir/mlflow_tracker.json
# Run uri is not exported if mlflow_run_uri_export_path variable is not set
# and output_dir is not specified.
mlflow_run_uri_export_path: str = None

def __post_init__(self):
if self.mlflow_experiment is None:
self.mlflow_experiment = "fms-hf-tuning"


@dataclass
class TrackerConfigFactory:
file_logger_config: FileLoggingTrackerConfig = None
aim_config: AimConfig = None
mlflow_config: MLflowConfig = None
25 changes: 18 additions & 7 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from tuning.config.tracker_configs import (
AimConfig,
FileLoggingTrackerConfig,
MLflowConfig,
TrackerConfigFactory,
)
from tuning.data.setup_dataprocessor import process_dataargs
Expand Down Expand Up @@ -434,6 +435,7 @@ def get_parser():
QuantizedLoraConfig,
FusedOpsAndKernelsConfig,
AttentionAndDistributedPackingConfig,
MLflowConfig,
)
)
parser.add_argument(
Expand Down Expand Up @@ -483,8 +485,10 @@ def parse_arguments(parser, json_config=None):
Configuration for fused operations and kernels.
AttentionAndDistributedPackingConfig
Configuration for padding free and packing.
MLflowConfig
Configuration for mlflow tracker.
dict[str, str]
Extra AIM metadata.
Extra tracker metadata.
"""
if json_config:
(
Expand All @@ -499,6 +503,7 @@ def parse_arguments(parser, json_config=None):
quantized_lora_config,
fusedops_kernels_config,
attention_and_distributed_packing_config,
mlflow_config,
) = parser.parse_dict(json_config, allow_extra_keys=True)
peft_method = json_config.get("peft_method")
exp_metadata = json_config.get("exp_metadata")
Expand All @@ -515,6 +520,7 @@ def parse_arguments(parser, json_config=None):
quantized_lora_config,
fusedops_kernels_config,
attention_and_distributed_packing_config,
mlflow_config,
additional,
_,
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
Expand All @@ -540,6 +546,7 @@ def parse_arguments(parser, json_config=None):
quantized_lora_config,
fusedops_kernels_config,
attention_and_distributed_packing_config,
mlflow_config,
exp_metadata,
)

Expand All @@ -561,6 +568,7 @@ def main():
quantized_lora_config,
fusedops_kernels_config,
attention_and_distributed_packing_config,
mlflow_config,
exp_metadata,
) = parse_arguments(parser, job_config)

Expand All @@ -572,7 +580,8 @@ def main():
model_args %s, data_args %s, training_args %s, trainer_controller_args %s, \
tune_config %s, file_logger_config, %s aim_config %s, \
quantized_lora_config %s, fusedops_kernels_config %s, \
attention_and_distributed_packing_config %s exp_metadata %s",
attention_and_distributed_packing_config %s,\
mlflow_config %s, exp_metadata %s",
model_args,
data_args,
training_args,
Expand All @@ -583,6 +592,7 @@ def main():
quantized_lora_config,
fusedops_kernels_config,
attention_and_distributed_packing_config,
mlflow_config,
exp_metadata,
)
except Exception as e: # pylint: disable=broad-except
Expand All @@ -607,10 +617,11 @@ def main():
"failed while parsing extra metadata. pass a valid json %s", repr(e)
)

combined_tracker_configs = TrackerConfigFactory()

combined_tracker_configs.file_logger_config = file_logger_config
combined_tracker_configs.aim_config = aim_config
tracker_configs = TrackerConfigFactory(
file_logger_config=file_logger_config,
aim_config=aim_config,
mlflow_config=mlflow_config,
)

if training_args.output_dir:
os.makedirs(training_args.output_dir, exist_ok=True)
Expand All @@ -622,7 +633,7 @@ def main():
train_args=training_args,
peft_config=tune_config,
trainer_controller_args=trainer_controller_args,
tracker_configs=combined_tracker_configs,
tracker_configs=tracker_configs,
additional_callbacks=None,
exp_metadata=metadata,
quantized_lora_config=quantized_lora_config,
Expand Down
Loading

0 comments on commit 2676159

Please sign in to comment.