Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new Databricks Vector Search langchain native tool VectorSearchRetrieverTool #24

Merged
merged 13 commits into from
Dec 20, 2024
2 changes: 2 additions & 0 deletions integrations/langchain/src/databricks_langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from databricks_langchain.embeddings import DatabricksEmbeddings
from databricks_langchain.genie import GenieAgent
from databricks_langchain.vectorstores import DatabricksVectorSearch
from databricks_langchain.vector_search import VectorSearchRetrieverTool

# Expose all integrations to users under databricks-langchain
__all__ = [
"ChatDatabricks",
"DatabricksEmbeddings",
"DatabricksVectorSearch",
"GenieAgent",
"VectorSearchRetrieverTool"
]
61 changes: 61 additions & 0 deletions integrations/langchain/src/databricks_langchain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
from urllib.parse import urlparse

import numpy as np
from enum import Enum
import json

from typing import (
Dict,
Optional
)

def get_deployment_client(target_uri: str) -> Any:
if (target_uri != "databricks") and (urlparse(target_uri).scheme != "databricks"):
Expand Down Expand Up @@ -95,3 +101,58 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
return similarity

class IndexType(str, Enum):
DIRECT_ACCESS = "DIRECT_ACCESS"
DELTA_SYNC = "DELTA_SYNC"

class IndexDetails:
"""An utility class to store the configuration details of an index."""

def __init__(self, index: Any):
self._index_details = index.describe()

@property
def name(self) -> str:
return self._index_details["name"]

@property
def schema(self) -> Optional[Dict]:
if self.is_direct_access_index():
schema_json = self.index_spec.get("schema_json")
if schema_json is not None:
return json.loads(schema_json)
return None

@property
def primary_key(self) -> str:
return self._index_details["primary_key"]

@property
def index_spec(self) -> Dict:
return (
self._index_details.get("delta_sync_index_spec", {})
if self.is_delta_sync_index()
else self._index_details.get("direct_access_index_spec", {})
)

@property
def embedding_vector_column(self) -> Dict:
if vector_columns := self.index_spec.get("embedding_vector_columns"):
return vector_columns[0]
return {}

@property
def embedding_source_column(self) -> Dict:
if source_columns := self.index_spec.get("embedding_source_columns"):
return source_columns[0]
return {}

def is_delta_sync_index(self) -> bool:
return self._index_details["index_type"] == IndexType.DELTA_SYNC.value

def is_direct_access_index(self) -> bool:
return self._index_details["index_type"] == IndexType.DIRECT_ACCESS.value

def is_databricks_managed_embeddings(self) -> bool:
return self.is_delta_sync_index() and self.embedding_source_column.get("name") is not None
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Field, model_validator, PrivateAttr

from databricks_langchain import DatabricksVectorSearch
from databricks_langchain.utils import IndexDetails
from langchain_core.embeddings import Embeddings
from langchain_core.tools import BaseTool


class VectorSearchRetrieverTool(BaseTool):
"""
A utility class to create a vector search-based retrieval tool for querying indexed embeddings.
This class integrates with a Databricks Vector Search and provides a convenient interface
for building a retriever tool for agents.
"""

index_name: str = Field(..., description="The name of the index to use, format: 'catalog.schema.index'.")
num_results: int = Field(10, description="The number of results to return.")
columns: Optional[List[str]] = Field(None, description="Columns to return when doing the search.")
filters: Optional[Dict[str, Any]] = Field(None, description="Filters to apply to the search.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ, does this get sent to the LLM as the parameter description? If so I wonder if it's worth including examples like the ones in https://docs.databricks.com/api/workspace/vectorsearchindexes/queryindex

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nvm, this is in the init, not in the tool call

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But seems like there is a way we can specify the description of the params for the LLM too: https://chatgpt.com/share/6764d76f-69a0-8009-8a8f-f58977753057

Copy link
Collaborator

@smurching smurching Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to include VectorSearchRetrieverToolInput as an args_schema

query_type: str = Field("ANN", description="The type of query to run.")
tool_name: Optional[str] = Field(None, description="The name of the retrieval tool.")
tool_description: Optional[str] = Field(None, description="A description of the tool.")
# TODO: Confirm if we can add these two to the API to support direct-access indexes or a delta-sync indexes with self-managed embeddings,
text_column: Optional[str] = Field(None, description="If using a direct-access index or delta-sync index, specify the text column.")
embedding: Optional[Embeddings] = Field(None, description="Embedding model for self-managed embeddings.")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two fields are required for direct-access indexes or delta-sync indexes with self-managed embeddings. Should we support these additional fields?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like if we support it for DatabricksVectorSearch it makes sense to support it here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, seems reasonable to support these, though I'd say it's worth asking vector search folks how commonly direct access indexes are used, if it's infrequent we could drop this to start with to simplify the API/testing surface

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to block this PR on that though, I figure we'll need this eventually anyways, would just be good for us to know

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# TODO: Confirm if we can add this endpoint field
endpoint: Optional[str] = Field(None, description="Endpoint for DatabricksVectorSearch.")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This field was added because of this restriction in databricks-langchain. I felt that if we threw this error without giving the ability for the user to rectify it, it would be a poor user experience. Alternatively maybe we pin databricks-vectorsearch to be >=0.35.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's valid to require databricks-vectorsearch >= 0.35 especially because this is new - that might be the better considering we don't need endpoint for any other reason.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah reasonable to require new versions of other clients!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out we already mark the "databricks-vectorsearch>=0.40" as a dependency here, so I'll just remove this argument.


# The BaseTool class requires 'name' and 'description' fields which we will populate in validate_tool_inputs()
name: str = Field(default="", description="The name of the tool")
description: str = Field(default="", description="The description of the tool")

_vector_store: DatabricksVectorSearch = PrivateAttr()

@model_validator(mode='after')
def validate_tool_inputs(self):
# Construct the vector store using provided params
kwargs = {
"index_name": self.index_name,
"endpoint": self.endpoint,
"embedding": self.embedding,
"text_column": self.text_column,
"columns": self.columns,
}
dbvs = DatabricksVectorSearch(**kwargs)
self._vector_store = dbvs

def get_tool_description():
default_tool_description = "A vector search-based retrieval tool for querying indexed embeddings."
index_details = IndexDetails(dbvs.index)
if index_details.is_delta_sync_index():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

direct access indexes don't have an associated source table so we'll just use the default tool description.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious what the existing langchain-databricks DatabricksVectorSearch.as_retriever(...).as_tool(...) ends up generating as the tool description

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One way to tell would be to use it as a tool with payload logging enabled & see what the tools argument to the LLM API in model serving looks like

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This generally looks reasonable, just curious if we can keep it in sync with the existing behavior/default

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image Tested it out in a notebook. The default seems to be extremely basic and lacking in content. Does this answer your question?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lol yep makes sense, the updated version in this PR is definitely better

from databricks.sdk import WorkspaceClient

source_table = index_details.index_spec.get('source_table', "")
w = WorkspaceClient()
source_table_comment = w.tables.get(full_name=source_table).comment
if source_table_comment:
return (
default_tool_description +
f" The queried index uses the source table {source_table} with the description: " +
source_table_comment
)
else:
return default_tool_description + f" The queried index uses the source table {source_table}"
return default_tool_description

self.name = self.tool_name or self.index_name
self.description = self.tool_description or get_tool_description()

return self


def _run(self, query: str) -> str:
return self._vector_store.similarity_search(
query,
k = self.num_results,
filter = self.filters,
query_type = self.query_type
)
60 changes: 1 addition & 59 deletions integrations/langchain/src/databricks_langchain/vectorstores.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from __future__ import annotations

import asyncio
import json
import logging
import re
import uuid
from enum import Enum
from functools import partial
from typing import (
Any,
Expand All @@ -23,16 +21,11 @@
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VST, VectorStore

from databricks_langchain.utils import maximal_marginal_relevance
from databricks_langchain.utils import maximal_marginal_relevance, IndexDetails

logger = logging.getLogger(__name__)


class IndexType(str, Enum):
DIRECT_ACCESS = "DIRECT_ACCESS"
DELTA_SYNC = "DELTA_SYNC"


_DIRECT_ACCESS_ONLY_MSG = "`%s` is only supported for direct-access index."
_NON_MANAGED_EMB_ONLY_MSG = "`%s` is not supported for index with Databricks-managed embeddings."
_INDEX_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+$")
Expand Down Expand Up @@ -783,54 +776,3 @@ def _validate_embedding_dimension(embeddings: Embeddings, index_details: IndexDe
f"not match with the index configuration '{index_embedding_dimension}'."
)


class IndexDetails:
"""An utility class to store the configuration details of an index."""

def __init__(self, index: Any):
self._index_details = index.describe()

@property
def name(self) -> str:
return self._index_details["name"]

@property
def schema(self) -> Optional[Dict]:
if self.is_direct_access_index():
schema_json = self.index_spec.get("schema_json")
if schema_json is not None:
return json.loads(schema_json)
return None

@property
def primary_key(self) -> str:
return self._index_details["primary_key"]

@property
def index_spec(self) -> Dict:
return (
self._index_details.get("delta_sync_index_spec", {})
if self.is_delta_sync_index()
else self._index_details.get("direct_access_index_spec", {})
)

@property
def embedding_vector_column(self) -> Dict:
if vector_columns := self.index_spec.get("embedding_vector_columns"):
return vector_columns[0]
return {}

@property
def embedding_source_column(self) -> Dict:
if source_columns := self.index_spec.get("embedding_source_columns"):
return source_columns[0]
return {}

def is_delta_sync_index(self) -> bool:
return self._index_details["index_type"] == IndexType.DELTA_SYNC.value

def is_direct_access_index(self) -> bool:
return self._index_details["index_type"] == IndexType.DIRECT_ACCESS.value

def is_databricks_managed_embeddings(self) -> bool:
return self.is_delta_sync_index() and self.embedding_source_column.get("name") is not None
128 changes: 2 additions & 126 deletions integrations/langchain/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,133 +30,9 @@
_convert_message_to_dict,
)

_MOCK_CHAT_RESPONSE = {
"id": "chatcmpl_id",
"object": "chat.completion",
"created": 1721875529,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "To calculate the result of 36939 multiplied by 8922.4, "
"I get:\n\n36939 x 8922.4 = 329,511,111.6",
},
"finish_reason": "stop",
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66},
}

_MOCK_STREAM_RESPONSE = [
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": "36939"},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 20, "total_tokens": 50},
},
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": "x"},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 22, "total_tokens": 52},
},
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": "8922.4"},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 24, "total_tokens": 54},
},
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": " = "},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 28, "total_tokens": 58},
},
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": "329,511,111.6"},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 30, "total_tokens": 60},
},
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": ""},
"finish_reason": "stop",
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66},
},
]


@pytest.fixture(autouse=True)
def mock_client() -> Generator:
client = mock.MagicMock()
client.predict.return_value = _MOCK_CHAT_RESPONSE
client.predict_stream.return_value = _MOCK_STREAM_RESPONSE
with mock.patch("mlflow.deployments.get_deploy_client", return_value=client):
yield


@pytest.fixture
def llm() -> ChatDatabricks:
return ChatDatabricks(endpoint="databricks-meta-llama-3-70b-instruct", target_uri="databricks")
from databricks_langchain import ChatDatabricks

from tests.utils.chat_models import _MOCK_CHAT_RESPONSE, _MOCK_STREAM_RESPONSE, mock_client, llm

def test_dict(llm: ChatDatabricks) -> None:
d = llm.dict()
Expand Down
Loading
Loading