Skip to content

Commit

Permalink
support different Pinecone initialization depending on the version
Browse files Browse the repository at this point in the history
  • Loading branch information
DosticJelena authored and Jelena Dostić committed Dec 29, 2023
1 parent e114f1f commit 0b016da
Showing 1 changed file with 41 additions and 16 deletions.
57 changes: 41 additions & 16 deletions llama_index/vector_stores/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from functools import partial
from typing import Any, Callable, Dict, List, Optional, cast

from packaging import version
from pkg_resources import get_distribution

from llama_index.bridge.pydantic import PrivateAttr
from llama_index.schema import BaseNode, MetadataMode, TextNode
from llama_index.vector_stores.types import (
Expand Down Expand Up @@ -69,7 +72,8 @@ def _transform_pinecone_filter_operator(operator: str) -> str:


def build_dict(input_batch: List[List[int]]) -> List[Dict[str, Any]]:
"""Build a list of sparse dictionaries from a batch of input_ids.
"""
Build a list of sparse dictionaries from a batch of input_ids.
NOTE: taken from https://www.pinecone.io/learn/hybrid-search-intro/.
Expand All @@ -93,7 +97,8 @@ def build_dict(input_batch: List[List[int]]) -> List[Dict[str, Any]]:
def generate_sparse_vectors(
context_batch: List[str], tokenizer: Callable
) -> List[Dict[str, Any]]:
"""Generate sparse vectors from a batch of contexts.
"""
Generate sparse vectors from a batch of contexts.
NOTE: taken from https://www.pinecone.io/learn/hybrid-search-intro/.
Expand All @@ -105,7 +110,8 @@ def generate_sparse_vectors(


def get_default_tokenizer() -> Callable:
"""Get default tokenizer.
"""
Get default tokenizer.
NOTE: taken from https://www.pinecone.io/learn/hybrid-search-intro/.
Expand Down Expand Up @@ -157,7 +163,8 @@ def _to_pinecone_filter(standard_filters: MetadataFilters) -> dict:


class PineconeVectorStore(BasePydanticVectorStore):
"""Pinecone Vector Store.
"""
Pinecone Vector Store.
In this vector store, embeddings and docs are stored within a
Pinecone index.
Expand Down Expand Up @@ -217,14 +224,24 @@ def __init__(
if pinecone_index is not None:
self._pinecone_index = cast(pinecone.Index, pinecone_index)
else:
if index_name is None or environment is None:
raise ValueError(
"Must specify index_name and environment "
"if not directly passing in client."
)
pinecone_client_version = get_distribution("pinecone-client").version

if version.parse(pinecone_client_version) >= version.parse("3.0.0"):
if index_name is None:
raise ValueError(
"Must specify index_name if not directly passing in client."
)
pinecone_instance = pinecone.Pinecone(api_key=api_key)
self._pinecone_index = pinecone_instance.Index(index_name)
else:
if index_name is None or environment is None:
raise ValueError(
"Must specify index_name and environment "
"if not directly passing in client."
)

pinecone.init(api_key=api_key, environment=environment)
self._pinecone_index = pinecone.Index(index_name)
pinecone.init(api_key=api_key, environment=environment)
self._pinecone_index = pinecone.Index(index_name)

insert_kwargs = insert_kwargs or {}

Expand Down Expand Up @@ -265,8 +282,14 @@ def from_params(
except ImportError:
raise ImportError(import_err_msg)

pinecone.init(api_key=api_key, environment=environment)
pinecone_index = pinecone.Index(index_name)
pinecone_client_version = get_distribution("pinecone-client").version

if version.parse(pinecone_client_version) >= version.parse("3.0.0"):
pinecone_instance = pinecone.Pinecone(api_key=api_key)
pinecone_index = pinecone_instance.Index(index_name)
else:
pinecone.init(api_key=api_key, environment=environment)
pinecone_index = pinecone.Index(index_name)

return cls(
pinecone_index=pinecone_index,
Expand All @@ -286,14 +309,15 @@ def from_params(

@classmethod
def class_name(cls) -> str:
return "PinconeVectorStore"
return "PineconeVectorStore"

def add(
self,
nodes: List[BaseNode],
**add_kwargs: Any,
) -> List[str]:
"""Add nodes to index.
"""
Add nodes to index.
Args:
nodes: List[BaseNode]: list of nodes with embeddings
Expand Down Expand Up @@ -353,7 +377,8 @@ def client(self) -> Any:
return self._pinecone_index

def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.
"""
Query index for top k most similar nodes.
Args:
query_embedding (List[float]): query embedding
Expand Down

0 comments on commit 0b016da

Please sign in to comment.