Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements bulk_create for create_batch if available #925

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions factory/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,18 @@ def chain(self):
parent_chain = ()
return (self.stub,) + parent_chain

def recurse(self, factory, declarations, force_sequence=None):
def recurse(self, factory, declarations, force_sequence=None, collect_instances=None):
from . import base
if not issubclass(factory, base.BaseFactory):
raise errors.AssociatedClassError(
"%r: Attempting to recursing into a non-factory object %r"
% (self, factory))
builder = self.builder.recurse(factory._meta, declarations)
return builder.build(parent_step=self, force_sequence=force_sequence)
return builder.build(
parent_step=self,
force_sequence=force_sequence,
collect_instances=collect_instances,
)

def __repr__(self):
return f"<BuildStep for {self.builder!r}>"
Expand All @@ -246,7 +250,7 @@ def __init__(self, factory_meta, extras, strategy):
self.extras = extras
self.force_init_sequence = extras.pop('__sequence', None)

def build(self, parent_step=None, force_sequence=None):
def build(self, parent_step=None, force_sequence=None, collect_instances=None):
"""Build a factory instance."""
# TODO: Handle "batch build" natively
pre, post = parse_declarations(
Expand Down Expand Up @@ -277,19 +281,23 @@ def build(self, parent_step=None, force_sequence=None):
kwargs=kwargs,
)

postgen_results = {}
for declaration_name in post.sorted():
declaration = post[declaration_name]
postgen_results[declaration_name] = declaration.declaration.evaluate_post(
if collect_instances is None:
postgen_results = {}
for declaration_name in post.sorted():
declaration = post[declaration_name]
postgen_results[declaration_name] = declaration.declaration.evaluate_post(
instance=instance,
step=step,
overrides=declaration.context,
)
self.factory_meta.use_postgeneration_results(
instance=instance,
step=step,
overrides=declaration.context,
results=postgen_results,
)
self.factory_meta.use_postgeneration_results(
instance=instance,
step=step,
results=postgen_results,
)
else:
collect_instances.append(instance)

return instance

def recurse(self, factory_meta, extras):
Expand Down
186 changes: 182 additions & 4 deletions factory/django.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,20 @@
import logging
import os
import warnings
from collections import defaultdict
from typing import Dict, TypeVar

from django.contrib.auth.hashers import make_password
from django.core import files as django_files
from django.db import IntegrityError
from django.db import IntegrityError, connections, models
from django.db.models.sql import InsertQuery

from . import base, declarations, errors
from . import base, builder, declarations, enums, errors

logger = logging.getLogger('factory.generate')


DEFAULT_DB_ALIAS = 'default' # Same as django.db.DEFAULT_DB_ALIAS
T = TypeVar("T")

_LAZY_LOADS: Dict[str, object] = {}


Expand All @@ -45,11 +45,29 @@ def _lazy_load_get_model():
_LAZY_LOADS['get_model'] = django_apps.apps.get_model


def connection_supports_bulk_insert(using):
"""
Does the database support bulk_insert
There are 2 pieces to this puzzle:
* The database needs to support `bulk_insert`
* AND it also needs to be capable of returning all the newly minted objects' id
If any of these is `False`, the database does NOT support bulk_insert
"""
db_features = connections[using].features
return (
db_features.has_bulk_insert
and db_features.can_return_rows_from_bulk_insert
)


class DjangoOptions(base.FactoryOptions):
def _build_default_options(self):
return super()._build_default_options() + [
base.OptionDefault('django_get_or_create', (), inherit=True),
base.OptionDefault('database', DEFAULT_DB_ALIAS, inherit=True),
base.OptionDefault('use_bulk_create', False, inherit=True),
base.OptionDefault('skip_postgeneration_save', False, inherit=True),
]

Expand Down Expand Up @@ -165,6 +183,89 @@ def _get_or_create(cls, model_class, *args, **kwargs):

return instance

@classmethod
def supports_bulk_insert(cls):
return (cls._meta.use_bulk_create
and connection_supports_bulk_insert(cls._meta.database))

@classmethod
def create(cls, **kwargs):
"""Create an instance of the associated class, with overridden attrs."""
if not cls.supports_bulk_insert():
return super().create(**kwargs)

return cls._bulk_create(1, **kwargs)[0]

@classmethod
def create_batch(cls, size, **kwargs):
if not cls.supports_bulk_insert():
return super().create_batch(size, **kwargs)

return cls._bulk_create(size, **kwargs)

@classmethod
def _refresh_database_pks(cls, model_cls, objs):
# Avoid causing a django.core.exceptions.AppRegistryNotReady throughout all the tests.
# TODO: remove the `from . import django` from the `__init__.py`
francoisfreitag marked this conversation as resolved.
Show resolved Hide resolved
from django.contrib.contenttypes.fields import GenericForeignKey

def get_field_value(instance, field):
if isinstance(field, GenericForeignKey) and field.is_cached(instance):
return field.get_cached_value(instance)
return getattr(instance, field.name)

# Current Django version's GenericForeignKey is not made to work with bulk_insert.
#
# The issue is that it caches the object referenced, once the object is
# saved and receives a pk, the cache no longer matches. It doesn't
# matter that it's the same obj reference. This is to bypass that pk
# check and reset it.
fields_to_reset = (GenericForeignKey, models.OneToOneField)

fields = [f for f in model_cls._meta.get_fields() if isinstance(f, fields_to_reset)]
if not fields:
return

for obj in objs:
for field in fields:
setattr(obj, field.name, get_field_value(obj, field))

@classmethod
def _bulk_create(cls, size, **kwargs):
if cls._meta.abstract:
raise errors.FactoryError(
"Cannot generate instances of abstract factory %(f)s; "
"Ensure %(f)s.Meta.model is set and %(f)s.Meta.abstract "
"is either not set or False." % dict(f=cls.__name__))

models_to_return = []
instances = []
for _ in range(size):
step = builder.StepBuilder(cls._meta, kwargs, enums.BUILD_STRATEGY)
models_to_return.append(step.build(collect_instances=instances))

for model_cls, objs in dependency_insert_order(instances):
manager = cls._get_manager(model_cls)
cls._refresh_database_pks(model_cls, objs)

concrete_model = True
for parent in model_cls._meta.get_parent_list():
if parent._meta.concrete_model is not model_cls._meta.concrete_model:
concrete_model = False

if concrete_model:
manager.bulk_create(objs)
else:
concrete_fields = model_cls._meta.local_fields
connection = connections[cls._meta.database]

# Avoids writing the INSERT INTO sql script manually
query = InsertQuery(model_cls)
query.insert_values(concrete_fields, objs)
query.get_compiler(connection=connection).execute_sql()

return models_to_return

@classmethod
def _create(cls, model_class, *args, **kwargs):
"""Create an instance of the model, and save it to the database."""
Expand Down Expand Up @@ -272,6 +373,82 @@ def _make_data(self, params):
return thumb_io.getvalue()


def dependency_insert_order(data):
"""This is almost the same function from django/core/serializers/__init__.py:sort_dependencies with a slight
modification on `if hasattr(rel_model, 'natural_key') and rel_model != model:` that was removed, so we have the
REAL dependency order. The original implementation was setup to only write to fields in order if they had a known
dependency, we always want it in order regardless of the natural_key.
"""

lookup = []
model_cls_by_data = defaultdict(list)
for instance in data:
# Instance has been persisted in the database
if not instance._state.adding:
continue
# Instance already in the list
if id(instance) in lookup:
continue
model_cls_by_data[type(instance)].append(instance)

# Avoid data leaks
del lookup
del data

# Process the list of models, and get the list of dependencies
model_dependencies = []
models = list(model_cls_by_data.keys())

for model in models:
deps = set()

# Now add a dependency for any FK relation with a model that
# defines a natural key
for field in model._meta.fields:
rel_model = field.related_model
if rel_model and rel_model != model:
deps.add(rel_model)

model_dependencies.append((model, deps))

model_dependencies.reverse()
# Now sort the models to ensure that dependencies are met. This
# is done by repeatedly iterating over the input list of models.
# If all the dependencies of a given model are in the final list,
# that model is promoted to the end of the final list. This process
# continues until the input list is empty, or we do a full iteration
# over the input models without promoting a model to the final list.
# If we do a full iteration without a promotion, that means there are
# circular dependencies in the list.
model_list = []
while model_dependencies:
skipped = []
changed = False
while model_dependencies:
model, deps = model_dependencies.pop()

# If all of the models in the dependency list are either already
# on the final model list, or not on the original serialization list,
# then we've found another model with all it's dependencies satisfied.
found = True
for candidate in ((d not in models or d in model_list) for d in deps):
if not candidate:
found = False
if found:
model_list.append(model)
changed = True
else:
skipped.append((model, deps))
if not changed:
unresolved_models = (f'{model._meta.app_label}.{model._meta.object_name}'
for model, _ in sorted(skipped, key=lambda obj: obj[0].__name__))
message = f"Can't resolve dependencies for {', '.join(unresolved_models)}."
raise RuntimeError(message)
model_dependencies = skipped

return [(model_cls, model_cls_by_data[model_cls]) for model_cls in model_list]


class mute_signals:
"""Temporarily disables and then restores any django signals.
Expand Down Expand Up @@ -327,6 +504,7 @@ def __call__(self, callable_obj):
if isinstance(callable_obj, base.FactoryMetaClass):
# Retrieve __func__, the *actual* callable object.
callable_obj._create = self.wrap_method(callable_obj._create.__func__)
callable_obj._bulk_create = self.wrap_method(callable_obj._bulk_create.__func__)
callable_obj._generate = self.wrap_method(callable_obj._generate.__func__)
callable_obj._after_postgeneration = self.wrap_method(
callable_obj._after_postgeneration.__func__
Expand Down
45 changes: 45 additions & 0 deletions tests/djapp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import os.path

from django.conf import settings
from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
from django.db import models
from django.db.models import signals

Expand Down Expand Up @@ -137,3 +139,46 @@ class FromAbstractWithCustomManager(AbstractWithCustomManager):

class HasMultifieldModel(models.Model):
multifield = models.ForeignKey(to=MultifieldModel, on_delete=models.CASCADE)


class P(models.Model):
pass


class R(models.Model):
is_default = models.BooleanField(default=False)
p = models.ForeignKey(P, models.CASCADE, null=True)


class S(models.Model):
r = models.ForeignKey(R, models.CASCADE)


class T(models.Model):
s = models.ForeignKey(S, models.CASCADE)


class U(models.Model):
t = models.ForeignKey(T, models.CASCADE)


class RChild(R):
text = models.CharField(max_length=10)


class A(models.Model):
p_o = models.OneToOneField('P', models.CASCADE, related_name="+")
p_f = models.ForeignKey('P', models.CASCADE, related_name="+")
p_m = models.ManyToManyField('P')


class AA(models.Model):
a = models.OneToOneField(A, models.CASCADE)
u = models.OneToOneField(U, models.CASCADE)
p = models.OneToOneField(P, models.CASCADE)


class GenericModel(models.Model):
content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE)
object_id = models.PositiveIntegerField()
generic_obj = GenericForeignKey("content_type", "object_id")
3 changes: 2 additions & 1 deletion tests/djapp/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@


INSTALLED_APPS = [
'tests.djapp'
'django.contrib.contenttypes',
'tests.djapp',
]

MIDDLEWARE_CLASSES = ()
Expand Down
Loading
Loading