Skip to content

Commit

Permalink
Merge pull request #155 from SylphAI-Inc/main
Browse files Browse the repository at this point in the history
[release 0.1.0.b5]
  • Loading branch information
liyin2015 authored Jul 23, 2024
2 parents 4f67642 + 8aefe98 commit b4493d7
Show file tree
Hide file tree
Showing 11 changed files with 262 additions and 94 deletions.
9 changes: 7 additions & 2 deletions lightrag/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
## [0.1.0-beta.5] - 2024-07-20
## [0.1.0-beta.5] - 2024-07-23

### Fixed
- Suppport Enum in `DataClass` schema. https://github.com/SylphAI-Inc/LightRAG/pull/135
- [issue 134](https://github.com/SylphAI-Inc/AdalFlow/issues/134) Suppport Enum in `DataClass` schema. https://github.com/SylphAI-Inc/LightRAG/pull/135
- [issue 154](https://github.com/SylphAI-Inc/AdalFlow/issues/154) Fixed the `DataClass.from_dict` failure on `list[int]` type due to conditional check failure in the functional.

### Added
- Support streaming in Generator (sync call) [issue 149](https://github.com/SylphAI-Inc/AdalFlow/issues/149)
- Support streaming in OpenAIClient (sync call)

## [0.1.0-beta.3, 4] - 2024-07-18

Expand Down
2 changes: 1 addition & 1 deletion lightrag/lightrag/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.0-beta.4"
__version__ = "0.1.0-beta.5"
156 changes: 92 additions & 64 deletions lightrag/lightrag/components/model_client/ollama_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Ollama ModelClient integration."""

import os
from typing import Dict, Optional, Any, TypeVar, List, Type
from typing import Dict, Optional, Any, TypeVar, List, Type, Generator, Union
import backoff
import logging
import warnings
Expand Down Expand Up @@ -48,11 +48,67 @@ class OllamaClient(ModelClient):
If not provided, it will look for OLLAMA_HOST env variable. Defaults to None.
The default host is "http://localhost:11434".
Setting model_kwargs:
For LLM, expect model_kwargs to have the following keys:
model (str, required):
Use `ollama list` via your CLI or visit ollama model page on https://ollama.com/library
stream (bool, default: False ) – Whether to stream the results.
options (Optional[dict], optional)
Options that affect model output.
# If not specified the following defaults will be assigned.
"seed": 0, - Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt.
"num_predict": 128, - Maximum number of tokens to predict when generating text. (-1 = infinite generation, -2 = fill context)
"top_k": 40, - Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative.
"top_p": 0.9, - Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.
"tfs_z": 1, - Tail free sampling. This is used to reduce the impact of less probable tokens from the output. Disabled by default (e.g. 1) (More documentation here for specifics)
"repeat_last_n": 64, - Sets how far back the model should look back to prevent repetition. (0 = disabled, -1 = num_ctx)
"temperature": 0.8, - The temperature of the model. Increasing the temperature will make the model answer more creatively.
"repeat_penalty": 1.1, - Sets how strongly to penalize repetitions. A higher value(e.g., 1.5 will penlaize repetitions more strongly, while lowe values *e.g., 0.9 will be more lenient.)
"mirostat": 0.0, - Enable microstat smapling for controlling perplexity. (0 = disabled, 1 = microstat, 2 = microstat 2.0)
"mirostat_tau": 0.5, - Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text.
"mirostat_eta": 0.1, - Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive.
"stop": ["\n", "user:"], - Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate stop parameters in a modelfile.
"num_ctx": 2048, - Sets the size of the context window used to generate the next token.
For EMBEDDER, expect model_kwargs to have the following keys:
model (str, required):
Use `ollama list` via your CLI or visit ollama model page on https://ollama.com/library
prompt (str, required):
String that is sent to the Embedding model.
options (Optional[dict], optional):
See LLM args for defaults.
References:
- https://github.com/ollama/ollama-python
- https://github.com/ollama/ollama
- Models: https://ollama.com/library
- Ollama API: https://github.com/ollama/ollama/blob/main/docs/api.md
- Options Parameters: https://github.com/ollama/ollama/blob/main/docs/modelfile.md.
- LlamaCPP API documentation(Ollama is based on this): https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#low-level-api
- LLM API: https://llama-cpp-python.readthedocs.io/en/stable/api-reference/#llama_cpp.Llama.create_completion
Tested Ollama models: 7/9/24
Expand Down Expand Up @@ -90,14 +146,25 @@ def init_async_client(self):

self.async_client = ollama.AsyncClient(host=self._host)

def parse_chat_completion(self, completion: GenerateResponse) -> Any:
def parse_chat_completion(
self, completion: Union[GenerateResponse, Generator]
) -> Any:
"""Parse the completion to a str. We use the generate with prompt instead of chat with messages."""
log.debug(f"completion: {completion}")
if "response" in completion:
return completion["response"]
log.debug(f"completion: {completion}, {isinstance(completion, Generator)}")
if isinstance(completion, Generator): # streaming
for chunk in completion:
log.debug(f"Raw chunk: {chunk}")
yield chunk["response"] if "response" in chunk else None
else:
log.error(f"Error parsing the completion: {completion}")
raise ValueError(f"Error parsing the completion: {completion}")
if "response" in completion:
return completion["response"]
else:
log.error(
f"Error parsing the completion: {completion}, type: {type(completion)}"
)
raise ValueError(
f"Error parsing the completion: {completion}, type: {type(completion)}"
)

def parse_embedding_response(
self, response: Dict[str, List[float]]
Expand All @@ -118,63 +185,7 @@ def convert_inputs_to_api_kwargs(
model_kwargs: Dict = {},
model_type: ModelType = ModelType.UNDEFINED,
) -> Dict:
r"""
API Reference: https://github.com/ollama/ollama/blob/main/docs/api.md
Options Parameters: https://github.com/ollama/ollama/blob/main/docs/modelfile.md.
LlamaCPP API documentation(Ollama is based on this): https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#low-level-api
For LLM, expect model_kwargs to have the following keys:
model (str, required):
Use `ollama list` via your CLI or visit ollama model page on https://ollama.com/library
prompt (str, required):
String that is sent to the LLM.
options (Optional[dict], optional)
Options that affect model output.
# If not specified the following defaults will be assigned.
"seed": 0, - Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt.
"num_predict": 128, - Maximum number of tokens to predict when generating text. (-1 = infinite generation, -2 = fill context)
"top_k": 40, - Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative.
"top_p": 0.9, - Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.
"tfs_z": 1, - Tail free sampling. This is used to reduce the impact of less probable tokens from the output. Disabled by default (e.g. 1) (More documentation here for specifics)
"repeat_last_n": 64, - Sets how far back the model should look back to prevent repetition. (0 = disabled, -1 = num_ctx)
"temperature": 0.8, - The temperature of the model. Increasing the temperature will make the model answer more creatively.
"repeat_penalty": 1.1, - Sets how strongly to penalize repetitions. A higher value(e.g., 1.5 will penlaize repetitions more strongly, while lowe values *e.g., 0.9 will be more lenient.)
"mirostat": 0.0, - Enable microstat smapling for controlling perplexity. (0 = disabled, 1 = microstat, 2 = microstat 2.0)
"mirostat_tau": 0.5, - Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text.
"mirostat_eta": 0.1, - Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive.
"stop": ["\n", "user:"], - Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate stop parameters in a modelfile.
"num_ctx": 2048, - Sets the size of the context window used to generate the next token.
For EMBEDDER, expect model_kwargs to have the following keys:
model (str, required):
Use `ollama list` via your CLI or visit ollama model page on https://ollama.com/library
prompt (str, required):
String that is sent to the Embedding model.
options (Optional[dict], optional):
See LLM args for defaults.
"""
r"""Convert the input and model_kwargs to api_kwargs for the Ollama SDK client."""
# TODO: ollama will support batch embedding in the future: https://ollama.com/blog/embedding-models
final_model_kwargs = model_kwargs.copy()
if model_type == ModelType.EMBEDDER:
Expand Down Expand Up @@ -251,3 +262,20 @@ def to_dict(self, exclude: Optional[List[str]] = None) -> Dict[str, Any]:

output = super().to_dict(exclude=exclude)
return output


# if __name__ == "__main__":
# from lightrag.core.generator import Generator
# from lightrag.components.model_client import OllamaClient
# from lightrag.utils import setup_env, get_logger

# # log = get_logger(level="DEBUG")

# setup_env()

# model_client = OllamaClient()
# model_kwargs = {"model": "phi3", "stream": True}
# generator = Generator(model_client=model_client, model_kwargs=model_kwargs)
# output = generator({"input_str": "What is the capital of France?"})
# for chunk in output.data:
# print(chunk)
67 changes: 59 additions & 8 deletions lightrag/lightrag/components/model_client/openai_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
"""OpenAI ModelClient integration."""

import os
from typing import Dict, Sequence, Optional, List, Any, TypeVar, Callable
from typing import (
Dict,
Sequence,
Optional,
List,
Any,
TypeVar,
Callable,
Generator,
Union,
)

import logging
import backoff
Expand All @@ -17,7 +27,7 @@

openai = safe_import(OptionalPackages.OPENAI.value[0], OptionalPackages.OPENAI.value[1])

from openai import OpenAI, AsyncOpenAI
from openai import OpenAI, AsyncOpenAI, Stream
from openai import (
APITimeoutError,
InternalServerError,
Expand All @@ -26,25 +36,39 @@
BadRequestError,
)
from openai.types import Completion, CreateEmbeddingResponse
from openai.types.chat import ChatCompletionChunk, ChatCompletion


log = logging.getLogger(__name__)
T = TypeVar("T")


# completion parsing functions and you can combine them into one singple chat completion parser
def get_first_message_content(completion: Completion) -> str:
def get_first_message_content(completion: ChatCompletion) -> str:
r"""When we only need the content of the first message.
It is the default parser for chat completion."""
return completion.choices[0].message.content


def get_all_messages_content(completion: Completion) -> List[str]:
def parse_stream_response(completion: ChatCompletionChunk) -> str:
r"""Parse the response of the stream API."""
return completion.choices[0].delta.content


def handle_streaming_response(generator: Stream[ChatCompletionChunk]):
r"""Handle the streaming response."""
for completion in generator:
log.debug(f"Raw chunk completion: {completion}")
parsed_content = parse_stream_response(completion)
yield parsed_content


def get_all_messages_content(completion: ChatCompletion) -> List[str]:
r"""When the n > 1, get all the messages content."""
return [c.message.content for c in completion.choices]


def get_probabilities(completion: Completion) -> List[List[TokenLogProb]]:
def get_probabilities(completion: ChatCompletion) -> List[List[TokenLogProb]]:
r"""Get the probabilities of each token in the completion."""
log_probs = []
for c in completion.choices:
Expand Down Expand Up @@ -114,9 +138,12 @@ def init_async_client(self):
raise ValueError("Environment variable OPENAI_API_KEY must be set")
return AsyncOpenAI(api_key=api_key)

def parse_chat_completion(self, completion: Completion) -> Any:
def parse_chat_completion(
self,
completion: Union[ChatCompletion, Generator[ChatCompletionChunk, None, None]],
) -> Any:
"""Parse the completion to a str."""
log.debug(f"completion: {completion}")
log.debug(f"completion: {completion}, parser: {self.chat_completion_parser}")
return self.chat_completion_parser(completion)

def parse_embedding_response(
Expand Down Expand Up @@ -173,12 +200,16 @@ def convert_inputs_to_api_kwargs(
)
def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED):
"""
kwargs is the combined input and model_kwargs
kwargs is the combined input and model_kwargs. Support streaming call.
"""
log.info(f"api_kwargs: {api_kwargs}")
if model_type == ModelType.EMBEDDER:
return self.sync_client.embeddings.create(**api_kwargs)
elif model_type == ModelType.LLM:
if "stream" in api_kwargs and api_kwargs.get("stream", False):
log.debug("streaming call")
self.chat_completion_parser = handle_streaming_response
return self.sync_client.chat.completions.create(**api_kwargs)
return self.sync_client.chat.completions.create(**api_kwargs)
else:
raise ValueError(f"model_type {model_type} is not supported")
Expand Down Expand Up @@ -226,3 +257,23 @@ def to_dict(self) -> Dict[str, Any]:
] # unserializable object
output = super().to_dict(exclude=exclude)
return output


# if __name__ == "__main__":
# from lightrag.core import Generator
# from lightrag.utils import setup_env, get_logger

# log = get_logger(level="DEBUG")

# setup_env()
# prompt_kwargs = {"input_str": "What is the meaning of life?"}

# gen = Generator(
# model_client=OpenAIClient(),
# model_kwargs={"model": "gpt-3.5-turbo", "stream": True},
# )
# gen_response = gen(prompt_kwargs)
# print(f"gen_response: {gen_response}")

# for genout in gen_response.data:
# print(f"genout: {genout}")
8 changes: 7 additions & 1 deletion lightrag/lightrag/components/output_parsers/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,17 @@ def format_instructions(
def call(self, input: str) -> Any:
try:
output_dict = self.output_processors(input)
log.debug(f"{__class__.__name__} output_dict: {output_dict}")

except Exception as e:
log.error(f"Error in parsing JSON to JSON: {e}")
raise e
try:
if self._return_data_class:
return self.data_class.from_dict(output_dict)
return output_dict
except Exception as e:
log.error(f"Error in parsing JSON to JSON: {e}")
log.error(f"Error in converting dict to data class: {e}")
raise e

def _extra_repr(self) -> str:
Expand Down
7 changes: 6 additions & 1 deletion lightrag/lightrag/core/base_data_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,12 @@ def from_dict(cls, data: Dict[str, Any]) -> "DataClass":
- Convert the json/yaml output from LLM prediction to a dataclass instance.
- Restore the dataclass instance from the serialized output used for states saving.
"""
return dataclass_obj_from_dict(cls, data)
try:
dclass = dataclass_obj_from_dict(cls, data)
logger.debug(f"Dataclass instance created from dict: {dclass}")
return dclass
except TypeError as e:
raise ValueError(f"Failed to load data: {e}")

@classmethod
def from_json(cls, json_str: str) -> "DataClass":
Expand Down
Loading

0 comments on commit b4493d7

Please sign in to comment.