Skip to content

Commit

Permalink
make the outptu processors a parser class instead of a component, imp…
Browse files Browse the repository at this point in the history
…roved the joined training experiemnt tracking
  • Loading branch information
liyin2015 committed Dec 29, 2024
1 parent 3889f8a commit a185e7f
Show file tree
Hide file tree
Showing 12 changed files with 231 additions and 108 deletions.
8 changes: 7 additions & 1 deletion adalflow/adalflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__version__ = "0.2.6"

from adalflow.core.component import Component, fun_to_component
from adalflow.core.component import Component
from adalflow.core.container import Sequential, ComponentList
from adalflow.core.base_data_class import DataClass, DataClassFormatType, required_field

Expand All @@ -24,6 +24,9 @@
FloatParser,
ListParser,
BooleanParser,
Parser,
func_to_parser,
FuncParser,
)
from adalflow.core.retriever import Retriever
from adalflow.components.output_parsers import (
Expand Down Expand Up @@ -101,6 +104,9 @@
"FloatParser",
"ListParser",
"BooleanParser",
"Parser",
"func_to_parser",
"FuncParser",
# Output Parsers with dataclass formatting
"YamlOutputParser",
"JsonOutputParser",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging

from adalflow.core.prompt_builder import Prompt
from adalflow.core.string_parser import YamlParser, JsonParser
from adalflow.core.string_parser import YamlParser, JsonParser, Parser
from adalflow.core.base_data_class import DataClass, DataClassFormatType
from adalflow.core.base_data_class import ExcludeType, IncludeType

Expand Down Expand Up @@ -42,7 +42,7 @@
"""


class DataClassParser:
class DataClassParser(Parser):
__doc__ = r"""Made the structured output even simpler compared with JsonOutputParser and YamlOutputParser.
1. Understands __input_fields__ and __output_fields__ from the DataClass (no need to use include/exclude to decide fields).
Expand Down
12 changes: 8 additions & 4 deletions adalflow/adalflow/components/output_parsers/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import logging

from adalflow.core.prompt_builder import Prompt
from adalflow.core.string_parser import YamlParser, ListParser, JsonParser
from adalflow.core.string_parser import YamlParser, ListParser, JsonParser, Parser
from adalflow.core.base_data_class import DataClass, DataClassFormatType
from adalflow.core.base_data_class import ExcludeType, IncludeType

Expand Down Expand Up @@ -68,15 +68,19 @@
YAML_OUTPUT_PARSER_OUTPUT_TYPE = Dict[str, Any]


class OutputParser:
class OutputParser(Parser):
__doc__ = r"""The abstract class for all output parsers.
On top of the basic string Parser, it handles structured data interaction:
1. format_instructions: Return the formatted instructions to use in prompt for the output format.
2. call: Parse the output string to the desired format and return the parsed output via yaml or json.
This interface helps users customize output parsers with consistent interfaces for the Generator.
Even though you don't always need to subclass it.
AdalFlow uses two core components:
AdalFlow uses two core classes:
1. the Prompt to format output instruction
2. A string parser component from core.string_parser for response parsing.
2. A string parser from core.string_parser for response parsing.
"""

def __init__(self, *args, **kwargs) -> None:
Expand Down
25 changes: 22 additions & 3 deletions adalflow/adalflow/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
from .base_data_class import DataClass, required_field, DataClassFormatType

from .component import Component, FunComponent, fun_to_component
from .component import Component
from .container import Sequential, ComponentList
from .db import LocalDB
from .default_prompt_template import DEFAULT_ADALFLOW_SYSTEM_PROMPT
from .embedder import Embedder, BatchEmbedder
from .generator import Generator, BackwardEngine
from .model_client import ModelClient
from .string_parser import (
Parser,
FuncParser,
func_to_parser,
YamlParser,
JsonParser,
IntParser,
FloatParser,
ListParser,
BooleanParser,
)

# from .parameter import Parameter
from .prompt_builder import Prompt
Expand Down Expand Up @@ -51,8 +62,6 @@
"Component",
"Sequential",
"ComponentList",
"FunComponent",
"fun_to_component",
"DataClass",
"DataClassFormatType",
"required_field",
Expand Down Expand Up @@ -94,6 +103,16 @@
"DialogTurn",
"Conversation",
"Tokenizer",
# Parsers
"Parser",
"FuncParser",
"func_to_parser",
"YamlParser",
"JsonParser",
"IntParser",
"FloatParser",
"ListParser",
"BooleanParser",
]

for name in __all__:
Expand Down
149 changes: 64 additions & 85 deletions adalflow/adalflow/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from collections import OrderedDict, namedtuple
from typing import (
Callable,
Dict,
Any,
Optional,
Expand Down Expand Up @@ -519,36 +518,86 @@ def named_parameters(
# )
# plt.show()

def forward(self, *args, **kwargs):
"""
User must override this for the training scenario
if bicall is not defined.
"""
raise NotImplementedError("Subclasses must implement `forward` or `bicall`.")

def call(self, *args, **kwargs):
"""
User must override this for the inference scenario
if bicall is not defined.
"""
if self._has_bicall():
output = self.bicall(*args, **kwargs)
return output
raise NotImplementedError("Subclasses must implement `call` or `bicall`.")

def bicall(self, *args, **kwargs):
"""
If the user provides a `bicall` method, then `__call__` will automatically
dispatch here for both training and inference scenarios. This can internally
decide how to handle training vs. inference, or just produce a single unified
output type.
"""
# Default fallback if not overridden
raise NotImplementedError(
"Optional method. Implement to handle both scenarios in one place."
)

def __call__(self, *args, **kwargs):
# 1. If `bicall` is defined by the user, use it
# and let the `bicall` implementation handle
# the difference between training vs. inference.
from adalflow.optim.parameter import Parameter

print("has_bicall", self._has_bicall())

if self._has_bicall():
output = self.bicall(*args, **kwargs)

# Validation checks based on training or inference
if self.training:
# Ensure output is a Parameter in training
if not isinstance(output, Parameter):
raise ValueError(
f"Output should be of type Parameter in training mode, but got {type(output)}"
)
else:
# Ensure output is not a Parameter in inference
if isinstance(output, Parameter):
raise ValueError(
f"Output should not be of type Parameter in inference mode, but got {type(output)}"
)
return output

# 2. Otherwise, if `bicall` is not defined, fall back to forward / call
if self.training:
output = self.forward(*args, **kwargs)
print(f"{isinstance(output, Parameter)}")

# Validation for training
if not isinstance(output, Parameter):
raise ValueError(
f"Output should be of type Parameter, but got {type(output)}"
f"Output should be of type Parameter in training mode, but got {type(output)}"
)
return output
else:
output = self.call(*args, **kwargs)
# Validation for inference
if isinstance(output, Parameter):
raise ValueError(
f"Output should not be of type OutputParameter, but got {type(output)}"
f"Output should not be of type Parameter in inference mode, but got {type(output)}"
)
return output

def forward(self, *args, **kwargs):
r"""Forward pass for training mode."""
raise NotImplementedError(
f"Component {type(self).__name__} is missing the 'forward' method for training mode."
)

def call(self, *args, **kwargs):
raise NotImplementedError(
f"Component {type(self).__name__} is missing the required 'call' method."
)
def _has_bicall(self):
"""
Helper method to check if this subclass has overridden bicall.
"""
# The default `bicall` in this class raises NotImplementedError,
# so we can check if the method is still the same one as in `MyModule`.
return self.bicall.__func__ is not Component.bicall

async def acall(self, *args, **kwargs):
r"""API call, file io."""
Expand Down Expand Up @@ -960,76 +1009,6 @@ def _get_init_args(self, *args, **kwargs) -> Dict[str, Any]:
return init_args


# TODO: support async call
class FunComponent(Component):
r"""Component that wraps a function.
Args:
fun (Callable): The function to be wrapped.
Examples:
function = lambda x: x + 1
fun_component = FunComponent(function)
print(fun_component(1)) # 2
"""

def __init__(self, fun: Optional[Callable] = None, afun: Optional[Callable] = None):
super().__init__()
self.fun_name = fun.__name__
EntityMapping.register(self.fun_name, fun)

def call(self, *args, **kwargs):
fun = EntityMapping.get(self.fun_name)
return fun(*args, **kwargs)

def _extra_repr(self) -> str:
return super()._extra_repr() + f"fun_name={self.fun_name}"


def fun_to_component(fun) -> FunComponent:
r"""Helper function to convert a function into a Component with
its own class name.
Can be used as both a decorator and a function.
Args:
fun (Callable): The function to be wrapped.
Returns:
FunComponent: The component that wraps the function.
Examples:
1. As a decorator:
>>> @fun_to_component
>>> def my_function(x):
>>> return x + 1
>>> # is equivalent to
>>> class MyFunctionComponent(FunComponent):
>>> def __init__(self):
>>> super().__init__(my_function)
2. As a function:
>>> my_function_component = fun_to_component(my_function)
"""

# Split the function name by underscores, capitalize each part, and join them back together
class_name = (
"".join(part.capitalize() for part in fun.__name__.split("_")) + "Component"
)
# register the function
EntityMapping.register(fun.__name__, fun)
# Define a new component class dynamically
component_class = type(
class_name,
(FunComponent,),
{"__init__": lambda self: FunComponent.__init__(self, fun)},
)
# register the component
EntityMapping.register(class_name, component_class)

return component_class()


# TODO: not used yet, will further investigate dict mode
# class ComponentDict(Component):
# r"""
Expand Down
9 changes: 7 additions & 2 deletions adalflow/adalflow/core/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from adalflow.utils.cache import CachedEngine
from adalflow.tracing.callback_manager import CallbackManager
from adalflow.utils.global_config import get_adalflow_default_root_path
from adalflow.core.string_parser import JsonParser
from adalflow.core.string_parser import JsonParser, Parser


from adalflow.optim.text_grad.backend_engine_prompt import (
Expand Down Expand Up @@ -105,7 +105,7 @@ def __init__(
template: Optional[str] = None,
prompt_kwargs: Optional[Dict] = {},
# args for the output processing
output_processors: Optional[Component] = None,
output_processors: Optional[Parser] = None,
name: Optional[str] = None,
# args for the cache
cache_path: Optional[str] = None,
Expand Down Expand Up @@ -152,6 +152,11 @@ def __init__(

self.output_processors = output_processors

if output_processors and (not isinstance(output_processors, Parser)):
raise ValueError(
f"output_processors should be a Parser instance, got {type(output_processors)}"
)

self.set_parameters(prompt_kwargs)

# end of trainable parameters
Expand Down
Loading

0 comments on commit a185e7f

Please sign in to comment.