Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new Databricks Vector Search langchain native tool VectorSearchRetrieverTool #24

Merged
merged 13 commits into from
Dec 20, 2024
2 changes: 2 additions & 0 deletions integrations/langchain/src/databricks_langchain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from databricks_langchain.chat_models import ChatDatabricks
from databricks_langchain.embeddings import DatabricksEmbeddings
from databricks_langchain.genie import GenieAgent
from databricks_langchain.vector_search_retriever_tool import VectorSearchRetrieverTool
from databricks_langchain.vectorstores import DatabricksVectorSearch

# Expose all integrations to users under databricks-langchain
Expand All @@ -9,4 +10,5 @@
"DatabricksEmbeddings",
"DatabricksVectorSearch",
"GenieAgent",
"VectorSearchRetrieverTool",
]
61 changes: 60 additions & 1 deletion integrations/langchain/src/databricks_langchain/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, List, Union
import json
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse

import numpy as np
Expand Down Expand Up @@ -95,3 +97,60 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
return similarity


class IndexType(str, Enum):
DIRECT_ACCESS = "DIRECT_ACCESS"
DELTA_SYNC = "DELTA_SYNC"


class IndexDetails:
"""An utility class to store the configuration details of an index."""

def __init__(self, index: Any):
self._index_details = index.describe()

@property
def name(self) -> str:
return self._index_details["name"]

@property
def schema(self) -> Optional[Dict]:
if self.is_direct_access_index():
schema_json = self.index_spec.get("schema_json")
if schema_json is not None:
return json.loads(schema_json)
return None

@property
def primary_key(self) -> str:
return self._index_details["primary_key"]

@property
def index_spec(self) -> Dict:
return (
self._index_details.get("delta_sync_index_spec", {})
if self.is_delta_sync_index()
else self._index_details.get("direct_access_index_spec", {})
)

@property
def embedding_vector_column(self) -> Dict:
if vector_columns := self.index_spec.get("embedding_vector_columns"):
return vector_columns[0]
return {}

@property
def embedding_source_column(self) -> Dict:
if source_columns := self.index_spec.get("embedding_source_columns"):
return source_columns[0]
return {}

def is_delta_sync_index(self) -> bool:
return self._index_details["index_type"] == IndexType.DELTA_SYNC.value

def is_direct_access_index(self) -> bool:
return self._index_details["index_type"] == IndexType.DIRECT_ACCESS.value

def is_databricks_managed_embeddings(self) -> bool:
return self.is_delta_sync_index() and self.embedding_source_column.get("name") is not None
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from typing import Any, Dict, List, Optional, Type

from langchain_core.embeddings import Embeddings
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field, PrivateAttr, model_validator

from databricks_langchain.utils import IndexDetails
from databricks_langchain.vectorstores import DatabricksVectorSearch


class VectorSearchRetrieverToolInput(BaseModel):
query: str = Field(
description="The string used to query the index with and identify the most similar "
"vectors and return the associated documents."
)


class VectorSearchRetrieverTool(BaseTool):
"""
A utility class to create a vector search-based retrieval tool for querying indexed embeddings.
This class integrates with a Databricks Vector Search and provides a convenient interface
for building a retriever tool for agents.
"""

index_name: str = Field(
..., description="The name of the index to use, format: 'catalog.schema.index'."
)
num_results: int = Field(10, description="The number of results to return.")
columns: Optional[List[str]] = Field(
None, description="Columns to return when doing the search."
)
filters: Optional[Dict[str, Any]] = Field(None, description="Filters to apply to the search.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ, does this get sent to the LLM as the parameter description? If so I wonder if it's worth including examples like the ones in https://docs.databricks.com/api/workspace/vectorsearchindexes/queryindex

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nvm, this is in the init, not in the tool call

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But seems like there is a way we can specify the description of the params for the LLM too: https://chatgpt.com/share/6764d76f-69a0-8009-8a8f-f58977753057

Copy link
Collaborator

@smurching smurching Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to include VectorSearchRetrieverToolInput as an args_schema

query_type: str = Field(
"ANN", description="The type of this query. Supported values are 'ANN' and 'HYBRID'."
)
tool_name: Optional[str] = Field(None, description="The name of the retrieval tool.")
tool_description: Optional[str] = Field(None, description="A description of the tool.")
text_column: Optional[str] = Field(
None,
description="The name of the text column to use for the embeddings. "
"Required for direct-access index or delta-sync index with "
"self-managed embeddings.",
)
embedding: Optional[Embeddings] = Field(
None, description="Embedding model for self-managed embeddings."
)

# The BaseTool class requires 'name' and 'description' fields which we will populate in validate_tool_inputs()
name: str = Field(default="", description="The name of the tool")
description: str = Field(default="", description="The description of the tool")
args_schema: Type[BaseModel] = VectorSearchRetrieverToolInput

_vector_store: DatabricksVectorSearch = PrivateAttr()

@model_validator(mode="after")
def validate_tool_inputs(self):
kwargs = {
"index_name": self.index_name,
"embedding": self.embedding,
"text_column": self.text_column,
"columns": self.columns,
}
dbvs = DatabricksVectorSearch(**kwargs)
self._vector_store = dbvs

def get_tool_description():
default_tool_description = (
"A vector search-based retrieval tool for querying indexed embeddings."
)
index_details = IndexDetails(dbvs.index)
if index_details.is_delta_sync_index():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

direct access indexes don't have an associated source table so we'll just use the default tool description.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious what the existing langchain-databricks DatabricksVectorSearch.as_retriever(...).as_tool(...) ends up generating as the tool description

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One way to tell would be to use it as a tool with payload logging enabled & see what the tools argument to the LLM API in model serving looks like

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This generally looks reasonable, just curious if we can keep it in sync with the existing behavior/default

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image Tested it out in a notebook. The default seems to be extremely basic and lacking in content. Does this answer your question?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lol yep makes sense, the updated version in this PR is definitely better

from databricks.sdk import WorkspaceClient

source_table = index_details.index_spec.get("source_table", "")
w = WorkspaceClient()
source_table_comment = w.tables.get(full_name=source_table).comment
if source_table_comment:
return (
default_tool_description
+ f" The queried index uses the source table {source_table} with the description: "
+ source_table_comment
)
else:
return (
default_tool_description
+ f" The queried index uses the source table {source_table}"
)
return default_tool_description

self.name = self.tool_name or self.index_name
self.description = self.tool_description or get_tool_description()

return self

def _run(self, query: str) -> str:
return self._vector_store.similarity_search(
query, k=self.num_results, filter=self.filters, query_type=self.query_type
)
61 changes: 1 addition & 60 deletions integrations/langchain/src/databricks_langchain/vectorstores.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from __future__ import annotations

import asyncio
import json
import logging
import re
import uuid
from enum import Enum
from functools import partial
from typing import (
Any,
Expand All @@ -23,16 +21,11 @@
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VST, VectorStore

from databricks_langchain.utils import maximal_marginal_relevance
from databricks_langchain.utils import IndexDetails, maximal_marginal_relevance

logger = logging.getLogger(__name__)


class IndexType(str, Enum):
DIRECT_ACCESS = "DIRECT_ACCESS"
DELTA_SYNC = "DELTA_SYNC"


_DIRECT_ACCESS_ONLY_MSG = "`%s` is only supported for direct-access index."
_NON_MANAGED_EMB_ONLY_MSG = "`%s` is not supported for index with Databricks-managed embeddings."
_INDEX_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+$")
Expand Down Expand Up @@ -782,55 +775,3 @@ def _validate_embedding_dimension(embeddings: Embeddings, index_details: IndexDe
f"The specified embedding model's dimension '{actual_dimension}' does "
f"not match with the index configuration '{index_embedding_dimension}'."
)


class IndexDetails:
"""An utility class to store the configuration details of an index."""

def __init__(self, index: Any):
self._index_details = index.describe()

@property
def name(self) -> str:
return self._index_details["name"]

@property
def schema(self) -> Optional[Dict]:
if self.is_direct_access_index():
schema_json = self.index_spec.get("schema_json")
if schema_json is not None:
return json.loads(schema_json)
return None

@property
def primary_key(self) -> str:
return self._index_details["primary_key"]

@property
def index_spec(self) -> Dict:
return (
self._index_details.get("delta_sync_index_spec", {})
if self.is_delta_sync_index()
else self._index_details.get("direct_access_index_spec", {})
)

@property
def embedding_vector_column(self) -> Dict:
if vector_columns := self.index_spec.get("embedding_vector_columns"):
return vector_columns[0]
return {}

@property
def embedding_source_column(self) -> Dict:
if source_columns := self.index_spec.get("embedding_source_columns"):
return source_columns[0]
return {}

def is_delta_sync_index(self) -> bool:
return self._index_details["index_type"] == IndexType.DELTA_SYNC.value

def is_direct_access_index(self) -> bool:
return self._index_details["index_type"] == IndexType.DIRECT_ACCESS.value

def is_databricks_managed_embeddings(self) -> bool:
return self.is_delta_sync_index() and self.embedding_source_column.get("name") is not None
Loading
Loading