Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡️ Speed up function is_valid_field_name by 26% in pydantic/_internal/_fields.py #13

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 57 additions & 26 deletions pydantic/_internal/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
from ._config import ConfigWrapper
from ._docs_extraction import extract_docstrings_from_cls
from ._repr import Representation
from ._typing_extra import get_cls_type_hints_lenient, get_type_hints, is_classvar, is_finalvar
from ._typing_extra import (
get_cls_type_hints_lenient,
get_type_hints,
is_classvar,
is_finalvar,
)

if TYPE_CHECKING:
from annotated_types import BaseMetadata
Expand Down Expand Up @@ -46,15 +51,17 @@ def get_type_hints_infer_globalns(
Returns:
The object type hints.
"""
module_name = getattr(obj, '__module__', None)
module_name = getattr(obj, "__module__", None)
globalns: dict[str, Any] | None = None
if module_name:
try:
globalns = sys.modules[module_name].__dict__
except KeyError:
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
pass
return get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras)
return get_type_hints(
obj, globalns=globalns, localns=localns, include_extras=include_extras
)


class PydanticMetadata(Representation):
Expand Down Expand Up @@ -89,7 +96,9 @@ def __init__(self, metadata: Any):
return _PydanticGeneralMetadata # type: ignore


def _update_fields_from_docstrings(cls: type[Any], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper) -> None:
def _update_fields_from_docstrings(
cls: type[Any], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper
) -> None:
if config_wrapper.use_attribute_docstrings:
fields_docs = extract_docstrings_from_cls(cls)
for ann_name, field_info in fields.items():
Expand Down Expand Up @@ -133,12 +142,12 @@ def collect_model_fields( # noqa: C901

# https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
# annotations is only used for finding fields in parent classes
annotations = cls.__dict__.get('__annotations__', {})
annotations = cls.__dict__.get("__annotations__", {})
fields: dict[str, FieldInfo] = {}

class_vars: set[str] = set()
for ann_name, ann_type in type_hints.items():
if ann_name == 'model_config':
if ann_name == "model_config":
# We never want to treat `model_config` as a field
# Note: we may need to change this logic if/when we introduce a `BareModel` class with no
# protected namespaces (where `model_config` might be allowed as a field name)
Expand All @@ -149,40 +158,49 @@ def collect_model_fields( # noqa: C901
if hasattr(b, ann_name):
from ..main import BaseModel

if not (issubclass(b, BaseModel) and ann_name in b.model_fields):
if not (
issubclass(b, BaseModel) and ann_name in b.model_fields
):
raise NameError(
f'Field "{ann_name}" conflicts with member {getattr(b, ann_name)}'
f' of protected namespace "{protected_namespace}".'
)
else:
valid_namespaces = tuple(
x for x in config_wrapper.protected_namespaces if not ann_name.startswith(x)
x
for x in config_wrapper.protected_namespaces
if not ann_name.startswith(x)
)
warnings.warn(
f'Field "{ann_name}" has conflict with protected namespace "{protected_namespace}".'
'\n\nYou may be able to resolve this warning by setting'
"\n\nYou may be able to resolve this warning by setting"
f" `model_config['protected_namespaces'] = {valid_namespaces}`.",
UserWarning,
)
if is_classvar(ann_type):
class_vars.add(ann_name)
continue
if _is_finalvar_with_default_val(ann_type, getattr(cls, ann_name, PydanticUndefined)):
if _is_finalvar_with_default_val(
ann_type, getattr(cls, ann_name, PydanticUndefined)
):
class_vars.add(ann_name)
continue
if not is_valid_field_name(ann_name):
continue
if cls.__pydantic_root_model__ and ann_name != 'root':
if cls.__pydantic_root_model__ and ann_name != "root":
raise NameError(
f"Unexpected field with name {ann_name!r}; only 'root' is allowed as a field of a `RootModel`"
)

# when building a generic model with `MyModel[int]`, the generic_origin check makes sure we don't get
# "... shadows an attribute" warnings
generic_origin = getattr(cls, '__pydantic_generic_metadata__', {}).get('origin')
generic_origin = getattr(cls, "__pydantic_generic_metadata__", {}).get("origin")
for base in bases:
dataclass_fields = {
field.name for field in (dataclasses.fields(base) if dataclasses.is_dataclass(base) else ())
field.name
for field in (
dataclasses.fields(base) if dataclasses.is_dataclass(base) else ()
)
}
if hasattr(base, ann_name):
if base is generic_origin:
Expand Down Expand Up @@ -216,7 +234,7 @@ def collect_model_fields( # noqa: C901
# defined in a base class and we can take it from there
model_fields_lookup: dict[str, FieldInfo] = {}
for x in cls.__bases__[::-1]:
model_fields_lookup.update(getattr(x, 'model_fields', {}))
model_fields_lookup.update(getattr(x, "model_fields", {}))
if ann_name in model_fields_lookup:
# The field was present on one of the (possibly multiple) base classes
# copy the field to make sure typevar substitutions don't cause issues with the base classes
Expand All @@ -239,7 +257,7 @@ def collect_model_fields( # noqa: C901

# Use cls.__dict__['__pydantic_decorators__'] instead of cls.__pydantic_decorators__
# to make sure the decorators have already been built for this exact class
decorators: DecoratorInfos = cls.__dict__['__pydantic_decorators__']
decorators: DecoratorInfos = cls.__dict__["__pydantic_decorators__"]
if ann_name in decorators.computed_fields:
raise ValueError("you can't override a field with a computed field")
fields[ann_name] = field_info
Expand All @@ -256,11 +274,14 @@ def collect_model_fields( # noqa: C901
def _warn_on_nested_alias_in_annotation(ann_type: type[Any], ann_name: str):
from ..fields import FieldInfo

if hasattr(ann_type, '__args__'):
if hasattr(ann_type, "__args__"):
for anno_arg in ann_type.__args__:
if _typing_extra.is_annotated(anno_arg):
for anno_type_arg in _typing_extra.get_args(anno_arg):
if isinstance(anno_type_arg, FieldInfo) and anno_type_arg.alias is not None:
if (
isinstance(anno_type_arg, FieldInfo)
and anno_type_arg.alias is not None
):
warnings.warn(
f'`alias` specification on field "{ann_name}" must be set on outermost annotation to take effect.',
UserWarning,
Expand All @@ -275,7 +296,9 @@ def _is_finalvar_with_default_val(type_: type[Any], val: Any) -> bool:
return False
elif val is PydanticUndefined:
return False
elif isinstance(val, FieldInfo) and (val.default is PydanticUndefined and val.default_factory is None):
elif isinstance(val, FieldInfo) and (
val.default is PydanticUndefined and val.default_factory is None
):
return False
else:
return True
Expand Down Expand Up @@ -303,14 +326,18 @@ def collect_dataclass_fields(

fields: dict[str, FieldInfo] = {}
dataclass_fields: dict[str, dataclasses.Field] = cls.__dataclass_fields__
cls_localns = dict(vars(cls)) # this matches get_cls_type_hints_lenient, but all tests pass with `= None` instead
cls_localns = dict(
vars(cls)
) # this matches get_cls_type_hints_lenient, but all tests pass with `= None` instead

source_module = sys.modules.get(cls.__module__)
if source_module is not None:
types_namespace = {**source_module.__dict__, **(types_namespace or {})}

for ann_name, dataclass_field in dataclass_fields.items():
ann_type = _typing_extra.eval_type_lenient(dataclass_field.type, types_namespace, cls_localns)
ann_type = _typing_extra.eval_type_lenient(
dataclass_field.type, types_namespace, cls_localns
)
if is_classvar(ann_type):
continue

Expand All @@ -327,19 +354,23 @@ def collect_dataclass_fields(
if dataclass_field.default.init_var:
if dataclass_field.default.init is False:
raise PydanticUserError(
f'Dataclass field {ann_name} has init=False and init_var=True, but these are mutually exclusive.',
code='clashing-init-and-init-var',
f"Dataclass field {ann_name} has init=False and init_var=True, but these are mutually exclusive.",
code="clashing-init-and-init-var",
)

# TODO: same note as above re validate_assignment
continue
field_info = FieldInfo.from_annotated_attribute(ann_type, dataclass_field.default)
field_info = FieldInfo.from_annotated_attribute(
ann_type, dataclass_field.default
)
else:
field_info = FieldInfo.from_annotated_attribute(ann_type, dataclass_field)

fields[ann_name] = field_info

if field_info.default is not PydanticUndefined and isinstance(getattr(cls, ann_name, field_info), FieldInfo):
if field_info.default is not PydanticUndefined and isinstance(
getattr(cls, ann_name, field_info), FieldInfo
):
# We need this to fix the default when the "default" from __dataclass_fields__ is a pydantic.FieldInfo
setattr(cls, ann_name, field_info.default)

Expand All @@ -354,8 +385,8 @@ def collect_dataclass_fields(


def is_valid_field_name(name: str) -> bool:
return not name.startswith('_')
return not (name and name[0] == "_")


def is_valid_privateattr_name(name: str) -> bool:
return name.startswith('_') and not name.startswith('__')
return name.startswith("_") and not name.startswith("__")