Skip to content

Commit

Permalink
Fix #459 - Check inner type of the scalar list
Browse files Browse the repository at this point in the history
  • Loading branch information
ziima committed Jul 9, 2020
1 parent 9b30035 commit 99d0a11
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 5 deletions.
34 changes: 29 additions & 5 deletions sqlalchemy_utils/types/scalar_list.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import six
import sqlalchemy as sa
from sqlalchemy import types
Expand All @@ -11,8 +13,7 @@ class ScalarListType(types.TypeDecorator):
"""
ScalarListType type provides convenient way for saving multiple scalar
values in one column. ScalarListType works like list on python side and
saves the result as comma-separated list in the database (custom separators
can also be used).
saves the result as comma-separated list in the database.
Example ::
Expand Down Expand Up @@ -50,13 +51,31 @@ class Player(Base):
session.commit()
:param inner_type:
The type of the values. Default is ``str``.
:param separator:
Separator of the values. Default is ``,``.
:param coerce_func:
Custom function to coerce values when read from database.
By default ``inner_type`` is used instead.
"""

impl = sa.UnicodeText()

def __init__(self, coerce_func=six.text_type, separator=u','):
def __init__(self, inner_type=six.text_type, separator=u',',
coerce_func=None):
self.separator = six.text_type(separator)
self.coerce_func = coerce_func
if not isinstance(inner_type, type) and coerce_func is None:
warn_msg = (
"ScalarListType has new required argument 'inner_type'. "
"Provide the type of the values and if required, "
"pass coerce func as a keyword argument.")
warnings.warn(warn_msg, DeprecationWarning)
self.inner_type = None
self.coerce_func = inner_type
else:
self.inner_type = inner_type
self.coerce_func = coerce_func

def process_bind_param(self, value, dialect):
# Convert list of values to unicode separator-separated list
Expand All @@ -69,6 +88,10 @@ def process_bind_param(self, value, dialect):
"these strings, use a different separator string.)"
% self.separator
)
if self.inner_type is not None:
if any(not isinstance(item, self.inner_type) for item in value):
msg = "Not all items in value {} match the type {}"
raise ValueError(msg.format(value, self.inner_type))
return self.separator.join(
map(six.text_type, value)
)
Expand All @@ -78,6 +101,7 @@ def process_result_value(self, value, dialect):
if value == u'':
return []
# coerce each value
coerce_func = self.coerce_func or self.inner_type
return list(map(
self.coerce_func, value.split(self.separator)
coerce_func, value.split(self.separator)
))
54 changes: 54 additions & 0 deletions tests/types/test_scalar_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ def test_save_integer_list(self, session, User):
user = session.query(User).first()
assert user.some_list == [1, 2, 3, 4]

def test_save_integer_list_invalid(self, session, User):
user = User(
some_list=[1, 2, 'invalid', 4]
)

session.add(user)
with pytest.raises(sa.exc.StatementError):
session.commit()


class TestScalarUnicodeList(object):

Expand Down Expand Up @@ -92,3 +101,48 @@ def test_save_and_retrieve_empty_list(self, session, User):

user = session.query(User).first()
assert user.some_list == []


def custom_int(value):
return int(value)


@pytest.mark.filterwarnings(
"ignore:ScalarListType has new required argument 'inner_type'")
class TestScalarListCoerceFunc(object):
"""Test deprecated behaviour with single argument which is not a type."""

@pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
some_list = sa.Column(ScalarListType(custom_int))

def __repr__(self):
return 'User(%r)' % self.id

return User

@pytest.fixture
def init_models(self, User):
pass

def test_save_integer_list(self, session, User):
user = User(some_list=[1, 2, 3, 4])

session.add(user)
session.commit()

user = session.query(User).first()
assert user.some_list == [1, 2, 3, 4]

def test_save_integer_list_invalid(self, session, User):
user = User(some_list=[1, 2, 'invalid', 4])

session.add(user)
session.commit()

# It stores invalid value to database and fails on coerce after read.
with pytest.raises(ValueError, match='invalid literal for int'):
session.query(User).first()

0 comments on commit 99d0a11

Please sign in to comment.