Skip to content

Commit

Permalink
Warn if no type hints provided for PythonModel (mlflow#14235)
Browse files Browse the repository at this point in the history
Signed-off-by: serena-ruan <[email protected]>
  • Loading branch information
serena-ruan authored Jan 15, 2025
1 parent 814ed74 commit 17d4887
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 0 deletions.
12 changes: 12 additions & 0 deletions mlflow/pyfunc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3027,12 +3027,24 @@ def predict(self, context, model_input: List[str], params=None) -> List[str]:
)
elif python_model is not None:
if callable(python_model):
# TODO: support input_example and TypeFromExample for callables
# first argument is the model input
type_hints = _extract_type_hints(python_model, input_arg_index=0)
if not _signature_cannot_be_inferred_from_type_hint(type_hints.input):
signature_from_type_hints = _infer_signature_from_type_hints(
func=python_model, type_hints=type_hints, input_example=input_example
)
pyfunc_decorator_used = getattr(python_model, "_is_pyfunc", False)
# only show the warning here if @pyfunc is not applied on the function
# since @pyfunc will trigger the warning instead
if type_hints.input is None and not pyfunc_decorator_used:
# TODO: add link to documentation
color_warning(
"Add type hints to the `predict` method to enable "
"data validation and automatic signature inference.",
stacklevel=1,
color="yellow",
)
elif isinstance(python_model, PythonModel):
saved_example = _save_example(mlflow_model, input_example, path, example_no_conversion)
type_hints = python_model.predict_type_hints
Expand Down
8 changes: 8 additions & 0 deletions mlflow/pyfunc/utils/data_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,14 @@ def _get_func_info_if_type_hint_supported(func) -> Optional[FuncInfo]:
)
else:
return FuncInfo(input_type_hint=type_hint, input_param_name=input_param_name)
else:
# TODO: add link to documentation
color_warning(
"Add type hints to the `predict` method to enable data validation "
"and automatic signature inference during model logging.",
stacklevel=1,
color="yellow",
)


def _model_input_index_in_function_signature(func):
Expand Down
22 changes: 22 additions & 0 deletions tests/pyfunc/test_pyfunc_model_with_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,3 +907,25 @@ def predict(self, model_input: list[Message], params=None):
@pyfunc
def predict(model_input: list[Message]):
return model_input


def test_python_model_without_type_hint_warning():
msg = r"Add type hints to the `predict` method"
with pytest.warns(UserWarning, match=msg):

class PythonModelWithoutTypeHint(mlflow.pyfunc.PythonModel):
def predict(self, model_input, params=None):
return model_input

with pytest.warns(UserWarning, match=msg):

@pyfunc
def predict(model_input):
return model_input

def predict(model_input):
return model_input

with mlflow.start_run():
with pytest.warns(UserWarning, match=msg):
mlflow.pyfunc.log_model("model", python_model=predict, input_example="abc")

0 comments on commit 17d4887

Please sign in to comment.