diff --git a/llama_index/vector_stores/pinecone.py b/llama_index/vector_stores/pinecone.py index d88f71d1ed3cb..accffebdc1af9 100644 --- a/llama_index/vector_stores/pinecone.py +++ b/llama_index/vector_stores/pinecone.py @@ -82,30 +82,34 @@ def build_sparse_dict(input_batch: List[List[int]]) -> List[Dict[str, Any]]: return sparse_emb -def encode_batch( - sparse_encoder: VectorStoreSparseEncoder, context_batch: List[str], query_mode: bool -) -> List[Dict[str, Any]]: +def initialize_sparse_encoder(sparse_encoder_type: VectorStoreSparseEncoder) -> Any: try: import pinecone_text.sparse - encoder_class_name = f"{sparse_encoder}Encoder" + encoder_class_name = f"{sparse_encoder_type}Encoder" encoder_class = getattr(pinecone_text.sparse, encoder_class_name) - encoder = ( + return ( encoder_class() - if sparse_encoder != VectorStoreSparseEncoder.BM25 + if sparse_encoder_type != VectorStoreSparseEncoder.BM25 else encoder_class().default() ) - if query_mode: - sparse_vectors = encoder.encode_queries(context_batch) - else: - sparse_vectors = encoder.encode_documents(context_batch) - return sparse_vectors except ImportError: raise ImportError(import_err_msg_pinecone_text) +def encode_batch( + sparse_encoder: Any, context_batch: List[str], query_mode: bool +) -> List[Dict[str, Any]]: + if query_mode: + sparse_vectors = sparse_encoder.encode_queries(context_batch) + else: + sparse_vectors = sparse_encoder.encode_documents(context_batch) + return sparse_vectors + + def generate_sparse_vectors( - sparse_encoder: VectorStoreSparseEncoder, + sparse_encoder_type: VectorStoreSparseEncoder, + sparse_encoder: Any, context_batch: List[str], tokenizer: Callable, query_mode: bool = False, @@ -113,7 +117,7 @@ def generate_sparse_vectors( """ Generate sparse vectors from a batch of contexts. """ - if sparse_encoder == VectorStoreSparseEncoder.SPARSE_DICT: + if sparse_encoder_type == VectorStoreSparseEncoder.SPARSE_DICT: sparse_vectors = build_sparse_dict(tokenizer(context_batch)["input_ids"]) else: sparse_vectors = encode_batch(sparse_encoder, context_batch, query_mode) @@ -155,13 +159,14 @@ class PineconeVectorStore(BasePydanticVectorStore): namespace: Optional[str] insert_kwargs: Optional[Dict] add_sparse_vector: bool - sparse_vector_encoder: Optional[VectorStoreSparseEncoder] + sparse_encoder_type: Optional[VectorStoreSparseEncoder] text_key: str batch_size: int remove_text_from_metadata: bool _pinecone_index: Any = PrivateAttr() _tokenizer: Optional[Callable] = PrivateAttr() + _sparse_encoder: Optional[Any] = PrivateAttr() def __init__( self, @@ -172,9 +177,10 @@ def __init__( namespace: Optional[str] = None, insert_kwargs: Optional[Dict] = None, add_sparse_vector: bool = False, - sparse_vector_encoder: Optional[ + sparse_encoder_type: Optional[ VectorStoreSparseEncoder ] = DEFAULT_SPARSE_VECTOR_ENCODER, + sparse_encoder: Optional[Any] = None, tokenizer: Optional[Callable] = None, text_key: str = DEFAULT_TEXT_KEY, batch_size: int = DEFAULT_BATCH_SIZE, @@ -205,6 +211,14 @@ def __init__( tokenizer = get_default_tokenizer() self._tokenizer = tokenizer + if ( + add_sparse_vector + and sparse_encoder is None + and sparse_encoder_type != VectorStoreSparseEncoder.SPARSE_DICT + ): + sparse_encoder = initialize_sparse_encoder(sparse_encoder_type) + self._sparse_encoder = sparse_encoder + super().__init__( index_name=index_name, environment=environment, @@ -212,7 +226,8 @@ def __init__( namespace=namespace, insert_kwargs=insert_kwargs, add_sparse_vector=add_sparse_vector, - sparse_vector_encoder=sparse_vector_encoder, + sparse_encoder_type=sparse_encoder_type, + sparse_encoder=sparse_encoder, text_key=text_key, batch_size=batch_size, remove_text_from_metadata=remove_text_from_metadata, @@ -227,9 +242,10 @@ def from_params( namespace: Optional[str] = None, insert_kwargs: Optional[Dict] = None, add_sparse_vector: bool = False, - sparse_vector_encoder: Optional[ + sparse_encoder_type: Optional[ VectorStoreSparseEncoder ] = DEFAULT_SPARSE_VECTOR_ENCODER, + sparse_encoder: Optional[Any] = None, tokenizer: Optional[Callable] = None, text_key: str = DEFAULT_TEXT_KEY, batch_size: int = DEFAULT_BATCH_SIZE, @@ -252,7 +268,8 @@ def from_params( namespace=namespace, insert_kwargs=insert_kwargs, add_sparse_vector=add_sparse_vector, - sparse_vector_encoder=sparse_vector_encoder, + sparse_encoder_type=sparse_encoder_type, + sparse_encoder=sparse_encoder, tokenizer=tokenizer, text_key=text_key, batch_size=batch_size, @@ -294,7 +311,8 @@ def add( if self.add_sparse_vector and self._tokenizer is not None: sparse_vector = generate_sparse_vectors( - self.sparse_vector_encoder, # type: ignore + self.sparse_encoder_type, + self._sparse_encoder, # type: ignore [node.get_content(metadata_mode=MetadataMode.EMBED)], self._tokenizer, )[0] @@ -360,7 +378,8 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul ) sparse_vector = generate_sparse_vectors( - self.sparse_vector_encoder, # type: ignore + self.sparse_encoder_type, + self._sparse_encoder, # type: ignore [query.query_str], self._tokenizer, query_mode=True,