diff --git a/lightrag/CHANGELOG.md b/lightrag/CHANGELOG.md index b0c989c0..d77172bc 100644 --- a/lightrag/CHANGELOG.md +++ b/lightrag/CHANGELOG.md @@ -1,7 +1,12 @@ -## [0.1.0-beta.5] - 2024-07-20 +## [0.1.0-beta.5] - 2024-07-23 ### Fixed -- Suppport Enum in `DataClass` schema. https://github.com/SylphAI-Inc/LightRAG/pull/135 +- [issue 134](https://github.com/SylphAI-Inc/AdalFlow/issues/134) Suppport Enum in `DataClass` schema. https://github.com/SylphAI-Inc/LightRAG/pull/135 +- [issue 154](https://github.com/SylphAI-Inc/AdalFlow/issues/154) Fixed the `DataClass.from_dict` failure on `list[int]` type due to conditional check failure in the functional. + +### Added +- Support streaming in Generator (sync call) [issue 149](https://github.com/SylphAI-Inc/AdalFlow/issues/149) +- Support streaming in OpenAIClient (sync call) ## [0.1.0-beta.3, 4] - 2024-07-18 diff --git a/lightrag/lightrag/__init__.py b/lightrag/lightrag/__init__.py index 5e437dcd..c972eb01 100644 --- a/lightrag/lightrag/__init__.py +++ b/lightrag/lightrag/__init__.py @@ -1 +1 @@ -__version__ = "0.1.0-beta.4" +__version__ = "0.1.0-beta.5" diff --git a/lightrag/lightrag/components/model_client/ollama_client.py b/lightrag/lightrag/components/model_client/ollama_client.py index 1290cbac..f4fc8cd0 100644 --- a/lightrag/lightrag/components/model_client/ollama_client.py +++ b/lightrag/lightrag/components/model_client/ollama_client.py @@ -1,7 +1,7 @@ """Ollama ModelClient integration.""" import os -from typing import Dict, Optional, Any, TypeVar, List, Type +from typing import Dict, Optional, Any, TypeVar, List, Type, Generator, Union import backoff import logging import warnings @@ -48,11 +48,67 @@ class OllamaClient(ModelClient): If not provided, it will look for OLLAMA_HOST env variable. Defaults to None. The default host is "http://localhost:11434". + Setting model_kwargs: + + For LLM, expect model_kwargs to have the following keys: + + model (str, required): + Use `ollama list` via your CLI or visit ollama model page on https://ollama.com/library + + stream (bool, default: False ) – Whether to stream the results. + + + options (Optional[dict], optional) + Options that affect model output. + + # If not specified the following defaults will be assigned. + + "seed": 0, - Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. + + "num_predict": 128, - Maximum number of tokens to predict when generating text. (-1 = infinite generation, -2 = fill context) + + "top_k": 40, - Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. + + "top_p": 0.9, - Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. + + "tfs_z": 1, - Tail free sampling. This is used to reduce the impact of less probable tokens from the output. Disabled by default (e.g. 1) (More documentation here for specifics) + + "repeat_last_n": 64, - Sets how far back the model should look back to prevent repetition. (0 = disabled, -1 = num_ctx) + + "temperature": 0.8, - The temperature of the model. Increasing the temperature will make the model answer more creatively. + + "repeat_penalty": 1.1, - Sets how strongly to penalize repetitions. A higher value(e.g., 1.5 will penlaize repetitions more strongly, while lowe values *e.g., 0.9 will be more lenient.) + + "mirostat": 0.0, - Enable microstat smapling for controlling perplexity. (0 = disabled, 1 = microstat, 2 = microstat 2.0) + + "mirostat_tau": 0.5, - Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. + + "mirostat_eta": 0.1, - Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. + + "stop": ["\n", "user:"], - Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate stop parameters in a modelfile. + + "num_ctx": 2048, - Sets the size of the context window used to generate the next token. + + For EMBEDDER, expect model_kwargs to have the following keys: + + model (str, required): + Use `ollama list` via your CLI or visit ollama model page on https://ollama.com/library + + prompt (str, required): + String that is sent to the Embedding model. + + options (Optional[dict], optional): + See LLM args for defaults. + References: - https://github.com/ollama/ollama-python - https://github.com/ollama/ollama - Models: https://ollama.com/library + - Ollama API: https://github.com/ollama/ollama/blob/main/docs/api.md + - Options Parameters: https://github.com/ollama/ollama/blob/main/docs/modelfile.md. + - LlamaCPP API documentation(Ollama is based on this): https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#low-level-api + - LLM API: https://llama-cpp-python.readthedocs.io/en/stable/api-reference/#llama_cpp.Llama.create_completion Tested Ollama models: 7/9/24 @@ -90,14 +146,25 @@ def init_async_client(self): self.async_client = ollama.AsyncClient(host=self._host) - def parse_chat_completion(self, completion: GenerateResponse) -> Any: + def parse_chat_completion( + self, completion: Union[GenerateResponse, Generator] + ) -> Any: """Parse the completion to a str. We use the generate with prompt instead of chat with messages.""" - log.debug(f"completion: {completion}") - if "response" in completion: - return completion["response"] + log.debug(f"completion: {completion}, {isinstance(completion, Generator)}") + if isinstance(completion, Generator): # streaming + for chunk in completion: + log.debug(f"Raw chunk: {chunk}") + yield chunk["response"] if "response" in chunk else None else: - log.error(f"Error parsing the completion: {completion}") - raise ValueError(f"Error parsing the completion: {completion}") + if "response" in completion: + return completion["response"] + else: + log.error( + f"Error parsing the completion: {completion}, type: {type(completion)}" + ) + raise ValueError( + f"Error parsing the completion: {completion}, type: {type(completion)}" + ) def parse_embedding_response( self, response: Dict[str, List[float]] @@ -118,63 +185,7 @@ def convert_inputs_to_api_kwargs( model_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED, ) -> Dict: - r""" - API Reference: https://github.com/ollama/ollama/blob/main/docs/api.md - - Options Parameters: https://github.com/ollama/ollama/blob/main/docs/modelfile.md. - - LlamaCPP API documentation(Ollama is based on this): https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#low-level-api - - For LLM, expect model_kwargs to have the following keys: - - model (str, required): - Use `ollama list` via your CLI or visit ollama model page on https://ollama.com/library - - prompt (str, required): - String that is sent to the LLM. - - options (Optional[dict], optional) - Options that affect model output. - - # If not specified the following defaults will be assigned. - - "seed": 0, - Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. - - "num_predict": 128, - Maximum number of tokens to predict when generating text. (-1 = infinite generation, -2 = fill context) - - "top_k": 40, - Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. - - "top_p": 0.9, - Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. - - "tfs_z": 1, - Tail free sampling. This is used to reduce the impact of less probable tokens from the output. Disabled by default (e.g. 1) (More documentation here for specifics) - - "repeat_last_n": 64, - Sets how far back the model should look back to prevent repetition. (0 = disabled, -1 = num_ctx) - - "temperature": 0.8, - The temperature of the model. Increasing the temperature will make the model answer more creatively. - - "repeat_penalty": 1.1, - Sets how strongly to penalize repetitions. A higher value(e.g., 1.5 will penlaize repetitions more strongly, while lowe values *e.g., 0.9 will be more lenient.) - - "mirostat": 0.0, - Enable microstat smapling for controlling perplexity. (0 = disabled, 1 = microstat, 2 = microstat 2.0) - - "mirostat_tau": 0.5, - Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. - - "mirostat_eta": 0.1, - Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. - - "stop": ["\n", "user:"], - Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate stop parameters in a modelfile. - - "num_ctx": 2048, - Sets the size of the context window used to generate the next token. - - For EMBEDDER, expect model_kwargs to have the following keys: - - model (str, required): - Use `ollama list` via your CLI or visit ollama model page on https://ollama.com/library - - prompt (str, required): - String that is sent to the Embedding model. - - options (Optional[dict], optional): - See LLM args for defaults. - """ + r"""Convert the input and model_kwargs to api_kwargs for the Ollama SDK client.""" # TODO: ollama will support batch embedding in the future: https://ollama.com/blog/embedding-models final_model_kwargs = model_kwargs.copy() if model_type == ModelType.EMBEDDER: @@ -251,3 +262,20 @@ def to_dict(self, exclude: Optional[List[str]] = None) -> Dict[str, Any]: output = super().to_dict(exclude=exclude) return output + + +# if __name__ == "__main__": +# from lightrag.core.generator import Generator +# from lightrag.components.model_client import OllamaClient +# from lightrag.utils import setup_env, get_logger + +# # log = get_logger(level="DEBUG") + +# setup_env() + +# model_client = OllamaClient() +# model_kwargs = {"model": "phi3", "stream": True} +# generator = Generator(model_client=model_client, model_kwargs=model_kwargs) +# output = generator({"input_str": "What is the capital of France?"}) +# for chunk in output.data: +# print(chunk) diff --git a/lightrag/lightrag/components/model_client/openai_client.py b/lightrag/lightrag/components/model_client/openai_client.py index d374d123..46a5c525 100644 --- a/lightrag/lightrag/components/model_client/openai_client.py +++ b/lightrag/lightrag/components/model_client/openai_client.py @@ -1,7 +1,17 @@ """OpenAI ModelClient integration.""" import os -from typing import Dict, Sequence, Optional, List, Any, TypeVar, Callable +from typing import ( + Dict, + Sequence, + Optional, + List, + Any, + TypeVar, + Callable, + Generator, + Union, +) import logging import backoff @@ -17,7 +27,7 @@ openai = safe_import(OptionalPackages.OPENAI.value[0], OptionalPackages.OPENAI.value[1]) -from openai import OpenAI, AsyncOpenAI +from openai import OpenAI, AsyncOpenAI, Stream from openai import ( APITimeoutError, InternalServerError, @@ -26,6 +36,7 @@ BadRequestError, ) from openai.types import Completion, CreateEmbeddingResponse +from openai.types.chat import ChatCompletionChunk, ChatCompletion log = logging.getLogger(__name__) @@ -33,18 +44,31 @@ # completion parsing functions and you can combine them into one singple chat completion parser -def get_first_message_content(completion: Completion) -> str: +def get_first_message_content(completion: ChatCompletion) -> str: r"""When we only need the content of the first message. It is the default parser for chat completion.""" return completion.choices[0].message.content -def get_all_messages_content(completion: Completion) -> List[str]: +def parse_stream_response(completion: ChatCompletionChunk) -> str: + r"""Parse the response of the stream API.""" + return completion.choices[0].delta.content + + +def handle_streaming_response(generator: Stream[ChatCompletionChunk]): + r"""Handle the streaming response.""" + for completion in generator: + log.debug(f"Raw chunk completion: {completion}") + parsed_content = parse_stream_response(completion) + yield parsed_content + + +def get_all_messages_content(completion: ChatCompletion) -> List[str]: r"""When the n > 1, get all the messages content.""" return [c.message.content for c in completion.choices] -def get_probabilities(completion: Completion) -> List[List[TokenLogProb]]: +def get_probabilities(completion: ChatCompletion) -> List[List[TokenLogProb]]: r"""Get the probabilities of each token in the completion.""" log_probs = [] for c in completion.choices: @@ -114,9 +138,12 @@ def init_async_client(self): raise ValueError("Environment variable OPENAI_API_KEY must be set") return AsyncOpenAI(api_key=api_key) - def parse_chat_completion(self, completion: Completion) -> Any: + def parse_chat_completion( + self, + completion: Union[ChatCompletion, Generator[ChatCompletionChunk, None, None]], + ) -> Any: """Parse the completion to a str.""" - log.debug(f"completion: {completion}") + log.debug(f"completion: {completion}, parser: {self.chat_completion_parser}") return self.chat_completion_parser(completion) def parse_embedding_response( @@ -173,12 +200,16 @@ def convert_inputs_to_api_kwargs( ) def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED): """ - kwargs is the combined input and model_kwargs + kwargs is the combined input and model_kwargs. Support streaming call. """ log.info(f"api_kwargs: {api_kwargs}") if model_type == ModelType.EMBEDDER: return self.sync_client.embeddings.create(**api_kwargs) elif model_type == ModelType.LLM: + if "stream" in api_kwargs and api_kwargs.get("stream", False): + log.debug("streaming call") + self.chat_completion_parser = handle_streaming_response + return self.sync_client.chat.completions.create(**api_kwargs) return self.sync_client.chat.completions.create(**api_kwargs) else: raise ValueError(f"model_type {model_type} is not supported") @@ -226,3 +257,23 @@ def to_dict(self) -> Dict[str, Any]: ] # unserializable object output = super().to_dict(exclude=exclude) return output + + +# if __name__ == "__main__": +# from lightrag.core import Generator +# from lightrag.utils import setup_env, get_logger + +# log = get_logger(level="DEBUG") + +# setup_env() +# prompt_kwargs = {"input_str": "What is the meaning of life?"} + +# gen = Generator( +# model_client=OpenAIClient(), +# model_kwargs={"model": "gpt-3.5-turbo", "stream": True}, +# ) +# gen_response = gen(prompt_kwargs) +# print(f"gen_response: {gen_response}") + +# for genout in gen_response.data: +# print(f"genout: {genout}") diff --git a/lightrag/lightrag/components/output_parsers/outputs.py b/lightrag/lightrag/components/output_parsers/outputs.py index 04cc3bb4..983601b4 100644 --- a/lightrag/lightrag/components/output_parsers/outputs.py +++ b/lightrag/lightrag/components/output_parsers/outputs.py @@ -275,11 +275,17 @@ def format_instructions( def call(self, input: str) -> Any: try: output_dict = self.output_processors(input) + log.debug(f"{__class__.__name__} output_dict: {output_dict}") + + except Exception as e: + log.error(f"Error in parsing JSON to JSON: {e}") + raise e + try: if self._return_data_class: return self.data_class.from_dict(output_dict) return output_dict except Exception as e: - log.error(f"Error in parsing JSON to JSON: {e}") + log.error(f"Error in converting dict to data class: {e}") raise e def _extra_repr(self) -> str: diff --git a/lightrag/lightrag/core/base_data_class.py b/lightrag/lightrag/core/base_data_class.py index 19585ada..f7ced0ef 100644 --- a/lightrag/lightrag/core/base_data_class.py +++ b/lightrag/lightrag/core/base_data_class.py @@ -262,7 +262,12 @@ def from_dict(cls, data: Dict[str, Any]) -> "DataClass": - Convert the json/yaml output from LLM prediction to a dataclass instance. - Restore the dataclass instance from the serialized output used for states saving. """ - return dataclass_obj_from_dict(cls, data) + try: + dclass = dataclass_obj_from_dict(cls, data) + logger.debug(f"Dataclass instance created from dict: {dclass}") + return dclass + except TypeError as e: + raise ValueError(f"Failed to load data: {e}") @classmethod def from_json(cls, json_str: str) -> "DataClass": diff --git a/lightrag/lightrag/core/functional.py b/lightrag/lightrag/core/functional.py index 6b329fbf..f3c1e155 100644 --- a/lightrag/lightrag/core/functional.py +++ b/lightrag/lightrag/core/functional.py @@ -178,6 +178,26 @@ def extract_dataclass_type(type_hint): return type_hint if is_dataclass(type_hint) else None +def check_data_class_field_args_zero(cls): + """Check if the field is a dataclass.""" + return ( + hasattr(cls, "__args__") + and len(cls.__args__) > 0 + and cls.__args__[0] + and hasattr(cls.__args__[0], "__dataclass_fields__") + ) + + +def check_data_class_field_args_one(cls): + """Check if the field is a dataclass.""" + return ( + hasattr(cls, "__args__") + and len(cls.__args__) > 1 + and cls.__args__[1] + and hasattr(cls.__args__[1], "__dataclass_fields__") + ) + + def dataclass_obj_from_dict(cls: Type[object], data: Dict[str, object]) -> Any: r"""Convert a dictionary to a dataclass object. @@ -213,7 +233,7 @@ class TrecDataList: # TrecDataList(data=[TrecData(question='What is the capital of France?', label=0)], name='trec_data_list') """ - + log.debug(f"Dataclass: {cls}, Data: {data}") if is_dataclass(cls) or is_potential_dataclass( cls ): # Optional[Address] will be false, and true for each check @@ -230,44 +250,40 @@ class TrecDataList: } ) elif isinstance(data, (list, tuple)): + log.debug(f"List or Tuple: {cls}, {data}") restored_data = [] for item in data: - if cls.__args__[0] and hasattr(cls.__args__[0], "__dataclass_fields__"): + if check_data_class_field_args_zero(cls): # restore the value to its dataclass type restored_data.append(dataclass_obj_from_dict(cls.__args__[0], item)) else: # Use the original data [Any] restored_data.append(item) - return restored_data elif isinstance(data, set): + log.debug(f"Set: {cls}, {data}") restored_data = set() for item in data: - if cls.__args__[0] and hasattr(cls.__args__[0], "__dataclass_fields__"): + if check_data_class_field_args_zero(cls): # restore the value to its dataclass type restored_data.add(dataclass_obj_from_dict(cls.__args__[0], item)) else: # Use the original data [Any] restored_data.add(item) - return restored_data elif isinstance(data, dict): + log.debug(f"Dict: {cls}, {data}") for key, value in data.items(): - if ( - hasattr(cls, "__args__") - and len(cls.__args__) > 1 - and cls.__args__[1] - and hasattr(cls.__args__[1], "__dataclass_fields__") - ): + if check_data_class_field_args_one(cls): # restore the value to its dataclass type data[key] = dataclass_obj_from_dict(cls.__args__[1], value) else: # Use the original data [Any] data[key] = value return data - + # else normal data like int, str, float, etc. else: log.debug(f"Not datclass, or list, or dict: {cls}, use the original data.") return data diff --git a/lightrag/lightrag/core/generator.py b/lightrag/lightrag/core/generator.py index d44ac6b8..03b90047 100644 --- a/lightrag/lightrag/core/generator.py +++ b/lightrag/lightrag/core/generator.py @@ -167,20 +167,19 @@ def _post_call(self, completion: Any) -> GeneratorOutputType: try: response = self.model_client.parse_chat_completion(completion) except Exception as e: - log.error(f"Error parsing the completion: {e}") + log.error(f"Error parsing the completion {completion}: {e}") # response = str(completion) return GeneratorOutput(raw_response=str(completion), error=str(e)) # the output processors operate on the str, the raw_response field. output: GeneratorOutputType = GeneratorOutput(raw_response=response) - response = deepcopy(response) if self.output_processors: try: response = self.output_processors(response) output.data = response except Exception as e: - log.error(f"Error processing the output: {e}") + log.error(f"Error processing the output processors: {e}") output.error = str(e) else: # default to string output output.data = response @@ -232,11 +231,17 @@ def call( completion = self.model_client.call( api_kwargs=api_kwargs, model_type=self.model_type ) - output = self._post_call(completion) + except Exception as e: log.error(f"Error calling the model: {e}") output = GeneratorOutput(error=str(e)) + try: + output = self._post_call(completion) + except Exception as e: + log.error(f"Error processing the output: {e}") + output = GeneratorOutput(raw_response=str(completion), error=str(e)) + log.info(f"output: {output}") return output diff --git a/lightrag/lightrag/core/string_parser.py b/lightrag/lightrag/core/string_parser.py index fe0c480c..7e52a2d4 100644 --- a/lightrag/lightrag/core/string_parser.py +++ b/lightrag/lightrag/core/string_parser.py @@ -192,6 +192,7 @@ def call(self, input: str) -> JSON_PARSER_OUTPUT_TYPE: # Parse JSON string with json.loads and yaml.safe_load try: json_obj = F.parse_json_str_to_obj(json_str) + log.debug(f"json_obj: {json_obj}") return json_obj except Exception as e: log.error(f"Error at parsing JSON string: {e}") diff --git a/lightrag/pyproject.toml b/lightrag/pyproject.toml index c9149c15..609b4892 100644 --- a/lightrag/pyproject.toml +++ b/lightrag/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "lightrag" -version = "0.1.0-beta.4" +version = "0.1.0-beta.5" description = "The Lightning Library for LLM Applications." authors = ["Li Yin "] readme = "README.md" diff --git a/lightrag/tests/test_base_data_class.py b/lightrag/tests/test_base_data_class.py index 8084f4b7..24cd27ee 100644 --- a/lightrag/tests/test_base_data_class.py +++ b/lightrag/tests/test_base_data_class.py @@ -346,6 +346,57 @@ class LabelDataClass(DataClass, str, enum.Enum): ) +@dataclass +class ListDataclass(DataClass): + answer: str = field(metadata={"desc": "The answer to the user question."}) + pmids: list[int] = field( + metadata={"desc": "The PMIDs of the relevant articles used to answer."} + ) + + +class TestUnnestedDataclass(unittest.TestCase): + def test_list_dataclass(self): + instance = ListDataclass(answer="answer", pmids=[1, 2, 3]) + result = instance.to_dict() + print(f"result: {result}") + expected = "{'answer': 'answer', 'pmids': [1, 2, 3]}" + self.assertEqual(str(result), expected) + restored_instance = ListDataclass.from_dict(result) + self.assertEqual(restored_instance, instance) + + def test_dict_dataclass(self): + @dataclass + class DictDataclass(DataClass): + answer: str = field(metadata={"desc": "The answer to the user question."}) + pmids: Dict[str, int] = field( + metadata={"desc": "The PMIDs of the relevant articles used to answer."} + ) + + instance = DictDataclass(answer="answer", pmids={"a": 1, "b": 2, "c": 3}) + result = instance.to_dict() + print(f"result: {result}") + expected = "{'answer': 'answer', 'pmids': {'a': 1, 'b': 2, 'c': 3}}" + self.assertEqual(str(result), expected) + restored_instance = DictDataclass.from_dict(result) + self.assertEqual(restored_instance, instance) + + def test_set_dataclass(self): + @dataclass + class SetDataclass(DataClass): + answer: str = field(metadata={"desc": "The answer to the user question."}) + pmids: Set[int] = field( + metadata={"desc": "The PMIDs of the relevant articles used to answer."} + ) + + instance = SetDataclass(answer="answer", pmids={1, 2, 3}) + result = instance.to_dict() + print(f"result: {result}") + expected = "{'answer': 'answer', 'pmids': {1, 2, 3}}" + self.assertEqual(str(result), expected) + restored_instance = SetDataclass.from_dict(result) + self.assertEqual(restored_instance, instance) + + if __name__ == "__main__": unittest.main()