Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP - dspy.RM/retrieve refactor #1739

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dspy/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .lm import LM
from .lm import LM
from .rm import RM
33 changes: 33 additions & 0 deletions dspy/clients/rm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Any, Callable, List, Optional

from dspy.primitives.prediction import Prediction
from dspy.retrieve.embedder import Embedder

class RM:
def __init__(
self,
search_function: Callable[..., Any],
embedder: Optional[Embedder] = None,
result_formatter: Optional[Callable[[Any], Prediction]] = None,
**provider_kwargs
):
self.embedder = embedder
self.search_function = search_function
self.result_formatter = result_formatter or self.default_formatter
self.provider_kwargs = provider_kwargs

def __call__(self, query: str, k: Optional[int] = None) -> Prediction:
if self.embedder:
query_vector = self.embedder([query])[0]
query_input = query_vector
else:
query_input = query
search_args = self.provider_kwargs.copy()
search_args['query'] = query_input
if k is not None:
search_args['k'] = k
results = self.search_function(**search_args)
return self.result_formatter(results)

def default_formatter(self, results) -> Prediction:
return Prediction(passages=results)
3 changes: 2 additions & 1 deletion dspy/retrieve/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .retrieve import Retrieve, RetrieveThenRerank
from .retrieve import Retrieve, RetrieveThenRerank
from .embedder import Embedder
16 changes: 16 additions & 0 deletions dspy/retrieve/embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Callable, List, Optional


class Embedder:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#1735 will go here once merged @chenmoneygithub

def __init__(self, embedding_model: str = 'text-embedding-ada-002', embedding_function: Optional[Callable[[List[str]], List[List[float]]]] = None):
self.embedding_model = embedding_model
self.embedding_function = embedding_function or self.default_embedding_function

def default_embedding_function(self, texts: List[str]) -> List[List[float]]:
from litellm import embedding
embeddings_response = embedding(model=self.embedding_model, input=texts)
embeddings = [data['embedding'] for data in embeddings_response.data]
return embeddings

def __call__(self, texts: List[str]) -> List[List[float]]:
return self.embedding_function(texts)
85 changes: 7 additions & 78 deletions dspy/retrieve/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import dsp
from dspy.predict.parameter import Parameter
from dspy.primitives.prediction import Prediction
from dspy.utils.callback import with_callbacks
from dspy.clients import RM


def single_query_passage(passages):
Expand All @@ -22,86 +22,15 @@ class Retrieve(Parameter):
input_variable = "query"
desc = "takes a search query and returns one or more potentially relevant passages from a corpus"

def __init__(self, k=3, callbacks=None):
self.stage = random.randbytes(8).hex()
def __init__(self, rm: RM, k=3):
self.rm = rm
self.k = k
self.callbacks = callbacks or []

def reset(self):
pass

def dump_state(self, save_verbose=False):
"""save_verbose is set as a default argument to support the inherited Parameter interface for dump_state"""
state_keys = ["k"]
return {k: getattr(self, k) for k in state_keys}

def load_state(self, state):
for name, value in state.items():
setattr(self, name, value)

@with_callbacks
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)

def forward(
self,
query_or_queries: Union[str, List[str]] = None,
query: Optional[str] = None,
k: Optional[int] = None,
by_prob: bool = True,
with_metadata: bool = False,
**kwargs,
) -> Union[List[str], Prediction, List[Prediction]]:
query_or_queries = query_or_queries or query

# queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries
# queries = [query.strip().split('\n')[0].strip() for query in queries]
#TODO - add back saving/loading for retrievers

# # print(queries)
# # TODO: Consider removing any quote-like markers that surround the query too.
# k = k if k is not None else self.k
# passages = dsp.retrieveEnsemble(queries, k=k,**kwargs)
# return Prediction(passages=passages)
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
queries = [query.strip().split("\n")[0].strip() for query in queries]

# print(queries)
# TODO: Consider removing any quote-like markers that surround the query too.
def __call__(self, query_or_queries, k=None):
k = k if k is not None else self.k
if not with_metadata:
passages = dsp.retrieveEnsemble(queries, k=k, by_prob=by_prob, **kwargs)
return Prediction(passages=passages)
else:
passages = dsp.retrieveEnsemblewithMetadata(
queries, k=k, by_prob=by_prob, **kwargs,
)
if isinstance(passages[0], List):
pred_returns = []
for query_passages in passages:
passages_dict = {
key: []
for key in list(query_passages[0].keys())
if key != "tracking_idx"
}
for psg in query_passages:
for key, value in psg.items():
if key == "tracking_idx":
continue
passages_dict[key].append(value)
if "long_text" in passages_dict:
passages_dict["passages"] = passages_dict.pop("long_text")
pred_returns.append(Prediction(**passages_dict))
return pred_returns
elif isinstance(passages[0], Dict):
# passages dict will contain {"long_text":long_text_list,"metadatas";metadatas_list...}
return single_query_passage(passages=passages)


# TODO: Consider doing Prediction.from_completions with the individual sets of passages (per query) too.
return self.rm(query_or_queries, k=k)


class RetrieveThenRerank(Parameter):
Expand Down Expand Up @@ -163,4 +92,4 @@ def forward(
pred_returns.append(Prediction(**passages_dict))
return pred_returns
elif isinstance(passages[0], Dict):
return single_query_passage(passages=passages)
return single_query_passage(passages=passages)
167 changes: 167 additions & 0 deletions examples/rm_migration.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"{DSPy.RM Migration - TBD}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#Querying ColBERTv2 \n",
"\n",
"import requests\n",
"import os\n",
"from typing import Any, Dict, List, Optional, Union\n",
"from dspy import RM, Retrieve, Embedder\n",
"from dspy.primitives.prediction import Prediction\n",
"\n",
"def colbert_search_function(query: str, k: int, url: str, post_requests: bool = False) -> List[Dict[str, Any]]:\n",
" if post_requests:\n",
" headers = {\"Content-Type\": \"application/json; charset=utf-8\"}\n",
" payload = {\"query\": query, \"k\": k}\n",
" res = requests.post(url, json=payload, headers=headers, timeout=10)\n",
" else:\n",
" payload = {\"query\": query, \"k\": k}\n",
" res = requests.get(url, params=payload, timeout=10)\n",
" \n",
" res.raise_for_status()\n",
" topk = res.json()[\"topk\"][:k]\n",
" topk = [{**doc, \"long_text\": doc.get(\"text\", \"\")} for doc in topk]\n",
" return topk\n",
"\n",
"def colbert_result_formatter(results: List[Dict[str, Any]]) -> Prediction:\n",
" passages = [doc[\"long_text\"] for doc in results]\n",
" return Prediction(passages=passages)\n",
"\n",
"colbert_url = \"http://20.102.90.50:2017/wiki17_abstracts\"\n",
"\n",
"colbert_rm = RM(\n",
" search_function=colbert_search_function,\n",
" result_formatter=colbert_result_formatter,\n",
" url=colbert_url,\n",
" post_requests=False\n",
")\n",
"\n",
"retrieve = Retrieve(rm=colbert_rm, k=10)\n",
"query_text = \"Example query text\"\n",
"results = retrieve(query_text)\n",
"print(results.passages)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#Querying Databricks Mosaic AI Vector Search \n",
"\n",
"#client setup\n",
"databricks_token = os.environ.get(\"DATABRICKS_TOKEN\")\n",
"databricks_endpoint = os.environ.get(\"DATABRICKS_HOST\")\n",
"databricks_client = WorkspaceClient(host=databricks_endpoint, token=databricks_token)\n",
"\n",
"#custom logic for querying and sorting the docs\n",
"def databricks_search_function(\n",
" query,\n",
" k,\n",
" index_name,\n",
" columns,\n",
" query_type='ANN',\n",
" filters_json=None,\n",
" client=None\n",
"):\n",
" results = client.vector_search_indexes.query(\n",
" index_name=index_name,\n",
" query_type=query_type,\n",
" query_text=query,\n",
" num_results=k,\n",
" columns=columns,\n",
" filters_json=filters_json,\n",
" ).as_dict()\n",
"\n",
" items = []\n",
" col_names = [column[\"name\"] for column in results[\"manifest\"][\"columns\"]]\n",
" for data_row in results[\"result\"][\"data_array\"]:\n",
" item = {col_name: val for col_name, val in zip(col_names, data_row)}\n",
" items.append(item)\n",
" sorted_docs = sorted(items, key=lambda x: x[\"score\"], reverse=True)\n",
" return sorted_docs\n",
"\n",
"def databricks_result_formatter(results) -> Prediction:\n",
" passages = [doc['some_text_column'] for doc in results] \n",
" return Prediction(passages=passages)\n",
"\n",
"databricks_rm = RM(\n",
" search_function=databricks_search_function,\n",
" result_formatter=databricks_result_formatter,\n",
" client=databricks_client,\n",
" index_name='your_index_name',\n",
" columns=['id', 'some_text_column'],\n",
" filters_json=None\n",
")\n",
"\n",
"retrieve = Retrieve(rm=databricks_rm, k=3)\n",
"results = retrieve(\"Example query text\")\n",
"print(results.passages)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#Querying Deeplake Vector Store\n",
"\n",
"embedder = Embedder()\n",
"\n",
"deeplake_vectorstore_name = 'vectorstore_name'\n",
"deeplake_client = deeplake.VectorStore(\n",
" path=deeplake_vectorstore_name,\n",
" embedding_function=embedder\n",
")\n",
"\n",
"def deeplake_search_function(query, k, client=None):\n",
" results = client.search(query, k=k)\n",
" return results\n",
"\n",
"def deeplake_result_formatter(results) -> Prediction:\n",
" passages = [doc['text'] for doc in results['documents']]\n",
" return Prediction(passages=passages)\n",
"\n",
"\n",
"deeplake_rm = RM(\n",
" embedder=embedder,\n",
" search_function=deeplake_search_function,\n",
" result_formatter=deeplake_result_formatter,\n",
" client=deeplake_client\n",
")\n",
"\n",
"retrieve = Retrieve(rm=deeplake_rm, k=3)\n",
"results = retrieve(\"some text\")\n",
"print(results.passages)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"TBD..."
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading