Skip to content

Commit

Permalink
Change GraphRAG search parameter query to query_text (neo4j#89)
Browse files Browse the repository at this point in the history
* Fix doc for GraphRAG (arg is called query and not query_text=

* Rename GraphRAG search parameter query to query_text for consistency with Retriever interface - and let the possibility to add query_vector param later on if requested

* Update CHANGELOG

* Check disk space before install

* Check container size

* Check docker image size

* Test with one single neo/python version
  • Loading branch information
stellasia authored Aug 7, 2024
1 parent 337aeba commit 6998402
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 19 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
- Corrected initialization to allow specifying the embedding model name.
- Removed sentence_transformers from embeddings/__init__.py to avoid ImportError when the package is not installed.

### Changed
- `GraphRAG.search` method first parameter has been renamed `query_text` (was `query`) for consistency with the retrievers interface.

## 0.3.0

### Added
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ rag = GraphRAG(retriever=retriever, llm=llm)

# Query the graph
query_text = "How do I do similarity search in Neo4j?"
response = rag.search(query_text=query_text, retriever_config={"top_k": 5})
response = rag.search(query=query_text, retriever_config={"top_k": 5})
print(response.answer)
```

Expand Down
2 changes: 1 addition & 1 deletion examples/graphrag_custom_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def formatter(record: neo4j.Record) -> RetrieverResultItem:
{context}
Question:
{query}
{query_text}
Answer:
"""
Expand Down
12 changes: 6 additions & 6 deletions src/neo4j_genai/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(

def search(
self,
query: str,
query_text: str,
examples: str = "",
retriever_config: Optional[dict[str, Any]] = None,
return_context: bool = False,
Expand All @@ -64,7 +64,7 @@ def search(
3. Generation: answer generation with LLM
Args:
query (str): The user question
query_text (str): The user question
examples: Examples added to the LLM prompt.
retriever_config (Optional[dict]): Parameters passed to the retriever
search method; e.g.: top_k
Expand All @@ -76,20 +76,20 @@ def search(
"""
try:
validated_data = RagSearchModel(
query=query,
query_text=query_text,
examples=examples,
retriever_config=retriever_config or {},
return_context=return_context,
)
except ValidationError as e:
raise SearchValidationError(e.errors())
query = validated_data.query
query_text = validated_data.query_text
retriever_result: RetrieverResult = self.retriever.search(
query_text=query, **validated_data.retriever_config
query_text=query_text, **validated_data.retriever_config
)
context = "\n".join(item.content for item in retriever_result.items)
prompt = self.prompt_template.format(
query=query, context=context, examples=validated_data.examples
query_text=query_text, context=context, examples=validated_data.examples
)
logger.debug(f"RAG: retriever_result={retriever_result}")
logger.debug(f"RAG: prompt={prompt}")
Expand Down
8 changes: 4 additions & 4 deletions src/neo4j_genai/generation/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ class RagTemplate(PromptTemplate):
{examples}
Question:
{query}
{query_text}
Answer:
"""
EXPECTED_INPUTS = ["context", "query", "examples"]
EXPECTED_INPUTS = ["context", "query_text", "examples"]

def format(self, query: str, context: str, examples: str) -> str:
return super().format(query=query, context=context, examples=examples)
def format(self, query_text: str, context: str, examples: str) -> str:
return super().format(query_text=query_text, context=context, examples=examples)


class Text2CypherTemplate(PromptTemplate):
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_genai/generation/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def check_llm(cls, value: Any) -> Any:


class RagSearchModel(BaseModel):
query: str
query_text: str
examples: str = ""
retriever_config: dict[str, Any] = {}
return_context: bool = False
Expand Down
10 changes: 5 additions & 5 deletions tests/e2e/test_graphrag_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_graphrag_happy_path(
llm.invoke.return_value = LLMResponse(content="some text")

result = rag.search(
query="biology",
query_text="biology",
retriever_config={
"top_k": 2,
},
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_graphrag_happy_path_return_context(
llm.invoke.return_value = LLMResponse(content="some text")

result = rag.search(
query="biology",
query_text="biology",
retriever_config={
"top_k": 2,
},
Expand Down Expand Up @@ -142,7 +142,7 @@ def test_graphrag_happy_path_examples(
llm.invoke.return_value = LLMResponse(content="some text")

result = rag.search(
query="biology",
query_text="biology",
retriever_config={
"top_k": 2,
},
Expand Down Expand Up @@ -186,7 +186,7 @@ def test_graphrag_llm_error(

with pytest.raises(LLMGenerationError):
rag.search(
query="biology",
query_text="biology",
)


Expand All @@ -203,5 +203,5 @@ def test_graphrag_retrieval_error(

with pytest.raises(TypeError):
rag.search(
query="biology",
query_text="biology",
)
4 changes: 3 additions & 1 deletion tests/unit/test_graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@

def test_graphrag_prompt_template() -> None:
template = RagTemplate()
prompt = template.format(context="my context", query="user's query", examples="")
prompt = template.format(
context="my context", query_text="user's query", examples=""
)
assert (
prompt
== """Answer the user question using the following context
Expand Down

0 comments on commit 6998402

Please sign in to comment.