-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add new Databricks Vector Search langchain native tool VectorSearchRetrieverTool #24
Changes from 9 commits
fc34a39
1e1798e
d662b46
9d32449
60a5eb5
fe09b25
2105f55
5b5361a
66fe9a7
9166b5a
53f9924
ef765b1
8798989
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from typing import Any, Dict, List, Optional | ||
|
||
from pydantic import BaseModel, Field, model_validator, PrivateAttr | ||
|
||
from databricks_langchain import DatabricksVectorSearch | ||
from databricks_langchain.utils import IndexDetails | ||
from langchain_core.embeddings import Embeddings | ||
from langchain_core.tools import BaseTool | ||
|
||
|
||
class VectorSearchRetrieverTool(BaseTool): | ||
""" | ||
A utility class to create a vector search-based retrieval tool for querying indexed embeddings. | ||
This class integrates with a Databricks Vector Search and provides a convenient interface | ||
for building a retriever tool for agents. | ||
""" | ||
|
||
index_name: str = Field(..., description="The name of the index to use, format: 'catalog.schema.index'.") | ||
num_results: int = Field(10, description="The number of results to return.") | ||
columns: Optional[List[str]] = Field(None, description="Columns to return when doing the search.") | ||
filters: Optional[Dict[str, Any]] = Field(None, description="Filters to apply to the search.") | ||
query_type: str = Field("ANN", description="The type of query to run.") | ||
tool_name: Optional[str] = Field(None, description="The name of the retrieval tool.") | ||
tool_description: Optional[str] = Field(None, description="A description of the tool.") | ||
# TODO: Confirm if we can add these two to the API to support direct-access indexes or a delta-sync indexes with self-managed embeddings, | ||
text_column: Optional[str] = Field(None, description="If using a direct-access index or delta-sync index, specify the text column.") | ||
embedding: Optional[Embeddings] = Field(None, description="Embedding model for self-managed embeddings.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These two fields are required for direct-access indexes or delta-sync indexes with self-managed embeddings. Should we support these additional fields? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like if we support it for DatabricksVectorSearch it makes sense to support it here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, seems reasonable to support these, though I'd say it's worth asking vector search folks how commonly direct access indexes are used, if it's infrequent we could drop this to start with to simplify the API/testing surface There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to block this PR on that though, I figure we'll need this eventually anyways, would just be good for us to know There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
# TODO: Confirm if we can add this endpoint field | ||
endpoint: Optional[str] = Field(None, description="Endpoint for DatabricksVectorSearch.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This field was added because of this restriction in databricks-langchain. I felt that if we threw this error without giving the ability for the user to rectify it, it would be a poor user experience. Alternatively maybe we pin databricks-vectorsearch to be >=0.35. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's valid to require databricks-vectorsearch >= 0.35 especially because this is new - that might be the better considering we don't need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah reasonable to require new versions of other clients! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Turns out we already mark the "databricks-vectorsearch>=0.40" as a dependency here, so I'll just remove this argument. |
||
|
||
# The BaseTool class requires 'name' and 'description' fields which we will populate in validate_tool_inputs() | ||
name: str = Field(default="", description="The name of the tool") | ||
description: str = Field(default="", description="The description of the tool") | ||
|
||
_vector_store: DatabricksVectorSearch = PrivateAttr() | ||
|
||
@model_validator(mode='after') | ||
def validate_tool_inputs(self): | ||
# Construct the vector store using provided params | ||
kwargs = { | ||
"index_name": self.index_name, | ||
"endpoint": self.endpoint, | ||
"embedding": self.embedding, | ||
"text_column": self.text_column, | ||
"columns": self.columns, | ||
} | ||
dbvs = DatabricksVectorSearch(**kwargs) | ||
self._vector_store = dbvs | ||
|
||
def get_tool_description(): | ||
default_tool_description = "A vector search-based retrieval tool for querying indexed embeddings." | ||
index_details = IndexDetails(dbvs.index) | ||
if index_details.is_delta_sync_index(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. direct access indexes don't have an associated source table so we'll just use the default tool description. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious what the existing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One way to tell would be to use it as a tool with payload logging enabled & see what the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This generally looks reasonable, just curious if we can keep it in sync with the existing behavior/default There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lol yep makes sense, the updated version in this PR is definitely better |
||
from databricks.sdk import WorkspaceClient | ||
|
||
source_table = index_details.index_spec.get('source_table', "") | ||
w = WorkspaceClient() | ||
source_table_comment = w.tables.get(full_name=source_table).comment | ||
if source_table_comment: | ||
return ( | ||
default_tool_description + | ||
f" The queried index uses the source table {source_table} with the description: " + | ||
source_table_comment | ||
) | ||
else: | ||
return default_tool_description + f" The queried index uses the source table {source_table}" | ||
return default_tool_description | ||
|
||
self.name = self.tool_name or self.index_name | ||
self.description = self.tool_description or get_tool_description() | ||
|
||
return self | ||
|
||
|
||
def _run(self, query: str) -> str: | ||
return self._vector_store.similarity_search( | ||
query, | ||
k = self.num_results, | ||
filter = self.filters, | ||
query_type = self.query_type | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ, does this get sent to the LLM as the parameter description? If so I wonder if it's worth including examples like the ones in https://docs.databricks.com/api/workspace/vectorsearchindexes/queryindex
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh nvm, this is in the init, not in the tool call
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But seems like there is a way we can specify the description of the params for the LLM too: https://chatgpt.com/share/6764d76f-69a0-8009-8a8f-f58977753057
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See also https://python.langchain.com/docs/how_to/custom_tools/#subclass-basetool (we can use
args_schema
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to include VectorSearchRetrieverToolInput as an args_schema