Skip to content

Commit

Permalink
Changed default dialect to 2 (#3467)
Browse files Browse the repository at this point in the history
* Changed default dialect to 2

* Codestyle fixes

* Fixed async tests

* Added handling of RESP3 responses

* Fixed flacky tests

* Codestyle fix

* Added separate file to hold default value
  • Loading branch information
vladvildanov authored Jan 9, 2025
1 parent 9dfaeae commit 7a6b412
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 49 deletions.
4 changes: 3 additions & 1 deletion redis/commands/search/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List, Union

from redis.commands.search.dialect import DEFAULT_DIALECT

FIELDNAME = object()


Expand Down Expand Up @@ -110,7 +112,7 @@ def __init__(self, query: str = "*") -> None:
self._with_schema = False
self._verbatim = False
self._cursor = []
self._dialect = None
self._dialect = DEFAULT_DIALECT
self._add_scores = False
self._scorer = "TFIDF"

Expand Down
3 changes: 3 additions & 0 deletions redis/commands/search/dialect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Value for the default dialect to be used as a part of
# Search or Aggregate query.
DEFAULT_DIALECT = 2
4 changes: 3 additions & 1 deletion redis/commands/search/query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List, Optional, Union

from redis.commands.search.dialect import DEFAULT_DIALECT


class Query:
"""
Expand Down Expand Up @@ -40,7 +42,7 @@ def __init__(self, query_string: str) -> None:
self._highlight_fields: List = []
self._language: Optional[str] = None
self._expander: Optional[str] = None
self._dialect: Optional[int] = None
self._dialect: int = DEFAULT_DIALECT

def query_string(self) -> str:
"""Return the query string of this query only."""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_asyncio/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1646,7 +1646,7 @@ async def test_search_commands_in_pipeline(decoded_r: redis.Redis):
@pytest.mark.redismod
async def test_query_timeout(decoded_r: redis.Redis):
q1 = Query("foo").timeout(5000)
assert q1.get_args() == ["foo", "TIMEOUT", 5000, "LIMIT", 0, 10]
assert q1.get_args() == ["foo", "TIMEOUT", 5000, "DIALECT", 2, "LIMIT", 0, 10]
q2 = Query("foo").timeout("not_a_number")
with pytest.raises(redis.ResponseError):
await decoded_r.ft().search(q2)
Expand Down
20 changes: 9 additions & 11 deletions tests/test_auth/test_token_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,18 @@ def on_next(token):
assert len(tokens) > 0

@pytest.mark.parametrize(
"exp_refresh_ratio,tokens_refreshed",
"exp_refresh_ratio",
[
(0.9, 2),
(0.28, 4),
(0.9),
(0.28),
],
ids=[
"Refresh ratio = 0.9, 2 tokens in 0,1 second",
"Refresh ratio = 0.28, 4 tokens in 0,1 second",
"Refresh ratio = 0.9",
"Refresh ratio = 0.28",
],
)
@pytest.mark.asyncio
async def test_async_success_token_renewal(
self, exp_refresh_ratio, tokens_refreshed
):
async def test_async_success_token_renewal(self, exp_refresh_ratio):
tokens = []
mock_provider = Mock(spec=IdentityProviderInterface)
mock_provider.request_token.side_effect = [
Expand Down Expand Up @@ -129,7 +127,7 @@ async def on_next(token):
await mgr.start_async(mock_listener, block_for_initial=True)
await asyncio.sleep(0.1)

assert len(tokens) == tokens_refreshed
assert len(tokens) > 0

@pytest.mark.parametrize(
"block_for_initial,tokens_acquired",
Expand Down Expand Up @@ -203,7 +201,7 @@ def on_next(token):
# additional token renewal.
sleep(0.1)

assert len(tokens) == 1
assert len(tokens) > 0

@pytest.mark.asyncio
async def test_async_token_renewal_with_skip_initial(self):
Expand Down Expand Up @@ -245,7 +243,7 @@ async def on_next(token):
# due to additional token renewal.
await asyncio.sleep(0.2)

assert len(tokens) == 2
assert len(tokens) > 0

def test_success_token_renewal_with_retry(self):
tokens = []
Expand Down
118 changes: 83 additions & 35 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2122,7 +2122,7 @@ def test_profile_query_params(client):
client.hset("b", "v", "aaaabaaa")
client.hset("c", "v", "aaaaabaa")
query = "*=>[KNN 2 @v $vec]"
q = Query(query).return_field("__v_score").sort_by("__v_score", True).dialect(2)
q = Query(query).return_field("__v_score").sort_by("__v_score", True)
if is_resp2_connection(client):
res, det = client.ft().profile(q, query_params={"vec": "aaaaaaaa"})
assert det["Iterators profile"]["Counter"] == 2.0
Expand Down Expand Up @@ -2155,7 +2155,7 @@ def test_vector_field(client):
client.hset("c", "v", "aaaaabaa")

query = "*=>[KNN 2 @v $vec]"
q = Query(query).return_field("__v_score").sort_by("__v_score", True).dialect(2)
q = Query(query).return_field("__v_score").sort_by("__v_score", True)
res = client.ft().search(q, query_params={"vec": "aaaaaaaa"})

if is_resp2_connection(client):
Expand Down Expand Up @@ -2191,7 +2191,7 @@ def test_text_params(client):
client.hset("doc3", mapping={"name": "Carol"})

params_dict = {"name1": "Alice", "name2": "Bob"}
q = Query("@name:($name1 | $name2 )").dialect(2)
q = Query("@name:($name1 | $name2 )")
res = client.ft().search(q, query_params=params_dict)
if is_resp2_connection(client):
assert 2 == res.total
Expand All @@ -2214,7 +2214,7 @@ def test_numeric_params(client):
client.hset("doc3", mapping={"numval": 103})

params_dict = {"min": 101, "max": 102}
q = Query("@numval:[$min $max]").dialect(2)
q = Query("@numval:[$min $max]")
res = client.ft().search(q, query_params=params_dict)

if is_resp2_connection(client):
Expand All @@ -2236,7 +2236,7 @@ def test_geo_params(client):
client.hset("doc3", mapping={"g": "29.68746, 34.94882"})

params_dict = {"lat": "34.95126", "lon": "29.69465", "radius": 1000, "units": "km"}
q = Query("@g:[$lon $lat $radius $units]").dialect(2)
q = Query("@g:[$lon $lat $radius $units]")
res = client.ft().search(q, query_params=params_dict)
_assert_search_result(client, res, ["doc1", "doc2", "doc3"])

Expand Down Expand Up @@ -2355,19 +2355,19 @@ def test_dialect(client):
with pytest.raises(redis.ResponseError) as err:
client.ft().explain(Query("(*)").dialect(1))
assert "Syntax error" in str(err)
assert "WILDCARD" in client.ft().explain(Query("(*)").dialect(2))
assert "WILDCARD" in client.ft().explain(Query("(*)"))

with pytest.raises(redis.ResponseError) as err:
client.ft().explain(Query("$hello").dialect(1))
assert "Syntax error" in str(err)
q = Query("$hello").dialect(2)
q = Query("$hello")
expected = "UNION {\n hello\n +hello(expanded)\n}\n"
assert expected in client.ft().explain(q, query_params={"hello": "hello"})

expected = "NUMERIC {0.000000 <= @num <= 10.000000}\n"
assert expected in client.ft().explain(Query("@title:(@num:[0 10])").dialect(1))
with pytest.raises(redis.ResponseError) as err:
client.ft().explain(Query("@title:(@num:[0 10])").dialect(2))
client.ft().explain(Query("@title:(@num:[0 10])"))
assert "Syntax error" in str(err)


Expand Down Expand Up @@ -2438,9 +2438,9 @@ def test_withsuffixtrie(client: redis.Redis):
@pytest.mark.redismod
def test_query_timeout(r: redis.Redis):
q1 = Query("foo").timeout(5000)
assert q1.get_args() == ["foo", "TIMEOUT", 5000, "LIMIT", 0, 10]
assert q1.get_args() == ["foo", "TIMEOUT", 5000, "DIALECT", 2, "LIMIT", 0, 10]
q1 = Query("foo").timeout(0)
assert q1.get_args() == ["foo", "TIMEOUT", 0, "LIMIT", 0, 10]
assert q1.get_args() == ["foo", "TIMEOUT", 0, "DIALECT", 2, "LIMIT", 0, 10]
q2 = Query("foo").timeout("not_a_number")
with pytest.raises(redis.ResponseError):
r.ft().search(q2)
Expand Down Expand Up @@ -2507,28 +2507,26 @@ def test_search_missing_fields(client):
)

with pytest.raises(redis.exceptions.ResponseError) as e:
client.ft().search(
Query("ismissing(@title)").dialect(2).return_field("id").no_content()
)
client.ft().search(Query("ismissing(@title)").return_field("id").no_content())
assert "to be defined with 'INDEXMISSING'" in e.value.args[0]

res = client.ft().search(
Query("ismissing(@features)").dialect(2).return_field("id").no_content()
Query("ismissing(@features)").return_field("id").no_content()
)
_assert_search_result(client, res, ["property:2"])

res = client.ft().search(
Query("-ismissing(@features)").dialect(2).return_field("id").no_content()
Query("-ismissing(@features)").return_field("id").no_content()
)
_assert_search_result(client, res, ["property:1", "property:3"])

res = client.ft().search(
Query("ismissing(@description)").dialect(2).return_field("id").no_content()
Query("ismissing(@description)").return_field("id").no_content()
)
_assert_search_result(client, res, ["property:3"])

res = client.ft().search(
Query("-ismissing(@description)").dialect(2).return_field("id").no_content()
Query("-ismissing(@description)").return_field("id").no_content()
)
_assert_search_result(client, res, ["property:1", "property:2"])

Expand Down Expand Up @@ -2578,31 +2576,25 @@ def test_search_empty_fields(client):
)

with pytest.raises(redis.exceptions.ResponseError) as e:
client.ft().search(
Query("@title:''").dialect(2).return_field("id").no_content()
)
client.ft().search(Query("@title:''").return_field("id").no_content())
assert "Use `INDEXEMPTY` in field creation" in e.value.args[0]

res = client.ft().search(
Query("@features:{$empty}").dialect(2).return_field("id").no_content(),
Query("@features:{$empty}").return_field("id").no_content(),
query_params={"empty": ""},
)
_assert_search_result(client, res, ["property:2"])

res = client.ft().search(
Query("-@features:{$empty}").dialect(2).return_field("id").no_content(),
Query("-@features:{$empty}").return_field("id").no_content(),
query_params={"empty": ""},
)
_assert_search_result(client, res, ["property:1", "property:3"])

res = client.ft().search(
Query("@description:''").dialect(2).return_field("id").no_content()
)
res = client.ft().search(Query("@description:''").return_field("id").no_content())
_assert_search_result(client, res, ["property:3"])

res = client.ft().search(
Query("-@description:''").dialect(2).return_field("id").no_content()
)
res = client.ft().search(Query("-@description:''").return_field("id").no_content())
_assert_search_result(client, res, ["property:1", "property:2"])


Expand Down Expand Up @@ -2643,29 +2635,85 @@ def test_special_characters_in_fields(client):

# no need to escape - when using params
res = client.ft().search(
Query("@uuid:{$uuid}").dialect(2),
Query("@uuid:{$uuid}"),
query_params={"uuid": "123e4567-e89b-12d3-a456-426614174000"},
)
_assert_search_result(client, res, ["resource:1"])

# with double quotes exact match no need to escape the - even without params
res = client.ft().search(
Query('@uuid:{"123e4567-e89b-12d3-a456-426614174000"}').dialect(2)
)
res = client.ft().search(Query('@uuid:{"123e4567-e89b-12d3-a456-426614174000"}'))
_assert_search_result(client, res, ["resource:1"])

res = client.ft().search(Query('@tags:{"new-year\'s-resolutions"}').dialect(2))
res = client.ft().search(Query('@tags:{"new-year\'s-resolutions"}'))
_assert_search_result(client, res, ["resource:2"])

# possible to search numeric fields by single value
res = client.ft().search(Query("@rating:[4]").dialect(2))
res = client.ft().search(Query("@rating:[4]"))
_assert_search_result(client, res, ["resource:2"])

# some chars still need escaping
res = client.ft().search(Query(r"@tags:{\$btc}").dialect(2))
res = client.ft().search(Query(r"@tags:{\$btc}"))
_assert_search_result(client, res, ["resource:1"])


@pytest.mark.redismod
@skip_ifmodversion_lt("2.4.3", "search")
def test_vector_search_with_default_dialect(client):
client.ft().create_index(
(
VectorField(
"v", "HNSW", {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"}
),
)
)

client.hset("a", "v", "aaaaaaaa")
client.hset("b", "v", "aaaabaaa")
client.hset("c", "v", "aaaaabaa")

query = "*=>[KNN 2 @v $vec]"
q = Query(query)

assert "DIALECT" in q.get_args()
assert 2 in q.get_args()

res = client.ft().search(q, query_params={"vec": "aaaaaaaa"})
if is_resp2_connection(client):
assert res.total == 2
else:
assert res["total_results"] == 2


@pytest.mark.redismod
@skip_ifmodversion_lt("2.4.3", "search")
def test_search_query_with_different_dialects(client):
client.ft().create_index(
(TextField("name"), TextField("lastname")),
definition=IndexDefinition(prefix=["test:"]),
)

client.hset("test:1", "name", "James")
client.hset("test:1", "lastname", "Brown")

# Query with default DIALECT 2
query = "@name: James Brown"
q = Query(query)
res = client.ft().search(q)
if is_resp2_connection(client):
assert res.total == 1
else:
assert res["total_results"] == 1

# Query with explicit DIALECT 1
query = "@name: James Brown"
q = Query(query).dialect(1)
res = client.ft().search(q)
if is_resp2_connection(client):
assert res.total == 0
else:
assert res["total_results"] == 0


def _assert_search_result(client, result, expected_doc_ids):
"""
Make sure the result of a geo search is as expected, taking into account the RESP
Expand Down

0 comments on commit 7a6b412

Please sign in to comment.