Skip to content

Commit

Permalink
fix array index lookup
Browse files Browse the repository at this point in the history
  • Loading branch information
timgraham committed Dec 13, 2024
1 parent cf39576 commit 6338305
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 18 deletions.
8 changes: 1 addition & 7 deletions django_mongodb/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,10 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"model_fields_.test_arrayfield.TestQuerying.test_icontains",
# Field 'field' expected a number but got Value(1).
"model_fields_.test_arrayfield.TestQuerying.test_exact_with_expression",
# int() argument must be a string, a bytes-like object or a real number, not 'list'
"model_fields_.test_arrayfield.TestQuerying.test_index_annotation",
# Wrong results
"model_fields_.test_arrayfield.TestQuerying.test_index",
"model_fields_.test_arrayfield.TestQuerying.test_index_chained",
"model_fields_.test_arrayfield.TestQuerying.test_index_nested",
"model_fields_.test_arrayfield.TestQuerying.test_order_by_slice",
# $lt treats null values as zero.
"model_fields_.test_arrayfield.TestQuerying.test_lt",
"model_fields_.test_arrayfield.TestQuerying.test_len",
"model_fields_.test_arrayfield.TestQuerying.test_index_chained",
# None is $in None
"model_fields_.test_arrayfield.TestQuerying.test_in_as_F_object",
}
Expand Down
7 changes: 2 additions & 5 deletions django_mongodb/fields/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ def get_transform(self, name):
except ValueError:
pass
else:
index += 1 # postgres uses 1-indexing
return IndexTransformFactory(index, self.base_field)
try:
start, end = name.split("_")
Expand Down Expand Up @@ -306,10 +305,8 @@ def __init__(self, index, base_field, *args, **kwargs):
self.base_field = base_field

def as_mql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
if not lhs.endswith("]"):
lhs = "(%s)" % lhs
return "%s[%%s]" % lhs, (*params, self.index)
lhs_mql = process_lhs(self, compiler, connection)
return {"$arrayElemAt": [lhs_mql, self.index]}

@property
def output_field(self):
Expand Down
10 changes: 4 additions & 6 deletions tests/model_fields_/test_arrayfield.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,6 @@ def test_index_nested(self):
instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
self.assertSequenceEqual(NestedIntegerArrayModel.objects.filter(field__0__0=1), [instance])

@unittest.expectedFailure
def test_index_used_on_nested_data(self):
instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
self.assertSequenceEqual(
Expand Down Expand Up @@ -388,7 +387,7 @@ def test_slice(self):
NullableIntegerArrayModel.objects.filter(field__0_2=[2, 3]), self.objs[2:3]
)

def test_order_by_slice(self):
def test_order_by_index(self):
more_objs = (
NullableIntegerArrayModel.objects.create(field=[1, 637]),
NullableIntegerArrayModel.objects.create(field=[2, 1]),
Expand All @@ -398,19 +397,18 @@ def test_order_by_slice(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.order_by("field__1"),
[
self.objs[0],
self.objs[1],
self.objs[4],
more_objs[2],
more_objs[1],
more_objs[3],
self.objs[2],
self.objs[3],
more_objs[0],
self.objs[4],
self.objs[1],
self.objs[0],
],
)

@unittest.expectedFailure
def test_slice_nested(self):
instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
self.assertSequenceEqual(
Expand Down

0 comments on commit 6338305

Please sign in to comment.