Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid search optimizations #3

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 122 additions & 54 deletions llama_index/vector_stores/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
VectorStoreQuery,
VectorStoreQueryMode,
VectorStoreQueryResult,
VectorStoreSparseEncoder,
)
from llama_index.vector_stores.utils import (
DEFAULT_TEXT_KEY,
Expand All @@ -31,19 +32,44 @@
METADATA_KEY = "metadata"

DEFAULT_BATCH_SIZE = 100
DEFAULT_SPARSE_VECTOR_ENCODER = VectorStoreSparseEncoder.SPLADE

_logger = logging.getLogger(__name__)

import_err_msg_pinecone_client = (
"`pinecone` package not found, please run `pip install pinecone-client`"
)

import_err_msg_pinecone_text = (
"`pinecone_text` package not found, please run `pip install pinecone-text`"
)


def get_default_tokenizer() -> Callable:
"""Get default tokenizer.

NOTE: taken from https://www.pinecone.io/learn/hybrid-search-intro/.

"""
from transformers import BertTokenizerFast

orig_tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
# set some default arguments, so input is just a list of strings
return partial(
orig_tokenizer,
padding=True,
truncation=True,
max_length=512,
)

def build_dict(input_batch: List[List[int]]) -> List[Dict[str, Any]]:

def build_sparse_dict(input_batch: List[List[int]]) -> List[Dict[str, Any]]:
"""Build a list of sparse dictionaries from a batch of input_ids.

NOTE: taken from https://www.pinecone.io/learn/hybrid-search-intro/.

"""
# store a batch of sparse embeddings
sparse_emb = []
# iterate through input batch
for token_ids in input_batch:
indices = []
values = []
Expand All @@ -53,40 +79,50 @@ def build_dict(input_batch: List[List[int]]) -> List[Dict[str, Any]]:
indices.append(idx)
values.append(float(d[idx]))
sparse_emb.append({"indices": indices, "values": values})
# return sparse_emb list
return sparse_emb


def generate_sparse_vectors(
context_batch: List[str], tokenizer: Callable
) -> List[Dict[str, Any]]:
"""Generate sparse vectors from a batch of contexts.
def initialize_sparse_encoder(sparse_encoder_type: VectorStoreSparseEncoder) -> Any:
try:
import pinecone_text.sparse

NOTE: taken from https://www.pinecone.io/learn/hybrid-search-intro/.
encoder_class_name = f"{sparse_encoder_type}Encoder"
encoder_class = getattr(pinecone_text.sparse, encoder_class_name)
return (
encoder_class()
if sparse_encoder_type != VectorStoreSparseEncoder.BM25
else encoder_class().default()
)
except ImportError:
raise ImportError(import_err_msg_pinecone_text)

"""
# create batch of input_ids
inputs = tokenizer(context_batch)["input_ids"]
# create sparse dictionaries
return build_dict(inputs)

def encode_batch(
sparse_encoder: Any, context_batch: List[str], query_mode: bool
) -> List[Dict[str, Any]]:
if query_mode:
sparse_vectors = sparse_encoder.encode_queries(context_batch)
else:
sparse_vectors = sparse_encoder.encode_documents(context_batch)
return sparse_vectors

def get_default_tokenizer() -> Callable:
"""Get default tokenizer.

NOTE: taken from https://www.pinecone.io/learn/hybrid-search-intro/.

def generate_sparse_vectors(
sparse_encoder_type: VectorStoreSparseEncoder,
sparse_encoder: Any,
context_batch: List[str],
tokenizer: Callable,
query_mode: bool = False,
) -> List[Dict[str, Any]]:
"""
from transformers import BertTokenizerFast
Generate sparse vectors from a batch of contexts.
"""
if sparse_encoder_type == VectorStoreSparseEncoder.SPARSE_DICT:
sparse_vectors = build_sparse_dict(tokenizer(context_batch)["input_ids"])
else:
sparse_vectors = encode_batch(sparse_encoder, context_batch, query_mode)

orig_tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
# set some default arguments, so input is just a list of strings
return partial(
orig_tokenizer,
padding=True,
truncation=True,
max_length=512,
)
return sparse_vectors


def _to_pinecone_filter(standard_filters: MetadataFilters) -> dict:
Expand All @@ -97,11 +133,6 @@ def _to_pinecone_filter(standard_filters: MetadataFilters) -> dict:
return filters


import_err_msg = (
"`pinecone` package not found, please run `pip install pinecone-client`"
)


class PineconeVectorStore(BasePydanticVectorStore):
"""Pinecone Vector Store.

Expand All @@ -128,12 +159,14 @@ class PineconeVectorStore(BasePydanticVectorStore):
namespace: Optional[str]
insert_kwargs: Optional[Dict]
add_sparse_vector: bool
sparse_encoder_type: Optional[VectorStoreSparseEncoder]
text_key: str
batch_size: int
remove_text_from_metadata: bool

_pinecone_index: Any = PrivateAttr()
_tokenizer: Optional[Callable] = PrivateAttr()
_sparse_encoder: Optional[Any] = PrivateAttr()

def __init__(
self,
Expand All @@ -144,6 +177,10 @@ def __init__(
namespace: Optional[str] = None,
insert_kwargs: Optional[Dict] = None,
add_sparse_vector: bool = False,
sparse_encoder_type: Optional[
VectorStoreSparseEncoder
] = DEFAULT_SPARSE_VECTOR_ENCODER,
sparse_encoder: Optional[Any] = None,
tokenizer: Optional[Callable] = None,
text_key: str = DEFAULT_TEXT_KEY,
batch_size: int = DEFAULT_BATCH_SIZE,
Expand All @@ -154,7 +191,7 @@ def __init__(
try:
import pinecone
except ImportError:
raise ImportError(import_err_msg)
raise ImportError(import_err_msg_pinecone_client)

if pinecone_index is not None:
self._pinecone_index = cast(pinecone.Index, pinecone_index)
Expand All @@ -174,13 +211,23 @@ def __init__(
tokenizer = get_default_tokenizer()
self._tokenizer = tokenizer

if (
add_sparse_vector
and sparse_encoder is None
and sparse_encoder_type != VectorStoreSparseEncoder.SPARSE_DICT
):
sparse_encoder = initialize_sparse_encoder(sparse_encoder_type) # type: ignore
self._sparse_encoder = sparse_encoder

super().__init__(
index_name=index_name,
environment=environment,
api_key=api_key,
namespace=namespace,
insert_kwargs=insert_kwargs,
add_sparse_vector=add_sparse_vector,
sparse_encoder_type=sparse_encoder_type,
sparse_encoder=sparse_encoder,
text_key=text_key,
batch_size=batch_size,
remove_text_from_metadata=remove_text_from_metadata,
Expand All @@ -195,6 +242,10 @@ def from_params(
namespace: Optional[str] = None,
insert_kwargs: Optional[Dict] = None,
add_sparse_vector: bool = False,
sparse_encoder_type: Optional[
VectorStoreSparseEncoder
] = DEFAULT_SPARSE_VECTOR_ENCODER,
sparse_encoder: Optional[Any] = None,
tokenizer: Optional[Callable] = None,
text_key: str = DEFAULT_TEXT_KEY,
batch_size: int = DEFAULT_BATCH_SIZE,
Expand All @@ -204,7 +255,7 @@ def from_params(
try:
import pinecone
except ImportError:
raise ImportError(import_err_msg)
raise ImportError(import_err_msg_pinecone_client)

pinecone.init(api_key=api_key, environment=environment)
pinecone_index = pinecone.Index(index_name)
Expand All @@ -217,6 +268,8 @@ def from_params(
namespace=namespace,
insert_kwargs=insert_kwargs,
add_sparse_vector=add_sparse_vector,
sparse_encoder_type=sparse_encoder_type,
sparse_encoder=sparse_encoder,
tokenizer=tokenizer,
text_key=text_key,
batch_size=batch_size,
Expand Down Expand Up @@ -255,15 +308,19 @@ def add(
VECTOR_KEY: node.get_embedding(),
METADATA_KEY: metadata,
}

if self.add_sparse_vector and self._tokenizer is not None:
sparse_vector = generate_sparse_vectors(
self.sparse_encoder_type, # type: ignore
self._sparse_encoder, # type: ignore
[node.get_content(metadata_mode=MetadataMode.EMBED)],
self._tokenizer,
)[0]
entry[SPARSE_VECTOR_KEY] = sparse_vector

ids.append(node_id)
entries.append(entry)

self._pinecone_index.upsert(
entries,
namespace=self.namespace,
Expand Down Expand Up @@ -293,13 +350,23 @@ def client(self) -> Any:
return self._pinecone_index

def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.
# Check vector store and query mode compatibility
if not self.add_sparse_vector and query.mode in (
VectorStoreQueryMode.HYBRID,
VectorStoreQueryMode.SPARSE,
):
raise ValueError(
"""Cannot query PineconeVectorStore in HYBRID or SPARSE mode because
the vector store doesn't include sparse values. To have them please
set add_sparse_vectors to True during the PineconeVectorStore initialization."""
)

Args:
query_embedding (List[float]): query embedding
similarity_top_k (int): top k most similar nodes
# Handle query embedding
query_embedding = cast(List[float], query.query_embedding)
if query.alpha is not None:
query_embedding = [v * query.alpha for v in query_embedding]

"""
# Handle sparse vector generation
sparse_vector = None
if (
query.mode in (VectorStoreQueryMode.SPARSE, VectorStoreQueryMode.HYBRID)
Expand All @@ -309,35 +376,37 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
raise ValueError(
"query_str must be specified if mode is SPARSE or HYBRID."
)
sparse_vector = generate_sparse_vectors([query.query_str], self._tokenizer)[
0
]

sparse_vector = generate_sparse_vectors(
self.sparse_encoder_type, # type: ignore
self._sparse_encoder, # type: ignore
[query.query_str],
self._tokenizer,
query_mode=True,
)[0]

if query.alpha is not None:
sparse_vector = {
"indices": sparse_vector["indices"],
"values": [v * (1 - query.alpha) for v in sparse_vector["values"]],
}

query_embedding = None
if query.mode in (VectorStoreQueryMode.DEFAULT, VectorStoreQueryMode.HYBRID):
query_embedding = cast(List[float], query.query_embedding)
if query.alpha is not None:
query_embedding = [v * query.alpha for v in query_embedding]

# Handle filter
if query.filters is not None:
if "filter" in kwargs:
raise ValueError(
"Cannot specify filter via both query and kwargs. "
"Use kwargs only for pinecone specific items that are "
"not supported via the generic query interface."
"Cannot specify filter via both query and kwargs. Use kwargs only for Pinecone-specific items."
)
filter = _to_pinecone_filter(query.filters)
else:
filter = kwargs.pop("filter", {})

# Perform the query based on the mode
response = self._pinecone_index.query(
vector=query_embedding,
sparse_vector=sparse_vector,
sparse_vector=sparse_vector
if query.mode in {VectorStoreQueryMode.SPARSE, VectorStoreQueryMode.HYBRID}
else None,
top_k=query.similarity_top_k,
include_values=True,
include_metadata=True,
Expand All @@ -346,6 +415,7 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
**kwargs,
)

# Process the response
top_k_nodes = []
top_k_ids = []
top_k_scores = []
Expand All @@ -354,14 +424,12 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
node = metadata_dict_to_node(match.metadata)
node.embedding = match.values
except Exception:
# NOTE: deprecated legacy logic for backward compatibility
_logger.debug(
"Failed to parse Node metadata, fallback to legacy logic."
)
metadata, node_info, relationships = legacy_metadata_dict_to_node(
match.metadata, text_key=self.text_key
)

text = match.metadata[self.text_key]
id = match.id
node = TextNode(
Expand Down
10 changes: 10 additions & 0 deletions llama_index/vector_stores/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ class VectorStoreQueryMode(str, Enum):
MMR = "mmr"


class VectorStoreSparseEncoder(str, Enum):
"""Vector store query mode."""

SPARSE_DICT = "SparseDict"

# encoder names implemented in pinecone-text library
BM25 = "BM25"
SPLADE = "Splade"


class ExactMatchFilter(BaseModel):
"""Exact match metadata filter for vector stores.

Expand Down