diff --git a/mlflow/langchain/databricks_dependencies.py b/mlflow/langchain/databricks_dependencies.py index 5762f040574ad..0cbbb67cfffc4 100644 --- a/mlflow/langchain/databricks_dependencies.py +++ b/mlflow/langchain/databricks_dependencies.py @@ -242,6 +242,11 @@ def _traverse_runnable( # Visit the returned graph for node in lc_model.get_graph().nodes.values(): yield from _traverse_runnable(node.data, visited) + + # Visit the variables of the function as well to extract dependencies + if hasattr(lc_model, "func") and lc_model.func is not None: + for node in inspect.getclosurevars(lc_model.func).globals.values(): + yield from _traverse_runnable(node, visited) else: # No-op for non-runnable, if any pass diff --git a/tests/langchain/sample_code/langgraph_agent.py b/tests/langchain/sample_code/langgraph_agent.py new file mode 100644 index 0000000000000..1dd8fcd0f84b2 --- /dev/null +++ b/tests/langchain/sample_code/langgraph_agent.py @@ -0,0 +1,74 @@ +from typing import Any, List, Literal, Optional + +from langchain_core.runnables import RunnableLambda +from langchain_core.tools import tool +from langgraph.prebuilt import create_react_agent + +import mlflow + + +def get_fake_chat_model(endpoint="fake-endpoint"): + from langchain.callbacks.manager import CallbackManagerForLLMRun + from langchain.chat_models import ChatDatabricks, ChatMlflow + from langchain.schema.messages import BaseMessage + from langchain_core.outputs import ChatResult + + class FakeChatModel(ChatDatabricks): + """Fake Chat Model wrapper for testing purposes.""" + + endpoint: str = "fake-endpoint" + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + response = { + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "test_content", + }, + "finish_reason": None, + } + ], + } + return ChatMlflow._create_chat_result(response) + + @property + def _llm_type(self) -> str: + return "fake chat model" + + return FakeChatModel(endpoint=endpoint) + + +@tool +def get_weather(city: Literal["nyc", "sf"]): + """Use this to get weather information.""" + if city == "nyc": + return "It might be cloudy in nyc" + elif city == "sf": + return "It's always sunny in sf" + + +llm = get_fake_chat_model() +tools = [get_weather] +agent = create_react_agent(llm, tools) + + +def wrap_lg(input): + if not isinstance(input, dict): + if isinstance(input, list) and len(input) > 0: + # Extract the content from the HumanMessage + content = input[0].content.strip('"') + input = {"messages": [{"role": "user", "content": content}]} + return agent.invoke(input) + + +chain = RunnableLambda(wrap_lg) + +mlflow.models.set_model(chain) diff --git a/tests/langchain/test_langchain_model_export.py b/tests/langchain/test_langchain_model_export.py index d43cf1bed15e2..964751d6f5621 100644 --- a/tests/langchain/test_langchain_model_export.py +++ b/tests/langchain/test_langchain_model_export.py @@ -3502,3 +3502,26 @@ def test_signature_inference_fails(monkeypatch: pytest.MonkeyPatch): input_example={"chat": []}, ) assert model_info.signature is None + + +@pytest.mark.skipif( + Version(langchain.__version__) < Version("0.2.0"), + reason="Langgraph are not supported the way we want in earlier versions", +) +def test_langgraph_agent_log_model_from_code(): + input_example = {"messages": [{"role": "user", "content": "what is the weather in sf?"}]} + + pyfunc_artifact_path = "weather_agent" + with mlflow.start_run() as run: + mlflow.langchain.log_model( + lc_model="tests/langchain/sample_code/langgraph_agent.py", + artifact_path=pyfunc_artifact_path, + input_example=input_example, + ) + pyfunc_model_uri = f"runs:/{run.info.run_id}/{pyfunc_artifact_path}" + pyfunc_model_path = _download_artifact_from_uri(pyfunc_model_uri) + reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel")) + actual = reloaded_model.resources["databricks"] + expected = {"serving_endpoint": [{"name": "fake-endpoint"}]} + assert all(item in actual["serving_endpoint"] for item in expected["serving_endpoint"]) + assert all(item in expected["serving_endpoint"] for item in actual["serving_endpoint"])