From 68de8e75c6862588dd265d96567bcf34c079186b Mon Sep 17 00:00:00 2001 From: Serg Tereshchenko Date: Wed, 12 Jan 2022 18:41:00 +0200 Subject: [PATCH] Add basic typing support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Only `Factory.build()` and `Factory.create()` are properly typed, provided the class is declared as `class UserFactory(Factory[User]):`. Relies on mypy for tests. Reviewed-By: Raphaël Barrois --- Makefile | 1 + docs/changelog.rst | 1 + factory/__init__.py | 6 ++++-- factory/base.py | 21 ++++++++++++++------- factory/django.py | 7 ++++--- factory/faker.py | 3 ++- setup.cfg | 1 + tests/test_typing.py | 31 +++++++++++++++++++++++++++++++ tox.ini | 1 + 9 files changed, 59 insertions(+), 13 deletions(-) create mode 100644 tests/test_typing.py diff --git a/Makefile b/Makefile index a31a9fb8..0dbae867 100644 --- a/Makefile +++ b/Makefile @@ -54,6 +54,7 @@ testall: # DOC: Run tests for the currently installed version # Remove cgi warning when dropping support for Django<=4.1. test: + mypy --ignore-missing-imports tests/test_typing.py python \ -b \ -X dev \ diff --git a/docs/changelog.rst b/docs/changelog.rst index b98bcfce..6397f107 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -10,6 +10,7 @@ ChangeLog - Add support for Django 4.2 - Add support for Django 5.0 - Add support for Python 3.12 +- :issue:`903`: Add basic typing annotations *Bugfix:* diff --git a/factory/__init__.py b/factory/__init__.py index bdc3ac0d..8b26dddc 100644 --- a/factory/__init__.py +++ b/factory/__init__.py @@ -1,5 +1,7 @@ # Copyright: See the LICENSE file. +import sys + from .base import ( BaseDictFactory, BaseListFactory, @@ -70,10 +72,10 @@ pass __author__ = 'Raphaël Barrois ' -try: +if sys.version_info >= (3, 8): # Python 3.8+ import importlib.metadata as importlib_metadata -except ImportError: +else: import importlib_metadata __version__ = importlib_metadata.version("factory_boy") diff --git a/factory/base.py b/factory/base.py index 36b2359a..8d499501 100644 --- a/factory/base.py +++ b/factory/base.py @@ -4,11 +4,14 @@ import collections import logging import warnings +from typing import Generic, List, Type, TypeVar from . import builder, declarations, enums, errors, utils logger = logging.getLogger('factory.generate') +T = TypeVar('T') + # Factory metaclasses @@ -405,7 +408,7 @@ def reset(self, next_value=0): self.seq = next_value -class BaseFactory: +class BaseFactory(Generic[T]): """Factory base support for sequences, attributes and stubs.""" # Backwards compatibility @@ -506,12 +509,12 @@ def _create(cls, model_class, *args, **kwargs): return model_class(*args, **kwargs) @classmethod - def build(cls, **kwargs): + def build(cls, **kwargs) -> T: """Build an instance of the associated class, with overridden attrs.""" return cls._generate(enums.BUILD_STRATEGY, kwargs) @classmethod - def build_batch(cls, size, **kwargs): + def build_batch(cls, size: int, **kwargs) -> List[T]: """Build a batch of instances of the given class, with overridden attrs. Args: @@ -523,12 +526,12 @@ def build_batch(cls, size, **kwargs): return [cls.build(**kwargs) for _ in range(size)] @classmethod - def create(cls, **kwargs): + def create(cls, **kwargs) -> T: """Create an instance of the associated class, with overridden attrs.""" return cls._generate(enums.CREATE_STRATEGY, kwargs) @classmethod - def create_batch(cls, size, **kwargs): + def create_batch(cls, size: int, **kwargs) -> List[T]: """Create a batch of instances of the given class, with overridden attrs. Args: @@ -627,18 +630,22 @@ def simple_generate_batch(cls, create, size, **kwargs): return cls.generate_batch(strategy, size, **kwargs) -class Factory(BaseFactory, metaclass=FactoryMetaClass): +class Factory(BaseFactory[T], metaclass=FactoryMetaClass): """Factory base with build and create support. This class has the ability to support multiple ORMs by using custom creation functions. """ + # Backwards compatibility + AssociatedClassError: Type[Exception] + class Meta(BaseMeta): pass -# Backwards compatibility +# Add the association after metaclass execution. +# Otherwise, AssociatedClassError would be detected as a declaration. Factory.AssociatedClassError = errors.AssociatedClassError diff --git a/factory/django.py b/factory/django.py index 9526b775..b53fd5b5 100644 --- a/factory/django.py +++ b/factory/django.py @@ -9,6 +9,7 @@ import logging import os import warnings +from typing import Dict, TypeVar from django.contrib.auth.hashers import make_password from django.core import files as django_files @@ -20,9 +21,9 @@ DEFAULT_DB_ALIAS = 'default' # Same as django.db.DEFAULT_DB_ALIAS +T = TypeVar("T") - -_LAZY_LOADS = {} +_LAZY_LOADS: Dict[str, object] = {} def get_model(app, model): @@ -72,7 +73,7 @@ def get_model_class(self): return self.model -class DjangoModelFactory(base.Factory): +class DjangoModelFactory(base.Factory[T]): """Factory for Django models. This makes sure that the 'sequence' field of created objects is a new id. diff --git a/factory/faker.py b/factory/faker.py index 6ed2e28c..88ae644c 100644 --- a/factory/faker.py +++ b/factory/faker.py @@ -14,6 +14,7 @@ class Meta: import contextlib +from typing import Dict import faker import faker.config @@ -47,7 +48,7 @@ def evaluate(self, instance, step, extra): subfaker = self._get_faker(locale) return subfaker.format(self.provider, **extra) - _FAKER_REGISTRY = {} + _FAKER_REGISTRY: Dict[str, faker.Faker] = {} _DEFAULT_LOCALE = faker.config.DEFAULT_LOCALE @classmethod diff --git a/setup.cfg b/setup.cfg index 3ba2b7aa..13b09b91 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,6 +47,7 @@ dev = Django flake8 isort + mypy Pillow SQLAlchemy sqlalchemy_utils diff --git a/tests/test_typing.py b/tests/test_typing.py new file mode 100644 index 00000000..c2f8b564 --- /dev/null +++ b/tests/test_typing.py @@ -0,0 +1,31 @@ +# Copyright: See the LICENSE file. + +import dataclasses +import unittest + +import factory + + +@dataclasses.dataclass +class User: + name: str + email: str + id: int + + +class TypingTests(unittest.TestCase): + + def test_simple_factory(self) -> None: + + class UserFactory(factory.Factory[User]): + name = "John Doe" + email = "john.doe@example.org" + id = 42 + + class Meta: + model = User + + result: User + result = UserFactory.build() + result = UserFactory.create() + self.assertEqual(result.name, "John Doe") diff --git a/tox.ini b/tox.ini index 9010d318..d842c759 100644 --- a/tox.ini +++ b/tox.ini @@ -35,6 +35,7 @@ passenv = POSTGRES_HOST POSTGRES_DATABASE deps = + mypy alchemy: SQLAlchemy alchemy: sqlalchemy_utils mongo: mongoengine