Skip to content

Commit

Permalink
[MLflow] Updating langchain databricks_dependency to use resources fo…
Browse files Browse the repository at this point in the history
…rmat as well (mlflow#11869)

Signed-off-by: Sunish Sheth <[email protected]>
  • Loading branch information
sunishsheth2009 authored May 2, 2024
1 parent d38872e commit 7494e32
Show file tree
Hide file tree
Showing 6 changed files with 266 additions and 58 deletions.
7 changes: 6 additions & 1 deletion mlflow/langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from mlflow.models import Model, ModelInputExample, ModelSignature, get_model_info
from mlflow.models.model import MLMODEL_FILE_NAME
from mlflow.models.model_config import _set_model_config
from mlflow.models.resources import _ResourceBuilder
from mlflow.models.signature import _infer_signature_from_input_example
from mlflow.models.utils import _convert_llm_input_data, _save_example
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
Expand Down Expand Up @@ -367,8 +368,12 @@ def load_retriever(persist_directory):
)

if Version(langchain.__version__) >= Version("0.0.311"):
if databricks_dependency := _detect_databricks_dependencies(lc_model):
(databricks_dependency, databricks_resources) = _detect_databricks_dependencies(lc_model)
if databricks_dependency:
flavor_conf[_DATABRICKS_DEPENDENCY_KEY] = databricks_dependency
if databricks_resources:
serialized_databricks_resources = _ResourceBuilder.from_resources(databricks_resources)
mlflow_model.resources = serialized_databricks_resources

mlflow_model.add_flavor(
FLAVOR_NAME,
Expand Down
58 changes: 36 additions & 22 deletions mlflow/langchain/databricks_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Set

from mlflow.models.resources import DatabricksServingEndpoint, DatabricksVectorSearchIndex, Resource

_DATABRICKS_DEPENDENCY_KEY = "databricks_dependency"
_DATABRICKS_VECTOR_SEARCH_INDEX_NAME_KEY = "databricks_vector_search_index_name"
_DATABRICKS_VECTOR_SEARCH_ENDPOINT_NAME_KEY = "databricks_vector_search_endpoint_name"
Expand All @@ -14,7 +16,7 @@


def _extract_databricks_dependencies_from_retriever(
retriever, dependency_dict: DefaultDict[str, List[Any]]
retriever, dependency_dict: DefaultDict[str, List[Any]], dependency_list: List[Resource]
):
try:
from langchain.embeddings import DatabricksEmbeddings as LegacyDatabricksEmbeddings
Expand All @@ -38,20 +40,18 @@ def _extract_databricks_dependencies_from_retriever(
index = vectorstore.index
dependency_dict[_DATABRICKS_VECTOR_SEARCH_INDEX_NAME_KEY].append(index.name)
dependency_dict[_DATABRICKS_VECTOR_SEARCH_ENDPOINT_NAME_KEY].append(index.endpoint_name)
dependency_list.append(DatabricksVectorSearchIndex(index_name=index.name))
dependency_list.append(DatabricksServingEndpoint(endpoint_name=index.endpoint_name))

embeddings = getattr(vectorstore, "embeddings", None)
if isinstance(embeddings, (DatabricksEmbeddings, LegacyDatabricksEmbeddings)):
dependency_dict[_DATABRICKS_EMBEDDINGS_ENDPOINT_NAME_KEY].append(embeddings.endpoint)
elif (
callable(getattr(vectorstore, "_is_databricks_managed_embeddings", None))
and vectorstore._is_databricks_managed_embeddings()
):
dependency_dict[_DATABRICKS_EMBEDDINGS_ENDPOINT_NAME_KEY].append(
"_is_databricks_managed_embeddings"
)
dependency_list.append(DatabricksServingEndpoint(endpoint_name=embeddings.endpoint))


def _extract_databricks_dependencies_from_llm(llm, dependency_dict: DefaultDict[str, List[Any]]):
def _extract_databricks_dependencies_from_llm(
llm, dependency_dict: DefaultDict[str, List[Any]], dependency_list: List[Resource]
):
try:
from langchain.llms import Databricks as LegacyDatabricks
except ImportError:
Expand All @@ -61,10 +61,11 @@ def _extract_databricks_dependencies_from_llm(llm, dependency_dict: DefaultDict[

if isinstance(llm, (LegacyDatabricks, Databricks)):
dependency_dict[_DATABRICKS_LLM_ENDPOINT_NAME_KEY].append(llm.endpoint_name)
dependency_list.append(DatabricksServingEndpoint(endpoint_name=llm.endpoint_name))


def _extract_databricks_dependencies_from_chat_model(
chat_model, dependency_dict: DefaultDict[str, List[Any]]
chat_model, dependency_dict: DefaultDict[str, List[Any]], dependency_list: List[Resource]
):
try:
from langchain.chat_models import ChatDatabricks as LegacyChatDatabricks
Expand All @@ -77,6 +78,7 @@ def _extract_databricks_dependencies_from_chat_model(

if isinstance(chat_model, (LegacyChatDatabricks, ChatDatabricks)):
dependency_dict[_DATABRICKS_CHAT_ENDPOINT_NAME_KEY].append(chat_model.endpoint)
dependency_list.append(DatabricksServingEndpoint(endpoint_name=chat_model.endpoint))


_LEGACY_MODEL_ATTR_SET = {
Expand All @@ -92,7 +94,9 @@ def _extract_databricks_dependencies_from_chat_model(
}


def _extract_dependency_dict_from_lc_model(lc_model, dependency_dict: DefaultDict[str, List[Any]]):
def _extract_dependency_dict_from_lc_model(
lc_model, dependency_dict: DefaultDict[str, List[Any]], dependency_list: List[Resource]
):
"""
This function contains the logic to examine a non-Runnable component of a langchain model.
The logic here does not cover all legacy chains. If you need to support a custom chain,
Expand All @@ -102,16 +106,23 @@ def _extract_dependency_dict_from_lc_model(lc_model, dependency_dict: DefaultDic
return

# leaf node
_extract_databricks_dependencies_from_chat_model(lc_model, dependency_dict)
_extract_databricks_dependencies_from_retriever(lc_model, dependency_dict)
_extract_databricks_dependencies_from_llm(lc_model, dependency_dict)
_extract_databricks_dependencies_from_chat_model(lc_model, dependency_dict, dependency_list)
_extract_databricks_dependencies_from_retriever(lc_model, dependency_dict, dependency_list)
_extract_databricks_dependencies_from_llm(lc_model, dependency_dict, dependency_list)

# recursively inspect legacy chain
for attr_name in _LEGACY_MODEL_ATTR_SET:
_extract_dependency_dict_from_lc_model(getattr(lc_model, attr_name, None), dependency_dict)
_extract_dependency_dict_from_lc_model(
getattr(lc_model, attr_name, None), dependency_dict, dependency_list
)


def _traverse_runnable(lc_model, dependency_dict: DefaultDict[str, List[Any]], visited: Set[str]):
def _traverse_runnable(
lc_model,
dependency_dict: DefaultDict[str, List[Any]],
dependency_list: List[Resource],
visited: Set[str],
):
"""
This function contains the logic to traverse a langchain_core.runnables.RunnableSerializable
object. It first inspects the current object using _extract_dependency_dict_from_lc_model
Expand All @@ -127,19 +138,21 @@ def _traverse_runnable(lc_model, dependency_dict: DefaultDict[str, List[Any]], v

# Visit the current object
visited.add(current_object_id)
_extract_dependency_dict_from_lc_model(lc_model, dependency_dict)
_extract_dependency_dict_from_lc_model(lc_model, dependency_dict, dependency_list)

if isinstance(lc_model, Runnable):
# Visit the returned graph
for node in lc_model.get_graph().nodes.values():
_traverse_runnable(node.data, dependency_dict, visited)
_traverse_runnable(node.data, dependency_dict, dependency_list, visited)
else:
# No-op for non-runnable, if any
pass
return


def _detect_databricks_dependencies(lc_model, log_errors_as_warnings=True) -> Dict[str, List[Any]]:
def _detect_databricks_dependencies(
lc_model, log_errors_as_warnings=True
) -> (Dict[str, List[Any]], List[Resource]):
"""
Detects the databricks dependencies of a langchain model and returns a dictionary of
detected endpoint names and index names.
Expand All @@ -162,14 +175,15 @@ def _detect_databricks_dependencies(lc_model, log_errors_as_warnings=True) -> Di
"""
try:
dependency_dict = defaultdict(list)
_traverse_runnable(lc_model, dependency_dict, set())
return dict(dependency_dict)
dependency_list = []
_traverse_runnable(lc_model, dependency_dict, dependency_list, set())
return (dict(dependency_dict), dependency_list)
except Exception:
if log_errors_as_warnings:
_logger.warning(
"Unable to detect Databricks dependencies. "
"Set logging level to DEBUG to see the full traceback."
)
_logger.debug("", exc_info=True)
return {}
return {}, []
raise
57 changes: 40 additions & 17 deletions mlflow/store/_unity_catalog/registry/rest_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,31 +176,54 @@ def get_model_version_dependencies(model_dir):
_DATABRICKS_LLM_ENDPOINT_NAME_KEY,
_DATABRICKS_VECTOR_SEARCH_INDEX_NAME_KEY,
)
from mlflow.models.resources import ResourceType

model = _load_model(model_dir)
model_info = model.get_model_info()
dependencies = []
index_names = _fetch_langchain_dependency_from_model_info(
model_info, _DATABRICKS_VECTOR_SEARCH_INDEX_NAME_KEY
)
for index_name in index_names:
dependencies.append({"type": "DATABRICKS_VECTOR_INDEX", "name": index_name})
for key in (
_DATABRICKS_EMBEDDINGS_ENDPOINT_NAME_KEY,
_DATABRICKS_LLM_ENDPOINT_NAME_KEY,
_DATABRICKS_CHAT_ENDPOINT_NAME_KEY,
):
endpoint_names = _fetch_langchain_dependency_from_model_info(model_info, key)

databricks_resources = getattr(model, "resources", {})

if databricks_resources:
databricks_dependencies = databricks_resources.get("databricks", {})
index_names = _fetch_langchain_dependency_from_model_info(
databricks_dependencies, ResourceType.VECTOR_SEARCH_INDEX.value
)
for index_name in index_names:
dependencies.append({"type": "DATABRICKS_VECTOR_INDEX", **index_name})
endpoint_names = _fetch_langchain_dependency_from_model_info(
databricks_dependencies, ResourceType.SERVING_ENDPOINT.value
)
for endpoint_name in endpoint_names:
dependencies.append({"type": "DATABRICKS_MODEL_ENDPOINT", "name": endpoint_name})
return dependencies
dependencies.append({"type": "DATABRICKS_MODEL_ENDPOINT", **endpoint_name})
else:
# import here to work around circular imports
from mlflow.langchain.databricks_dependencies import _DATABRICKS_DEPENDENCY_KEY

databricks_dependencies = model_info.flavors.get("langchain", {}).get(
_DATABRICKS_DEPENDENCY_KEY, {}
)

index_names = _fetch_langchain_dependency_from_model_info(
databricks_dependencies, _DATABRICKS_VECTOR_SEARCH_INDEX_NAME_KEY
)
for index_name in index_names:
dependencies.append({"type": "DATABRICKS_VECTOR_INDEX", "name": index_name})
for key in (
_DATABRICKS_EMBEDDINGS_ENDPOINT_NAME_KEY,
_DATABRICKS_LLM_ENDPOINT_NAME_KEY,
_DATABRICKS_CHAT_ENDPOINT_NAME_KEY,
):
endpoint_names = _fetch_langchain_dependency_from_model_info(
databricks_dependencies, key
)
for endpoint_name in endpoint_names:
dependencies.append({"type": "DATABRICKS_MODEL_ENDPOINT", "name": endpoint_name})
return dependencies

def _fetch_langchain_dependency_from_model_info(model_info, key):
# import here to work around circular imports
from mlflow.langchain.databricks_dependencies import _DATABRICKS_DEPENDENCY_KEY

return model_info.flavors.get("langchain", {}).get(_DATABRICKS_DEPENDENCY_KEY, {}).get(key, [])
def _fetch_langchain_dependency_from_model_info(databricks_dependencies, key):
return databricks_dependencies.get(key, [])


@experimental
Expand Down
31 changes: 26 additions & 5 deletions tests/langchain/test_langchain_databricks_dependency_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
_extract_databricks_dependencies_from_retriever,
)
from mlflow.langchain.utils import IS_PICKLE_SERIALIZATION_RESTRICTED
from mlflow.models.resources import DatabricksServingEndpoint, DatabricksVectorSearchIndex


class MockDatabricksServingEndpointClient:
Expand Down Expand Up @@ -54,8 +55,12 @@ def test_parsing_dependency_from_databricks_llm(monkeypatch: pytest.MonkeyPatch)

llm = Databricks(**llm_kwargs)
d = defaultdict(list)
_extract_databricks_dependencies_from_llm(llm, d)
resources = []
_extract_databricks_dependencies_from_llm(llm, d, resources)
assert d.get(_DATABRICKS_LLM_ENDPOINT_NAME_KEY) == ["databricks-mixtral-8x7b-instruct"]
assert resources == [
DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct")
]


class MockVectorSearchIndex:
Expand Down Expand Up @@ -96,10 +101,16 @@ def test_parsing_dependency_from_databricks_retriever(monkeypatch: pytest.Monkey
vectorstore = DatabricksVectorSearch(vs_index, text_column="content", embedding=embedding_model)
retriever = vectorstore.as_retriever()
d = defaultdict(list)
_extract_databricks_dependencies_from_retriever(retriever, d)
resources = []
_extract_databricks_dependencies_from_retriever(retriever, d, resources)
assert d.get(_DATABRICKS_EMBEDDINGS_ENDPOINT_NAME_KEY) == ["databricks-bge-large-en"]
assert d.get(_DATABRICKS_VECTOR_SEARCH_INDEX_NAME_KEY) == ["mlflow.rag.vs_index"]
assert d.get(_DATABRICKS_VECTOR_SEARCH_ENDPOINT_NAME_KEY) == ["dbdemos_vs_endpoint"]
assert resources == [
DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index"),
DatabricksServingEndpoint(endpoint_name="dbdemos_vs_endpoint"),
DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"),
]


@pytest.mark.skipif(
Expand All @@ -124,10 +135,16 @@ def test_parsing_dependency_from_databricks_retriever(monkeypatch: pytest.Monkey
vectorstore = DatabricksVectorSearch(vs_index, text_column="content", embedding=embedding_model)
retriever = vectorstore.as_retriever()
d = defaultdict(list)
_extract_databricks_dependencies_from_retriever(retriever, d)
resources = []
_extract_databricks_dependencies_from_retriever(retriever, d, resources)
assert d.get(_DATABRICKS_EMBEDDINGS_ENDPOINT_NAME_KEY) == ["databricks-bge-large-en"]
assert d.get(_DATABRICKS_VECTOR_SEARCH_INDEX_NAME_KEY) == ["mlflow.rag.vs_index"]
assert d.get(_DATABRICKS_VECTOR_SEARCH_ENDPOINT_NAME_KEY) == ["dbdemos_vs_endpoint"]
assert resources == [
DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index"),
DatabricksServingEndpoint(endpoint_name="dbdemos_vs_endpoint"),
DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"),
]


@pytest.mark.skipif(
Expand All @@ -142,8 +159,10 @@ def test_parsing_dependency_from_databricks_chat(monkeypatch: pytest.MonkeyPatch

chat_model = ChatDatabricks(endpoint="databricks-llama-2-70b-chat", max_tokens=500)
d = defaultdict(list)
_extract_databricks_dependencies_from_chat_model(chat_model, d)
resources = []
_extract_databricks_dependencies_from_chat_model(chat_model, d, resources)
assert d.get(_DATABRICKS_CHAT_ENDPOINT_NAME_KEY) == ["databricks-llama-2-70b-chat"]
assert resources == [DatabricksServingEndpoint(endpoint_name="databricks-llama-2-70b-chat")]


@pytest.mark.skipif(
Expand All @@ -158,5 +177,7 @@ def test_parsing_dependency_from_databricks_chat(monkeypatch: pytest.MonkeyPatch

chat_model = ChatDatabricks(endpoint="databricks-llama-2-70b-chat", max_tokens=500)
d = defaultdict(list)
_extract_databricks_dependencies_from_chat_model(chat_model, d)
resources = []
_extract_databricks_dependencies_from_chat_model(chat_model, d, resources)
assert d.get(_DATABRICKS_CHAT_ENDPOINT_NAME_KEY) == ["databricks-llama-2-70b-chat"]
assert resources == [DatabricksServingEndpoint(endpoint_name="databricks-llama-2-70b-chat")]
Loading

0 comments on commit 7494e32

Please sign in to comment.