Skip to content

Commit

Permalink
add support for BM25 and SPLADE sparse vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
DosticJelena authored and Jelena Dostić committed Dec 29, 2023
1 parent 43bb20c commit 93d9eee
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 54 deletions.
176 changes: 122 additions & 54 deletions llama_index/vector_stores/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
VectorStoreQuery,
VectorStoreQueryMode,
VectorStoreQueryResult,
VectorStoreSparseEncoder,
)
from llama_index.vector_stores.utils import (
DEFAULT_TEXT_KEY,
Expand All @@ -31,19 +32,44 @@
METADATA_KEY = "metadata"

DEFAULT_BATCH_SIZE = 100
DEFAULT_SPARSE_VECTOR_ENCODER = VectorStoreSparseEncoder.SPLADE

_logger = logging.getLogger(__name__)

import_err_msg_pinecone_client = (
"`pinecone` package not found, please run `pip install pinecone-client`"
)

import_err_msg_pinecone_text = (
"`pinecone_text` package not found, please run `pip install pinecone-text`"
)


def get_default_tokenizer() -> Callable:
"""Get default tokenizer.
NOTE: taken from https://www.pinecone.io/learn/hybrid-search-intro/.
"""
from transformers import BertTokenizerFast

orig_tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
# set some default arguments, so input is just a list of strings
return partial(
orig_tokenizer,
padding=True,
truncation=True,
max_length=512,
)

def build_dict(input_batch: List[List[int]]) -> List[Dict[str, Any]]:

def build_sparse_dict(input_batch: List[List[int]]) -> List[Dict[str, Any]]:
"""Build a list of sparse dictionaries from a batch of input_ids.
NOTE: taken from https://www.pinecone.io/learn/hybrid-search-intro/.
"""
# store a batch of sparse embeddings
sparse_emb = []
# iterate through input batch
for token_ids in input_batch:
indices = []
values = []
Expand All @@ -53,40 +79,50 @@ def build_dict(input_batch: List[List[int]]) -> List[Dict[str, Any]]:
indices.append(idx)
values.append(float(d[idx]))
sparse_emb.append({"indices": indices, "values": values})
# return sparse_emb list
return sparse_emb


def generate_sparse_vectors(
context_batch: List[str], tokenizer: Callable
) -> List[Dict[str, Any]]:
"""Generate sparse vectors from a batch of contexts.
def initialize_sparse_encoder(sparse_encoder_type: VectorStoreSparseEncoder) -> Any:
try:
import pinecone_text.sparse

NOTE: taken from https://www.pinecone.io/learn/hybrid-search-intro/.
encoder_class_name = f"{sparse_encoder_type}Encoder"
encoder_class = getattr(pinecone_text.sparse, encoder_class_name)
return (
encoder_class()
if sparse_encoder_type != VectorStoreSparseEncoder.BM25
else encoder_class().default()
)
except ImportError:
raise ImportError(import_err_msg_pinecone_text)

"""
# create batch of input_ids
inputs = tokenizer(context_batch)["input_ids"]
# create sparse dictionaries
return build_dict(inputs)

def encode_batch(
sparse_encoder: Any, context_batch: List[str], query_mode: bool
) -> List[Dict[str, Any]]:
if query_mode:
sparse_vectors = sparse_encoder.encode_queries(context_batch)
else:
sparse_vectors = sparse_encoder.encode_documents(context_batch)
return sparse_vectors

def get_default_tokenizer() -> Callable:
"""Get default tokenizer.
NOTE: taken from https://www.pinecone.io/learn/hybrid-search-intro/.

def generate_sparse_vectors(
sparse_encoder_type: VectorStoreSparseEncoder,
sparse_encoder: Any,
context_batch: List[str],
tokenizer: Callable,
query_mode: bool = False,
) -> List[Dict[str, Any]]:
"""
from transformers import BertTokenizerFast
Generate sparse vectors from a batch of contexts.
"""
if sparse_encoder_type == VectorStoreSparseEncoder.SPARSE_DICT:
sparse_vectors = build_sparse_dict(tokenizer(context_batch)["input_ids"])
else:
sparse_vectors = encode_batch(sparse_encoder, context_batch, query_mode)

orig_tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
# set some default arguments, so input is just a list of strings
return partial(
orig_tokenizer,
padding=True,
truncation=True,
max_length=512,
)
return sparse_vectors


def _to_pinecone_filter(standard_filters: MetadataFilters) -> dict:
Expand All @@ -97,11 +133,6 @@ def _to_pinecone_filter(standard_filters: MetadataFilters) -> dict:
return filters


import_err_msg = (
"`pinecone` package not found, please run `pip install pinecone-client`"
)


class PineconeVectorStore(BasePydanticVectorStore):
"""Pinecone Vector Store.
Expand All @@ -128,12 +159,14 @@ class PineconeVectorStore(BasePydanticVectorStore):
namespace: Optional[str]
insert_kwargs: Optional[Dict]
add_sparse_vector: bool
sparse_encoder_type: Optional[VectorStoreSparseEncoder]
text_key: str
batch_size: int
remove_text_from_metadata: bool

_pinecone_index: Any = PrivateAttr()
_tokenizer: Optional[Callable] = PrivateAttr()
_sparse_encoder: Optional[Any] = PrivateAttr()

def __init__(
self,
Expand All @@ -144,6 +177,10 @@ def __init__(
namespace: Optional[str] = None,
insert_kwargs: Optional[Dict] = None,
add_sparse_vector: bool = False,
sparse_encoder_type: Optional[
VectorStoreSparseEncoder
] = DEFAULT_SPARSE_VECTOR_ENCODER,
sparse_encoder: Optional[Any] = None,
tokenizer: Optional[Callable] = None,
text_key: str = DEFAULT_TEXT_KEY,
batch_size: int = DEFAULT_BATCH_SIZE,
Expand All @@ -154,7 +191,7 @@ def __init__(
try:
import pinecone
except ImportError:
raise ImportError(import_err_msg)
raise ImportError(import_err_msg_pinecone_client)

if pinecone_index is not None:
self._pinecone_index = cast(pinecone.Index, pinecone_index)
Expand All @@ -174,13 +211,23 @@ def __init__(
tokenizer = get_default_tokenizer()
self._tokenizer = tokenizer

if (
add_sparse_vector
and sparse_encoder is None
and sparse_encoder_type != VectorStoreSparseEncoder.SPARSE_DICT
):
sparse_encoder = initialize_sparse_encoder(sparse_encoder_type) # type: ignore
self._sparse_encoder = sparse_encoder

super().__init__(
index_name=index_name,
environment=environment,
api_key=api_key,
namespace=namespace,
insert_kwargs=insert_kwargs,
add_sparse_vector=add_sparse_vector,
sparse_encoder_type=sparse_encoder_type,
sparse_encoder=sparse_encoder,
text_key=text_key,
batch_size=batch_size,
remove_text_from_metadata=remove_text_from_metadata,
Expand All @@ -195,6 +242,10 @@ def from_params(
namespace: Optional[str] = None,
insert_kwargs: Optional[Dict] = None,
add_sparse_vector: bool = False,
sparse_encoder_type: Optional[
VectorStoreSparseEncoder
] = DEFAULT_SPARSE_VECTOR_ENCODER,
sparse_encoder: Optional[Any] = None,
tokenizer: Optional[Callable] = None,
text_key: str = DEFAULT_TEXT_KEY,
batch_size: int = DEFAULT_BATCH_SIZE,
Expand All @@ -204,7 +255,7 @@ def from_params(
try:
import pinecone
except ImportError:
raise ImportError(import_err_msg)
raise ImportError(import_err_msg_pinecone_client)

pinecone.init(api_key=api_key, environment=environment)
pinecone_index = pinecone.Index(index_name)
Expand All @@ -217,6 +268,8 @@ def from_params(
namespace=namespace,
insert_kwargs=insert_kwargs,
add_sparse_vector=add_sparse_vector,
sparse_encoder_type=sparse_encoder_type,
sparse_encoder=sparse_encoder,
tokenizer=tokenizer,
text_key=text_key,
batch_size=batch_size,
Expand Down Expand Up @@ -255,15 +308,19 @@ def add(
VECTOR_KEY: node.get_embedding(),
METADATA_KEY: metadata,
}

if self.add_sparse_vector and self._tokenizer is not None:
sparse_vector = generate_sparse_vectors(
self.sparse_encoder_type, # type: ignore
self._sparse_encoder, # type: ignore
[node.get_content(metadata_mode=MetadataMode.EMBED)],
self._tokenizer,
)[0]
entry[SPARSE_VECTOR_KEY] = sparse_vector

ids.append(node_id)
entries.append(entry)

self._pinecone_index.upsert(
entries,
namespace=self.namespace,
Expand Down Expand Up @@ -293,13 +350,23 @@ def client(self) -> Any:
return self._pinecone_index

def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.
# Check vector store and query mode compatibility
if not self.add_sparse_vector and query.mode in (
VectorStoreQueryMode.HYBRID,
VectorStoreQueryMode.SPARSE,
):
raise ValueError(
"""Cannot query PineconeVectorStore in HYBRID or SPARSE mode because
the vector store doesn't include sparse values. To have them please
set add_sparse_vectors to True during the PineconeVectorStore initialization."""
)

Args:
query_embedding (List[float]): query embedding
similarity_top_k (int): top k most similar nodes
# Handle query embedding
query_embedding = cast(List[float], query.query_embedding)
if query.alpha is not None:
query_embedding = [v * query.alpha for v in query_embedding]

"""
# Handle sparse vector generation
sparse_vector = None
if (
query.mode in (VectorStoreQueryMode.SPARSE, VectorStoreQueryMode.HYBRID)
Expand All @@ -309,35 +376,37 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
raise ValueError(
"query_str must be specified if mode is SPARSE or HYBRID."
)
sparse_vector = generate_sparse_vectors([query.query_str], self._tokenizer)[
0
]

sparse_vector = generate_sparse_vectors(
self.sparse_encoder_type, # type: ignore
self._sparse_encoder, # type: ignore
[query.query_str],
self._tokenizer,
query_mode=True,
)[0]

if query.alpha is not None:
sparse_vector = {
"indices": sparse_vector["indices"],
"values": [v * (1 - query.alpha) for v in sparse_vector["values"]],
}

query_embedding = None
if query.mode in (VectorStoreQueryMode.DEFAULT, VectorStoreQueryMode.HYBRID):
query_embedding = cast(List[float], query.query_embedding)
if query.alpha is not None:
query_embedding = [v * query.alpha for v in query_embedding]

# Handle filter
if query.filters is not None:
if "filter" in kwargs:
raise ValueError(
"Cannot specify filter via both query and kwargs. "
"Use kwargs only for pinecone specific items that are "
"not supported via the generic query interface."
"Cannot specify filter via both query and kwargs. Use kwargs only for Pinecone-specific items."
)
filter = _to_pinecone_filter(query.filters)
else:
filter = kwargs.pop("filter", {})

# Perform the query based on the mode
response = self._pinecone_index.query(
vector=query_embedding,
sparse_vector=sparse_vector,
sparse_vector=sparse_vector
if query.mode in {VectorStoreQueryMode.SPARSE, VectorStoreQueryMode.HYBRID}
else None,
top_k=query.similarity_top_k,
include_values=True,
include_metadata=True,
Expand All @@ -346,6 +415,7 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
**kwargs,
)

# Process the response
top_k_nodes = []
top_k_ids = []
top_k_scores = []
Expand All @@ -354,14 +424,12 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
node = metadata_dict_to_node(match.metadata)
node.embedding = match.values
except Exception:
# NOTE: deprecated legacy logic for backward compatibility
_logger.debug(
"Failed to parse Node metadata, fallback to legacy logic."
)
metadata, node_info, relationships = legacy_metadata_dict_to_node(
match.metadata, text_key=self.text_key
)

text = match.metadata[self.text_key]
id = match.id
node = TextNode(
Expand Down
10 changes: 10 additions & 0 deletions llama_index/vector_stores/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ class VectorStoreQueryMode(str, Enum):
MMR = "mmr"


class VectorStoreSparseEncoder(str, Enum):
"""Vector store query mode."""

SPARSE_DICT = "SparseDict"

# encoder names implemented in pinecone-text library
BM25 = "BM25"
SPLADE = "Splade"


class ExactMatchFilter(BaseModel):
"""Exact match metadata filter for vector stores.
Expand Down

0 comments on commit 93d9eee

Please sign in to comment.