From 628890750e282ea5cbff1a8d915d6061dd13cd67 Mon Sep 17 00:00:00 2001 From: Leila Messallem Date: Fri, 13 Dec 2024 13:45:34 +0100 Subject: [PATCH] Fix mypy errors --- src/neo4j_graphrag/generation/graphrag.py | 4 +++- src/neo4j_graphrag/llm/anthropic_llm.py | 8 +++++--- src/neo4j_graphrag/llm/cohere_llm.py | 8 +++++--- src/neo4j_graphrag/llm/mistralai_llm.py | 12 +++++++----- src/neo4j_graphrag/llm/openai_llm.py | 6 +++--- src/neo4j_graphrag/llm/types.py | 2 +- src/neo4j_graphrag/llm/vertexai_llm.py | 8 +++++--- 7 files changed, 29 insertions(+), 19 deletions(-) diff --git a/src/neo4j_graphrag/generation/graphrag.py b/src/neo4j_graphrag/generation/graphrag.py index 5694a910..e360d641 100644 --- a/src/neo4j_graphrag/generation/graphrag.py +++ b/src/neo4j_graphrag/generation/graphrag.py @@ -146,7 +146,9 @@ def search( result["retriever_result"] = retriever_result return RagResultModel(**result) - def build_query(self, query_text: str, chat_history: list[dict[str, str]]) -> str: + def build_query( + self, query_text: str, chat_history: Optional[list[dict[str, str]]] = None + ) -> str: if chat_history: summarization_prompt = ChatSummaryTemplate().format( chat_history=chat_history diff --git a/src/neo4j_graphrag/llm/anthropic_llm.py b/src/neo4j_graphrag/llm/anthropic_llm.py index 6f1632e3..2c220c2a 100644 --- a/src/neo4j_graphrag/llm/anthropic_llm.py +++ b/src/neo4j_graphrag/llm/anthropic_llm.py @@ -71,7 +71,9 @@ def __init__( self.client = anthropic.Anthropic(**kwargs) self.async_client = anthropic.AsyncAnthropic(**kwargs) - def get_messages(self, input: str, chat_history: list) -> Iterable[MessageParam]: + def get_messages( + self, input: str, chat_history: Optional[list[Any]] = None + ) -> Iterable[MessageParam]: messages = [] if chat_history: try: @@ -83,7 +85,7 @@ def get_messages(self, input: str, chat_history: list) -> Iterable[MessageParam] return messages def invoke( - self, input: str, chat_history: Optional[list[dict[str, str]]] = None + self, input: str, chat_history: Optional[list[Any]] = None ) -> LLMResponse: """Sends text to the LLM and returns a response. @@ -107,7 +109,7 @@ def invoke( raise LLMGenerationError(e) async def ainvoke( - self, input: str, chat_history: Optional[list[dict[str, str]]] = None + self, input: str, chat_history: Optional[list[Any]] = None ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. diff --git a/src/neo4j_graphrag/llm/cohere_llm.py b/src/neo4j_graphrag/llm/cohere_llm.py index 3a502b3f..63fc396b 100644 --- a/src/neo4j_graphrag/llm/cohere_llm.py +++ b/src/neo4j_graphrag/llm/cohere_llm.py @@ -74,7 +74,9 @@ def __init__( self.client = cohere.ClientV2(**kwargs) self.async_client = cohere.AsyncClientV2(**kwargs) - def get_messages(self, input: str, chat_history: list) -> ChatMessages: # type: ignore + def get_messages( + self, input: str, chat_history: Optional[list[Any]] = None + ) -> ChatMessages: messages = [] if self.system_instruction: messages.append(SystemMessage(content=self.system_instruction).model_dump()) @@ -88,7 +90,7 @@ def get_messages(self, input: str, chat_history: list) -> ChatMessages: # type: return messages def invoke( - self, input: str, chat_history: Optional[list[dict[str, str]]] = None + self, input: str, chat_history: Optional[list[Any]] = None ) -> LLMResponse: """Sends text to the LLM and returns a response. @@ -112,7 +114,7 @@ def invoke( ) async def ainvoke( - self, input: str, chat_history: Optional[list[dict[str, str]]] = None + self, input: str, chat_history: Optional[list[Any]] = None ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. diff --git a/src/neo4j_graphrag/llm/mistralai_llm.py b/src/neo4j_graphrag/llm/mistralai_llm.py index 63dd6a66..b95249c8 100644 --- a/src/neo4j_graphrag/llm/mistralai_llm.py +++ b/src/neo4j_graphrag/llm/mistralai_llm.py @@ -31,8 +31,8 @@ from mistralai import Mistral, Messages from mistralai.models.sdkerror import SDKError except ImportError: - Mistral = None # type: ignore - SDKError = None # type: ignore + Mistral = None + SDKError = None class MistralAILLM(LLMInterface): @@ -64,7 +64,9 @@ def __init__( api_key = os.getenv("MISTRAL_API_KEY", "") self.client = Mistral(api_key=api_key, **kwargs) - def get_messages(self, input: str, chat_history: list) -> list[Messages]: + def get_messages( + self, input: str, chat_history: Optional[list[Any]] = None + ) -> list[Messages]: messages = [] if self.system_instruction: messages.append(SystemMessage(content=self.system_instruction).model_dump()) @@ -78,7 +80,7 @@ def get_messages(self, input: str, chat_history: list) -> list[Messages]: return messages def invoke( - self, input: str, chat_history: Optional[list[dict[str, str]]] = None + self, input: str, chat_history: Optional[list[Any]] = None ) -> LLMResponse: """Sends a text input to the Mistral chat completion model and returns the response's content. @@ -110,7 +112,7 @@ def invoke( raise LLMGenerationError(e) async def ainvoke( - self, input: str, chat_history: Optional[list[dict[str, str]]] = None + self, input: str, chat_history: Optional[list[Any]] = None ) -> LLMResponse: """Asynchronously sends a text input to the MistralAI chat completion model and returns the response's content. diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index 146b6169..a349a38c 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -61,7 +61,7 @@ def __init__( super().__init__(model_name, model_params, system_instruction) def get_messages( - self, input: str, chat_history: list + self, input: str, chat_history: Optional[list[Any]] = None ) -> Iterable[ChatCompletionMessageParam]: messages = [] if self.system_instruction: @@ -76,7 +76,7 @@ def get_messages( return messages def invoke( - self, input: str, chat_history: Optional[list[dict[str, str]]] = None + self, input: str, chat_history: Optional[list[Any]] = None ) -> LLMResponse: """Sends a text input to the OpenAI chat completion model and returns the response's content. @@ -103,7 +103,7 @@ def invoke( raise LLMGenerationError(e) async def ainvoke( - self, input: str, chat_history: Optional[list[dict[str, str]]] = None + self, input: str, chat_history: Optional[list[Any]] = None ) -> LLMResponse: """Asynchronously sends a text input to the OpenAI chat completion model and returns the response's content. diff --git a/src/neo4j_graphrag/llm/types.py b/src/neo4j_graphrag/llm/types.py index d243871f..a6888475 100644 --- a/src/neo4j_graphrag/llm/types.py +++ b/src/neo4j_graphrag/llm/types.py @@ -7,7 +7,7 @@ class LLMResponse(BaseModel): class BaseMessage(BaseModel): - role: Literal["user", "assistant"] + role: Literal["user", "assistant", "system"] content: str diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index ae27e4ff..66ad6683 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -76,7 +76,9 @@ def __init__( model_name=model_name, system_instruction=[system_instruction], **kwargs ) - def get_messages(self, input: str, chat_history: list[str]) -> list[Content]: + def get_messages( + self, input: str, chat_history: Optional[list[Any]] = None + ) -> list[Content]: messages = [] if chat_history: try: @@ -102,7 +104,7 @@ def get_messages(self, input: str, chat_history: list[str]) -> list[Content]: return messages def invoke( - self, input: str, chat_history: Optional[list[dict[str, str]]] = None + self, input: str, chat_history: Optional[list[Any]] = None ) -> LLMResponse: """Sends text to the LLM and returns a response. @@ -121,7 +123,7 @@ def invoke( raise LLMGenerationError(e) async def ainvoke( - self, input: str, chat_history: Optional[list[dict[str, str]]] = None + self, input: str, chat_history: Optional[list[Any]] = None ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response.