Skip to content

Commit

Permalink
Unpack PEP 695 type aliases if using the Annotated form (pydantic#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
Viicos authored Dec 27, 2024
1 parent fac2bfd commit b236291
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 73 deletions.
25 changes: 19 additions & 6 deletions pydantic/_internal/_generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1967,12 +1967,25 @@ def _computed_field_schema(
def _annotated_schema(self, annotated_type: Any) -> core_schema.CoreSchema:
"""Generate schema for an Annotated type, e.g. `Annotated[int, Field(...)]` or `Annotated[int, Gt(0)]`."""
FieldInfo = import_cached_field_info()

source_type, *annotations = self._get_args_resolving_forward_refs(
annotated_type,
required=True,
)
schema = self._apply_annotations(source_type, annotations)
# Ideally, we should delegate all this to `_typing_extra.unpack_annotated`, e.g.:
# `typ, annotations = _typing_extra.unpack_annotated(annotated_type); schema = self.apply_annotations(...)`
# if it was able to use a `NsResolver`. But because `unpack_annotated` is also used
# when constructing `FieldInfo` instances (where we don't have access to a `NsResolver`),
# the implementation of the function does *not* resolve forward annotations. This could
# be solved by calling `unpack_annotated` directly inside `collect_model_fields`.
# For now, we at least resolve the annotated type if it is a forward ref, but note that
# unexpected results will happen if you have something like `Annotated[Alias, ...]` and
# `Alias` is a PEP 695 type alias containing forward references.
typ, *annotations = get_args(annotated_type)
if isinstance(typ, str):
typ = _typing_extra._make_forward_ref(typ)
if isinstance(typ, ForwardRef):
typ = self._resolve_forward_ref(typ)

typ, sub_annotations = _typing_extra.unpack_annotated(typ)
annotations = sub_annotations + annotations

schema = self._apply_annotations(typ, annotations)
# put the default validator last so that TypeAdapter.get_default_value() works
# even if there are function validators involved
for annotation in annotations:
Expand Down
94 changes: 93 additions & 1 deletion pydantic/_internal/_typing_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,89 @@ def annotated_type(tp: Any, /) -> Any | None:
return get_args(tp)[0] if is_annotated(tp) else None


def unpack_annotated(annotation: Any, /) -> tuple[Any, list[Any]]:
"""Unpack the annotation if it is wrapped with the `Annotated` type qualifier.
This function also unpacks PEP 695 type aliases if necessary (and also generic
aliases with a PEP 695 type alias origin). However, it does *not* try to evaluate
forward references, so users should make sure the type alias' `__value__` does not
contain unresolvable forward references.
Example:
```python {test="skip" lint="skip"}
from typing import Annotated
type InnerList[T] = Annotated[list[T], 'meta_1']
type MyList[T] = Annotated[InnerList[T], 'meta_2']
type MyIntList = MyList[int]
_unpack_annotated(MyList)
#> (list[T], ['meta_1', 'meta_2'])
_unpack_annotated(MyList[int])
#> (list[int], ['meta_1', 'meta_2'])
_unpack_annotated(MyIntList)
#> (list[int], ['meta_1', 'meta_2'])
```
Returns:
A two-tuple, the first element is the annotated type and the second element
is a list containing the annotated metadata. If the annotation wasn't
wrapped with `Annotated` in the first place, it is returned as is and the
metadata list is empty.
"""
if is_annotated(annotation):
typ, *metadata = typing_extensions.get_args(annotation)
# The annotated type might be a PEP 695 type alias, so we need to recursively
# unpack it. Note that we could make an optimization here: the following next
# call to `_unpack_annotated` could omit the `is_annotated` check, because Python
# already flattens `Annotated[Annotated[<type>, ...], ...]` forms. However, we would
# need to "re-enable" the check for further recursive calls.
typ, sub_meta = unpack_annotated(typ)
metadata = sub_meta + metadata
return typ, metadata
elif is_type_alias_type(annotation):
try:
value = annotation.__value__
except NameError:
# The type alias value contains an unresolvable reference. Note that even if it
# resolves successfully, it might contain string annotations, and because of design
# limitations we don't evaluate the type (we don't have access to a `NsResolver` instance).
pass
else:
typ, metadata = unpack_annotated(value)
if metadata:
# Having metadata means the type alias' `__value__` was an `Annotated` form
# (or, recursively, a type alias to an `Annotated` form). It is important to
# check for this as we don't want to unpack "normal" type aliases (e.g. `type MyInt = int`).
return typ, metadata
return annotation, []
elif is_generic_alias(annotation):
# When parametrized, a PEP 695 type alias becomes a generic alias
# (e.g. with `type MyList[T] = Annotated[list[T], ...]`, `MyList[int]`
# is a generic alias).
origin = typing_extensions.get_origin(annotation)
if is_type_alias_type(origin):
try:
value = origin.__value__
except NameError:
pass
else:
# Circular import (note that these two functions should probably be defined in `_typing_extra`):
from ._generics import get_standard_typevars_map, replace_types

# While Python already handles type variable replacement for simple `Annotated` forms,
# we need to manually apply the same logic for PEP 695 type aliases:
# - With `MyList = Annotated[list[T], ...]`, `MyList[int] == Annotated[list[int], ...]`
# - With `type MyList = Annotated[list[T], ...]`, `MyList[int].__value__ == Annotated[list[T], ...]`.
value = replace_types(value, get_standard_typevars_map(annotation))
typ, metadata = unpack_annotated(value)
if metadata:
return typ, metadata
return annotation, []

return annotation, []


def is_unpack(tp: Any, /) -> bool:
"""Return whether the provided argument is a `Unpack` special form.
Expand Down Expand Up @@ -221,6 +304,8 @@ def is_paramspec(tp: Any, /) -> bool:
if sys.version_info >= (3, 12):
_TYPE_ALIAS_TYPES = (*_TYPE_ALIAS_TYPES, typing.TypeAliasType)

_IS_PY310 = sys.version_info[:2] == (3, 10)


def is_type_alias_type(tp: Any, /) -> TypeIs[typing_extensions.TypeAliasType]:
"""Return whether the provided argument is an instance of `TypeAliasType`.
Expand All @@ -234,7 +319,14 @@ def is_type_alias_type(tp: Any, /) -> TypeIs[typing_extensions.TypeAliasType]:
#> True
```
"""
return isinstance(tp, _TYPE_ALIAS_TYPES)
if _IS_PY310:
# Parametrized PEP 695 type aliases are instances of `types.GenericAlias` in typing_extensions>=4.13.0.
# On Python 3.10, with `Alias[int]` being such an instance of `GenericAlias`,
# `isinstance(Alias[int], TypeAliasType)` returns `True`.
# See https://github.com/python/cpython/issues/89828.
return type(tp) is not types.GenericAlias and isinstance(tp, _TYPE_ALIAS_TYPES)
else:
return isinstance(tp, _TYPE_ALIAS_TYPES)


def is_classvar(tp: Any, /) -> bool:
Expand Down
102 changes: 52 additions & 50 deletions pydantic/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
class _FromFieldInfoInputs(typing_extensions.TypedDict, total=False):
"""This class exists solely to add type checking for the `**kwargs` in `FieldInfo.from_field`."""

# TODO PEP 747: use TypeForm:
annotation: type[Any] | None
default_factory: Callable[[], Any] | Callable[[dict[str, Any]], Any] | None
alias: str | None
Expand Down Expand Up @@ -207,7 +208,7 @@ def __init__(self, **kwargs: Unpack[_FieldInfoInputs]) -> None:
"""
self._attributes_set = {k: v for k, v in kwargs.items() if v is not _Unset}
kwargs = {k: _DefaultValues.get(k) if v is _Unset else v for k, v in kwargs.items()} # type: ignore
self.annotation, annotation_metadata = self._extract_metadata(kwargs.get('annotation'))
self.annotation = kwargs.get('annotation')
self.evaluated = False

default = kwargs.pop('default', PydanticUndefined)
Expand Down Expand Up @@ -247,7 +248,7 @@ def __init__(self, **kwargs: Unpack[_FieldInfoInputs]) -> None:
self.init_var = kwargs.pop('init_var', None)
self.kw_only = kwargs.pop('kw_only', None)

self.metadata = self._collect_metadata(kwargs) + annotation_metadata # type: ignore
self.metadata = self._collect_metadata(kwargs) # type: ignore

@staticmethod
def from_field(default: Any = PydanticUndefined, **kwargs: Unpack[_FromFieldInfoInputs]) -> FieldInfo:
Expand Down Expand Up @@ -310,34 +311,51 @@ class MyModel(pydantic.BaseModel):
Returns:
An instance of the field metadata.
"""
final = False
if _typing_extra.is_finalvar(annotation):
final = True
if annotation is not typing_extensions.Final:
# 1. Check if the annotation is the `Final` type qualifier:
final = _typing_extra.is_finalvar(annotation)
if final:
if _typing_extra.is_generic_alias(annotation):
# 1.1. The annotation is a parametrized `Final`, e.g. `Final[int]`.
# In this case, `annotation` will be `int`:
annotation = typing_extensions.get_args(annotation)[0]

if _typing_extra.is_annotated(annotation):
first_arg, *extra_args = typing_extensions.get_args(annotation)
if _typing_extra.is_finalvar(first_arg):
final = True
field_info_annotations = [a for a in extra_args if isinstance(a, FieldInfo)]
field_info = FieldInfo.merge_field_infos(*field_info_annotations, annotation=first_arg)
else:
# 1.2. The annotation is a bare `Final`. Use `Any` as a type annotation:
return FieldInfo(annotation=Any, frozen=True) # pyright: ignore[reportArgumentType] (PEP 747)

# 2. Check if the annotation is an `Annotated` form.
# In this case, `annotation` will be the annotated type:
annotation, metadata = _typing_extra.unpack_annotated(annotation)

# 3. If we have metadata, `annotation` was the annotated type:
if metadata:
# 3.1. Check if the annotated type is the `Final` type qualifier.
# (i.e. `Annotated[Final[...], ...]`). Note that we only do
# so if `final` isn't `True` already, because we don't want to
# support the invalid `Final[Annotated[Final, ...]]` form.
if not final:
final = _typing_extra.is_finalvar(annotation)
if final and _typing_extra.is_generic_alias(annotation):
annotation = typing_extensions.get_args(annotation)[0]

field_info_annotations = [a for a in metadata if isinstance(a, FieldInfo)]
field_info = FieldInfo.merge_field_infos(*field_info_annotations, annotation=annotation)
if field_info:
new_field_info = copy(field_info)
new_field_info.annotation = first_arg
new_field_info.annotation = annotation
new_field_info.frozen = final or field_info.frozen
metadata: list[Any] = []
for a in extra_args:
field_metadata: list[Any] = []
for a in metadata:
if _typing_extra.is_deprecated_instance(a):
new_field_info.deprecated = a.message
elif not isinstance(a, FieldInfo):
metadata.append(a)
field_metadata.append(a)
else:
metadata.extend(a.metadata)
new_field_info.metadata = metadata
field_metadata.extend(a.metadata)
new_field_info.metadata = field_metadata
return new_field_info

return FieldInfo(annotation=annotation, frozen=final or None) # pyright: ignore[reportArgumentType]
# 4. We don't have metadata:
return FieldInfo(annotation=annotation, frozen=final or None) # pyright: ignore[reportArgumentType] (PEP 747)

@staticmethod
def from_annotated_attribute(annotation: type[Any], default: Any) -> FieldInfo:
Expand Down Expand Up @@ -367,16 +385,16 @@ class MyModel(pydantic.BaseModel):
if annotation is default:
raise PydanticUserError(
'Error when building FieldInfo from annotated attribute. '
"Make sure you don't have any field name clashing with a type annotation ",
"Make sure you don't have any field name clashing with a type annotation.",
code='unevaluable-type-annotation',
)

final = _typing_extra.is_finalvar(annotation)
if final and annotation is not typing_extensions.Final:
if final and _typing_extra.is_generic_alias(annotation):
annotation = typing_extensions.get_args(annotation)[0]

if isinstance(default, FieldInfo):
default.annotation, annotation_metadata = FieldInfo._extract_metadata(annotation) # pyright: ignore[reportArgumentType]
default.annotation, annotation_metadata = _typing_extra.unpack_annotated(annotation)
default.metadata += annotation_metadata
default = default.merge_field_infos(
*[x for x in annotation_metadata if isinstance(x, FieldInfo)], default, annotation=default.annotation
Expand All @@ -394,7 +412,7 @@ class MyModel(pydantic.BaseModel):
annotation = annotation.type

pydantic_field = FieldInfo._from_dataclass_field(default)
pydantic_field.annotation, annotation_metadata = FieldInfo._extract_metadata(annotation) # pyright: ignore[reportArgumentType]
pydantic_field.annotation, annotation_metadata = _typing_extra.unpack_annotated(annotation)
pydantic_field.metadata += annotation_metadata
pydantic_field = pydantic_field.merge_field_infos(
*[x for x in annotation_metadata if isinstance(x, FieldInfo)],
Expand All @@ -407,19 +425,20 @@ class MyModel(pydantic.BaseModel):
pydantic_field.kw_only = getattr(default, 'kw_only', None)
return pydantic_field

if _typing_extra.is_annotated(annotation):
first_arg, *extra_args = typing_extensions.get_args(annotation)
field_infos = [a for a in extra_args if isinstance(a, FieldInfo)]
field_info = FieldInfo.merge_field_infos(*field_infos, annotation=first_arg, default=default)
metadata: list[Any] = []
for a in extra_args:
annotation, metadata = _typing_extra.unpack_annotated(annotation)

if metadata:
field_infos = [a for a in metadata if isinstance(a, FieldInfo)]
field_info = FieldInfo.merge_field_infos(*field_infos, annotation=annotation, default=default)
field_metadata: list[Any] = []
for a in metadata:
if _typing_extra.is_deprecated_instance(a):
field_info.deprecated = a.message
elif not isinstance(a, FieldInfo):
metadata.append(a)
field_metadata.append(a)
else:
metadata.extend(a.metadata)
field_info.metadata = metadata
field_metadata.extend(a.metadata)
field_info.metadata = field_metadata
return field_info

return FieldInfo(annotation=annotation, default=default, frozen=final or None) # pyright: ignore[reportArgumentType]
Expand Down Expand Up @@ -516,23 +535,6 @@ def _from_dataclass_field(dc_field: DataclassField[Any]) -> FieldInfo:
dc_field_metadata = {k: v for k, v in dc_field.metadata.items() if k in _FIELD_ARG_NAMES}
return Field(default=default, default_factory=default_factory, repr=dc_field.repr, **dc_field_metadata) # pyright: ignore[reportCallIssue]

@staticmethod
def _extract_metadata(annotation: type[Any] | None) -> tuple[type[Any] | None, list[Any]]:
"""Tries to extract metadata/constraints from an annotation if it uses `Annotated`.
Args:
annotation: The type hint annotation for which metadata has to be extracted.
Returns:
A tuple containing the extracted metadata type and the list of extra arguments.
"""
if annotation is not None:
if _typing_extra.is_annotated(annotation):
first_arg, *extra_args = typing_extensions.get_args(annotation)
return first_arg, list(extra_args)

return annotation, []

@staticmethod
def _collect_metadata(kwargs: dict[str, Any]) -> list[Any]:
"""Collect annotations from kwargs.
Expand Down
Loading

0 comments on commit b236291

Please sign in to comment.