diff --git a/docs/changelog.rst b/docs/changelog.rst index e45e4be1..34f3cac0 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -10,6 +10,8 @@ ChangeLog - :issue:`366`: Add :class:`factory.django.Password` to generate Django :class:`~django.contrib.auth.models.User` passwords. + - :issue:`304`: Add :attr:`~factory.alchemy.SQLAlchemyOptions.sqlalchemy_session_factory` to dynamically + create sessions for use by the :class:`~factory.alchemy.SQLAlchemyModelFactory`. - Add support for Django 3.2 - Add support for Django 4.0 - Add support for Python 3.10 diff --git a/docs/orms.rst b/docs/orms.rst index e464fc4e..3d255a31 100644 --- a/docs/orms.rst +++ b/docs/orms.rst @@ -369,6 +369,25 @@ To work, this class needs an `SQLAlchemy`_ session object affected to the :attr: SQLAlchemy session to use to communicate with the database when creating an object through this :class:`SQLAlchemyModelFactory`. + .. attribute:: sqlalchemy_session_factory + + .. versionadded:: 3.3.0 + + :class:`~collections.abc.Callable` returning a :class:`~sqlalchemy.orm.Session` instance to use to communicate + with the database. You can either provide the session through this attribute, or through + :attr:`~factory.alchemy.SQLAlchemyOptions.sqlalchemy_session`, but not both at the same time. + + .. code-block:: python + + from . import common + + class UserFactory(factory.alchemy.SQLAlchemyModelFactory): + class Meta: + model = User + sqlalchemy_session_factory = lambda: common.Session() + + username = 'john' + .. attribute:: sqlalchemy_session_persistence Control the action taken by ``sqlalchemy_session`` at the end of a create call. diff --git a/docs/recipes.rst b/docs/recipes.rst index 6afa9d1a..a1717ad5 100644 --- a/docs/recipes.rst +++ b/docs/recipes.rst @@ -52,8 +52,10 @@ simply use a :class:`factory.Iterator` on the chosen queryset: language = factory.Iterator(models.Language.objects.all()) -Here, ``models.Language.objects.all()`` won't be evaluated until the -first call to ``UserFactory``; thus avoiding DB queries at import time. +Here, ``models.Language.objects.all()`` is a +:class:`~django.db.models.query.QuerySet` and will only hit the database when +``factory_boy`` starts iterating on it, i.e on the first call to +``UserFactory``; thus avoiding DB queries at import time. Reverse dependencies (reverse ForeignKey) diff --git a/factory/alchemy.py b/factory/alchemy.py index ef7a591a..cf20b537 100644 --- a/factory/alchemy.py +++ b/factory/alchemy.py @@ -22,10 +22,18 @@ def _check_sqlalchemy_session_persistence(self, meta, value): (meta, VALID_SESSION_PERSISTENCE_TYPES, value) ) + @staticmethod + def _check_has_sqlalchemy_session_set(meta, value): + if value and meta.sqlalchemy_session: + raise RuntimeError("Provide either a sqlalchemy_session or a sqlalchemy_session_factory, not both") + def _build_default_options(self): return super()._build_default_options() + [ base.OptionDefault('sqlalchemy_get_or_create', (), inherit=True), base.OptionDefault('sqlalchemy_session', None, inherit=True), + base.OptionDefault( + 'sqlalchemy_session_factory', None, inherit=True, checker=self._check_has_sqlalchemy_session_set + ), base.OptionDefault( 'sqlalchemy_session_persistence', None, @@ -90,6 +98,10 @@ def _get_or_create(cls, model_class, session, args, kwargs): @classmethod def _create(cls, model_class, *args, **kwargs): """Create an instance of the model, and save it to the database.""" + session_factory = cls._meta.sqlalchemy_session_factory + if session_factory: + cls._meta.sqlalchemy_session = session_factory() + session = cls._meta.sqlalchemy_session if session is None: diff --git a/factory/django.py b/factory/django.py index 3d60d92c..cffe1b42 100644 --- a/factory/django.py +++ b/factory/django.py @@ -24,7 +24,7 @@ DEFAULT_DB_ALIAS = 'default' # Same as django.db.DEFAULT_DB_ALIAS -DJANGO_22 = Version('2.2') <= Version(django_version) < Version('3.0') +DJANGO_22 = Version(django_version) < Version('3.0') _LAZY_LOADS = {} @@ -205,9 +205,18 @@ def create_batch(cls, size, **kwargs): @classmethod def _refresh_database_pks(cls, model_cls, objs): + """ + Before Django 3.0, there is an issue when bulk_insert. + + The issue is that if you create an instance of a model, + and reference it in another unsaved instance of a model. + When you create the instance of the first one, the pk/id + is never updated on the sub model that referenced the first. + """ if not DJANGO_22: return - fields = [f for f in model_cls._meta.get_fields() if isinstance(f, models.fields.related.ForeignObject)] + fields = [f for f in model_cls._meta.get_fields() + if isinstance(f, models.fields.related.ForeignObject)] if not fields: return for obj in objs: @@ -217,17 +226,13 @@ def _refresh_database_pks(cls, model_cls, objs): @classmethod def _bulk_create(cls, size, **kwargs): models_to_create = cls.build_batch(size, **kwargs) - collector = Collector(cls._meta.database) + collector = DependencyInsertOrderCollector() collector.collect(cls, models_to_create) collector.sort() for model_cls, objs in collector.data.items(): manager = cls._get_manager(model_cls) - for instance in objs: - models.signals.pre_save.send(model_cls, instance=instance, created=False) cls._refresh_database_pks(model_cls, objs) manager.bulk_create(objs) - for instance in objs: - models.signals.post_save.send(model_cls, instance=instance, created=True) return models_to_create @classmethod @@ -334,19 +339,10 @@ def _make_data(self, params): return thumb_io.getvalue() -class Collector: - def __init__(self, using): - self.using = using +class DependencyInsertOrderCollector: + def __init__(self): # Initially, {model: {instances}}, later values become lists. self.data = defaultdict(list) - # {model: {(field, value): {instances}}} - self.field_updates = defaultdict(functools.partial(defaultdict, set)) - # {model: {field: {instances}}} - self.restricted_objects = defaultdict(functools.partial(defaultdict, set)) - # fast_deletes is a list of queryset-likes that can be deleted without - # fetching the objects into memory. - self.fast_deletes = [] - # Tracks deletion-order dependency for databases without transactions # or ability to defer constraint checks. Only concrete model classes # should be included, as the dependencies exist only between actual @@ -354,9 +350,9 @@ def __init__(self, using): # parent. self.dependencies = defaultdict(set) # {model: {models}} - def add(self, objs, source=None, nullable=False, reverse_dependency=False): + def add(self, objs, source=None, nullable=False): """ - Add 'objs' to the collection of objects to be deleted. If the call is + Add 'objs' to the collection of objects to be inserted in order. If the call is the result of a cascade, 'source' should be the model that caused it, and 'nullable' should be set to True if the relation can be null. Return a list of all objects that were not already collected. @@ -372,21 +368,15 @@ def add(self, objs, source=None, nullable=False, reverse_dependency=False): continue if id(obj) not in lookup: new_objs.append(obj) - # import ipdb; ipdb.sset_trace() instances.extend(new_objs) # Nullable relationships can be ignored -- they are nulled out before # deleting, and therefore do not affect the order in which objects have # to be deleted. if source is not None and not nullable: - self.add_dependency(source, model, reverse_dependency=reverse_dependency) - # if not nullable: - # import ipdb; ipdb.sset_trace() - # self.add_dependency(source, model, reverse_dependency=reverse_dependency) + self.add_dependency(source, model) return new_objs - def add_dependency(self, model, dependency, reverse_dependency=False): - if reverse_dependency: - model, dependency = dependency, model + def add_dependency(self, model, dependency): self.dependencies[model._meta.concrete_model].add( dependency._meta.concrete_model ) @@ -398,11 +388,6 @@ def collect( objs, source=None, nullable=False, - collect_related=True, - source_attr=None, - reverse_dependency=False, - keep_parents=False, - fail_on_restricted=True, ): """ Add 'objs' to the collection of objects to be deleted as well as all @@ -412,10 +397,6 @@ def collect( If the call is the result of a cascade, 'source' should be the model that caused it and 'nullable' should be set to True, if the relation can be null. - If 'reverse_dependency' is True, 'source' will be deleted before the - current model, rather than after. (Needed for cascading to parent - models, the one case in which the cascade follows the forwards - direction of an FK rather than the reverse direction.) If 'keep_parents' is True, data of parent model's will be not deleted. If 'fail_on_restricted' is False, error won't be raised even if it's prohibited to delete such objects due to RESTRICT, that defers @@ -424,32 +405,27 @@ def collect( can be deleted. """ new_objs = self.add( - objs, source, nullable, reverse_dependency=reverse_dependency + objs, source, nullable ) if not new_objs: return - # import ipdb; ipdb.sset_trace() model = new_objs[0].__class__ - def get_candidate_relations(opts): - # The candidate relations are the ones that come from N-1 and 1-1 relations. - # N-N (i.e., many-to-many) relations aren't candidates for deletion. - return ( - f - for f in opts.get_fields(include_hidden=True) - if isinstance(f, models.ForeignKey) - ) + # The candidate relations are the ones that come from N-1 and 1-1 relations. + candidate_relations = ( + f for f in model._meta.get_fields(include_hidden=True) + if isinstance(f, models.ForeignKey) + ) collected_objs = [] - for field in get_candidate_relations(model._meta): + for field in candidate_relations: for obj in new_objs: val = getattr(obj, field.name) if isinstance(val, models.Model): collected_objs.append(val) - for name, _ in factory_cls._meta.post_declarations.as_dict().items(): - + for name, in factory_cls._meta.post_declarations.as_dict().keys(): for obj in new_objs: val = getattr(obj, name, None) if isinstance(val, models.Model): @@ -457,14 +433,19 @@ def get_candidate_relations(opts): if collected_objs: new_objs = self.collect( - factory_cls=factory_cls, objs=collected_objs, source=model, reverse_dependency=False + factory_cls=factory_cls, objs=collected_objs, source=model ) def sort(self): + """ + Sort the model instances by the least dependecies to the most dependencies. + + We want to insert the models with no dependencies first, and continue inserting + using the models that the higher models depend on. + """ sorted_models = [] concrete_models = set() models = list(self.data) - # import ipdb; ipdb.sset_trace() while len(sorted_models) < len(models): found = False for model in models: @@ -476,6 +457,7 @@ def sort(self): concrete_models.add(model._meta.concrete_model) found = True if not found: + logger.debug('dependency order could not be determined') return self.data = {model: self.data[model] for model in sorted_models} diff --git a/tests/test_alchemy.py b/tests/test_alchemy.py index 03410838..005fb0fa 100644 --- a/tests/test_alchemy.py +++ b/tests/test_alchemy.py @@ -264,6 +264,34 @@ def test_build_does_not_raises_exception_when_no_session_was_set(self): self.assertEqual(inst1.id, 1) +class SQLAlchemySessionFactoryTestCase(unittest.TestCase): + + def test_create_get_session_from_sqlalchemy_session_factory(self): + class SessionGetterFactory(SQLAlchemyModelFactory): + class Meta: + model = models.StandardModel + sqlalchemy_session = None + sqlalchemy_session_factory = lambda: models.session + + id = factory.Sequence(lambda n: n) + + SessionGetterFactory.create() + self.assertEqual(SessionGetterFactory._meta.sqlalchemy_session, models.session) + # Reuse the session obtained from sqlalchemy_session_factory. + SessionGetterFactory.create() + + def test_create_raise_exception_sqlalchemy_session_factory_not_callable(self): + message = "^Provide either a sqlalchemy_session or a sqlalchemy_session_factory, not both$" + with self.assertRaisesRegex(RuntimeError, message): + class SessionAndGetterFactory(SQLAlchemyModelFactory): + class Meta: + model = models.StandardModel + sqlalchemy_session = models.session + sqlalchemy_session_factory = lambda: models.session + + id = factory.Sequence(lambda n: n) + + class NameConflictTests(unittest.TestCase): """Regression test for `TypeError: _save() got multiple values for argument 'session'` diff --git a/tests/test_django.py b/tests/test_django.py index 3d59522e..fa576b28 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -175,6 +175,16 @@ class Meta: level_2 = factory.SubFactory(Level2Factory) +class DependencyInsertOrderCollector(django_test.TestCase): + + def test_empty(self): + collector = factory.django.DependencyInsertOrderCollector() + collector.collect(Level2Factory, []) + collector.sort() + + self.assertEqual(collector.data, {}) + + @unittest.skipIf(SKIP_BULK_INSERT, "bulk insert not supported by current db.") class DjangoBulkInsert(django_test.TestCase):