diff --git a/lightrag/lightrag/components/model_client/ollama_client.py b/lightrag/lightrag/components/model_client/ollama_client.py index fc569dd1..4ee1d005 100644 --- a/lightrag/lightrag/components/model_client/ollama_client.py +++ b/lightrag/lightrag/components/model_client/ollama_client.py @@ -9,6 +9,7 @@ List, Type, Generator as GeneratorType, + AsyncGenerator as AsyncGeneratorType, Union, ) import backoff @@ -36,6 +37,12 @@ def parse_stream_response(completion: GeneratorType) -> Any: log.debug(f"Raw chunk: {chunk}") yield chunk["response"] if "response" in chunk else None +async def aparse_stream_response(completion: GeneratorType) -> Any: + """Parse the completion to a str. We use the generate with prompt instead of chat with messages.""" + async for chunk in completion: + log.debug(f"Raw chunk: {chunk}") + yield chunk["response"] if "response" in chunk else None + def parse_generate_response(completion: GenerateResponse) -> Any: """Parse the completion to a str. We use the generate with prompt instead of chat with messages.""" @@ -187,6 +194,16 @@ def parse_chat_completion( else: return parse_generate_response(completion) + async def aparse_chat_completion( + self, completion: Union[GenerateResponse, GeneratorType] + ) -> Any: + """Parse the completion to a str. We use the generate with prompt instead of chat with messages.""" + log.debug(f"completion: {completion}, {isinstance(completion, GeneratorType)}") + if isinstance(completion, AsyncGeneratorType): # streaming + return aparse_stream_response(completion) + else: + return parse_generate_response(completion) + def parse_embedding_response( self, response: Dict[str, List[float]] ) -> EmbedderOutput: diff --git a/lightrag/lightrag/components/model_client/openai_client.py b/lightrag/lightrag/components/model_client/openai_client.py index 46a5c525..8f56be60 100644 --- a/lightrag/lightrag/components/model_client/openai_client.py +++ b/lightrag/lightrag/components/model_client/openai_client.py @@ -27,7 +27,7 @@ openai = safe_import(OptionalPackages.OPENAI.value[0], OptionalPackages.OPENAI.value[1]) -from openai import OpenAI, AsyncOpenAI, Stream +from openai import OpenAI, AsyncOpenAI, Stream, AsyncStream from openai import ( APITimeoutError, InternalServerError, @@ -62,6 +62,13 @@ def handle_streaming_response(generator: Stream[ChatCompletionChunk]): parsed_content = parse_stream_response(completion) yield parsed_content +async def ahandle_stream_response(generator: AsyncStream[ChatCompletionChunk]): + """Handle the streaming response (async)""" + async for completion in generator: + log.debug(f"Raw chunk completion: {completion}") + parsed_content = parse_stream_response(completion) + yield parsed_content + def get_all_messages_content(completion: ChatCompletion) -> List[str]: r"""When the n > 1, get all the messages content.""" @@ -138,6 +145,14 @@ def init_async_client(self): raise ValueError("Environment variable OPENAI_API_KEY must be set") return AsyncOpenAI(api_key=api_key) + async def aparse_chat_completion( + self, + completion: Union[ChatCompletion, Generator[ChatCompletionChunk, None, None]], + ) -> Any: + """Parse the completion to a str.""" + log.debug(f"completion: {completion}, parser: {self.chat_completion_parser}") + return self.chat_completion_parser(completion) + def parse_chat_completion( self, completion: Union[ChatCompletion, Generator[ChatCompletionChunk, None, None]], @@ -236,6 +251,10 @@ async def acall( if model_type == ModelType.EMBEDDER: return await self.async_client.embeddings.create(**api_kwargs) elif model_type == ModelType.LLM: + if "stream" in api_kwargs and api_kwargs.get("stream", False): + log.debug("streaming call") + self.chat_completion_parser = ahandle_stream_response + return await self.async_client.chat.completions.create(**api_kwargs) return await self.async_client.chat.completions.create(**api_kwargs) else: raise ValueError(f"model_type {model_type} is not supported") diff --git a/lightrag/lightrag/core/generator.py b/lightrag/lightrag/core/generator.py index 1368f046..70260057 100644 --- a/lightrag/lightrag/core/generator.py +++ b/lightrag/lightrag/core/generator.py @@ -162,6 +162,29 @@ def _extra_repr(self) -> str: s = f"model_kwargs={self.model_kwargs}, model_type={self.model_type}" return s + async def _apost_call(self, completion: Any) -> GeneratorOutputType: + r"""Get string completion and process it with the output_processors.""" + try: + response = await self.model_client.aparse_chat_completion(completion) + except Exception as e: + log.error(f"Error parsing the completion {completion}: {e}") + return GeneratorOutput(raw_response=str(completion), error=str(e)) + + # the output processors operate on the str, the raw_response field. + output: GeneratorOutputType = GeneratorOutput(raw_response=response) + + if self.output_processors: + try: + response = self.output_processors(response) + output.data = response + except Exception as e: + log.error(f"Error processing the output processors: {e}") + output.error = str(e) + else: # default to string output + output.data = response + + return output + def _post_call(self, completion: Any) -> GeneratorOutputType: r"""Get string completion and process it with the output_processors.""" try: @@ -251,7 +274,7 @@ async def acall( self, prompt_kwargs: Optional[Dict] = {}, model_kwargs: Optional[Dict] = {}, - ) -> GeneratorOutputType: + ) -> GeneratorOutputType: r"""Async call the model with the input and model_kwargs. :warning:: @@ -263,8 +286,8 @@ async def acall( api_kwargs = self._pre_call(prompt_kwargs, model_kwargs) completion = await self.model_client.acall( api_kwargs=api_kwargs, model_type=self.model_type - ) - output = self._post_call(completion) + ) + output = await self._apost_call(completion) log.info(f"output: {output}") return output diff --git a/lightrag/lightrag/core/model_client.py b/lightrag/lightrag/core/model_client.py index 2e3b699c..0916bd69 100644 --- a/lightrag/lightrag/core/model_client.py +++ b/lightrag/lightrag/core/model_client.py @@ -90,6 +90,12 @@ def parse_chat_completion(self, completion: Any) -> Any: f"{type(self).__name__} must implement parse_chat_completion method" ) + async def aparse_chat_completion(self, completion: Any) -> Any: + r"""Parse the chat completion to str.""" + raise NotImplementedError( + f"{type(self).__name__} must implement aparse_chat_completion method" + ) + def parse_embedding_response(self, response: Any) -> EmbedderOutput: r"""Parse the embedding response to a structure LightRAG components can understand.""" raise NotImplementedError( diff --git a/lightrag/pyproject.toml b/lightrag/pyproject.toml index 609b4892..c53474ce 100644 --- a/lightrag/pyproject.toml +++ b/lightrag/pyproject.toml @@ -61,6 +61,7 @@ ollama = { version = "^0.2.1", optional = true } [tool.poetry.group.test.dependencies] pytest = "^8.1.1" pytest-mock = "^3.14.0" +pytest-asyncio = "^0.23.8" torch = "^2.3.1" ollama = "^0.2.1" faiss-cpu = "^1.8.0" diff --git a/lightrag/tests/test_generator_async.py b/lightrag/tests/test_generator_async.py new file mode 100644 index 00000000..6d2bec3e --- /dev/null +++ b/lightrag/tests/test_generator_async.py @@ -0,0 +1,25 @@ +import pytest +from typing import AsyncGenerator +from unittest.mock import AsyncMock +from lightrag.components.model_client.ollama_client import OllamaClient +from lightrag.core import Generator + +async def async_gen() -> AsyncGenerator[int, None]: + yield {"response": "I"} + yield {"response": " am"} + yield {"response": " hungry"} + + +@pytest.mark.asyncio +async def test_acall(): + ollama_client = OllamaClient() + ollama_client.acall = AsyncMock(return_value = async_gen()) + + generator = Generator(model_client=ollama_client) + output = await generator.acall({}, {}) + + result = [] + async for value in output.data: + result.append(value) + assert result == ["I", " am", " hungry"] + diff --git a/lightrag/tests/test_ollama_async_client.py b/lightrag/tests/test_ollama_async_client.py new file mode 100644 index 00000000..811f3aa4 --- /dev/null +++ b/lightrag/tests/test_ollama_async_client.py @@ -0,0 +1,45 @@ +import pytest +from typing import AsyncGenerator +from unittest.mock import AsyncMock +from lightrag.core.types import ModelType +from lightrag.components.model_client.ollama_client import OllamaClient + +@pytest.mark.asyncio +async def test_ollama_llm_client_async(): + ollama_client = AsyncMock(spec=OllamaClient()) + ollama_client.acall.return_value = {"message": "Hello"} + print("Testing ollama LLM async client") + + # run the model + kwargs = { + "model": "qwen2:0.5b", + } + api_kwargs = ollama_client.convert_inputs_to_api_kwargs( + input="Hello world", + model_kwargs=kwargs, + model_type=ModelType.LLM, + ).return_value = {"prompt": "Hello World", "model": "qwen2:0.5b"} + + assert api_kwargs == {"prompt": "Hello World", "model": "qwen2:0.5b"} + + output = await ollama_client.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM) + assert output == {"message": "Hello"} + + +async def async_gen() -> AsyncGenerator[int, None]: + yield {"response": "I"} + yield {"response": " am"} + yield {"response": " cool"} + +@pytest.mark.asyncio +async def test_async_generator_completion(): + ollama_client = OllamaClient() + print("Testing ollama LLM async client") + + output = await ollama_client.aparse_chat_completion(async_gen()) + + result = [] + async for value in output: + result.append(value) + + assert result == ["I", " am", " cool"] diff --git a/lightrag/tests/test_ollama_client.py b/lightrag/tests/test_ollama_client.py index b4a07a39..85358f95 100644 --- a/lightrag/tests/test_ollama_client.py +++ b/lightrag/tests/test_ollama_client.py @@ -1,5 +1,6 @@ import unittest -from unittest.mock import Mock +import asyncio +from unittest.mock import Mock, AsyncMock from lightrag.core.types import ModelType from lightrag.components.model_client.ollama_client import OllamaClient @@ -34,6 +35,24 @@ def test_ollama_llm_client(self): ).return_value = {"message": "Hello"} assert output == {"message": "Hello"} + async def test_ollama_llm_client_acall(self): + ollama_client = AsyncMock(spec=OllamaClient()) + ollama_client.acall.return_value = {"message": "Hello"} + print("Testing ollama LLM client") + # run the model + kwargs = { + "model": "qwen2:0.5b", + } + api_kwargs = ollama_client.convert_inputs_to_api_kwargs( + input="Hello world", + model_kwargs=kwargs, + model_type=ModelType.LLM, + ).return_value = {"prompt": "Hello World", "model": "qwen2:0.5b"} + assert api_kwargs == {"prompt": "Hello World", "model": "qwen2:0.5b"} + + output = await ollama_client.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM) + assert output == {"message": "Hello"} + def test_ollama_embedding_client(self): ollama_client = Mock(spec=OllamaClient()) print("Testing ollama embedding client")