Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add streaming to acall #158

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions lightrag/lightrag/components/model_client/ollama_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
List,
Type,
Generator as GeneratorType,
AsyncGenerator as AsyncGeneratorType,
Union,
)
import backoff
Expand Down Expand Up @@ -36,6 +37,12 @@ def parse_stream_response(completion: GeneratorType) -> Any:
log.debug(f"Raw chunk: {chunk}")
yield chunk["response"] if "response" in chunk else None

async def aparse_stream_response(completion: GeneratorType) -> Any:
"""Parse the completion to a str. We use the generate with prompt instead of chat with messages."""
async for chunk in completion:
log.debug(f"Raw chunk: {chunk}")
yield chunk["response"] if "response" in chunk else None


def parse_generate_response(completion: GenerateResponse) -> Any:
"""Parse the completion to a str. We use the generate with prompt instead of chat with messages."""
Expand Down Expand Up @@ -187,6 +194,16 @@ def parse_chat_completion(
else:
return parse_generate_response(completion)

async def aparse_chat_completion(
self, completion: Union[GenerateResponse, GeneratorType]
) -> Any:
"""Parse the completion to a str. We use the generate with prompt instead of chat with messages."""
log.debug(f"completion: {completion}, {isinstance(completion, GeneratorType)}")
if isinstance(completion, AsyncGeneratorType): # streaming
return aparse_stream_response(completion)
else:
return parse_generate_response(completion)

def parse_embedding_response(
self, response: Dict[str, List[float]]
) -> EmbedderOutput:
Expand Down
21 changes: 20 additions & 1 deletion lightrag/lightrag/components/model_client/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

openai = safe_import(OptionalPackages.OPENAI.value[0], OptionalPackages.OPENAI.value[1])

from openai import OpenAI, AsyncOpenAI, Stream
from openai import OpenAI, AsyncOpenAI, Stream, AsyncStream
from openai import (
APITimeoutError,
InternalServerError,
Expand Down Expand Up @@ -62,6 +62,13 @@ def handle_streaming_response(generator: Stream[ChatCompletionChunk]):
parsed_content = parse_stream_response(completion)
yield parsed_content

async def ahandle_stream_response(generator: AsyncStream[ChatCompletionChunk]):
"""Handle the streaming response (async)"""
async for completion in generator:
log.debug(f"Raw chunk completion: {completion}")
parsed_content = parse_stream_response(completion)
yield parsed_content


def get_all_messages_content(completion: ChatCompletion) -> List[str]:
r"""When the n > 1, get all the messages content."""
Expand Down Expand Up @@ -138,6 +145,14 @@ def init_async_client(self):
raise ValueError("Environment variable OPENAI_API_KEY must be set")
return AsyncOpenAI(api_key=api_key)

async def aparse_chat_completion(
self,
completion: Union[ChatCompletion, Generator[ChatCompletionChunk, None, None]],
) -> Any:
"""Parse the completion to a str."""
log.debug(f"completion: {completion}, parser: {self.chat_completion_parser}")
return self.chat_completion_parser(completion)

def parse_chat_completion(
self,
completion: Union[ChatCompletion, Generator[ChatCompletionChunk, None, None]],
Expand Down Expand Up @@ -236,6 +251,10 @@ async def acall(
if model_type == ModelType.EMBEDDER:
return await self.async_client.embeddings.create(**api_kwargs)
elif model_type == ModelType.LLM:
if "stream" in api_kwargs and api_kwargs.get("stream", False):
log.debug("streaming call")
self.chat_completion_parser = ahandle_stream_response
return await self.async_client.chat.completions.create(**api_kwargs)
return await self.async_client.chat.completions.create(**api_kwargs)
else:
raise ValueError(f"model_type {model_type} is not supported")
Expand Down
29 changes: 26 additions & 3 deletions lightrag/lightrag/core/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,29 @@ def _extra_repr(self) -> str:
s = f"model_kwargs={self.model_kwargs}, model_type={self.model_type}"
return s

async def _apost_call(self, completion: Any) -> GeneratorOutputType:
r"""Get string completion and process it with the output_processors."""
try:
response = await self.model_client.aparse_chat_completion(completion)
except Exception as e:
log.error(f"Error parsing the completion {completion}: {e}")
return GeneratorOutput(raw_response=str(completion), error=str(e))

# the output processors operate on the str, the raw_response field.
output: GeneratorOutputType = GeneratorOutput(raw_response=response)

if self.output_processors:
try:
response = self.output_processors(response)
output.data = response
except Exception as e:
log.error(f"Error processing the output processors: {e}")
output.error = str(e)
else: # default to string output
output.data = response

return output

def _post_call(self, completion: Any) -> GeneratorOutputType:
r"""Get string completion and process it with the output_processors."""
try:
Expand Down Expand Up @@ -251,7 +274,7 @@ async def acall(
self,
prompt_kwargs: Optional[Dict] = {},
model_kwargs: Optional[Dict] = {},
) -> GeneratorOutputType:
) -> GeneratorOutputType:
r"""Async call the model with the input and model_kwargs.

:warning::
Expand All @@ -263,8 +286,8 @@ async def acall(
api_kwargs = self._pre_call(prompt_kwargs, model_kwargs)
completion = await self.model_client.acall(
api_kwargs=api_kwargs, model_type=self.model_type
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im thinking we need to have a new function, self.model_client_call and model_client_acall to call and parse the completion this way, the pre_call and post_call does not need to be async for now

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im working on something also need to separate it:

    def _model_client_call(self, api_kwargs: Dict) -> Any:
        # call the model client
        try:
            # check the cache
            index_content = json.dumps(api_kwargs)  # all messages
            cached_completion = self._check_cache(index_content)
            if cached_completion is not None:
                return cached_completion
            completion = self.model_client.call(
                api_kwargs=api_kwargs, model_type=self.model_type
            )
            # prepare cache
            self._save_cache(index_content, completion)
            return completion
        except Exception as e:
            log.error(f"Error calling the model: {e}")
            raise e

You can use this minus the cache, here is how to use it in the call

 output: GeneratorOutputType = None
        # call the model client

        completion = None
        try:
            completion = self._model_client_call(api_kwargs=api_kwargs)
        except Exception as e:
            log.error(f"Error calling the model: {e}")
            output = GeneratorOutput(error=str(e))
        # process the completion
        if completion:
            try:
                output = self._post_call(completion)

            except Exception as e:
                log.error(f"Error processing the output: {e}")
                output = GeneratorOutput(raw_response=str(completion), error=str(e))

)
output = self._post_call(completion)
)
output = await self._apost_call(completion)
log.info(f"output: {output}")
return output

Expand Down
6 changes: 6 additions & 0 deletions lightrag/lightrag/core/model_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ def parse_chat_completion(self, completion: Any) -> Any:
f"{type(self).__name__} must implement parse_chat_completion method"
)

async def aparse_chat_completion(self, completion: Any) -> Any:
r"""Parse the chat completion to str."""
raise NotImplementedError(
f"{type(self).__name__} must implement aparse_chat_completion method"
)

def parse_embedding_response(self, response: Any) -> EmbedderOutput:
r"""Parse the embedding response to a structure LightRAG components can understand."""
raise NotImplementedError(
Expand Down
1 change: 1 addition & 0 deletions lightrag/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ ollama = { version = "^0.2.1", optional = true }
[tool.poetry.group.test.dependencies]
pytest = "^8.1.1"
pytest-mock = "^3.14.0"
pytest-asyncio = "^0.23.8"
torch = "^2.3.1"
ollama = "^0.2.1"
faiss-cpu = "^1.8.0"
Expand Down
25 changes: 25 additions & 0 deletions lightrag/tests/test_generator_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest
from typing import AsyncGenerator
from unittest.mock import AsyncMock
from lightrag.components.model_client.ollama_client import OllamaClient
from lightrag.core import Generator

async def async_gen() -> AsyncGenerator[int, None]:
yield {"response": "I"}
yield {"response": " am"}
yield {"response": " hungry"}


@pytest.mark.asyncio
async def test_acall():
ollama_client = OllamaClient()
ollama_client.acall = AsyncMock(return_value = async_gen())

generator = Generator(model_client=ollama_client)
output = await generator.acall({}, {})

result = []
async for value in output.data:
result.append(value)
assert result == ["I", " am", " hungry"]

45 changes: 45 additions & 0 deletions lightrag/tests/test_ollama_async_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest
from typing import AsyncGenerator
from unittest.mock import AsyncMock
from lightrag.core.types import ModelType
from lightrag.components.model_client.ollama_client import OllamaClient

@pytest.mark.asyncio
async def test_ollama_llm_client_async():
ollama_client = AsyncMock(spec=OllamaClient())
ollama_client.acall.return_value = {"message": "Hello"}
print("Testing ollama LLM async client")

# run the model
kwargs = {
"model": "qwen2:0.5b",
}
api_kwargs = ollama_client.convert_inputs_to_api_kwargs(
input="Hello world",
model_kwargs=kwargs,
model_type=ModelType.LLM,
).return_value = {"prompt": "Hello World", "model": "qwen2:0.5b"}

assert api_kwargs == {"prompt": "Hello World", "model": "qwen2:0.5b"}

output = await ollama_client.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM)
assert output == {"message": "Hello"}


async def async_gen() -> AsyncGenerator[int, None]:
yield {"response": "I"}
yield {"response": " am"}
yield {"response": " cool"}

@pytest.mark.asyncio
async def test_async_generator_completion():
ollama_client = OllamaClient()
print("Testing ollama LLM async client")

output = await ollama_client.aparse_chat_completion(async_gen())

result = []
async for value in output:
result.append(value)

assert result == ["I", " am", " cool"]
21 changes: 20 additions & 1 deletion lightrag/tests/test_ollama_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
from unittest.mock import Mock
import asyncio
from unittest.mock import Mock, AsyncMock
from lightrag.core.types import ModelType
from lightrag.components.model_client.ollama_client import OllamaClient

Expand Down Expand Up @@ -34,6 +35,24 @@ def test_ollama_llm_client(self):
).return_value = {"message": "Hello"}
assert output == {"message": "Hello"}

async def test_ollama_llm_client_acall(self):
ollama_client = AsyncMock(spec=OllamaClient())
ollama_client.acall.return_value = {"message": "Hello"}
print("Testing ollama LLM client")
# run the model
kwargs = {
"model": "qwen2:0.5b",
}
api_kwargs = ollama_client.convert_inputs_to_api_kwargs(
input="Hello world",
model_kwargs=kwargs,
model_type=ModelType.LLM,
).return_value = {"prompt": "Hello World", "model": "qwen2:0.5b"}
assert api_kwargs == {"prompt": "Hello World", "model": "qwen2:0.5b"}

output = await ollama_client.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM)
assert output == {"message": "Hello"}

def test_ollama_embedding_client(self):
ollama_client = Mock(spec=OllamaClient())
print("Testing ollama embedding client")
Expand Down