Skip to content

Commit

Permalink
move sparse encoder creation to initialization part
Browse files Browse the repository at this point in the history
  • Loading branch information
DosticJelena committed Nov 22, 2023
1 parent b821625 commit e2d6128
Showing 1 changed file with 39 additions and 20 deletions.
59 changes: 39 additions & 20 deletions llama_index/vector_stores/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,38 +82,42 @@ def build_sparse_dict(input_batch: List[List[int]]) -> List[Dict[str, Any]]:
return sparse_emb


def encode_batch(
sparse_encoder: VectorStoreSparseEncoder, context_batch: List[str], query_mode: bool
) -> List[Dict[str, Any]]:
def initialize_sparse_encoder(sparse_encoder_type: VectorStoreSparseEncoder) -> Any:
try:
import pinecone_text.sparse

encoder_class_name = f"{sparse_encoder}Encoder"
encoder_class_name = f"{sparse_encoder_type}Encoder"
encoder_class = getattr(pinecone_text.sparse, encoder_class_name)
encoder = (
return (
encoder_class()
if sparse_encoder != VectorStoreSparseEncoder.BM25
if sparse_encoder_type != VectorStoreSparseEncoder.BM25
else encoder_class().default()
)
if query_mode:
sparse_vectors = encoder.encode_queries(context_batch)
else:
sparse_vectors = encoder.encode_documents(context_batch)
return sparse_vectors
except ImportError:
raise ImportError(import_err_msg_pinecone_text)


def encode_batch(
sparse_encoder: Any, context_batch: List[str], query_mode: bool
) -> List[Dict[str, Any]]:
if query_mode:
sparse_vectors = sparse_encoder.encode_queries(context_batch)
else:
sparse_vectors = sparse_encoder.encode_documents(context_batch)
return sparse_vectors


def generate_sparse_vectors(
sparse_encoder: VectorStoreSparseEncoder,
sparse_encoder_type: VectorStoreSparseEncoder,
sparse_encoder: Any,
context_batch: List[str],
tokenizer: Callable,
query_mode: bool = False,
) -> List[Dict[str, Any]]:
"""
Generate sparse vectors from a batch of contexts.
"""
if sparse_encoder == VectorStoreSparseEncoder.SPARSE_DICT:
if sparse_encoder_type == VectorStoreSparseEncoder.SPARSE_DICT:
sparse_vectors = build_sparse_dict(tokenizer(context_batch)["input_ids"])
else:
sparse_vectors = encode_batch(sparse_encoder, context_batch, query_mode)
Expand Down Expand Up @@ -155,13 +159,14 @@ class PineconeVectorStore(BasePydanticVectorStore):
namespace: Optional[str]
insert_kwargs: Optional[Dict]
add_sparse_vector: bool
sparse_vector_encoder: Optional[VectorStoreSparseEncoder]
sparse_encoder_type: Optional[VectorStoreSparseEncoder]
text_key: str
batch_size: int
remove_text_from_metadata: bool

_pinecone_index: Any = PrivateAttr()
_tokenizer: Optional[Callable] = PrivateAttr()
_sparse_encoder: Optional[Any] = PrivateAttr()

def __init__(
self,
Expand All @@ -172,9 +177,10 @@ def __init__(
namespace: Optional[str] = None,
insert_kwargs: Optional[Dict] = None,
add_sparse_vector: bool = False,
sparse_vector_encoder: Optional[
sparse_encoder_type: Optional[
VectorStoreSparseEncoder
] = DEFAULT_SPARSE_VECTOR_ENCODER,
sparse_encoder: Optional[Any] = None,
tokenizer: Optional[Callable] = None,
text_key: str = DEFAULT_TEXT_KEY,
batch_size: int = DEFAULT_BATCH_SIZE,
Expand Down Expand Up @@ -205,14 +211,23 @@ def __init__(
tokenizer = get_default_tokenizer()
self._tokenizer = tokenizer

if (
add_sparse_vector
and sparse_encoder is None
and sparse_encoder_type != VectorStoreSparseEncoder.SPARSE_DICT
):
sparse_encoder = initialize_sparse_encoder(sparse_encoder_type)
self._sparse_encoder = sparse_encoder

super().__init__(
index_name=index_name,
environment=environment,
api_key=api_key,
namespace=namespace,
insert_kwargs=insert_kwargs,
add_sparse_vector=add_sparse_vector,
sparse_vector_encoder=sparse_vector_encoder,
sparse_encoder_type=sparse_encoder_type,
sparse_encoder=sparse_encoder,
text_key=text_key,
batch_size=batch_size,
remove_text_from_metadata=remove_text_from_metadata,
Expand All @@ -227,9 +242,10 @@ def from_params(
namespace: Optional[str] = None,
insert_kwargs: Optional[Dict] = None,
add_sparse_vector: bool = False,
sparse_vector_encoder: Optional[
sparse_encoder_type: Optional[
VectorStoreSparseEncoder
] = DEFAULT_SPARSE_VECTOR_ENCODER,
sparse_encoder: Optional[Any] = None,
tokenizer: Optional[Callable] = None,
text_key: str = DEFAULT_TEXT_KEY,
batch_size: int = DEFAULT_BATCH_SIZE,
Expand All @@ -252,7 +268,8 @@ def from_params(
namespace=namespace,
insert_kwargs=insert_kwargs,
add_sparse_vector=add_sparse_vector,
sparse_vector_encoder=sparse_vector_encoder,
sparse_encoder_type=sparse_encoder_type,
sparse_encoder=sparse_encoder,
tokenizer=tokenizer,
text_key=text_key,
batch_size=batch_size,
Expand Down Expand Up @@ -294,7 +311,8 @@ def add(

if self.add_sparse_vector and self._tokenizer is not None:
sparse_vector = generate_sparse_vectors(
self.sparse_vector_encoder, # type: ignore
self.sparse_encoder_type,
self._sparse_encoder, # type: ignore
[node.get_content(metadata_mode=MetadataMode.EMBED)],
self._tokenizer,
)[0]
Expand Down Expand Up @@ -360,7 +378,8 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
)

sparse_vector = generate_sparse_vectors(
self.sparse_vector_encoder, # type: ignore
self.sparse_encoder_type,
self._sparse_encoder, # type: ignore
[query.query_str],
self._tokenizer,
query_mode=True,
Expand Down

0 comments on commit e2d6128

Please sign in to comment.