From 271e3aacab28bc8386f662417dd96218c469dc82 Mon Sep 17 00:00:00 2001 From: DosticJelena Date: Thu, 28 Dec 2023 17:00:08 +0100 Subject: [PATCH] support different Pinecone initialization depending on the version --- llama_index/vector_stores/pinecone.py | 57 +++++++++++++++++++-------- 1 file changed, 41 insertions(+), 16 deletions(-) diff --git a/llama_index/vector_stores/pinecone.py b/llama_index/vector_stores/pinecone.py index 6ffec95095f9a..0638b02bcf263 100644 --- a/llama_index/vector_stores/pinecone.py +++ b/llama_index/vector_stores/pinecone.py @@ -10,6 +10,9 @@ from functools import partial from typing import Any, Callable, Dict, List, Optional, cast +from packaging import version +from pkg_resources import get_distribution + from llama_index.bridge.pydantic import PrivateAttr from llama_index.schema import BaseNode, MetadataMode, TextNode from llama_index.vector_stores.types import ( @@ -69,7 +72,8 @@ def _transform_pinecone_filter_operator(operator: str) -> str: def build_dict(input_batch: List[List[int]]) -> List[Dict[str, Any]]: - """Build a list of sparse dictionaries from a batch of input_ids. + """ + Build a list of sparse dictionaries from a batch of input_ids. NOTE: taken from https://www.pinecone.io/learn/hybrid-search-intro/. @@ -93,7 +97,8 @@ def build_dict(input_batch: List[List[int]]) -> List[Dict[str, Any]]: def generate_sparse_vectors( context_batch: List[str], tokenizer: Callable ) -> List[Dict[str, Any]]: - """Generate sparse vectors from a batch of contexts. + """ + Generate sparse vectors from a batch of contexts. NOTE: taken from https://www.pinecone.io/learn/hybrid-search-intro/. @@ -105,7 +110,8 @@ def generate_sparse_vectors( def get_default_tokenizer() -> Callable: - """Get default tokenizer. + """ + Get default tokenizer. NOTE: taken from https://www.pinecone.io/learn/hybrid-search-intro/. @@ -157,7 +163,8 @@ def _to_pinecone_filter(standard_filters: MetadataFilters) -> dict: class PineconeVectorStore(BasePydanticVectorStore): - """Pinecone Vector Store. + """ + Pinecone Vector Store. In this vector store, embeddings and docs are stored within a Pinecone index. @@ -217,14 +224,24 @@ def __init__( if pinecone_index is not None: self._pinecone_index = cast(pinecone.Index, pinecone_index) else: - if index_name is None or environment is None: - raise ValueError( - "Must specify index_name and environment " - "if not directly passing in client." - ) + pinecone_client_version = get_distribution("pinecone-client").version + + if version.parse(pinecone_client_version) >= version.parse("3.0.0"): + if index_name is None: + raise ValueError( + "Must specify index_name if not directly passing in client." + ) + pinecone_instance = pinecone.Pinecone(api_key=api_key) + self._pinecone_index = pinecone_instance.Index(index_name) + else: + if index_name is None or environment is None: + raise ValueError( + "Must specify index_name and environment " + "if not directly passing in client." + ) - pinecone.init(api_key=api_key, environment=environment) - self._pinecone_index = pinecone.Index(index_name) + pinecone.init(api_key=api_key, environment=environment) + self._pinecone_index = pinecone.Index(index_name) insert_kwargs = insert_kwargs or {} @@ -265,8 +282,14 @@ def from_params( except ImportError: raise ImportError(import_err_msg) - pinecone.init(api_key=api_key, environment=environment) - pinecone_index = pinecone.Index(index_name) + pinecone_client_version = get_distribution("pinecone-client").version + + if version.parse(pinecone_client_version) >= version.parse("3.0.0"): + pinecone_instance = pinecone.Pinecone(api_key=api_key) + pinecone_index = pinecone_instance.Index(index_name) + else: + pinecone.init(api_key=api_key, environment=environment) + pinecone_index = pinecone.Index(index_name) return cls( pinecone_index=pinecone_index, @@ -286,14 +309,15 @@ def from_params( @classmethod def class_name(cls) -> str: - return "PinconeVectorStore" + return "PineconeVectorStore" def add( self, nodes: List[BaseNode], **add_kwargs: Any, ) -> List[str]: - """Add nodes to index. + """ + Add nodes to index. Args: nodes: List[BaseNode]: list of nodes with embeddings @@ -353,7 +377,8 @@ def client(self) -> Any: return self._pinecone_index def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: - """Query index for top k most similar nodes. + """ + Query index for top k most similar nodes. Args: query_embedding (List[float]): query embedding