diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b8a89fb2..70573aa3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,35 +6,32 @@ on: jobs: build: - name: Python ${{ matrix.python-version }} / ${{ matrix.tox-environment }} + name: Python ${{ matrix.python-version }} runs-on: ubuntu-latest strategy: fail-fast: false matrix: python-version: + - "3.7" - "3.8" - "3.9" - "3.10" + - "pypy-3.7" - "pypy-3.8" - tox-environment: - - django32-alchemy-mongoengine - - django40-alchemy-mongoengine - include: - - python-version: "3.7" - tox-environment: django22-alchemy-mongoengine - - python-version: "pypy-3.7" - tox-environment: django22-alchemy-mongoengine - - python-version: "3.7" - tox-environment: django32-alchemy-mongoengine - - python-version: "pypy-3.7" - tox-environment: django32-alchemy-mongoengine services: mongodb: image: mongo ports: - 27017:27017 + postgresdb: + image: postgres:alpine + ports: + - 5432:5432 + env: + POSTGRES_PASSWORD: password + env: TOXENV: ${{ matrix.tox-environment }} @@ -47,7 +44,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies - run: python -m pip install tox + run: python -m pip install tox tox-gh-actions - name: Run tests run: tox diff --git a/Makefile b/Makefile index 9474bae5..c1911d29 100644 --- a/Makefile +++ b/Makefile @@ -62,6 +62,7 @@ test: -Wdefault:"Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3, and in 3.9 it will stop working":DeprecationWarning:: \ -Wdefault:"set_output_charset() is deprecated":DeprecationWarning:: \ -Wdefault:"parameter codeset is deprecated":DeprecationWarning:: \ + -Wdefault:"distutils Version classes are deprecated. Use packaging.version instead":DeprecationWarning:: \ -m unittest # DOC: Test the examples diff --git a/factory/django.py b/factory/django.py index 7e5427fe..cffe1b42 100644 --- a/factory/django.py +++ b/factory/django.py @@ -9,10 +9,13 @@ import logging import os import warnings +from collections import defaultdict +from django import __version__ as django_version from django.contrib.auth.hashers import make_password from django.core import files as django_files -from django.db import IntegrityError +from django.db import IntegrityError, connections, models +from packaging.version import Version from . import base, declarations, errors @@ -21,6 +24,7 @@ DEFAULT_DB_ALIAS = 'default' # Same as django.db.DEFAULT_DB_ALIAS +DJANGO_22 = Version(django_version) < Version('3.0') _LAZY_LOADS = {} @@ -44,11 +48,31 @@ def _lazy_load_get_model(): _LAZY_LOADS['get_model'] = django_apps.apps.get_model +def connection_supports_bulk_insert(using): + """ + Does the database support bulk_insert + + There are 2 pieces to this puzzle: + * The database needs to support `bulk_insert` + * AND it also needs to be capable of returning all the newly minted objects' id + + If any of these is `False`, the database does NOT support bulk_insert + """ + connection = connections[using] + if DJANGO_22: + can_return_rows_from_bulk_insert = connection.features.can_return_ids_from_bulk_insert + else: + can_return_rows_from_bulk_insert = connection.features.can_return_rows_from_bulk_insert + return (connection.features.has_bulk_insert + and can_return_rows_from_bulk_insert) + + class DjangoOptions(base.FactoryOptions): def _build_default_options(self): return super()._build_default_options() + [ base.OptionDefault('django_get_or_create', (), inherit=True), base.OptionDefault('database', DEFAULT_DB_ALIAS, inherit=True), + base.OptionDefault('use_bulk_create', False, inherit=True), base.OptionDefault('skip_postgeneration_save', False, inherit=True), ] @@ -159,6 +183,58 @@ def _get_or_create(cls, model_class, *args, **kwargs): return instance + @classmethod + def supports_bulk_insert(cls): + return (cls._meta.use_bulk_create + and connection_supports_bulk_insert(cls._meta.database)) + + @classmethod + def create(cls, **kwargs): + """Create an instance of the associated class, with overridden attrs.""" + if not cls.supports_bulk_insert(): + return super().create(**kwargs) + + return cls._bulk_create(1, **kwargs)[0] + + @classmethod + def create_batch(cls, size, **kwargs): + if not cls.supports_bulk_insert(): + return super().create_batch(size, **kwargs) + + return cls._bulk_create(size, **kwargs) + + @classmethod + def _refresh_database_pks(cls, model_cls, objs): + """ + Before Django 3.0, there is an issue when bulk_insert. + + The issue is that if you create an instance of a model, + and reference it in another unsaved instance of a model. + When you create the instance of the first one, the pk/id + is never updated on the sub model that referenced the first. + """ + if not DJANGO_22: + return + fields = [f for f in model_cls._meta.get_fields() + if isinstance(f, models.fields.related.ForeignObject)] + if not fields: + return + for obj in objs: + for field in fields: + setattr(obj, field.name, getattr(obj, field.name)) + + @classmethod + def _bulk_create(cls, size, **kwargs): + models_to_create = cls.build_batch(size, **kwargs) + collector = DependencyInsertOrderCollector() + collector.collect(cls, models_to_create) + collector.sort() + for model_cls, objs in collector.data.items(): + manager = cls._get_manager(model_cls) + cls._refresh_database_pks(model_cls, objs) + manager.bulk_create(objs) + return models_to_create + @classmethod def _create(cls, model_class, *args, **kwargs): """Create an instance of the model, and save it to the database.""" @@ -263,6 +339,129 @@ def _make_data(self, params): return thumb_io.getvalue() +class DependencyInsertOrderCollector: + def __init__(self): + # Initially, {model: {instances}}, later values become lists. + self.data = defaultdict(list) + # Tracks deletion-order dependency for databases without transactions + # or ability to defer constraint checks. Only concrete model classes + # should be included, as the dependencies exist only between actual + # database tables; proxy models are represented here by their concrete + # parent. + self.dependencies = defaultdict(set) # {model: {models}} + + def add(self, objs, source=None, nullable=False): + """ + Add 'objs' to the collection of objects to be inserted in order. If the call is + the result of a cascade, 'source' should be the model that caused it, + and 'nullable' should be set to True if the relation can be null. + Return a list of all objects that were not already collected. + """ + if not objs: + return [] + new_objs = [] + model = objs[0].__class__ + instances = self.data[model] + lookup = [id(instance) for instance in instances] + for obj in objs: + if not obj._state.adding: + continue + if id(obj) not in lookup: + new_objs.append(obj) + instances.extend(new_objs) + # Nullable relationships can be ignored -- they are nulled out before + # deleting, and therefore do not affect the order in which objects have + # to be deleted. + if source is not None and not nullable: + self.add_dependency(source, model) + return new_objs + + def add_dependency(self, model, dependency): + self.dependencies[model._meta.concrete_model].add( + dependency._meta.concrete_model + ) + self.data.setdefault(dependency, self.data.default_factory()) + + def collect( + self, + factory_cls, + objs, + source=None, + nullable=False, + ): + """ + Add 'objs' to the collection of objects to be deleted as well as all + parent instances. 'objs' must be a homogeneous iterable collection of + model instances (e.g. a QuerySet). If 'collect_related' is True, + related objects will be handled by their respective on_delete handler. + If the call is the result of a cascade, 'source' should be the model + that caused it and 'nullable' should be set to True, if the relation + can be null. + If 'keep_parents' is True, data of parent model's will be not deleted. + If 'fail_on_restricted' is False, error won't be raised even if it's + prohibited to delete such objects due to RESTRICT, that defers + restricted object checking in recursive calls where the top-level call + may need to collect more objects to determine whether restricted ones + can be deleted. + """ + new_objs = self.add( + objs, source, nullable + ) + if not new_objs: + return + + model = new_objs[0].__class__ + + # The candidate relations are the ones that come from N-1 and 1-1 relations. + candidate_relations = ( + f for f in model._meta.get_fields(include_hidden=True) + if isinstance(f, models.ForeignKey) + ) + + collected_objs = [] + for field in candidate_relations: + for obj in new_objs: + val = getattr(obj, field.name) + if isinstance(val, models.Model): + collected_objs.append(val) + + for name, in factory_cls._meta.post_declarations.as_dict().keys(): + for obj in new_objs: + val = getattr(obj, name, None) + if isinstance(val, models.Model): + collected_objs.append(val) + + if collected_objs: + new_objs = self.collect( + factory_cls=factory_cls, objs=collected_objs, source=model + ) + + def sort(self): + """ + Sort the model instances by the least dependecies to the most dependencies. + + We want to insert the models with no dependencies first, and continue inserting + using the models that the higher models depend on. + """ + sorted_models = [] + concrete_models = set() + models = list(self.data) + while len(sorted_models) < len(models): + found = False + for model in models: + if model in sorted_models: + continue + dependencies = self.dependencies.get(model._meta.concrete_model) + if not (dependencies and dependencies.difference(concrete_models)): + sorted_models.append(model) + concrete_models.add(model._meta.concrete_model) + found = True + if not found: + logger.debug('dependency order could not be determined') + return + self.data = {model: self.data[model] for model in sorted_models} + + class mute_signals: """Temporarily disables and then restores any django signals. @@ -318,6 +517,7 @@ def __call__(self, callable_obj): if isinstance(callable_obj, base.FactoryMetaClass): # Retrieve __func__, the *actual* callable object. callable_obj._create = self.wrap_method(callable_obj._create.__func__) + callable_obj._bulk_create = self.wrap_method(callable_obj._bulk_create.__func__) callable_obj._generate = self.wrap_method(callable_obj._generate.__func__) return callable_obj diff --git a/setup.cfg b/setup.cfg index 3ae6d65f..b1d0cf68 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,7 +37,9 @@ classifiers = zip_safe = false packages = factory python_requires = >=3.7 -install_requires = Faker>=0.7.0 +install_requires = + packaging + Faker>=0.7.0 [options.extras_require] dev = diff --git a/tests/djapp/models.py b/tests/djapp/models.py index fb34e907..e19ba3bb 100644 --- a/tests/djapp/models.py +++ b/tests/djapp/models.py @@ -133,3 +133,24 @@ class Meta: class FromAbstractWithCustomManager(AbstractWithCustomManager): pass + + +class Level2(models.Model): + + foo = models.CharField(max_length=20) + + +class LevelA1(models.Model): + + level_2 = models.ForeignKey(Level2, on_delete=models.CASCADE) + + +class LevelA2(models.Model): + + level_2 = models.ForeignKey(Level2, on_delete=models.CASCADE) + + +class Level0(models.Model): + + level_a1 = models.ForeignKey(LevelA1, on_delete=models.CASCADE) + level_a2 = models.ForeignKey(LevelA2, on_delete=models.CASCADE) diff --git a/tests/djapp/settings_pg.py b/tests/djapp/settings_pg.py new file mode 100644 index 00000000..a1c3812f --- /dev/null +++ b/tests/djapp/settings_pg.py @@ -0,0 +1,35 @@ +# Copyright: See the LICENSE file. + +"""Settings for factory_boy/Django tests.""" + +import os + +from .settings import * # noqa: F401, F403 + +try: + # pypy does not support `psycopg2` or `psycopg2-binary` + # This is a package that only gets installed with pypy, and it needs to be + # initialized for it to work properly. It mimic `psycopg2` 1-to-1 + from psycopg2cffi import compat + compat.register() +except ImportError: + pass + +DATABASES = { + 'default': { + 'ENGINE': 'django.db.backends.postgresql_psycopg2', + 'NAME': os.environ.get('POSTGRES_DATABASE', 'factory_boy_test'), + 'USER': os.environ.get('POSTGRES_USER', 'postgres'), + 'PASSWORD': os.environ.get('POSTGRES_PASSWORD', 'password'), + 'HOST': os.environ.get('POSTGRES_HOST', 'localhost'), + 'PORT': os.environ.get('POSTGRES_PORT', '5432'), + }, + 'replica': { + 'ENGINE': 'django.db.backends.postgresql_psycopg2', + 'NAME': os.environ.get('POSTGRES_DATABASE', 'factory_boy_test') + '_rp', + 'USER': os.environ.get('POSTGRES_USER', 'postgres'), + 'PASSWORD': os.environ.get('POSTGRES_PASSWORD', 'password'), + 'HOST': os.environ.get('POSTGRES_HOST', 'localhost'), + 'PORT': os.environ.get('POSTGRES_PORT', '5432'), + } +} diff --git a/tests/test_django.py b/tests/test_django.py index 48b94329..fa576b28 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -11,8 +11,11 @@ from django import test as django_test from django.conf import settings from django.contrib.auth.hashers import check_password +from django.core.management.color import no_style +from django.db import connections from django.db.models import signals from django.test import utils as django_test_utils +from faker import Factory as FakerFactory import factory.django @@ -23,10 +26,15 @@ except ImportError: Image = None +faker = FakerFactory.create() + + # Setup Django before importing Django models. os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'tests.djapp.settings') django.setup() +SKIP_BULK_INSERT = not factory.django.connection_supports_bulk_insert(factory.django.DEFAULT_DB_ALIAS) + from .djapp import models # noqa:E402 isort:skip test_state = {} @@ -72,7 +80,7 @@ class Meta: model = models.MultifieldModel django_get_or_create = ['slug'] - text = factory.Faker('text') + text = factory.LazyAttribute(lambda n: faker.text()[:20]) class AbstractBaseFactory(factory.django.DjangoModelFactory): @@ -143,6 +151,68 @@ class Meta: text = factory.Sequence(lambda n: "text%s" % n) +class Level2Factory(factory.django.DjangoModelFactory): + class Meta: + model = models.Level2 + use_bulk_create = True + + foo = factory.Sequence(lambda n: "foo%s" % n) + + +class LevelA1Factory(factory.django.DjangoModelFactory): + class Meta: + model = models.LevelA1 + use_bulk_create = True + + level_2 = factory.SubFactory(Level2Factory) + + +class LevelA2Factory(factory.django.DjangoModelFactory): + class Meta: + model = models.LevelA2 + use_bulk_create = True + + level_2 = factory.SubFactory(Level2Factory) + + +class DependencyInsertOrderCollector(django_test.TestCase): + + def test_empty(self): + collector = factory.django.DependencyInsertOrderCollector() + collector.collect(Level2Factory, []) + collector.sort() + + self.assertEqual(collector.data, {}) + + +@unittest.skipIf(SKIP_BULK_INSERT, "bulk insert not supported by current db.") +class DjangoBulkInsert(django_test.TestCase): + + def test_single_object_create(self): + with self.assertNumQueries(1): + Level2Factory() + + def test_single_object_create_batch(self): + with self.assertNumQueries(1): + Level2Factory.create_batch(10) + + def test_one_level_nested_single_object_create(self): + with self.assertNumQueries(2): + LevelA1Factory() + + existing_level2 = Level2Factory() + with self.assertNumQueries(1): + LevelA1Factory(level_2=existing_level2) + + def test_one_level_nested_single_object_create_batch(self): + with self.assertNumQueries(2): + LevelA1Factory.create_batch(10) + + existing_level2 = Level2Factory() + with self.assertNumQueries(1): + LevelA1Factory.create_batch(10, level_2=existing_level2) + + class ModelTests(django_test.TestCase): databases = {'default', 'replica'} @@ -164,7 +234,16 @@ class Meta: self.assertEqual(obj, models.StandardModel.objects.using('replica').get()) -class DjangoPkSequenceTestCase(django_test.TestCase): +class DjangoResetTestCase(django_test.TestCase): + def reset_database_sequences(self, *models): + using = factory.django.DEFAULT_DB_ALIAS + with connections[using].cursor() as cursor: + sequence_sql = connections[using].ops.sequence_reset_sql(no_style(), models) + for command in sequence_sql: + cursor.execute(command) + + +class DjangoPkSequenceTestCase(DjangoResetTestCase): def setUp(self): super().setUp() StandardFactory.reset_sequence() @@ -180,6 +259,8 @@ def test_pk_many(self): self.assertEqual('foo1', std2.foo) def test_pk_creation(self): + self.reset_database_sequences(StandardFactory._meta.model) + std1 = StandardFactory.create() self.assertEqual('foo0', std1.foo) self.assertEqual(1, std1.pk) @@ -194,6 +275,8 @@ def test_pk_force_value(self): self.assertEqual('foo0', std1.foo) # sequence is unrelated to pk self.assertEqual(10, std1.pk) + self.reset_database_sequences(StandardFactory._meta.model) + StandardFactory.reset_sequence() std2 = StandardFactory.create() self.assertEqual('foo0', std2.foo) @@ -374,7 +457,8 @@ def test_force_pk(self): self.assertEqual('foo0', nonint2.pk) -class DjangoAbstractBaseSequenceTestCase(django_test.TestCase): +class DjangoAbstractBaseSequenceTestCase(DjangoResetTestCase): + def test_auto_sequence_son(self): """The sequence of the concrete son of an abstract model should be autonomous.""" obj = ConcreteSonFactory() @@ -397,6 +481,8 @@ class ConcreteSonFactory(AbstractBaseFactory): class Meta: model = models.ConcreteSon + self.reset_database_sequences(models.ConcreteSon) + obj = ConcreteSonFactory() self.assertEqual(1, obj.pk) self.assertEqual("foo0", obj.foo) @@ -1065,11 +1151,12 @@ class DjangoModelFactoryDuplicateSaveDeprecationTest(django_test.TestCase): class StandardFactoryWithPost(StandardFactory): @factory.post_generation def post_action(obj, create, extracted, **kwargs): - return 3 + obj.non_existant_field = 3 def test_create_warning(self): with self.assertWarns(DeprecationWarning) as cm: - self.StandardFactoryWithPost.create() + instance = self.StandardFactoryWithPost.create() + assert instance.non_existant_field == 3 [msg] = cm.warning.args self.assertEqual( diff --git a/tests/test_using.py b/tests/test_using.py index 07dfbb47..9d400ca7 100644 --- a/tests/test_using.py +++ b/tests/test_using.py @@ -72,6 +72,9 @@ def create(self, **kwargs): instance._defaults = None return instance + def bulk_create(self, objs, **kwargs): + return objs + def values_list(self, *args, **kwargs): return self @@ -81,6 +84,16 @@ def order_by(self, *args, **kwargs): def using(self, db): return self + class _meta: + concrete_model = None + + @staticmethod + def get_fields(*args, **kwargs): + return [] + + class _state: + adding = True + objects = FakeModelManager() def __init__(self, **kwargs): diff --git a/tox.ini b/tox.ini index 0ce0d706..0f80f232 100644 --- a/tox.ini +++ b/tox.ini @@ -2,25 +2,40 @@ minversion = 1.9 envlist = lint - py{37,38,39,py3}-django22-alchemy-mongoengine - py{37,38,39,310,py3}-django32-alchemy-mongoengine - py{38,39,310,py3}-django40-alchemy-mongoengine - py310-djangomain-alchemy-mongoengine + py{37,38,39,py37,py38}-django22-{sqlite,postgres} + py{37,38,39,310,py37,py38}-django32-{sqlite,postgres} + py{38,39,310,py38}-django40-{sqlite,postgres} + py310-djangomain-{sqlite,postgres} docs examples linkcheck toxworkdir = {env:TOX_WORKDIR:.tox} +[gh-actions] +python = + 3.7: py37 + 3.8: py38 + 3.9: py39 + 3.10: py310 + pypy-3.7: pypy37 + pypy-3.8: pypy38 + [testenv] deps = + Pillow + SQLAlchemy + mongoengine django22: Django>=2.2,<2.3 django32: Django>=3.2,<3.3 django40: Django>=4.0,<4.1 djangomain: https://github.com/django/django/archive/main.tar.gz - django{22,32,40,main}: Pillow - alchemy: SQLAlchemy - mongoengine: mongoengine + py{37,38,39,310}-django{22,32,40,main}-postgres: psycopg2-binary + py{py37,py38}-django{22,32,40,main}-postgres: psycopg2cffi + +setenv = + py: DJANGO_SETTINGS_MODULE=tests.djapp.settings + postgres: DJANGO_SETTINGS_MODULE=tests.djapp.settings_pg whitelist_externals = make commands = make test