Skip to content

Commit

Permalink
Rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
kingbuzzman committed Dec 12, 2024
1 parent 4209372 commit 62cd6f5
Show file tree
Hide file tree
Showing 6 changed files with 515 additions and 22 deletions.
34 changes: 21 additions & 13 deletions factory/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,18 @@ def chain(self):
parent_chain = ()
return (self.stub,) + parent_chain

def recurse(self, factory, declarations, force_sequence=None):
def recurse(self, factory, declarations, force_sequence=None, collect_instances=None):
from . import base
if not issubclass(factory, base.BaseFactory):
raise errors.AssociatedClassError(
"%r: Attempting to recursing into a non-factory object %r"
% (self, factory))
builder = self.builder.recurse(factory._meta, declarations)
return builder.build(parent_step=self, force_sequence=force_sequence)
return builder.build(
parent_step=self,
force_sequence=force_sequence,
collect_instances=collect_instances,
)

def __repr__(self):
return f"<BuildStep for {self.builder!r}>"
Expand All @@ -246,7 +250,7 @@ def __init__(self, factory_meta, extras, strategy):
self.extras = extras
self.force_init_sequence = extras.pop('__sequence', None)

def build(self, parent_step=None, force_sequence=None):
def build(self, parent_step=None, force_sequence=None, collect_instances=None):
"""Build a factory instance."""
# TODO: Handle "batch build" natively
pre, post = parse_declarations(
Expand Down Expand Up @@ -277,19 +281,23 @@ def build(self, parent_step=None, force_sequence=None):
kwargs=kwargs,
)

postgen_results = {}
for declaration_name in post.sorted():
declaration = post[declaration_name]
postgen_results[declaration_name] = declaration.declaration.evaluate_post(
if collect_instances is None:
postgen_results = {}
for declaration_name in post.sorted():
declaration = post[declaration_name]
postgen_results[declaration_name] = declaration.declaration.evaluate_post(
instance=instance,
step=step,
overrides=declaration.context,
)
self.factory_meta.use_postgeneration_results(
instance=instance,
step=step,
overrides=declaration.context,
results=postgen_results,
)
self.factory_meta.use_postgeneration_results(
instance=instance,
step=step,
results=postgen_results,
)
else:
collect_instances.append(instance)

return instance

def recurse(self, factory_meta, extras):
Expand Down
186 changes: 182 additions & 4 deletions factory/django.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,20 @@
import logging
import os
import warnings
from collections import defaultdict
from typing import Dict, TypeVar

from django.contrib.auth.hashers import make_password
from django.core import files as django_files
from django.db import IntegrityError
from django.db import IntegrityError, connections, models
from django.db.models.sql import InsertQuery

from . import base, declarations, errors
from . import base, builder, declarations, enums, errors

logger = logging.getLogger('factory.generate')


DEFAULT_DB_ALIAS = 'default' # Same as django.db.DEFAULT_DB_ALIAS
T = TypeVar("T")

_LAZY_LOADS: Dict[str, object] = {}


Expand All @@ -45,11 +45,29 @@ def _lazy_load_get_model():
_LAZY_LOADS['get_model'] = django_apps.apps.get_model


def connection_supports_bulk_insert(using):
"""
Does the database support bulk_insert
There are 2 pieces to this puzzle:
* The database needs to support `bulk_insert`
* AND it also needs to be capable of returning all the newly minted objects' id
If any of these is `False`, the database does NOT support bulk_insert
"""
db_features = connections[using].features
return (
db_features.has_bulk_insert
and db_features.can_return_rows_from_bulk_insert
)


class DjangoOptions(base.FactoryOptions):
def _build_default_options(self):
return super()._build_default_options() + [
base.OptionDefault('django_get_or_create', (), inherit=True),
base.OptionDefault('database', DEFAULT_DB_ALIAS, inherit=True),
base.OptionDefault('use_bulk_create', False, inherit=True),
base.OptionDefault('skip_postgeneration_save', False, inherit=True),
]

Expand Down Expand Up @@ -165,6 +183,89 @@ def _get_or_create(cls, model_class, *args, **kwargs):

return instance

@classmethod
def supports_bulk_insert(cls):
return (cls._meta.use_bulk_create
and connection_supports_bulk_insert(cls._meta.database))

@classmethod
def create(cls, **kwargs):
"""Create an instance of the associated class, with overridden attrs."""
if not cls.supports_bulk_insert():
return super().create(**kwargs)

return cls._bulk_create(1, **kwargs)[0]

@classmethod
def create_batch(cls, size, **kwargs):
if not cls.supports_bulk_insert():
return super().create_batch(size, **kwargs)

return cls._bulk_create(size, **kwargs)

@classmethod
def _refresh_database_pks(cls, model_cls, objs):
# Avoid causing a django.core.exceptions.AppRegistryNotReady throughout all the tests.
# TODO: remove the `from . import django` from the `__init__.py`
from django.contrib.contenttypes.fields import GenericForeignKey

def get_field_value(instance, field):
if isinstance(field, GenericForeignKey) and field.is_cached(instance):
return field.get_cached_value(instance)
return getattr(instance, field.name)

# Current Django version's GenericForeignKey is not made to work with bulk_insert.
#
# The issue is that it caches the object referenced, once the object is
# saved and receives a pk, the cache no longer matches. It doesn't
# matter that it's the same obj reference. This is to bypass that pk
# check and reset it.
fields_to_reset = (GenericForeignKey, models.OneToOneField)

fields = [f for f in model_cls._meta.get_fields() if isinstance(f, fields_to_reset)]
if not fields:
return

for obj in objs:
for field in fields:
setattr(obj, field.name, get_field_value(obj, field))

@classmethod
def _bulk_create(cls, size, **kwargs):
if cls._meta.abstract:
raise errors.FactoryError(
"Cannot generate instances of abstract factory %(f)s; "
"Ensure %(f)s.Meta.model is set and %(f)s.Meta.abstract "
"is either not set or False." % dict(f=cls.__name__))

models_to_return = []
instances = []
for _ in range(size):
step = builder.StepBuilder(cls._meta, kwargs, enums.BUILD_STRATEGY)
models_to_return.append(step.build(collect_instances=instances))

for model_cls, objs in dependency_insert_order(instances):
manager = cls._get_manager(model_cls)
cls._refresh_database_pks(model_cls, objs)

concrete_model = True
for parent in model_cls._meta.get_parent_list():
if parent._meta.concrete_model is not model_cls._meta.concrete_model:
concrete_model = False

if concrete_model:
manager.bulk_create(objs)
else:
concrete_fields = model_cls._meta.local_fields
connection = connections[cls._meta.database]

# Avoids writing the INSERT INTO sql script manually
query = InsertQuery(model_cls)
query.insert_values(concrete_fields, objs)
query.get_compiler(connection=connection).execute_sql()

return models_to_return

@classmethod
def _create(cls, model_class, *args, **kwargs):
"""Create an instance of the model, and save it to the database."""
Expand Down Expand Up @@ -272,6 +373,82 @@ def _make_data(self, params):
return thumb_io.getvalue()


def dependency_insert_order(data):
"""This is almost the same function from django/core/serializers/__init__.py:sort_dependencies with a slight
modification on `if hasattr(rel_model, 'natural_key') and rel_model != model:` that was removed, so we have the
REAL dependency order. The original implementation was setup to only write to fields in order if they had a known
dependency, we always want it in order regardless of the natural_key.
"""

lookup = []
model_cls_by_data = defaultdict(list)
for instance in data:
# Instance has been persisted in the database
if not instance._state.adding:
continue
# Instance already in the list
if id(instance) in lookup:
continue
model_cls_by_data[type(instance)].append(instance)

# Avoid data leaks
del lookup
del data

# Process the list of models, and get the list of dependencies
model_dependencies = []
models = list(model_cls_by_data.keys())

for model in models:
deps = set()

# Now add a dependency for any FK relation with a model that
# defines a natural key
for field in model._meta.fields:
rel_model = field.related_model
if rel_model and rel_model != model:
deps.add(rel_model)

model_dependencies.append((model, deps))

model_dependencies.reverse()
# Now sort the models to ensure that dependencies are met. This
# is done by repeatedly iterating over the input list of models.
# If all the dependencies of a given model are in the final list,
# that model is promoted to the end of the final list. This process
# continues until the input list is empty, or we do a full iteration
# over the input models without promoting a model to the final list.
# If we do a full iteration without a promotion, that means there are
# circular dependencies in the list.
model_list = []
while model_dependencies:
skipped = []
changed = False
while model_dependencies:
model, deps = model_dependencies.pop()

# If all of the models in the dependency list are either already
# on the final model list, or not on the original serialization list,
# then we've found another model with all it's dependencies satisfied.
found = True
for candidate in ((d not in models or d in model_list) for d in deps):
if not candidate:
found = False
if found:
model_list.append(model)
changed = True
else:
skipped.append((model, deps))
if not changed:
unresolved_models = (f'{model._meta.app_label}.{model._meta.object_name}'
for model, _ in sorted(skipped, key=lambda obj: obj[0].__name__))
message = f"Can't resolve dependencies for {', '.join(unresolved_models)}."
raise RuntimeError(message)
model_dependencies = skipped

return [(model_cls, model_cls_by_data[model_cls]) for model_cls in model_list]


class mute_signals:
"""Temporarily disables and then restores any django signals.
Expand Down Expand Up @@ -327,6 +504,7 @@ def __call__(self, callable_obj):
if isinstance(callable_obj, base.FactoryMetaClass):
# Retrieve __func__, the *actual* callable object.
callable_obj._create = self.wrap_method(callable_obj._create.__func__)
callable_obj._bulk_create = self.wrap_method(callable_obj._bulk_create.__func__)
callable_obj._generate = self.wrap_method(callable_obj._generate.__func__)
callable_obj._after_postgeneration = self.wrap_method(
callable_obj._after_postgeneration.__func__
Expand Down
45 changes: 45 additions & 0 deletions tests/djapp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import os.path

from django.conf import settings
from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
from django.db import models
from django.db.models import signals

Expand Down Expand Up @@ -137,3 +139,46 @@ class FromAbstractWithCustomManager(AbstractWithCustomManager):

class HasMultifieldModel(models.Model):
multifield = models.ForeignKey(to=MultifieldModel, on_delete=models.CASCADE)


class P(models.Model):
pass


class R(models.Model):
is_default = models.BooleanField(default=False)
p = models.ForeignKey(P, models.CASCADE, null=True)


class S(models.Model):
r = models.ForeignKey(R, models.CASCADE)


class T(models.Model):
s = models.ForeignKey(S, models.CASCADE)


class U(models.Model):
t = models.ForeignKey(T, models.CASCADE)


class RChild(R):
text = models.CharField(max_length=10)


class A(models.Model):
p_o = models.OneToOneField('P', models.CASCADE, related_name="+")
p_f = models.ForeignKey('P', models.CASCADE, related_name="+")
p_m = models.ManyToManyField('P')


class AA(models.Model):
a = models.OneToOneField(A, models.CASCADE)
u = models.OneToOneField(U, models.CASCADE)
p = models.OneToOneField(P, models.CASCADE)


class GenericModel(models.Model):
content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE)
object_id = models.PositiveIntegerField()
generic_obj = GenericForeignKey("content_type", "object_id")
3 changes: 2 additions & 1 deletion tests/djapp/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@


INSTALLED_APPS = [
'tests.djapp'
'django.contrib.contenttypes',
'tests.djapp',
]

MIDDLEWARE_CLASSES = ()
Expand Down
Loading

0 comments on commit 62cd6f5

Please sign in to comment.