Skip to content

Commit

Permalink
Update Dependency Extraction for Agents (mlflow#13105)
Browse files Browse the repository at this point in the history
Signed-off-by: aravind-segu <[email protected]>
  • Loading branch information
aravind-segu authored Sep 10, 2024
1 parent cf70cae commit 22f6bc1
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 0 deletions.
5 changes: 5 additions & 0 deletions mlflow/langchain/databricks_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,11 @@ def _traverse_runnable(
# Visit the returned graph
for node in lc_model.get_graph().nodes.values():
yield from _traverse_runnable(node.data, visited)

# Visit the variables of the function as well to extract dependencies
if hasattr(lc_model, "func") and lc_model.func is not None:
for node in inspect.getclosurevars(lc_model.func).globals.values():
yield from _traverse_runnable(node, visited)
else:
# No-op for non-runnable, if any
pass
Expand Down
74 changes: 74 additions & 0 deletions tests/langchain/sample_code/langgraph_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Any, List, Literal, Optional

from langchain_core.runnables import RunnableLambda
from langchain_core.tools import tool
from langgraph.prebuilt import create_react_agent

import mlflow


def get_fake_chat_model(endpoint="fake-endpoint"):
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models import ChatDatabricks, ChatMlflow
from langchain.schema.messages import BaseMessage
from langchain_core.outputs import ChatResult

class FakeChatModel(ChatDatabricks):
"""Fake Chat Model wrapper for testing purposes."""

endpoint: str = "fake-endpoint"

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
response = {
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "test_content",
},
"finish_reason": None,
}
],
}
return ChatMlflow._create_chat_result(response)

@property
def _llm_type(self) -> str:
return "fake chat model"

return FakeChatModel(endpoint=endpoint)


@tool
def get_weather(city: Literal["nyc", "sf"]):
"""Use this to get weather information."""
if city == "nyc":
return "It might be cloudy in nyc"
elif city == "sf":
return "It's always sunny in sf"


llm = get_fake_chat_model()
tools = [get_weather]
agent = create_react_agent(llm, tools)


def wrap_lg(input):
if not isinstance(input, dict):
if isinstance(input, list) and len(input) > 0:
# Extract the content from the HumanMessage
content = input[0].content.strip('"')
input = {"messages": [{"role": "user", "content": content}]}
return agent.invoke(input)


chain = RunnableLambda(wrap_lg)

mlflow.models.set_model(chain)
23 changes: 23 additions & 0 deletions tests/langchain/test_langchain_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3502,3 +3502,26 @@ def test_signature_inference_fails(monkeypatch: pytest.MonkeyPatch):
input_example={"chat": []},
)
assert model_info.signature is None


@pytest.mark.skipif(
Version(langchain.__version__) < Version("0.2.0"),
reason="Langgraph are not supported the way we want in earlier versions",
)
def test_langgraph_agent_log_model_from_code():
input_example = {"messages": [{"role": "user", "content": "what is the weather in sf?"}]}

pyfunc_artifact_path = "weather_agent"
with mlflow.start_run() as run:
mlflow.langchain.log_model(
lc_model="tests/langchain/sample_code/langgraph_agent.py",
artifact_path=pyfunc_artifact_path,
input_example=input_example,
)
pyfunc_model_uri = f"runs:/{run.info.run_id}/{pyfunc_artifact_path}"
pyfunc_model_path = _download_artifact_from_uri(pyfunc_model_uri)
reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel"))
actual = reloaded_model.resources["databricks"]
expected = {"serving_endpoint": [{"name": "fake-endpoint"}]}
assert all(item in actual["serving_endpoint"] for item in expected["serving_endpoint"])
assert all(item in expected["serving_endpoint"] for item in actual["serving_endpoint"])

0 comments on commit 22f6bc1

Please sign in to comment.