diff --git a/src/gino/crud.py b/src/gino/crud.py index 65498423..8e0d2708 100644 --- a/src/gino/crud.py +++ b/src/gino/crud.py @@ -102,7 +102,7 @@ def _set_prop(self, prop, value): self._literal = False self._props[prop] = value - async def apply(self, bind=None, timeout=DEFAULT): + async def apply(self, bind=None, timeout=DEFAULT, extra_returning_fields=tuple()): """ Apply pending updates into database by executing an ``UPDATE`` SQL. @@ -113,6 +113,9 @@ async def apply(self, bind=None, timeout=DEFAULT): ``None`` for wait forever. By default it will use the ``timeout`` execution option value if unspecified. + :param extra_returning_fields: A `tuple` of returning fields besides + fields to create/update, e.g. (`updated_at`, `created_at`) + :return: ``self`` for chaining calls. """ @@ -174,9 +177,9 @@ async def apply(self, bind=None, timeout=DEFAULT): ) .execution_options(**opts) ) - await _query_and_update( - bind, self._instance, clause, [getattr(cls, key) for key in values], opts - ) + cols = tuple(getattr(cls, key) for key in values) + extra_cols = tuple(getattr(cls, key) for key in extra_returning_fields) + await _query_and_update(bind, self._instance, clause, cols + extra_cols, opts) for prop in self._props: prop.reload(self._instance) return self diff --git a/tests/models.py b/tests/models.py index 0b6dc36e..c64c1047 100644 --- a/tests/models.py +++ b/tests/models.py @@ -8,6 +8,7 @@ from gino import Gino from gino.dialects.asyncpg import JSONB +from sqlalchemy import func DB_ARGS = dict( host=os.getenv("DB_HOST", "localhost"), @@ -149,6 +150,18 @@ class UserSetting(db.Model): col1_check = db.CheckConstraint("col1 >= 1 AND col1 <= 5") col2_idx = db.Index("col2_idx", "col2") +class Record(db.Model): + __tablename__ = "gino_records" + + id = db.Column(db.BigInteger(), primary_key=True) + value = db.Column(db.Text()) + updated_at = db.Column( + db.DateTime(timezone=True), + nullable=False, + onupdate=func.now(), + server_default=func.now(), + ) + def qsize(engine): # noinspection PyProtectedMember diff --git a/tests/test_crud.py b/tests/test_crud.py index 2622923b..34c73590 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -2,7 +2,7 @@ import pytest -from .models import db, User, UserType, Friendship, Relation, PG_URL +from .models import db, User, UserType, Friendship, Relation, Record, PG_URL pytestmark = pytest.mark.asyncio @@ -209,6 +209,20 @@ async def test_update_multiple_primary_key(engine): assert f2 +async def test_update_extra_returing_fields(engine): + r1 = await Record.create(bind=engine, id=1, value='v0') + updated1 = r1.updated_at + await r1.update(value="v1").apply(bind=engine) + assert updated1 == r1.updated_at + + r2 = await Record.create(bind=engine, id=2, value='v0') + updated2 = r2.updated_at + await r2.update(value="v2").apply( + bind=engine, extra_returning_fields=("updated_at",) + ) + assert updated2 < r2.updated_at + + async def test_delete(engine): u1 = await test_create(engine) await u1.delete(bind=engine, timeout=10)