diff --git a/docs/changelog.rst b/docs/changelog.rst index e45e4be1..34f3cac0 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -10,6 +10,8 @@ ChangeLog - :issue:`366`: Add :class:`factory.django.Password` to generate Django :class:`~django.contrib.auth.models.User` passwords. + - :issue:`304`: Add :attr:`~factory.alchemy.SQLAlchemyOptions.sqlalchemy_session_factory` to dynamically + create sessions for use by the :class:`~factory.alchemy.SQLAlchemyModelFactory`. - Add support for Django 3.2 - Add support for Django 4.0 - Add support for Python 3.10 diff --git a/docs/orms.rst b/docs/orms.rst index e464fc4e..3d255a31 100644 --- a/docs/orms.rst +++ b/docs/orms.rst @@ -369,6 +369,25 @@ To work, this class needs an `SQLAlchemy`_ session object affected to the :attr: SQLAlchemy session to use to communicate with the database when creating an object through this :class:`SQLAlchemyModelFactory`. + .. attribute:: sqlalchemy_session_factory + + .. versionadded:: 3.3.0 + + :class:`~collections.abc.Callable` returning a :class:`~sqlalchemy.orm.Session` instance to use to communicate + with the database. You can either provide the session through this attribute, or through + :attr:`~factory.alchemy.SQLAlchemyOptions.sqlalchemy_session`, but not both at the same time. + + .. code-block:: python + + from . import common + + class UserFactory(factory.alchemy.SQLAlchemyModelFactory): + class Meta: + model = User + sqlalchemy_session_factory = lambda: common.Session() + + username = 'john' + .. attribute:: sqlalchemy_session_persistence Control the action taken by ``sqlalchemy_session`` at the end of a create call. diff --git a/factory/alchemy.py b/factory/alchemy.py index ef7a591a..cf20b537 100644 --- a/factory/alchemy.py +++ b/factory/alchemy.py @@ -22,10 +22,18 @@ def _check_sqlalchemy_session_persistence(self, meta, value): (meta, VALID_SESSION_PERSISTENCE_TYPES, value) ) + @staticmethod + def _check_has_sqlalchemy_session_set(meta, value): + if value and meta.sqlalchemy_session: + raise RuntimeError("Provide either a sqlalchemy_session or a sqlalchemy_session_factory, not both") + def _build_default_options(self): return super()._build_default_options() + [ base.OptionDefault('sqlalchemy_get_or_create', (), inherit=True), base.OptionDefault('sqlalchemy_session', None, inherit=True), + base.OptionDefault( + 'sqlalchemy_session_factory', None, inherit=True, checker=self._check_has_sqlalchemy_session_set + ), base.OptionDefault( 'sqlalchemy_session_persistence', None, @@ -90,6 +98,10 @@ def _get_or_create(cls, model_class, session, args, kwargs): @classmethod def _create(cls, model_class, *args, **kwargs): """Create an instance of the model, and save it to the database.""" + session_factory = cls._meta.sqlalchemy_session_factory + if session_factory: + cls._meta.sqlalchemy_session = session_factory() + session = cls._meta.sqlalchemy_session if session is None: diff --git a/tests/test_alchemy.py b/tests/test_alchemy.py index 03410838..005fb0fa 100644 --- a/tests/test_alchemy.py +++ b/tests/test_alchemy.py @@ -264,6 +264,34 @@ def test_build_does_not_raises_exception_when_no_session_was_set(self): self.assertEqual(inst1.id, 1) +class SQLAlchemySessionFactoryTestCase(unittest.TestCase): + + def test_create_get_session_from_sqlalchemy_session_factory(self): + class SessionGetterFactory(SQLAlchemyModelFactory): + class Meta: + model = models.StandardModel + sqlalchemy_session = None + sqlalchemy_session_factory = lambda: models.session + + id = factory.Sequence(lambda n: n) + + SessionGetterFactory.create() + self.assertEqual(SessionGetterFactory._meta.sqlalchemy_session, models.session) + # Reuse the session obtained from sqlalchemy_session_factory. + SessionGetterFactory.create() + + def test_create_raise_exception_sqlalchemy_session_factory_not_callable(self): + message = "^Provide either a sqlalchemy_session or a sqlalchemy_session_factory, not both$" + with self.assertRaisesRegex(RuntimeError, message): + class SessionAndGetterFactory(SQLAlchemyModelFactory): + class Meta: + model = models.StandardModel + sqlalchemy_session = models.session + sqlalchemy_session_factory = lambda: models.session + + id = factory.Sequence(lambda n: n) + + class NameConflictTests(unittest.TestCase): """Regression test for `TypeError: _save() got multiple values for argument 'session'`