From 5720a4b6752e33b4032e990aae7d5e4078b48d83 Mon Sep 17 00:00:00 2001 From: Leila Messallem Date: Thu, 12 Dec 2024 18:09:14 +0100 Subject: [PATCH] VertexAI --- src/neo4j_graphrag/llm/vertexai_llm.py | 35 +++++++++++++-- tests/unit/llm/test_vertexai_llm.py | 59 +++++++++++++++++++++++++- 2 files changed, 88 insertions(+), 6 deletions(-) diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index 9047bfb8..2900deec 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -15,12 +15,19 @@ from typing import Any, Optional +from pydantic import ValidationError + from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.types import LLMResponse +from neo4j_graphrag.llm.types import LLMResponse, MessageList try: - from vertexai.generative_models import GenerativeModel, ResponseValidationError + from vertexai.generative_models import ( + GenerativeModel, + ResponseValidationError, + Part, + Content, + ) except ImportError: GenerativeModel = None ResponseValidationError = None @@ -69,6 +76,24 @@ def __init__( model_name=model_name, system_instruction=[system_instruction], **kwargs ) + + def get_messages(self, input: str, chat_history: list[str]) -> list[Content]: + messages = [] + if chat_history: + try: + MessageList(messages=chat_history) + except ValidationError as e: + raise LLMGenerationError(e.errors()) from e + + for message in chat_history: + if message.get("role") == "user": + messages.append(Content(role="user", parts=[Part.from_text(message.get("content"))])) + elif message.get("role") == "assistant": + messages.append(Content(role="model", parts=[Part.from_text(message.get("content"))])) + + messages.append(Content(role="user", parts=[Part.from_text(input)])) + return messages + def invoke(self, input: str, chat_history: Optional[list[dict[str, str]]] = None) -> LLMResponse: """Sends text to the LLM and returns a response. @@ -80,7 +105,8 @@ def invoke(self, input: str, chat_history: Optional[list[dict[str, str]]] = None LLMResponse: The response from the LLM. """ try: - response = self.model.generate_content(input, **self.model_params) + messages = self.get_messages(input, chat_history) + response = self.model.generate_content(messages, **self.model_params) return LLMResponse(content=response.text) except ResponseValidationError as e: raise LLMGenerationError(e) @@ -98,8 +124,9 @@ async def ainvoke( LLMResponse: The response from the LLM. """ try: + messages = self.get_messages(input, chat_history) response = await self.model.generate_content_async( - input, **self.model_params + messages, **self.model_params ) return LLMResponse(content=response.text) except ResponseValidationError as e: diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index adffeb1d..7b5a4cf9 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -13,10 +13,13 @@ # limitations under the License. from __future__ import annotations +from unittest import mock from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest +from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.vertexai_llm import VertexAILLM +from vertexai.generative_models import Content, Part @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel", None) @@ -36,7 +39,59 @@ def test_vertexai_invoke_happy_path(GenerativeModelMock: MagicMock) -> None: input_text = "may thy knife chip and shatter" response = llm.invoke(input_text) assert response.content == "Return text" - llm.model.generate_content.assert_called_once_with(input_text, **model_params) + llm.model.generate_content.assert_called_once_with([mock.ANY], **model_params) + + +@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") +def test_vertexai_get_messages(GenerativeModelMock: MagicMock) -> None: + system_instruction = "You are a helpful assistant." + model_name = "gemini-1.5-flash-001" + question = "When does it set?" + chat_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + {"role": "user", "content": "What about next season?"}, + {"role": "assistant", "content": "Around 8am."}, + ] + expected_response = [ + Content( + role="user", + parts=[Part.from_text("When does the sun come up in the summer?")], + ), + Content(role="model", parts=[Part.from_text("Usually around 6am.")]), + Content(role="user", parts=[ + Part.from_text("What about next season?")]), + Content(role="model", parts=[Part.from_text("Around 8am.")]), + Content(role="user", parts=[Part.from_text("When does it set?")]), + ] + + llm = VertexAILLM( + model_name=model_name, system_instruction=system_instruction + ) + response = llm.get_messages(question, chat_history) + + GenerativeModelMock.assert_called_once_with(model_name=model_name, system_instruction=[system_instruction]) + assert len(response) == len(expected_response) + for actual, expected in zip(response, expected_response): + assert actual.role == expected.role + assert actual.parts[0].text == expected.parts[0].text + + +@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") +def test_vertexai_get_messages_validation_error(GenerativeModelMock: MagicMock) -> None: + system_instruction = "You are a helpful assistant." + model_name = "gemini-1.5-flash-001" + question = "hi!" + chat_history = [ + {"role": "model", "content": "hello!"}, + ] + + llm = VertexAILLM( + model_name=model_name, system_instruction=system_instruction + ) + with pytest.raises(LLMGenerationError) as exc_info: + llm.invoke(question, chat_history) + assert "Input should be 'user' or 'assistant'" in str(exc_info.value) @pytest.mark.asyncio @@ -51,4 +106,4 @@ async def test_vertexai_ainvoke_happy_path(GenerativeModelMock: MagicMock) -> No input_text = "may thy knife chip and shatter" response = await llm.ainvoke(input_text) assert response.content == "Return text" - llm.model.generate_content_async.assert_called_once_with(input_text, **model_params) + llm.model.generate_content_async.assert_called_once_with([mock.ANY], **model_params)