From 5dc5c6b01ea4f6a170edae9102cd168ff789ed46 Mon Sep 17 00:00:00 2001 From: Arnav Singhvi Date: Fri, 1 Nov 2024 14:57:17 -0700 Subject: [PATCH 1/3] wip - dspy.RM/retrieve refactor --- dspy/clients/__init__.py | 3 +- dspy/clients/rm.py | 33 +++++++ dspy/retrieve/__init__.py | 3 +- dspy/retrieve/embedder.py | 16 ++++ dspy/retrieve/retrieve.py | 85 ++---------------- examples/rm_migration.ipynb | 167 ++++++++++++++++++++++++++++++++++++ 6 files changed, 227 insertions(+), 80 deletions(-) create mode 100644 dspy/clients/rm.py create mode 100644 dspy/retrieve/embedder.py create mode 100644 examples/rm_migration.ipynb diff --git a/dspy/clients/__init__.py b/dspy/clients/__init__.py index 079dc5420..db2f86da1 100644 --- a/dspy/clients/__init__.py +++ b/dspy/clients/__init__.py @@ -1 +1,2 @@ -from .lm import LM \ No newline at end of file +from .lm import LM +from .rm import RM \ No newline at end of file diff --git a/dspy/clients/rm.py b/dspy/clients/rm.py new file mode 100644 index 000000000..2357adb54 --- /dev/null +++ b/dspy/clients/rm.py @@ -0,0 +1,33 @@ +from typing import Any, Callable, List, Optional + +from dspy.primitives.prediction import Prediction +from dspy.retrieve.embedder import Embedder + +class RM: + def __init__( + self, + search_function: Callable[..., Any], + embedder: Optional[Embedder] = None, + result_formatter: Optional[Callable[[Any], Prediction]] = None, + **provider_kwargs + ): + self.embedder = embedder + self.search_function = search_function + self.result_formatter = result_formatter or self.default_formatter + self.provider_kwargs = provider_kwargs + + def __call__(self, query: str, k: Optional[int] = None) -> Prediction: + if self.embedder: + query_vector = self.embedder([query])[0] + query_input = query_vector + else: + query_input = query + search_args = self.provider_kwargs.copy() + search_args['query'] = query_input + if k is not None: + search_args['k'] = k + results = self.search_function(**search_args) + return self.result_formatter(results) + + def default_formatter(self, results) -> Prediction: + return Prediction(passages=results) \ No newline at end of file diff --git a/dspy/retrieve/__init__.py b/dspy/retrieve/__init__.py index 2f699c23a..67b83464e 100644 --- a/dspy/retrieve/__init__.py +++ b/dspy/retrieve/__init__.py @@ -1 +1,2 @@ -from .retrieve import Retrieve, RetrieveThenRerank \ No newline at end of file +from .retrieve import Retrieve, RetrieveThenRerank +from .embedder import Embedder \ No newline at end of file diff --git a/dspy/retrieve/embedder.py b/dspy/retrieve/embedder.py new file mode 100644 index 000000000..d9b8c182d --- /dev/null +++ b/dspy/retrieve/embedder.py @@ -0,0 +1,16 @@ +from typing import Callable, List, Optional + + +class Embedder: + def __init__(self, embedding_model: str = 'text-embedding-ada-002', embedding_function: Optional[Callable[[List[str]], List[List[float]]]] = None): + self.embedding_model = embedding_model + self.embedding_function = embedding_function or self.default_embedding_function + + def default_embedding_function(self, texts: List[str]) -> List[List[float]]: + from litellm import embedding + embeddings_response = embedding(model=self.embedding_model, input=texts) + embeddings = [data['embedding'] for data in embeddings_response.data] + return embeddings + + def __call__(self, texts: List[str]) -> List[List[float]]: + return self.embedding_function(texts) \ No newline at end of file diff --git a/dspy/retrieve/retrieve.py b/dspy/retrieve/retrieve.py index 37ac0390d..6fe282d0c 100644 --- a/dspy/retrieve/retrieve.py +++ b/dspy/retrieve/retrieve.py @@ -4,7 +4,7 @@ import dsp from dspy.predict.parameter import Parameter from dspy.primitives.prediction import Prediction -from dspy.utils.callback import with_callbacks +from dspy.clients import RM def single_query_passage(passages): @@ -22,86 +22,15 @@ class Retrieve(Parameter): input_variable = "query" desc = "takes a search query and returns one or more potentially relevant passages from a corpus" - def __init__(self, k=3, callbacks=None): - self.stage = random.randbytes(8).hex() + def __init__(self, rm: RM, k=3): + self.rm = rm self.k = k - self.callbacks = callbacks or [] - - def reset(self): - pass - - def dump_state(self, save_verbose=False): - """save_verbose is set as a default argument to support the inherited Parameter interface for dump_state""" - state_keys = ["k"] - return {k: getattr(self, k) for k in state_keys} - - def load_state(self, state): - for name, value in state.items(): - setattr(self, name, value) - - @with_callbacks - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - def forward( - self, - query_or_queries: Union[str, List[str]] = None, - query: Optional[str] = None, - k: Optional[int] = None, - by_prob: bool = True, - with_metadata: bool = False, - **kwargs, - ) -> Union[List[str], Prediction, List[Prediction]]: - query_or_queries = query_or_queries or query - # queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries - # queries = [query.strip().split('\n')[0].strip() for query in queries] + #TODO - add back saving/loading for retrievers - # # print(queries) - # # TODO: Consider removing any quote-like markers that surround the query too. - # k = k if k is not None else self.k - # passages = dsp.retrieveEnsemble(queries, k=k,**kwargs) - # return Prediction(passages=passages) - queries = ( - [query_or_queries] - if isinstance(query_or_queries, str) - else query_or_queries - ) - queries = [query.strip().split("\n")[0].strip() for query in queries] - - # print(queries) - # TODO: Consider removing any quote-like markers that surround the query too. + def __call__(self, query_or_queries, k=None): k = k if k is not None else self.k - if not with_metadata: - passages = dsp.retrieveEnsemble(queries, k=k, by_prob=by_prob, **kwargs) - return Prediction(passages=passages) - else: - passages = dsp.retrieveEnsemblewithMetadata( - queries, k=k, by_prob=by_prob, **kwargs, - ) - if isinstance(passages[0], List): - pred_returns = [] - for query_passages in passages: - passages_dict = { - key: [] - for key in list(query_passages[0].keys()) - if key != "tracking_idx" - } - for psg in query_passages: - for key, value in psg.items(): - if key == "tracking_idx": - continue - passages_dict[key].append(value) - if "long_text" in passages_dict: - passages_dict["passages"] = passages_dict.pop("long_text") - pred_returns.append(Prediction(**passages_dict)) - return pred_returns - elif isinstance(passages[0], Dict): - # passages dict will contain {"long_text":long_text_list,"metadatas";metadatas_list...} - return single_query_passage(passages=passages) - - -# TODO: Consider doing Prediction.from_completions with the individual sets of passages (per query) too. + return self.rm(query_or_queries, k=k) class RetrieveThenRerank(Parameter): @@ -163,4 +92,4 @@ def forward( pred_returns.append(Prediction(**passages_dict)) return pred_returns elif isinstance(passages[0], Dict): - return single_query_passage(passages=passages) + return single_query_passage(passages=passages) \ No newline at end of file diff --git a/examples/rm_migration.ipynb b/examples/rm_migration.ipynb new file mode 100644 index 000000000..312fb1f99 --- /dev/null +++ b/examples/rm_migration.ipynb @@ -0,0 +1,167 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "{DSPy.RM Migration - TBD}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#Querying ColBERTv2 \n", + "\n", + "import requests\n", + "import os\n", + "from typing import Any, Dict, List, Optional, Union\n", + "from dspy import RM, Retrieve, Embedder\n", + "from dspy.primitives.prediction import Prediction\n", + "\n", + "def colbert_search_function(query: str, k: int, url: str, post_requests: bool = False) -> List[Dict[str, Any]]:\n", + " if post_requests:\n", + " headers = {\"Content-Type\": \"application/json; charset=utf-8\"}\n", + " payload = {\"query\": query, \"k\": k}\n", + " res = requests.post(url, json=payload, headers=headers, timeout=10)\n", + " else:\n", + " payload = {\"query\": query, \"k\": k}\n", + " res = requests.get(url, params=payload, timeout=10)\n", + " \n", + " res.raise_for_status()\n", + " topk = res.json()[\"topk\"][:k]\n", + " topk = [{**doc, \"long_text\": doc.get(\"text\", \"\")} for doc in topk]\n", + " return topk\n", + "\n", + "def colbert_result_formatter(results: List[Dict[str, Any]]) -> Prediction:\n", + " passages = [doc[\"long_text\"] for doc in results]\n", + " return Prediction(passages=passages)\n", + "\n", + "colbert_url = \"http://20.102.90.50:2017/wiki17_abstracts\"\n", + "\n", + "colbert_rm = RM(\n", + " search_function=colbert_search_function,\n", + " result_formatter=colbert_result_formatter,\n", + " url=colbert_url,\n", + " post_requests=False\n", + ")\n", + "\n", + "retrieve = Retrieve(rm=colbert_rm, k=10)\n", + "query_text = \"Example query text\"\n", + "results = retrieve(query_text)\n", + "print(results.passages)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#Querying Databricks Mosaic AI Vector Search \n", + "\n", + "#client setup\n", + "databricks_token = os.environ.get(\"DATABRICKS_TOKEN\")\n", + "databricks_endpoint = os.environ.get(\"DATABRICKS_HOST\")\n", + "databricks_client = WorkspaceClient(host=databricks_endpoint, token=databricks_token)\n", + "\n", + "#custom logic for querying and sorting the docs\n", + "def databricks_search_function(\n", + " query,\n", + " k,\n", + " index_name,\n", + " columns,\n", + " query_type='ANN',\n", + " filters_json=None,\n", + " client=None\n", + "):\n", + " results = client.vector_search_indexes.query(\n", + " index_name=index_name,\n", + " query_type=query_type,\n", + " query_text=query,\n", + " num_results=k,\n", + " columns=columns,\n", + " filters_json=filters_json,\n", + " ).as_dict()\n", + "\n", + " items = []\n", + " col_names = [column[\"name\"] for column in results[\"manifest\"][\"columns\"]]\n", + " for data_row in results[\"result\"][\"data_array\"]:\n", + " item = {col_name: val for col_name, val in zip(col_names, data_row)}\n", + " items.append(item)\n", + " sorted_docs = sorted(items, key=lambda x: x[\"score\"], reverse=True)\n", + " return sorted_docs\n", + "\n", + "def databricks_result_formatter(results) -> Prediction:\n", + " passages = [doc['some_text_column'] for doc in results] \n", + " return Prediction(passages=passages)\n", + "\n", + "databricks_rm = RM(\n", + " search_function=databricks_search_function,\n", + " result_formatter=databricks_result_formatter,\n", + " client=databricks_client,\n", + " index_name='your_index_name',\n", + " columns=['id', 'some_text_column'],\n", + " filters_json=None\n", + ")\n", + "\n", + "retrieve = Retrieve(rm=databricks_rm, k=3)\n", + "results = retrieve(\"Example query text\")\n", + "print(results.passages)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#Querying Deeplake Vector Store\n", + "\n", + "embedder = Embedder()\n", + "\n", + "deeplake_vectorstore_name = 'vectorstore_name'\n", + "deeplake_client = deeplake.VectorStore(\n", + " path=deeplake_vectorstore_name,\n", + " embedding_function=embedder\n", + ")\n", + "\n", + "def deeplake_search_function(query, k, client=None):\n", + " results = client.search(query, k=k)\n", + " return results\n", + "\n", + "def deeplake_result_formatter(results) -> Prediction:\n", + " passages = [doc['text'] for doc in results['documents']]\n", + " return Prediction(passages=passages)\n", + "\n", + "\n", + "deeplake_rm = RM(\n", + " embedder=embedder,\n", + " search_function=deeplake_search_function,\n", + " result_formatter=deeplake_result_formatter,\n", + " client=deeplake_client\n", + ")\n", + "\n", + "retrieve = Retrieve(rm=deeplake_rm, k=3)\n", + "results = retrieve(\"some text\")\n", + "print(results.passages)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "TBD..." + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From d69acddf05ae0e8eec7f7d6fda7366890cb9d9cb Mon Sep 17 00:00:00 2001 From: Arnav Singhvi Date: Thu, 14 Nov 2024 21:35:49 -0800 Subject: [PATCH 2/3] update dspy.Retrieve interface --- dsp/modules/colbertv2.py | 2 +- dsp/primitives/search.py | 1 + dspy/__init__.py | 2 +- dspy/clients/__init__.py | 3 +- dspy/clients/embedding.py | 43 +++++---- dspy/clients/rm.py | 33 ------- dspy/retrieve/__init__.py | 3 +- dspy/retrieve/colbertv2_rm.py | 62 +++++++++++++ dspy/retrieve/embedder.py | 16 ---- dspy/retrieve/retrieve.py | 41 ++++++--- examples/rm_migration.ipynb | 167 ---------------------------------- 11 files changed, 121 insertions(+), 252 deletions(-) delete mode 100644 dspy/clients/rm.py create mode 100644 dspy/retrieve/colbertv2_rm.py delete mode 100644 dspy/retrieve/embedder.py delete mode 100644 examples/rm_migration.ipynb diff --git a/dsp/modules/colbertv2.py b/dsp/modules/colbertv2.py index 67b246c5e..47b6e0d7b 100644 --- a/dsp/modules/colbertv2.py +++ b/dsp/modules/colbertv2.py @@ -8,7 +8,7 @@ # TODO: Ideally, this takes the name of the index and looks up its port. - +#TODO remove references of ColBERTv2 from here now that it is supported in retrieve/ class ColBERTv2: """Wrapper for the ColBERTv2 Retrieval.""" diff --git a/dsp/primitives/search.py b/dsp/primitives/search.py index 1ad9a07cd..81689f122 100644 --- a/dsp/primitives/search.py +++ b/dsp/primitives/search.py @@ -7,6 +7,7 @@ logger = logging.getLogger(__name__) +#TODO remove references now that Retrieve interface is supported def retrieve(query: str, k: int, **kwargs) -> list[str]: """Retrieves passages from the RM for the query and returns the top k passages.""" if not dsp.settings.rm: diff --git a/dspy/__init__.py b/dspy/__init__.py index 3ba977eb8..3e95d1667 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -24,7 +24,7 @@ Mistral = dsp.Mistral Databricks = dsp.Databricks Cohere = dsp.Cohere -ColBERTv2 = dsp.ColBERTv2 +# ColBERTv2 = dsp.ColBERTv2 ColBERTv2RerankerLocal = dsp.ColBERTv2RerankerLocal ColBERTv2RetrieverLocal = dsp.ColBERTv2RetrieverLocal Pyserini = dsp.PyseriniRetriever diff --git a/dspy/clients/__init__.py b/dspy/clients/__init__.py index de50766c9..2fc0e2543 100644 --- a/dspy/clients/__init__.py +++ b/dspy/clients/__init__.py @@ -1,8 +1,7 @@ from .lm import LM -from .rm import RM from .provider import Provider, TrainingJob from .base_lm import BaseLM, inspect_history -from .embedding import Embedding +from .embedding import Embedder import litellm import os from pathlib import Path diff --git a/dspy/clients/embedding.py b/dspy/clients/embedding.py index eec41c32b..ccaaedab0 100644 --- a/dspy/clients/embedding.py +++ b/dspy/clients/embedding.py @@ -1,8 +1,8 @@ import litellm import numpy as np +from typing import Callable, List, Union, Optional - -class Embedding: +class Embedder: """DSPy embedding class. The class for computing embeddings for text inputs. This class provides a unified interface for both: @@ -13,7 +13,7 @@ class Embedding: For hosted models, simply pass the model name as a string (e.g. "openai/text-embedding-3-small"). The class will use litellm to handle the API calls and caching. - For custom embedding models, pass a callable function that: + For custom embedding models, pass a callable function to `embedding_function` that: - Takes a list of strings as input. - Returns embeddings as either: - A 2D numpy array of float32 values @@ -21,9 +21,10 @@ class Embedding: - Each row should represent one embedding vector Args: - model: The embedding model to use. This can be either a string (representing the name of the hosted embedding - model, must be an embedding model supported by litellm) or a callable that represents a custom embedding - model. + embedding_model: The embedding model to use, either a string (for hosted models supported by litellm) or + a callable function that returns custom embeddings. + embedding_function: An optional custom embedding function. If not provided, defaults to litellm + for hosted models when `embedding_model` is a string. Examples: Example 1: Using a hosted model. @@ -31,7 +32,7 @@ class Embedding: ```python import dspy - embedder = dspy.Embedding("openai/text-embedding-3-small") + embedder = dspy.Embedding(embedding_model="openai/text-embedding-3-small") embeddings = embedder(["hello", "world"]) assert embeddings.shape == (2, 1536) @@ -41,21 +42,27 @@ class Embedding: ```python import dspy + import numpy as np def my_embedder(texts): return np.random.rand(len(texts), 10) - embedder = dspy.Embedding(my_embedder) + embedder = dspy.Embedding(embedding_function=my_embedder) embeddings = embedder(["hello", "world"]) assert embeddings.shape == (2, 10) ``` """ - def __init__(self, model): - self.model = model + def __init__(self, embedding_model: Union[str, Callable[[List[str]], List[List[float]]]] = 'text-embedding-ada-002', embedding_function: Optional[Callable[[List[str]], List[List[float]]]] = None): + self.embedding_model = embedding_model + self.embedding_function = embedding_function or self.default_embedding_function + + def default_embedding_function(self, texts: List[str], caching: bool = True, **kwargs) -> List[List[float]]: + embeddings_response = litellm.embedding(model=self.embedding_model, input=texts, caching=caching, **kwargs) + return [data['embedding'] for data in embeddings_response.data] - def __call__(self, inputs, caching=True, **kwargs): + def __call__(self, inputs: Union[str, List[str]], caching: bool = True, **kwargs) -> np.ndarray: """Compute embeddings for the given inputs. Args: @@ -68,10 +75,12 @@ def __call__(self, inputs, caching=True, **kwargs): """ if isinstance(inputs, str): inputs = [inputs] - if isinstance(self.model, str): - embedding_response = litellm.embedding(model=self.model, input=inputs, caching=caching, **kwargs) - return np.array([data["embedding"] for data in embedding_response.data], dtype=np.float32) - elif callable(self.model): - return np.array(self.model(inputs, **kwargs), dtype=np.float32) + if callable(self.embedding_function): + embeddings = self.embedding_function(inputs, **kwargs) + elif isinstance(self.embedding_model, str): + embeddings = self.default_embedding_function(inputs, caching=caching, **kwargs) else: - raise ValueError(f"`model` in `dspy.Embedding` must be a string or a callable, but got {type(self.model)}.") + raise ValueError( + f"`embedding_model` must be a string or `embedding_function` must be a callable, but got types: `embedding_model`={type(self.embedding_model)}, `embedding_function`={type(self.embedding_function)}." + ) + return np.array(embeddings, dtype=np.float32) diff --git a/dspy/clients/rm.py b/dspy/clients/rm.py deleted file mode 100644 index 2357adb54..000000000 --- a/dspy/clients/rm.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Any, Callable, List, Optional - -from dspy.primitives.prediction import Prediction -from dspy.retrieve.embedder import Embedder - -class RM: - def __init__( - self, - search_function: Callable[..., Any], - embedder: Optional[Embedder] = None, - result_formatter: Optional[Callable[[Any], Prediction]] = None, - **provider_kwargs - ): - self.embedder = embedder - self.search_function = search_function - self.result_formatter = result_formatter or self.default_formatter - self.provider_kwargs = provider_kwargs - - def __call__(self, query: str, k: Optional[int] = None) -> Prediction: - if self.embedder: - query_vector = self.embedder([query])[0] - query_input = query_vector - else: - query_input = query - search_args = self.provider_kwargs.copy() - search_args['query'] = query_input - if k is not None: - search_args['k'] = k - results = self.search_function(**search_args) - return self.result_formatter(results) - - def default_formatter(self, results) -> Prediction: - return Prediction(passages=results) \ No newline at end of file diff --git a/dspy/retrieve/__init__.py b/dspy/retrieve/__init__.py index 67b83464e..2f699c23a 100644 --- a/dspy/retrieve/__init__.py +++ b/dspy/retrieve/__init__.py @@ -1,2 +1 @@ -from .retrieve import Retrieve, RetrieveThenRerank -from .embedder import Embedder \ No newline at end of file +from .retrieve import Retrieve, RetrieveThenRerank \ No newline at end of file diff --git a/dspy/retrieve/colbertv2_rm.py b/dspy/retrieve/colbertv2_rm.py new file mode 100644 index 000000000..387d97739 --- /dev/null +++ b/dspy/retrieve/colbertv2_rm.py @@ -0,0 +1,62 @@ +from typing import Any, Union, Optional, List +from dspy.primitives.prediction import Prediction +import dspy +from dsp.utils import dotdict +import requests +import functools + +class ColBERTv2(dspy.Retrieve): + def __init__(self, url: str = "http://0.0.0.0", port: Optional[Union[str, int]] = None, post_requests: bool = False): + super().__init__(embedder=None) + self.post_requests = post_requests + self.url = f"{url}:{port}" if port else url + + def forward(self, query: str, k: int = 10) -> Any: + if self.post_requests: + topk = colbertv2_post_request(self.url, query, k) + else: + topk = colbertv2_get_request(self.url, query, k) + return dotdict({'passages': [dotdict(psg) for psg in topk]}) + + +from dsp.modules.cache_utils import CacheMemory, NotebookCacheMemory +from dsp.utils import dotdict +@CacheMemory.cache +def colbertv2_get_request_v2(url: str, query: str, k: int): + assert ( + k <= 100 + ), "Only k <= 100 is supported for the hosted ColBERTv2 server at the moment." + + payload = {"query": query, "k": k} + res = requests.get(url, params=payload, timeout=10) + + topk = res.json()["topk"][:k] + topk = [{**d, "long_text": d["text"]} for d in topk] + return topk[:k] + + +@functools.cache +@NotebookCacheMemory.cache +def colbertv2_get_request_v2_wrapped(*args, **kwargs): + return colbertv2_get_request_v2(*args, **kwargs) + + +colbertv2_get_request = colbertv2_get_request_v2_wrapped + + +@CacheMemory.cache +def colbertv2_post_request_v2(url: str, query: str, k: int): + headers = {"Content-Type": "application/json; charset=utf-8"} + payload = {"query": query, "k": k} + res = requests.post(url, json=payload, headers=headers, timeout=10) + + return res.json()["topk"][:k] + + +@functools.cache +@NotebookCacheMemory.cache +def colbertv2_post_request_v2_wrapped(*args, **kwargs): + return colbertv2_post_request_v2(*args, **kwargs) + + +colbertv2_post_request = colbertv2_post_request_v2_wrapped \ No newline at end of file diff --git a/dspy/retrieve/embedder.py b/dspy/retrieve/embedder.py deleted file mode 100644 index d9b8c182d..000000000 --- a/dspy/retrieve/embedder.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import Callable, List, Optional - - -class Embedder: - def __init__(self, embedding_model: str = 'text-embedding-ada-002', embedding_function: Optional[Callable[[List[str]], List[List[float]]]] = None): - self.embedding_model = embedding_model - self.embedding_function = embedding_function or self.default_embedding_function - - def default_embedding_function(self, texts: List[str]) -> List[List[float]]: - from litellm import embedding - embeddings_response = embedding(model=self.embedding_model, input=texts) - embeddings = [data['embedding'] for data in embeddings_response.data] - return embeddings - - def __call__(self, texts: List[str]) -> List[List[float]]: - return self.embedding_function(texts) \ No newline at end of file diff --git a/dspy/retrieve/retrieve.py b/dspy/retrieve/retrieve.py index 6fe282d0c..820178dec 100644 --- a/dspy/retrieve/retrieve.py +++ b/dspy/retrieve/retrieve.py @@ -1,10 +1,12 @@ import random -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import dsp from dspy.predict.parameter import Parameter +from abc import ABC, abstractmethod from dspy.primitives.prediction import Prediction -from dspy.clients import RM +from dspy.clients.embedding import Embedder +from dspy.utils.callback import with_callbacks def single_query_passage(passages): @@ -17,20 +19,33 @@ def single_query_passage(passages): return Prediction(**passages_dict) -class Retrieve(Parameter): - name = "Search" - input_variable = "query" - desc = "takes a search query and returns one or more potentially relevant passages from a corpus" - - def __init__(self, rm: RM, k=3): - self.rm = rm +class Retrieve(ABC): + def __init__(self, embedder: Optional[Embedder] = None, k: int = 5, callbacks: Optional[List[Any]] = None): + self.embedder = embedder self.k = k + self.callbacks = callbacks or [] + + @abstractmethod + def forward(self, query: str, k: Optional[int] = None) -> Any: + """ + Retrievers implement this method with their custom retrieval logic. + Must return an object that has a 'passages' attribute (ideally `dspy.Prediction`). + """ + pass - #TODO - add back saving/loading for retrievers - - def __call__(self, query_or_queries, k=None): + def __call__(self, query: str, k: Optional[int] = None) -> Any: + """ + Calls the forward method and checks if the result has a 'passages' attribute. + """ k = k if k is not None else self.k - return self.rm(query_or_queries, k=k) + result = self.forward(query, k) + if not hasattr(result, 'passages'): + raise ValueError("The 'forward' method must return an object with a 'passages' attribute (ideally `dspy.Prediction`).") + for callback in self.callbacks: + callback(result) + return result + +# TODO: Consider doing Prediction.from_completions with the individual sets of passages (per query) too. class RetrieveThenRerank(Parameter): diff --git a/examples/rm_migration.ipynb b/examples/rm_migration.ipynb deleted file mode 100644 index 312fb1f99..000000000 --- a/examples/rm_migration.ipynb +++ /dev/null @@ -1,167 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "{DSPy.RM Migration - TBD}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#Querying ColBERTv2 \n", - "\n", - "import requests\n", - "import os\n", - "from typing import Any, Dict, List, Optional, Union\n", - "from dspy import RM, Retrieve, Embedder\n", - "from dspy.primitives.prediction import Prediction\n", - "\n", - "def colbert_search_function(query: str, k: int, url: str, post_requests: bool = False) -> List[Dict[str, Any]]:\n", - " if post_requests:\n", - " headers = {\"Content-Type\": \"application/json; charset=utf-8\"}\n", - " payload = {\"query\": query, \"k\": k}\n", - " res = requests.post(url, json=payload, headers=headers, timeout=10)\n", - " else:\n", - " payload = {\"query\": query, \"k\": k}\n", - " res = requests.get(url, params=payload, timeout=10)\n", - " \n", - " res.raise_for_status()\n", - " topk = res.json()[\"topk\"][:k]\n", - " topk = [{**doc, \"long_text\": doc.get(\"text\", \"\")} for doc in topk]\n", - " return topk\n", - "\n", - "def colbert_result_formatter(results: List[Dict[str, Any]]) -> Prediction:\n", - " passages = [doc[\"long_text\"] for doc in results]\n", - " return Prediction(passages=passages)\n", - "\n", - "colbert_url = \"http://20.102.90.50:2017/wiki17_abstracts\"\n", - "\n", - "colbert_rm = RM(\n", - " search_function=colbert_search_function,\n", - " result_formatter=colbert_result_formatter,\n", - " url=colbert_url,\n", - " post_requests=False\n", - ")\n", - "\n", - "retrieve = Retrieve(rm=colbert_rm, k=10)\n", - "query_text = \"Example query text\"\n", - "results = retrieve(query_text)\n", - "print(results.passages)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#Querying Databricks Mosaic AI Vector Search \n", - "\n", - "#client setup\n", - "databricks_token = os.environ.get(\"DATABRICKS_TOKEN\")\n", - "databricks_endpoint = os.environ.get(\"DATABRICKS_HOST\")\n", - "databricks_client = WorkspaceClient(host=databricks_endpoint, token=databricks_token)\n", - "\n", - "#custom logic for querying and sorting the docs\n", - "def databricks_search_function(\n", - " query,\n", - " k,\n", - " index_name,\n", - " columns,\n", - " query_type='ANN',\n", - " filters_json=None,\n", - " client=None\n", - "):\n", - " results = client.vector_search_indexes.query(\n", - " index_name=index_name,\n", - " query_type=query_type,\n", - " query_text=query,\n", - " num_results=k,\n", - " columns=columns,\n", - " filters_json=filters_json,\n", - " ).as_dict()\n", - "\n", - " items = []\n", - " col_names = [column[\"name\"] for column in results[\"manifest\"][\"columns\"]]\n", - " for data_row in results[\"result\"][\"data_array\"]:\n", - " item = {col_name: val for col_name, val in zip(col_names, data_row)}\n", - " items.append(item)\n", - " sorted_docs = sorted(items, key=lambda x: x[\"score\"], reverse=True)\n", - " return sorted_docs\n", - "\n", - "def databricks_result_formatter(results) -> Prediction:\n", - " passages = [doc['some_text_column'] for doc in results] \n", - " return Prediction(passages=passages)\n", - "\n", - "databricks_rm = RM(\n", - " search_function=databricks_search_function,\n", - " result_formatter=databricks_result_formatter,\n", - " client=databricks_client,\n", - " index_name='your_index_name',\n", - " columns=['id', 'some_text_column'],\n", - " filters_json=None\n", - ")\n", - "\n", - "retrieve = Retrieve(rm=databricks_rm, k=3)\n", - "results = retrieve(\"Example query text\")\n", - "print(results.passages)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#Querying Deeplake Vector Store\n", - "\n", - "embedder = Embedder()\n", - "\n", - "deeplake_vectorstore_name = 'vectorstore_name'\n", - "deeplake_client = deeplake.VectorStore(\n", - " path=deeplake_vectorstore_name,\n", - " embedding_function=embedder\n", - ")\n", - "\n", - "def deeplake_search_function(query, k, client=None):\n", - " results = client.search(query, k=k)\n", - " return results\n", - "\n", - "def deeplake_result_formatter(results) -> Prediction:\n", - " passages = [doc['text'] for doc in results['documents']]\n", - " return Prediction(passages=passages)\n", - "\n", - "\n", - "deeplake_rm = RM(\n", - " embedder=embedder,\n", - " search_function=deeplake_search_function,\n", - " result_formatter=deeplake_result_formatter,\n", - " client=deeplake_client\n", - ")\n", - "\n", - "retrieve = Retrieve(rm=deeplake_rm, k=3)\n", - "results = retrieve(\"some text\")\n", - "print(results.passages)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "TBD..." - ] - } - ], - "metadata": { - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From fdfa1f176ec65f9fb7fc06af857a54b978b7fbb0 Mon Sep 17 00:00:00 2001 From: Arnav Singhvi Date: Fri, 15 Nov 2024 15:02:22 -0800 Subject: [PATCH 3/3] updated dspy.Retriever interface --- dsp/primitives/search.py | 1 - dspy/__init__.py | 1 + dspy/adapters/chat_adapter.py | 1 - dspy/retrieve/retrieve.py | 110 ++++-- dspy/retriever/__init__.py | 1 + .../colbertv2_retriever.py} | 22 +- dspy/retriever/databricks_retriever.py | 347 ++++++++++++++++++ dspy/retriever/faiss_retriever.py | 138 +++++++ dspy/retriever/milvus_retriever.py | 115 ++++++ dspy/retriever/pinecone_retriever.py | 178 +++++++++ dspy/retriever/retriever.py | 60 +++ tests/clients/test_embedding.py | 8 +- 12 files changed, 950 insertions(+), 32 deletions(-) create mode 100644 dspy/retriever/__init__.py rename dspy/{retrieve/colbertv2_rm.py => retriever/colbertv2_retriever.py} (73%) create mode 100644 dspy/retriever/databricks_retriever.py create mode 100644 dspy/retriever/faiss_retriever.py create mode 100644 dspy/retriever/milvus_retriever.py create mode 100644 dspy/retriever/pinecone_retriever.py create mode 100644 dspy/retriever/retriever.py diff --git a/dsp/primitives/search.py b/dsp/primitives/search.py index 81689f122..1ad9a07cd 100644 --- a/dsp/primitives/search.py +++ b/dsp/primitives/search.py @@ -7,7 +7,6 @@ logger = logging.getLogger(__name__) -#TODO remove references now that Retrieve interface is supported def retrieve(query: str, k: int, **kwargs) -> list[str]: """Retrieves passages from the RM for the query and returns the top k passages.""" if not dsp.settings.rm: diff --git a/dspy/__init__.py b/dspy/__init__.py index 3e95d1667..f30522f06 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -4,6 +4,7 @@ from .predict import * from .primitives import * from .retrieve import * +from .retriever import * from .signatures import * # Functional must be imported after primitives, predict and signatures diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index edb5e1870..3a7b0de4c 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -17,7 +17,6 @@ from pydantic.fields import FieldInfo from typing import Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin -from dspy.adapters.base import Adapter from ..signatures.field import OutputField from ..signatures.signature import SignatureMeta from ..signatures.utils import get_dspy_field_type diff --git a/dspy/retrieve/retrieve.py b/dspy/retrieve/retrieve.py index 820178dec..a83763a86 100644 --- a/dspy/retrieve/retrieve.py +++ b/dspy/retrieve/retrieve.py @@ -1,11 +1,11 @@ import random -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union +import logging +from functools import lru_cache import dsp from dspy.predict.parameter import Parameter -from abc import ABC, abstractmethod from dspy.primitives.prediction import Prediction -from dspy.clients.embedding import Embedder from dspy.utils.callback import with_callbacks @@ -18,36 +18,102 @@ def single_query_passage(passages): passages_dict["passages"] = passages_dict.pop("long_text") return Prediction(**passages_dict) +@lru_cache(maxsize=None) +def warn_once(msg: str): + logging.warning(msg) -class Retrieve(ABC): - def __init__(self, embedder: Optional[Embedder] = None, k: int = 5, callbacks: Optional[List[Any]] = None): - self.embedder = embedder +class Retrieve(Parameter): + name = "Search" + input_variable = "query" + desc = "takes a search query and returns one or more potentially relevant passages from a corpus" + + def __init__(self, k=3, callbacks=None): + warn_once( + "Existing retriever integrations under dspy/retrieve inheriting `dspy.Retrieve` are deprecated and will be removed in DSPy 2.6+. \n" + "For future retriever integrations, please use the `dspy.Retriever` interface under dspy/retriever/retriever.py and reference any of the custom integrations supported in dspy/retriever/" + ) + self.stage = random.randbytes(8).hex() self.k = k self.callbacks = callbacks or [] - @abstractmethod - def forward(self, query: str, k: Optional[int] = None) -> Any: - """ - Retrievers implement this method with their custom retrieval logic. - Must return an object that has a 'passages' attribute (ideally `dspy.Prediction`). - """ + def reset(self): pass - def __call__(self, query: str, k: Optional[int] = None) -> Any: - """ - Calls the forward method and checks if the result has a 'passages' attribute. - """ + def dump_state(self, save_verbose=False): + """save_verbose is set as a default argument to support the inherited Parameter interface for dump_state""" + state_keys = ["k"] + return {k: getattr(self, k) for k in state_keys} + + def load_state(self, state): + for name, value in state.items(): + setattr(self, name, value) + + @with_callbacks + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def forward( + self, + query_or_queries: Union[str, List[str]] = None, + query: Optional[str] = None, + k: Optional[int] = None, + by_prob: bool = True, + with_metadata: bool = False, + **kwargs, + ) -> Union[List[str], Prediction, List[Prediction]]: + query_or_queries = query_or_queries or query + + # queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries + # queries = [query.strip().split('\n')[0].strip() for query in queries] + + # # print(queries) + # # TODO: Consider removing any quote-like markers that surround the query too. + # k = k if k is not None else self.k + # passages = dsp.retrieveEnsemble(queries, k=k,**kwargs) + # return Prediction(passages=passages) + queries = ( + [query_or_queries] + if isinstance(query_or_queries, str) + else query_or_queries + ) + queries = [query.strip().split("\n")[0].strip() for query in queries] + + # print(queries) + # TODO: Consider removing any quote-like markers that surround the query too. k = k if k is not None else self.k - result = self.forward(query, k) - if not hasattr(result, 'passages'): - raise ValueError("The 'forward' method must return an object with a 'passages' attribute (ideally `dspy.Prediction`).") - for callback in self.callbacks: - callback(result) - return result + if not with_metadata: + passages = dsp.retrieveEnsemble(queries, k=k, by_prob=by_prob, **kwargs) + return Prediction(passages=passages) + else: + passages = dsp.retrieveEnsemblewithMetadata( + queries, k=k, by_prob=by_prob, **kwargs, + ) + if isinstance(passages[0], List): + pred_returns = [] + for query_passages in passages: + passages_dict = { + key: [] + for key in list(query_passages[0].keys()) + if key != "tracking_idx" + } + for psg in query_passages: + for key, value in psg.items(): + if key == "tracking_idx": + continue + passages_dict[key].append(value) + if "long_text" in passages_dict: + passages_dict["passages"] = passages_dict.pop("long_text") + pred_returns.append(Prediction(**passages_dict)) + return pred_returns + elif isinstance(passages[0], Dict): + # passages dict will contain {"long_text":long_text_list,"metadatas";metadatas_list...} + return single_query_passage(passages=passages) + # TODO: Consider doing Prediction.from_completions with the individual sets of passages (per query) too. +#TODO potentially add for deprecation/removal in 2.6+ class RetrieveThenRerank(Parameter): name = "Search" input_variable = "query" diff --git a/dspy/retriever/__init__.py b/dspy/retriever/__init__.py new file mode 100644 index 000000000..3d04bf69e --- /dev/null +++ b/dspy/retriever/__init__.py @@ -0,0 +1 @@ +from .retriever import Retriever \ No newline at end of file diff --git a/dspy/retrieve/colbertv2_rm.py b/dspy/retriever/colbertv2_retriever.py similarity index 73% rename from dspy/retrieve/colbertv2_rm.py rename to dspy/retriever/colbertv2_retriever.py index 387d97739..38e891c4e 100644 --- a/dspy/retrieve/colbertv2_rm.py +++ b/dspy/retriever/colbertv2_retriever.py @@ -1,11 +1,26 @@ -from typing import Any, Union, Optional, List -from dspy.primitives.prediction import Prediction +from typing import Any, Union, Optional import dspy from dsp.utils import dotdict import requests import functools -class ColBERTv2(dspy.Retrieve): +class ColBERTv2(dspy.Retriever): + """ + ColBERTv2 Retriever for retrieval of top-k most relevant text passages for given query. + + Args: + post_requests (bool): Determines if POST requests should be used + instead of GET requests for querying the server. + url (str): URL endpoint for ColBERTv2 server + + Returns: + An object containing the retrieved passages. + + Example: + from dspy.retriever.colbertv2_retriever import ColBERTv2 + results = ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')(query, k=5).passages + print(results) + """ def __init__(self, url: str = "http://0.0.0.0", port: Optional[Union[str, int]] = None, post_requests: bool = False): super().__init__(embedder=None) self.post_requests = post_requests @@ -20,7 +35,6 @@ def forward(self, query: str, k: int = 10) -> Any: from dsp.modules.cache_utils import CacheMemory, NotebookCacheMemory -from dsp.utils import dotdict @CacheMemory.cache def colbertv2_get_request_v2(url: str, query: str, k: int): assert ( diff --git a/dspy/retriever/databricks_retriever.py b/dspy/retriever/databricks_retriever.py new file mode 100644 index 000000000..fe25e6995 --- /dev/null +++ b/dspy/retriever/databricks_retriever.py @@ -0,0 +1,347 @@ +import json +import os +from importlib.util import find_spec +from typing import Any, Dict, List, Optional + +import requests + +import dspy +from dspy.primitives.prediction import Prediction +from dspy.clients.embedding import Embedder + +_databricks_sdk_installed = find_spec("databricks.sdk") is not None + + +class DatabricksRetriever(dspy.Retriever): + """ + A retriever module that uses a Databricks Mosaic AI Vector Search Index to return the top-k + embeddings for a given query. + + Examples: + Below is a code snippet that shows how to set up a Databricks Vector Search Index + and configure a DatabricksRetriever module to query the index. + + (example adapted from "Databricks: How to create and query a Vector Search Index: + https://docs.databricks.com/en/generative-ai/create-query-vector-search.html#create-a-vector-search-index) + + ```python + from databricks.vector_search.client import VectorSearchClient + + # Create a Databricks Vector Search Endpoint + client = VectorSearchClient() + client.create_endpoint( + name="your_vector_search_endpoint_name", + endpoint_type="STANDARD" + ) + + # Create a Databricks Direct Access Vector Search Index + index = client.create_direct_access_index( + endpoint_name="your_vector_search_endpoint_name", + index_name="your_index_name", + primary_key="id", + embedding_dimension=1024, + embedding_vector_column="text_vector", + schema={ + "id": "int", + "field2": "str", + "field3": "float", + "text_vector": "array" + } + ) + + # Create a DatabricksRetriever module to query the Databricks Direct Access Vector + # Search Index + from dspy.retriever.databricks_retriever import DatabricksRetriever + + retriever = DatabricksRetriever( + databricks_index_name = "your_index_name", + docs_id_column_name="id", + text_column_name="field2", + k=3 + ) + ``` + + Below is a code snippet that shows how to query the Databricks Direct Access Vector + Search Index using the DatabricksRetriever module: + + ```python + retrieved_results = retriever(query="Example query text") + ``` + """ + + def __init__( + self, + databricks_index_name: str, + databricks_endpoint: Optional[str] = None, + databricks_token: Optional[str] = None, + columns: Optional[List[str]] = None, + filters_json: Optional[str] = None, + query_type: str = "ANN", + k: int = 3, + docs_id_column_name: str = "id", + text_column_name: str = "text", + embedder: Optional[Embedder] = None, + callbacks: Optional[List[Any]] = None, + ): + """ + Args: + databricks_index_name (str): The name of the Databricks Vector Search Index to query. + databricks_endpoint (Optional[str]): The URL of the Databricks Workspace containing + the Vector Search Index. Defaults to the value of the ``DATABRICKS_HOST`` + environment variable. If unspecified, the Databricks SDK is used to identify the + endpoint based on the current environment. + databricks_token (Optional[str]): The Databricks Workspace authentication token to use + when querying the Vector Search Index. Defaults to the value of the + ``DATABRICKS_TOKEN`` environment variable. If unspecified, the Databricks SDK is + used to identify the token based on the current environment. + columns (Optional[List[str]]): Extra column names to include in response, + in addition to the document id and text columns specified by + ``docs_id_column_name`` and ``text_column_name``. + filters_json (Optional[str]): A JSON string specifying additional query filters. + Example filters: ``{"id <": 5}`` selects records that have an ``id`` column value + less than 5, and ``{"id >=": 5, "id <": 10}`` selects records that have an ``id`` + column value greater than or equal to 5 and less than 10. + query_type (str): The type of search query to perform. Must be 'ANN', 'HYBRID', or 'VECTOR'. + k (int): The number of documents to retrieve. + docs_id_column_name (str): The name of the column in the Databricks Vector Search Index + containing document IDs. + text_column_name (str): The name of the column in the Databricks Vector Search Index + containing document text to retrieve. + embedder (Optional[Embedder]): An embedder to convert query text to vectors when + using 'VECTOR' query_type. + callbacks (Optional[List[Any]]): A list of callback functions. + """ + super().__init__(embedder=embedder, k=k, callbacks=callbacks) + self.databricks_token = databricks_token if databricks_token is not None else os.environ.get("DATABRICKS_TOKEN") + self.databricks_endpoint = ( + databricks_endpoint if databricks_endpoint is not None else os.environ.get("DATABRICKS_HOST") + ) + if not _databricks_sdk_installed and (self.databricks_token, self.databricks_endpoint).count(None) > 0: + raise ValueError( + "To retrieve documents with Databricks Vector Search, you must install the" + " databricks-sdk Python library, supply the databricks_token and" + " databricks_endpoint parameters, or set the DATABRICKS_TOKEN and DATABRICKS_HOST" + " environment variables." + ) + self.databricks_index_name = databricks_index_name + self.columns = list({docs_id_column_name, text_column_name, *(columns or [])}) + self.filters_json = filters_json + self.query_type = query_type + self.docs_id_column_name = docs_id_column_name + self.text_column_name = text_column_name + + def _extract_doc_ids(self, item: Dict[str, Any]) -> str: + """Extracts the document id from a search result. + + Args: + item (Dict[str, Any]): A record from the search results. + + Returns: + str: Document id. + """ + if self.docs_id_column_name == "metadata": + docs_dict = json.loads(item["metadata"]) + return docs_dict["document_id"] + return item[self.docs_id_column_name] + + def _get_extra_columns(self, item: Dict[str, Any]) -> Dict[str, Any]: + """Extracts search result column values, excluding the "text" and "id" columns. + + Args: + item (Dict[str, Any]): A record from the search results. + + Returns: + Dict[str, Any]: Search result column values, excluding the "text" and "id" columns. + """ + extra_columns = {k: v for k, v in item.items() if k not in [self.docs_id_column_name, self.text_column_name]} + if self.docs_id_column_name == "metadata": + extra_columns = { + **extra_columns, + **{"metadata": {k: v for k, v in json.loads(item["metadata"]).items() if k != "document_id"}}, + } + return extra_columns + + def forward(self, query: str, k: Optional[int] = None) -> dspy.Prediction: + """ + Retrieve documents from a Databricks Mosaic AI Vector Search Index that are relevant to the + specified query. + + Args: + query (str): The query text for which to retrieve relevant documents. + k (Optional[int]): The number of documents to retrieve. If None, defaults to self.k. + + Returns: + dspy.Prediction: An object containing the retrieved results. + """ + k = k or self.k + query_text = query + query_vector = None + + if self.query_type.upper() in ["ANN", "HYBRID"]: + query_text = query + elif self.query_type.upper() == "VECTOR": + if self.embedder: + query_vector = self.embedder.embed(query) + query_text = None + else: + raise ValueError("An embedder must be provided when using 'VECTOR' query_type without providing a query vector.") + else: + raise ValueError(f"Unsupported query_type: {self.query_type}") + + if _databricks_sdk_installed: + results = self._query_via_databricks_sdk( + index_name=self.databricks_index_name, + k=k, + columns=self.columns, + query_type=self.query_type.upper(), + query_text=query_text, + query_vector=query_vector, + databricks_token=self.databricks_token, + databricks_endpoint=self.databricks_endpoint, + filters_json=self.filters_json, + ) + else: + results = self._query_via_requests( + index_name=self.databricks_index_name, + k=k, + columns=self.columns, + databricks_token=self.databricks_token, + databricks_endpoint=self.databricks_endpoint, + query_type=self.query_type.upper(), + query_text=query_text, + query_vector=query_vector, + filters_json=self.filters_json, + ) + + # Checking if defined columns are present in the index columns + col_names = [column["name"] for column in results["manifest"]["columns"]] + + if self.docs_id_column_name not in col_names: + raise Exception( + f"docs_id_column_name: '{self.docs_id_column_name}' is not in the index columns: \n {col_names}" + ) + + if self.text_column_name not in col_names: + raise Exception(f"text_column_name: '{self.text_column_name}' is not in the index columns: \n {col_names}") + + # Extracting the results + items = [] + for data_row in results["result"]["data_array"]: + item = {col_name: val for col_name, val in zip(col_names, data_row)} + items.append(item) + + # Sorting results by score in descending order + sorted_docs = sorted(items, key=lambda x: x["score"], reverse=True)[:k] + + # Returning the prediction + return Prediction( + passages=[doc[self.text_column_name] for doc in sorted_docs], + doc_ids=[self._extract_doc_ids(doc) for doc in sorted_docs], + extra_columns=[self._get_extra_columns(doc) for doc in sorted_docs], + ) + + @staticmethod + def _query_via_databricks_sdk( + index_name: str, + k: int, + columns: List[str], + query_type: str, + query_text: Optional[str], + query_vector: Optional[List[float]], + databricks_token: Optional[str], + databricks_endpoint: Optional[str], + filters_json: Optional[str], + ) -> Dict[str, Any]: + """ + Query a Databricks Vector Search Index via the Databricks SDK. + Assumes that the databricks-sdk Python library is installed. + + Args: + index_name (str): Name of the Databricks vector search index to query + k (int): Number of relevant documents to retrieve. + columns (List[str]): Column names to include in response. + query_text (Optional[str]): Text query for which to find relevant documents. Exactly + one of query_text or query_vector must be specified. + query_vector (Optional[List[float]]): Numeric query vector for which to find relevant + documents. Exactly one of query_text or query_vector must be specified. + filters_json (Optional[str]): JSON string representing additional query filters. + databricks_token (str): Databricks authentication token. If not specified, + the token is resolved from the current environment. + databricks_endpoint (str): Databricks index endpoint url. If not specified, + the endpoint is resolved from the current environment. + Returns: + Dict[str, Any]: Parsed JSON response from the Databricks Vector Search Index query. + """ + from databricks.sdk import WorkspaceClient + + if (query_text, query_vector).count(None) != 1: + raise ValueError("Exactly one of query_text or query_vector must be specified.") + + databricks_client = WorkspaceClient(host=databricks_endpoint, token=databricks_token) + return databricks_client.vector_search_indexes.query_index( + index_name=index_name, + query_type=query_type, + query_text=query_text, + query_vector=query_vector, + columns=columns, + filters_json=filters_json, + num_results=k, + ).as_dict() + + @staticmethod + def _query_via_requests( + index_name: str, + k: int, + columns: List[str], + databricks_token: str, + databricks_endpoint: str, + query_type: str, + query_text: Optional[str], + query_vector: Optional[List[float]], + filters_json: Optional[str], + ) -> Dict[str, Any]: + """ + Query a Databricks Vector Search Index via the Python requests library. + + Args: + index_name (str): Name of the Databricks vector search index to query + k (int): Number of relevant documents to retrieve. + columns (List[str]): Column names to include in response. + databricks_token (str): Databricks authentication token. + databricks_endpoint (str): Databricks index endpoint url. + query_text (Optional[str]): Text query for which to find relevant documents. Exactly + one of query_text or query_vector must be specified. + query_vector (Optional[List[float]]): Numeric query vector for which to find relevant + documents. Exactly one of query_text or query_vector must be specified. + filters_json (Optional[str]): JSON string representing additional query filters. + + Returns: + Dict[str, Any]: Parsed JSON response from the Databricks Vector Search Index query. + """ + if (query_text, query_vector).count(None) != 1: + raise ValueError("Exactly one of query_text or query_vector must be specified.") + + headers = { + "Authorization": f"Bearer {databricks_token}", + "Content-Type": "application/json", + } + payload = { + "columns": columns, + "num_results": k, + "query_type": query_type, + } + if filters_json is not None: + payload["filters_json"] = filters_json + if query_text is not None: + payload["query_text"] = query_text + elif query_vector is not None: + payload["query_vector"] = query_vector + response = requests.post( + f"{databricks_endpoint}/api/2.0/vector-search/indexes/{index_name}/query", + json=payload, + headers=headers, + ) + results = response.json() + if "error_code" in results: + raise Exception(f"ERROR: {results['error_code']} -- {results['message']}") + return results diff --git a/dspy/retriever/faiss_retriever.py b/dspy/retriever/faiss_retriever.py new file mode 100644 index 000000000..31cc3270c --- /dev/null +++ b/dspy/retriever/faiss_retriever.py @@ -0,0 +1,138 @@ +"""Retriever model for faiss: https://github.com/facebookresearch/faiss. +Author: Jagane Sundar: https://github.com/jagane. +(modified to support `dspy.Retriever` interface) +""" + +import logging +from typing import List, Any, Optional + +import numpy as np + +import dspy +from dspy import Embedder + +try: + import faiss +except ImportError: + faiss = None + +if faiss is None: + raise ImportError( + """ + The faiss package is required. Install it using `pip install dspy-ai[faiss-cpu]` + """, + ) + +logger = logging.getLogger(__name__) + +class FaissRetriever(dspy.Retriever): + """A retrieval module that uses an in-memory Faiss index to return the top passages for a given query. + + Args: + document_chunks: The input text chunks. + embedder: An instance of `dspy.Embedder` to compute embeddings. + k (int, optional): The number of top passages to retrieve. Defaults to 3. + + Returns: + dspy.Prediction: An object containing the retrieved passages. + + Examples: + Below is a code snippet that shows how to use this as the default retriever: + + ```python + import dspy + from dspy.retriever.faiss_retriever import FaissRetriever + + # Custom embedding function using SentenceTransformers and dspy.Embedder + def sentence_transformers_embedder(texts): + #(pip install sentence-transformers) + from sentence_transformers import SentenceTransformer + model = SentenceTransformer('all-MiniLM-L6-v2') + embeddings = model.encode(texts, batch_size=256, normalize_embeddings=True) + return embeddings.tolist() + embedder = dspy.Embedder(embedding_function=sentence_transformers_embedder) + + document_chunks = [ + "The superbowl this year was played between the San Francisco 49ers and the Kansas City Chiefs", + "Pop corn is often served in a bowl", + "The Rice Bowl is a Chinese Restaurant located in the city of Tucson, Arizona", + "Mars is the fourth planet in the Solar System", + "An aquarium is a place where children can learn about marine life", + "The capital of the United States is Washington, D.C", + "Rock and Roll musicians are honored by being inducted in the Rock and Roll Hall of Fame", + "Music albums were published on Long Play Records in the 70s and 80s", + "Sichuan cuisine is a spicy cuisine from central China", + "The interest rates for mortgages are considered to be very high in 2024", + ] + + retriever = FaissRetriever(document_chunks, embedder=embedder) + results = retriever("I am in the mood for Chinese food").passages + print(results) + ``` + """ + + def __init__( + self, + document_chunks: List[str], + embedder: Optional[Embedder] = None, + k: int = 3, + callbacks: Optional[List[Any]] = None, + ): + """Inits the faiss retriever. + + Args: + document_chunks: A list of input strings. + embedder: An instance of `dspy.Embedder` to compute embeddings. + k: Number of matches to return. + """ + if embedder is not None and not isinstance(embedder, dspy.Embedder): + raise ValueError("If provided, the embedder must be of type `dspy.Embedder`.") + self.embedder = embedder + embeddings = self.embedder(document_chunks) + xb = np.array(embeddings) + d = xb.shape[1] + logger.info(f"FaissRetriever: embedding size={d}") + if len(xb) < 100: + self._faiss_index = faiss.IndexFlatL2(d) + self._faiss_index.add(xb) + else: + # If we have at least 100 vectors, we use Voronoi cells + nlist = 100 + quantizer = faiss.IndexFlatL2(d) + self._faiss_index = faiss.IndexIVFFlat(quantizer, d, nlist) + self._faiss_index.train(xb) + self._faiss_index.add(xb) + + logger.info(f"{self._faiss_index.ntotal} vectors in faiss index") + self._document_chunks = document_chunks # Save the input document chunks + + super().__init__(embedder=self.embedder, k=k, callbacks=callbacks) + + def _dump_raw_results(self, queries, index_list, distance_list) -> None: + for i in range(len(queries)): + indices = index_list[i] + distances = distance_list[i] + logger.debug(f"Query: {queries[i]}") + for j in range(len(indices)): + logger.debug( + f" Hit {j} = {indices[j]}/{distances[j]}: {self._document_chunks[indices[j]]}" + ) + return + + def forward(self, query: str, k: Optional[int] = None, **kwargs) -> dspy.Prediction: + """Search the faiss index for k or self.k top passages for query. + + Args: + query (str): The query to search for. + + Returns: + dspy.Prediction: An object containing the retrieved passages. + """ + k = k or self.k + embeddings = self.embedder([query]) + emb_npa = np.array(embeddings) + distance_list, index_list = self._faiss_index.search(emb_npa, k) + # self._dump_raw_results([query], index_list, distance_list) + passages = [self._document_chunks[ind] for ind in index_list[0]] + doc_ids = [ind for ind in index_list[0]] + return dspy.Prediction(passages=passages, doc_ids=doc_ids) diff --git a/dspy/retriever/milvus_retriever.py b/dspy/retriever/milvus_retriever.py new file mode 100644 index 000000000..515f33994 --- /dev/null +++ b/dspy/retriever/milvus_retriever.py @@ -0,0 +1,115 @@ +""" +Retriever model for Milvus or Zilliz Cloud +""" + +from typing import List, Optional, Any + +import dspy +from dspy import Embedder + +try: + from pymilvus import MilvusClient +except ImportError: + raise ImportError( + "The pymilvus library is required to use MilvusRetriever. Install it with `pip install dspy-ai[milvus]`", + ) + +class MilvusRetriever(dspy.Retriever): + """ + A retrieval module that uses Milvus to return passages for a given query. + + Assumes that a Milvus collection has been created and populated with the following field: + - text: The text of the passage + + Args: + collection_name (str): The name of the Milvus collection to query against. + uri (str, optional): The Milvus connection URI. Defaults to "http://localhost:19530". + token (str, optional): The Milvus connection token. Defaults to None. + db_name (str, optional): The Milvus database name. Defaults to "default". + embedder (dspy.Embedder): An instance of `dspy.Embedder` to compute embeddings. + k (int, optional): The number of top passages to retrieve. Defaults to 3. + callbacks (Optional[List[Any]]): A list of callback functions. + + Returns: + dspy.Prediction: An object containing the retrieved passages. + + Examples: + Below is a code snippet that shows how to use this as the default retriever: + ```python + import dspy + from dspy.retriever.milvus_retriever import MilvusRetriever + + # Create an Embedder instance + embedder = dspy.Embedder(embedding_model="text-embedding-ada-002") + + retriever = MilvusRetriever( + collection_name="", + uri="", + token="", + embedder=embedder, + k=3 + ) + results = retriever(query).passages + print(results) + ``` + """ + + def __init__( + self, + collection_name: str, + uri: Optional[str] = "http://localhost:19530", + token: Optional[str] = None, + db_name: Optional[str] = "default", + embedder: Embedder = None, + k: int = 3, + callbacks: Optional[List[Any]] = None, + ): + if embedder is not None and not isinstance(embedder, dspy.Embedder): + raise ValueError("If provided, the embedder must be of type `dspy.Embedder`.") + super().__init__(embedder=embedder, k=k, callbacks=callbacks) + + self.milvus_client = MilvusClient(uri=uri, token=token, db_name=db_name) + + # Check if collection exists + if collection_name not in self.milvus_client.list_collections(): + raise AttributeError(f"Milvus collection not found: {collection_name}") + self.collection_name = collection_name + + def forward(self, query: str, k: Optional[int] = None) -> dspy.Prediction: + """ + Retrieve passages from Milvus that are relevant to the specified query. + + Args: + query (str): The query text for which to retrieve relevant passages. + k (Optional[int]): The number of passages to retrieve. If None, defaults to self.k. + + Returns: + dspy.Prediction: An object containing the retrieved passages. + """ + k = k or self.k + query_embedding = self.embedder([query])[0] + + # Milvus expects embeddings as lists + query_embedding = query_embedding.tolist() + + milvus_res = self.milvus_client.search( + collection_name=self.collection_name, + data=[query_embedding], + output_fields=["text"], + limit=k, + ) + + results = [] + for res in milvus_res: + for r in res: + text = r["entity"]["text"] + doc_id = r["id"] + distance = r["distance"] + results.append((text, doc_id, distance)) + + sorted_results = sorted(results, key=lambda x: x[2], reverse=True)[:k] + passages = [x[0] for x in sorted_results] + doc_ids = [x[1] for x in sorted_results] + distances = [x[2] for x in sorted_results] + + return dspy.Prediction(passages=passages, doc_ids=doc_ids, scores=distances) \ No newline at end of file diff --git a/dspy/retriever/pinecone_retriever.py b/dspy/retriever/pinecone_retriever.py new file mode 100644 index 000000000..945008459 --- /dev/null +++ b/dspy/retriever/pinecone_retriever.py @@ -0,0 +1,178 @@ +""" +Retriever model for Pinecone +Author: Dhar Rawal (@drawal1) +(modified to support `dspy.Retriever` interface) +""" + +from typing import List, Optional, Any, Union + +import dspy +from dspy import Embedder +from dspy.primitives.prediction import Prediction +from dsp.utils import dotdict + +try: + import pinecone +except ImportError: + pinecone = None + +if pinecone is None: + raise ImportError( + "The pinecone library is required to use PineconeRetriever. Install it with `pip install dspy-ai[pinecone]`", + ) + + +class PineconeRetriever(dspy.Retriever): + """ + A retrieval module that uses Pinecone to return the top passages for a given query or list of queries. + + Assumes that the Pinecone index has been created and populated with the following metadata: + - text: The text of the passage + + Args: + pinecone_index_name (str): The name of the Pinecone index to query against. + pinecone_api_key (str, optional): The Pinecone API key. Defaults to None. + pinecone_env (str, optional): The Pinecone environment. Defaults to None. + embedder (dspy.Embedder): An instance of `dspy.Embedder` to compute embeddings. + k (int, optional): The number of top passages to retrieve. Defaults to 3. + callbacks (Optional[List[Any]]): A list of callback functions. + + Returns: + dspy.Prediction: An object containing the retrieved passages. + + Examples: + Below is a code snippet that shows how to use this as the default retriever: + ```python + import dspy + from dspy.retriever.pinecone_retriever import PineconeRetriever + + # Create an Embedder instance + embedder = dspy.Embedder(embedding_model="text-embedding-ada-002") + + retriever = PineconeRetriever( + pinecone_index_name="", + pinecone_api_key="", + pinecone_env="", + embedder=embedder, + k=3 + ) + + results = retriever(query).passages + print(results) + ``` + """ + + def __init__( + self, + pinecone_index_name: str, + pinecone_api_key: Optional[str] = None, + pinecone_env: Optional[str] = None, + dimension: Optional[int] = None, + distance_metric: Optional[str] = None, + embedder: Embedder = None, + k: int = 3, + callbacks: Optional[List[Any]] = None, + ): + if embedder is None or not isinstance(embedder, dspy.Embedder): + raise ValueError("An embedder of type `dspy.Embedder` must be provided.") + self.embedder = embedder + super().__init__(embedder=self.embedder, k=k, callbacks=callbacks) + + self._pinecone_index = self._init_pinecone( + index_name=pinecone_index_name, + api_key=pinecone_api_key, + environment=pinecone_env, + dimension=dimension, + distance_metric=distance_metric, + ) + + def _init_pinecone( + self, + index_name: str, + api_key: Optional[str] = None, + environment: Optional[str] = None, + dimension: Optional[int] = None, + distance_metric: Optional[str] = None, + ) -> pinecone.Index: + """Initialize pinecone and return the loaded index. + + Args: + index_name (str): The name of the index to load. If the index is not does not exist, it will be created. + api_key (str, optional): The Pinecone API key, defaults to env var PINECONE_API_KEY if not provided. + environment (str, optional): The environment (ie. `us-west1-gcp` or `gcp-starter`. Defaults to env PINECONE_ENVIRONMENT. + + Raises: + ValueError: If api_key or environment is not provided and not set as an environment variable. + + Returns: + pinecone.Index: The loaded index. + """ + + # Pinecone init overrides default if kwargs are present, so we need to exclude if None + kwargs = {} + if api_key: + kwargs["api_key"] = api_key + if environment: + kwargs["environment"] = environment + pinecone.init(**kwargs) + + active_indexes = pinecone.list_indexes() + if index_name not in active_indexes: + if dimension is None or distance_metric is None: + raise ValueError( + "dimension and distance_metric must be provided since the index does not exist and needs to be created." + ) + + pinecone.create_index( + name=index_name, + dimension=dimension, + metric=distance_metric, + ) + + return pinecone.Index(index_name) + + def forward(self, query: Union[str, List[str]], k: Optional[int] = None) -> dspy.Prediction: + """Search with Pinecone for top k passages for the query or queries. + + Args: + query (Union[str, List[str]]): The query or list of queries to search for. + k (Optional[int]): The number of top passages to retrieve. Defaults to self.k. + + Returns: + dspy.Prediction: An object containing the retrieved passages. + """ + k = k or self.k + queries = [query] if isinstance(query, str) else query + queries = [q for q in queries if q] + embeddings = self.embedder(queries) + # For single query, just look up the top k passages + if len(queries) == 1: + results_dict = self._pinecone_index.query( + embeddings[0], top_k=self.k, include_metadata=True, + ) + + # Sort results by score + sorted_results = sorted( + results_dict["matches"], key=lambda x: x.get("scores", 0.0), reverse=True, + ) + passages = [result["metadata"]["text"] for result in sorted_results] + passages = [dotdict({"long_text": passage for passage in passages})] + return Prediction(passages=passages) + + # For multiple queries, query each and return the highest scoring passages + # If a passage is returned multiple times, the score is accumulated. For this reason we increase top_k by 3x + passage_scores = {} + for embedding in embeddings: + results_dict = self._pinecone_index.query( + embedding, top_k=self.k * 3, include_metadata=True, + ) + for result in results_dict["matches"]: + passage_scores[result["metadata"]["text"]] = ( + passage_scores.get(result["metadata"]["text"], 0.0) + + result["score"] + ) + + sorted_passages = sorted( + passage_scores.items(), key=lambda x: x[1], reverse=True, + )[: self.k] + return Prediction(passages=[dotdict({"long_text": passage}) for passage, _ in sorted_passages]) \ No newline at end of file diff --git a/dspy/retriever/retriever.py b/dspy/retriever/retriever.py new file mode 100644 index 000000000..ce0a96ede --- /dev/null +++ b/dspy/retriever/retriever.py @@ -0,0 +1,60 @@ +from typing import Any, List, Optional + +from abc import ABC, abstractmethod +from dspy.clients.embedding import Embedder +from dspy.utils.callback import with_callbacks + +import os +from pathlib import Path +from diskcache import Cache + +DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") + + +class Retriever(ABC): + def __init__(self, embedder: Optional[Embedder] = None, k: int = 5, callbacks: Optional[List[Any]] = None, cache: bool = False): + """ + Interface for composing retrievers in DSPy to return relevant passages or documents based on a query. + + Args: + embedder (Optional[Embedder]): An instance of `dspy.Embedder` used to compute embeddings + for queries and documents. If `None`, embedding functionality should be implemented + within the subclass. Defaults to `None`. + k (int): The default number of top passages to retrieve when not specified in the `forward` method. Defaults to `5`. + callbacks (Optional[List[Any]]): A list of callback functions to be called during retrieval. + cache (bool): Enable retrieval caching. Disabled by default. + """ + self.embedder = embedder + self.k = k + self.callbacks = callbacks or [] + self.cache_enabled = cache + self.cache = Cache(directory=DISK_CACHE_DIR) if self.cache_enabled else None + + @abstractmethod + def forward(self, query: str, k: Optional[int] = None) -> Any: + """ + Retrievers implement this method with their custom retrieval logic. + Must return an object that has a 'passages' attribute (ideally `dspy.Prediction`). + """ + pass + + @with_callbacks + def __call__(self, query: str, k: Optional[int] = None) -> Any: + """ + Calls the forward method and checks if the result has a 'passages' attribute. + """ + k = k if k is not None else self.k + if self.cache_enabled and self.cache is not None: + cache_key = (query, k) + try: + result = self.cache[cache_key] + except KeyError: + result = self.forward(query, k) + self.cache[cache_key] = result + else: + result = self.forward(query, k) + if not hasattr(result, 'passages'): + raise ValueError( + "The 'forward' method must return an object with a 'passages' attribute (ideally `dspy.Prediction`)." + ) + return result diff --git a/tests/clients/test_embedding.py b/tests/clients/test_embedding.py index d12850e52..0ac9e24ba 100644 --- a/tests/clients/test_embedding.py +++ b/tests/clients/test_embedding.py @@ -2,7 +2,7 @@ from unittest.mock import Mock, patch import numpy as np -from dspy.clients.embedding import Embedding +from dspy.clients.embedding import Embedder # Mock response format similar to litellm's embedding response. @@ -27,7 +27,7 @@ def test_litellm_embedding(): mock_litellm.return_value = MockEmbeddingResponse(mock_embeddings) # Create embedding instance and call it. - embedding = Embedding(model) + embedding = Embedder(model) result = embedding(inputs) # Verify litellm was called with correct parameters. @@ -51,7 +51,7 @@ def mock_embedding_fn(texts): return expected_embeddings # Create embedding instance with callable - embedding = Embedding(mock_embedding_fn) + embedding = Embedder(mock_embedding_fn) result = embedding(inputs) np.testing.assert_allclose(result, expected_embeddings) @@ -60,5 +60,5 @@ def mock_embedding_fn(texts): def test_invalid_model_type(): # Test that invalid model type raises ValueError with pytest.raises(ValueError): - embedding = Embedding(123) # Invalid model type + embedding = Embedder(123) # Invalid model type embedding(["test"])