Skip to content

Commit

Permalink
Add extra_returning_fields in apply
Browse files Browse the repository at this point in the history
  • Loading branch information
kigawas committed Mar 17, 2021
1 parent 1145f24 commit b3be96a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/gino/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _set_prop(self, prop, value):
self._literal = False
self._props[prop] = value

async def apply(self, bind=None, timeout=DEFAULT):
async def apply(self, bind=None, timeout=DEFAULT, extra_returning_fields=tuple()):
"""
Apply pending updates into database by executing an ``UPDATE`` SQL.
Expand All @@ -113,6 +113,9 @@ async def apply(self, bind=None, timeout=DEFAULT):
``None`` for wait forever. By default it will use the ``timeout``
execution option value if unspecified.
:param extra_returning_fields: A `tuple` of returning fields besides
fields to create/update, e.g. (`updated_at`, `created_at`)
:return: ``self`` for chaining calls.
"""
Expand Down Expand Up @@ -174,9 +177,9 @@ async def apply(self, bind=None, timeout=DEFAULT):
)
.execution_options(**opts)
)
await _query_and_update(
bind, self._instance, clause, [getattr(cls, key) for key in values], opts
)
cols = tuple(getattr(cls, key) for key in values)
extra_cols = tuple(getattr(cls, key) for key in extra_returning_fields)
await _query_and_update(bind, self._instance, clause, cols + extra_cols, opts)
for prop in self._props:
prop.reload(self._instance)
return self
Expand Down
7 changes: 7 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from gino import Gino
from gino.dialects.asyncpg import JSONB
from sqlalchemy import func

DB_ARGS = dict(
host=os.getenv("DB_HOST", "localhost"),
Expand Down Expand Up @@ -49,6 +50,12 @@ class User(db.Model):
weight = db.IntegerProperty(prop_name='parameter')
height = db.IntegerProperty(default=170, prop_name='parameter')
bio = db.StringProperty(prop_name='parameter')
updated_at = db.Column(
db.DateTime(timezone=True),
nullable=False,
onupdate=func.now(),
server_default=func.now(),
)

@balance.after_get
def balance(self, val):
Expand Down
14 changes: 14 additions & 0 deletions tests/test_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,20 @@ async def test_update_multiple_primary_key(engine):
assert f2


async def test_update_extra_returing_fields(engine):
u1 = await test_create(engine)
updated1 = u1.updated_at
await u1.update(nickname="new_nickname").apply(bind=engine)
assert updated1 == u1.updated_at

u2 = await test_create(engine)
updated2 = u2.updated_at
await u2.update(nickname="new_nickname").apply(
bind=engine, extra_returning_fields=("updated_at",)
)
assert updated2 < u2.updated_at


async def test_delete(engine):
u1 = await test_create(engine)
await u1.delete(bind=engine, timeout=10)
Expand Down

0 comments on commit b3be96a

Please sign in to comment.