Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FR] Add chatprompt_chain decorator #391

Closed
alexchandel opened this issue Dec 20, 2024 · 1 comment · Fixed by #403
Closed

[FR] Add chatprompt_chain decorator #391

alexchandel opened this issue Dec 20, 2024 · 1 comment · Fixed by #403

Comments

@alexchandel
Copy link
Contributor

From #389, it would be nice to have a chatprompt_chain decorator that simplified the process of repeatedly evaluating function calls from a chat template.

Part of the difficulty in writing this is that the chatprompt collection of types (off of BaseChatPromptFunction) are more general/sophisticated than incompatible with the BasePromptFunction / Chat / FunctionCall collection of types, like Chat somehow doesn't have a factory function that takes a BaseChatPromptFunction. It seems like some redundancy should be eliminated here and these hierarchies should b emerged. BasePromptFunction should probably be a specialization that simply has one UserMessage.

#390 fixes one issue with trying to stuff a ChatPromptFunction into a Chat, but there are others.

@alexchandel
Copy link
Contributor Author

alexchandel commented Dec 20, 2024

Nonetheless I've taken a stab at it:

"""Chat prompt chain."""

import inspect
from collections.abc import Awaitable
from functools import wraps
from typing import Any, Callable, ParamSpec, TypeVar, cast

from magentic.chat import Chat
from magentic.chat_model.base import ChatModel
from magentic.chat_model.message import Message
from magentic.chatprompt import AsyncChatPromptFunction, ChatPromptDecorator, ChatPromptFunction
from magentic.function_call import (
    # AsyncParallelFunctionCall,
    FunctionCall,
    # ParallelFunctionCall,
)
from magentic.logger import logfire
from magentic.prompt_chain import MaxFunctionCallsError

P = ParamSpec("P")
R = TypeVar("R")


def chatprompt_chain(
    *messages: Message[Any],
    functions: list[Callable[..., Any]] | None = None,
    stop: list[str] | None = None,
    max_retries: int = 0,
    model: ChatModel | None = None,
    max_calls: int | None = None,
) -> ChatPromptDecorator:
    """Convert a Python function to an LLM chat prompt, auto-resolving function calls.

    The `@chatprompt_chain` decorator allows you to define a prompt template for chat-based Large Language Models (LLM).
    """

    def decorator(
        func: Callable[P, Awaitable[R]] | Callable[P, R],
    ) -> AsyncChatPromptFunction[P, R] | ChatPromptFunction[P, R]:
        func_signature = inspect.signature(func)

        if inspect.iscoroutinefunction(func):
            async_prompt_function = AsyncChatPromptFunction[P, R](
                name=func.__name__,
                parameters=list(func_signature.parameters.values()),
                # TODO: Also allow ParallelFunctionCall. Support this more neatly
                return_type=func_signature.return_annotation | FunctionCall,  # type: ignore[arg-type,unused-ignore]
                messages=messages,
                functions=functions,
                stop=stop,
                max_retries=max_retries,
                model=model,
            )

            @wraps(func)
            async def awrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
                with logfire.span(
                    f"Calling async prompt-chain {func.__name__}",
                    **func_signature.bind(*args, **kwargs).arguments,
                ):
                    chat = Chat(
                        messages=async_prompt_function.format(*args, **kwargs),
                        functions=async_prompt_function.functions,
                        output_types=async_prompt_function.return_types,
                        model=async_prompt_function.model,
                        *args,
                        **kwargs
                    ).submit()
                    num_calls = 0
                    while callable(chat.last_message.content): # was FunctionCall
                        if max_calls is not None and num_calls >= max_calls:
                            msg = (
                                f"Function {func.__name__} reached limit of"
                                f" {max_calls} function calls"
                            )
                            raise MaxFunctionCallsError(msg)
                        chat = await chat.aexec_function_call()
                        chat = await chat.asubmit()
                        num_calls += 1
                    return chat.last_message.content

            return cast(
                AsyncChatPromptFunction[P, R],
                awrapper,
            )

        prompt_function = ChatPromptFunction[P, R](
            name=func.__name__,
            parameters=list(func_signature.parameters.values()),
            # TODO: Also allow ParallelFunctionCall. Support this more neatly
            return_type=func_signature.return_annotation | FunctionCall,  # type: ignore[arg-type,unused-ignore]
            messages=messages,
            functions=functions,
            stop=stop,
            max_retries=max_retries,
            model=model,
        )

        @wraps(func)
        def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
            with logfire.span(
                f"Calling prompt-chain {func.__name__}",
                **func_signature.bind(*args, **kwargs).arguments,
            ):
                chat = Chat(
                    messages=prompt_function.format(*args, **kwargs),
                    functions=prompt_function.functions,
                    output_types=prompt_function.return_types,
                    model=prompt_function.model,
                ).submit()
                num_calls = 0
                while callable(chat.last_message.content): # was FunctionCall
                    if max_calls is not None and num_calls >= max_calls:
                        msg = (
                            f"Function {func.__name__} reached limit of"
                            f" {max_calls} function calls"
                        )
                        raise MaxFunctionCallsError(msg)
                    chat = chat.exec_function_call().submit()
                    num_calls += 1
                return cast(R, chat.last_message.content)

        return cast(
            ChatPromptFunction[P, R],
            wrapper
        )

    return cast(ChatPromptDecorator, decorator)

You probably want to test this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
1 participant