From 17d4887795a6f9ad23478291ec18f59cb54b467d Mon Sep 17 00:00:00 2001 From: Serena Ruan <82044803+serena-ruan@users.noreply.github.com> Date: Wed, 15 Jan 2025 14:45:31 +0800 Subject: [PATCH] Warn if no type hints provided for PythonModel (#14235) Signed-off-by: serena-ruan --- mlflow/pyfunc/__init__.py | 12 ++++++++++ mlflow/pyfunc/utils/data_validation.py | 8 +++++++ .../test_pyfunc_model_with_type_hints.py | 22 +++++++++++++++++++ 3 files changed, 42 insertions(+) diff --git a/mlflow/pyfunc/__init__.py b/mlflow/pyfunc/__init__.py index dcd22a68da245..9869f0817cab4 100644 --- a/mlflow/pyfunc/__init__.py +++ b/mlflow/pyfunc/__init__.py @@ -3027,12 +3027,24 @@ def predict(self, context, model_input: List[str], params=None) -> List[str]: ) elif python_model is not None: if callable(python_model): + # TODO: support input_example and TypeFromExample for callables # first argument is the model input type_hints = _extract_type_hints(python_model, input_arg_index=0) if not _signature_cannot_be_inferred_from_type_hint(type_hints.input): signature_from_type_hints = _infer_signature_from_type_hints( func=python_model, type_hints=type_hints, input_example=input_example ) + pyfunc_decorator_used = getattr(python_model, "_is_pyfunc", False) + # only show the warning here if @pyfunc is not applied on the function + # since @pyfunc will trigger the warning instead + if type_hints.input is None and not pyfunc_decorator_used: + # TODO: add link to documentation + color_warning( + "Add type hints to the `predict` method to enable " + "data validation and automatic signature inference.", + stacklevel=1, + color="yellow", + ) elif isinstance(python_model, PythonModel): saved_example = _save_example(mlflow_model, input_example, path, example_no_conversion) type_hints = python_model.predict_type_hints diff --git a/mlflow/pyfunc/utils/data_validation.py b/mlflow/pyfunc/utils/data_validation.py index 2176a6e9d8d91..1f6eeae89546b 100644 --- a/mlflow/pyfunc/utils/data_validation.py +++ b/mlflow/pyfunc/utils/data_validation.py @@ -130,6 +130,14 @@ def _get_func_info_if_type_hint_supported(func) -> Optional[FuncInfo]: ) else: return FuncInfo(input_type_hint=type_hint, input_param_name=input_param_name) + else: + # TODO: add link to documentation + color_warning( + "Add type hints to the `predict` method to enable data validation " + "and automatic signature inference during model logging.", + stacklevel=1, + color="yellow", + ) def _model_input_index_in_function_signature(func): diff --git a/tests/pyfunc/test_pyfunc_model_with_type_hints.py b/tests/pyfunc/test_pyfunc_model_with_type_hints.py index 599558e09f77b..ea19ceff4df58 100644 --- a/tests/pyfunc/test_pyfunc_model_with_type_hints.py +++ b/tests/pyfunc/test_pyfunc_model_with_type_hints.py @@ -907,3 +907,25 @@ def predict(self, model_input: list[Message], params=None): @pyfunc def predict(model_input: list[Message]): return model_input + + +def test_python_model_without_type_hint_warning(): + msg = r"Add type hints to the `predict` method" + with pytest.warns(UserWarning, match=msg): + + class PythonModelWithoutTypeHint(mlflow.pyfunc.PythonModel): + def predict(self, model_input, params=None): + return model_input + + with pytest.warns(UserWarning, match=msg): + + @pyfunc + def predict(model_input): + return model_input + + def predict(model_input): + return model_input + + with mlflow.start_run(): + with pytest.warns(UserWarning, match=msg): + mlflow.pyfunc.log_model("model", python_model=predict, input_example="abc")