diff --git a/invenio_db/uow.py b/invenio_db/uow.py index f5830f3..79a2408 100644 --- a/invenio_db/uow.py +++ b/invenio_db/uow.py @@ -108,6 +108,10 @@ def on_post_commit(self, uow): """Called right after the commit phase.""" pass + def on_exception(self, uow, exception): + """Called in case of an exception.""" + pass + def on_rollback(self, uow): """Called in the rollback phase (after the transaction rollback).""" pass @@ -165,10 +169,10 @@ def __enter__(self): """Entering the context.""" return self - def __exit__(self, exc_type, *args): + def __exit__(self, exc_type, exc_value, traceback): """Rollback on exception.""" if exc_type is not None: - self.rollback() + self.rollback(exception=exc_value) self._mark_dirty() @property @@ -193,9 +197,17 @@ def commit(self): op.on_post_commit(self) self._mark_dirty() - def rollback(self): + def rollback(self, exception=None): """Rollback the database session.""" self.session.rollback() + + # Run exception operations + for op in self._operations: + op.on_exception(self, exception) + + # Commit exception operations + self.session.commit() + # Run rollback operations for op in self._operations: op.on_rollback(self)