diff --git a/integrations/langchain/src/databricks_langchain/vectorstores.py b/integrations/langchain/src/databricks_langchain/vectorstores.py index e67f315..f513dfa 100644 --- a/integrations/langchain/src/databricks_langchain/vectorstores.py +++ b/integrations/langchain/src/databricks_langchain/vectorstores.py @@ -86,6 +86,10 @@ class DatabricksVectorSearch(VectorStore): Make sure the text column specified is in the index. columns: The list of column names to get when doing the search. Defaults to ``[primary_key, text_column]``. + client_args: Additional arguments to pass to the VectorSearchClient. + Allows you to pass in values like `service_principal_client_id` + and `service_principal_client_secret` for to allow for + service principal authentication instead of personal access token authentication. Instantiate: @@ -212,6 +216,7 @@ def __init__( embedding: Optional[Embeddings] = None, text_column: Optional[str] = None, columns: Optional[List[str]] = None, + client_args: Optional[Dict[str, Any]] = None, ): if not (isinstance(index_name, str) and _INDEX_NAME_PATTERN.match(index_name)): raise ValueError( @@ -230,7 +235,7 @@ def __init__( ) from e try: - self.index = VectorSearchClient().get_index( + self.index = VectorSearchClient(**(client_args or {})).get_index( endpoint_name=endpoint, index_name=index_name ) except Exception as e: