diff --git a/django_celery_monitor/managers.py b/django_celery_monitor/managers.py index 2d21d5ae..38c83eb3 100644 --- a/django_celery_monitor/managers.py +++ b/django_celery_monitor/managers.py @@ -5,6 +5,7 @@ from celery import states from celery.events.state import Task from celery.utils.time import maybe_timedelta +import django from django.db import models, router, transaction from .utils import Now @@ -26,13 +27,30 @@ def select_for_update_or_create(self, defaults=None, **kwargs): select_for_update when getting the object. """ defaults = defaults or {} - lookup, params = self._extract_model_params(defaults, **kwargs) self._for_write = True + + if django.VERSION < (2,2): + lookup, params = self._extract_model_params(defaults, **kwargs) + with transaction.atomic(using=self.db): + try: + obj = self.select_for_update().get(**lookup) + except self.model.DoesNotExist: + obj, created = self._create_object_from_params(lookup, params) + if created: + return obj, created + for k, v in defaults.items(): + setattr(obj, k, v() if callable(v) else v) + obj.save(using=self.db) + return obj, False + with transaction.atomic(using=self.db): try: - obj = self.select_for_update().get(**lookup) + obj = self.select_for_update().get(**kwargs) except self.model.DoesNotExist: - obj, created = self._create_object_from_params(lookup, params) + params = self._extract_model_params(defaults, **kwargs) + # Lock the row so that a concurrent update is blocked until + # after update_or_create() has performed its save. + obj, created = self._create_object_from_params(kwargs, params, lock=True) if created: return obj, created for k, v in defaults.items():