From 2ecf6b1394351d8caae708b6bf59d4932c38c2be Mon Sep 17 00:00:00 2001 From: Manolis Stamatogiannakis Date: Fri, 3 Jan 2025 23:58:41 +0100 Subject: [PATCH] Allow turning off context unrolling for post_generation decorator. --- docs/reference.rst | 5 +++++ factory/helpers.py | 16 ++++++++++++++-- tests/test_using.py | 46 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 2 deletions(-) diff --git a/docs/reference.rst b/docs/reference.rst index 2122b9f7..8da8a6b2 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -1672,6 +1672,11 @@ When calling the factory, some arguments will be extracted for this method: - If a ``post`` argument is passed, it will be passed as the ``extracted`` field - Any argument starting with ``post__XYZ`` will be extracted, its ``post__`` prefix removed, and added to the kwargs passed to the post-generation hook. +- By default kwargs are "unrolled" before running the post-generation hook. + This means that any lazily-evaluated constructs (e.g. a :class:`LazyFunction`) + will be evaluated before post-generation. + Unrolling can be disabled with the ``unroll_context`` decorator argument: + ``@post_generation(unroll_context=False)`` Extracted arguments won't be passed to the :attr:`~FactoryOptions.model` class. diff --git a/factory/helpers.py b/factory/helpers.py index 496de6e3..da46f36d 100644 --- a/factory/helpers.py +++ b/factory/helpers.py @@ -113,5 +113,17 @@ def container_attribute(func): return declarations.ContainerAttribute(func, strict=False) -def post_generation(fun): - return declarations.PostGeneration(fun) +def post_generation(fun=None, unroll_context=True): + """Post-generation decorator that allows turning context unrolling on/off. + + Turning off context unrolling is useful e.g. for passing a LazyFunction as + a post-generation keyword argument. + """ + class PostGeneration(declarations.PostGeneration): + UNROLL_CONTEXT_BEFORE_EVALUATION = unroll_context + + def post_generation_(fun): + return PostGeneration(fun) + + # Note: fun will be None when the decorator is used with parentheses. + return post_generation_(fun) if fun is not None else post_generation_ diff --git a/tests/test_using.py b/tests/test_using.py index 5b2200a6..02061660 100644 --- a/tests/test_using.py +++ b/tests/test_using.py @@ -1249,6 +1249,52 @@ class Meta: self.assertEqual(5 ** 2 - 1, obj_squared.value) self.assertEqual(6 ** 2 * 5 * 2, obj_combined.value) + def test_post_generation_unroll_context(self): + class DummyFactory(factory.Factory): + class Meta: + model = Dummy + + value = 0 + generated = [] + + @factory.post_generation() + @staticmethod + def pg1(obj, create, extracted, **kwargs): + """Post-generation with context unrolling enabled.""" + if extracted is None: + return + obj.generated = DummyFactory.build_batch(extracted, value=kwargs['v']) + + @factory.post_generation(unroll_context=False) + @staticmethod + def pg2(obj, create, extracted, **kwargs): + """Post-generation with context unrolling disabled.""" + if extracted is None: + return + obj.generated = DummyFactory.build_batch(extracted, value=kwargs['v']) + + obj = DummyFactory.build(value=4) + self.assertEqual(4, obj.value) + self.assertEqual([], obj.generated) + + obj = DummyFactory.build(value=100, pg1=3, pg1__v=10) + self.assertEqual(100, obj.value) + self.assertEqual(3, len(obj.generated)) + self.assertEqual([10, 10, 10], [g.value for g in obj.generated]) + self.assertTrue(all(g.generated == [] for g in obj.generated)) + + obj = DummyFactory.build(value=100, pg1=2, pg1__v=factory.Iterator([78, 79, 80])) + self.assertEqual(100, obj.value) + self.assertEqual(2, len(obj.generated)) + self.assertEqual([78, 78], [g.value for g in obj.generated]) + self.assertTrue(all(g.generated == [] for g in obj.generated)) + + obj = DummyFactory.build(value=100, pg2=3, pg2__v=factory.Iterator([78, 79, 80])) + self.assertEqual(100, obj.value) + self.assertEqual(3, len(obj.generated)) + self.assertEqual([78, 79, 80], [g.value for g in obj.generated]) + self.assertTrue(all(g.generated == [] for g in obj.generated)) + class TraitTestCase(unittest.TestCase): def test_traits(self):