Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Created a TypedPredictorSignature class that builds a signature from Pydantic models - this signature is optimized for use with TypedPredictor and TypedChainOfThought #1655

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions dspy/functional/functional.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
import json
import enum
import typing
from typing import Annotated, Callable, List, Tuple, Union # noqa: UP035

Expand Down Expand Up @@ -236,6 +237,14 @@ def parse(x):
format=lambda x: x if isinstance(x, str) else str(x),
parser=type_,
)
elif inspect.isclass(type_) and issubclass(type_, enum.Enum):
signature = signature.with_updated_fields(
name,
desc=field.json_schema_extra.get("desc", "")
+ f" (Respond with one of: {', '.join(type_.__members__)})",
format=lambda x: x if isinstance(x, str) else str(x),
parser=lambda x: type_(x.strip()),
)
else:
# Anything else we wrap in a pydantic object
if (
Expand Down
130 changes: 130 additions & 0 deletions dspy/functional/typed_predictor_signature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from typing import Annotated, Type, Union
from typing import get_origin, get_args

from pydantic import BaseModel
from pydantic_core import PydanticUndefined
from pydantic.fields import FieldInfo

import dspy
from dspy import InputField, OutputField, Signature


class TypedPredictorSignature:
drawal1 marked this conversation as resolved.
Show resolved Hide resolved
@classmethod
def create(
cls,
pydantic_class_for_dspy_input_fields: Type[BaseModel],
pydantic_class_for_dspy_output_fields: Type[BaseModel],
prefix_instructions: str = "") -> Type[Signature]:
"""
Return a DSPy Signature class that can be used to extract the output parameters.

:param pydantic_class_for_dspy_input_fields: Pydantic class that defines the DSPy InputField's.
:param pydantic_class_for_dspy_output_fields: Pydantic class that defines the DSPy OutputField's.
:param prefix_instructions: Optional text that is prefixed to the instructions.
:return: A DSPy Signature class optimizedfor use with a TypedPredictor to extract structured information.
"""
if prefix_instructions:
prefix_instructions += "\n\n"
instructions = prefix_instructions + "Use only the available information to extract the output fields.\n\n"
dspy_fields = {}
for field_name, field in pydantic_class_for_dspy_input_fields.model_fields.items():
if field.default and 'typing.Annotated' in str(field.default):
raise ValueError(f"Field '{field_name}' is annotated incorrectly. See 'Constraints on compound types' in https://docs.pydantic.dev/latest/concepts/fields/")

is_default_value_specified, is_marked_as_optional, inner_field = cls._process_field(field)
Copy link
Collaborator

@okhat okhat Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting to see special support for optional fields hmm, do we need optional fields? Or are Optional[.] types enough? I see that you have this for input fields here and for output fields below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pydantic allows multiple ways to specify a field schema, including Optional property. It could be on the field spec inside the annotation or outside. IMO, we should not let users guess what pydantic field metadata we do and don't support. If we can support it, we should

if is_marked_as_optional:
if field.default is None or field.default is PydanticUndefined:
field.default = 'null'
field.description = inner_field.description
field.examples = inner_field.examples
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting to see examples and metadata per field

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are definitely valuable signals. We should leverage Pydantic field metadata instead of duplicating in DSPy signature fields

field.metadata = inner_field.metadata
field.json_schema_extra = inner_field.json_schema_extra
else:
field.validate_default = False

input_field = InputField(desc=field.description)
dspy_fields[field_name] = (field.annotation, input_field)

for field_name, field in pydantic_class_for_dspy_output_fields.model_fields.items():
if field.default and 'typing.Annotated' in str(field.default):
raise ValueError(f"Field '{field_name}' is annotated incorrectly. See 'Constraints on compound types' in https://docs.pydantic.dev/latest/concepts/fields/")

is_default_value_specified, is_marked_as_optional, inner_field = cls._process_field(field)
if is_marked_as_optional:
if field.default is None or field.default is PydanticUndefined:
field.default = 'null'
field.description = inner_field.description
field.examples = inner_field.examples
field.metadata = inner_field.metadata
field.json_schema_extra = inner_field.json_schema_extra
else:
field.validate_default = False

if field.default is PydanticUndefined:
raise ValueError(
f"Field '{field_name}' has no default value. Required fields must have a default value. "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah interesting concept here, requiring a default value for every required output field, why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users may provide input without passing the value of a required parameter. So how to indicate that the required value was missing in the input, without throwing an exception? And without forcing the LM to hallucinate a required missing value?

Luckily, you can specify that Pydantic default value should not be subject to validation. This means that regardless of the type (str, int, float, ...), you can specify a default value of "NOT_FOUND" and this signature will correctly detect and return it without hallucinating.

No wrap validators necessary, which in any case don't distinguish between invalid and missing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unique nature of agentic applications is that we cannot assume user input is being validated in the UI for required input. Combine this with LM ability to hallucinate when forced to do so...

"Change the field to be Optional or specify a default value."
)

output_field = OutputField(desc=field.description if field.description else "")
dspy_fields[field_name] = (field.annotation, output_field)

instructions += f"When extracting '{field_name}':\n"
instructions += f"If it is not mentioned in the input fields, return: '{field.default}'. "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting choice of words. "When extracting" and "return". Something to keep in mind.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to work. Any pitfalls here? Are you thinking if the field name was "extracted_name" or field.default was "return"? Not quite sure how to get around that


examples = field.examples
if examples:
quoted_examples = [f"'{example}'" for example in examples]
Copy link
Collaborator

@okhat okhat Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naive quotes will fail on complex values? same with the naive join below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only issue I have seen is the LM sometimes returning string values enclosed in quotes. For example, returning "'John Doe'", instead of "John Doe". But this is easy to strip. It correctly extracts numbers without putting them in quotes but I guess the prompt could be enhanced to detect if the example values should be quoted. Any other suggestions on how this could be improved?

Also, re. complex types - isn't that what I am testing with complex Input and Output BaseModels? So may be the code for pulling apart complex types and building a custom signature is unavoidable. Thoughts?

instructions += f"Example values are: {', '.join(quoted_examples)} etc. "

if field.metadata:
constraints = [meta for meta in field.metadata if 'Validator' not in str(meta)]
if field.json_schema_extra and 'invalid_value' in field.json_schema_extra:
instructions += f"If the extracted value does not conform to: {constraints}, return: '{field.json_schema_extra['invalid_value']}'."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very interesting, asking the model to signal bad/hard fields

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Detecting missing and invalid values is a "must-have" requirement for production-quality apps. Here, I am basically trying to avoid the multiple retries that the framework makes when the extracted value does not match the constraints specified in the Pydantic BaseModel.

The LM may be incorrectly extracting the parameter or the user has specified an invalid value. In both cases, the system should not hallucinate the closest value that matches the actual input. Rather, it should flag the invalid field and let the application handle the error with a proper error-correction workflow ("Provided value was invalid. It must be... Here are some examples... Did you mean...?")

else:
print(f"WARNING - Field: '{field_name}' is missing an 'invalid_value' attribute. Fields with value constraints should specify an 'invalid_value'.")
instructions += f"If the extracted value does not conform to: {constraints}, return: '{field.default}'."

instructions += '\n\n'

return dspy.Signature(dspy_fields, instructions.strip())

@classmethod
def _process_field(cls, field: FieldInfo) -> tuple[bool, bool, FieldInfo]:
is_default_value_specified = not field.is_required()
is_marked_as_optional, inner_type, field_info = cls._analyze_field_annotation(field.annotation)
if field_info:
field_info.annotation = inner_type
return is_default_value_specified, is_marked_as_optional, field_info

return is_default_value_specified, is_marked_as_optional, field

@classmethod
def _analyze_field_annotation(cls, annotation):
is_optional = False
inner_type = annotation
field_info = None

# If field is specfied as Optional[Annotated[...]]
if get_origin(annotation) is Union:
args = get_args(annotation)
if type(None) in args:
is_optional = True
inner_type = args[0] if args[0] is not type(None) else args[1]
# Not sure why I added this, perhaps for some aother way of specifying optional fields?
# elif hasattr(annotation, '_name') and annotation._name == 'Optional':
# is_optional = True
# inner_type = get_args(annotation)[0]


# Check if it's Annotated
if get_origin(inner_type) is Annotated:
args = get_args(inner_type)
inner_type = args[0]
for arg in args[1:]:
if isinstance(arg, FieldInfo):
field_info = arg
break

return is_optional, inner_type, field_info
82 changes: 82 additions & 0 deletions tests/functional/test_typed_predictor_signature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import pytest
from pydantic import BaseModel, Field, ValidationError, field_validator
from typing import Annotated, Optional
from dspy.functional.typed_predictor_signature import TypedPredictorSignature


class PydanticInput(BaseModel):
command: str


class PydanticOutput1(BaseModel):
@field_validator("name", mode="wrap")
@staticmethod
def validate_name(name, handler):
try:
return handler(name)
except ValidationError:
return 'INVALID_NAME'

name: Annotated[str,
Field(default='NOT_FOUND', max_length=15,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why use Annotated instead of assignment = Field(...)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From Pydantic docs - "In case you use field constraints with compound types, an error can happen in some cases. To avoid potential issues, you can use Annotated:"

See https://docs.pydantic.dev/latest/concepts/fields/#validation-alias

title='Name', description='The name of the person',
examples=['John Doe', 'Jane Doe'],
json_schema_extra={'invalid_value': 'INVALID_NAME'}
)
]

class PydanticOutput2(BaseModel):
@field_validator("age", mode="wrap")
@staticmethod
def validate_age(age, handler):
try:
return handler(age)
except ValidationError:
return -8888

age: Annotated[int,
Field(gt=0, lt=150, default=-999,
json_schema_extra={'invalid_value': '-8888'}
)
]

class PydanticOutput3(BaseModel):
age: Annotated[int,
Field(gt=0, lt=150,
json_schema_extra={'invalid_value': '-8888'}
)
] = -999

class PydanticOutput4(BaseModel):
age: Optional[Annotated[int,
Field(gt=0, lt=150,
json_schema_extra={'invalid_value': '-8888'}
)]]

class PydanticOutput5(BaseModel):
@field_validator("email", mode="wrap")
@staticmethod
def validate_email(email, handler):
try:
return handler(email)
except ValidationError:
return 'INVALID_EMAIL'

email: Annotated[str,
Field(default='NOT_FOUND',
pattern=r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$',
json_schema_extra={'invalid_value': 'INVALID_EMAIL'}
)
]

@pytest.mark.parametrize("pydantic_output_class", [
# PydanticOutput1,
# PydanticOutput2,
# PydanticOutput3,
PydanticOutput4,
# PydanticOutput5
])
def test_valid_pydantic_types(pydantic_output_class: str):
dspy_signature_class = TypedPredictorSignature.create(PydanticInput, pydantic_output_class)
assert dspy_signature_class is not None

Loading