Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow turning off context unrolling for post_generation decorator. #1110

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1672,6 +1672,11 @@ When calling the factory, some arguments will be extracted for this method:
- If a ``post`` argument is passed, it will be passed as the ``extracted`` field
- Any argument starting with ``post__XYZ`` will be extracted, its ``post__`` prefix
removed, and added to the kwargs passed to the post-generation hook.
- By default kwargs are "unrolled" before running the post-generation hook.
This means that any lazily-evaluated constructs (e.g. a :class:`LazyFunction`)
will be evaluated before post-generation.
Unrolling can be disabled with the ``unroll_context`` decorator argument:
``@post_generation(unroll_context=False)``

Extracted arguments won't be passed to the :attr:`~FactoryOptions.model` class.

Expand Down
16 changes: 14 additions & 2 deletions factory/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,17 @@ def container_attribute(func):
return declarations.ContainerAttribute(func, strict=False)


def post_generation(fun):
return declarations.PostGeneration(fun)
def post_generation(fun=None, unroll_context=True):
"""Post-generation decorator that allows turning context unrolling on/off.

Turning off context unrolling is useful e.g. for passing a LazyFunction as
a post-generation keyword argument.
"""
class PostGeneration(declarations.PostGeneration):
UNROLL_CONTEXT_BEFORE_EVALUATION = unroll_context

def post_generation_(fun):
return PostGeneration(fun)

# Note: fun will be None when the decorator is used with parentheses.
return post_generation_(fun) if fun is not None else post_generation_
46 changes: 46 additions & 0 deletions tests/test_using.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,52 @@ class Meta:
self.assertEqual(5 ** 2 - 1, obj_squared.value)
self.assertEqual(6 ** 2 * 5 * 2, obj_combined.value)

def test_post_generation_unroll_context(self):
class DummyFactory(factory.Factory):
class Meta:
model = Dummy

value = 0
generated = []

@factory.post_generation()
@staticmethod
def pg1(obj, create, extracted, **kwargs):
"""Post-generation with context unrolling enabled."""
if extracted is None:
return
obj.generated = DummyFactory.build_batch(extracted, value=kwargs['v'])

@factory.post_generation(unroll_context=False)
@staticmethod
def pg2(obj, create, extracted, **kwargs):
"""Post-generation with context unrolling disabled."""
if extracted is None:
return
obj.generated = DummyFactory.build_batch(extracted, value=kwargs['v'])

obj = DummyFactory.build(value=4)
self.assertEqual(4, obj.value)
self.assertEqual([], obj.generated)

obj = DummyFactory.build(value=100, pg1=3, pg1__v=10)
self.assertEqual(100, obj.value)
self.assertEqual(3, len(obj.generated))
self.assertEqual([10, 10, 10], [g.value for g in obj.generated])
self.assertTrue(all(g.generated == [] for g in obj.generated))

obj = DummyFactory.build(value=100, pg1=2, pg1__v=factory.Iterator([78, 79, 80]))
self.assertEqual(100, obj.value)
self.assertEqual(2, len(obj.generated))
self.assertEqual([78, 78], [g.value for g in obj.generated])
self.assertTrue(all(g.generated == [] for g in obj.generated))

obj = DummyFactory.build(value=100, pg2=3, pg2__v=factory.Iterator([78, 79, 80]))
self.assertEqual(100, obj.value)
self.assertEqual(3, len(obj.generated))
self.assertEqual([78, 79, 80], [g.value for g in obj.generated])
self.assertTrue(all(g.generated == [] for g in obj.generated))


class TraitTestCase(unittest.TestCase):
def test_traits(self):
Expand Down