Skip to content

Commit

Permalink
Allow turning off context unrolling for post_generation decorator.
Browse files Browse the repository at this point in the history
  • Loading branch information
m000 committed Jan 3, 2025
1 parent 4209372 commit 78c2ec3
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 2 deletions.
5 changes: 5 additions & 0 deletions docs/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1672,6 +1672,11 @@ When calling the factory, some arguments will be extracted for this method:
- If a ``post`` argument is passed, it will be passed as the ``extracted`` field
- Any argument starting with ``post__XYZ`` will be extracted, its ``post__`` prefix
removed, and added to the kwargs passed to the post-generation hook.
- By default kwargs are "unrolled" before running the post-generation hook.
This means that any lazily-evaluated constructs (e.g. a :class:`LazyFunction`)
will be evaluated before post-generation.
Unrolling can be disabled with the ``unroll_context`` decorator argument:
``@post_generation(unroll_context=False)``

Extracted arguments won't be passed to the :attr:`~FactoryOptions.model` class.

Expand Down
16 changes: 14 additions & 2 deletions factory/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,17 @@ def container_attribute(func):
return declarations.ContainerAttribute(func, strict=False)


def post_generation(fun):
return declarations.PostGeneration(fun)
def post_generation(fun=None, unroll_context=True):
"""Post-generation decorator that allows turning context unrolling on/off.
Turning off context unrolling is useful e.g. for passing a LazyFunction as
a post-generation keyword argument.
"""
class PostGeneration(declarations.PostGeneration):
UNROLL_CONTEXT_BEFORE_EVALUATION = unroll_context

def post_generation_(fun):
return PostGeneration(fun)

# Note: fun will be None when the decorator is used with parentheses.
return post_generation_(fun) if fun is not None else post_generation_
46 changes: 46 additions & 0 deletions tests/test_using.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,52 @@ class Meta:
self.assertEqual(5 ** 2 - 1, obj_squared.value)
self.assertEqual(6 ** 2 * 5 * 2, obj_combined.value)

def test_post_generation_unrolling(self):
class DummyFactory(factory.Factory):
class Meta:
model = Dummy

value = 0
generated = []

@factory.post_generation()
@staticmethod
def pg1(obj, create, extracted, **kwargs):
"""Post-generation with context unrolling enabled."""
if extracted is None:
return
obj.generated = DummyFactory.build_batch(extracted, value=kwargs['v'])

@factory.post_generation(unroll_context=False)
@staticmethod
def pg2(obj, create, extracted, **kwargs):
"""Post-generation with context unrolling disabled."""
if extracted is None:
return
obj.generated = DummyFactory.build_batch(extracted, value=kwargs['v'])

obj = DummyFactory.build(value=4)
self.assertEqual(4, obj.value)
self.assertEqual([], obj.generated)

obj = DummyFactory.build(value=100, pg1=3, pg1__v=10)
self.assertEqual(100, obj.value)
self.assertEqual(3, len(obj.generated))
self.assertEqual([10, 10, 10], [g.value for g in obj.generated])
self.assertTrue(all(g.generated == [] for g in obj.generated))

obj = DummyFactory.build(value=100, pg1=2, pg1__v=factory.Iterator([78, 79, 80]))
self.assertEqual(100, obj.value)
self.assertEqual(2, len(obj.generated))
self.assertEqual([78, 78], [g.value for g in obj.generated])
self.assertTrue(all(g.generated == [] for g in obj.generated))

obj = DummyFactory.build(value=100, pg2=3, pg2__v=factory.Iterator([78, 79, 80]))
self.assertEqual(100, obj.value)
self.assertEqual(3, len(obj.generated))
self.assertEqual([78, 79, 80], [g.value for g in obj.generated])
self.assertTrue(all(g.generated == [] for g in obj.generated))


class TraitTestCase(unittest.TestCase):
def test_traits(self):
Expand Down

0 comments on commit 78c2ec3

Please sign in to comment.