Skip to content

Commit

Permalink
Add support for post generation in async creation
Browse files Browse the repository at this point in the history
  • Loading branch information
nadege committed Jul 23, 2023
1 parent e282b2a commit 054a945
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 12 deletions.
2 changes: 1 addition & 1 deletion factory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def instantiate(self, step, args, kwargs):
def use_postgeneration_results(self, step, instance, results):
self.factory._after_postgeneration(
instance,
create=step.builder.strategy == enums.CREATE_STRATEGY,
create=step.builder.strategy != enums.BUILD_STRATEGY,
results=results,
)

Expand Down
35 changes: 25 additions & 10 deletions factory/builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Build factory instances."""

import asyncio
import collections

from . import enums, errors, utils
Expand Down Expand Up @@ -277,19 +278,33 @@ def build(self, parent_step=None, force_sequence=None):
kwargs=kwargs,
)

postgen_results = {}
for declaration_name in post.sorted():
declaration = post[declaration_name]
postgen_results[declaration_name] = declaration.declaration.evaluate_post(
def _handle_post_generation(instance):
postgen_results = {}
for declaration_name in post.sorted():
declaration = post[declaration_name]
postgen_results[declaration_name] = declaration.declaration.evaluate_post(
instance=instance,
step=step,
overrides=declaration.context,
)

self.factory_meta.use_postgeneration_results(
instance=instance,
step=step,
overrides=declaration.context,
results=postgen_results,
)
self.factory_meta.use_postgeneration_results(
instance=instance,
step=step,
results=postgen_results,
)

if step.builder.strategy == enums.ASYNC_CREATE_STRATEGY and isinstance(instance, asyncio.Task):

def post_generation_callback(task):
instance = task.result()
_handle_post_generation(instance)

instance.add_done_callback(post_generation_callback)

else:
_handle_post_generation(instance)

return instance

def recurse(self, factory_meta, extras):
Expand Down
3 changes: 2 additions & 1 deletion factory/declarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,8 @@ def call(self, instance, step, context):
context._asdict(),
),
)
create = step.builder.strategy == enums.CREATE_STRATEGY
create = step.builder.strategy != enums.BUILD_STRATEGY

return self.function(
instance, create, context.value, **context.extra)

Expand Down
47 changes: 47 additions & 0 deletions tests/test_using.py
Original file line number Diff line number Diff line change
Expand Up @@ -2576,6 +2576,28 @@ def incr_one(self, _create, _increment):
self.assertEqual(3, obj.one)
self.assertFalse(hasattr(obj, 'incr_one'))

def test_post_generation_async(self):
class TestAsyncFactory(FakeAsyncModelFactory):
class Meta:
model = AsyncTestModel

one = 1

@factory.post_generation
def incr_one(self, _create, _increment):
self.one += 1

async def test():
obj = await TestAsyncFactory.create_async()
self.assertEqual(2, obj.one)
self.assertFalse(hasattr(obj, 'incr_one'))

obj = await TestAsyncFactory.create_async(one=2)
self.assertEqual(3, obj.one)
self.assertFalse(hasattr(obj, 'incr_one'))

asyncio.run(test())

def test_post_generation_hook(self):
class TestObjectFactory(factory.Factory):
class Meta:
Expand All @@ -2598,6 +2620,31 @@ def _after_postgeneration(cls, obj, create, results):
self.assertFalse(obj.create)
self.assertEqual({'incr_one': 42}, obj.results)

def test_post_generation_hook_async(self):
class TestAsyncFactory(FakeAsyncModelFactory):
class Meta:
model = AsyncTestModel

one = 1

@factory.post_generation
def incr_one(self, _create, _increment):
self.one += 1
return 42

@classmethod
def _after_postgeneration(cls, obj, create, results):
obj.create = create
obj.results = results

async def test():
obj = await TestAsyncFactory.create_async()
self.assertEqual(2, obj.one)
self.assertTrue(obj.create)
self.assertEqual({'incr_one': 42}, obj.results)

asyncio.run(test())

def test_post_generation_extraction(self):
class TestObjectFactory(factory.Factory):
class Meta:
Expand Down

0 comments on commit 054a945

Please sign in to comment.