Skip to content

Commit

Permalink
Add extra_returning_fields in apply
Browse files Browse the repository at this point in the history
  • Loading branch information
kigawas committed Mar 17, 2021
1 parent 1145f24 commit 24aa060
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 5 deletions.
11 changes: 7 additions & 4 deletions src/gino/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _set_prop(self, prop, value):
self._literal = False
self._props[prop] = value

async def apply(self, bind=None, timeout=DEFAULT):
async def apply(self, bind=None, timeout=DEFAULT, extra_returning_fields=tuple()):
"""
Apply pending updates into database by executing an ``UPDATE`` SQL.
Expand All @@ -113,6 +113,9 @@ async def apply(self, bind=None, timeout=DEFAULT):
``None`` for wait forever. By default it will use the ``timeout``
execution option value if unspecified.
:param extra_returning_fields: A `tuple` of returning fields besides
fields to create/update, e.g. (`updated_at`, `created_at`)
:return: ``self`` for chaining calls.
"""
Expand Down Expand Up @@ -174,9 +177,9 @@ async def apply(self, bind=None, timeout=DEFAULT):
)
.execution_options(**opts)
)
await _query_and_update(
bind, self._instance, clause, [getattr(cls, key) for key in values], opts
)
cols = tuple(getattr(cls, key) for key in values)
extra_cols = tuple(getattr(cls, key) for key in extra_returning_fields)
await _query_and_update(bind, self._instance, clause, cols + extra_cols, opts)
for prop in self._props:
prop.reload(self._instance)
return self
Expand Down
13 changes: 13 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from gino import Gino
from gino.dialects.asyncpg import JSONB
from sqlalchemy import func

DB_ARGS = dict(
host=os.getenv("DB_HOST", "localhost"),
Expand Down Expand Up @@ -149,6 +150,18 @@ class UserSetting(db.Model):
col1_check = db.CheckConstraint("col1 >= 1 AND col1 <= 5")
col2_idx = db.Index("col2_idx", "col2")

class Record(db.Model):
__tablename__ = "gino_records"

id = db.Column(db.BigInteger(), primary_key=True)
value = db.Column(db.Text())
updated_at = db.Column(
db.DateTime(timezone=True),
nullable=False,
onupdate=func.now(),
server_default=func.now(),
)


def qsize(engine):
# noinspection PyProtectedMember
Expand Down
16 changes: 15 additions & 1 deletion tests/test_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from .models import db, User, UserType, Friendship, Relation, PG_URL
from .models import db, User, UserType, Friendship, Relation, Record, PG_URL

pytestmark = pytest.mark.asyncio

Expand Down Expand Up @@ -209,6 +209,20 @@ async def test_update_multiple_primary_key(engine):
assert f2


async def test_update_extra_returing_fields(engine):
r1 = await Record.create(bind=engine, id=1, value='v0')
updated1 = r1.updated_at
await r1.update(value="v1").apply(bind=engine)
assert updated1 == r1.updated_at

r2 = await Record.create(bind=engine, id=2, value='v0')
updated2 = r2.updated_at
await r2.update(value="v2").apply(
bind=engine, extra_returning_fields=("updated_at",)
)
assert updated2 < r2.updated_at


async def test_delete(engine):
u1 = await test_create(engine)
await u1.delete(bind=engine, timeout=10)
Expand Down

0 comments on commit 24aa060

Please sign in to comment.