diff --git a/dspy/functional/functional.py b/dspy/functional/functional.py index d407c0a2a..a26a63009 100644 --- a/dspy/functional/functional.py +++ b/dspy/functional/functional.py @@ -1,5 +1,6 @@ import inspect import json +import enum import typing from typing import Annotated, Callable, List, Tuple, Union # noqa: UP035 @@ -236,6 +237,14 @@ def parse(x): format=lambda x: x if isinstance(x, str) else str(x), parser=type_, ) + elif inspect.isclass(type_) and issubclass(type_, enum.Enum): + signature = signature.with_updated_fields( + name, + desc=field.json_schema_extra.get("desc", "") + + f" (Respond with one of: {', '.join(type_.__members__)})", + format=lambda x: x if isinstance(x, str) else str(x), + parser=lambda x: type_(x.strip()), + ) else: # Anything else we wrap in a pydantic object if ( diff --git a/dspy/functional/typed_predictor_signature.py b/dspy/functional/typed_predictor_signature.py new file mode 100644 index 000000000..b94416b7f --- /dev/null +++ b/dspy/functional/typed_predictor_signature.py @@ -0,0 +1,130 @@ +from typing import Annotated, Type, Union +from typing import get_origin, get_args + +from pydantic import BaseModel +from pydantic_core import PydanticUndefined +from pydantic.fields import FieldInfo + +import dspy +from dspy import InputField, OutputField, Signature + + +class TypedPredictorSignature: + @classmethod + def create( + cls, + pydantic_class_for_dspy_input_fields: Type[BaseModel], + pydantic_class_for_dspy_output_fields: Type[BaseModel], + prefix_instructions: str = "") -> Type[Signature]: + """ + Return a DSPy Signature class that can be used to extract the output parameters. + + :param pydantic_class_for_dspy_input_fields: Pydantic class that defines the DSPy InputField's. + :param pydantic_class_for_dspy_output_fields: Pydantic class that defines the DSPy OutputField's. + :param prefix_instructions: Optional text that is prefixed to the instructions. + :return: A DSPy Signature class optimizedfor use with a TypedPredictor to extract structured information. + """ + if prefix_instructions: + prefix_instructions += "\n\n" + instructions = prefix_instructions + "Use only the available information to extract the output fields.\n\n" + dspy_fields = {} + for field_name, field in pydantic_class_for_dspy_input_fields.model_fields.items(): + if field.default and 'typing.Annotated' in str(field.default): + raise ValueError(f"Field '{field_name}' is annotated incorrectly. See 'Constraints on compound types' in https://docs.pydantic.dev/latest/concepts/fields/") + + is_default_value_specified, is_marked_as_optional, inner_field = cls._process_field(field) + if is_marked_as_optional: + if field.default is None or field.default is PydanticUndefined: + field.default = 'null' + field.description = inner_field.description + field.examples = inner_field.examples + field.metadata = inner_field.metadata + field.json_schema_extra = inner_field.json_schema_extra + else: + field.validate_default = False + + input_field = InputField(desc=field.description) + dspy_fields[field_name] = (field.annotation, input_field) + + for field_name, field in pydantic_class_for_dspy_output_fields.model_fields.items(): + if field.default and 'typing.Annotated' in str(field.default): + raise ValueError(f"Field '{field_name}' is annotated incorrectly. See 'Constraints on compound types' in https://docs.pydantic.dev/latest/concepts/fields/") + + is_default_value_specified, is_marked_as_optional, inner_field = cls._process_field(field) + if is_marked_as_optional: + if field.default is None or field.default is PydanticUndefined: + field.default = 'null' + field.description = inner_field.description + field.examples = inner_field.examples + field.metadata = inner_field.metadata + field.json_schema_extra = inner_field.json_schema_extra + else: + field.validate_default = False + + if field.default is PydanticUndefined: + raise ValueError( + f"Field '{field_name}' has no default value. Required fields must have a default value. " + "Change the field to be Optional or specify a default value." + ) + + output_field = OutputField(desc=field.description if field.description else "") + dspy_fields[field_name] = (field.annotation, output_field) + + instructions += f"When extracting '{field_name}':\n" + instructions += f"If it is not mentioned in the input fields, return: '{field.default}'. " + + examples = field.examples + if examples: + quoted_examples = [f"'{example}'" for example in examples] + instructions += f"Example values are: {', '.join(quoted_examples)} etc. " + + if field.metadata: + constraints = [meta for meta in field.metadata if 'Validator' not in str(meta)] + if field.json_schema_extra and 'invalid_value' in field.json_schema_extra: + instructions += f"If the extracted value does not conform to: {constraints}, return: '{field.json_schema_extra['invalid_value']}'." + else: + print(f"WARNING - Field: '{field_name}' is missing an 'invalid_value' attribute. Fields with value constraints should specify an 'invalid_value'.") + instructions += f"If the extracted value does not conform to: {constraints}, return: '{field.default}'." + + instructions += '\n\n' + + return dspy.Signature(dspy_fields, instructions.strip()) + + @classmethod + def _process_field(cls, field: FieldInfo) -> tuple[bool, bool, FieldInfo]: + is_default_value_specified = not field.is_required() + is_marked_as_optional, inner_type, field_info = cls._analyze_field_annotation(field.annotation) + if field_info: + field_info.annotation = inner_type + return is_default_value_specified, is_marked_as_optional, field_info + + return is_default_value_specified, is_marked_as_optional, field + + @classmethod + def _analyze_field_annotation(cls, annotation): + is_optional = False + inner_type = annotation + field_info = None + + # If field is specfied as Optional[Annotated[...]] + if get_origin(annotation) is Union: + args = get_args(annotation) + if type(None) in args: + is_optional = True + inner_type = args[0] if args[0] is not type(None) else args[1] + # Not sure why I added this, perhaps for some aother way of specifying optional fields? + # elif hasattr(annotation, '_name') and annotation._name == 'Optional': + # is_optional = True + # inner_type = get_args(annotation)[0] + + + # Check if it's Annotated + if get_origin(inner_type) is Annotated: + args = get_args(inner_type) + inner_type = args[0] + for arg in args[1:]: + if isinstance(arg, FieldInfo): + field_info = arg + break + + return is_optional, inner_type, field_info diff --git a/tests/functional/test_typed_predictor_signature.py b/tests/functional/test_typed_predictor_signature.py new file mode 100644 index 000000000..a55523a17 --- /dev/null +++ b/tests/functional/test_typed_predictor_signature.py @@ -0,0 +1,82 @@ +import pytest +from pydantic import BaseModel, Field, ValidationError, field_validator +from typing import Annotated, Optional +from dspy.functional.typed_predictor_signature import TypedPredictorSignature + + +class PydanticInput(BaseModel): + command: str + + +class PydanticOutput1(BaseModel): + @field_validator("name", mode="wrap") + @staticmethod + def validate_name(name, handler): + try: + return handler(name) + except ValidationError: + return 'INVALID_NAME' + + name: Annotated[str, + Field(default='NOT_FOUND', max_length=15, + title='Name', description='The name of the person', + examples=['John Doe', 'Jane Doe'], + json_schema_extra={'invalid_value': 'INVALID_NAME'} + ) + ] + +class PydanticOutput2(BaseModel): + @field_validator("age", mode="wrap") + @staticmethod + def validate_age(age, handler): + try: + return handler(age) + except ValidationError: + return -8888 + + age: Annotated[int, + Field(gt=0, lt=150, default=-999, + json_schema_extra={'invalid_value': '-8888'} + ) + ] + +class PydanticOutput3(BaseModel): + age: Annotated[int, + Field(gt=0, lt=150, + json_schema_extra={'invalid_value': '-8888'} + ) + ] = -999 + +class PydanticOutput4(BaseModel): + age: Optional[Annotated[int, + Field(gt=0, lt=150, + json_schema_extra={'invalid_value': '-8888'} + )]] + +class PydanticOutput5(BaseModel): + @field_validator("email", mode="wrap") + @staticmethod + def validate_email(email, handler): + try: + return handler(email) + except ValidationError: + return 'INVALID_EMAIL' + + email: Annotated[str, + Field(default='NOT_FOUND', + pattern=r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$', + json_schema_extra={'invalid_value': 'INVALID_EMAIL'} + ) + ] + +@pytest.mark.parametrize("pydantic_output_class", [ + # PydanticOutput1, + # PydanticOutput2, + # PydanticOutput3, + PydanticOutput4, + # PydanticOutput5 +]) +def test_valid_pydantic_types(pydantic_output_class: str): + dspy_signature_class = TypedPredictorSignature.create(PydanticInput, pydantic_output_class) + assert dspy_signature_class is not None +