Skip to content

Commit

Permalink
VertexAI
Browse files Browse the repository at this point in the history
  • Loading branch information
leila-messallem committed Dec 12, 2024
1 parent 72f4de5 commit 5720a4b
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 6 deletions.
35 changes: 31 additions & 4 deletions src/neo4j_graphrag/llm/vertexai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,19 @@

from typing import Any, Optional

from pydantic import ValidationError

from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.llm.base import LLMInterface
from neo4j_graphrag.llm.types import LLMResponse
from neo4j_graphrag.llm.types import LLMResponse, MessageList

try:
from vertexai.generative_models import GenerativeModel, ResponseValidationError
from vertexai.generative_models import (
GenerativeModel,
ResponseValidationError,
Part,
Content,
)
except ImportError:
GenerativeModel = None
ResponseValidationError = None
Expand Down Expand Up @@ -69,6 +76,24 @@ def __init__(
model_name=model_name, system_instruction=[system_instruction], **kwargs
)


def get_messages(self, input: str, chat_history: list[str]) -> list[Content]:
messages = []
if chat_history:
try:
MessageList(messages=chat_history)
except ValidationError as e:
raise LLMGenerationError(e.errors()) from e

for message in chat_history:
if message.get("role") == "user":
messages.append(Content(role="user", parts=[Part.from_text(message.get("content"))]))
elif message.get("role") == "assistant":
messages.append(Content(role="model", parts=[Part.from_text(message.get("content"))]))

messages.append(Content(role="user", parts=[Part.from_text(input)]))
return messages

def invoke(self, input: str, chat_history: Optional[list[dict[str, str]]] = None) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Expand All @@ -80,7 +105,8 @@ def invoke(self, input: str, chat_history: Optional[list[dict[str, str]]] = None
LLMResponse: The response from the LLM.
"""
try:
response = self.model.generate_content(input, **self.model_params)
messages = self.get_messages(input, chat_history)
response = self.model.generate_content(messages, **self.model_params)
return LLMResponse(content=response.text)
except ResponseValidationError as e:
raise LLMGenerationError(e)
Expand All @@ -98,8 +124,9 @@ async def ainvoke(
LLMResponse: The response from the LLM.
"""
try:
messages = self.get_messages(input, chat_history)
response = await self.model.generate_content_async(
input, **self.model_params
messages, **self.model_params
)
return LLMResponse(content=response.text)
except ResponseValidationError as e:
Expand Down
59 changes: 57 additions & 2 deletions tests/unit/llm/test_vertexai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
# limitations under the License.
from __future__ import annotations

from unittest import mock
from unittest.mock import AsyncMock, MagicMock, Mock, patch

import pytest
from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.llm.vertexai_llm import VertexAILLM
from vertexai.generative_models import Content, Part


@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel", None)
Expand All @@ -36,7 +39,59 @@ def test_vertexai_invoke_happy_path(GenerativeModelMock: MagicMock) -> None:
input_text = "may thy knife chip and shatter"
response = llm.invoke(input_text)
assert response.content == "Return text"
llm.model.generate_content.assert_called_once_with(input_text, **model_params)
llm.model.generate_content.assert_called_once_with([mock.ANY], **model_params)


@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel")
def test_vertexai_get_messages(GenerativeModelMock: MagicMock) -> None:
system_instruction = "You are a helpful assistant."
model_name = "gemini-1.5-flash-001"
question = "When does it set?"
chat_history = [
{"role": "user", "content": "When does the sun come up in the summer?"},
{"role": "assistant", "content": "Usually around 6am."},
{"role": "user", "content": "What about next season?"},
{"role": "assistant", "content": "Around 8am."},
]
expected_response = [
Content(
role="user",
parts=[Part.from_text("When does the sun come up in the summer?")],
),
Content(role="model", parts=[Part.from_text("Usually around 6am.")]),
Content(role="user", parts=[
Part.from_text("What about next season?")]),
Content(role="model", parts=[Part.from_text("Around 8am.")]),
Content(role="user", parts=[Part.from_text("When does it set?")]),
]

llm = VertexAILLM(
model_name=model_name, system_instruction=system_instruction
)
response = llm.get_messages(question, chat_history)

GenerativeModelMock.assert_called_once_with(model_name=model_name, system_instruction=[system_instruction])
assert len(response) == len(expected_response)
for actual, expected in zip(response, expected_response):
assert actual.role == expected.role
assert actual.parts[0].text == expected.parts[0].text


@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel")
def test_vertexai_get_messages_validation_error(GenerativeModelMock: MagicMock) -> None:
system_instruction = "You are a helpful assistant."
model_name = "gemini-1.5-flash-001"
question = "hi!"
chat_history = [
{"role": "model", "content": "hello!"},
]

llm = VertexAILLM(
model_name=model_name, system_instruction=system_instruction
)
with pytest.raises(LLMGenerationError) as exc_info:
llm.invoke(question, chat_history)
assert "Input should be 'user' or 'assistant'" in str(exc_info.value)


@pytest.mark.asyncio
Expand All @@ -51,4 +106,4 @@ async def test_vertexai_ainvoke_happy_path(GenerativeModelMock: MagicMock) -> No
input_text = "may thy knife chip and shatter"
response = await llm.ainvoke(input_text)
assert response.content == "Return text"
llm.model.generate_content_async.assert_called_once_with(input_text, **model_params)
llm.model.generate_content_async.assert_called_once_with([mock.ANY], **model_params)

0 comments on commit 5720a4b

Please sign in to comment.