Skip to content

Commit

Permalink
Fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
leila-messallem committed Dec 13, 2024
1 parent f2792ff commit 6288907
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 19 deletions.
4 changes: 3 additions & 1 deletion src/neo4j_graphrag/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def search(
result["retriever_result"] = retriever_result
return RagResultModel(**result)

def build_query(self, query_text: str, chat_history: list[dict[str, str]]) -> str:
def build_query(
self, query_text: str, chat_history: Optional[list[dict[str, str]]] = None
) -> str:
if chat_history:
summarization_prompt = ChatSummaryTemplate().format(
chat_history=chat_history
Expand Down
8 changes: 5 additions & 3 deletions src/neo4j_graphrag/llm/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def __init__(
self.client = anthropic.Anthropic(**kwargs)
self.async_client = anthropic.AsyncAnthropic(**kwargs)

def get_messages(self, input: str, chat_history: list) -> Iterable[MessageParam]:
def get_messages(
self, input: str, chat_history: Optional[list[Any]] = None
) -> Iterable[MessageParam]:
messages = []
if chat_history:
try:
Expand All @@ -83,7 +85,7 @@ def get_messages(self, input: str, chat_history: list) -> Iterable[MessageParam]
return messages

def invoke(
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
self, input: str, chat_history: Optional[list[Any]] = None
) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Expand All @@ -107,7 +109,7 @@ def invoke(
raise LLMGenerationError(e)

async def ainvoke(
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
self, input: str, chat_history: Optional[list[Any]] = None
) -> LLMResponse:
"""Asynchronously sends text to the LLM and returns a response.
Expand Down
8 changes: 5 additions & 3 deletions src/neo4j_graphrag/llm/cohere_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def __init__(
self.client = cohere.ClientV2(**kwargs)
self.async_client = cohere.AsyncClientV2(**kwargs)

def get_messages(self, input: str, chat_history: list) -> ChatMessages: # type: ignore
def get_messages(
self, input: str, chat_history: Optional[list[Any]] = None
) -> ChatMessages:
messages = []
if self.system_instruction:
messages.append(SystemMessage(content=self.system_instruction).model_dump())
Expand All @@ -88,7 +90,7 @@ def get_messages(self, input: str, chat_history: list) -> ChatMessages: # type:
return messages

def invoke(
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
self, input: str, chat_history: Optional[list[Any]] = None
) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Expand All @@ -112,7 +114,7 @@ def invoke(
)

async def ainvoke(
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
self, input: str, chat_history: Optional[list[Any]] = None
) -> LLMResponse:
"""Asynchronously sends text to the LLM and returns a response.
Expand Down
12 changes: 7 additions & 5 deletions src/neo4j_graphrag/llm/mistralai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
from mistralai import Mistral, Messages
from mistralai.models.sdkerror import SDKError
except ImportError:
Mistral = None # type: ignore
SDKError = None # type: ignore
Mistral = None
SDKError = None


class MistralAILLM(LLMInterface):
Expand Down Expand Up @@ -64,7 +64,9 @@ def __init__(
api_key = os.getenv("MISTRAL_API_KEY", "")
self.client = Mistral(api_key=api_key, **kwargs)

def get_messages(self, input: str, chat_history: list) -> list[Messages]:
def get_messages(
self, input: str, chat_history: Optional[list[Any]] = None
) -> list[Messages]:
messages = []
if self.system_instruction:
messages.append(SystemMessage(content=self.system_instruction).model_dump())
Expand All @@ -78,7 +80,7 @@ def get_messages(self, input: str, chat_history: list) -> list[Messages]:
return messages

def invoke(
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
self, input: str, chat_history: Optional[list[Any]] = None
) -> LLMResponse:
"""Sends a text input to the Mistral chat completion model
and returns the response's content.
Expand Down Expand Up @@ -110,7 +112,7 @@ def invoke(
raise LLMGenerationError(e)

async def ainvoke(
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
self, input: str, chat_history: Optional[list[Any]] = None
) -> LLMResponse:
"""Asynchronously sends a text input to the MistralAI chat
completion model and returns the response's content.
Expand Down
6 changes: 3 additions & 3 deletions src/neo4j_graphrag/llm/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
super().__init__(model_name, model_params, system_instruction)

def get_messages(
self, input: str, chat_history: list
self, input: str, chat_history: Optional[list[Any]] = None
) -> Iterable[ChatCompletionMessageParam]:
messages = []
if self.system_instruction:
Expand All @@ -76,7 +76,7 @@ def get_messages(
return messages

def invoke(
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
self, input: str, chat_history: Optional[list[Any]] = None
) -> LLMResponse:
"""Sends a text input to the OpenAI chat completion model
and returns the response's content.
Expand All @@ -103,7 +103,7 @@ def invoke(
raise LLMGenerationError(e)

async def ainvoke(
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
self, input: str, chat_history: Optional[list[Any]] = None
) -> LLMResponse:
"""Asynchronously sends a text input to the OpenAI chat
completion model and returns the response's content.
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_graphrag/llm/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class LLMResponse(BaseModel):


class BaseMessage(BaseModel):
role: Literal["user", "assistant"]
role: Literal["user", "assistant", "system"]
content: str


Expand Down
8 changes: 5 additions & 3 deletions src/neo4j_graphrag/llm/vertexai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def __init__(
model_name=model_name, system_instruction=[system_instruction], **kwargs
)

def get_messages(self, input: str, chat_history: list[str]) -> list[Content]:
def get_messages(
self, input: str, chat_history: Optional[list[Any]] = None
) -> list[Content]:
messages = []
if chat_history:
try:
Expand All @@ -102,7 +104,7 @@ def get_messages(self, input: str, chat_history: list[str]) -> list[Content]:
return messages

def invoke(
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
self, input: str, chat_history: Optional[list[Any]] = None
) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Expand All @@ -121,7 +123,7 @@ def invoke(
raise LLMGenerationError(e)

async def ainvoke(
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
self, input: str, chat_history: Optional[list[Any]] = None
) -> LLMResponse:
"""Asynchronously sends text to the LLM and returns a response.
Expand Down

0 comments on commit 6288907

Please sign in to comment.