From 95dfa9060a15f22d466fbe012fb0fc31d0317e0d Mon Sep 17 00:00:00 2001 From: Serg Tereshchenko Date: Wed, 12 Jan 2022 18:41:00 +0200 Subject: [PATCH] Add basic typing support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Only `Factory.build()` and `Factory.create()` are properly typed, provided the class is declared as `class UserFactory(Factory[User]):`. Relies on mypy for tests. Reviewed-By: Raphaël Barrois --- Makefile | 1 + factory/__init__.py | 6 ++++-- factory/base.py | 21 ++++++++++++++------- factory/django.py | 7 ++++--- factory/faker.py | 2 +- setup.cfg | 1 + tests/test_typing.py | 28 ++++++++++++++++++++++++++++ 7 files changed, 53 insertions(+), 13 deletions(-) create mode 100644 tests/test_typing.py diff --git a/Makefile b/Makefile index a31a9fb8..0dbae867 100644 --- a/Makefile +++ b/Makefile @@ -54,6 +54,7 @@ testall: # DOC: Run tests for the currently installed version # Remove cgi warning when dropping support for Django<=4.1. test: + mypy --ignore-missing-imports tests/test_typing.py python \ -b \ -X dev \ diff --git a/factory/__init__.py b/factory/__init__.py index bdc3ac0d..8b26dddc 100644 --- a/factory/__init__.py +++ b/factory/__init__.py @@ -1,5 +1,7 @@ # Copyright: See the LICENSE file. +import sys + from .base import ( BaseDictFactory, BaseListFactory, @@ -70,10 +72,10 @@ pass __author__ = 'Raphaël Barrois ' -try: +if sys.version_info >= (3, 8): # Python 3.8+ import importlib.metadata as importlib_metadata -except ImportError: +else: import importlib_metadata __version__ = importlib_metadata.version("factory_boy") diff --git a/factory/base.py b/factory/base.py index 36b2359a..3804a8b6 100644 --- a/factory/base.py +++ b/factory/base.py @@ -4,11 +4,14 @@ import collections import logging import warnings +from typing import Generic, List, TypeVar from . import builder, declarations, enums, errors, utils logger = logging.getLogger('factory.generate') +T = TypeVar('T') + # Factory metaclasses @@ -405,7 +408,7 @@ def reset(self, next_value=0): self.seq = next_value -class BaseFactory: +class BaseFactory(Generic[T]): """Factory base support for sequences, attributes and stubs.""" # Backwards compatibility @@ -506,12 +509,12 @@ def _create(cls, model_class, *args, **kwargs): return model_class(*args, **kwargs) @classmethod - def build(cls, **kwargs): + def build(cls, **kwargs) -> T: """Build an instance of the associated class, with overridden attrs.""" return cls._generate(enums.BUILD_STRATEGY, kwargs) @classmethod - def build_batch(cls, size, **kwargs): + def build_batch(cls, size: int, **kwargs) -> List[T]: """Build a batch of instances of the given class, with overridden attrs. Args: @@ -523,12 +526,12 @@ def build_batch(cls, size, **kwargs): return [cls.build(**kwargs) for _ in range(size)] @classmethod - def create(cls, **kwargs): + def create(cls, **kwargs) -> T: """Create an instance of the associated class, with overridden attrs.""" return cls._generate(enums.CREATE_STRATEGY, kwargs) @classmethod - def create_batch(cls, size, **kwargs): + def create_batch(cls, size: int, **kwargs) -> List[T]: """Create a batch of instances of the given class, with overridden attrs. Args: @@ -627,18 +630,22 @@ def simple_generate_batch(cls, create, size, **kwargs): return cls.generate_batch(strategy, size, **kwargs) -class Factory(BaseFactory, metaclass=FactoryMetaClass): +class Factory(BaseFactory[T], metaclass=FactoryMetaClass): """Factory base with build and create support. This class has the ability to support multiple ORMs by using custom creation functions. """ + # Backwards compatibility + AssociatedClassError: type[Exception] + class Meta(BaseMeta): pass -# Backwards compatibility +# Add the association after metaclass execution. +# Otherwise, AssociatedClassError would be detected as a declaration. Factory.AssociatedClassError = errors.AssociatedClassError diff --git a/factory/django.py b/factory/django.py index 9526b775..665e2dfe 100644 --- a/factory/django.py +++ b/factory/django.py @@ -9,6 +9,7 @@ import logging import os import warnings +from typing import TypeVar from django.contrib.auth.hashers import make_password from django.core import files as django_files @@ -20,9 +21,9 @@ DEFAULT_DB_ALIAS = 'default' # Same as django.db.DEFAULT_DB_ALIAS +T = TypeVar("T") - -_LAZY_LOADS = {} +_LAZY_LOADS: dict[str, object] = {} def get_model(app, model): @@ -72,7 +73,7 @@ def get_model_class(self): return self.model -class DjangoModelFactory(base.Factory): +class DjangoModelFactory(base.Factory[T]): """Factory for Django models. This makes sure that the 'sequence' field of created objects is a new id. diff --git a/factory/faker.py b/factory/faker.py index 6ed2e28c..f58e2f09 100644 --- a/factory/faker.py +++ b/factory/faker.py @@ -47,7 +47,7 @@ def evaluate(self, instance, step, extra): subfaker = self._get_faker(locale) return subfaker.format(self.provider, **extra) - _FAKER_REGISTRY = {} + _FAKER_REGISTRY: dict[str, faker.Faker] = {} _DEFAULT_LOCALE = faker.config.DEFAULT_LOCALE @classmethod diff --git a/setup.cfg b/setup.cfg index 3ba2b7aa..13b09b91 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,6 +47,7 @@ dev = Django flake8 isort + mypy Pillow SQLAlchemy sqlalchemy_utils diff --git a/tests/test_typing.py b/tests/test_typing.py new file mode 100644 index 00000000..03182c00 --- /dev/null +++ b/tests/test_typing.py @@ -0,0 +1,28 @@ +# Copyright: See the LICENSE file. + +import dataclasses +import unittest +import typing + +import factory + + +@dataclasses.dataclass +class User: + name: str + email: str + id: int + + +class TypingTests(unittest.TestCase): + def test_simple_factory(self) -> None: + class UserFactory(factory.Factory[User]): + name = "John Doe" + email = "john.doe@example.org" + id = 42 + class Meta: + model = User + + result: User + result = UserFactory.build() + result = UserFactory.create()