diff --git a/factory/base.py b/factory/base.py index 36b2359a..754098c3 100644 --- a/factory/base.py +++ b/factory/base.py @@ -4,11 +4,14 @@ import collections import logging import warnings +from typing import Generic, List, TypeVar from . import builder, declarations, enums, errors, utils logger = logging.getLogger('factory.generate') +T = TypeVar('T') + # Factory metaclasses @@ -405,7 +408,7 @@ def reset(self, next_value=0): self.seq = next_value -class BaseFactory: +class BaseFactory(Generic[T]): """Factory base support for sequences, attributes and stubs.""" # Backwards compatibility @@ -506,12 +509,12 @@ def _create(cls, model_class, *args, **kwargs): return model_class(*args, **kwargs) @classmethod - def build(cls, **kwargs): + def build(cls, **kwargs) -> T: """Build an instance of the associated class, with overridden attrs.""" return cls._generate(enums.BUILD_STRATEGY, kwargs) @classmethod - def build_batch(cls, size, **kwargs): + def build_batch(cls, size, **kwargs) -> List[T]: """Build a batch of instances of the given class, with overridden attrs. Args: @@ -523,12 +526,12 @@ def build_batch(cls, size, **kwargs): return [cls.build(**kwargs) for _ in range(size)] @classmethod - def create(cls, **kwargs): + def create(cls, **kwargs) -> T: """Create an instance of the associated class, with overridden attrs.""" return cls._generate(enums.CREATE_STRATEGY, kwargs) @classmethod - def create_batch(cls, size, **kwargs): + def create_batch(cls, size, **kwargs) -> List[T]: """Create a batch of instances of the given class, with overridden attrs. Args: @@ -627,7 +630,7 @@ def simple_generate_batch(cls, create, size, **kwargs): return cls.generate_batch(strategy, size, **kwargs) -class Factory(BaseFactory, metaclass=FactoryMetaClass): +class Factory(BaseFactory[T], meta=FactoryMetaClass): """Factory base with build and create support. This class has the ability to support multiple ORMs by using custom creation diff --git a/factory/django.py b/factory/django.py index 7e5427fe..7e6e8311 100644 --- a/factory/django.py +++ b/factory/django.py @@ -8,6 +8,7 @@ import io import logging import os +from typing import TypeVar import warnings from django.contrib.auth.hashers import make_password @@ -20,7 +21,7 @@ DEFAULT_DB_ALIAS = 'default' # Same as django.db.DEFAULT_DB_ALIAS - +T = TypeVar("T") _LAZY_LOADS = {} @@ -72,7 +73,7 @@ def get_model_class(self): return self.model -class DjangoModelFactory(base.Factory): +class DjangoModelFactory(base.Factory[T]): """Factory for Django models. This makes sure that the 'sequence' field of created objects is a new id.