From fcc600142a298d88758381fb85c311edc7042f16 Mon Sep 17 00:00:00 2001 From: Shoham Elias <116083498+shohamazon@users.noreply.github.com> Date: Thu, 10 Oct 2024 16:09:17 +0300 Subject: [PATCH 001/180] Python: adds JSON.ARRLEN command (#2403) --------- Signed-off-by: Shoham Elias Signed-off-by: Shoham Elias <116083498+shohamazon@users.noreply.github.com> --- CHANGELOG.md | 1 + .../async_commands/server_modules/json.py | 54 +++++++++++++++++++ .../tests/tests_server_modules/test_json.py | 36 +++++++++++++ 3 files changed, 91 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 72d205fefb..8ee5dda24f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ #### Changes +* Python: Add JSON.ARRLEN command ([#2403](https://github.com/valkey-io/valkey-glide/pull/2403)) #### Breaking Changes diff --git a/python/python/glide/async_commands/server_modules/json.py b/python/python/glide/async_commands/server_modules/json.py index 40d7709646..d1709806bc 100644 --- a/python/python/glide/async_commands/server_modules/json.py +++ b/python/python/glide/async_commands/server_modules/json.py @@ -139,6 +139,60 @@ async def get( return cast(bytes, await client.custom_command(args)) +async def arrlen( + client: TGlideClient, + key: TEncodable, + path: Optional[TEncodable] = None, +) -> Optional[TJsonResponse[int]]: + """ + Retrieves the length of the array at the specified `path` within the JSON document stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (Optional[TEncodable]): The path within the JSON document. Defaults to None. + + Returns: + Optional[TJsonResponse[int]]: + For JSONPath (`path` starts with `$`): + Returns a list of integer replies for every possible path, indicating the length of the array, + or None for JSON values matching the path that are not an array. + If `path` doesn't exist, an empty array will be returned. + For legacy path (`path` doesn't starts with `$`): + Returns the length of the array at `path`. + If multiple paths match, the length of the first array match is returned. + If the JSON value at `path` is not a array or if `path` doesn't exist, an error is raised. + If `key` doesn't exist, None is returned. + + Examples: + >>> from glide import json + >>> await json.set(client, "doc", "$", '{"a": [1, 2, 3], "b": {"a": [1, 2], "c": {"a": 42}}}') + b'OK' # JSON is successfully set for doc + >>> await json.arrlen(client, "doc", "$") + [None] # No array at the root path. + >>> await json.arrlen(client, "doc", "$.a") + [3] # Retrieves the length of the array at path $.a. + >>> await json.arrlen(client, "doc", "$..a") + [3, 2, None] # Retrieves lengths of arrays found at all levels of the path `..a`. + >>> await json.arrlen(client, "doc", "..a") + 3 # Legacy path retrieves the first array match at path `..a`. + >>> await json.arrlen(client, "non_existing_key", "$.a") + None # Returns None because the key does not exist. + + >>> await json.set(client, "doc", "$", '[1, 2, 3, 4]') + b'OK' # JSON is successfully set for doc + >>> await json.arrlen(client, "doc") + 4 # Retrieves lengths of arrays in root. + """ + args = ["JSON.ARRLEN", key] + if path: + args.append(path) + return cast( + Optional[TJsonResponse[int]], + await client.custom_command(args), + ) + + async def delete( client: TGlideClient, key: TEncodable, diff --git a/python/python/tests/tests_server_modules/test_json.py b/python/python/tests/tests_server_modules/test_json.py index 67c8c7a112..a69c3010e2 100644 --- a/python/python/tests/tests_server_modules/test_json.py +++ b/python/python/tests/tests_server_modules/test_json.py @@ -276,3 +276,39 @@ async def test_json_type(self, glide_client: TGlideClient): # Check for all types in the JSON document using legacy path result = await json.type(glide_client, key, "[*]") assert result == b"string" # Expecting only the first type (string for key1) + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_arrlen(self, glide_client: TGlideClient): + key = get_random_string(5) + + json_value = '{"a": [1, 2, 3], "b": {"a": [1, 2], "c": {"a": 42}}}' + assert await json.set(glide_client, key, "$", json_value) == OK + + assert await json.arrlen(glide_client, key, "$.a") == [3] + + assert await json.arrlen(glide_client, key, "$..a") == [3, 2, None] + + # Legacy path retrieves the first array match at ..a + assert await json.arrlen(glide_client, key, "..a") == 3 + + # Value at path is not an array + assert await json.arrlen(glide_client, key, "$") == [None] + with pytest.raises(RequestError): + assert await json.arrlen(glide_client, key, ".") + + # Path doesn't exist + assert await json.arrlen(glide_client, key, "$.non_existing_path") == [] + with pytest.raises(RequestError): + assert await json.arrlen(glide_client, key, "non_existing_path") + + # Non-existing key + assert await json.arrlen(glide_client, "non_existing_key", "$.a") is None + assert await json.arrlen(glide_client, "non_existing_key", ".a") is None + + # No path + with pytest.raises(RequestError): + assert await json.arrlen(glide_client, key) + + assert await json.set(glide_client, key, "$", "[1, 2, 3, 4]") == OK + assert await json.arrlen(glide_client, key) == 4 From a2ea32e1991569ea44d106dfdaae420a441389d2 Mon Sep 17 00:00:00 2001 From: Shoham Elias <116083498+shohamazon@users.noreply.github.com> Date: Thu, 10 Oct 2024 16:50:34 +0300 Subject: [PATCH 002/180] Python: adds JSON.CLEAR command (#2418) Signed-off-by: Shoham Elias --- CHANGELOG.md | 1 + .../async_commands/server_modules/json.py | 52 +++++++++++++++++++ .../tests/tests_server_modules/test_json.py | 47 +++++++++++++++++ 3 files changed, 100 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ee5dda24f..f8bc293a82 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ #### Changes * Python: Add JSON.ARRLEN command ([#2403](https://github.com/valkey-io/valkey-glide/pull/2403)) +* Python: Add JSON.CLEAR command ([#2418](https://github.com/valkey-io/valkey-glide/pull/2418)) #### Breaking Changes diff --git a/python/python/glide/async_commands/server_modules/json.py b/python/python/glide/async_commands/server_modules/json.py index d1709806bc..1864132451 100644 --- a/python/python/glide/async_commands/server_modules/json.py +++ b/python/python/glide/async_commands/server_modules/json.py @@ -193,6 +193,58 @@ async def arrlen( ) +async def clear( + client: TGlideClient, + key: TEncodable, + path: Optional[str] = None, +) -> int: + """ + Clears arrays or objects at the specified JSON path in the document stored at `key`. + Numeric values are set to `0`, and boolean values are set to `False`, and string values are converted to empty strings. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (Optional[str]): The JSON path to the arrays or objects to be cleared. Defaults to None. + + Returns: + int: The number of containers cleared, numeric values zeroed, and booleans toggled to `false`, + and string values converted to empty strings. + If `path` doesn't exist, or the value at `path` is already empty (e.g., an empty array, object, or string), 0 is returned. + If `key doesn't exist, an error is raised. + + Examples: + >>> from glide import json + >>> await json.set(client, "doc", "$", '{"obj":{"a":1, "b":2}, "arr":[1,2,3], "str": "foo", "bool": true, "int": 42, "float": 3.14, "nullVal": null}') + b'OK' # JSON document is successfully set. + >>> await json.clear(client, "doc", "$.*") + 6 # 6 values are cleared (arrays/objects/strings/numbers/booleans), but `null` remains as is. + >>> await json.get(client, "doc", "$") + b'[{"obj":{},"arr":[],"str":"","bool":false,"int":0,"float":0.0,"nullVal":null}]' + >>> await json.clear(client, "doc", "$.*") + 0 # No further clearing needed since the containers are already empty and the values are defaults. + + >>> await json.set(client, "doc", "$", '{"a": 1, "b": {"a": [5, 6, 7], "b": {"a": true}}, "c": {"a": "value", "b": {"a": 3.5}}, "d": {"a": {"foo": "foo"}}, "nullVal": null}') + b'OK' + >>> await json.clear(client, "doc", "b.a[1:3]") + 2 # 2 elements (`6` and `7`) are cleared. + >>> await json.clear(client, "doc", "b.a[1:3]") + 0 # No elements cleared since specified slice has already been cleared. + >>> await json.get(client, "doc", "$..a") + b'[1,[5,0,0],true,"value",3.5,{"foo":"foo"}]' + + >>> await json.clear(client, "doc", "$..a") + 6 # All numeric, boolean, and string values across paths are cleared. + >>> await json.get(client, "doc", "$..a") + b'[0,[],false,"",0.0,{}]' + """ + args = ["JSON.CLEAR", key] + if path: + args.append(path) + + return cast(int, await client.custom_command(args)) + + async def delete( client: TGlideClient, key: TEncodable, diff --git a/python/python/tests/tests_server_modules/test_json.py b/python/python/tests/tests_server_modules/test_json.py index a69c3010e2..d21b11686b 100644 --- a/python/python/tests/tests_server_modules/test_json.py +++ b/python/python/tests/tests_server_modules/test_json.py @@ -312,3 +312,50 @@ async def test_json_arrlen(self, glide_client: TGlideClient): assert await json.set(glide_client, key, "$", "[1, 2, 3, 4]") == OK assert await json.arrlen(glide_client, key) == 4 + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_clear(self, glide_client: TGlideClient): + key = get_random_string(5) + + json_value = '{"obj":{"a":1, "b":2}, "arr":[1,2,3], "str": "foo", "bool": true, "int": 42, "float": 3.14, "nullVal": null}' + assert await json.set(glide_client, key, "$", json_value) == OK + + assert await json.clear(glide_client, key, "$.*") == 6 + result = await json.get(glide_client, key, "$") + assert ( + result + == b'[{"obj":{},"arr":[],"str":"","bool":false,"int":0,"float":0.0,"nullVal":null}]' + ) + assert await json.clear(glide_client, key, "$.*") == 0 + + assert await json.set(glide_client, key, "$", json_value) == OK + assert await json.clear(glide_client, key, "*") == 6 + + json_value = '{"a": 1, "b": {"a": [5, 6, 7], "b": {"a": true}}, "c": {"a": "value", "b": {"a": 3.5}}, "d": {"a": {"foo": "foo"}}, "nullVal": null}' + assert await json.set(glide_client, key, "$", json_value) == OK + + assert await json.clear(glide_client, key, "b.a[1:3]") == 2 + assert await json.clear(glide_client, key, "b.a[1:3]") == 0 + assert ( + await json.get(glide_client, key, "$..a") + == b'[1,[5,0,0],true,"value",3.5,{"foo":"foo"}]' + ) + assert await json.clear(glide_client, key, "..a") == 6 + assert await json.get(glide_client, key, "$..a") == b'[0,[],false,"",0.0,{}]' + + assert await json.clear(glide_client, key, "$..a") == 0 + + # Path doesn't exists + assert await json.clear(glide_client, key, "$.path") == 0 + assert await json.clear(glide_client, key, "path") == 0 + + # Key doesn't exists + with pytest.raises(RequestError): + await json.clear(glide_client, "non_existing_key") + + with pytest.raises(RequestError): + await json.clear(glide_client, "non_existing_key", "$") + + with pytest.raises(RequestError): + await json.clear(glide_client, "non_existing_key", ".") From 8fea303fac384db4a109644e930d29b1dbb45594 Mon Sep 17 00:00:00 2001 From: prateek-kumar-improving Date: Thu, 10 Oct 2024 21:46:26 -0700 Subject: [PATCH 003/180] Python [FT.CREATE] command added(Created for release-1.2 branch) (#2426) * Python [FT.CREATE] command added Signed-off-by: Prateek Kumar --------- Signed-off-by: Prateek Kumar --- CHANGELOG.md | 1 + python/python/glide/__init__.py | 34 +- .../glide/async_commands/server_modules/ft.py | 56 +++ .../server_modules/ft_constants.py | 32 ++ .../ft_options/ft_create_options.py | 429 ++++++++++++++++++ .../tests/tests_server_modules/test_ft.py | 103 +++++ 6 files changed, 654 insertions(+), 1 deletion(-) create mode 100644 python/python/glide/async_commands/server_modules/ft.py create mode 100644 python/python/glide/async_commands/server_modules/ft_constants.py create mode 100644 python/python/glide/async_commands/server_modules/ft_options/ft_create_options.py create mode 100644 python/python/tests/tests_server_modules/test_ft.py diff --git a/CHANGELOG.md b/CHANGELOG.md index f8bc293a82..bbb5c9d506 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ #### Changes +* Python: Python: Added FT.CREATE command([#2413](https://github.com/valkey-io/valkey-glide/pull/2413)) * Python: Add JSON.ARRLEN command ([#2403](https://github.com/valkey-io/valkey-glide/pull/2403)) * Python: Add JSON.CLEAR command ([#2418](https://github.com/valkey-io/valkey-glide/pull/2418)) diff --git a/python/python/glide/__init__.py b/python/python/glide/__init__.py index cf817c128a..05910eb480 100644 --- a/python/python/glide/__init__.py +++ b/python/python/glide/__init__.py @@ -32,7 +32,23 @@ InsertPosition, UpdateOptions, ) -from glide.async_commands.server_modules import json +from glide.async_commands.server_modules import ft, json +from glide.async_commands.server_modules.ft_options.ft_create_options import ( + DataType, + DistanceMetricType, + Field, + FieldType, + FtCreateOptions, + NumericField, + TagField, + TextField, + VectorAlgorithm, + VectorField, + VectorFieldAttributes, + VectorFieldAttributesFlat, + VectorFieldAttributesHnsw, + VectorType, +) from glide.async_commands.sorted_set import ( AggregationType, GeoSearchByBox, @@ -185,6 +201,7 @@ "InfoSection", "InsertPosition", "json", + "ft", "LexBoundary", "Limit", "ListDirection", @@ -233,4 +250,19 @@ "GlideError", "RequestError", "TimeoutError", + # Ft + "DataType", + "DistanceMetricType", + "Field", + "FieldType", + "FtCreateOptions", + "NumericField", + "TagField", + "TextField", + "VectorAlgorithm", + "VectorField", + "VectorFieldAttributes", + "VectorFieldAttributesFlat", + "VectorFieldAttributesHnsw", + "VectorType", ] diff --git a/python/python/glide/async_commands/server_modules/ft.py b/python/python/glide/async_commands/server_modules/ft.py new file mode 100644 index 0000000000..b7c764cd0f --- /dev/null +++ b/python/python/glide/async_commands/server_modules/ft.py @@ -0,0 +1,56 @@ +# Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 +""" +module for `vector search` commands. +""" + +from typing import List, Optional, cast + +from glide.async_commands.server_modules.ft_constants import ( + CommandNames, + FtCreateKeywords, +) +from glide.async_commands.server_modules.ft_options.ft_create_options import ( + Field, + FtCreateOptions, +) +from glide.constants import TOK, TEncodable +from glide.glide_client import TGlideClient + + +async def create( + client: TGlideClient, + indexName: TEncodable, + schema: List[Field], + options: Optional[FtCreateOptions] = None, +) -> TOK: + """ + Creates an index and initiates a backfill of that index. + + Args: + client (TGlideClient): The client to execute the command. + indexName (TEncodable): The index name for the index to be created + schema (List[Field]): The fields of the index schema, specifying the fields and their types. + options (Optional[FtCreateOptions]): Optional arguments for the [FT.CREATE] command. + + Returns: + If the index is successfully created, returns "OK". + + Examples: + >>> from glide.async_commands.server_modules import ft + >>> schema: List[Field] = [] + >>> field: TextField = TextField("title") + >>> schema.append(field) + >>> prefixes: List[str] = [] + >>> prefixes.append("blog:post:") + >>> index = "idx" + >>> result = await ft.create(glide_client, index, schema, FtCreateOptions(DataType.HASH, prefixes)) + b'OK' # Indicates successful creation of index named 'idx' + """ + args: List[TEncodable] = [CommandNames.FT_CREATE, indexName] + if options: + args.extend(options.toArgs()) + if schema: + args.append(FtCreateKeywords.SCHEMA) + for field in schema: + args.extend(field.toArgs()) + return cast(TOK, await client.custom_command(args)) diff --git a/python/python/glide/async_commands/server_modules/ft_constants.py b/python/python/glide/async_commands/server_modules/ft_constants.py new file mode 100644 index 0000000000..3c48f5b67c --- /dev/null +++ b/python/python/glide/async_commands/server_modules/ft_constants.py @@ -0,0 +1,32 @@ +# Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + + +class CommandNames: + """ + Command name constants for vector search. + """ + + FT_CREATE = "FT.CREATE" + + +class FtCreateKeywords: + """ + Keywords used in the [FT.CREATE] command statment. + """ + + SCHEMA = "SCHEMA" + AS = "AS" + SORTABLE = "SORTABLE" + UNF = "UNF" + NO_INDEX = "NOINDEX" + ON = "ON" + PREFIX = "PREFIX" + SEPARATOR = "SEPARATOR" + CASESENSITIVE = "CASESENSITIVE" + DIM = "DIM" + DISTANCE_METRIC = "DISTANCE_METRIC" + TYPE = "TYPE" + INITIAL_CAP = "INITIAL_CAP" + M = "M" + EF_CONSTRUCTION = "EF_CONSTRUCTION" + EF_RUNTIME = "EF_RUNTIME" diff --git a/python/python/glide/async_commands/server_modules/ft_options/ft_create_options.py b/python/python/glide/async_commands/server_modules/ft_options/ft_create_options.py new file mode 100644 index 0000000000..d3db3dbe75 --- /dev/null +++ b/python/python/glide/async_commands/server_modules/ft_options/ft_create_options.py @@ -0,0 +1,429 @@ +# Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from enum import Enum +from typing import List, Optional + +from glide.async_commands.server_modules.ft_constants import FtCreateKeywords +from glide.constants import TEncodable + + +class FieldType(Enum): + """ + All possible values for the data type of field identifier for the SCHEMA option. + """ + + TEXT = "TEXT" + """ + If the field contains any blob of data. + """ + TAG = "TAG" + """ + If the field contains a tag field. + """ + NUMERIC = "NUMERIC" + """ + If the field contains a number. + """ + VECTOR = "VECTOR" + """ + If the field is a vector field that supports vector search. + """ + + +class VectorAlgorithm(Enum): + """ + Algorithm for vector type fields used for vector similarity search. + """ + + HNSW = "HNSW" + """ + Hierarchical Navigable Small World algorithm. + """ + FLAT = "FLAT" + """ + Flat algorithm or the brute force algorithm. + """ + + +class DistanceMetricType(Enum): + """ + The metric options for the distance in vector type field. + """ + + L2 = "L2" + """ + Euclidean distance + """ + IP = "IP" + """ + Inner product + """ + COSINE = "COSINE" + """ + Cosine distance + """ + + +class VectorType(Enum): + """ + Type type for the vector field type. + """ + + FLOAT32 = "FLOAT32" + """ + FLOAT32 type of vector. The only supported type. + """ + + +class Field(ABC): + """ + Abstract base class for defining fields in a schema. + """ + + @abstractmethod + def __init__( + self, + name: TEncodable, + type: FieldType, + alias: Optional[str] = None, + ): + """ + Initialize a new field instance. + + Args: + name (TEncodable): The name of the field. + type (FieldType): The type of the field. + alias (Optional[str]): An alias for the field. + """ + self.name = name + self.type = type + self.alias = alias + + @abstractmethod + def toArgs(self) -> List[TEncodable]: + """ + Get the arguments representing the field. + + Returns: + List[TEncodable]: A list of field arguments. + """ + args = [self.name] + if self.alias: + args.extend([FtCreateKeywords.AS, self.alias]) + args.append(self.type.value) + return args + + +class TextField(Field): + """ + Class for defining text fields in a schema. + """ + + def __init__(self, name: TEncodable, alias: Optional[str] = None): + """ + Initialize a new TextField instance. + + Args: + name (TEncodable): The name of the text field. + alias (Optional[str]): An alias for the field. + """ + super().__init__(name, FieldType.TEXT, alias) + + def toArgs(self) -> List[TEncodable]: + """ + Get the arguments representing the text field. + + Returns: + List[TEncodable]: A list of text field arguments. + """ + args = super().toArgs() + return args + + +class TagField(Field): + """ + Class for defining tag fields in a schema. + """ + + def __init__( + self, + name: TEncodable, + alias: Optional[str] = None, + separator: Optional[str] = None, + case_sensitive: bool = False, + ): + """ + Initialize a new TagField instance. + + Args: + name (TEncodable): The name of the tag field. + alias (Optional[str]): An alias for the field. + separator (Optional[str]): Specify how text in the attribute is split into individual tags. Must be a single character. + case_sensitive (bool): Preserve the original letter cases of tags. If set to False, characters are converted to lowercase by default. + """ + super().__init__(name, FieldType.TAG, alias) + self.separator = separator + self.case_sensitive = case_sensitive + + def toArgs(self) -> List[TEncodable]: + """ + Get the arguments representing the tag field. + + Returns: + List[TEncodable]: A list of tag field arguments. + """ + args = super().toArgs() + if self.separator: + args.extend([FtCreateKeywords.SEPARATOR, self.separator]) + if self.case_sensitive: + args.append(FtCreateKeywords.CASESENSITIVE) + return args + + +class NumericField(Field): + """ + Class for defining the numeric fields in a schema. + """ + + def __init__(self, name: TEncodable, alias: Optional[str] = None): + """ + Initialize a new NumericField instance. + + Args: + name (TEncodable): The name of the numeric field. + alias (Optional[str]): An alias for the field. + """ + super().__init__(name, FieldType.NUMERIC, alias) + + def toArgs(self) -> List[TEncodable]: + """ + Get the arguments representing the numeric field. + + Returns: + List[TEncodable]: A list of numeric field arguments. + """ + args = super().toArgs() + return args + + +class VectorFieldAttributes(ABC): + """ + Abstract base class for defining vector field attributes to be used after the vector algorithm name. + """ + + @abstractmethod + def __init__(self, dim: int, distance_metric: DistanceMetricType, type: VectorType): + """ + Initialize a new vector field attributes instance. + + Args: + dim (int): Number of dimensions in the vector. + distance_metric (DistanceMetricType): The distance metric used in vector type field. Can be one of [L2 | IP | COSINE]. + type (VectorType): Vector type. The only supported type is FLOAT32. + """ + self.dim = dim + self.distance_metric = distance_metric + self.type = type + + @abstractmethod + def toArgs(self) -> List[str]: + """ + Get the arguments to be used for the algorithm of the vector field. + + Returns: + List[str]: A list of arguments. + """ + args = [] + if self.dim: + args.extend([FtCreateKeywords.DIM, str(self.dim)]) + if self.distance_metric: + args.extend([FtCreateKeywords.DISTANCE_METRIC, self.distance_metric.name]) + if self.type: + args.extend([FtCreateKeywords.TYPE, self.type.name]) + return args + + +class VectorFieldAttributesFlat(VectorFieldAttributes): + """ + Get the arguments to be used for the FLAT algorithm of the vector field. + """ + + def __init__( + self, + dim: int, + distance_metric: DistanceMetricType, + type: VectorType, + initial_cap: Optional[int] = None, + ): + """ + Initialize a new flat vector field attributes instance. + + Args: + dim (int): Number of dimensions in the vector. + distance_metric (DistanceMetricType): The distance metric used in vector type field. Can be one of [L2 | IP | COSINE]. + type (VectorType): Vector type. The only supported type is FLOAT32. + initial_cap (Optional[int]): Initial vector capacity in the index affecting memory allocation size of the index. Defaults to 1024. + """ + super().__init__(dim, distance_metric, type) + self.initial_cap = initial_cap + + def toArgs(self) -> List[str]: + """ + Get the arguments representing the vector field created with FLAT algorithm. + + Returns: + List[str]: A list of FLAT algorithm type vector arguments. + """ + args = super().toArgs() + if self.initial_cap: + args.extend([FtCreateKeywords.INITIAL_CAP, str(self.initial_cap)]) + return args + + +class VectorFieldAttributesHnsw(VectorFieldAttributes): + """ + Get the arguments to be used for the HNSW algorithm of the vector field. + """ + + def __init__( + self, + dim: int, + distance_metric: DistanceMetricType, + type: VectorType, + initial_cap: Optional[int] = None, + m: Optional[int] = None, + ef_contruction: Optional[int] = None, + ef_runtime: Optional[int] = None, + ): + """ + Initialize a new TagField instance. + + Args: + dim (int): Number of dimensions in the vector. + distance_metric (DistanceMetricType): The distance metric used in vector type field. Can be one of [L2 | IP | COSINE]. + type (VectorType): Vector type. The only supported type is FLOAT32. + initial_cap (Optional[int]): Initial vector capacity in the index affecting memory allocation size of the index. Defaults to 1024. + m (Optional[int]): Number of maximum allowed outgoing edges for each node in the graph in each layer. Default is 16, maximum is 512. + ef_contruction (Optional[int]): Controls the number of vectors examined during index construction. Default value is 200, Maximum value is 4096. + ef_runtime (Optional[int]): Controls the number of vectors examined during query operations. Default value is 10, Maximum value is 4096. + """ + super().__init__(dim, distance_metric, type) + self.initial_cap = initial_cap + self.m = m + self.ef_contruction = ef_contruction + self.ef_runtime = ef_runtime + + def toArgs(self) -> List[str]: + """ + Get the arguments representing the vector field created with HSNW algorithm. + + Returns: + List[str]: A list of HNSW algorithm type vector arguments. + """ + args = super().toArgs() + if self.initial_cap: + args.extend([FtCreateKeywords.INITIAL_CAP, str(self.initial_cap)]) + if self.m: + args.extend([FtCreateKeywords.M, str(self.m)]) + if self.ef_contruction: + args.extend([FtCreateKeywords.EF_CONSTRUCTION, str(self.ef_contruction)]) + if self.ef_runtime: + args.extend([FtCreateKeywords.EF_RUNTIME, str(self.ef_runtime)]) + return args + + +class VectorField(Field): + """ + Class for defining vector field in a schema. + """ + + def __init__( + self, + name: TEncodable, + algorithm: VectorAlgorithm, + attributes: VectorFieldAttributes, + alias: Optional[str] = None, + ): + """ + Initialize a new VectorField instance. + + Args: + name (TEncodable): The name of the vector field. + algorithm (VectorAlgorithm): The vector indexing algorithm. + alias (Optional[str]): An alias for the field. + attributes (VectorFieldAttributes): Additional attributes to be passed with the vector field after the algorithm name. + """ + super().__init__(name, FieldType.VECTOR, alias) + self.algorithm = algorithm + self.attributes = attributes + + def toArgs(self) -> List[TEncodable]: + """ + Get the arguments representing the vector field. + + Returns: + List[TEncodable]: A list of vector field arguments. + """ + args = super().toArgs() + args.append(self.algorithm.value) + if self.attributes: + attribute_list = self.attributes.toArgs() + args.append(str(len(attribute_list))) + args.extend(attribute_list) + return args + + +class DataType(Enum): + """ + Options for the type of data for which the index is being created. + """ + + HASH = "HASH" + """ + If the created index will index HASH data. + """ + JSON = "JSON" + """ + If the created index will index JSON document data. + """ + + +class FtCreateOptions: + """ + This class represents the input options to be used in the [FT.CREATE] command. + All fields in this class are optional inputs for [FT.CREATE]. + """ + + def __init__( + self, + data_type: Optional[DataType] = None, + prefixes: Optional[List[str]] = None, + ): + """ + Initialize the [FT.CREATE] optional fields. + + Args: + data_type (Optional[DataType]): The type of data to be indexed using [FT.CREATE]. + prefixes (Optional[List[str]]): The prefix of the key to be indexed. + """ + self.data_type = data_type + self.prefixes = prefixes + + def toArgs(self) -> List[str]: + """ + Get the optional arguments for the [FT.CREATE] command. + + Returns: + List[str]: + List of [FT.CREATE] optional agruments. + """ + args = [] + if self.data_type: + args.append(FtCreateKeywords.ON) + args.append(self.data_type.value) + if self.prefixes: + args.append(FtCreateKeywords.PREFIX) + args.append(str(len(self.prefixes))) + for prefix in self.prefixes: + args.append(prefix) + return args diff --git a/python/python/tests/tests_server_modules/test_ft.py b/python/python/tests/tests_server_modules/test_ft.py new file mode 100644 index 0000000000..93f9efc9c1 --- /dev/null +++ b/python/python/tests/tests_server_modules/test_ft.py @@ -0,0 +1,103 @@ +# Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 +import uuid +from typing import List + +import pytest +from glide.async_commands.server_modules import ft +from glide.async_commands.server_modules.ft_options.ft_create_options import ( + DataType, + DistanceMetricType, + Field, + FtCreateOptions, + NumericField, + TextField, + VectorAlgorithm, + VectorField, + VectorFieldAttributesHnsw, + VectorType, +) +from glide.config import ProtocolVersion +from glide.constants import OK +from glide.glide_client import GlideClusterClient + + +@pytest.mark.asyncio +class TestVss: + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_vss_create(self, glide_client: GlideClusterClient): + fields: List[Field] = [] + textFieldTitle: TextField = TextField("$title") + numberField: NumericField = NumericField("$published_at") + textFieldCategory: TextField = TextField("$category") + fields.append(textFieldTitle) + fields.append(numberField) + fields.append(textFieldCategory) + + prefixes: List[str] = [] + prefixes.append("blog:post:") + + # Create an index with multiple fields with Hash data type. + index = str(uuid.uuid4()) + result = await ft.create( + glide_client, index, fields, FtCreateOptions(DataType.HASH, prefixes) + ) + assert result == OK + + # Create an index with multiple fields with JSON data type. + index2 = str(uuid.uuid4()) + result = await ft.create( + glide_client, index2, fields, FtCreateOptions(DataType.JSON, prefixes) + ) + assert result == OK + + # Create an index for vectors of size 2 + # FT.CREATE hash_idx1 ON HASH PREFIX 1 hash: SCHEMA vec AS VEC VECTOR HNSW 6 DIM 2 TYPE FLOAT32 DISTANCE_METRIC L2 + index3 = str(uuid.uuid4()) + prefixes = [] + prefixes.append("hash:") + fields = [] + vectorFieldHash: VectorField = VectorField( + name="vec", + algorithm=VectorAlgorithm.HNSW, + attributes=VectorFieldAttributesHnsw( + dim=2, distance_metric=DistanceMetricType.L2, type=VectorType.FLOAT32 + ), + alias="VEC", + ) + fields.append(vectorFieldHash) + + result = await ft.create( + glide_client, index3, fields, FtCreateOptions(DataType.HASH, prefixes) + ) + assert result == OK + + # Create a 6-dimensional JSON index using the HNSW algorithm + # FT.CREATE json_idx1 ON JSON PREFIX 1 json: SCHEMA $.vec AS VEC VECTOR HNSW 6 DIM 6 TYPE FLOAT32 DISTANCE_METRIC L2 + index4 = str(uuid.uuid4()) + prefixes = [] + prefixes.append("json:") + fields = [] + vectorFieldJson: VectorField = VectorField( + name="$.vec", + algorithm=VectorAlgorithm.HNSW, + attributes=VectorFieldAttributesHnsw( + dim=6, distance_metric=DistanceMetricType.L2, type=VectorType.FLOAT32 + ), + alias="VEC", + ) + fields.append(vectorFieldJson) + + result = await ft.create( + glide_client, index4, fields, FtCreateOptions(DataType.JSON, prefixes) + ) + assert result == OK + + # Create an index without FtCreateOptions + + index5 = str(uuid.uuid4()) + result = await ft.create(glide_client, index5, fields, FtCreateOptions()) + assert result == OK + + # TO-DO: + # Add additional tests from VSS documentation that require a combination of commands to run. From bbfce44fce28b09edbbc864b49ef3b2cc9ca4a82 Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Fri, 11 Oct 2024 11:01:18 -0700 Subject: [PATCH 004/180] Java: `FT.CREATE` (#2414) Signed-off-by: Yury-Fridlyand --- CHANGELOG.md | 2 + java/client/build.gradle | 4 +- .../glide/api/commands/servermodules/FT.java | 164 +++++++ .../models/commands/FT/FTCreateOptions.java | 413 ++++++++++++++++++ java/client/src/main/java/module-info.java | 2 + java/integTest/build.gradle | 3 - .../java/glide/modules/VectorSearchTests.java | 167 ++++++- 7 files changed, 747 insertions(+), 8 deletions(-) create mode 100644 java/client/src/main/java/glide/api/commands/servermodules/FT.java create mode 100644 java/client/src/main/java/glide/api/models/commands/FT/FTCreateOptions.java diff --git a/CHANGELOG.md b/CHANGELOG.md index bbb5c9d506..dc7d93523e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,8 @@ * Python: Add JSON.ARRLEN command ([#2403](https://github.com/valkey-io/valkey-glide/pull/2403)) * Python: Add JSON.CLEAR command ([#2418](https://github.com/valkey-io/valkey-glide/pull/2418)) +* Java: Added `FT.CREATE` ([#2414](https://github.com/valkey-io/valkey-glide/pull/2414)) + #### Breaking Changes #### Fixes diff --git a/java/client/build.gradle b/java/client/build.gradle index 46fa8f4cee..364b09ca1e 100644 --- a/java/client/build.gradle +++ b/java/client/build.gradle @@ -165,8 +165,8 @@ jar.dependsOn('copyNativeLib') javadoc.dependsOn('copyNativeLib') copyNativeLib.dependsOn('buildRustRelease') compileTestJava.dependsOn('copyNativeLib') -test.dependsOn('buildRust') -testFfi.dependsOn('buildRust') +test.dependsOn('buildRustRelease') +testFfi.dependsOn('buildRustRelease') test { exclude "glide/ffi/FfiTest.class" diff --git a/java/client/src/main/java/glide/api/commands/servermodules/FT.java b/java/client/src/main/java/glide/api/commands/servermodules/FT.java new file mode 100644 index 0000000000..51bde7a03d --- /dev/null +++ b/java/client/src/main/java/glide/api/commands/servermodules/FT.java @@ -0,0 +1,164 @@ +/** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ +package glide.api.commands.servermodules; + +import static glide.api.models.GlideString.gs; + +import glide.api.BaseClient; +import glide.api.GlideClient; +import glide.api.GlideClusterClient; +import glide.api.models.ClusterValue; +import glide.api.models.GlideString; +import glide.api.models.commands.FT.FTCreateOptions; +import glide.api.models.commands.FT.FTCreateOptions.FieldInfo; +import java.util.Arrays; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; +import lombok.NonNull; + +/** Module for vector search commands. */ +public class FT { + /** + * Creates an index and initiates a backfill of that index. + * + * @param client The client to execute the command. + * @param indexName The index name. + * @param fields Fields to populate into the index. + * @return OK. + * @example + *
{@code
+     * // Create an index for vectors of size 2:
+     * FT.create(client, "my_idx1", new FieldInfo[] {
+     *     new FieldInfo("vec", VectorFieldFlat.builder(DistanceMetric.L2, 2).build())
+     * }).get();
+     *
+     * // Create a 6-dimensional JSON index using the HNSW algorithm:
+     * FT.create(client, "my_idx2",
+     *     new FieldInfo[] { new FieldInfo("$.vec", "VEC",
+     *         VectorFieldHnsw.builder(DistanceMetric.L2, 6).numberOfEdges(32).build())
+     * }).get();
+     * }
+ */ + public static CompletableFuture create( + @NonNull BaseClient client, @NonNull String indexName, @NonNull FieldInfo[] fields) { + // Node: bug in meme DB - command fails if cmd is too short even though all mandatory args are + // present + // TODO confirm is it fixed or not and update docs if needed + return create(client, indexName, fields, FTCreateOptions.builder().build()); + } + + /** + * Creates an index and initiates a backfill of that index. + * + * @param client The client to execute the command. + * @param indexName The index name. + * @param fields Fields to populate into the index. + * @param options Additional parameters for the command - see {@link FTCreateOptions}. + * @return OK. + * @example + *
{@code
+     * // Create a 6-dimensional JSON index using the HNSW algorithm:
+     * FT.create(client, "json_idx1",
+     *     new FieldInfo[] { new FieldInfo("$.vec", "VEC",
+     *         VectorFieldHnsw.builder(DistanceMetric.L2, 6).numberOfEdges(32).build())
+     *     },
+     *     FTCreateOptions.builder().indexType(JSON).prefixes(new String[] {"json:"}).build(),
+     * ).get();
+     * }
+ */ + public static CompletableFuture create( + @NonNull BaseClient client, + @NonNull String indexName, + @NonNull FieldInfo[] fields, + @NonNull FTCreateOptions options) { + return create(client, gs(indexName), fields, options); + } + + /** + * Creates an index and initiates a backfill of that index. + * + * @param client The client to execute the command. + * @param indexName The index name. + * @param fields Fields to populate into the index. + * @return OK. + * @example + *
{@code
+     * // Create an index for vectors of size 2:
+     * FT.create(client, gs("my_idx1"), new FieldInfo[] {
+     *     new FieldInfo("vec", VectorFieldFlat.builder(DistanceMetric.L2, 2).build())
+     * }).get();
+     *
+     * // Create a 6-dimensional JSON index using the HNSW algorithm:
+     * FT.create(client, gs("my_idx2"),
+     *     new FieldInfo[] { new FieldInfo(gs("$.vec"), gs("VEC"),
+     *         VectorFieldHnsw.builder(DistanceMetric.L2, 6).numberOfEdges(32).build())
+     * }).get();
+     * }
+ */ + public static CompletableFuture create( + @NonNull BaseClient client, @NonNull GlideString indexName, @NonNull FieldInfo[] fields) { + // Node: bug in meme DB - command fails if cmd is too short even though all mandatory args are + // present + // TODO confirm is it fixed or not and update docs if needed + return create(client, indexName, fields, FTCreateOptions.builder().build()); + } + + /** + * Creates an index and initiates a backfill of that index. + * + * @param client The client to execute the command. + * @param indexName The index name. + * @param fields Fields to populate into the index. + * @param options Additional parameters for the command - see {@link FTCreateOptions}. + * @return OK. + * @example + *
{@code
+     * // Create a 6-dimensional JSON index using the HNSW algorithm:
+     * FT.create(client, gs("json_idx1"),
+     *     new FieldInfo[] { new FieldInfo(gs("$.vec"), gs("VEC"),
+     *         VectorFieldHnsw.builder(DistanceMetric.L2, 6).numberOfEdges(32).build())
+     *     },
+     *     FTCreateOptions.builder().indexType(JSON).prefixes(new String[] {"json:"}).build(),
+     * ).get();
+     * }
+ */ + public static CompletableFuture create( + @NonNull BaseClient client, + @NonNull GlideString indexName, + @NonNull FieldInfo[] fields, + @NonNull FTCreateOptions options) { + var args = + Stream.of( + new GlideString[] {gs("FT.CREATE"), indexName}, + options.toArgs(), + new GlideString[] {gs("SCHEMA")}, + Arrays.stream(fields) + .map(FieldInfo::toArgs) + .flatMap(Arrays::stream) + .toArray(GlideString[]::new)) + .flatMap(Arrays::stream) + .toArray(GlideString[]::new); + return executeCommand(client, args, false); + } + + /** + * A wrapper for custom command API. + * + * @param client The client to execute the command. + * @param args The command line. + * @param returnsMap - true if command returns a map + */ + @SuppressWarnings("unchecked") + private static CompletableFuture executeCommand( + BaseClient client, GlideString[] args, boolean returnsMap) { + if (client instanceof GlideClient) { + return ((GlideClient) client).customCommand(args).thenApply(r -> (T) r); + } else if (client instanceof GlideClusterClient) { + return ((GlideClusterClient) client) + .customCommand(args) + .thenApply(returnsMap ? ClusterValue::getMultiValue : ClusterValue::getSingleValue) + .thenApply(r -> (T) r); + } + throw new IllegalArgumentException( + "Unknown type of client, should be either `GlideClient` or `GlideClusterClient`"); + } +} diff --git a/java/client/src/main/java/glide/api/models/commands/FT/FTCreateOptions.java b/java/client/src/main/java/glide/api/models/commands/FT/FTCreateOptions.java new file mode 100644 index 0000000000..1cdb6c77d0 --- /dev/null +++ b/java/client/src/main/java/glide/api/models/commands/FT/FTCreateOptions.java @@ -0,0 +1,413 @@ +/** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ +package glide.api.models.commands.FT; + +import static glide.api.models.GlideString.gs; + +import glide.api.BaseClient; +import glide.api.commands.servermodules.FT; +import glide.api.models.GlideString; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.NonNull; + +/** + * Additional parameters for {@link FT#create(BaseClient, String, FieldInfo[], FTCreateOptions)} + * command. + */ +@Builder +public class FTCreateOptions { + /** The index type. If not given a {@link IndexType#HASH} index is created. */ + private final IndexType indexType; + + /** A list of prefixes of index definitions. */ + private final GlideString[] prefixes; + + FTCreateOptions(IndexType indexType, GlideString[] prefixes) { + this.indexType = indexType; + this.prefixes = prefixes; + } + + public static FTCreateOptionsBuilder builder() { + return new FTCreateOptionsBuilder(); + } + + public GlideString[] toArgs() { + var args = new ArrayList(); + if (indexType != null) { + args.add(gs("ON")); + args.add(gs(indexType.toString())); + } + if (prefixes != null && prefixes.length > 0) { + args.add(gs("PREFIX")); + args.add(gs(Integer.toString(prefixes.length))); + args.addAll(List.of(prefixes)); + } + return args.toArray(GlideString[]::new); + } + + public static class FTCreateOptionsBuilder { + public FTCreateOptionsBuilder prefixes(String[] prefixes) { + this.prefixes = Stream.of(prefixes).map(GlideString::gs).toArray(GlideString[]::new); + return this; + } + } + + /** Type of the index dataset. */ + public enum IndexType { + /** Data stored in hashes, so field identifiers are field names within the hashes. */ + HASH, + /** Data stored in JSONs, so field identifiers are JSON Path expressions. */ + JSON + } + + /** + * A vector search field. Could be one of the following: + * + *
    + *
  • {@link NumericField} + *
  • {@link TextField} + *
  • {@link TagField} + *
  • {@link VectorFieldHnsw} + *
  • {@link VectorFieldFlat} + *
+ */ + public interface Field { + /** Convert to module API. */ + String[] toArgs(); + } + + private enum FieldType { + NUMERIC, + TEXT, + TAG, + VECTOR + } + + /** Field contains a number. */ + public static class NumericField implements Field { + @Override + public String[] toArgs() { + return new String[] {FieldType.NUMERIC.toString()}; + } + } + + /** Field contains any blob of data. */ + public static class TextField implements Field { + @Override + public String[] toArgs() { + return new String[] {FieldType.TEXT.toString()}; + } + } + + /** + * Tag fields are similar to full-text fields, but they interpret the text as a simple list of + * tags delimited by a separator character.
+ * For {@link IndexType#HASH} fields, separator default is a comma (,). For {@link + * IndexType#JSON} fields, there is no default separator; you must declare one explicitly if + * needed. + */ + public static class TagField implements Field { + private Optional separator; + private final boolean caseSensitive; + + /** Create a TAG field. */ + public TagField() { + this.separator = Optional.empty(); + this.caseSensitive = false; + } + + /** + * Create a TAG field. + * + * @param separator The tag separator. + */ + public TagField(char separator) { + this.separator = Optional.of(separator); + this.caseSensitive = false; + } + + /** + * Create a TAG field. + * + * @param separator The tag separator. + * @param caseSensitive Whether to keep the original case. + */ + public TagField(char separator, boolean caseSensitive) { + this.separator = Optional.of(separator); + this.caseSensitive = caseSensitive; + } + + /** + * Create a TAG field. + * + * @param caseSensitive Whether to keep the original case. + */ + public TagField(boolean caseSensitive) { + this.caseSensitive = caseSensitive; + } + + @Override + public String[] toArgs() { + var args = new ArrayList(); + args.add(FieldType.TAG.toString()); + if (separator.isPresent()) { + args.add("SEPARATOR"); + args.add(separator.get().toString()); + } + if (caseSensitive) { + args.add("CASESENSITIVE"); + } + return args.toArray(String[]::new); + } + } + + /** + * Distance metrics to measure the degree of similarity between two vectors.
+ * The above metrics calculate distance between two vectors, where the smaller the value is, the + * closer the two vectors are in the vector space. + */ + public enum DistanceMetric { + /** Euclidean distance between two vectors. */ + L2, + /** Inner product of two vectors. */ + IP, + /** Cosine distance of two vectors. */ + COSINE + } + + /** Superclass for vector field implementations, contains common logic. */ + @AllArgsConstructor(access = AccessLevel.PROTECTED) + abstract static class VectorField implements Field { + private final Map params; + private final VectorAlgorithm algorithm; + + @Override + public String[] toArgs() { + var args = new ArrayList(); + args.add(FieldType.VECTOR.toString()); + args.add(algorithm.toString()); + args.add(Integer.toString(params.size() * 2)); + params.forEach( + (name, value) -> { + args.add(name.toString()); + args.add(value); + }); + return args.toArray(String[]::new); + } + } + + private enum VectorAlgorithm { + HNSW, + FLAT + } + + private enum VectorAlgorithmParam { + M, + EF_CONSTRUCTION, + EF_RUNTIME, + TYPE, + DIM, + DISTANCE_METRIC, + INITIAL_CAP + } + + /** + * Vector field that supports vector search by HNSM (Hierarchical Navigable Small + * World) algorithm.
+ * The algorithm provides an approximation of the correct answer in exchange for substantially + * lower execution times. + */ + public static class VectorFieldHnsw extends VectorField { + private VectorFieldHnsw(Map params) { + super(params, VectorAlgorithm.HNSW); + } + + /** + * Init a builder. + * + * @param distanceMetric {@link DistanceMetric} to measure the degree of similarity between two + * vectors. + * @param dimensions Vector dimension, specified as a positive integer. Maximum: 32768 + */ + public static VectorFieldHnswBuilder builder( + @NonNull DistanceMetric distanceMetric, int dimensions) { + return new VectorFieldHnswBuilder(distanceMetric, dimensions); + } + } + + public static class VectorFieldHnswBuilder extends VectorFieldBuilder { + VectorFieldHnswBuilder(DistanceMetric distanceMetric, int dimensions) { + super(distanceMetric, dimensions); + } + + @Override + public VectorFieldHnsw build() { + return new VectorFieldHnsw(params); + } + + /** + * Number of maximum allowed outgoing edges for each node in the graph in each layer. On layer + * zero the maximal number of outgoing edges is doubled. Default is 16 Maximum is 512. + */ + public VectorFieldHnswBuilder numberOfEdges(int numberOfEdges) { + params.put(VectorAlgorithmParam.M, Integer.toString(numberOfEdges)); + return this; + } + + /** + * (Optional) The number of vectors examined during index construction. Higher values for this + * parameter will improve recall ratio at the expense of longer index creation times. Default + * value is 200. Maximum value is 4096. + */ + public VectorFieldHnswBuilder vectorsExaminedOnConstruction(int vectorsExaminedOnConstruction) { + params.put( + VectorAlgorithmParam.EF_CONSTRUCTION, Integer.toString(vectorsExaminedOnConstruction)); + return this; + } + + /** + * (Optional) The number of vectors examined during query operations. Higher values for this + * parameter can yield improved recall at the expense of longer query times. The value of this + * parameter can be overriden on a per-query basis. Default value is 10. Maximum value is 4096. + */ + public VectorFieldHnswBuilder vectorsExaminedOnRuntime(int vectorsExaminedOnRuntime) { + params.put(VectorAlgorithmParam.EF_RUNTIME, Integer.toString(vectorsExaminedOnRuntime)); + return this; + } + } + + /** + * Vector field that supports vector search by FLAT (brute force) algorithm.
+ * The algorithm is a brute force linear processing of each vector in the index, yielding exact + * answers within the bounds of the precision of the distance computations. + */ + public static class VectorFieldFlat extends VectorField { + + private VectorFieldFlat(Map params) { + super(params, VectorAlgorithm.FLAT); + } + + /** + * Init a builder. + * + * @param distanceMetric {@link DistanceMetric} to measure the degree of similarity between two + * vectors. + * @param dimensions Vector dimension, specified as a positive integer. Maximum: 32768 + */ + public static VectorFieldFlatBuilder builder( + @NonNull DistanceMetric distanceMetric, int dimensions) { + return new VectorFieldFlatBuilder(distanceMetric, dimensions); + } + } + + public static class VectorFieldFlatBuilder extends VectorFieldBuilder { + VectorFieldFlatBuilder(DistanceMetric distanceMetric, int dimensions) { + super(distanceMetric, dimensions); + } + + @Override + public VectorFieldFlat build() { + return new VectorFieldFlat(params); + } + } + + abstract static class VectorFieldBuilder> { + final Map params = new HashMap<>(); + + VectorFieldBuilder(DistanceMetric distanceMetric, int dimensions) { + params.put(VectorAlgorithmParam.TYPE, "FLOAT32"); + params.put(VectorAlgorithmParam.DIM, Integer.toString(dimensions)); + params.put(VectorAlgorithmParam.DISTANCE_METRIC, distanceMetric.toString()); + } + + /** + * Initial vector capacity in the index affecting memory allocation size of the index. Defaults + * to 1024. + */ + @SuppressWarnings("unchecked") + public T initialCapacity(int initialCapacity) { + params.put(VectorAlgorithmParam.INITIAL_CAP, Integer.toString(initialCapacity)); + return (T) this; + } + + public abstract VectorField build(); + } + + /** Field definition to be added into index schema. */ + public static class FieldInfo { + private final GlideString identifier; + private final GlideString alias; + private final Field field; + + /** + * Field definition to be added into index schema. + * + * @param identifier Field identifier (name). + * @param field The {@link Field} itself. + */ + public FieldInfo(@NonNull String identifier, @NonNull Field field) { + this.identifier = gs(identifier); + this.field = field; + this.alias = null; + } + + /** + * Field definition to be added into index schema. + * + * @param identifier Field identifier (name). + * @param alias Field alias. + * @param field The {@link Field} itself. + */ + public FieldInfo(@NonNull String identifier, @NonNull String alias, @NonNull Field field) { + this.identifier = gs(identifier); + this.alias = gs(alias); + this.field = field; + } + + /** + * Field definition to be added into index schema. + * + * @param identifier Field identifier (name). + * @param field The {@link Field} itself. + */ + public FieldInfo(@NonNull GlideString identifier, @NonNull Field field) { + this.identifier = identifier; + this.field = field; + this.alias = null; + } + + /** + * Field definition to be added into index schema. + * + * @param identifier Field identifier (name). + * @param alias Field alias. + * @param field The {@link Field} itself. + */ + public FieldInfo( + @NonNull GlideString identifier, @NonNull GlideString alias, @NonNull Field field) { + this.identifier = identifier; + this.alias = alias; + this.field = field; + } + + /** Convert to module API. */ + public GlideString[] toArgs() { + var args = new ArrayList(); + args.add(identifier); + if (alias != null) { + args.add(gs("AS")); + args.add(alias); + } + args.addAll(Stream.of(field.toArgs()).map(GlideString::gs).collect(Collectors.toList())); + return args.toArray(GlideString[]::new); + } + } +} diff --git a/java/client/src/main/java/module-info.java b/java/client/src/main/java/module-info.java index 99c4655082..183e6c0410 100644 --- a/java/client/src/main/java/module-info.java +++ b/java/client/src/main/java/module-info.java @@ -9,8 +9,10 @@ exports glide.api.models.commands.function; exports glide.api.models.commands.scan; exports glide.api.models.commands.stream; + exports glide.api.models.commands.FT; exports glide.api.models.configuration; exports glide.api.models.exceptions; + exports glide.api.commands.servermodules; requires com.google.protobuf; requires io.netty.codec; diff --git a/java/integTest/build.gradle b/java/integTest/build.gradle index d467b4ebbb..c2032d05d1 100644 --- a/java/integTest/build.gradle +++ b/java/integTest/build.gradle @@ -102,7 +102,6 @@ tasks.register('startStandalone') { } } - test.dependsOn 'stopAllBeforeTests' stopAllBeforeTests.finalizedBy 'clearDirs' clearDirs.finalizedBy 'startStandalone' @@ -112,8 +111,6 @@ test.dependsOn ':client:buildRustRelease' tasks.withType(Test) { doFirst { - println "Cluster hosts = ${clusterHosts}" - println "Standalone hosts = ${standaloneHosts}" systemProperty 'test.server.standalone', standaloneHosts systemProperty 'test.server.cluster', clusterHosts systemProperty 'test.server.tls', System.getProperty("tls") diff --git a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java index 07b0946b3d..67387026bd 100644 --- a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java +++ b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java @@ -2,23 +2,184 @@ package glide.modules; import static glide.TestUtilities.commonClusterClientConfig; +import static glide.api.BaseClient.OK; +import static glide.api.models.configuration.RequestRoutingConfiguration.SimpleMultiNodeRoute.ALL_PRIMARIES; import static glide.api.models.configuration.RequestRoutingConfiguration.SimpleSingleNodeRoute.RANDOM; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import glide.api.GlideClusterClient; +import glide.api.commands.servermodules.FT; +import glide.api.models.commands.FT.FTCreateOptions; +import glide.api.models.commands.FT.FTCreateOptions.DistanceMetric; +import glide.api.models.commands.FT.FTCreateOptions.FieldInfo; +import glide.api.models.commands.FT.FTCreateOptions.IndexType; +import glide.api.models.commands.FT.FTCreateOptions.NumericField; +import glide.api.models.commands.FT.FTCreateOptions.TagField; +import glide.api.models.commands.FT.FTCreateOptions.TextField; +import glide.api.models.commands.FT.FTCreateOptions.VectorFieldFlat; +import glide.api.models.commands.FT.FTCreateOptions.VectorFieldHnsw; +import glide.api.models.commands.FlushMode; import glide.api.models.commands.InfoOptions.Section; +import glide.api.models.exceptions.RequestException; +import java.util.UUID; +import java.util.concurrent.ExecutionException; import lombok.SneakyThrows; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; public class VectorSearchTests { - @Test + private static GlideClusterClient client; + + @BeforeAll @SneakyThrows - public void check_module_loaded() { - var client = + public static void init() { + client = GlideClusterClient.createClient(commonClusterClientConfig().requestTimeout(5000).build()) .get(); + client.flushall(FlushMode.SYNC, ALL_PRIMARIES).get(); + } + + @AfterAll + @SneakyThrows + public static void teardown() { + client.close(); + } + + @Test + @SneakyThrows + public void check_module_loaded() { var info = client.info(new Section[] {Section.MODULES}, RANDOM).get().getSingleValue(); assertTrue(info.contains("# search_index_stats")); } + + @SneakyThrows + @Test + public void ft_create() { + // create few simple indices + assertEquals( + OK, + FT.create( + client, + UUID.randomUUID().toString(), + new FieldInfo[] { + new FieldInfo("vec", "VEC", VectorFieldHnsw.builder(DistanceMetric.L2, 2).build()) + }) + .get()); + assertEquals( + OK, + FT.create( + client, + UUID.randomUUID().toString(), + new FieldInfo[] { + new FieldInfo( + "$.vec", "VEC", VectorFieldFlat.builder(DistanceMetric.L2, 6).build()) + }, + FTCreateOptions.builder() + .indexType(IndexType.JSON) + .prefixes(new String[] {"json:"}) + .build()) + .get()); + + // create an index with HNSW vector with additional parameters + assertEquals( + OK, + FT.create( + client, + UUID.randomUUID().toString(), + new FieldInfo[] { + new FieldInfo( + "doc_embedding", + VectorFieldHnsw.builder(DistanceMetric.COSINE, 1536) + .numberOfEdges(40) + .vectorsExaminedOnConstruction(250) + .vectorsExaminedOnRuntime(40) + .build()) + }, + FTCreateOptions.builder() + .indexType(IndexType.HASH) + .prefixes(new String[] {"docs:"}) + .build()) + .get()); + + // create an index with multiple fields + assertEquals( + OK, + FT.create( + client, + UUID.randomUUID().toString(), + new FieldInfo[] { + new FieldInfo("title", new TextField()), + new FieldInfo("published_at", new NumericField()), + new FieldInfo("category", new TagField()) + }, + FTCreateOptions.builder() + .indexType(IndexType.HASH) + .prefixes(new String[] {"blog:post:"}) + .build()) + .get()); + + // create an index with multiple prefixes + var name = UUID.randomUUID().toString(); + assertEquals( + OK, + FT.create( + client, + name, + new FieldInfo[] { + new FieldInfo("author_id", new TagField()), + new FieldInfo("author_ids", new TagField()), + new FieldInfo("title", new TextField()), + new FieldInfo("name", new TextField()) + }, + FTCreateOptions.builder() + .indexType(IndexType.HASH) + .prefixes(new String[] {"author:details:", "book:details:"}) + .build()) + .get()); + + // create a duplicating index + var exception = + assertThrows( + ExecutionException.class, + () -> + FT.create( + client, + name, + new FieldInfo[] { + new FieldInfo("title", new TextField()), + new FieldInfo("name", new TextField()) + }) + .get()); + assertInstanceOf(RequestException.class, exception.getCause()); + assertTrue(exception.getMessage().contains("already exists")); + + // create an index without fields + exception = + assertThrows( + ExecutionException.class, + () -> FT.create(client, UUID.randomUUID().toString(), new FieldInfo[0]).get()); + assertInstanceOf(RequestException.class, exception.getCause()); + assertTrue(exception.getMessage().contains("wrong number of arguments")); + + // duplicated field name + exception = + assertThrows( + ExecutionException.class, + () -> + FT.create( + client, + UUID.randomUUID().toString(), + new FieldInfo[] { + new FieldInfo("name", new TextField()), + new FieldInfo("name", new TextField()) + }) + .get()); + assertInstanceOf(RequestException.class, exception.getCause()); + assertTrue(exception.getMessage().contains("already exists")); + } } From 855bb336bfe93ba011c633872d57fe31b6cbe41c Mon Sep 17 00:00:00 2001 From: Gilboab <97948000+GilboaAWS@users.noreply.github.com> Date: Sun, 13 Oct 2024 12:34:23 +0300 Subject: [PATCH 005/180] Added inflightRequestsLimit client config to java (#2443) Added inflightRequestsLimit client config to java (#2408) * Add inflight request limit config to java Signed-off-by: GilboaAWS --- .../BaseClientConfiguration.java | 8 +++ .../GlideClientConfiguration.java | 1 + .../GlideClusterClientConfiguration.java | 1 + .../glide/managers/ConnectionManager.java | 4 ++ .../glide/managers/ConnectionManagerTest.java | 4 ++ .../test/java/glide/SharedClientTests.java | 60 +++++++++++++++++++ 6 files changed, 78 insertions(+) diff --git a/java/client/src/main/java/glide/api/models/configuration/BaseClientConfiguration.java b/java/client/src/main/java/glide/api/models/configuration/BaseClientConfiguration.java index b6cc4e26ff..e0a4ed5500 100644 --- a/java/client/src/main/java/glide/api/models/configuration/BaseClientConfiguration.java +++ b/java/client/src/main/java/glide/api/models/configuration/BaseClientConfiguration.java @@ -66,4 +66,12 @@ public abstract class BaseClientConfiguration { private final ThreadPoolResource threadPoolResource; public abstract BaseSubscriptionConfiguration getSubscriptionConfiguration(); + + /** + * The maximum number of concurrent requests allowed to be in-flight (sent but not yet completed). + * This limit is used to control the memory usage and prevent the client from overwhelming the + * server or getting stuck in case of a queue backlog. If not set, a default value of 1000 will be + * used. + */ + private final Integer inflightRequestsLimit; } diff --git a/java/client/src/main/java/glide/api/models/configuration/GlideClientConfiguration.java b/java/client/src/main/java/glide/api/models/configuration/GlideClientConfiguration.java index edb7bbb326..83d84e7c1f 100644 --- a/java/client/src/main/java/glide/api/models/configuration/GlideClientConfiguration.java +++ b/java/client/src/main/java/glide/api/models/configuration/GlideClientConfiguration.java @@ -23,6 +23,7 @@ * .databaseId(1) * .clientName("GLIDE") * .subscriptionConfiguration(subscriptionConfiguration) + * .inflightRequestsLimit(1000) * .build(); * } */ diff --git a/java/client/src/main/java/glide/api/models/configuration/GlideClusterClientConfiguration.java b/java/client/src/main/java/glide/api/models/configuration/GlideClusterClientConfiguration.java index 2e49e7b66d..b1d1c7590c 100644 --- a/java/client/src/main/java/glide/api/models/configuration/GlideClusterClientConfiguration.java +++ b/java/client/src/main/java/glide/api/models/configuration/GlideClusterClientConfiguration.java @@ -22,6 +22,7 @@ * .requestTimeout(2000) * .clientName("GLIDE") * .subscriptionConfiguration(subscriptionConfiguration) + * .inflightRequestsLimit(1000) * .build(); * } */ diff --git a/java/client/src/main/java/glide/managers/ConnectionManager.java b/java/client/src/main/java/glide/managers/ConnectionManager.java index 58328d375c..a5a8b9c5c3 100644 --- a/java/client/src/main/java/glide/managers/ConnectionManager.java +++ b/java/client/src/main/java/glide/managers/ConnectionManager.java @@ -118,6 +118,10 @@ private ConnectionRequest.Builder setupConnectionRequestBuilderBaseConfiguration connectionRequestBuilder.setClientName(configuration.getClientName()); } + if (configuration.getInflightRequestsLimit() != null) { + connectionRequestBuilder.setInflightRequestsLimit(configuration.getInflightRequestsLimit()); + } + return connectionRequestBuilder; } diff --git a/java/client/src/test/java/glide/managers/ConnectionManagerTest.java b/java/client/src/test/java/glide/managers/ConnectionManagerTest.java index 9a3ebe6e19..7a8f1a0d44 100644 --- a/java/client/src/test/java/glide/managers/ConnectionManagerTest.java +++ b/java/client/src/test/java/glide/managers/ConnectionManagerTest.java @@ -64,6 +64,8 @@ public class ConnectionManagerTest { private static final String CLIENT_NAME = "ClientName"; + private static final int INFLIGHT_REQUESTS_LIMIT = 1000; + @BeforeEach public void setUp() { channel = mock(ChannelHandler.class); @@ -149,6 +151,7 @@ public void connection_request_protobuf_generation_with_all_fields_set() { .subscription(EXACT, gs("channel_2")) .subscription(PATTERN, gs("*chatRoom*")) .build()) + .inflightRequestsLimit(INFLIGHT_REQUESTS_LIMIT) .build(); ConnectionRequest expectedProtobufConnectionRequest = ConnectionRequest.newBuilder() @@ -193,6 +196,7 @@ public void connection_request_protobuf_generation_with_all_fields_set() { ByteString.copyFrom(gs("*chatRoom*").getBytes())) .build())) .build()) + .setInflightRequestsLimit(INFLIGHT_REQUESTS_LIMIT) .build(); CompletableFuture completedFuture = new CompletableFuture<>(); Response response = Response.newBuilder().setConstantResponse(ConstantResponse.OK).build(); diff --git a/java/integTest/src/test/java/glide/SharedClientTests.java b/java/integTest/src/test/java/glide/SharedClientTests.java index bf106f1ff4..3650c079e3 100644 --- a/java/integTest/src/test/java/glide/SharedClientTests.java +++ b/java/integTest/src/test/java/glide/SharedClientTests.java @@ -6,18 +6,25 @@ import static glide.TestUtilities.getRandomString; import static glide.api.BaseClient.OK; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import glide.api.BaseClient; import glide.api.GlideClient; import glide.api.GlideClusterClient; +import glide.api.models.exceptions.RequestException; +import java.util.ArrayList; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import java.util.stream.Stream; import lombok.Getter; import lombok.SneakyThrows; +import net.bytebuddy.utility.RandomString; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Timeout; @@ -111,4 +118,57 @@ public void client_can_handle_concurrent_workload(BaseClient client, int valueSi executorService.shutdown(); } + + private static Stream inflightRequestsLimitSizeAndClusterMode() { + return Stream.of( + Arguments.of(false, 5), + Arguments.of(false, 100), + Arguments.of(false, 1000), + Arguments.of(true, 5), + Arguments.of(true, 100), + Arguments.of(true, 1000)); + } + + @SneakyThrows + @ParameterizedTest() + @MethodSource("inflightRequestsLimitSizeAndClusterMode") + public void inflight_requests_limit(boolean clusterMode, int inflightRequestsLimit) { + BaseClient testClient; + String keyName = "nonexistkeylist" + RandomString.make(4); + + if (clusterMode) { + testClient = + GlideClient.createClient( + commonClientConfig().inflightRequestsLimit(inflightRequestsLimit).build()) + .get(); + } else { + testClient = + GlideClusterClient.createClient( + commonClusterClientConfig().inflightRequestsLimit(inflightRequestsLimit).build()) + .get(); + } + + // exercise + List> responses = new ArrayList<>(); + for (int i = 0; i < inflightRequestsLimit + 1; i++) { + responses.add(testClient.blpop(new String[] {keyName}, 0)); + } + + // verify + // Check that all requests except the last one are still pending + for (int i = 0; i < inflightRequestsLimit; i++) { + assertFalse(responses.get(i).isDone(), "Request " + i + " should still be pending"); + } + + // The last request should complete exceptionally + try { + responses.get(inflightRequestsLimit).get(100, TimeUnit.MILLISECONDS); + fail("Expected the last request to throw an exception"); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof RequestException); + assertTrue(e.getCause().getMessage().contains("maximum inflight requests")); + } + + testClient.close(); + } } From 705b4047f4c4623fdf876e1af745f5f9c2763681 Mon Sep 17 00:00:00 2001 From: Shoham Elias <116083498+shohamazon@users.noreply.github.com> Date: Sun, 13 Oct 2024 17:27:08 +0300 Subject: [PATCH 006/180] Fix CI failing (#2444) Signed-off-by: Shoham Elias --- .github/workflows/python.yml | 7 +++---- .github/workflows/start-self-hosted-runner/action.yml | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index a1b5a16721..c85045df07 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -216,10 +216,9 @@ jobs: - name: Install dependencies if: always() - working-directory: ./python - run: | - python -m pip install --upgrade pip - pip install flake8 isort black + uses: threeal/pipx-install-action@latest + with: + packages: flake8 isort black - name: Lint python with isort if: always() diff --git a/.github/workflows/start-self-hosted-runner/action.yml b/.github/workflows/start-self-hosted-runner/action.yml index 2e99795c67..45038b2d1d 100644 --- a/.github/workflows/start-self-hosted-runner/action.yml +++ b/.github/workflows/start-self-hosted-runner/action.yml @@ -23,8 +23,8 @@ runs: - name: Start EC2 self hosted runner shell: bash run: | - sudo apt update - sudo apt install awscli -y + sudo snap refresh + sudo snap install aws-cli --classic command_id=$(aws ssm send-command --instance-ids ${{ inputs.ec2-instance-id }} --document-name StartGithubSelfHostedRunner --query Command.CommandId --output text) while [[ "$invoke_status" != "Success" && "$invoke_status" != "Failed" ]]; do From 3661f8727b2ab8a33202d61c0cce609df5bb0902 Mon Sep 17 00:00:00 2001 From: Shoham Elias <116083498+shohamazon@users.noreply.github.com> Date: Mon, 14 Oct 2024 17:40:23 +0300 Subject: [PATCH 007/180] Fix mypy failing (#2453) --------- Signed-off-by: Shoham Elias --- python/python/tests/test_async_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/python/tests/test_async_client.py b/python/python/tests/test_async_client.py index 62d66801c6..7566194dcc 100644 --- a/python/python/tests/test_async_client.py +++ b/python/python/tests/test_async_client.py @@ -9725,7 +9725,7 @@ async def cluster_route_custom_command_slot_route( route_class = SlotKeyRoute if is_slot_key else SlotIdRoute route_second_arg = "foo" if is_slot_key else 4000 primary_res = await glide_client.custom_command( - ["CLUSTER", "NODES"], route_class(SlotType.PRIMARY, route_second_arg) + ["CLUSTER", "NODES"], route_class(SlotType.PRIMARY, route_second_arg) # type: ignore ) assert isinstance(primary_res, bytes) primary_res = primary_res.decode() @@ -9738,7 +9738,7 @@ async def cluster_route_custom_command_slot_route( expected_primary_node_id = node_line.split(" ")[0] replica_res = await glide_client.custom_command( - ["CLUSTER", "NODES"], route_class(SlotType.REPLICA, route_second_arg) + ["CLUSTER", "NODES"], route_class(SlotType.REPLICA, route_second_arg) # type: ignore ) assert isinstance(replica_res, bytes) replica_res = replica_res.decode() From 9d953d370e59f1666818b9b335d0850e1d757541 Mon Sep 17 00:00:00 2001 From: eifrah-aws Date: Tue, 15 Oct 2024 10:28:27 +0300 Subject: [PATCH 008/180] Make redis-rs part of this repo (#2456) * Make redis-rs part of this repo * Improved PYTHON DEVELOPERS.md file * Added Makefile for unified build method for all the languages --- .github/workflows/python.yml | 6 +- .github/workflows/semgrep.yml | 2 +- .gitignore | 2 + .gitmodules | 3 - Makefile | 118 + benchmarks/rust/Cargo.toml | 2 +- csharp/lib/Cargo.toml | 2 +- glide-core/Cargo.toml | 4 +- glide-core/redis-rs/Cargo.toml | 3 + glide-core/redis-rs/LICENSE | 33 + glide-core/redis-rs/Makefile | 96 + glide-core/redis-rs/README.md | 233 + glide-core/redis-rs/afl/.gitignore | 2 + glide-core/redis-rs/afl/parser/Cargo.toml | 17 + glide-core/redis-rs/afl/parser/in/array | 5 + glide-core/redis-rs/afl/parser/in/array-null | 1 + glide-core/redis-rs/afl/parser/in/bulkstring | 2 + .../redis-rs/afl/parser/in/bulkstring-null | 1 + glide-core/redis-rs/afl/parser/in/error | 1 + glide-core/redis-rs/afl/parser/in/integer | 1 + .../redis-rs/afl/parser/in/invalid-string | 2 + glide-core/redis-rs/afl/parser/in/string | 1 + glide-core/redis-rs/afl/parser/src/main.rs | 9 + .../redis-rs/afl/parser/src/reproduce.rs | 13 + glide-core/redis-rs/appveyor.yml | 23 + glide-core/redis-rs/redis-test/CHANGELOG.md | 44 + glide-core/redis-rs/redis-test/Cargo.toml | 26 + glide-core/redis-rs/redis-test/LICENSE | 33 + glide-core/redis-rs/redis-test/README.md | 4 + glide-core/redis-rs/redis-test/release.toml | 1 + glide-core/redis-rs/redis-test/src/lib.rs | 426 ++ glide-core/redis-rs/redis/CHANGELOG.md | 828 ++++ glide-core/redis-rs/redis/Cargo.toml | 227 + glide-core/redis-rs/redis/LICENSE | 33 + .../redis-rs/redis/benches/bench_basic.rs | 277 ++ .../redis-rs/redis/benches/bench_cluster.rs | 108 + .../redis/benches/bench_cluster_async.rs | 88 + .../redis-rs/redis/examples/async-await.rs | 24 + .../redis/examples/async-connection-loss.rs | 97 + .../redis/examples/async-multiplexed.rs | 46 + .../redis-rs/redis/examples/async-pub-sub.rs | 22 + .../redis-rs/redis/examples/async-scan.rs | 25 + glide-core/redis-rs/redis/examples/basic.rs | 169 + .../redis-rs/redis/examples/geospatial.rs | 68 + glide-core/redis-rs/redis/examples/streams.rs | 270 ++ glide-core/redis-rs/redis/release.toml | 2 + glide-core/redis-rs/redis/src/acl.rs | 312 ++ .../redis-rs/redis/src/aio/async_std.rs | 269 ++ .../redis-rs/redis/src/aio/connection.rs | 543 +++ .../redis/src/aio/connection_manager.rs | 310 ++ glide-core/redis-rs/redis/src/aio/mod.rs | 286 ++ .../redis/src/aio/multiplexed_connection.rs | 656 +++ glide-core/redis-rs/redis/src/aio/runtime.rs | 82 + glide-core/redis-rs/redis/src/aio/tokio.rs | 204 + glide-core/redis-rs/redis/src/client.rs | 855 ++++ glide-core/redis-rs/redis/src/cluster.rs | 1076 +++++ .../redis-rs/redis/src/cluster_async/LICENSE | 7 + .../cluster_async/connections_container.rs | 881 ++++ .../src/cluster_async/connections_logic.rs | 481 ++ .../redis-rs/redis/src/cluster_async/mod.rs | 2656 +++++++++++ .../redis-rs/redis/src/cluster_client.rs | 752 +++ .../redis-rs/redis/src/cluster_pipeline.rs | 151 + .../redis-rs/redis/src/cluster_routing.rs | 1374 ++++++ .../redis-rs/redis/src/cluster_slotmap.rs | 435 ++ .../redis-rs/redis/src/cluster_topology.rs | 645 +++ glide-core/redis-rs/redis/src/cmd.rs | 663 +++ .../redis/src/commands/cluster_scan.rs | 720 +++ .../redis-rs/redis/src/commands/json.rs | 390 ++ .../redis-rs/redis/src/commands/macros.rs | 275 ++ glide-core/redis-rs/redis/src/commands/mod.rs | 2190 +++++++++ glide-core/redis-rs/redis/src/connection.rs | 1997 ++++++++ glide-core/redis-rs/redis/src/geo.rs | 361 ++ glide-core/redis-rs/redis/src/lib.rs | 506 ++ glide-core/redis-rs/redis/src/macros.rs | 7 + glide-core/redis-rs/redis/src/parser.rs | 658 +++ glide-core/redis-rs/redis/src/pipeline.rs | 324 ++ glide-core/redis-rs/redis/src/push_manager.rs | 234 + glide-core/redis-rs/redis/src/r2d2.rs | 36 + glide-core/redis-rs/redis/src/script.rs | 255 + glide-core/redis-rs/redis/src/sentinel.rs | 778 +++ glide-core/redis-rs/redis/src/streams.rs | 670 +++ glide-core/redis-rs/redis/src/tls.rs | 142 + glide-core/redis-rs/redis/src/types.rs | 2460 ++++++++++ glide-core/redis-rs/redis/tests/parser.rs | 195 + .../redis-rs/redis/tests/support/cluster.rs | 792 +++ .../redis/tests/support/mock_cluster.rs | 487 ++ .../redis-rs/redis/tests/support/mod.rs | 887 ++++ .../redis-rs/redis/tests/support/sentinel.rs | 404 ++ .../redis-rs/redis/tests/support/util.rs | 23 + glide-core/redis-rs/redis/tests/test_acl.rs | 156 + glide-core/redis-rs/redis/tests/test_async.rs | 1132 +++++ .../redis/tests/test_async_async_std.rs | 328 ++ .../test_async_cluster_connections_logic.rs | 563 +++ glide-core/redis-rs/redis/tests/test_basic.rs | 1581 ++++++ .../redis-rs/redis/tests/test_bignum.rs | 61 + .../redis-rs/redis/tests/test_cluster.rs | 1093 +++++ .../redis/tests/test_cluster_async.rs | 4245 +++++++++++++++++ .../redis-rs/redis/tests/test_cluster_scan.rs | 849 ++++ .../redis-rs/redis/tests/test_geospatial.rs | 197 + .../redis-rs/redis/tests/test_module_json.rs | 540 +++ .../redis-rs/redis/tests/test_sentinel.rs | 496 ++ .../redis-rs/redis/tests/test_streams.rs | 627 +++ glide-core/redis-rs/redis/tests/test_types.rs | 606 +++ glide-core/redis-rs/release.sh | 15 + glide-core/redis-rs/rustfmt.toml | 2 + .../redis-rs/scripts/get_command_info.py | 227 + .../redis-rs/scripts/update-versions.sh | 20 + glide-core/redis-rs/upload-docs.sh | 26 + go/Cargo.toml | 2 +- go/DEVELOPER.md | 4 +- java/Cargo.toml | 2 +- node/DEVELOPER.md | 2 + node/rust-client/Cargo.toml | 2 +- python/Cargo.toml | 2 +- python/DEVELOPER.md | 175 +- python/python/tests/test_async_client.py | 1 + submodules/redis-rs | 1 - 117 files changed, 43683 insertions(+), 101 deletions(-) create mode 100644 Makefile create mode 100644 glide-core/redis-rs/Cargo.toml create mode 100644 glide-core/redis-rs/LICENSE create mode 100644 glide-core/redis-rs/Makefile create mode 100644 glide-core/redis-rs/README.md create mode 100644 glide-core/redis-rs/afl/.gitignore create mode 100644 glide-core/redis-rs/afl/parser/Cargo.toml create mode 100644 glide-core/redis-rs/afl/parser/in/array create mode 100644 glide-core/redis-rs/afl/parser/in/array-null create mode 100644 glide-core/redis-rs/afl/parser/in/bulkstring create mode 100644 glide-core/redis-rs/afl/parser/in/bulkstring-null create mode 100644 glide-core/redis-rs/afl/parser/in/error create mode 100644 glide-core/redis-rs/afl/parser/in/integer create mode 100644 glide-core/redis-rs/afl/parser/in/invalid-string create mode 100644 glide-core/redis-rs/afl/parser/in/string create mode 100644 glide-core/redis-rs/afl/parser/src/main.rs create mode 100644 glide-core/redis-rs/afl/parser/src/reproduce.rs create mode 100644 glide-core/redis-rs/appveyor.yml create mode 100644 glide-core/redis-rs/redis-test/CHANGELOG.md create mode 100644 glide-core/redis-rs/redis-test/Cargo.toml create mode 100644 glide-core/redis-rs/redis-test/LICENSE create mode 100644 glide-core/redis-rs/redis-test/README.md create mode 100644 glide-core/redis-rs/redis-test/release.toml create mode 100644 glide-core/redis-rs/redis-test/src/lib.rs create mode 100644 glide-core/redis-rs/redis/CHANGELOG.md create mode 100644 glide-core/redis-rs/redis/Cargo.toml create mode 100644 glide-core/redis-rs/redis/LICENSE create mode 100644 glide-core/redis-rs/redis/benches/bench_basic.rs create mode 100644 glide-core/redis-rs/redis/benches/bench_cluster.rs create mode 100644 glide-core/redis-rs/redis/benches/bench_cluster_async.rs create mode 100644 glide-core/redis-rs/redis/examples/async-await.rs create mode 100644 glide-core/redis-rs/redis/examples/async-connection-loss.rs create mode 100644 glide-core/redis-rs/redis/examples/async-multiplexed.rs create mode 100644 glide-core/redis-rs/redis/examples/async-pub-sub.rs create mode 100644 glide-core/redis-rs/redis/examples/async-scan.rs create mode 100644 glide-core/redis-rs/redis/examples/basic.rs create mode 100644 glide-core/redis-rs/redis/examples/geospatial.rs create mode 100644 glide-core/redis-rs/redis/examples/streams.rs create mode 100644 glide-core/redis-rs/redis/release.toml create mode 100644 glide-core/redis-rs/redis/src/acl.rs create mode 100644 glide-core/redis-rs/redis/src/aio/async_std.rs create mode 100644 glide-core/redis-rs/redis/src/aio/connection.rs create mode 100644 glide-core/redis-rs/redis/src/aio/connection_manager.rs create mode 100644 glide-core/redis-rs/redis/src/aio/mod.rs create mode 100644 glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs create mode 100644 glide-core/redis-rs/redis/src/aio/runtime.rs create mode 100644 glide-core/redis-rs/redis/src/aio/tokio.rs create mode 100644 glide-core/redis-rs/redis/src/client.rs create mode 100644 glide-core/redis-rs/redis/src/cluster.rs create mode 100644 glide-core/redis-rs/redis/src/cluster_async/LICENSE create mode 100644 glide-core/redis-rs/redis/src/cluster_async/connections_container.rs create mode 100644 glide-core/redis-rs/redis/src/cluster_async/connections_logic.rs create mode 100644 glide-core/redis-rs/redis/src/cluster_async/mod.rs create mode 100644 glide-core/redis-rs/redis/src/cluster_client.rs create mode 100644 glide-core/redis-rs/redis/src/cluster_pipeline.rs create mode 100644 glide-core/redis-rs/redis/src/cluster_routing.rs create mode 100644 glide-core/redis-rs/redis/src/cluster_slotmap.rs create mode 100644 glide-core/redis-rs/redis/src/cluster_topology.rs create mode 100644 glide-core/redis-rs/redis/src/cmd.rs create mode 100644 glide-core/redis-rs/redis/src/commands/cluster_scan.rs create mode 100644 glide-core/redis-rs/redis/src/commands/json.rs create mode 100644 glide-core/redis-rs/redis/src/commands/macros.rs create mode 100644 glide-core/redis-rs/redis/src/commands/mod.rs create mode 100644 glide-core/redis-rs/redis/src/connection.rs create mode 100644 glide-core/redis-rs/redis/src/geo.rs create mode 100644 glide-core/redis-rs/redis/src/lib.rs create mode 100644 glide-core/redis-rs/redis/src/macros.rs create mode 100644 glide-core/redis-rs/redis/src/parser.rs create mode 100644 glide-core/redis-rs/redis/src/pipeline.rs create mode 100644 glide-core/redis-rs/redis/src/push_manager.rs create mode 100644 glide-core/redis-rs/redis/src/r2d2.rs create mode 100644 glide-core/redis-rs/redis/src/script.rs create mode 100644 glide-core/redis-rs/redis/src/sentinel.rs create mode 100644 glide-core/redis-rs/redis/src/streams.rs create mode 100644 glide-core/redis-rs/redis/src/tls.rs create mode 100644 glide-core/redis-rs/redis/src/types.rs create mode 100644 glide-core/redis-rs/redis/tests/parser.rs create mode 100644 glide-core/redis-rs/redis/tests/support/cluster.rs create mode 100644 glide-core/redis-rs/redis/tests/support/mock_cluster.rs create mode 100644 glide-core/redis-rs/redis/tests/support/mod.rs create mode 100644 glide-core/redis-rs/redis/tests/support/sentinel.rs create mode 100644 glide-core/redis-rs/redis/tests/support/util.rs create mode 100644 glide-core/redis-rs/redis/tests/test_acl.rs create mode 100644 glide-core/redis-rs/redis/tests/test_async.rs create mode 100644 glide-core/redis-rs/redis/tests/test_async_async_std.rs create mode 100644 glide-core/redis-rs/redis/tests/test_async_cluster_connections_logic.rs create mode 100644 glide-core/redis-rs/redis/tests/test_basic.rs create mode 100644 glide-core/redis-rs/redis/tests/test_bignum.rs create mode 100644 glide-core/redis-rs/redis/tests/test_cluster.rs create mode 100644 glide-core/redis-rs/redis/tests/test_cluster_async.rs create mode 100644 glide-core/redis-rs/redis/tests/test_cluster_scan.rs create mode 100644 glide-core/redis-rs/redis/tests/test_geospatial.rs create mode 100644 glide-core/redis-rs/redis/tests/test_module_json.rs create mode 100644 glide-core/redis-rs/redis/tests/test_sentinel.rs create mode 100644 glide-core/redis-rs/redis/tests/test_streams.rs create mode 100644 glide-core/redis-rs/redis/tests/test_types.rs create mode 100755 glide-core/redis-rs/release.sh create mode 100644 glide-core/redis-rs/rustfmt.toml create mode 100644 glide-core/redis-rs/scripts/get_command_info.py create mode 100755 glide-core/redis-rs/scripts/update-versions.sh create mode 100755 glide-core/redis-rs/upload-docs.sh delete mode 160000 submodules/redis-rs diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index c85045df07..45d2c0cf0d 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -216,9 +216,9 @@ jobs: - name: Install dependencies if: always() - uses: threeal/pipx-install-action@latest - with: - packages: flake8 isort black + working-directory: ./python + run: | + sudo apt install -y python3-pip python3 flake8 isort black - name: Lint python with isort if: always() diff --git a/.github/workflows/semgrep.yml b/.github/workflows/semgrep.yml index 6e4235abdb..4bfd9e12ac 100644 --- a/.github/workflows/semgrep.yml +++ b/.github/workflows/semgrep.yml @@ -33,4 +33,4 @@ jobs: # Fetch project source with GitHub Actions Checkout. - uses: actions/checkout@v3 # Run the "semgrep ci" command on the command line of the docker image. - - run: semgrep ci --config auto --no-suppress-errors + - run: semgrep ci --config auto --no-suppress-errors --exclude-rule generic.secrets.security.detected-private-key.detected-private-key diff --git a/.gitignore b/.gitignore index 573bfc218d..6799f31ea6 100644 --- a/.gitignore +++ b/.gitignore @@ -43,6 +43,8 @@ logger-rs.linux-x64-gnu.node utils/clusters/ utils/tls_crts/ utils/TestUtils.js +.build/ +.project # OSS Review Toolkit (ORT) files **/ort*/** diff --git a/.gitmodules b/.gitmodules index 87a3d9b855..e69de29bb2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +0,0 @@ -[submodule "submodules/redis-rs"] - path = submodules/redis-rs - url = https://github.com/amazon-contributing/redis-rs diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000..92bcf7acc5 --- /dev/null +++ b/Makefile @@ -0,0 +1,118 @@ +.PHONY: all java java-test python python-test node node-test check-redis-server go go-test + +BLUE=\033[34m +YELLOW=\033[33m +GREEN=\033[32m +RESET=\033[0m +ROOT_DIR=$(shell pwd) +PYENV_DIR=$(shell pwd)/python/.env +PY_PATH=$(shell find python/.env -name "site-packages"|xargs readlink -f) +PY_GLIDE_PATH=$(shell pwd)/python/python/ + +all: java java-test python python-test node node-test go go-test python-lint java-lint + +## +## Java targets +## +java: + @echo "$(GREEN)Building for Java (release)$(RESET)" + @cd java && ./gradlew :client:buildAllRelease + +java-lint: + @echo "$(GREEN)Running spotlessCheck$(RESET)" + @cd java && ./gradlew :spotlessCheck + @echo "$(GREEN)Running spotlessApply$(RESET)" + @cd java && ./gradlew :spotlessApply + +java-test: check-redis-server + @echo "$(GREEN)Running integration tests$(RESET)" + @cd java && ./gradlew :integTest:test + +## +## Python targets +## +python: .build/python_deps + @echo "$(GREEN)Building for Python (release)$(RESET)" + @cd python && VIRTUAL_ENV=$(PYENV_DIR) .env/bin/maturin develop --release --strip + +python-lint: .build/python_deps + @echo "$(GREEN)Building Linters for python$(RESET)" + cd python && \ + export VIRTUAL_ENV=$(PYENV_DIR); \ + export PYTHONPATH=$(PY_PATH):$(PY_GLIDE_PATH); \ + export PATH=$(PYENV_DIR)/bin:$(PATH); \ + isort . --profile black --skip-glob python/glide/protobuf --skip-glob .env && \ + black . --exclude python/glide/protobuf --exclude .env && \ + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics \ + --exclude=python/glide/protobuf,.env/* --extend-ignore=E230 && \ + flake8 . --count --exit-zero --max-complexity=12 --max-line-length=127 \ + --statistics --exclude=python/glide/protobuf,.env/* \ + --extend-ignore=E230 + +python-test: .build/python_deps check-redis-server + cd python && PYTHONPATH=$(PY_PATH):$(PY_GLIDE_PATH) .env/bin/pytest --asyncio-mode=auto + +.build/python_deps: + @echo "$(GREEN)Generating protobuf files...$(RESET)" + @protoc -Iprotobuf=$(ROOT_DIR)/glide-core/src/protobuf/ \ + --python_out=$(ROOT_DIR)/python/python/glide $(ROOT_DIR)/glide-core/src/protobuf/*.proto + @echo "$(GREEN)Building environment...$(RESET)" + @cd python && python3 -m venv .env + @echo "$(GREEN)Installing requirements...$(RESET)" + @cd python && .env/bin/pip install -r requirements.txt + @cd python && .env/bin/pip install -r dev_requirements.txt + @mkdir -p .build/ && touch .build/python_deps + +## +## NodeJS targets +## +node: .build/node_deps + @echo "$(GREEN)Building for NodeJS (release)...$(RESET)" + @cd node && npm run build:release + +.build/node_deps: + @echo "$(GREEN)Installing NodeJS dependencies...$(RESET)" + @cd node && npm i + @cd node/rust-client && npm i + @mkdir -p .build/ && touch .build/node_deps + +node-test: .build/node_deps check-redis-server + @echo "$(GREEN)Running tests for NodeJS$(RESET)" + @cd node && npm run build + cd node && npm test + +node-lint: .build/node_deps + @echo "$(GREEN)Running linters for NodeJS$(RESET)" + @cd node && npx run lint:fix + +## +## Go targets +## + + +go: .build/go_deps + $(MAKE) -C go build + +go-test: .build/go_deps + $(MAKE) -C go test + +go-lint: .build/go_deps + $(MAKE) -C go lint + +.build/go_deps: + @echo "$(GREEN)Installing GO dependencies...$(RESET)" + $(MAKE) -C go install-build-tools install-dev-tools + @mkdir -p .build/ && touch .build/go_deps + +## +## Common targets +## +check-redis-server: + which redis-server + +clean: + rm -fr .build/ + +help: + @echo "$(GREEN)Listing Makefile targets:$(RESET)" + @echo $(shell grep '^[^#[:space:]].*:' Makefile|cut -d":" -f1|grep -v PHONY|grep -v "^.build"|sort) diff --git a/benchmarks/rust/Cargo.toml b/benchmarks/rust/Cargo.toml index 6f0849d505..d63bc98e57 100644 --- a/benchmarks/rust/Cargo.toml +++ b/benchmarks/rust/Cargo.toml @@ -11,7 +11,7 @@ authors = ["Valkey GLIDE Maintainers"] tokio = { version = "1", features = ["macros", "time", "rt-multi-thread"] } glide-core = { path = "../../glide-core" } logger_core = {path = "../../logger_core"} -redis = { path = "../../submodules/redis-rs/redis", features = ["aio"] } +redis = { path = "../../glide-core/redis-rs/redis", features = ["aio"] } futures = "0.3.28" rand = "0.8.5" itoa = "1.0.6" diff --git a/csharp/lib/Cargo.toml b/csharp/lib/Cargo.toml index 95981480b2..b49e098bf7 100644 --- a/csharp/lib/Cargo.toml +++ b/csharp/lib/Cargo.toml @@ -12,7 +12,7 @@ name = "glide_rs" crate-type = ["cdylib"] [dependencies] -redis = { path = "../../submodules/redis-rs/redis", features = ["aio", "tokio-comp","tokio-native-tls-comp"] } +redis = { path = "../../glide-core/redis-rs/redis", features = ["aio", "tokio-comp","tokio-native-tls-comp"] } glide-core = { path = "../../glide-core" } tokio = { version = "^1", features = ["rt", "macros", "rt-multi-thread", "time"] } logger_core = {path = "../../logger_core"} diff --git a/glide-core/Cargo.toml b/glide-core/Cargo.toml index e0a1b05368..51de808bd2 100644 --- a/glide-core/Cargo.toml +++ b/glide-core/Cargo.toml @@ -10,7 +10,7 @@ authors = ["Valkey GLIDE Maintainers"] [dependencies] bytes = "1" futures = "^0.3" -redis = { path = "../submodules/redis-rs/redis", features = ["aio", "tokio-comp", "tokio-rustls-comp", "connection-manager","cluster", "cluster-async"] } +redis = { path = "./redis-rs/redis", features = ["aio", "tokio-comp", "tokio-rustls-comp", "connection-manager","cluster", "cluster-async"] } tokio = { version = "1", features = ["macros", "time"] } logger_core = {path = "../logger_core"} dispose = "0.5.0" @@ -42,7 +42,7 @@ serial_test = "3" criterion = { version = "^0.5", features = ["html_reports", "async_tokio"] } which = "5" ctor = "0.2.2" -redis = { path = "../submodules/redis-rs/redis", features = ["tls-rustls-insecure"] } +redis = { path = "./redis-rs/redis", features = ["tls-rustls-insecure"] } iai-callgrind = "0.9" tokio = { version = "1", features = ["rt-multi-thread"] } glide-core = { path = ".", features = ["socket-layer"] } # always enable this feature in tests. diff --git a/glide-core/redis-rs/Cargo.toml b/glide-core/redis-rs/Cargo.toml new file mode 100644 index 0000000000..2f4ebbcbbe --- /dev/null +++ b/glide-core/redis-rs/Cargo.toml @@ -0,0 +1,3 @@ +[workspace] +members = ["redis", "redis-test"] +resolver = "2" diff --git a/glide-core/redis-rs/LICENSE b/glide-core/redis-rs/LICENSE new file mode 100644 index 0000000000..533ac4e5a2 --- /dev/null +++ b/glide-core/redis-rs/LICENSE @@ -0,0 +1,33 @@ +Copyright (c) 2022 by redis-rs contributors + +Redis cluster code in parts copyright (c) 2018 by Atsushi Koge. + +Some rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + * The names of the contributors may not be used to endorse or + promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/glide-core/redis-rs/Makefile b/glide-core/redis-rs/Makefile new file mode 100644 index 0000000000..9f177a6c22 --- /dev/null +++ b/glide-core/redis-rs/Makefile @@ -0,0 +1,96 @@ +build: + @cargo build + +test: + @echo "====================================================================" + @echo "Build all features with lock file" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" cargo build --locked --all-features + + @echo "====================================================================" + @echo "Testing Connection Type TCP without features" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" REDISRS_SERVER_TYPE=tcp RUST_BACKTRACE=1 cargo test --locked -p redis --no-default-features -- --nocapture --test-threads=1 --skip test_module + + @echo "====================================================================" + @echo "Testing Connection Type TCP with all features and RESP2" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" REDISRS_SERVER_TYPE=tcp RUST_BACKTRACE=1 cargo test --locked -p redis --all-features -- --nocapture --test-threads=1 --skip test_module + + @echo "====================================================================" + @echo "Testing Connection Type TCP with all features and RESP3" + @echo "====================================================================" + @REDISRS_SERVER_TYPE=tcp PROTOCOL=RESP3 cargo test -p redis --all-features -- --nocapture --test-threads=1 --skip test_module + + @echo "====================================================================" + @echo "Testing Connection Type TCP with all features and Rustls support" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" REDISRS_SERVER_TYPE=tcp+tls RUST_BACKTRACE=1 cargo test --locked -p redis --all-features -- --nocapture --test-threads=1 --skip test_module + + @echo "====================================================================" + @echo "Testing Connection Type TCP with all features and native-TLS support" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" REDISRS_SERVER_TYPE=tcp+tls RUST_BACKTRACE=1 cargo test --locked -p redis --features=json,tokio-native-tls-comp,connection-manager,cluster-async -- --nocapture --test-threads=1 --skip test_module + + @echo "====================================================================" + @echo "Testing Connection Type UNIX" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" REDISRS_SERVER_TYPE=unix RUST_BACKTRACE=1 cargo test --locked -p redis --test parser --test test_basic --test test_types --all-features -- --test-threads=1 --skip test_module + + @echo "====================================================================" + @echo "Testing Connection Type UNIX SOCKETS" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" REDISRS_SERVER_TYPE=unix RUST_BACKTRACE=1 cargo test --locked -p redis --all-features -- --test-threads=1 --skip test_cluster --skip test_async_cluster --skip test_module --skip test_cluster_scan + + @echo "====================================================================" + @echo "Testing async-std with Rustls" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" REDISRS_SERVER_TYPE=tcp RUST_BACKTRACE=1 cargo test --locked -p redis --features=async-std-rustls-comp,cluster-async -- --nocapture --test-threads=1 --skip test_module + + @echo "====================================================================" + @echo "Testing async-std with native-TLS" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" REDISRS_SERVER_TYPE=tcp RUST_BACKTRACE=1 cargo test --locked -p redis --features=async-std-native-tls-comp,cluster-async -- --nocapture --test-threads=1 --skip test_module + + @echo "====================================================================" + @echo "Testing redis-test" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" RUST_BACKTRACE=1 cargo test --locked -p redis-test + + +test-module: + @echo "====================================================================" + @echo "Testing RESP2 with module support enabled (currently only RedisJSON)" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" REDISRS_SERVER_TYPE=tcp RUST_BACKTRACE=1 cargo test --locked --all-features test_module -- --test-threads=1 + + @echo "====================================================================" + @echo "Testing RESP3 with module support enabled (currently only RedisJSON)" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" REDISRS_SERVER_TYPE=tcp RUST_BACKTRACE=1 RESP3=true cargo test --all-features test_module -- --test-threads=1 + +test-single: test + +bench: + cargo bench --all-features + +docs: + @RUSTFLAGS="-D warnings" RUSTDOCFLAGS="--cfg docsrs" cargo +nightly doc --all-features --no-deps + +upload-docs: docs + @./upload-docs.sh + +style-check: + @rustup component add rustfmt 2> /dev/null + cargo fmt --all -- --check + +lint: + @rustup component add clippy 2> /dev/null + cargo clippy --all-features --all --tests --examples -- -D clippy::all -D warnings + +fuzz: + cd afl/parser/ && \ + cargo afl build --bin fuzz-target && \ + cargo afl fuzz -i in -o out target/debug/fuzz-target + +.PHONY: build test bench docs upload-docs style-check lint fuzz diff --git a/glide-core/redis-rs/README.md b/glide-core/redis-rs/README.md new file mode 100644 index 0000000000..34cdfe4778 --- /dev/null +++ b/glide-core/redis-rs/README.md @@ -0,0 +1,233 @@ +# redis-rs + +[![Rust](https://github.com/redis-rs/redis-rs/actions/workflows/rust.yml/badge.svg)](https://github.com/redis-rs/redis-rs/actions/workflows/rust.yml) +[![crates.io](https://img.shields.io/crates/v/redis.svg)](https://crates.io/crates/redis) +[![Chat](https://img.shields.io/discord/976380008299917365?logo=discord)](https://discord.gg/WHKcJK9AKP) + +Redis-rs is a high level redis library for Rust. It provides convenient access +to all Redis functionality through a very flexible but low-level API. It +uses a customizable type conversion trait so that any operation can return +results in just the type you are expecting. This makes for a very pleasant +development experience. + +The crate is called `redis` and you can depend on it via cargo: + +```ini +[dependencies] +redis = "0.25.2" +``` + +Documentation on the library can be found at +[docs.rs/redis](https://docs.rs/redis). + +**Note: redis-rs requires at least Rust 1.60.** + +## Basic Operation + +To open a connection you need to create a client and then to fetch a +connection from it. In the future there will be a connection pool for +those, currently each connection is separate and not pooled. + +Many commands are implemented through the `Commands` trait but manual +command creation is also possible. + +```rust +use redis::Commands; + +fn fetch_an_integer() -> redis::RedisResult { + // connect to redis + let client = redis::Client::open("redis://127.0.0.1/")?; + let mut con = client.get_connection(None)?; + // throw away the result, just make sure it does not fail + let _ : () = con.set("my_key", 42)?; + // read back the key and return it. Because the return value + // from the function is a result for integer this will automatically + // convert into one. + con.get("my_key") +} +``` + +Variables are converted to and from the Redis format for a wide variety of types +(`String`, num types, tuples, `Vec`). If you want to use it with your own types, +you can implement the `FromRedisValue` and `ToRedisArgs` traits, or derive it with the +[redis-macros](https://github.com/daniel7grant/redis-macros/#json-wrapper-with-redisjson) crate. + +## Async support + +To enable asynchronous clients, enable the relevant feature in your Cargo.toml, +`tokio-comp` for tokio users or `async-std-comp` for async-std users. + +``` +# if you use tokio +redis = { version = "0.25.2", features = ["tokio-comp"] } + +# if you use async-std +redis = { version = "0.25.2", features = ["async-std-comp"] } +``` + +## TLS Support + +To enable TLS support, you need to use the relevant feature entry in your Cargo.toml. +Currently, `native-tls` and `rustls` are supported. + +To use `native-tls`: + +``` +redis = { version = "0.25.2", features = ["tls-native-tls"] } + +# if you use tokio +redis = { version = "0.25.2", features = ["tokio-native-tls-comp"] } + +# if you use async-std +redis = { version = "0.25.2", features = ["async-std-native-tls-comp"] } +``` + +To use `rustls`: + +``` +redis = { version = "0.25.2", features = ["tls-rustls"] } + +# if you use tokio +redis = { version = "0.25.2", features = ["tokio-rustls-comp"] } + +# if you use async-std +redis = { version = "0.25.2", features = ["async-std-rustls-comp"] } +``` + +With `rustls`, you can add the following feature flags on top of other feature flags to enable additional features: + +- `tls-rustls-insecure`: Allow insecure TLS connections +- `tls-rustls-webpki-roots`: Use `webpki-roots` (Mozilla's root certificates) instead of native root certificates + +then you should be able to connect to a redis instance using the `rediss://` URL scheme: + +```rust +let client = redis::Client::open("rediss://127.0.0.1/")?; +``` + +To enable insecure mode, append `#insecure` at the end of the URL: + +```rust +let client = redis::Client::open("rediss://127.0.0.1/#insecure")?; +``` + +**Deprecation Notice:** If you were using the `tls` or `async-std-tls-comp` features, please use the `tls-native-tls` or `async-std-native-tls-comp` features respectively. + +## Cluster Support + +Support for Redis Cluster can be enabled by enabling the `cluster` feature in your Cargo.toml: + +`redis = { version = "0.25.2", features = [ "cluster"] }` + +Then you can simply use the `ClusterClient`, which accepts a list of available nodes. Note +that only one node in the cluster needs to be specified when instantiating the client, though +you can specify multiple. + +```rust +use redis::cluster::ClusterClient; +use redis::Commands; + +fn fetch_an_integer() -> String { + let nodes = vec!["redis://127.0.0.1/"]; + let client = ClusterClient::new(nodes).unwrap(); + let mut connection = client.get_connection(None).unwrap(); + let _: () = connection.set("test", "test_data").unwrap(); + let rv: String = connection.get("test").unwrap(); + return rv; +} +``` + +Async Redis Cluster support can be enabled by enabling the `cluster-async` feature, along +with your preferred async runtime, e.g.: + +`redis = { version = "0.25.2", features = [ "cluster-async", "tokio-std-comp" ] }` + +```rust +use redis::cluster::ClusterClient; +use redis::AsyncCommands; + +async fn fetch_an_integer() -> String { + let nodes = vec!["redis://127.0.0.1/"]; + let client = ClusterClient::new(nodes).unwrap(); + let mut connection = client.get_async_connection().await.unwrap(); + let _: () = connection.set("test", "test_data").await.unwrap(); + let rv: String = connection.get("test").await.unwrap(); + return rv; +} +``` + +## JSON Support + +Support for the RedisJSON Module can be enabled by specifying "json" as a feature in your Cargo.toml. + +`redis = { version = "0.25.2", features = ["json"] }` + +Then you can simply import the `JsonCommands` trait which will add the `json` commands to all Redis Connections (not to be confused with just `Commands` which only adds the default commands) + +```rust +use redis::Client; +use redis::JsonCommands; +use redis::RedisResult; +use redis::ToRedisArgs; + +// Result returns Ok(true) if the value was set +// Result returns Err(e) if there was an error with the server itself OR serde_json was unable to serialize the boolean +fn set_json_bool(key: P, path: P, b: bool) -> RedisResult { + let client = Client::open("redis://127.0.0.1").unwrap(); + let connection = client.get_connection(None).unwrap(); + + // runs `JSON.SET {key} {path} {b}` + connection.json_set(key, path, b)? +} + +``` + +To parse the results, you'll need to use `serde_json` (or some other json lib) to deserialize +the results from the bytes. It will always be a `Vec`, if no results were found at the path it'll +be an empty `Vec`. If you want to handle deserialization and `Vec` unwrapping automatically, +you can use the `Json` wrapper from the +[redis-macros](https://github.com/daniel7grant/redis-macros/#json-wrapper-with-redisjson) crate. + +## Development + +To test `redis` you're going to need to be able to test with the Redis Modules, to do this +you must set the following environment variable before running the test script + +- `REDIS_RS_REDIS_JSON_PATH` = The absolute path to the RedisJSON module (Either `librejson.so` for Linux or `librejson.dylib` for MacOS). + +- Please refer to this [link](https://github.com/RedisJSON/RedisJSON) to access the RedisJSON module: + + + +If you want to develop on the library there are a few commands provided +by the makefile: + +To build: + + $ make + +To test: + + $ make test + +To run benchmarks: + + $ make bench + +To build the docs (require nightly compiler, see [rust-lang/rust#43781](https://github.com/rust-lang/rust/issues/43781)): + + $ make docs + +We encourage you to run `clippy` prior to seeking a merge for your work. The lints can be quite strict. Running this on your own workstation can save you time, since Travis CI will fail any build that doesn't satisfy `clippy`: + + $ cargo clippy --all-features --all --tests --examples -- -D clippy::all -D warnings + +To run fuzz tests with afl, first install cargo-afl (`cargo install -f afl`), +then run: + + $ make fuzz + +If the fuzzer finds a crash, in order to reproduce it, run: + + $ cd afl// + $ cargo run --bin reproduce -- out/crashes/ diff --git a/glide-core/redis-rs/afl/.gitignore b/glide-core/redis-rs/afl/.gitignore new file mode 100644 index 0000000000..1776e13233 --- /dev/null +++ b/glide-core/redis-rs/afl/.gitignore @@ -0,0 +1,2 @@ +out/ +core.* diff --git a/glide-core/redis-rs/afl/parser/Cargo.toml b/glide-core/redis-rs/afl/parser/Cargo.toml new file mode 100644 index 0000000000..9f5202d86a --- /dev/null +++ b/glide-core/redis-rs/afl/parser/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "fuzz-target-parser" +version = "0.1.0" +authors = ["redis-rs developers"] +edition = "2018" + +[[bin]] +name = "fuzz-target" +path = "src/main.rs" + +[[bin]] +name = "reproduce" +path = "src/reproduce.rs" + +[dependencies] +afl = "0.4" +redis = { path = "../../redis" } diff --git a/glide-core/redis-rs/afl/parser/in/array b/glide-core/redis-rs/afl/parser/in/array new file mode 100644 index 0000000000..c92e405790 --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/array @@ -0,0 +1,5 @@ +*3 +:1 +$-1 +$2 +hi diff --git a/glide-core/redis-rs/afl/parser/in/array-null b/glide-core/redis-rs/afl/parser/in/array-null new file mode 100644 index 0000000000..e0f619c1b3 --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/array-null @@ -0,0 +1 @@ +*-1 diff --git a/glide-core/redis-rs/afl/parser/in/bulkstring b/glide-core/redis-rs/afl/parser/in/bulkstring new file mode 100644 index 0000000000..930878abea --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/bulkstring @@ -0,0 +1,2 @@ +$6 +foobar diff --git a/glide-core/redis-rs/afl/parser/in/bulkstring-null b/glide-core/redis-rs/afl/parser/in/bulkstring-null new file mode 100644 index 0000000000..f4280bede5 --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/bulkstring-null @@ -0,0 +1 @@ +$-1 diff --git a/glide-core/redis-rs/afl/parser/in/error b/glide-core/redis-rs/afl/parser/in/error new file mode 100644 index 0000000000..7cfd9a521a --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/error @@ -0,0 +1 @@ +-ERR unknown command diff --git a/glide-core/redis-rs/afl/parser/in/integer b/glide-core/redis-rs/afl/parser/in/integer new file mode 100644 index 0000000000..49525f0d45 --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/integer @@ -0,0 +1 @@ +:1337 diff --git a/glide-core/redis-rs/afl/parser/in/invalid-string b/glide-core/redis-rs/afl/parser/in/invalid-string new file mode 100644 index 0000000000..604dd3e85a --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/invalid-string @@ -0,0 +1,2 @@ +$6 +foo diff --git a/glide-core/redis-rs/afl/parser/in/string b/glide-core/redis-rs/afl/parser/in/string new file mode 100644 index 0000000000..054430c700 --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/string @@ -0,0 +1 @@ ++OK diff --git a/glide-core/redis-rs/afl/parser/src/main.rs b/glide-core/redis-rs/afl/parser/src/main.rs new file mode 100644 index 0000000000..6dc674edff --- /dev/null +++ b/glide-core/redis-rs/afl/parser/src/main.rs @@ -0,0 +1,9 @@ +use afl::fuzz; + +use redis::parse_redis_value; + +fn main() { + fuzz!(|data: &[u8]| { + let _ = parse_redis_value(data); + }); +} diff --git a/glide-core/redis-rs/afl/parser/src/reproduce.rs b/glide-core/redis-rs/afl/parser/src/reproduce.rs new file mode 100644 index 0000000000..086dfffb50 --- /dev/null +++ b/glide-core/redis-rs/afl/parser/src/reproduce.rs @@ -0,0 +1,13 @@ +use redis::parse_redis_value; + +fn main() { + let args: Vec = std::env::args().collect(); + if args.len() != 2 { + println!("Usage: {} ", args[0]); + std::process::exit(1); + } + + let data = std::fs::read(&args[1]).expect(&format!("Could not open file {}", args[1])); + let v = parse_redis_value(&data); + println!("Result: {:?}", v); +} diff --git a/glide-core/redis-rs/appveyor.yml b/glide-core/redis-rs/appveyor.yml new file mode 100644 index 0000000000..8310b8def5 --- /dev/null +++ b/glide-core/redis-rs/appveyor.yml @@ -0,0 +1,23 @@ +os: Visual Studio 2015 + +environment: + REDISRS_SERVER_TYPE: tcp + RUST_BACKTRACE: 1 + matrix: + - channel: stable + target: x86_64-pc-windows-msvc + - channel: stable + target: x86_64-pc-windows-gnu +install: + - appveyor DownloadFile https://win.rustup.rs/ -FileName rustup-init.exe + - rustup-init -yv --default-toolchain %channel% --default-host %target% + - set PATH=%PATH%;%USERPROFILE%\.cargo\bin + - rustc -vV + - cargo -vV + - cmd: nuget install redis-64 -excludeversion + - set PATH=%PATH%;%APPVEYOR_BUILD_FOLDER%\redis-64\tools\ + +build: false + +test_script: + - cargo test --verbose --no-default-features --features tokio-comp %cargoflags% diff --git a/glide-core/redis-rs/redis-test/CHANGELOG.md b/glide-core/redis-rs/redis-test/CHANGELOG.md new file mode 100644 index 0000000000..83d3ab3dc4 --- /dev/null +++ b/glide-core/redis-rs/redis-test/CHANGELOG.md @@ -0,0 +1,44 @@ +### 0.4.0 (2023-03-08) +* Track redis 0.25.0 release + +### 0.3.0 (2023-12-05) +* Track redis 0.24.0 release + +### 0.2.3 (2023-09-01) + +* Track redis 0.23.3 release + +### 0.2.2 (2023-08-10) + +* Track redis 0.23.2 release + +### 0.2.1 (2023-07-28) + +* Track redis 0.23.1 release + + +### 0.2.0 (2023-04-05) + +* Track redis 0.23.0 release + + +### 0.2.0-beta.1 (2023-03-28) + +* Track redis 0.23.0-beta.1 release + + +### 0.1.1 (2022-10-18) + +#### Changes +* Add README +* Update LICENSE file / symlink from parent directory + + + +### 0.1.0 (2022-10-05) + +This is the initial release of the redis-test crate, which aims to provide mocking +for connections and commands. Thanks @tdyas! + +#### Features +* Testing module with support for mocking redis connections and commands ([#465](https://github.com/redis-rs/redis-rs/pull/465) @tdyas) \ No newline at end of file diff --git a/glide-core/redis-rs/redis-test/Cargo.toml b/glide-core/redis-rs/redis-test/Cargo.toml new file mode 100644 index 0000000000..6e0bcc3a9f --- /dev/null +++ b/glide-core/redis-rs/redis-test/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "redis-test" +version = "0.4.0" +edition = "2021" +description = "Testing helpers for the `redis` crate" +homepage = "https://github.com/redis-rs/redis-rs" +repository = "https://github.com/redis-rs/redis-rs" +documentation = "https://docs.rs/redis-test" +license = "BSD-3-Clause" +rust-version = "1.65" + +[lib] +bench = false + +[dependencies] +redis = { version = "0.25.0", path = "../redis" } + +bytes = { version = "1", optional = true } +futures = { version = "0.3", optional = true } + +[features] +aio = ["futures", "redis/aio"] + +[dev-dependencies] +redis = { version = "0.25.0", path = "../redis", features = ["aio", "tokio-comp"] } +tokio = { version = "1", features = ["rt", "macros", "rt-multi-thread", "time"] } diff --git a/glide-core/redis-rs/redis-test/LICENSE b/glide-core/redis-rs/redis-test/LICENSE new file mode 100644 index 0000000000..533ac4e5a2 --- /dev/null +++ b/glide-core/redis-rs/redis-test/LICENSE @@ -0,0 +1,33 @@ +Copyright (c) 2022 by redis-rs contributors + +Redis cluster code in parts copyright (c) 2018 by Atsushi Koge. + +Some rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + * The names of the contributors may not be used to endorse or + promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/glide-core/redis-rs/redis-test/README.md b/glide-core/redis-rs/redis-test/README.md new file mode 100644 index 0000000000..b89bfc4edb --- /dev/null +++ b/glide-core/redis-rs/redis-test/README.md @@ -0,0 +1,4 @@ +# redis-test + +Testing utilities for the redis-rs crate. + diff --git a/glide-core/redis-rs/redis-test/release.toml b/glide-core/redis-rs/redis-test/release.toml new file mode 100644 index 0000000000..7dc5b7a0a6 --- /dev/null +++ b/glide-core/redis-rs/redis-test/release.toml @@ -0,0 +1 @@ +tag-name = "redis-test-{{version}}" diff --git a/glide-core/redis-rs/redis-test/src/lib.rs b/glide-core/redis-rs/redis-test/src/lib.rs new file mode 100644 index 0000000000..cafe8a347b --- /dev/null +++ b/glide-core/redis-rs/redis-test/src/lib.rs @@ -0,0 +1,426 @@ +//! Testing support +//! +//! This module provides `MockRedisConnection` which implements ConnectionLike and can be +//! used in the same place as any other type that behaves like a Redis connection. This is useful +//! for writing unit tests without needing a Redis server. +//! +//! # Example +//! +//! ```rust +//! use redis::{ConnectionLike, RedisError}; +//! use redis_test::{MockCmd, MockRedisConnection}; +//! +//! fn my_exists(conn: &mut C, key: &str) -> Result { +//! let exists: bool = redis::cmd("EXISTS").arg(key).query(conn)?; +//! Ok(exists) +//! } +//! +//! let mut mock_connection = MockRedisConnection::new(vec![ +//! MockCmd::new(redis::cmd("EXISTS").arg("foo"), Ok("1")), +//! ]); +//! +//! let result = my_exists(&mut mock_connection, "foo").unwrap(); +//! assert_eq!(result, true); +//! ``` + +use std::collections::VecDeque; +use std::sync::{Arc, Mutex}; + +use redis::{Cmd, ConnectionLike, ErrorKind, Pipeline, RedisError, RedisResult, Value}; + +#[cfg(feature = "aio")] +use futures::{future, FutureExt}; + +#[cfg(feature = "aio")] +use redis::{aio::ConnectionLike as AioConnectionLike, RedisFuture}; + +/// Helper trait for converting test values into a `redis::Value` returned from a +/// `MockRedisConnection`. This is necessary because neither `redis::types::ToRedisArgs` +/// nor `redis::types::FromRedisValue` performs the precise conversion needed. +pub trait IntoRedisValue { + /// Convert a value into `redis::Value`. + fn into_redis_value(self) -> Value; +} + +impl IntoRedisValue for String { + fn into_redis_value(self) -> Value { + Value::BulkString(self.as_bytes().to_vec()) + } +} + +impl IntoRedisValue for &str { + fn into_redis_value(self) -> Value { + Value::BulkString(self.as_bytes().to_vec()) + } +} + +#[cfg(feature = "bytes")] +impl IntoRedisValue for bytes::Bytes { + fn into_redis_value(self) -> Value { + Value::BulkString(self.to_vec()) + } +} + +impl IntoRedisValue for Vec { + fn into_redis_value(self) -> Value { + Value::BulkString(self) + } +} + +impl IntoRedisValue for Value { + fn into_redis_value(self) -> Value { + self + } +} + +impl IntoRedisValue for i64 { + fn into_redis_value(self) -> Value { + Value::Int(self) + } +} + +/// Helper trait for converting `redis::Cmd` and `redis::Pipeline` instances into +/// encoded byte vectors. +pub trait IntoRedisCmdBytes { + /// Convert a command into an encoded byte vector. + fn into_redis_cmd_bytes(self) -> Vec; +} + +impl IntoRedisCmdBytes for Cmd { + fn into_redis_cmd_bytes(self) -> Vec { + self.get_packed_command() + } +} + +impl IntoRedisCmdBytes for &Cmd { + fn into_redis_cmd_bytes(self) -> Vec { + self.get_packed_command() + } +} + +impl IntoRedisCmdBytes for &mut Cmd { + fn into_redis_cmd_bytes(self) -> Vec { + self.get_packed_command() + } +} + +impl IntoRedisCmdBytes for Pipeline { + fn into_redis_cmd_bytes(self) -> Vec { + self.get_packed_pipeline() + } +} + +impl IntoRedisCmdBytes for &Pipeline { + fn into_redis_cmd_bytes(self) -> Vec { + self.get_packed_pipeline() + } +} + +impl IntoRedisCmdBytes for &mut Pipeline { + fn into_redis_cmd_bytes(self) -> Vec { + self.get_packed_pipeline() + } +} + +/// Represents a command to be executed against a `MockConnection`. +pub struct MockCmd { + cmd_bytes: Vec, + responses: Result, RedisError>, +} + +impl MockCmd { + /// Create a new `MockCmd` given a Redis command and either a value convertible to + /// a `redis::Value` or a `RedisError`. + pub fn new(cmd: C, response: Result) -> Self + where + C: IntoRedisCmdBytes, + V: IntoRedisValue, + { + MockCmd { + cmd_bytes: cmd.into_redis_cmd_bytes(), + responses: response.map(|r| vec![r.into_redis_value()]), + } + } + + /// Create a new `MockCommand` given a Redis command/pipeline and a vector of value convertible + /// to a `redis::Value` or a `RedisError`. + pub fn with_values(cmd: C, responses: Result, RedisError>) -> Self + where + C: IntoRedisCmdBytes, + V: IntoRedisValue, + { + MockCmd { + cmd_bytes: cmd.into_redis_cmd_bytes(), + responses: responses.map(|xs| xs.into_iter().map(|x| x.into_redis_value()).collect()), + } + } +} + +/// A mock Redis client for testing without a server. `MockRedisConnection` checks whether the +/// client submits a specific sequence of commands and generates an error if it does not. +#[derive(Clone)] +pub struct MockRedisConnection { + commands: Arc>>, +} + +impl MockRedisConnection { + /// Construct a new from the given sequence of commands. + pub fn new(commands: I) -> Self + where + I: IntoIterator, + { + MockRedisConnection { + commands: Arc::new(Mutex::new(VecDeque::from_iter(commands))), + } + } +} + +impl ConnectionLike for MockRedisConnection { + fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult { + let mut commands = self.commands.lock().unwrap(); + let next_cmd = commands.pop_front().ok_or_else(|| { + RedisError::from(( + ErrorKind::ClientError, + "TEST", + "unexpected command".to_owned(), + )) + })?; + + if cmd != next_cmd.cmd_bytes { + return Err(RedisError::from(( + ErrorKind::ClientError, + "TEST", + format!( + "unexpected command: expected={}, actual={}", + String::from_utf8(next_cmd.cmd_bytes) + .unwrap_or_else(|_| "decode error".to_owned()), + String::from_utf8(Vec::from(cmd)).unwrap_or_else(|_| "decode error".to_owned()), + ), + ))); + } + + next_cmd + .responses + .and_then(|values| match values.as_slice() { + [value] => Ok(value.clone()), + [] => Err(RedisError::from(( + ErrorKind::ClientError, + "no value configured as response", + ))), + _ => Err(RedisError::from(( + ErrorKind::ClientError, + "multiple values configured as response for command expecting a single value", + ))), + }) + } + + fn req_packed_commands( + &mut self, + cmd: &[u8], + _offset: usize, + _count: usize, + ) -> RedisResult> { + let mut commands = self.commands.lock().unwrap(); + let next_cmd = commands.pop_front().ok_or_else(|| { + RedisError::from(( + ErrorKind::ClientError, + "TEST", + "unexpected command".to_owned(), + )) + })?; + + if cmd != next_cmd.cmd_bytes { + return Err(RedisError::from(( + ErrorKind::ClientError, + "TEST", + format!( + "unexpected command: expected={}, actual={}", + String::from_utf8(next_cmd.cmd_bytes) + .unwrap_or_else(|_| "decode error".to_owned()), + String::from_utf8(Vec::from(cmd)).unwrap_or_else(|_| "decode error".to_owned()), + ), + ))); + } + + next_cmd.responses + } + + fn get_db(&self) -> i64 { + 0 + } + + fn check_connection(&mut self) -> bool { + true + } + + fn is_open(&self) -> bool { + true + } +} + +#[cfg(feature = "aio")] +impl AioConnectionLike for MockRedisConnection { + fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> { + let packed_cmd = cmd.get_packed_command(); + let response = ::req_packed_command( + self, + packed_cmd.as_slice(), + ); + future::ready(response).boxed() + } + + fn req_packed_commands<'a>( + &'a mut self, + cmd: &'a Pipeline, + offset: usize, + count: usize, + ) -> RedisFuture<'a, Vec> { + let packed_cmd = cmd.get_packed_pipeline(); + let response = ::req_packed_commands( + self, + packed_cmd.as_slice(), + offset, + count, + ); + future::ready(response).boxed() + } + + fn get_db(&self) -> i64 { + 0 + } + + fn is_closed(&self) -> bool { + false + } +} + +#[cfg(test)] +mod tests { + use super::{MockCmd, MockRedisConnection}; + use redis::{cmd, pipe, ErrorKind, Value}; + + #[test] + fn sync_basic_test() { + let mut conn = MockRedisConnection::new(vec![ + MockCmd::new(cmd("SET").arg("foo").arg(42), Ok("")), + MockCmd::new(cmd("GET").arg("foo"), Ok(42)), + MockCmd::new(cmd("SET").arg("bar").arg("foo"), Ok("")), + MockCmd::new(cmd("GET").arg("bar"), Ok("foo")), + ]); + + cmd("SET").arg("foo").arg(42).execute(&mut conn); + assert_eq!(cmd("GET").arg("foo").query(&mut conn), Ok(42)); + + cmd("SET").arg("bar").arg("foo").execute(&mut conn); + assert_eq!( + cmd("GET").arg("bar").query(&mut conn), + Ok(Value::BulkString(b"foo".as_ref().into())) + ); + } + + #[cfg(feature = "aio")] + #[tokio::test] + async fn async_basic_test() { + let mut conn = MockRedisConnection::new(vec![ + MockCmd::new(cmd("SET").arg("foo").arg(42), Ok("")), + MockCmd::new(cmd("GET").arg("foo"), Ok(42)), + MockCmd::new(cmd("SET").arg("bar").arg("foo"), Ok("")), + MockCmd::new(cmd("GET").arg("bar"), Ok("foo")), + ]); + + cmd("SET") + .arg("foo") + .arg("42") + .query_async::<_, ()>(&mut conn) + .await + .unwrap(); + let result: Result = cmd("GET").arg("foo").query_async(&mut conn).await; + assert_eq!(result, Ok(42)); + + cmd("SET") + .arg("bar") + .arg("foo") + .query_async::<_, ()>(&mut conn) + .await + .unwrap(); + let result: Result, _> = cmd("GET").arg("bar").query_async(&mut conn).await; + assert_eq!(result.as_deref(), Ok(&b"foo"[..])); + } + + #[test] + fn errors_for_unexpected_commands() { + let mut conn = MockRedisConnection::new(vec![ + MockCmd::new(cmd("SET").arg("foo").arg(42), Ok("")), + MockCmd::new(cmd("GET").arg("foo"), Ok(42)), + ]); + + cmd("SET").arg("foo").arg(42).execute(&mut conn); + assert_eq!(cmd("GET").arg("foo").query(&mut conn), Ok(42)); + + let err = cmd("SET") + .arg("bar") + .arg("foo") + .query::<()>(&mut conn) + .unwrap_err(); + assert_eq!(err.kind(), ErrorKind::ClientError); + assert_eq!(err.detail(), Some("unexpected command")); + } + + #[test] + fn errors_for_mismatched_commands() { + let mut conn = MockRedisConnection::new(vec![ + MockCmd::new(cmd("SET").arg("foo").arg(42), Ok("")), + MockCmd::new(cmd("GET").arg("foo"), Ok(42)), + MockCmd::new(cmd("SET").arg("bar").arg("foo"), Ok("")), + ]); + + cmd("SET").arg("foo").arg(42).execute(&mut conn); + let err = cmd("SET") + .arg("bar") + .arg("foo") + .query::<()>(&mut conn) + .unwrap_err(); + assert_eq!(err.kind(), ErrorKind::ClientError); + assert!(err.detail().unwrap().contains("unexpected command")); + } + + #[test] + fn pipeline_basic_test() { + let mut conn = MockRedisConnection::new(vec![MockCmd::with_values( + pipe().cmd("GET").arg("foo").cmd("GET").arg("bar"), + Ok(vec!["hello", "world"]), + )]); + + let results: Vec = pipe() + .cmd("GET") + .arg("foo") + .cmd("GET") + .arg("bar") + .query(&mut conn) + .expect("success"); + assert_eq!(results, vec!["hello", "world"]); + } + + #[test] + fn pipeline_atomic_test() { + let mut conn = MockRedisConnection::new(vec![MockCmd::with_values( + pipe().atomic().cmd("GET").arg("foo").cmd("GET").arg("bar"), + Ok(vec![Value::Array( + vec!["hello", "world"] + .into_iter() + .map(|x| Value::BulkString(x.as_bytes().into())) + .collect(), + )]), + )]); + + let results: Vec = pipe() + .atomic() + .cmd("GET") + .arg("foo") + .cmd("GET") + .arg("bar") + .query(&mut conn) + .expect("success"); + assert_eq!(results, vec!["hello", "world"]); + } +} diff --git a/glide-core/redis-rs/redis/CHANGELOG.md b/glide-core/redis-rs/redis/CHANGELOG.md new file mode 100644 index 0000000000..9c3dd18524 --- /dev/null +++ b/glide-core/redis-rs/redis/CHANGELOG.md @@ -0,0 +1,828 @@ +### 0.25.2 (2024-03-15) + +* MultiplexedConnection: Separate response handling for pipeline. ([#1078](https://github.com/redis-rs/redis-rs/pull/1078)) + +### 0.25.1 (2024-03-12) + +* Fix small disambiguity in examples ([#1072](https://github.com/redis-rs/redis-rs/pull/1072) @sunhuachuang) +* Upgrade to socket2 0.5 ([#1073](https://github.com/redis-rs/redis-rs/pull/1073) @djc) +* Avoid library dependency on futures-time ([#1074](https://github.com/redis-rs/redis-rs/pull/1074) @djc) + + +### 0.25.0 (2024-03-08) + +#### Features + +* **Breaking change**: Add connection timeout to the cluster client ([#834](https://github.com/redis-rs/redis-rs/pull/834)) +* **Breaking change**: Deprecate aio::Connection ([#889](https://github.com/redis-rs/redis-rs/pull/889)) +* Cluster: fix read from replica & missing slots ([#965](https://github.com/redis-rs/redis-rs/pull/965)) +* Async cluster connection: Improve handling of missing connections ([#968](https://github.com/redis-rs/redis-rs/pull/968)) +* Add support for parsing to/from any sized arrays ([#981](https://github.com/redis-rs/redis-rs/pull/981)) +* Upgrade to rustls 0.22 ([#1000](https://github.com/redis-rs/redis-rs/pull/1000) @djc) +* add SMISMEMBER command ([#1002](https://github.com/redis-rs/redis-rs/pull/1002) @Zacaria) +* Add support for some big number types ([#1014](https://github.com/redis-rs/redis-rs/pull/1014) @AkiraMiyakoda) +* Add Support for UUIDs ([#1029](https://github.com/redis-rs/redis-rs/pull/1029) @Rabbitminers) +* Add FromRedisValue::from_owned_redis_value to reduce copies while parsing response ([#1030](https://github.com/redis-rs/redis-rs/pull/1030) @Nathan-Fenner) +* Save reconnected connections during retries ([#1033](https://github.com/redis-rs/redis-rs/pull/1033)) +* Avoid panic on connection failure ([#1035](https://github.com/redis-rs/redis-rs/pull/1035)) +* add disable client setinfo feature and its default mode is off ([#1036](https://github.com/redis-rs/redis-rs/pull/1036) @Ggiggle) +* Reconnect on parsing errors ([#1051](https://github.com/redis-rs/redis-rs/pull/1051)) +* preallocate buffer for evalsha in Script ([#1044](https://github.com/redis-rs/redis-rs/pull/1044) @framlog) + +#### Changes + +* Align more commands routings ([#938](https://github.com/redis-rs/redis-rs/pull/938)) +* Fix HashMap conversion ([#977](https://github.com/redis-rs/redis-rs/pull/977) @mxbrt) +* MultiplexedConnection: Remove unnecessary allocation in send ([#990](https://github.com/redis-rs/redis-rs/pull/990)) +* Tests: Reduce cluster setup flakiness ([#999](https://github.com/redis-rs/redis-rs/pull/999)) +* Remove the unwrap_or! macro ([#1010](https://github.com/redis-rs/redis-rs/pull/1010)) +* Remove allocation from command function ([#1008](https://github.com/redis-rs/redis-rs/pull/1008)) +* Catch panics from task::spawn in tests ([#1015](https://github.com/redis-rs/redis-rs/pull/1015)) +* Fix lint errors from new Rust version ([#1016](https://github.com/redis-rs/redis-rs/pull/1016)) +* Fix warnings that appear only with native-TLS ([#1018](https://github.com/redis-rs/redis-rs/pull/1018)) +* Hide the req_packed_commands from docs ([#1020](https://github.com/redis-rs/redis-rs/pull/1020)) +* Fix documentaion error ([#1022](https://github.com/redis-rs/redis-rs/pull/1022) @rcl-viveksharma) +* Fixes minor grammar mistake in json.rs file ([#1026](https://github.com/redis-rs/redis-rs/pull/1026) @RScrusoe) +* Enable ignored pipe test ([#1027](https://github.com/redis-rs/redis-rs/pull/1027)) +* Fix names of existing async cluster tests ([#1028](https://github.com/redis-rs/redis-rs/pull/1028)) +* Add lock file to keep MSRV constant ([#1039](https://github.com/redis-rs/redis-rs/pull/1039)) +* Fail CI if lock file isn't updated ([#1042](https://github.com/redis-rs/redis-rs/pull/1042)) +* impl Clone/Copy for SetOptions ([#1046](https://github.com/redis-rs/redis-rs/pull/1046) @ahmadbky) +* docs: add "connection-manager" cfg attr ([#1048](https://github.com/redis-rs/redis-rs/pull/1048) @DCNick3) +* Remove the usage of aio::Connection in tests ([#1049](https://github.com/redis-rs/redis-rs/pull/1049)) +* Fix new clippy lints ([#1052](https://github.com/redis-rs/redis-rs/pull/1052)) +* Handle server errors in array response ([#1056](https://github.com/redis-rs/redis-rs/pull/1056)) +* Appease Clippy ([#1061](https://github.com/redis-rs/redis-rs/pull/1061)) +* make Pipeline handle returned bulks correctly ([#1063](https://github.com/redis-rs/redis-rs/pull/1063) @framlog) +* Update mio dependency due to vulnerability ([#1064](https://github.com/redis-rs/redis-rs/pull/1064)) +* Simplify Sink polling logic ([#1065](https://github.com/redis-rs/redis-rs/pull/1065)) +* Separate parsing errors from general response errors ([#1069](https://github.com/redis-rs/redis-rs/pull/1069)) + +### 0.24.0 (2023-12-05) + +#### Features +* **Breaking change**: Support Mutual TLS ([#858](https://github.com/redis-rs/redis-rs/pull/858) @sp-angel) +* Implement `FromRedisValue` for `Box<[T]>` and `Arc<[T]>` ([#799](https://github.com/redis-rs/redis-rs/pull/799) @JOT85) +* Sync Cluster: support multi-slot operations. ([#967](https://github.com/redis-rs/redis-rs/pull/967)) +* Execute multi-node requests using try_request. ([#919](https://github.com/redis-rs/redis-rs/pull/919)) +* Sorted set blocking commands ([#962](https://github.com/redis-rs/redis-rs/pull/962) @gheorghitamutu) +* Allow passing routing information to cluster. ([#899](https://github.com/redis-rs/redis-rs/pull/899)) +* Add `tcp_nodelay` feature ([#941](https://github.com/redis-rs/redis-rs/pull/941) @PureWhiteWu) +* Add support for multi-shard commands. ([#900](https://github.com/redis-rs/redis-rs/pull/900)) + +#### Changes +* Order in usage of ClusterParams. ([#997](https://github.com/redis-rs/redis-rs/pull/997)) +* **Breaking change**: Fix StreamId::contains_key signature ([#783](https://github.com/redis-rs/redis-rs/pull/783) @Ayush1325) +* **Breaking change**: Update Command expiration values to be an appropriate type ([#589](https://github.com/redis-rs/redis-rs/pull/589) @joshleeb) +* **Breaking change**: Bump aHash to v0.8.6 ([#966](https://github.com/redis-rs/redis-rs/pull/966) @aumetra) +* Fix features for `load_native_certs`. ([#996](https://github.com/redis-rs/redis-rs/pull/996)) +* Revert redis-test versioning changes ([#993](https://github.com/redis-rs/redis-rs/pull/993)) +* Tests: Add retries to test cluster creation ([#994](https://github.com/redis-rs/redis-rs/pull/994)) +* Fix sync cluster behavior with transactions. ([#983](https://github.com/redis-rs/redis-rs/pull/983)) +* Sync Pub/Sub - cache received pub/sub messages. ([#910](https://github.com/redis-rs/redis-rs/pull/910)) +* Prefer routing to primary in a transaction. ([#986](https://github.com/redis-rs/redis-rs/pull/986)) +* Accept iterator at `ClusterClient` initialization ([#987](https://github.com/redis-rs/redis-rs/pull/987) @ruanpetterson) +* **Breaking change**: Change timeouts from usize and isize to f64 ([#988](https://github.com/redis-rs/redis-rs/pull/988) @eythorhel19) +* Update minimal rust version to 1.6.5 ([#982](https://github.com/redis-rs/redis-rs/pull/982)) +* Disable JSON module tests for redis 6.2.4. ([#980](https://github.com/redis-rs/redis-rs/pull/980)) +* Add connection string examples ([#976](https://github.com/redis-rs/redis-rs/pull/976) @NuclearOreo) +* Move response policy into multi-node routing. ([#952](https://github.com/redis-rs/redis-rs/pull/952)) +* Added functions that allow tests to check version. ([#963](https://github.com/redis-rs/redis-rs/pull/963)) +* Fix XREADGROUP command ordering as per Redis Docs, and compatibility with Upstash Redis ([#960](https://github.com/redis-rs/redis-rs/pull/960) @prabhpreet) +* Optimize make_pipeline_results by pre-allocate memory ([#957](https://github.com/redis-rs/redis-rs/pull/957) @PureWhiteWu) +* Run module tests sequentially. ([#956](https://github.com/redis-rs/redis-rs/pull/956)) +* Log cluster creation output in tests. ([#955](https://github.com/redis-rs/redis-rs/pull/955)) +* CI: Update and use better maintained github actions. ([#954](https://github.com/redis-rs/redis-rs/pull/954)) +* Call CLIENT SETINFO on new connections. ([#945](https://github.com/redis-rs/redis-rs/pull/945)) +* Deprecate functions that erroneously use `tokio` in their name. ([#913](https://github.com/redis-rs/redis-rs/pull/913)) +* CI: Increase timeouts and use newer redis. ([#949](https://github.com/redis-rs/redis-rs/pull/949)) +* Remove redis version from redis-test. ([#943](https://github.com/redis-rs/redis-rs/pull/943)) + +### 0.23.4 (2023-11-26) +**Yanked** -- Inadvertently introduced breaking changes (sorry!). The changes in this tag +have been pushed to 0.24.0. + +### 0.23.3 (2023-09-01) + +Note that this release fixes a small regression in async Redis Cluster handling of the `PING` command. +Based on updated response aggregation logic in [#888](https://github.com/redis-rs/redis-rs/pull/888), it +will again return a single response instead of an array. + +#### Features +* Add `key_type` command ([#933](https://github.com/redis-rs/redis-rs/pull/933) @bruaba) +* Async cluster: Group responses by response_policy. ([#888](https://github.com/redis-rs/redis-rs/pull/888)) + + +#### Fixes +* Remove unnecessary heap allocation ([#939](https://github.com/redis-rs/redis-rs/pull/939) @thechampagne) +* Sentinel tests: Ensure no ports are used twice ([#915](https://github.com/redis-rs/redis-rs/pull/915)) +* Fix lint issues ([#937](https://github.com/redis-rs/redis-rs/pull/937)) +* Fix JSON serialization error test ([#928](https://github.com/redis-rs/redis-rs/pull/928)) +* Remove unused dependencies ([#916](https://github.com/redis-rs/redis-rs/pull/916)) + + +### 0.23.2 (2023-08-10) + +#### Fixes +* Fix sentinel tests flakiness ([#912](https://github.com/redis-rs/redis-rs/pull/912)) +* Rustls: Remove usage of deprecated method ([#921](https://github.com/redis-rs/redis-rs/pull/921)) +* Fix compiling with sentinel feature, without aio feature ([#922](https://github.com/redis-rs/redis-rs/pull/923) @brocaar) +* Add timeouts to tests github action ([#911](https://github.com/redis-rs/redis-rs/pull/911)) + +### 0.23.1 (2023-07-28) + +#### Features +* Add basic Sentinel functionality ([#836](https://github.com/redis-rs/redis-rs/pull/836) @felipou) +* Enable keep alive on tcp connections via feature ([#886](https://github.com/redis-rs/redis-rs/pull/886) @DoumanAsh) +* Support fan-out commands in cluster-async ([#843](https://github.com/redis-rs/redis-rs/pull/843) @nihohit) +* connection_manager: retry and backoff on reconnect ([#804](https://github.com/redis-rs/redis-rs/pull/804) @nihohit) + +#### Changes +* Tests: Wait for all servers ([#901](https://github.com/redis-rs/redis-rs/pull/901) @barshaul) +* Pin `tempfile` dependency ([#902](https://github.com/redis-rs/redis-rs/pull/902)) +* Update routing data for commands. ([#887](https://github.com/redis-rs/redis-rs/pull/887) @nihohit) +* Add basic benchmark reporting to CI ([#880](https://github.com/redis-rs/redis-rs/pull/880)) +* Add `set_options` cmd ([#879](https://github.com/redis-rs/redis-rs/pull/879) @RokasVaitkevicius) +* Move random connection creation to when needed. ([#882](https://github.com/redis-rs/redis-rs/pull/882) @nihohit) +* Clean up existing benchmarks ([#881](https://github.com/redis-rs/redis-rs/pull/881)) +* Improve async cluster client performance. ([#877](https://github.com/redis-rs/redis-rs/pull/877) @nihohit) +* Allow configuration of cluster retry wait duration ([#859](https://github.com/redis-rs/redis-rs/pull/859) @nihohit) +* Fix async connect when ns resolved to multi ip ([#872](https://github.com/redis-rs/redis-rs/pull/872) @hugefiver) +* Reduce the number of unnecessary clones. ([#874](https://github.com/redis-rs/redis-rs/pull/874) @nihohit) +* Remove connection checking on every request. ([#873](https://github.com/redis-rs/redis-rs/pull/873) @nihohit) +* cluster_async: Wrap internal state with Arc. ([#864](https://github.com/redis-rs/redis-rs/pull/864) @nihohit) +* Fix redirect routing on request with no route. ([#870](https://github.com/redis-rs/redis-rs/pull/870) @nihohit) +* Amend README for macOS users ([#869](https://github.com/redis-rs/redis-rs/pull/869) @sarisssa) +* Improved redirection error handling ([#857](https://github.com/redis-rs/redis-rs/pull/857)) +* Fix minor async client bug. ([#862](https://github.com/redis-rs/redis-rs/pull/862) @nihohit) +* Split aio.rs to separate files. ([#821](https://github.com/redis-rs/redis-rs/pull/821) @nihohit) +* Add time feature to tokio dependency ([#855](https://github.com/redis-rs/redis-rs/pull/855) @robjtede) +* Refactor cluster error handling ([#844](https://github.com/redis-rs/redis-rs/pull/844)) +* Fix unnecessarily mutable variable ([#849](https://github.com/redis-rs/redis-rs/pull/849) @kamulos) +* Newtype SlotMap ([#845](https://github.com/redis-rs/redis-rs/pull/845)) +* Bump MSRV to 1.60 ([#846](https://github.com/redis-rs/redis-rs/pull/846)) +* Improve error logging. ([#838](https://github.com/redis-rs/redis-rs/pull/838) @nihohit) +* Improve documentation, add references to `redis-macros` ([#769](https://github.com/redis-rs/redis-rs/pull/769) @daniel7grant) +* Allow creating Cmd with capacity. ([#817](https://github.com/redis-rs/redis-rs/pull/817) @nihohit) + + +### 0.23.0 (2023-04-05) +In addition to *everything mentioned in 0.23.0-beta.1 notes*, this release adds support for Rustls, a long- +sought feature. Thanks to @rharish101 and @LeoRowan for getting this in! + +#### Changes +* Update Rustls to v0.21.0 ([#820](https://github.com/redis-rs/redis-rs/pull/820) @rharish101) +* Implement support for Rustls ([#725](https://github.com/redis-rs/redis-rs/pull/725) @rharish101, @LeoRowan) + + +### 0.23.0-beta.1 (2023-03-28) + +This release adds the `cluster_async` module, which introduces async Redis Cluster support. The code therein +is largely taken from @Marwes's [redis-cluster-async crate](https://github.com/redis-rs/redis-cluster-async), which itself +appears to have started from a sync Redis Cluster implementation started by @atuk721. In any case, thanks to @Marwes and @atuk721 +for the great work, and we hope to keep development moving forward in `redis-rs`. + +Though async Redis Cluster functionality for the time being has been kept as close to the originating crate as possible, previous users of +`redis-cluster-async` should note the following changes: +* Retries, while still configurable, can no longer be set to `None`/infinite retries +* Routing and slot parsing logic has been removed and merged with existing `redis-rs` functionality +* The client has been removed and superceded by common `ClusterClient` +* Renamed `Connection` to `ClusterConnection` +* Added support for reading from replicas +* Added support for insecure TLS +* Added support for setting both username and password + +#### Breaking Changes +* Fix long-standing bug related to `AsyncIter`'s stream implementation in which polling the server + for additional data yielded broken data in most cases. Type bounds for `AsyncIter` have changed slightly, + making this a potentially breaking change. ([#597](https://github.com/redis-rs/redis-rs/pull/597) @roger) + +#### Changes +* Commands: Add additional generic args for key arguments ([#795](https://github.com/redis-rs/redis-rs/pull/795) @MaxOhn) +* Add `mset` / deprecate `set_multiple` ([#766](https://github.com/redis-rs/redis-rs/pull/766) @randomairborne) +* More efficient interfaces for `MultiplexedConnection` and `ConnectionManager` ([#811](https://github.com/redis-rs/redis-rs/pull/811) @nihohit) +* Refactor / remove flaky test ([#810](https://github.com/redis-rs/redis-rs/pull/810)) +* `cluster_async`: rename `Connection` to `ClusterConnection`, `Pipeline` to `ClusterConnInner` ([#808](https://github.com/redis-rs/redis-rs/pull/808)) +* Support parsing IPV6 cluster nodes ([#796](https://github.com/redis-rs/redis-rs/pull/796) @socs) +* Common client for sync/async cluster connections ([#798](https://github.com/redis-rs/redis-rs/pull/798)) + * `cluster::ClusterConnection` underlying connection type is now generic (with existing type as default) + * Support `read_from_replicas` in cluster_async + * Set retries in `ClusterClientBuilder` + * Add mock tests for `cluster` +* cluster-async common slot parsing([#793](https://github.com/redis-rs/redis-rs/pull/793)) +* Support async-std in cluster_async module ([#790](https://github.com/redis-rs/redis-rs/pull/790)) +* Async-Cluster use same routing as Sync-Cluster ([#789](https://github.com/redis-rs/redis-rs/pull/789)) +* Add Async Cluster Support ([#696](https://github.com/redis-rs/redis-rs/pull/696)) +* Fix broken json-module tests ([#786](https://github.com/redis-rs/redis-rs/pull/786)) +* `cluster`: Tls Builder support / simplify cluster connection map ([#718](https://github.com/redis-rs/redis-rs/pull/718) @0xWOF, @utkarshgupta137) + + +### 0.22.3 (2023-01-23) + +#### Changes +* Restore inherent `ClusterConnection::check_connection()` method ([#758](https://github.com/redis-rs/redis-rs/pull/758) @robjtede) + + + +### 0.22.2 (2023-01-07) + +This release adds various incremental improvements and fixes a few long-standing bugs. Thanks to all our +contributors for making this release happen. + +#### Features +* Implement ToRedisArgs for HashMap ([#722](https://github.com/redis-rs/redis-rs/pull/722) @gibranamparan) +* Add explicit `MGET` command ([#729](https://github.com/redis-rs/redis-rs/pull/729) @vamshiaruru-virgodesigns) + +#### Bug fixes +* Enable single-item-vector `get` responses ([#507](https://github.com/redis-rs/redis-rs/pull/507) @hank121314) +* Fix empty result from xread_options with deleted entries ([#712](https://github.com/redis-rs/redis-rs/pull/712) @Quiwin) +* Limit Parser Recursion ([#724](https://github.com/redis-rs/redis-rs/pull/724)) +* Improve MultiplexedConnection Error Handling ([#699](https://github.com/redis-rs/redis-rs/pull/699)) + +#### Changes +* Add test case for atomic pipeline ([#702](https://github.com/redis-rs/redis-rs/pull/702) @CNLHC) +* Capture subscribe result error in PubSub doc example ([#739](https://github.com/redis-rs/redis-rs/pull/739) @baoyachi) +* Use async-std name resolution when necessary ([#701](https://github.com/redis-rs/redis-rs/pull/701) @UgnilJoZ) +* Add Script::invoke_async method ([#711](https://github.com/redis-rs/redis-rs/pull/711) @r-bk) +* Cluster Refactorings ([#717](https://github.com/redis-rs/redis-rs/pull/717), [#716](https://github.com/redis-rs/redis-rs/pull/716), [#709](https://github.com/redis-rs/redis-rs/pull/709), [#707](https://github.com/redis-rs/redis-rs/pull/707), [#706](https://github.com/redis-rs/redis-rs/pull/706) @0xWOF, @utkarshgupta137) +* Fix intermitent test failure ([#714](https://github.com/redis-rs/redis-rs/pull/714) @0xWOF, @utkarshgupta137) +* Doc changes ([#705](https://github.com/redis-rs/redis-rs/pull/705) @0xWOF, @utkarshgupta137) +* Lint fixes ([#704](https://github.com/redis-rs/redis-rs/pull/704) @0xWOF) + + + +### 0.22.1 (2022-10-18) + +#### Changes +* Add README attribute to Cargo.toml +* Update LICENSE file / symlink from parent directory + + +### 0.22.0 (2022-10-05) + +This release adds various incremental improvements, including +additional convenience commands, improved Cluster APIs, and various other bug +fixes/library improvements. + +Although the changes here are incremental, this is a major release due to the +breaking changes listed below. + +This release would not be possible without our many wonderful +contributors -- thank you! + +#### Breaking changes +* Box all large enum variants; changes enum signature ([#667](https://github.com/redis-rs/redis-rs/pull/667) @nihohit) +* Support ACL commands by adding Rule::Other to cover newly defined flags; adds new enum variant ([#685](https://github.com/redis-rs/redis-rs/pull/685) @garyhai) +* Switch from sha1 to sha1_smol; renames `sha1` feature ([#576](https://github.com/redis-rs/redis-rs/pull/576)) + +#### Features +* Add support for RedisJSON ([#657](https://github.com/redis-rs/redis-rs/pull/657) @d3rpp) +* Add support for weights in zunionstore and zinterstore ([#641](https://github.com/redis-rs/redis-rs/pull/641) @ndd7xv) +* Cluster: Create read_from_replicas option ([#635](https://github.com/redis-rs/redis-rs/pull/635) @utkarshgupta137) +* Make Direction a public enum to use with Commands like BLMOVE ([#646](https://github.com/redis-rs/redis-rs/pull/646) @thorbadour) +* Add `ahash` feature for using ahash internally & for redis values ([#636](https://github.com/redis-rs/redis-rs/pull/636) @utkarshgupta137) +* Add Script::load function ([#603](https://github.com/redis-rs/redis-rs/pull/603) @zhiburt) +* Add support for OBJECT ([[#610]](https://github.com/redis-rs/redis-rs/pull/610) @roger) +* Add GETEX and GETDEL support ([#582](https://github.com/redis-rs/redis-rs/pull/582) @arpandaze) +* Add support for ZMPOP ([#605](https://github.com/redis-rs/redis-rs/pull/605) @gkorland) + +#### Changes +* Rust 2021 Edition / MSRV 1.59.0 +* Fix: Support IPV6 link-local address parsing ([#679](https://github.com/redis-rs/redis-rs/pull/679) @buaazp) +* Derive Clone and add Deref trait to InfoDict ([#661](https://github.com/redis-rs/redis-rs/pull/661) @danni-m) +* ClusterClient: add handling for empty initial_nodes, use ClusterParams to store cluster parameters, improve builder pattern ([#669](https://github.com/redis-rs/redis-rs/pull/669) @utkarshgupta137) +* Implement Debug for MultiplexedConnection & Pipeline ([#664](https://github.com/redis-rs/redis-rs/pull/664) @elpiel) +* Add support for casting RedisResult to CString ([#660](https://github.com/redis-rs/redis-rs/pull/660) @nihohit) +* Move redis crate to subdirectory to support multiple crates in project ([#465](https://github.com/redis-rs/redis-rs/pull/465) @tdyas) +* Stop versioning Cargo.lock ([#620](https://github.com/redis-rs/redis-rs/pull/620)) +* Auto-implement ConnectionLike for DerefMut ([#567](https://github.com/redis-rs/redis-rs/pull/567) @holmesmr) +* Return errors from parsing inner items ([#608](https://github.com/redis-rs/redis-rs/pull/608)) +* Make dns resolution async, in async runtime ([#606](https://github.com/redis-rs/redis-rs/pull/606) @roger) +* Make async_trait dependency optional ([#572](https://github.com/redis-rs/redis-rs/pull/572) @kamulos) +* Add username to ClusterClient and ClusterConnection ([#596](https://github.com/redis-rs/redis-rs/pull/596) @gildaf) + + + +### 0.21.6 (2022-08-24) + +* Update dependencies ([#588](https://github.com/mitsuhiko/redis-rs/pull/588)) + + +### 0.21.5 (2022-01-10) + +#### Features + +* Add new list commands ([#570](https://github.com/mitsuhiko/redis-rs/pull/570)) + + + +### 0.21.4 (2021-11-04) + +#### Features + +* Add convenience command: zrandbember ([#556](https://github.com/mitsuhiko/redis-rs/pull/556)) + + + + +### 0.21.3 (2021-10-15) + +#### Features + +* Add support for TLS with cluster mode ([#548](https://github.com/mitsuhiko/redis-rs/pull/548)) + +#### Changes + +* Remove stunnel as a dep, use redis native tls ([#542](https://github.com/mitsuhiko/redis-rs/pull/542)) + + + + + +### 0.21.2 (2021-09-02) + + +#### Bug Fixes + +* Compile with tokio-comp and up-to-date dependencies ([282f997e](https://github.com/mitsuhiko/redis-rs/commit/282f997e41cc0de2a604c0f6a96d82818dacc373), closes [#531](https://github.com/mitsuhiko/redis-rs/issues/531), breaks [#](https://github.com/mitsuhiko/redis-rs/issues/)) + +#### Breaking Changes + +* Compile with tokio-comp and up-to-date dependencies ([282f997e](https://github.com/mitsuhiko/redis-rs/commit/282f997e41cc0de2a604c0f6a96d82818dacc373), closes [#531](https://github.com/mitsuhiko/redis-rs/issues/531), breaks [#](https://github.com/mitsuhiko/redis-rs/issues/)) + + + + +### 0.21.1 (2021-08-25) + + +#### Bug Fixes + +* pin futures dependency to required version ([9be392bc](https://github.com/mitsuhiko/redis-rs/commit/9be392bc5b22326a8a0eafc7aa18cc04c5d79e0e)) + + + + +### 0.21.0 (2021-07-16) + + +#### Performance + +* Don't enqueue multiplexed commands if the receiver is dropped ([ca5019db](https://github.com/mitsuhiko/redis-rs/commit/ca5019dbe76cc56c93eaecb5721de8fcf74d1641)) + +#### Features + +* Refactor ConnectionAddr to remove boxing and clarify fields + + +### 0.20.2 (2021-06-17) + +#### Features + +* Provide a new_async_std function ([c3716d15](https://github.com/mitsuhiko/redis-rs/commit/c3716d154f067b71acdd5bd927e118305cd0830b)) + +#### Bug Fixes + +* Return Ready(Ok(())) when we have flushed all messages ([ca319c06](https://github.com/mitsuhiko/redis-rs/commit/ca319c06ad80fc37f1f701aecebbd5dabb0dceb0)) +* Don't loop forever on shutdown of the multiplexed connection ([ddecce9e](https://github.com/mitsuhiko/redis-rs/commit/ddecce9e10b8ab626f41409aae289d62b4fb74be)) + + + + +### 0.20.1 (2021-05-18) + + +#### Bug Fixes + +* Error properly if eof is reached in the decoder ([306797c3](https://github.com/mitsuhiko/redis-rs/commit/306797c3c55ab24e0a29b6517356af794731d326)) + + + + +## 0.20.0 (2021-02-17) + + +#### Features + +* Make ErrorKind non_exhaustive for forwards compatibility ([ac5e1a60](https://github.com/mitsuhiko/redis-rs/commit/ac5e1a60d398814b18ed1b579fe3f6b337e545e9)) +* **aio:** Allow the underlying IO stream to be customized ([6d2fc8fa](https://github.com/mitsuhiko/redis-rs/commit/6d2fc8faa707fbbbaae9fe092bbc90ce01224523)) + + + + +## 0.19.0 (2020-12-26) + + +#### Features + +* Update to tokio 1.0 ([41960194](https://github.com/mitsuhiko/redis-rs/commit/4196019494aafc2bab718bafd1fdfd5e8c195ffa)) +* use the node specified in the MOVED error ([8a53e269](https://github.com/mitsuhiko/redis-rs/commit/8a53e2699d7d7bd63f222de778ed6820b0a65665)) + + + + +## 0.18.0 (2020-12-03) + + +#### Bug Fixes + +* Don't require tokio for the connection manager ([46be86f3](https://github.com/mitsuhiko/redis-rs/commit/46be86f3f07df4900559bf9a4dfd0b5138c3ac52)) + +* Make ToRedisArgs and FromRedisValue consistent for booleans + +BREAKING CHANGE + +bool are now written as 0 and 1 instead of true and false. Parsing a bool still accept true and false so this should not break anything for most users however if you are reading something out that was stored as a bool you may see different results. + +#### Features + +* Update tokio dependency to 0.3 ([bf5e0af3](https://github.com/mitsuhiko/redis-rs/commit/bf5e0af31c08be1785656031ffda96c355ee83c4), closes [#396](https://github.com/mitsuhiko/redis-rs/issues/396)) +* add doc_cfg for Makefile and docs.rs config ([1bf79517](https://github.com/mitsuhiko/redis-rs/commit/1bf795174521160934f3695326897458246e4978)) +* Impl FromRedisValue for i128 and u128 + + +# Changelog + +## [0.18.0](https://github.com/mitsuhiko/redis-rs/compare/0.17.0...0.18.0) - 2020-12-03 + +## [0.17.0](https://github.com/mitsuhiko/redis-rs/compare/0.16.0...0.17.0) - 2020-07-29 + +**Fixes and improvements** + +* Added Redis Streams commands ([#162](https://github.com/mitsuhiko/redis-rs/pull/319)) +* Added support for zpopmin and zpopmax ([#351](https://github.com/mitsuhiko/redis-rs/pull/351)) +* Added TLS support, gated by a feature flag ([#305](https://github.com/mitsuhiko/redis-rs/pull/305)) +* Added Debug and Clone implementations to redis::Script ([#365](https://github.com/mitsuhiko/redis-rs/pull/365)) +* Added FromStr for ConnectionInfo ([#368](https://github.com/mitsuhiko/redis-rs/pull/368)) +* Support SCAN methods on async connections ([#326](https://github.com/mitsuhiko/redis-rs/pull/326)) +* Removed unnecessary overhead around `Value` conversions ([#327](https://github.com/mitsuhiko/redis-rs/pull/327)) +* Support for Redis 6 auth ([#341](https://github.com/mitsuhiko/redis-rs/pull/341)) +* BUGFIX: Make aio::Connection Sync again ([#321](https://github.com/mitsuhiko/redis-rs/pull/321)) +* BUGFIX: Return UnexpectedEof if we try to decode at eof ([#322](https://github.com/mitsuhiko/redis-rs/pull/322)) +* Added support to create a connection from a (host, port) tuple ([#370](https://github.com/mitsuhiko/redis-rs/pull/370)) + +## [0.16.0](https://github.com/mitsuhiko/redis-rs/compare/0.15.1...0.16.0) - 2020-05-10 + +**Fixes and improvements** + +* Reduce dependencies without async IO ([#266](https://github.com/mitsuhiko/redis-rs/pull/266)) +* Add an afl fuzz target ([#274](https://github.com/mitsuhiko/redis-rs/pull/274)) +* Updated to combine 4 and avoid async dependencies for sync-only ([#272](https://github.com/mitsuhiko/redis-rs/pull/272)) + * **BREAKING CHANGE**: The parser type now only persists the buffer and takes the Read instance in `parse_value` +* Implement a connection manager for automatic reconnection ([#278](https://github.com/mitsuhiko/redis-rs/pull/278)) +* Add async-std support ([#281](https://github.com/mitsuhiko/redis-rs/pull/281)) +* Fix key extraction for some stream commands ([#283](https://github.com/mitsuhiko/redis-rs/pull/283)) +* Add asynchronous PubSub support ([#287](https://github.com/mitsuhiko/redis-rs/pull/287)) + +### Breaking changes + +#### Changes to the `Parser` type ([#272](https://github.com/mitsuhiko/redis-rs/pull/272)) + +The parser type now only persists the buffer and takes the Read instance in `parse_value`. +`redis::parse_redis_value` is unchanged and continues to work. + + +Old: + +```rust +let mut parser = Parser::new(bytes); +let result = parser.parse_value(); +``` + +New: + +```rust +let mut parser = Parser::new(); +let result = parser.pase_value(bytes); +``` + +## [0.15.1](https://github.com/mitsuhiko/redis-rs/compare/0.15.0...0.15.1) - 2020-01-15 + +**Fixes and improvements** + +* Fixed the `r2d2` feature (re-added it) ([#265](https://github.com/mitsuhiko/redis-rs/pull/265)) + +## [0.15.0](https://github.com/mitsuhiko/redis-rs/compare/0.14.0...0.15.0) - 2020-01-15 + +**Fixes and improvements** + +* Added support for redis cluster ([#239](https://github.com/mitsuhiko/redis-rs/pull/239)) + +## [0.14.0](https://github.com/mitsuhiko/redis-rs/compare/0.13.0...0.14.0) - 2020-01-08 + +**Fixes and improvements** + +* Fix the command verb being sent to redis for `zremrangebyrank` ([#240](https://github.com/mitsuhiko/redis-rs/pull/240)) +* Add `get_connection_with_timeout` to Client ([#243](https://github.com/mitsuhiko/redis-rs/pull/243)) +* **Breaking change:** Add Cmd::get, Cmd::set and remove PipelineCommands ([#253](https://github.com/mitsuhiko/redis-rs/pull/253)) +* Async-ify the API ([#232](https://github.com/mitsuhiko/redis-rs/pull/232)) +* Bump minimal required Rust version to 1.39 (required for the async/await API) +* Add async/await examples ([#261](https://github.com/mitsuhiko/redis-rs/pull/261), [#263](https://github.com/mitsuhiko/redis-rs/pull/263)) +* Added support for PSETEX and PTTL commands. ([#259](https://github.com/mitsuhiko/redis-rs/pull/259)) + +### Breaking changes + +#### Add Cmd::get, Cmd::set and remove PipelineCommands ([#253](https://github.com/mitsuhiko/redis-rs/pull/253)) + +If you are using pipelines and were importing the `PipelineCommands` trait you can now remove that import +and only use the `Commands` trait. + +Old: + +```rust +use redis::{Commands, PipelineCommands}; +``` + +New: + +```rust +use redis::Commands; +``` + +## [0.13.0](https://github.com/mitsuhiko/redis-rs/compare/0.12.0...0.13.0) - 2019-10-14 + +**Fixes and improvements** + +* **Breaking change:** rename `parse_async` to `parse_redis_value_async` for consistency ([ce59cecb](https://github.com/mitsuhiko/redis-rs/commit/ce59cecb830d4217115a4e74e38891e76cf01474)). +* Run clippy over the entire codebase ([#238](https://github.com/mitsuhiko/redis-rs/pull/238)) +* **Breaking change:** Make `Script#invoke_async` generic over `aio::ConnectionLike` ([#242](https://github.com/mitsuhiko/redis-rs/pull/242)) + +### BREAKING CHANGES + +#### Rename `parse_async` to `parse_redis_value_async` for consistency ([ce59cecb](https://github.com/mitsuhiko/redis-rs/commit/ce59cecb830d4217115a4e74e38891e76cf01474)). + +If you used `redis::parse_async` before, you now need to change this to `redis::parse_redis_value_async` +or import the method under the new name: `use redis::parse_redis_value_async`. + +#### Make `Script#invoke_async` generic over `aio::ConnectionLike` ([#242](https://github.com/mitsuhiko/redis-rs/pull/242)) + +`Script#invoke_async` was changed to be generic over `aio::ConnectionLike` in order to support wrapping a `SharedConnection` in user code. +This required adding a new generic parameter to the method, causing an error when the return type is defined using the turbofish syntax. + +Old: + +```rust +redis::Script::new("return ...") + .key("key1") + .arg("an argument") + .invoke_async::() +``` + +New: + +```rust +redis::Script::new("return ...") + .key("key1") + .arg("an argument") + .invoke_async::<_, String>() +``` + +## [0.12.0](https://github.com/mitsuhiko/redis-rs/compare/0.11.0...0.12.0) - 2019-08-26 + +**Fixes and improvements** + +* **Breaking change:** Use `dyn` keyword to avoid deprecation warning ([#223](https://github.com/mitsuhiko/redis-rs/pull/223)) +* **Breaking change:** Update url dependency to v2 ([#234](https://github.com/mitsuhiko/redis-rs/pull/234)) +* **Breaking change:** (async) Fix `Script::invoke_async` return type error ([#233](https://github.com/mitsuhiko/redis-rs/pull/233)) +* Add `GETRANGE` and `SETRANGE` commands ([#202](https://github.com/mitsuhiko/redis-rs/pull/202)) +* Fix `SINTERSTORE` wrapper name, it's now correctly `sinterstore` ([#225](https://github.com/mitsuhiko/redis-rs/pull/225)) +* Allow running `SharedConnection` with any other runtime ([#229](https://github.com/mitsuhiko/redis-rs/pull/229)) +* Reformatted as Edition 2018 code ([#235](https://github.com/mitsuhiko/redis-rs/pull/235)) + +### BREAKING CHANGES + +#### Use `dyn` keyword to avoid deprecation warning ([#223](https://github.com/mitsuhiko/redis-rs/pull/223)) + +Rust nightly deprecated bare trait objects. +This PR adds the `dyn` keyword to all trait objects in order to get rid of the warning. +This bumps the minimal supported Rust version to [Rust 1.27](https://blog.rust-lang.org/2018/06/21/Rust-1.27.html). + +#### Update url dependency to v2 ([#234](https://github.com/mitsuhiko/redis-rs/pull/234)) + +We updated the `url` dependency to v2. We do expose this on our public API on the `redis::parse_redis_url` function. If you depend on that, make sure to also upgrade your direct dependency. + +#### (async) Fix Script::invoke_async return type error ([#233](https://github.com/mitsuhiko/redis-rs/pull/233)) + +Previously, invoking a script with a complex return type would cause the following error: + +``` +Response was of incompatible type: "Not a bulk response" (response was string data('"4b98bef92b171357ddc437b395c7c1a5145ca2bd"')) +``` + +This was because the Future returned when loading the script into the database returns the hash of the script, and thus the return type of `String` would not match the intended return type. + +This commit adds an enum to account for the different Future return types. + + +## [0.11.0](https://github.com/mitsuhiko/redis-rs/compare/0.11.0-beta.2...0.11.0) - 2019-07-19 + +This release includes all fixes & improvements from the two beta releases listed below. +This release contains breaking changes. + +**Fixes and improvements** + +* (async) Fix performance problem for SharedConnection ([#222](https://github.com/mitsuhiko/redis-rs/pull/222)) + +## [0.11.0-beta.2](https://github.com/mitsuhiko/redis-rs/compare/0.11.0-beta.1...0.11.0-beta.2) - 2019-07-14 + +**Fixes and improvements** + +* (async) Don't block the executor from shutting down ([#217](https://github.com/mitsuhiko/redis-rs/pull/217)) + +## [0.11.0-beta.1](https://github.com/mitsuhiko/redis-rs/compare/0.10.0...0.11.0-beta.1) - 2019-05-30 + +**Fixes and improvements** + +* (async) Simplify implicit pipeline handling ([#182](https://github.com/mitsuhiko/redis-rs/pull/182)) +* (async) Use `tokio_sync`'s channels instead of futures ([#195](https://github.com/mitsuhiko/redis-rs/pull/195)) +* (async) Only allocate one oneshot per request ([#194](https://github.com/mitsuhiko/redis-rs/pull/194)) +* Remove redundant BufReader when parsing ([#197](https://github.com/mitsuhiko/redis-rs/pull/197)) +* Hide actual type returned from async parser ([#193](https://github.com/mitsuhiko/redis-rs/pull/193)) +* Use more performant operations for line parsing ([#198](https://github.com/mitsuhiko/redis-rs/pull/198)) +* Optimize the command encoding, see below for **breaking changes** ([#165](https://github.com/mitsuhiko/redis-rs/pull/165)) +* Add support for geospatial commands ([#130](https://github.com/mitsuhiko/redis-rs/pull/130)) +* (async) Add support for async Script invocation ([#206](https://github.com/mitsuhiko/redis-rs/pull/206)) + +### BREAKING CHANGES + +#### Renamed the async module to aio ([#189](https://github.com/mitsuhiko/redis-rs/pull/189)) + +`async` is a reserved keyword in Rust 2018, so this avoids the need to write `r#async` in it. + +Old code: + +```rust +use redis::async::SharedConnection; +``` + +New code: + +```rust +use redis::aio::SharedConnection; +``` + +#### The trait `ToRedisArgs` was changed ([#165](https://github.com/mitsuhiko/redis-rs/pull/165)) + +`ToRedisArgs` has been changed to take take an instance of `RedisWrite` instead of `Vec>`. Use the `write_arg` method instead of `Vec::push`. + +#### Minimum Rust version is now 1.26 ([#165](https://github.com/mitsuhiko/redis-rs/pull/165)) + +Upgrade your compiler. +`impl Iterator` is used, requiring a more recent version of the Rust compiler. + +#### `iter` now takes `self` by value ([#165](https://github.com/mitsuhiko/redis-rs/pull/165)) + +`iter` now takes `self` by value instead of cloning `self` inside the method. + +Old code: + +```rust +let mut iter : redis::Iter = cmd.arg("my_set").cursor_arg(0).iter(&con).unwrap(); +``` + +New code: + +```rust +let mut iter : redis::Iter = cmd.arg("my_set").cursor_arg(0).clone().iter(&con).unwrap(); +``` + +(The above line calls `clone()`.) + +#### A mutable connection object is now required ([#148](https://github.com/mitsuhiko/redis-rs/pull/148)) + +We removed the internal usage of `RefCell` and `Cell` and instead require a mutable reference, `&mut ConnectionLike`, +on all command calls. + +Old code: + +```rust +let client = redis::Client::open("redis://127.0.0.1/")?; +let con = client.get_connection(None)?; +redis::cmd("SET").arg("my_key").arg(42).execute(&con); +``` + +New code: + +```rust +let client = redis::Client::open("redis://127.0.0.1/")?; +let mut con = client.get_connection(None)?; +redis::cmd("SET").arg("my_key").arg(42).execute(&mut con); +``` + +Due to this, `transaction` has changed. The callback now also receives a mutable reference to the used connection. + +Old code: + +```rust +let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +let con = client.get_connection(None).unwrap(); +let key = "the_key"; +let (new_val,) : (isize,) = redis::transaction(&con, &[key], |pipe| { + let old_val : isize = con.get(key)?; + pipe + .set(key, old_val + 1).ignore() + .get(key).query(&con) +})?; +``` + +New code: + +```rust +let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +let mut con = client.get_connection(None).unwrap(); +let key = "the_key"; +let (new_val,) : (isize,) = redis::transaction(&mut con, &[key], |con, pipe| { + let old_val : isize = con.get(key)?; + pipe + .set(key, old_val + 1).ignore() + .get(key).query(&con) +})?; +``` + +#### Remove `rustc-serialize` feature ([#200](https://github.com/mitsuhiko/redis-rs/pull/200)) + +We removed serialization to/from JSON. The underlying library is deprecated for a long time. + +Old code in `Cargo.toml`: + +``` +[dependencies.redis] +version = "0.9.1" +features = ["with-rustc-json"] +``` + +There's no replacement for the feature. +Use [serde](https://serde.rs/) and handle the serialization/deserialization in your own code. + +#### Remove `with-unix-sockets` feature ([#201](https://github.com/mitsuhiko/redis-rs/pull/201)) + +We removed the Unix socket feature. It is now always enabled. +We also removed auto-detection. + +Old code in `Cargo.toml`: + +``` +[dependencies.redis] +version = "0.9.1" +features = ["with-unix-sockets"] +``` + +There's no replacement for the feature. Unix sockets will continue to work by default. + +## [0.10.0](https://github.com/mitsuhiko/redis-rs/compare/0.9.1...0.10.0) - 2019-02-19 + +* Fix handling of passwords with special characters (#163) +* Better performance for async code due to less boxing (#167) + * CAUTION: redis-rs will now require Rust 1.26 +* Add `clear` method to the pipeline (#176) +* Better benchmarking (#179) +* Fully formatted source code (#181) + +## [0.9.1](https://github.com/mitsuhiko/redis-rs/compare/0.9.0...0.9.1) (2018-09-10) + +* Add ttl command + +## [0.9.0](https://github.com/mitsuhiko/redis-rs/compare/0.8.0...0.9.0) (2018-08-08) + +Some time has passed since the last release. +This new release will bring less bugs, more commands, experimental async support and better performance. + +Highlights: + +* Implement flexible PubSub API (#136) +* Avoid allocating some redundant Vec's during encoding (#140) +* Add an async interface using futures-rs (#141) +* Allow the async connection to have multiple in flight requests (#143) + +The async support is currently experimental. + +## [0.8.0](https://github.com/mitsuhiko/redis-rs/compare/0.7.1...0.8.0) (2016-12-26) + +* Add publish command + +## [0.7.1](https://github.com/mitsuhiko/redis-rs/compare/0.7.0...0.7.1) (2016-12-17) + +* Fix unix socket builds +* Relax lifetimes for scripts + +## [0.7.0](https://github.com/mitsuhiko/redis-rs/compare/0.6.0...0.7.0) (2016-07-23) + +* Add support for built-in unix sockets + +## [0.6.0](https://github.com/mitsuhiko/redis-rs/compare/0.5.4...0.6.0) (2016-07-14) + +* feat: Make rustc-serialize an optional feature (#96) + +## [0.5.4](https://github.com/mitsuhiko/redis-rs/compare/0.5.3...0.5.4) (2016-06-25) + +* fix: Improved single arg handling (#95) +* feat: Implement ToRedisArgs for &String (#89) +* feat: Faster command encoding (#94) + +## [0.5.3](https://github.com/mitsuhiko/redis-rs/compare/0.5.2...0.5.3) (2016-05-03) + +* fix: Use explicit versions for dependencies +* fix: Send `AUTH` command before other commands +* fix: Shutdown connection upon protocol error +* feat: Add `keys` method +* feat: Possibility to set read and write timeouts for the connection diff --git a/glide-core/redis-rs/redis/Cargo.toml b/glide-core/redis-rs/redis/Cargo.toml new file mode 100644 index 0000000000..fd79ff079e --- /dev/null +++ b/glide-core/redis-rs/redis/Cargo.toml @@ -0,0 +1,227 @@ +[package] +name = "redis" +version = "0.25.2" +keywords = ["redis", "database"] +description = "Redis driver for Rust." +homepage = "https://github.com/redis-rs/redis-rs" +repository = "https://github.com/redis-rs/redis-rs" +documentation = "https://docs.rs/redis" +license = "BSD-3-Clause" +edition = "2021" +rust-version = "1.65" +readme = "../README.md" + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] + +[lib] +bench = false + +[dependencies] +# These two are generally really common simple dependencies so it does not seem +# much of a point to optimize these, but these could in theory be removed for +# an indirection through std::Formatter. +ryu = "1.0" +itoa = "1.0" + +# Strum is a set of macros and traits for working with enums and strings easier in Rust. +strum = "0.26" +strum_macros = "0.26" + +# This is a dependency that already exists in url +percent-encoding = "2.1" + +# We need this for redis url parsing +url = "= 2.5.0" + +# We need this for script support +sha1_smol = { version = "1.0", optional = true } + +combine = { version = "4.6", default-features = false, features = ["std"] } + +# Only needed for AIO +bytes = { version = "1", optional = true } +futures-util = { version = "0.3.15", default-features = false, optional = true } +pin-project-lite = { version = "0.2", optional = true } +tokio-util = { version = "0.7", optional = true } +tokio = { version = "1", features = ["rt", "net", "time", "sync"] } +socket2 = { version = "0.5", features = ["all"], optional = true } +fast-math = { version = "0.1.1", optional = true } +dispose = { version = "0.5.0", optional = true } + +# Only needed for the connection manager +arc-swap = { version = "1.7.1" } +futures = { version = "0.3.3", optional = true } +tokio-retry = { version = "0.3.0", optional = true } + +# Only needed for the r2d2 feature +r2d2 = { version = "0.8.8", optional = true } + +# Only needed for cluster +crc16 = { version = "0.4", optional = true } +rand = { version = "0.8", optional = true } +derivative = { version = "2.2.0", optional = true } + +# Only needed for async cluster +dashmap = { version = "6.0", optional = true } + +# Only needed for async_std support +async-std = { version = "1.8.0", optional = true } +async-trait = { version = "0.1.24", optional = true } +# To avoid conflicts, backoff-std-async.version != backoff-tokio.version so we could run tests with --all-features +backoff-std-async = { package = "backoff", version = "0.3.0", optional = true, features = ["async-std"] } + +# Only needed for tokio support +backoff-tokio = { package = "backoff", version = "0.4.0", optional = true, features = ["tokio"] } + +# Only needed for native tls +native-tls = { version = "0.2", optional = true } +tokio-native-tls = { version = "0.3", optional = true } +async-native-tls = { version = "0.4", optional = true } + +# Only needed for rustls +rustls = { version = "0.22", optional = true } +webpki-roots = { version = "0.26", optional = true } +rustls-native-certs = { version = "0.7", optional = true } +tokio-rustls = { version = "0.25", optional = true } +futures-rustls = { version = "0.25", optional = true } +rustls-pemfile = { version = "2", optional = true } +rustls-pki-types = { version = "1", optional = true } + +# Only needed for RedisJSON Support +serde = { version = "1.0.82", optional = true } +serde_json = { version = "1.0.82", optional = true } + +# Only needed for bignum Support +rust_decimal = { version = "1.33.1", optional = true } +bigdecimal = { version = "0.4.2", optional = true } +num-bigint = "0.4.4" + +# Optional aHash support +ahash = { version = "0.8.11", optional = true } + +tracing = "0.1" +arcstr = "1.1.5" + +# Optional uuid support +uuid = { version = "1.6.1", optional = true } + +[features] +default = ["acl", "streams", "geospatial", "script", "keep-alive"] +acl = [] +aio = ["bytes", "pin-project-lite", "futures-util", "futures-util/alloc", "futures-util/sink", "tokio/io-util", "tokio-util", "tokio-util/codec", "combine/tokio", "async-trait", "fast-math", "dispose"] +geospatial = [] +json = ["serde", "serde/derive", "serde_json"] +cluster = ["crc16", "rand", "derivative"] +script = ["sha1_smol"] +tls-native-tls = ["native-tls"] +tls-rustls = ["rustls", "rustls-native-certs", "rustls-pemfile", "rustls-pki-types"] +tls-rustls-insecure = ["tls-rustls"] +tls-rustls-webpki-roots = ["tls-rustls", "webpki-roots"] +async-std-comp = ["aio", "async-std", "backoff-std-async"] +async-std-native-tls-comp = ["async-std-comp", "async-native-tls", "tls-native-tls"] +async-std-rustls-comp = ["async-std-comp", "futures-rustls", "tls-rustls"] +tokio-comp = ["aio", "tokio/net", "backoff-tokio"] +tokio-native-tls-comp = ["tokio-comp", "tls-native-tls", "tokio-native-tls"] +tokio-rustls-comp = ["tokio-comp", "tls-rustls", "tokio-rustls"] +connection-manager = ["futures", "aio", "tokio-retry"] +streams = [] +cluster-async = ["cluster", "futures", "futures-util", "dashmap"] +keep-alive = ["socket2"] +sentinel = ["rand"] +tcp_nodelay = [] +rust_decimal = ["dep:rust_decimal"] +bigdecimal = ["dep:bigdecimal"] +num-bigint = [] +uuid = ["dep:uuid"] +disable-client-setinfo = [] + +# Deprecated features +tls = ["tls-native-tls"] # use "tls-native-tls" instead +async-std-tls-comp = ["async-std-native-tls-comp"] # use "async-std-native-tls-comp" instead + +[dev-dependencies] +rand = "0.8" +socket2 = "0.5" +assert_approx_eq = "1.0" +fnv = "1.0.5" +futures = "0.3" +futures-time = "3" +criterion = "0.4" +partial-io = { version = "0.5", features = ["tokio", "quickcheck1"] } +quickcheck = "1.0.3" +tokio = { version = "1", features = ["rt", "macros", "rt-multi-thread", "time"] } +tempfile = "=3.6.0" +once_cell = "1" +anyhow = "1" +sscanf = "0.4.1" + +[[test]] +name = "test_async" +required-features = ["tokio-comp"] + +[[test]] +name = "test_async_async_std" +required-features = ["async-std-comp"] + +[[test]] +name = "parser" +required-features = ["aio"] + +[[test]] +name = "test_acl" + +[[test]] +name = "test_module_json" +required-features = ["json", "serde/derive"] + +[[test]] +name = "test_cluster_async" +required-features = ["cluster-async"] + +[[test]] +name = "test_async_cluster_connections_logic" +required-features = ["cluster-async"] + +[[test]] +name = "test_bignum" + +[[bench]] +name = "bench_basic" +harness = false +required-features = ["tokio-comp"] + +[[bench]] +name = "bench_cluster" +harness = false +required-features = ["cluster"] + +[[bench]] +name = "bench_cluster_async" +harness = false +required-features = ["cluster-async", "tokio-comp"] + +[[example]] +name = "async-multiplexed" +required-features = ["tokio-comp"] + +[[example]] +name = "async-await" +required-features = ["aio"] + +[[example]] +name = "async-pub-sub" +required-features = ["aio"] + +[[example]] +name = "async-scan" +required-features = ["aio"] + +[[example]] +name = "async-connection-loss" +required-features = ["connection-manager"] + +[[example]] +name = "streams" +required-features = ["streams"] diff --git a/glide-core/redis-rs/redis/LICENSE b/glide-core/redis-rs/redis/LICENSE new file mode 100644 index 0000000000..533ac4e5a2 --- /dev/null +++ b/glide-core/redis-rs/redis/LICENSE @@ -0,0 +1,33 @@ +Copyright (c) 2022 by redis-rs contributors + +Redis cluster code in parts copyright (c) 2018 by Atsushi Koge. + +Some rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + * The names of the contributors may not be used to endorse or + promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/glide-core/redis-rs/redis/benches/bench_basic.rs b/glide-core/redis-rs/redis/benches/bench_basic.rs new file mode 100644 index 0000000000..356f74217e --- /dev/null +++ b/glide-core/redis-rs/redis/benches/bench_basic.rs @@ -0,0 +1,277 @@ +use criterion::{criterion_group, criterion_main, Bencher, Criterion, Throughput}; +use futures::prelude::*; +use redis::{RedisError, Value}; + +use support::*; + +#[path = "../tests/support/mod.rs"] +mod support; + +fn bench_simple_getsetdel(b: &mut Bencher) { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + b.iter(|| { + let key = "test_key"; + redis::cmd("SET").arg(key).arg(42).execute(&mut con); + let _: isize = redis::cmd("GET").arg(key).query(&mut con).unwrap(); + redis::cmd("DEL").arg(key).execute(&mut con); + }); +} + +fn bench_simple_getsetdel_async(b: &mut Bencher) { + let ctx = TestContext::new(); + let runtime = current_thread_runtime(); + let mut con = runtime.block_on(ctx.async_connection()).unwrap(); + + b.iter(|| { + runtime + .block_on(async { + let key = "test_key"; + () = redis::cmd("SET") + .arg(key) + .arg(42) + .query_async(&mut con) + .await?; + let _: isize = redis::cmd("GET").arg(key).query_async(&mut con).await?; + () = redis::cmd("DEL").arg(key).query_async(&mut con).await?; + Ok::<_, RedisError>(()) + }) + .unwrap() + }); +} + +fn bench_simple_getsetdel_pipeline(b: &mut Bencher) { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + b.iter(|| { + let key = "test_key"; + let _: (usize,) = redis::pipe() + .cmd("SET") + .arg(key) + .arg(42) + .ignore() + .cmd("GET") + .arg(key) + .cmd("DEL") + .arg(key) + .ignore() + .query(&mut con) + .unwrap(); + }); +} + +fn bench_simple_getsetdel_pipeline_precreated(b: &mut Bencher) { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let key = "test_key"; + let mut pipe = redis::pipe(); + pipe.cmd("SET") + .arg(key) + .arg(42) + .ignore() + .cmd("GET") + .arg(key) + .cmd("DEL") + .arg(key) + .ignore(); + + b.iter(|| { + let _: (usize,) = pipe.query(&mut con).unwrap(); + }); +} + +const PIPELINE_QUERIES: usize = 1_000; + +fn long_pipeline() -> redis::Pipeline { + let mut pipe = redis::pipe(); + + for i in 0..PIPELINE_QUERIES { + pipe.set(format!("foo{i}"), "bar").ignore(); + } + pipe +} + +fn bench_long_pipeline(b: &mut Bencher) { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let pipe = long_pipeline(); + + b.iter(|| { + pipe.query::<()>(&mut con).unwrap(); + }); +} + +fn bench_async_long_pipeline(b: &mut Bencher) { + let ctx = TestContext::new(); + let runtime = current_thread_runtime(); + let mut con = runtime.block_on(ctx.async_connection()).unwrap(); + + let pipe = long_pipeline(); + + b.iter(|| { + runtime + .block_on(async { pipe.query_async::<_, ()>(&mut con).await }) + .unwrap(); + }); +} + +fn bench_multiplexed_async_long_pipeline(b: &mut Bencher) { + let ctx = TestContext::new(); + let runtime = current_thread_runtime(); + let mut con = runtime + .block_on(ctx.multiplexed_async_connection_tokio()) + .unwrap(); + + let pipe = long_pipeline(); + + b.iter(|| { + runtime + .block_on(async { pipe.query_async::<_, ()>(&mut con).await }) + .unwrap(); + }); +} + +fn bench_multiplexed_async_implicit_pipeline(b: &mut Bencher) { + let ctx = TestContext::new(); + let runtime = current_thread_runtime(); + let con = runtime + .block_on(ctx.multiplexed_async_connection_tokio()) + .unwrap(); + + let cmds: Vec<_> = (0..PIPELINE_QUERIES) + .map(|i| redis::cmd("SET").arg(format!("foo{i}")).arg(i).clone()) + .collect(); + + let mut connections = (0..PIPELINE_QUERIES) + .map(|_| con.clone()) + .collect::>(); + + b.iter(|| { + runtime + .block_on(async { + cmds.iter() + .zip(&mut connections) + .map(|(cmd, con)| cmd.query_async::<_, ()>(con)) + .collect::>() + .try_for_each(|()| async { Ok(()) }) + .await + }) + .unwrap(); + }); +} + +fn bench_query(c: &mut Criterion) { + let mut group = c.benchmark_group("query"); + group + .bench_function("simple_getsetdel", bench_simple_getsetdel) + .bench_function("simple_getsetdel_async", bench_simple_getsetdel_async) + .bench_function("simple_getsetdel_pipeline", bench_simple_getsetdel_pipeline) + .bench_function( + "simple_getsetdel_pipeline_precreated", + bench_simple_getsetdel_pipeline_precreated, + ); + group.finish(); + + let mut group = c.benchmark_group("query_pipeline"); + group + .bench_function( + "multiplexed_async_implicit_pipeline", + bench_multiplexed_async_implicit_pipeline, + ) + .bench_function( + "multiplexed_async_long_pipeline", + bench_multiplexed_async_long_pipeline, + ) + .bench_function("async_long_pipeline", bench_async_long_pipeline) + .bench_function("long_pipeline", bench_long_pipeline) + .throughput(Throughput::Elements(PIPELINE_QUERIES as u64)); + group.finish(); +} + +fn bench_encode_small(b: &mut Bencher) { + b.iter(|| { + let mut cmd = redis::cmd("HSETX"); + + cmd.arg("ABC:1237897325302:878241asdyuxpioaswehqwu") + .arg("some hash key") + .arg(124757920); + + cmd.get_packed_command() + }); +} + +fn bench_encode_integer(b: &mut Bencher) { + b.iter(|| { + let mut pipe = redis::pipe(); + + for _ in 0..1_000 { + pipe.set(123, 45679123).ignore(); + } + pipe.get_packed_pipeline() + }); +} + +fn bench_encode_pipeline(b: &mut Bencher) { + b.iter(|| { + let mut pipe = redis::pipe(); + + for _ in 0..1_000 { + pipe.set("foo", "bar").ignore(); + } + pipe.get_packed_pipeline() + }); +} + +fn bench_encode_pipeline_nested(b: &mut Bencher) { + b.iter(|| { + let mut pipe = redis::pipe(); + + for _ in 0..200 { + pipe.set( + "foo", + ("bar", 123, b"1231279712", &["test", "test", "test"][..]), + ) + .ignore(); + } + pipe.get_packed_pipeline() + }); +} + +fn bench_encode(c: &mut Criterion) { + let mut group = c.benchmark_group("encode"); + group + .bench_function("pipeline", bench_encode_pipeline) + .bench_function("pipeline_nested", bench_encode_pipeline_nested) + .bench_function("integer", bench_encode_integer) + .bench_function("small", bench_encode_small); + group.finish(); +} + +fn bench_decode_simple(b: &mut Bencher, input: &[u8]) { + b.iter(|| redis::parse_redis_value(input).unwrap()); +} +fn bench_decode(c: &mut Criterion) { + let value = Value::Array(vec![ + Value::Okay, + Value::SimpleString("testing".to_string()), + Value::Array(vec![]), + Value::Nil, + Value::BulkString(vec![b'a'; 10]), + Value::Int(7512182390), + ]); + + let mut group = c.benchmark_group("decode"); + { + let mut input = Vec::new(); + support::encode_value(&value, &mut input).unwrap(); + assert_eq!(redis::parse_redis_value(&input).unwrap(), value); + group.bench_function("decode", move |b| bench_decode_simple(b, &input)); + } + group.finish(); +} + +criterion_group!(bench, bench_query, bench_encode, bench_decode); +criterion_main!(bench); diff --git a/glide-core/redis-rs/redis/benches/bench_cluster.rs b/glide-core/redis-rs/redis/benches/bench_cluster.rs new file mode 100644 index 0000000000..da854474ae --- /dev/null +++ b/glide-core/redis-rs/redis/benches/bench_cluster.rs @@ -0,0 +1,108 @@ +#![allow(clippy::unit_arg)] // want to allow this for `black_box()` +#![cfg(feature = "cluster")] +use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use redis::cluster::cluster_pipe; + +use support::*; + +#[path = "../tests/support/mod.rs"] +mod support; + +const PIPELINE_QUERIES: usize = 100; + +fn bench_set_get_and_del(c: &mut Criterion, con: &mut redis::cluster::ClusterConnection) { + let key = "test_key"; + + let mut group = c.benchmark_group("cluster_basic"); + + group.bench_function("set", |b| { + b.iter(|| { + redis::cmd("SET").arg(key).arg(42).execute(con); + black_box(()) + }) + }); + + group.bench_function("get", |b| { + b.iter(|| black_box(redis::cmd("GET").arg(key).query::(con).unwrap())) + }); + + let mut set_and_del = || { + redis::cmd("SET").arg(key).arg(42).execute(con); + redis::cmd("DEL").arg(key).execute(con); + }; + group.bench_function("set_and_del", |b| { + b.iter(|| { + set_and_del(); + black_box(()) + }) + }); + + group.finish(); +} + +fn bench_pipeline(c: &mut Criterion, con: &mut redis::cluster::ClusterConnection) { + let mut group = c.benchmark_group("cluster_pipeline"); + group.throughput(Throughput::Elements(PIPELINE_QUERIES as u64)); + + let mut queries = Vec::new(); + for i in 0..PIPELINE_QUERIES { + queries.push(format!("foo{i}")); + } + + let build_pipeline = || { + let mut pipe = cluster_pipe(); + for q in &queries { + pipe.set(q, "bar").ignore(); + } + }; + group.bench_function("build_pipeline", |b| { + b.iter(|| { + build_pipeline(); + black_box(()) + }) + }); + + let mut pipe = cluster_pipe(); + for q in &queries { + pipe.set(q, "bar").ignore(); + } + group.bench_function("query_pipeline", |b| { + b.iter(|| { + pipe.query::<()>(con).unwrap(); + black_box(()) + }) + }); + + group.finish(); +} + +fn bench_cluster_setup(c: &mut Criterion) { + let cluster = TestClusterContext::new(6, 1); + cluster.wait_for_cluster_up(); + + let mut con = cluster.connection(); + bench_set_get_and_del(c, &mut con); + bench_pipeline(c, &mut con); +} + +#[allow(dead_code)] +fn bench_cluster_read_from_replicas_setup(c: &mut Criterion) { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 6, + 1, + |builder| builder.read_from_replicas(), + false, + ); + cluster.wait_for_cluster_up(); + + let mut con = cluster.connection(); + bench_set_get_and_del(c, &mut con); + bench_pipeline(c, &mut con); +} + +criterion_group!( + cluster_bench, + bench_cluster_setup, + // bench_cluster_read_from_replicas_setup +); +criterion_main!(cluster_bench); diff --git a/glide-core/redis-rs/redis/benches/bench_cluster_async.rs b/glide-core/redis-rs/redis/benches/bench_cluster_async.rs new file mode 100644 index 0000000000..28c3b83c87 --- /dev/null +++ b/glide-core/redis-rs/redis/benches/bench_cluster_async.rs @@ -0,0 +1,88 @@ +#![allow(clippy::unit_arg)] // want to allow this for `black_box()` +#![cfg(feature = "cluster")] +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use futures_util::{stream, TryStreamExt}; +use redis::RedisError; + +use support::*; +use tokio::runtime::Runtime; + +#[path = "../tests/support/mod.rs"] +mod support; + +fn bench_cluster_async( + c: &mut Criterion, + con: &mut redis::cluster_async::ClusterConnection, + runtime: &Runtime, +) { + let mut group = c.benchmark_group("cluster_async"); + group.bench_function("set_get_and_del", |b| { + b.iter(|| { + runtime + .block_on(async { + let key = "test_key"; + () = redis::cmd("SET").arg(key).arg(42).query_async(con).await?; + let _: isize = redis::cmd("GET").arg(key).query_async(con).await?; + () = redis::cmd("DEL").arg(key).query_async(con).await?; + + Ok::<_, RedisError>(()) + }) + .unwrap(); + black_box(()) + }) + }); + + group.bench_function("parallel_requests", |b| { + let num_parallel = 100; + let cmds: Vec<_> = (0..num_parallel) + .map(|i| redis::cmd("SET").arg(format!("foo{i}")).arg(i).clone()) + .collect(); + + let mut connections = (0..num_parallel).map(|_| con.clone()).collect::>(); + + b.iter(|| { + runtime + .block_on(async { + cmds.iter() + .zip(&mut connections) + .map(|(cmd, con)| cmd.query_async::<_, ()>(con)) + .collect::>() + .try_for_each(|()| async { Ok(()) }) + .await + }) + .unwrap(); + black_box(()) + }); + }); + + group.bench_function("pipeline", |b| { + let num_queries = 100; + + let mut pipe = redis::pipe(); + + for _ in 0..num_queries { + pipe.set("foo".to_string(), "bar").ignore(); + } + + b.iter(|| { + runtime + .block_on(async { pipe.query_async::<_, ()>(con).await }) + .unwrap(); + black_box(()) + }); + }); + + group.finish(); +} + +fn bench_cluster_setup(c: &mut Criterion) { + let cluster = TestClusterContext::new(6, 1); + cluster.wait_for_cluster_up(); + let runtime = current_thread_runtime(); + let mut con = runtime.block_on(cluster.async_connection(None)); + + bench_cluster_async(c, &mut con, &runtime); +} + +criterion_group!(cluster_async_bench, bench_cluster_setup,); +criterion_main!(cluster_async_bench); diff --git a/glide-core/redis-rs/redis/examples/async-await.rs b/glide-core/redis-rs/redis/examples/async-await.rs new file mode 100644 index 0000000000..2d829c7d60 --- /dev/null +++ b/glide-core/redis-rs/redis/examples/async-await.rs @@ -0,0 +1,24 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +use redis::{AsyncCommands, GlideConnectionOptions}; + +#[tokio::main] +async fn main() -> redis::RedisResult<()> { + let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + let mut con = client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + + con.set("key1", b"foo").await?; + + redis::cmd("SET") + .arg(&["key2", "bar"]) + .query_async(&mut con) + .await?; + + let result = redis::cmd("MGET") + .arg(&["key1", "key2"]) + .query_async(&mut con) + .await; + assert_eq!(result, Ok(("foo".to_string(), b"bar".to_vec()))); + Ok(()) +} diff --git a/glide-core/redis-rs/redis/examples/async-connection-loss.rs b/glide-core/redis-rs/redis/examples/async-connection-loss.rs new file mode 100644 index 0000000000..a7dba3ab89 --- /dev/null +++ b/glide-core/redis-rs/redis/examples/async-connection-loss.rs @@ -0,0 +1,97 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +//! This example will connect to Redis in one of three modes: +//! +//! - Regular async connection +//! - Async multiplexed connection +//! - Async connection manager +//! +//! It will then send a PING every 100 ms and print the result. + +use std::env; +use std::process; +use std::time::Duration; + +use futures::future; +use redis::aio::ConnectionLike; +use redis::GlideConnectionOptions; +use redis::RedisResult; +use tokio::time::interval; + +enum Mode { + Deprecated, + Default, + Reconnect, +} + +async fn run_single(mut con: C) -> RedisResult<()> { + let mut interval = interval(Duration::from_millis(100)); + loop { + interval.tick().await; + println!(); + println!("> PING"); + let result: RedisResult = redis::cmd("PING").query_async(&mut con).await; + println!("< {result:?}"); + } +} + +async fn run_multi(mut con: C) -> RedisResult<()> { + let mut interval = interval(Duration::from_millis(100)); + loop { + interval.tick().await; + println!(); + println!("> PING"); + println!("> PING"); + println!("> PING"); + let results: ( + RedisResult, + RedisResult, + RedisResult, + ) = future::join3( + redis::cmd("PING").query_async(&mut con.clone()), + redis::cmd("PING").query_async(&mut con.clone()), + redis::cmd("PING").query_async(&mut con), + ) + .await; + println!("< {:?}", results.0); + println!("< {:?}", results.1); + println!("< {:?}", results.2); + } +} + +#[tokio::main] +async fn main() -> RedisResult<()> { + let mode = match env::args().nth(1).as_deref() { + Some("default") => { + println!("Using default connection mode\n"); + Mode::Default + } + Some("reconnect") => { + println!("Using reconnect manager mode\n"); + Mode::Reconnect + } + Some("deprecated") => { + println!("Using deprecated connection mode\n"); + Mode::Deprecated + } + Some(_) | None => { + println!("Usage: reconnect-manager (default|multiplexed|reconnect)"); + process::exit(1); + } + }; + + let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + match mode { + Mode::Default => { + run_multi( + client + .get_multiplexed_tokio_connection(GlideConnectionOptions::default()) + .await?, + ) + .await? + } + Mode::Reconnect => run_multi(client.get_connection_manager().await?).await?, + #[allow(deprecated)] + Mode::Deprecated => run_single(client.get_async_connection(None).await?).await?, + }; + Ok(()) +} diff --git a/glide-core/redis-rs/redis/examples/async-multiplexed.rs b/glide-core/redis-rs/redis/examples/async-multiplexed.rs new file mode 100644 index 0000000000..2e5332359b --- /dev/null +++ b/glide-core/redis-rs/redis/examples/async-multiplexed.rs @@ -0,0 +1,46 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +use futures::prelude::*; +use redis::{aio::MultiplexedConnection, GlideConnectionOptions, RedisResult}; + +async fn test_cmd(con: &MultiplexedConnection, i: i32) -> RedisResult<()> { + let mut con = con.clone(); + + let key = format!("key{i}"); + let key2 = format!("key{i}_2"); + let value = format!("foo{i}"); + + redis::cmd("SET") + .arg(&key) + .arg(&value) + .query_async(&mut con) + .await?; + + redis::cmd("SET") + .arg(&[&key2, "bar"]) + .query_async(&mut con) + .await?; + + redis::cmd("MGET") + .arg(&[&key, &key2]) + .query_async(&mut con) + .map(|result| { + assert_eq!(Ok((value, b"bar".to_vec())), result); + Ok(()) + }) + .await +} + +#[tokio::main] +async fn main() { + let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + + let con = client + .get_multiplexed_tokio_connection(GlideConnectionOptions::default()) + .await + .unwrap(); + + let cmds = (0..100).map(|i| test_cmd(&con, i)); + let result = future::try_join_all(cmds).await.unwrap(); + + assert_eq!(100, result.len()); +} diff --git a/glide-core/redis-rs/redis/examples/async-pub-sub.rs b/glide-core/redis-rs/redis/examples/async-pub-sub.rs new file mode 100644 index 0000000000..fe84b44fb9 --- /dev/null +++ b/glide-core/redis-rs/redis/examples/async-pub-sub.rs @@ -0,0 +1,22 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +use futures_util::StreamExt as _; +use redis::{AsyncCommands, GlideConnectionOptions}; + +#[tokio::main] +async fn main() -> redis::RedisResult<()> { + let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + let mut publish_conn = client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + let mut pubsub_conn = client.get_async_pubsub().await?; + + pubsub_conn.subscribe("wavephone").await?; + let mut pubsub_stream = pubsub_conn.on_message(); + + publish_conn.publish("wavephone", "banana").await?; + + let pubsub_msg: String = pubsub_stream.next().await.unwrap().get_payload()?; + assert_eq!(&pubsub_msg, "banana"); + + Ok(()) +} diff --git a/glide-core/redis-rs/redis/examples/async-scan.rs b/glide-core/redis-rs/redis/examples/async-scan.rs new file mode 100644 index 0000000000..06a66fe83e --- /dev/null +++ b/glide-core/redis-rs/redis/examples/async-scan.rs @@ -0,0 +1,25 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +use futures::stream::StreamExt; +use redis::{AsyncCommands, AsyncIter, GlideConnectionOptions}; + +#[tokio::main] +async fn main() -> redis::RedisResult<()> { + let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + let mut con = client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + + con.set("async-key1", b"foo").await?; + con.set("async-key2", b"foo").await?; + + let iter: AsyncIter = con.scan().await?; + let mut keys: Vec<_> = iter.collect().await; + + keys.sort(); + + assert_eq!( + keys, + vec!["async-key1".to_string(), "async-key2".to_string()] + ); + Ok(()) +} diff --git a/glide-core/redis-rs/redis/examples/basic.rs b/glide-core/redis-rs/redis/examples/basic.rs new file mode 100644 index 0000000000..622dc36e59 --- /dev/null +++ b/glide-core/redis-rs/redis/examples/basic.rs @@ -0,0 +1,169 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +use redis::{transaction, Commands}; + +use std::collections::HashMap; +use std::env; + +/// This function demonstrates how a return value can be coerced into a +/// hashmap of tuples. This is particularly useful for responses like +/// CONFIG GET or all most H functions which will return responses in +/// such list of implied tuples. +fn do_print_max_entry_limits(con: &mut redis::Connection) -> redis::RedisResult<()> { + // since rust cannot know what format we actually want we need to be + // explicit here and define the type of our response. In this case + // String -> int fits all the items we query for. + let config: HashMap = redis::cmd("CONFIG") + .arg("GET") + .arg("*-max-*-entries") + .query(con)?; + + println!("Max entry limits:"); + + println!( + " max-intset: {}", + config.get("set-max-intset-entries").unwrap_or(&0) + ); + println!( + " hash-max-ziplist: {}", + config.get("hash-max-ziplist-entries").unwrap_or(&0) + ); + println!( + " list-max-ziplist: {}", + config.get("list-max-ziplist-entries").unwrap_or(&0) + ); + println!( + " zset-max-ziplist: {}", + config.get("zset-max-ziplist-entries").unwrap_or(&0) + ); + + Ok(()) +} + +/// This is a pretty stupid example that demonstrates how to create a large +/// set through a pipeline and how to iterate over it through implied +/// cursors. +fn do_show_scanning(con: &mut redis::Connection) -> redis::RedisResult<()> { + // This makes a large pipeline of commands. Because the pipeline is + // modified in place we can just ignore the return value upon the end + // of each iteration. + let mut pipe = redis::pipe(); + for num in 0..1000 { + pipe.cmd("SADD").arg("my_set").arg(num).ignore(); + } + + // since we don't care about the return value of the pipeline we can + // just cast it into the unit type. + pipe.query(con)?; + + // since rust currently does not track temporaries for us, we need to + // store it in a local variable. + let mut cmd = redis::cmd("SSCAN"); + cmd.arg("my_set").cursor_arg(0); + + // as a simple exercise we just sum up the iterator. Since the fold + // method carries an initial value we do not need to define the + // type of the iterator, rust will figure "int" out for us. + let sum: i32 = cmd.iter::(con)?.sum(); + + println!("The sum of all numbers in the set 0-1000: {sum}"); + + Ok(()) +} + +/// Demonstrates how to do an atomic increment in a very low level way. +fn do_atomic_increment_lowlevel(con: &mut redis::Connection) -> redis::RedisResult<()> { + let key = "the_key"; + println!("Run low-level atomic increment:"); + + // set the initial value so we have something to test with. + redis::cmd("SET").arg(key).arg(42).query(con)?; + + loop { + // we need to start watching the key we care about, so that our + // exec fails if the key changes. + redis::cmd("WATCH").arg(key).query(con)?; + + // load the old value, so we know what to increment. + let val: isize = redis::cmd("GET").arg(key).query(con)?; + + // at this point we can go into an atomic pipe (a multi block) + // and set up the keys. + let response: Option<(isize,)> = redis::pipe() + .atomic() + .cmd("SET") + .arg(key) + .arg(val + 1) + .ignore() + .cmd("GET") + .arg(key) + .query(con)?; + + match response { + None => { + continue; + } + Some(response) => { + let (new_val,) = response; + println!(" New value: {new_val}"); + break; + } + } + } + + Ok(()) +} + +/// Demonstrates how to do an atomic increment with transaction support. +fn do_atomic_increment(con: &mut redis::Connection) -> redis::RedisResult<()> { + let key = "the_key"; + println!("Run high-level atomic increment:"); + + // set the initial value so we have something to test with. + con.set(key, 42)?; + + // run the transaction block. + let (new_val,): (isize,) = transaction(con, &[key], |con, pipe| { + // load the old value, so we know what to increment. + let val: isize = con.get(key)?; + // increment + pipe.set(key, val + 1).ignore().get(key).query(con) + })?; + + // and print the result + println!("New value: {new_val}"); + + Ok(()) +} + +/// Runs all the examples and propagates errors up. +fn do_redis_code(url: &str) -> redis::RedisResult<()> { + // general connection handling + let client = redis::Client::open(url)?; + let mut con = client.get_connection(None)?; + + // read some config and print it. + do_print_max_entry_limits(&mut con)?; + + // demonstrate how scanning works. + do_show_scanning(&mut con)?; + + // shows an atomic increment. + do_atomic_increment_lowlevel(&mut con)?; + do_atomic_increment(&mut con)?; + + Ok(()) +} + +fn main() { + // at this point the errors are fatal, let's just fail hard. + let url = if env::args().nth(1) == Some("--tls".into()) { + "rediss://127.0.0.1:6380/#insecure" + } else { + "redis://127.0.0.1:6379/" + }; + + if let Err(err) = do_redis_code(url) { + println!("Could not execute example:"); + println!(" {}: {}", err.category(), err); + } +} diff --git a/glide-core/redis-rs/redis/examples/geospatial.rs b/glide-core/redis-rs/redis/examples/geospatial.rs new file mode 100644 index 0000000000..5033b6c775 --- /dev/null +++ b/glide-core/redis-rs/redis/examples/geospatial.rs @@ -0,0 +1,68 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +use std::process::exit; + +use redis::RedisResult; + +#[cfg(feature = "geospatial")] +fn run() -> RedisResult<()> { + use redis::{geo, Commands}; + use std::env; + use std::f64; + + let redis_url = match env::var("REDIS_URL") { + Ok(url) => url, + Err(..) => "redis://127.0.0.1/".to_string(), + }; + + let client = redis::Client::open(redis_url.as_str())?; + let mut con = client.get_connection(None)?; + + // Add some members to the geospatial index. + + let added: isize = con.geo_add( + "gis", + &[ + (geo::Coord::lon_lat("13.361389", "38.115556"), "Palermo"), + (geo::Coord::lon_lat("15.087269", "37.502669"), "Catania"), + (geo::Coord::lon_lat("13.5833332", "37.316667"), "Agrigento"), + ], + )?; + + println!("[geo_add] Added {added} members."); + + // Get the position of one of them. + + let position: Vec> = con.geo_pos("gis", "Palermo")?; + println!("[geo_pos] Position for Palermo: {position:?}"); + + // Search members near (13.5, 37.75) + + let options = geo::RadiusOptions::default() + .order(geo::RadiusOrder::Asc) + .with_dist() + .limit(2); + let items: Vec = + con.geo_radius("gis", 13.5, 37.75, 150.0, geo::Unit::Kilometers, options)?; + + for item in items { + println!( + "[geo_radius] {}, dist = {} Km", + item.name, + item.dist.unwrap_or(f64::NAN) + ); + } + + Ok(()) +} + +#[cfg(not(feature = "geospatial"))] +fn run() -> RedisResult<()> { + Ok(()) +} + +fn main() { + if let Err(e) = run() { + println!("{e:?}"); + exit(1); + } +} diff --git a/glide-core/redis-rs/redis/examples/streams.rs b/glide-core/redis-rs/redis/examples/streams.rs new file mode 100644 index 0000000000..8c40ea487d --- /dev/null +++ b/glide-core/redis-rs/redis/examples/streams.rs @@ -0,0 +1,270 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +#![cfg(feature = "streams")] + +use redis::streams::{StreamId, StreamKey, StreamMaxlen, StreamReadOptions, StreamReadReply}; + +use redis::{Commands, RedisResult, Value}; + +use std::thread; +use std::time::Duration; +use std::time::{SystemTime, UNIX_EPOCH}; + +const DOG_STREAM: &str = "example-dog"; +const CAT_STREAM: &str = "example-cat"; +const DUCK_STREAM: &str = "example-duck"; + +const STREAMS: &[&str] = &[DOG_STREAM, CAT_STREAM, DUCK_STREAM]; + +const SLOWNESSES: &[u8] = &[2, 3, 4]; + +/// This program generates an arbitrary set of records across three +/// different streams. It then reads the data back in such a way +/// that demonstrates basic usage of both the XREAD and XREADGROUP +/// commands. +fn main() { + let client = redis::Client::open("redis://127.0.0.1/").expect("client"); + + println!("Demonstrating XADD followed by XREAD, single threaded\n"); + + add_records(&client).expect("contrived record generation"); + + read_records(&client).expect("simple read"); + + demo_group_reads(&client); + + clean_up(&client) +} + +fn demo_group_reads(client: &redis::Client) { + println!("\n\nDemonstrating a longer stream of data flowing\nin over time, consumed by multiple threads using XREADGROUP\n"); + + let mut handles = vec![]; + + let cc = client.clone(); + // Launch a producer thread which repeatedly adds records, + // with only a small delay between writes. + handles.push(thread::spawn(move || { + let repeat = 30; + let slowness = 1; + for _ in 0..repeat { + add_records(&cc).expect("add"); + thread::sleep(Duration::from_millis(random_wait_millis(slowness))) + } + })); + + // Launch consumer threads which repeatedly read from the + // streams at various speeds. They'll effectively compete + // to consume the stream. + // + // Consumer groups are only appropriate for cases where you + // do NOT want each consumer to read ALL of the data. This + // example is a contrived scenario so that each consumer + // receives its own, specific chunk of data. + // + // Once the data is read, the redis-rs lib will automatically + // acknowledge its receipt via XACK. + // + // Read more about reading with consumer groups here: + // https://redis.io/commands/xreadgroup + for slowness in SLOWNESSES { + let repeat = 5; + let ca = client.clone(); + handles.push(thread::spawn(move || { + let mut con = ca.get_connection(None).expect("con"); + + // We must create each group and each consumer + // See https://redis.io/commands/xreadgroup#differences-between-xread-and-xreadgroup + + for key in STREAMS { + let created: Result<(), _> = con.xgroup_create_mkstream(*key, GROUP_NAME, "$"); + if let Err(e) = created { + println!("Group already exists: {e:?}") + } + } + + for _ in 0..repeat { + let read_reply = read_group_records(&ca, *slowness).expect("group read"); + + // fake some expensive work + for StreamKey { key, ids } in read_reply.keys { + for StreamId { id, map: _ } in &ids { + thread::sleep(Duration::from_millis(random_wait_millis(*slowness))); + println!( + "Stream {} ID {} Consumer slowness {} SysTime {}", + key, + id, + slowness, + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time") + .as_millis() + ); + } + + // acknowledge each stream and message ID once all messages are + // correctly processed + let id_strs: Vec<&String> = + ids.iter().map(|StreamId { id, map: _ }| id).collect(); + con.xack(key, GROUP_NAME, &id_strs).expect("ack") + } + } + })) + } + + for h in handles { + h.join().expect("Join") + } +} + +/// Generate some contrived records and add them to various +/// streams. +fn add_records(client: &redis::Client) -> RedisResult<()> { + let mut con = client.get_connection(None).expect("conn"); + + let maxlen = StreamMaxlen::Approx(1000); + + // a stream whose records have two fields + for _ in 0..thrifty_rand() { + con.xadd_maxlen( + DOG_STREAM, + maxlen, + "*", + &[("bark", arbitrary_value()), ("groom", arbitrary_value())], + )?; + } + + // a streams whose records have three fields + for _ in 0..thrifty_rand() { + con.xadd_maxlen( + CAT_STREAM, + maxlen, + "*", + &[ + ("meow", arbitrary_value()), + ("groom", arbitrary_value()), + ("hunt", arbitrary_value()), + ], + )?; + } + + // a streams whose records have four fields + for _ in 0..thrifty_rand() { + con.xadd_maxlen( + DUCK_STREAM, + maxlen, + "*", + &[ + ("quack", arbitrary_value()), + ("waddle", arbitrary_value()), + ("splash", arbitrary_value()), + ("flap", arbitrary_value()), + ], + )?; + } + + Ok(()) +} + +/// An approximation of randomness, without leaving the stdlib. +fn thrifty_rand() -> u8 { + let penultimate_num = 2; + (SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time travel") + .as_nanos() + % penultimate_num) as u8 + + 1 +} + +const MAGIC: u64 = 11; +fn random_wait_millis(slowness: u8) -> u64 { + thrifty_rand() as u64 * thrifty_rand() as u64 * MAGIC * slowness as u64 +} + +/// Generate a potentially unique value. +fn arbitrary_value() -> String { + format!( + "{}", + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time travel") + .as_nanos() + ) +} + +/// Block the thread for this many milliseconds while +/// waiting for data to arrive on the stream. +const BLOCK_MILLIS: usize = 5000; + +/// Read back records from all three streams, if they're available. +/// Doesn't bother with consumer groups. Generally the user +/// would be responsible for keeping track of the most recent +/// ID from which they need to read, but in this example, we +/// just go back to the beginning of time and ask for all the +/// records in the stream. +fn read_records(client: &redis::Client) -> RedisResult<()> { + let mut con = client.get_connection(None).expect("conn"); + + let opts = StreamReadOptions::default().block(BLOCK_MILLIS); + + // Oldest known time index + let starting_id = "0-0"; + // Same as above + let another_form = "0"; + + let srr: StreamReadReply = con + .xread_options(STREAMS, &[starting_id, another_form, starting_id], &opts) + .expect("read"); + + for StreamKey { key, ids } in srr.keys { + println!("Stream {key}"); + for StreamId { id, map } in ids { + println!("\tID {id}"); + for (n, s) in map { + if let Value::BulkString(bytes) = s { + println!("\t\t{}: {}", n, String::from_utf8(bytes).expect("utf8")) + } else { + panic!("Weird data") + } + } + } + } + + Ok(()) +} + +fn consumer_name(slowness: u8) -> String { + format!("example-consumer-{slowness}") +} + +const GROUP_NAME: &str = "example-group-aaa"; + +fn read_group_records(client: &redis::Client, slowness: u8) -> RedisResult { + let mut con = client.get_connection(None).expect("conn"); + + let opts = StreamReadOptions::default() + .block(BLOCK_MILLIS) + .count(3) + .group(GROUP_NAME, consumer_name(slowness)); + + let srr: StreamReadReply = con + .xread_options( + &[DOG_STREAM, CAT_STREAM, DUCK_STREAM], + &[">", ">", ">"], + &opts, + ) + .expect("records"); + + Ok(srr) +} + +fn clean_up(client: &redis::Client) { + let mut con = client.get_connection(None).expect("con"); + for k in STREAMS { + let trimmed: RedisResult<()> = con.xtrim(*k, StreamMaxlen::Equals(0)); + trimmed.expect("trim"); + + let destroyed: RedisResult<()> = con.xgroup_destroy(*k, GROUP_NAME); + destroyed.expect("xgroup destroy"); + } +} diff --git a/glide-core/redis-rs/redis/release.toml b/glide-core/redis-rs/redis/release.toml new file mode 100644 index 0000000000..942730e0b6 --- /dev/null +++ b/glide-core/redis-rs/redis/release.toml @@ -0,0 +1,2 @@ +pre-release-hook = "../scripts/update-versions.sh" +tag-name = "{{version}}" diff --git a/glide-core/redis-rs/redis/src/acl.rs b/glide-core/redis-rs/redis/src/acl.rs new file mode 100644 index 0000000000..ef85877ba6 --- /dev/null +++ b/glide-core/redis-rs/redis/src/acl.rs @@ -0,0 +1,312 @@ +//! Defines types to use with the ACL commands. + +use crate::types::{ + ErrorKind, FromRedisValue, RedisError, RedisResult, RedisWrite, ToRedisArgs, Value, +}; + +macro_rules! not_convertible_error { + ($v:expr, $det:expr) => { + RedisError::from(( + ErrorKind::TypeError, + "Response type not convertible", + format!("{:?} (response was {:?})", $det, $v), + )) + }; +} + +/// ACL rules are used in order to activate or remove a flag, or to perform a +/// given change to the user ACL, which under the hood are just single words. +#[derive(Debug, Eq, PartialEq)] +pub enum Rule { + /// Enable the user: it is possible to authenticate as this user. + On, + /// Disable the user: it's no longer possible to authenticate with this + /// user, however the already authenticated connections will still work. + Off, + + /// Add the command to the list of commands the user can call. + AddCommand(String), + /// Remove the command to the list of commands the user can call. + RemoveCommand(String), + /// Add all the commands in such category to be called by the user. + AddCategory(String), + /// Remove the commands from such category the client can call. + RemoveCategory(String), + /// Alias for `+@all`. Note that it implies the ability to execute all the + /// future commands loaded via the modules system. + AllCommands, + /// Alias for `-@all`. + NoCommands, + + /// Add this password to the list of valid password for the user. + AddPass(String), + /// Remove this password from the list of valid passwords. + RemovePass(String), + /// Add this SHA-256 hash value to the list of valid passwords for the user. + AddHashedPass(String), + /// Remove this hash value from from the list of valid passwords + RemoveHashedPass(String), + /// All the set passwords of the user are removed, and the user is flagged + /// as requiring no password: it means that every password will work + /// against this user. + NoPass, + /// Flush the list of allowed passwords. Moreover removes the _nopass_ status. + ResetPass, + + /// Add a pattern of keys that can be mentioned as part of commands. + Pattern(String), + /// Alias for `~*`. + AllKeys, + /// Flush the list of allowed keys patterns. + ResetKeys, + + /// Performs the following actions: `resetpass`, `resetkeys`, `off`, `-@all`. + /// The user returns to the same state it has immediately after its creation. + Reset, + + /// Raw text of [`ACL rule`][1] that not enumerated above. + /// + /// [1]: https://redis.io/docs/manual/security/acl + Other(String), +} + +impl ToRedisArgs for Rule { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + use self::Rule::*; + + match self { + On => out.write_arg(b"on"), + Off => out.write_arg(b"off"), + + AddCommand(cmd) => out.write_arg_fmt(format_args!("+{cmd}")), + RemoveCommand(cmd) => out.write_arg_fmt(format_args!("-{cmd}")), + AddCategory(cat) => out.write_arg_fmt(format_args!("+@{cat}")), + RemoveCategory(cat) => out.write_arg_fmt(format_args!("-@{cat}")), + AllCommands => out.write_arg(b"allcommands"), + NoCommands => out.write_arg(b"nocommands"), + + AddPass(pass) => out.write_arg_fmt(format_args!(">{pass}")), + RemovePass(pass) => out.write_arg_fmt(format_args!("<{pass}")), + AddHashedPass(pass) => out.write_arg_fmt(format_args!("#{pass}")), + RemoveHashedPass(pass) => out.write_arg_fmt(format_args!("!{pass}")), + NoPass => out.write_arg(b"nopass"), + ResetPass => out.write_arg(b"resetpass"), + + Pattern(pat) => out.write_arg_fmt(format_args!("~{pat}")), + AllKeys => out.write_arg(b"allkeys"), + ResetKeys => out.write_arg(b"resetkeys"), + + Reset => out.write_arg(b"reset"), + + Other(rule) => out.write_arg(rule.as_bytes()), + }; + } +} + +/// An info dictionary type storing Redis ACL information as multiple `Rule`. +/// This type collects key/value data returned by the [`ACL GETUSER`][1] command. +/// +/// [1]: https://redis.io/commands/acl-getuser +#[derive(Debug, Eq, PartialEq)] +pub struct AclInfo { + /// Describes flag rules for the user. Represented by [`Rule::On`][1], + /// [`Rule::Off`][2], [`Rule::AllKeys`][3], [`Rule::AllCommands`][4] and + /// [`Rule::NoPass`][5]. + /// + /// [1]: ./enum.Rule.html#variant.On + /// [2]: ./enum.Rule.html#variant.Off + /// [3]: ./enum.Rule.html#variant.AllKeys + /// [4]: ./enum.Rule.html#variant.AllCommands + /// [5]: ./enum.Rule.html#variant.NoPass + pub flags: Vec, + /// Describes the user's passwords. Represented by [`Rule::AddHashedPass`][1]. + /// + /// [1]: ./enum.Rule.html#variant.AddHashedPass + pub passwords: Vec, + /// Describes capabilities of which commands the user can call. + /// Represented by [`Rule::AddCommand`][1], [`Rule::AddCategory`][2], + /// [`Rule::RemoveCommand`][3] and [`Rule::RemoveCategory`][4]. + /// + /// [1]: ./enum.Rule.html#variant.AddCommand + /// [2]: ./enum.Rule.html#variant.AddCategory + /// [3]: ./enum.Rule.html#variant.RemoveCommand + /// [4]: ./enum.Rule.html#variant.RemoveCategory + pub commands: Vec, + /// Describes patterns of keys which the user can access. Represented by + /// [`Rule::Pattern`][1]. + /// + /// [1]: ./enum.Rule.html#variant.Pattern + pub keys: Vec, +} + +impl FromRedisValue for AclInfo { + fn from_redis_value(v: &Value) -> RedisResult { + let mut it = v + .as_sequence() + .ok_or_else(|| not_convertible_error!(v, ""))? + .iter() + .skip(1) + .step_by(2); + + let (flags, passwords, commands, keys) = match (it.next(), it.next(), it.next(), it.next()) + { + (Some(flags), Some(passwords), Some(commands), Some(keys)) => { + // Parse flags + // Ref: https://github.com/redis/redis/blob/0cabe0cfa7290d9b14596ec38e0d0a22df65d1df/src/acl.c#L83-L90 + let flags = flags + .as_sequence() + .ok_or_else(|| { + not_convertible_error!(flags, "Expect an array response of ACL flags") + })? + .iter() + .map(|flag| match flag { + Value::BulkString(flag) => match flag.as_slice() { + b"on" => Ok(Rule::On), + b"off" => Ok(Rule::Off), + b"allkeys" => Ok(Rule::AllKeys), + b"allcommands" => Ok(Rule::AllCommands), + b"nopass" => Ok(Rule::NoPass), + other => Ok(Rule::Other(String::from_utf8_lossy(other).into_owned())), + }, + _ => Err(not_convertible_error!( + flag, + "Expect an arbitrary binary data" + )), + }) + .collect::>()?; + + let passwords = passwords + .as_sequence() + .ok_or_else(|| { + not_convertible_error!(flags, "Expect an array response of ACL flags") + })? + .iter() + .map(|pass| Ok(Rule::AddHashedPass(String::from_redis_value(pass)?))) + .collect::>()?; + + let commands = match commands { + Value::BulkString(cmd) => std::str::from_utf8(cmd)?, + _ => { + return Err(not_convertible_error!( + commands, + "Expect a valid UTF8 string" + )) + } + } + .split_terminator(' ') + .map(|cmd| match cmd { + x if x.starts_with("+@") => Ok(Rule::AddCategory(x[2..].to_owned())), + x if x.starts_with("-@") => Ok(Rule::RemoveCategory(x[2..].to_owned())), + x if x.starts_with('+') => Ok(Rule::AddCommand(x[1..].to_owned())), + x if x.starts_with('-') => Ok(Rule::RemoveCommand(x[1..].to_owned())), + _ => Err(not_convertible_error!( + cmd, + "Expect a command addition/removal" + )), + }) + .collect::>()?; + + let keys = keys + .as_sequence() + .ok_or_else(|| not_convertible_error!(keys, ""))? + .iter() + .map(|pat| Ok(Rule::Pattern(String::from_redis_value(pat)?))) + .collect::>()?; + + (flags, passwords, commands, keys) + } + _ => { + return Err(not_convertible_error!( + v, + "Expect a resposne from `ACL GETUSER`" + )) + } + }; + + Ok(Self { + flags, + passwords, + commands, + keys, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! assert_args { + ($rule:expr, $arg:expr) => { + assert_eq!($rule.to_redis_args(), vec![$arg.to_vec()]); + }; + } + + #[test] + fn test_rule_to_arg() { + use self::Rule::*; + + assert_args!(On, b"on"); + assert_args!(Off, b"off"); + assert_args!(AddCommand("set".to_owned()), b"+set"); + assert_args!(RemoveCommand("set".to_owned()), b"-set"); + assert_args!(AddCategory("hyperloglog".to_owned()), b"+@hyperloglog"); + assert_args!(RemoveCategory("hyperloglog".to_owned()), b"-@hyperloglog"); + assert_args!(AllCommands, b"allcommands"); + assert_args!(NoCommands, b"nocommands"); + assert_args!(AddPass("mypass".to_owned()), b">mypass"); + assert_args!(RemovePass("mypass".to_owned()), b" io::Result { + let socket = TcpStream::connect(addr).await?; + #[cfg(feature = "tcp_nodelay")] + socket.set_nodelay(true)?; + #[cfg(feature = "keep-alive")] + { + //For now rely on system defaults + const KEEP_ALIVE: socket2::TcpKeepalive = socket2::TcpKeepalive::new(); + //these are useless error that not going to happen + let mut std_socket = std::net::TcpStream::try_from(socket)?; + let socket2: socket2::Socket = std_socket.into(); + socket2.set_tcp_keepalive(&KEEP_ALIVE)?; + std_socket = socket2.into(); + Ok(std_socket.into()) + } + #[cfg(not(feature = "keep-alive"))] + { + Ok(socket) + } +} +#[cfg(feature = "tls-rustls")] +use crate::tls::TlsConnParams; + +#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] +use crate::connection::TlsConnParams; + +pin_project_lite::pin_project! { + /// Wraps the async_std `AsyncRead/AsyncWrite` in order to implement the required the tokio traits + /// for it + pub struct AsyncStdWrapped { #[pin] inner: T } +} + +impl AsyncStdWrapped { + pub(super) fn new(inner: T) -> Self { + Self { inner } + } +} + +impl AsyncWrite for AsyncStdWrapped +where + T: async_std::io::Write, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut core::task::Context, + buf: &[u8], + ) -> std::task::Poll> { + async_std::io::Write::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut core::task::Context, + ) -> std::task::Poll> { + async_std::io::Write::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut core::task::Context, + ) -> std::task::Poll> { + async_std::io::Write::poll_close(self.project().inner, cx) + } +} + +impl AsyncRead for AsyncStdWrapped +where + T: async_std::io::Read, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut core::task::Context, + buf: &mut ReadBuf<'_>, + ) -> std::task::Poll> { + let n = ready!(async_std::io::Read::poll_read( + self.project().inner, + cx, + buf.initialize_unfilled() + ))?; + buf.advance(n); + std::task::Poll::Ready(Ok(())) + } +} + +/// Represents an AsyncStd connectable +pub enum AsyncStd { + /// Represents an Async_std TCP connection. + Tcp(AsyncStdWrapped), + /// Represents an Async_std TLS encrypted TCP connection. + #[cfg(any( + feature = "async-std-native-tls-comp", + feature = "async-std-rustls-comp" + ))] + TcpTls(AsyncStdWrapped>>), + /// Represents an Async_std Unix connection. + #[cfg(unix)] + Unix(AsyncStdWrapped), +} + +impl AsyncWrite for AsyncStd { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + buf: &[u8], + ) -> Poll> { + match &mut *self { + AsyncStd::Tcp(r) => Pin::new(r).poll_write(cx, buf), + #[cfg(any( + feature = "async-std-native-tls-comp", + feature = "async-std-rustls-comp" + ))] + AsyncStd::TcpTls(r) => Pin::new(r).poll_write(cx, buf), + #[cfg(unix)] + AsyncStd::Unix(r) => Pin::new(r).poll_write(cx, buf), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { + match &mut *self { + AsyncStd::Tcp(r) => Pin::new(r).poll_flush(cx), + #[cfg(any( + feature = "async-std-native-tls-comp", + feature = "async-std-rustls-comp" + ))] + AsyncStd::TcpTls(r) => Pin::new(r).poll_flush(cx), + #[cfg(unix)] + AsyncStd::Unix(r) => Pin::new(r).poll_flush(cx), + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { + match &mut *self { + AsyncStd::Tcp(r) => Pin::new(r).poll_shutdown(cx), + #[cfg(any( + feature = "async-std-native-tls-comp", + feature = "async-std-rustls-comp" + ))] + AsyncStd::TcpTls(r) => Pin::new(r).poll_shutdown(cx), + #[cfg(unix)] + AsyncStd::Unix(r) => Pin::new(r).poll_shutdown(cx), + } + } +} + +impl AsyncRead for AsyncStd { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match &mut *self { + AsyncStd::Tcp(r) => Pin::new(r).poll_read(cx, buf), + #[cfg(any( + feature = "async-std-native-tls-comp", + feature = "async-std-rustls-comp" + ))] + AsyncStd::TcpTls(r) => Pin::new(r).poll_read(cx, buf), + #[cfg(unix)] + AsyncStd::Unix(r) => Pin::new(r).poll_read(cx, buf), + } + } +} + +#[async_trait] +impl RedisRuntime for AsyncStd { + async fn connect_tcp(socket_addr: SocketAddr) -> RedisResult { + Ok(connect_tcp(&socket_addr) + .await + .map(|con| Self::Tcp(AsyncStdWrapped::new(con)))?) + } + + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + async fn connect_tcp_tls( + hostname: &str, + socket_addr: SocketAddr, + insecure: bool, + _tls_params: &Option, + ) -> RedisResult { + let tcp_stream = connect_tcp(&socket_addr).await?; + let tls_connector = if insecure { + TlsConnector::new() + .danger_accept_invalid_certs(true) + .danger_accept_invalid_hostnames(true) + .use_sni(false) + } else { + TlsConnector::new() + }; + Ok(tls_connector + .connect(hostname, tcp_stream) + .await + .map(|con| Self::TcpTls(AsyncStdWrapped::new(Box::new(con))))?) + } + + #[cfg(feature = "tls-rustls")] + async fn connect_tcp_tls( + hostname: &str, + socket_addr: SocketAddr, + insecure: bool, + tls_params: &Option, + ) -> RedisResult { + let tcp_stream = connect_tcp(&socket_addr).await?; + + let config = create_rustls_config(insecure, tls_params.clone())?; + let tls_connector = TlsConnector::from(Arc::new(config)); + + Ok(tls_connector + .connect( + rustls_pki_types::ServerName::try_from(hostname)?.to_owned(), + tcp_stream, + ) + .await + .map(|con| Self::TcpTls(AsyncStdWrapped::new(Box::new(con))))?) + } + + #[cfg(unix)] + async fn connect_unix(path: &Path) -> RedisResult { + Ok(UnixStream::connect(path) + .await + .map(|con| Self::Unix(AsyncStdWrapped::new(con)))?) + } + + fn spawn(f: impl Future + Send + 'static) { + async_std::task::spawn(f); + } + + fn boxed(self) -> Pin> { + match self { + AsyncStd::Tcp(x) => Box::pin(x), + #[cfg(any( + feature = "async-std-native-tls-comp", + feature = "async-std-rustls-comp" + ))] + AsyncStd::TcpTls(x) => Box::pin(x), + #[cfg(unix)] + AsyncStd::Unix(x) => Box::pin(x), + } + } +} diff --git a/glide-core/redis-rs/redis/src/aio/connection.rs b/glide-core/redis-rs/redis/src/aio/connection.rs new file mode 100644 index 0000000000..6b1f6e657a --- /dev/null +++ b/glide-core/redis-rs/redis/src/aio/connection.rs @@ -0,0 +1,543 @@ +#![allow(deprecated)] + +#[cfg(feature = "async-std-comp")] +use super::async_std; +use super::ConnectionLike; +use super::{setup_connection, AsyncStream, RedisRuntime}; +use crate::cmd::{cmd, Cmd}; +use crate::connection::{ + resp2_is_pub_sub_state_cleared, resp3_is_pub_sub_state_cleared, ConnectionAddr, ConnectionInfo, + Msg, RedisConnectionInfo, +}; +#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] +use crate::parser::ValueCodec; +use crate::types::{ErrorKind, FromRedisValue, RedisError, RedisFuture, RedisResult, Value}; +use crate::{from_owned_redis_value, ProtocolVersion, ToRedisArgs}; +#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] +use ::async_std::net::ToSocketAddrs; +use ::tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +#[cfg(feature = "tokio-comp")] +use ::tokio::net::lookup_host; +use combine::{parser::combinator::AnySendSyncPartialState, stream::PointerOffset}; +use futures_util::future::select_ok; +use futures_util::{ + future::FutureExt, + stream::{Stream, StreamExt}, +}; +use std::net::{IpAddr, SocketAddr}; +use std::pin::Pin; +#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] +use tokio_util::codec::Decoder; +use tracing::info; + +/// Represents a stateful redis TCP connection. +#[deprecated(note = "aio::Connection is deprecated. Use aio::MultiplexedConnection instead.")] +pub struct Connection>> { + con: C, + buf: Vec, + decoder: combine::stream::Decoder>, + db: i64, + + // Flag indicating whether the connection was left in the PubSub state after dropping `PubSub`. + // + // This flag is checked when attempting to send a command, and if it's raised, we attempt to + // exit the pubsub state before executing the new request. + pubsub: bool, + + // Field indicating which protocol to use for server communications. + protocol: ProtocolVersion, +} + +fn assert_sync() {} + +#[allow(unused)] +fn test() { + assert_sync::(); +} + +impl Connection { + pub(crate) fn map(self, f: impl FnOnce(C) -> D) -> Connection { + let Self { + con, + buf, + decoder, + db, + pubsub, + protocol, + } = self; + Connection { + con: f(con), + buf, + decoder, + db, + pubsub, + protocol, + } + } +} + +impl Connection +where + C: Unpin + AsyncRead + AsyncWrite + Send, +{ + /// Constructs a new `Connection` out of a `AsyncRead + AsyncWrite` object + /// and a `RedisConnectionInfo` + pub async fn new(connection_info: &RedisConnectionInfo, con: C) -> RedisResult { + let mut rv = Connection { + con, + buf: Vec::new(), + decoder: combine::stream::Decoder::new(), + db: connection_info.db, + pubsub: false, + protocol: connection_info.protocol, + }; + setup_connection(connection_info, &mut rv).await?; + Ok(rv) + } + + /// Converts this [`Connection`] into [`PubSub`]. + pub fn into_pubsub(self) -> PubSub { + PubSub::new(self) + } + + /// Converts this [`Connection`] into [`Monitor`] + pub fn into_monitor(self) -> Monitor { + Monitor::new(self) + } + + /// Fetches a single response from the connection. + async fn read_response(&mut self) -> RedisResult { + crate::parser::parse_redis_value_async(&mut self.decoder, &mut self.con).await + } + + /// Brings [`Connection`] out of `PubSub` mode. + /// + /// This will unsubscribe this [`Connection`] from all subscriptions. + /// + /// If this function returns error then on all command send tries will be performed attempt + /// to exit from `PubSub` mode until it will be successful. + async fn exit_pubsub(&mut self) -> RedisResult<()> { + let res = self.clear_active_subscriptions().await; + if res.is_ok() { + self.pubsub = false; + } else { + // Raise the pubsub flag to indicate the connection is "stuck" in that state. + self.pubsub = true; + } + + res + } + + /// Get the inner connection out of a PubSub + /// + /// Any active subscriptions are unsubscribed. In the event of an error, the connection is + /// dropped. + async fn clear_active_subscriptions(&mut self) -> RedisResult<()> { + // Responses to unsubscribe commands return in a 3-tuple with values + // ("unsubscribe" or "punsubscribe", name of subscription removed, count of remaining subs). + // The "count of remaining subs" includes both pattern subscriptions and non pattern + // subscriptions. Thus, to accurately drain all unsubscribe messages received from the + // server, both commands need to be executed at once. + { + // Prepare both unsubscribe commands + let unsubscribe = crate::Pipeline::new() + .add_command(cmd("UNSUBSCRIBE")) + .add_command(cmd("PUNSUBSCRIBE")) + .get_packed_pipeline(); + + // Execute commands + self.con.write_all(&unsubscribe).await?; + } + + // Receive responses + // + // There will be at minimum two responses - 1 for each of punsubscribe and unsubscribe + // commands. There may be more responses if there are active subscriptions. In this case, + // messages are received until the _subscription count_ in the responses reach zero. + let mut received_unsub = false; + let mut received_punsub = false; + if self.protocol != ProtocolVersion::RESP2 { + while let Value::Push { kind, data } = + from_owned_redis_value(self.read_response().await?)? + { + if data.len() >= 2 { + if let Value::Int(num) = data[1] { + if resp3_is_pub_sub_state_cleared( + &mut received_unsub, + &mut received_punsub, + &kind, + num as isize, + ) { + break; + } + } + } + } + } else { + loop { + let res: (Vec, (), isize) = + from_owned_redis_value(self.read_response().await?)?; + if resp2_is_pub_sub_state_cleared( + &mut received_unsub, + &mut received_punsub, + &res.0, + res.2, + ) { + break; + } + } + } + + // Finally, the connection is back in its normal state since all subscriptions were + // cancelled *and* all unsubscribe messages were received. + Ok(()) + } +} + +#[cfg(feature = "async-std-comp")] +#[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] +impl Connection> +where + C: Unpin + ::async_std::io::Read + ::async_std::io::Write + Send, +{ + /// Constructs a new `Connection` out of a `async_std::io::AsyncRead + async_std::io::AsyncWrite` object + /// and a `RedisConnectionInfo` + pub async fn new_async_std(connection_info: &RedisConnectionInfo, con: C) -> RedisResult { + Connection::new(connection_info, async_std::AsyncStdWrapped::new(con)).await + } +} + +pub(crate) async fn connect( + connection_info: &ConnectionInfo, + socket_addr: Option, +) -> RedisResult> +where + C: Unpin + RedisRuntime + AsyncRead + AsyncWrite + Send, +{ + let (con, _ip) = connect_simple::(connection_info, socket_addr).await?; + Connection::new(&connection_info.redis, con).await +} + +impl ConnectionLike for Connection +where + C: Unpin + AsyncRead + AsyncWrite + Send, +{ + fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> { + (async move { + if self.pubsub { + self.exit_pubsub().await?; + } + self.buf.clear(); + cmd.write_packed_command(&mut self.buf); + self.con.write_all(&self.buf).await?; + if cmd.is_no_response() { + return Ok(Value::Nil); + } + loop { + match self.read_response().await? { + Value::Push { .. } => continue, + val => return Ok(val), + } + } + }) + .boxed() + } + + fn req_packed_commands<'a>( + &'a mut self, + cmd: &'a crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisFuture<'a, Vec> { + (async move { + if self.pubsub { + self.exit_pubsub().await?; + } + + self.buf.clear(); + cmd.write_packed_pipeline(&mut self.buf); + self.con.write_all(&self.buf).await?; + + let mut first_err = None; + + for _ in 0..offset { + let response = self.read_response().await; + if let Err(err) = response { + if first_err.is_none() { + first_err = Some(err); + } + } + } + + let mut rv = Vec::with_capacity(count); + let mut count = count; + let mut idx = 0; + while idx < count { + let response = self.read_response().await; + match response { + Ok(item) => { + // RESP3 can insert push data between command replies + if let Value::Push { .. } = item { + // if that is the case we have to extend the loop and handle push data + count += 1; + } else { + rv.push(item); + } + } + Err(err) => { + if first_err.is_none() { + first_err = Some(err); + } + } + } + idx += 1; + } + + if let Some(err) = first_err { + Err(err) + } else { + Ok(rv) + } + }) + .boxed() + } + + fn get_db(&self) -> i64 { + self.db + } + + fn is_closed(&self) -> bool { + // always false for AsyncRead + AsyncWrite (cant do better) + false + } +} + +/// Represents a `PubSub` connection. +pub struct PubSub>>(Connection); + +/// Represents a `Monitor` connection. +pub struct Monitor>>(Connection); + +impl PubSub +where + C: Unpin + AsyncRead + AsyncWrite + Send, +{ + fn new(con: Connection) -> Self { + Self(con) + } + + /// Subscribes to a new channel. + pub async fn subscribe(&mut self, channel: T) -> RedisResult<()> { + let mut cmd = cmd("SUBSCRIBE"); + cmd.arg(channel); + if self.0.protocol != ProtocolVersion::RESP2 { + cmd.set_no_response(true); + } + cmd.query_async(&mut self.0).await + } + + /// Subscribes to a new channel with a pattern. + pub async fn psubscribe(&mut self, pchannel: T) -> RedisResult<()> { + let mut cmd = cmd("PSUBSCRIBE"); + cmd.arg(pchannel); + if self.0.protocol != ProtocolVersion::RESP2 { + cmd.set_no_response(true); + } + cmd.query_async(&mut self.0).await + } + + /// Unsubscribes from a channel. + pub async fn unsubscribe(&mut self, channel: T) -> RedisResult<()> { + let mut cmd = cmd("UNSUBSCRIBE"); + cmd.arg(channel); + if self.0.protocol != ProtocolVersion::RESP2 { + cmd.set_no_response(true); + } + cmd.query_async(&mut self.0).await + } + + /// Unsubscribes from a channel with a pattern. + pub async fn punsubscribe(&mut self, pchannel: T) -> RedisResult<()> { + let mut cmd = cmd("PUNSUBSCRIBE"); + cmd.arg(pchannel); + if self.0.protocol != ProtocolVersion::RESP2 { + cmd.set_no_response(true); + } + cmd.query_async(&mut self.0).await + } + + /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions. + /// + /// The message itself is still generic and can be converted into an appropriate type through + /// the helper methods on it. + pub fn on_message(&mut self) -> impl Stream + '_ { + ValueCodec::default() + .framed(&mut self.0.con) + .filter_map(|msg| Box::pin(async move { Msg::from_value(&msg.ok()?.ok()?) })) + } + + /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions consuming it. + /// + /// The message itself is still generic and can be converted into an appropriate type through + /// the helper methods on it. + /// This can be useful in cases where the stream needs to be returned or held by something other + /// than the [`PubSub`]. + pub fn into_on_message(self) -> impl Stream { + ValueCodec::default() + .framed(self.0.con) + .filter_map(|msg| Box::pin(async move { Msg::from_value(&msg.ok()?.ok()?) })) + } + + /// Exits from `PubSub` mode and converts [`PubSub`] into [`Connection`]. + #[deprecated(note = "aio::Connection is deprecated")] + pub async fn into_connection(mut self) -> Connection { + self.0.exit_pubsub().await.ok(); + + self.0 + } +} + +impl Monitor +where + C: Unpin + AsyncRead + AsyncWrite + Send, +{ + /// Create a [`Monitor`] from a [`Connection`] + pub fn new(con: Connection) -> Self { + Self(con) + } + + /// Deliver the MONITOR command to this [`Monitor`]ing wrapper. + pub async fn monitor(&mut self) -> RedisResult<()> { + cmd("MONITOR").query_async(&mut self.0).await + } + + /// Returns [`Stream`] of [`FromRedisValue`] values from this [`Monitor`]ing connection + pub fn on_message(&mut self) -> impl Stream + '_ { + ValueCodec::default() + .framed(&mut self.0.con) + .filter_map(|value| { + Box::pin(async move { T::from_owned_redis_value(value.ok()?.ok()?).ok() }) + }) + } + + /// Returns [`Stream`] of [`FromRedisValue`] values from this [`Monitor`]ing connection + pub fn into_on_message(self) -> impl Stream { + ValueCodec::default() + .framed(self.0.con) + .filter_map(|value| { + Box::pin(async move { T::from_owned_redis_value(value.ok()?.ok()?).ok() }) + }) + } +} + +pub(crate) async fn get_socket_addrs( + host: &str, + port: u16, +) -> RedisResult + Send + '_> { + #[cfg(feature = "tokio-comp")] + let socket_addrs = lookup_host((host, port)).await?; + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + let socket_addrs = (host, port).to_socket_addrs().await?; + + let mut socket_addrs = socket_addrs.peekable(); + match socket_addrs.peek() { + Some(_) => Ok(socket_addrs), + None => Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "No address found for host", + ))), + } +} + +/// Logs the creation of a connection, including its type, the node, and optionally its IP address. +fn log_conn_creation(conn_type: &str, node: T, ip: Option) +where + T: std::fmt::Debug, +{ + info!( + "Creating {conn_type} connection for node: {node:?}{}", + ip.map(|ip| format!(", IP: {:?}", ip)).unwrap_or_default() + ); +} + +pub(crate) async fn connect_simple( + connection_info: &ConnectionInfo, + _socket_addr: Option, +) -> RedisResult<(T, Option)> { + Ok(match connection_info.addr { + ConnectionAddr::Tcp(ref host, port) => { + if let Some(socket_addr) = _socket_addr { + return Ok::<_, RedisError>(( + ::connect_tcp(socket_addr).await?, + Some(socket_addr.ip()), + )); + } + let socket_addrs = get_socket_addrs(host, port).await?; + select_ok(socket_addrs.map(|socket_addr| { + log_conn_creation("TCP", format!("{host}:{port}"), Some(socket_addr.ip())); + Box::pin(async move { + Ok::<_, RedisError>(( + ::connect_tcp(socket_addr).await?, + Some(socket_addr.ip()), + )) + }) + })) + .await? + .0 + } + + #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))] + ConnectionAddr::TcpTls { + ref host, + port, + insecure, + ref tls_params, + } => { + if let Some(socket_addr) = _socket_addr { + return Ok::<_, RedisError>(( + ::connect_tcp_tls(host, socket_addr, insecure, tls_params).await?, + Some(socket_addr.ip()), + )); + } + let socket_addrs = get_socket_addrs(host, port).await?; + select_ok(socket_addrs.map(|socket_addr| { + log_conn_creation( + "TCP with TLS", + format!("{host}:{port}"), + Some(socket_addr.ip()), + ); + Box::pin(async move { + Ok::<_, RedisError>(( + ::connect_tcp_tls(host, socket_addr, insecure, tls_params).await?, + Some(socket_addr.ip()), + )) + }) + })) + .await? + .0 + } + + #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))] + ConnectionAddr::TcpTls { .. } => { + fail!(( + ErrorKind::InvalidClientConfig, + "Cannot connect to TCP with TLS without the tls feature" + )); + } + + #[cfg(unix)] + ConnectionAddr::Unix(ref path) => { + log_conn_creation("UDS", path, None); + (::connect_unix(path).await?, None) + } + + #[cfg(not(unix))] + ConnectionAddr::Unix(_) => { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Cannot connect to unix sockets \ + on this platform", + ))) + } + }) +} diff --git a/glide-core/redis-rs/redis/src/aio/connection_manager.rs b/glide-core/redis-rs/redis/src/aio/connection_manager.rs new file mode 100644 index 0000000000..61df9bc31a --- /dev/null +++ b/glide-core/redis-rs/redis/src/aio/connection_manager.rs @@ -0,0 +1,310 @@ +use super::RedisFuture; +use crate::client::GlideConnectionOptions; +use crate::cmd::Cmd; +use crate::push_manager::PushManager; +use crate::types::{RedisError, RedisResult, Value}; +use crate::{ + aio::{ConnectionLike, MultiplexedConnection, Runtime}, + Client, +}; +#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] +use ::async_std::net::ToSocketAddrs; +use arc_swap::ArcSwap; +use futures::{ + future::{self, Shared}, + FutureExt, +}; +use futures_util::future::BoxFuture; +use std::sync::Arc; +use tokio_retry::strategy::{jitter, ExponentialBackoff}; +use tokio_retry::Retry; + +/// A `ConnectionManager` is a proxy that wraps a [multiplexed +/// connection][multiplexed-connection] and automatically reconnects to the +/// server when necessary. +/// +/// Like the [`MultiplexedConnection`][multiplexed-connection], this +/// manager can be cloned, allowing requests to be be sent concurrently on +/// the same underlying connection (tcp/unix socket). +/// +/// ## Behavior +/// +/// - When creating an instance of the `ConnectionManager`, an initial +/// connection will be established and awaited. Connection errors will be +/// returned directly. +/// - When a command sent to the server fails with an error that represents +/// a "connection dropped" condition, that error will be passed on to the +/// user, but it will trigger a reconnection in the background. +/// - The reconnect code will atomically swap the current (dead) connection +/// with a future that will eventually resolve to a `MultiplexedConnection` +/// or to a `RedisError` +/// - All commands that are issued after the reconnect process has been +/// initiated, will have to await the connection future. +/// - If reconnecting fails, all pending commands will be failed as well. A +/// new reconnection attempt will be triggered if the error is an I/O error. +/// +/// [multiplexed-connection]: struct.MultiplexedConnection.html +#[derive(Clone)] +pub struct ConnectionManager { + /// Information used for the connection. This is needed to be able to reconnect. + client: Client, + /// The connection future. + /// + /// The `ArcSwap` is required to be able to replace the connection + /// without making the `ConnectionManager` mutable. + connection: Arc>>, + + runtime: Runtime, + retry_strategy: ExponentialBackoff, + number_of_retries: usize, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + push_manager: PushManager, +} + +/// A `RedisResult` that can be cloned because `RedisError` is behind an `Arc`. +type CloneableRedisResult = Result>; + +/// Type alias for a shared boxed future that will resolve to a `CloneableRedisResult`. +type SharedRedisFuture = Shared>>; + +/// Handle a command result. If the connection was dropped, reconnect. +macro_rules! reconnect_if_dropped { + ($self:expr, $result:expr, $current:expr) => { + if let Err(ref e) = $result { + if e.is_unrecoverable_error() { + $self.reconnect($current); + } + } + }; +} + +/// Handle a connection result. If there's an I/O error, reconnect. +/// Propagate any error. +macro_rules! reconnect_if_io_error { + ($self:expr, $result:expr, $current:expr) => { + if let Err(e) = $result { + if e.is_io_error() { + $self.reconnect($current); + } + return Err(e); + } + }; +} + +impl ConnectionManager { + const DEFAULT_CONNECTION_RETRY_EXPONENT_BASE: u64 = 2; + const DEFAULT_CONNECTION_RETRY_FACTOR: u64 = 100; + const DEFAULT_NUMBER_OF_CONNECTION_RETRIESE: usize = 6; + + /// Connect to the server and store the connection inside the returned `ConnectionManager`. + /// + /// This requires the `connection-manager` feature, which will also pull in + /// the Tokio executor. + pub async fn new(client: Client) -> RedisResult { + Self::new_with_backoff( + client, + Self::DEFAULT_CONNECTION_RETRY_EXPONENT_BASE, + Self::DEFAULT_CONNECTION_RETRY_FACTOR, + Self::DEFAULT_NUMBER_OF_CONNECTION_RETRIESE, + ) + .await + } + + /// Connect to the server and store the connection inside the returned `ConnectionManager`. + /// + /// This requires the `connection-manager` feature, which will also pull in + /// the Tokio executor. + /// + /// In case of reconnection issues, the manager will retry reconnection + /// number_of_retries times, with an exponentially increasing delay, calculated as + /// rand(0 .. factor * (exponent_base ^ current-try)). + pub async fn new_with_backoff( + client: Client, + exponent_base: u64, + factor: u64, + number_of_retries: usize, + ) -> RedisResult { + Self::new_with_backoff_and_timeouts( + client, + exponent_base, + factor, + number_of_retries, + std::time::Duration::MAX, + std::time::Duration::MAX, + ) + .await + } + + /// Connect to the server and store the connection inside the returned `ConnectionManager`. + /// + /// This requires the `connection-manager` feature, which will also pull in + /// the Tokio executor. + /// + /// In case of reconnection issues, the manager will retry reconnection + /// number_of_retries times, with an exponentially increasing delay, calculated as + /// rand(0 .. factor * (exponent_base ^ current-try)). + /// + /// The new connection will timeout operations after `response_timeout` has passed. + /// Each connection attempt to the server will timeout after `connection_timeout`. + pub async fn new_with_backoff_and_timeouts( + client: Client, + exponent_base: u64, + factor: u64, + number_of_retries: usize, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + ) -> RedisResult { + // Create a MultiplexedConnection and wait for it to be established + let push_manager = PushManager::default(); + let runtime = Runtime::locate(); + let retry_strategy = ExponentialBackoff::from_millis(exponent_base).factor(factor); + let mut connection = Self::new_connection( + client.clone(), + retry_strategy.clone(), + number_of_retries, + response_timeout, + connection_timeout, + ) + .await?; + + // Wrap the connection in an `ArcSwap` instance for fast atomic access + connection.set_push_manager(push_manager.clone()).await; + Ok(Self { + client, + connection: Arc::new(ArcSwap::from_pointee( + future::ok(connection).boxed().shared(), + )), + runtime, + number_of_retries, + retry_strategy, + response_timeout, + connection_timeout, + push_manager, + }) + } + + async fn new_connection( + client: Client, + exponential_backoff: ExponentialBackoff, + number_of_retries: usize, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + ) -> RedisResult { + let retry_strategy = exponential_backoff.map(jitter).take(number_of_retries); + Retry::spawn(retry_strategy, || { + client.get_multiplexed_async_connection_with_timeouts( + response_timeout, + connection_timeout, + GlideConnectionOptions::default(), + ) + }) + .await + } + + /// Reconnect and overwrite the old connection. + /// + /// The `current` guard points to the shared future that was active + /// when the connection loss was detected. + fn reconnect(&self, current: arc_swap::Guard>>) { + let client = self.client.clone(); + let retry_strategy = self.retry_strategy.clone(); + let number_of_retries = self.number_of_retries; + let response_timeout = self.response_timeout; + let connection_timeout = self.connection_timeout; + let pmc = self.push_manager.clone(); + let new_connection: SharedRedisFuture = async move { + let mut con = Self::new_connection( + client, + retry_strategy, + number_of_retries, + response_timeout, + connection_timeout, + ) + .await?; + con.set_push_manager(pmc).await; + Ok(con) + } + .boxed() + .shared(); + + // Update the connection in the connection manager + let new_connection_arc = Arc::new(new_connection.clone()); + let prev = self + .connection + .compare_and_swap(¤t, new_connection_arc); + + // If the swap happened... + if Arc::ptr_eq(&prev, ¤t) { + // ...start the connection attempt immediately but do not wait on it. + self.runtime.spawn(new_connection.map(|_| ())); + } + } + + /// Sends an already encoded (packed) command into the TCP socket and + /// reads the single response from it. + pub async fn send_packed_command(&mut self, cmd: &Cmd) -> RedisResult { + // Clone connection to avoid having to lock the ArcSwap in write mode + let guard = self.connection.load(); + let connection_result = (**guard) + .clone() + .await + .map_err(|e| e.clone_mostly("Reconnecting failed")); + reconnect_if_io_error!(self, connection_result, guard); + let result = connection_result?.send_packed_command(cmd).await; + reconnect_if_dropped!(self, &result, guard); + result + } + + /// Sends multiple already encoded (packed) command into the TCP socket + /// and reads `count` responses from it. This is used to implement + /// pipelining. + pub async fn send_packed_commands( + &mut self, + cmd: &crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisResult> { + // Clone shared connection future to avoid having to lock the ArcSwap in write mode + let guard = self.connection.load(); + let connection_result = (**guard) + .clone() + .await + .map_err(|e| e.clone_mostly("Reconnecting failed")); + reconnect_if_io_error!(self, connection_result, guard); + let result = connection_result? + .send_packed_commands(cmd, offset, count) + .await; + reconnect_if_dropped!(self, &result, guard); + result + } + + /// Returns `PushManager` of Connection, this method is used to subscribe/unsubscribe from Push types + pub fn get_push_manager(&self) -> PushManager { + self.push_manager.clone() + } +} + +impl ConnectionLike for ConnectionManager { + fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> { + (async move { self.send_packed_command(cmd).await }).boxed() + } + + fn req_packed_commands<'a>( + &'a mut self, + cmd: &'a crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisFuture<'a, Vec> { + (async move { self.send_packed_commands(cmd, offset, count).await }).boxed() + } + + fn get_db(&self) -> i64 { + self.client.connection_info().redis.db + } + + fn is_closed(&self) -> bool { + // always return false due to automatic reconnect + false + } +} diff --git a/glide-core/redis-rs/redis/src/aio/mod.rs b/glide-core/redis-rs/redis/src/aio/mod.rs new file mode 100644 index 0000000000..ffe2c9e3a2 --- /dev/null +++ b/glide-core/redis-rs/redis/src/aio/mod.rs @@ -0,0 +1,286 @@ +//! Adds async IO support to redis. +use crate::cmd::{cmd, Cmd}; +use crate::connection::{ + get_resp3_hello_command_error, PubSubSubscriptionKind, RedisConnectionInfo, +}; +use crate::types::{ErrorKind, ProtocolVersion, RedisFuture, RedisResult, Value}; +use crate::PushKind; +use ::tokio::io::{AsyncRead, AsyncWrite}; +use async_trait::async_trait; +use futures_util::Future; +use std::net::SocketAddr; +#[cfg(unix)] +use std::path::Path; +use std::pin::Pin; +use std::time::Duration; + +/// Enables the async_std compatibility +#[cfg(feature = "async-std-comp")] +#[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] +pub mod async_std; + +#[cfg(feature = "tls-rustls")] +use crate::tls::TlsConnParams; + +#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] +use crate::connection::TlsConnParams; + +/// Enables the tokio compatibility +#[cfg(feature = "tokio-comp")] +#[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] +pub mod tokio; + +/// Represents the ability of connecting via TCP or via Unix socket +#[async_trait] +pub(crate) trait RedisRuntime: AsyncStream + Send + Sync + Sized + 'static { + /// Performs a TCP connection + async fn connect_tcp(socket_addr: SocketAddr) -> RedisResult; + + // Performs a TCP TLS connection + #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))] + async fn connect_tcp_tls( + hostname: &str, + socket_addr: SocketAddr, + insecure: bool, + tls_params: &Option, + ) -> RedisResult; + + /// Performs a UNIX connection + #[cfg(unix)] + async fn connect_unix(path: &Path) -> RedisResult; + + fn spawn(f: impl Future + Send + 'static); + + fn boxed(self) -> Pin> { + Box::pin(self) + } +} + +/// Trait for objects that implements `AsyncRead` and `AsyncWrite` +pub trait AsyncStream: AsyncRead + AsyncWrite {} +impl AsyncStream for S where S: AsyncRead + AsyncWrite {} + +/// An async abstraction over connections. +pub trait ConnectionLike { + /// Sends an already encoded (packed) command into the TCP socket and + /// reads the single response from it. + fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value>; + + /// Sends multiple already encoded (packed) command into the TCP socket + /// and reads `count` responses from it. This is used to implement + /// pipelining. + /// Important - this function is meant for internal usage, since it's + /// easy to pass incorrect `offset` & `count` parameters, which might + /// cause the connection to enter an erroneous state. Users shouldn't + /// call it, instead using the Pipeline::query_async function. + #[doc(hidden)] + fn req_packed_commands<'a>( + &'a mut self, + cmd: &'a crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisFuture<'a, Vec>; + + /// Returns the database this connection is bound to. Note that this + /// information might be unreliable because it's initially cached and + /// also might be incorrect if the connection like object is not + /// actually connected. + fn get_db(&self) -> i64; + + /// Returns the state of the connection + fn is_closed(&self) -> bool; +} + +/// Implements ability to notify about disconnection events +#[async_trait] +pub trait DisconnectNotifier: Send + Sync { + /// Notify about disconnect event + fn notify_disconnect(&mut self); + + /// Wait for disconnect event with timeout + async fn wait_for_disconnect_with_timeout(&self, max_wait: &Duration); + + /// Intended to be used with Box + fn clone_box(&self) -> Box; +} + +impl Clone for Box { + fn clone(&self) -> Box { + self.clone_box() + } +} + +// Initial setup for every connection. +async fn setup_connection(connection_info: &RedisConnectionInfo, con: &mut C) -> RedisResult<()> +where + C: ConnectionLike, +{ + if connection_info.protocol != ProtocolVersion::RESP2 { + let hello_cmd = resp3_hello(connection_info); + let val: RedisResult = hello_cmd.query_async(con).await; + if let Err(err) = val { + return Err(get_resp3_hello_command_error(err)); + } + } else if let Some(password) = &connection_info.password { + let mut command = cmd("AUTH"); + if let Some(username) = &connection_info.username { + command.arg(username); + } + match command.arg(password).query_async(con).await { + Ok(Value::Okay) => (), + Err(e) => { + let err_msg = e.detail().ok_or(( + ErrorKind::AuthenticationFailed, + "Password authentication failed", + ))?; + + if !err_msg.contains("wrong number of arguments for 'auth' command") { + fail!(( + ErrorKind::AuthenticationFailed, + "Password authentication failed", + )); + } + + let mut command = cmd("AUTH"); + match command.arg(password).query_async(con).await { + Ok(Value::Okay) => (), + _ => { + fail!(( + ErrorKind::AuthenticationFailed, + "Password authentication failed" + )); + } + } + } + _ => { + fail!(( + ErrorKind::AuthenticationFailed, + "Password authentication failed" + )); + } + } + } + + if connection_info.db != 0 { + match cmd("SELECT").arg(connection_info.db).query_async(con).await { + Ok(Value::Okay) => (), + _ => fail!(( + ErrorKind::ResponseError, + "Redis server refused to switch database" + )), + } + } + + if let Some(client_name) = &connection_info.client_name { + match cmd("CLIENT") + .arg("SETNAME") + .arg(client_name) + .query_async(con) + .await + { + Ok(Value::Okay) => {} + _ => fail!(( + ErrorKind::ResponseError, + "Redis server refused to set client name" + )), + } + } + + // result is ignored, as per the command's instructions. + // https://redis.io/commands/client-setinfo/ + #[cfg(not(feature = "disable-client-setinfo"))] + let _: RedisResult<()> = crate::connection::client_set_info_pipeline() + .query_async(con) + .await; + + // resubscribe + if connection_info.protocol != ProtocolVersion::RESP3 { + return Ok(()); + } + static KIND_TO_COMMAND: [(PubSubSubscriptionKind, &str); 3] = [ + (PubSubSubscriptionKind::Exact, "SUBSCRIBE"), + (PubSubSubscriptionKind::Pattern, "PSUBSCRIBE"), + (PubSubSubscriptionKind::Sharded, "SSUBSCRIBE"), + ]; + + if connection_info.pubsub_subscriptions.is_none() { + return Ok(()); + } + + for (subscription_kind, channels_patterns) in + connection_info.pubsub_subscriptions.as_ref().unwrap() + { + for channel_pattern in channels_patterns.iter() { + let mut subscribe_command = + cmd(KIND_TO_COMMAND[Into::::into(*subscription_kind)].1); + subscribe_command.arg(channel_pattern); + + // This is a quite intricate code - Per RESP3, subscriptions commands do not return anything. + // Instead, push messages will be pushed for each channel. Thus, this is not a typycal request-response pattern. + // The act of pushing is asyncronous with the regard to the subscription command, and might be delayed for some time after the server state was already updated. + // (i.e. the behaviour is implementation defined). + // We will assume the configured time out is enough for the server to push the notifications. + match subscribe_command.query_async(con).await { + Ok(Value::Push { kind, data }) => { + match *subscription_kind { + PubSubSubscriptionKind::Exact => { + if kind != PushKind::Subscribe + || Value::BulkString(channel_pattern.clone()) != data[0] + { + fail!(( + ErrorKind::ResponseError, + // TODO: Consider printing the exact command + "Failed to restore Exact subscription channels" + )); + } + } + PubSubSubscriptionKind::Pattern => { + if kind != PushKind::PSubscribe + || Value::BulkString(channel_pattern.clone()) != data[0] + { + fail!(( + ErrorKind::ResponseError, + // TODO: Consider printing the exact command + "Failed to restore Pattern subscription channels" + )); + } + } + PubSubSubscriptionKind::Sharded => { + if kind != PushKind::SSubscribe + || Value::BulkString(channel_pattern.clone()) != data[0] + { + fail!(( + ErrorKind::ResponseError, + // TODO: Consider printing the exact command + "Failed to restore Sharded subscription channels" + )); + } + } + } + } + _ => { + fail!(( + ErrorKind::ResponseError, + // TODO: Consider printing the exact command + "Failed to receive subscription notification while restoring subscription channels" + )); + } + } + } + } + + Ok(()) +} + +mod connection; +pub use connection::*; +mod multiplexed_connection; +pub use multiplexed_connection::*; +#[cfg(feature = "connection-manager")] +mod connection_manager; +#[cfg(feature = "connection-manager")] +#[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))] +pub use connection_manager::*; +mod runtime; +use crate::commands::resp3_hello; +pub(super) use runtime::*; diff --git a/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs b/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs new file mode 100644 index 0000000000..1067bc2df5 --- /dev/null +++ b/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs @@ -0,0 +1,656 @@ +use super::{ConnectionLike, Runtime}; +use crate::aio::setup_connection; +use crate::aio::DisconnectNotifier; +use crate::client::GlideConnectionOptions; +use crate::cmd::Cmd; +#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] +use crate::parser::ValueCodec; +use crate::push_manager::PushManager; +use crate::types::{RedisError, RedisFuture, RedisResult, Value}; +use crate::{cmd, ConnectionInfo, ProtocolVersion, PushKind}; +use ::tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{mpsc, oneshot}, +}; +use arc_swap::ArcSwap; +use futures_util::{ + future::{Future, FutureExt}, + ready, + sink::Sink, + stream::{self, Stream, StreamExt, TryStreamExt as _}, +}; +use pin_project_lite::pin_project; +use std::collections::VecDeque; +use std::fmt; +use std::fmt::Debug; +use std::io; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::task::{self, Poll}; +use std::time::Duration; +#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] +use tokio_util::codec::Decoder; + +// Senders which the result of a single request are sent through +type PipelineOutput = oneshot::Sender>; + +enum ResponseAggregate { + SingleCommand, + Pipeline { + expected_response_count: usize, + current_response_count: usize, + buffer: Vec, + first_err: Option, + }, +} + +impl ResponseAggregate { + fn new(pipeline_response_count: Option) -> Self { + match pipeline_response_count { + Some(response_count) => ResponseAggregate::Pipeline { + expected_response_count: response_count, + current_response_count: 0, + buffer: Vec::new(), + first_err: None, + }, + None => ResponseAggregate::SingleCommand, + } + } +} + +struct InFlight { + output: PipelineOutput, + response_aggregate: ResponseAggregate, +} + +// A single message sent through the pipeline +struct PipelineMessage { + input: S, + output: PipelineOutput, + // If `None`, this is a single request, not a pipeline of multiple requests. + pipeline_response_count: Option, +} + +/// Wrapper around a `Stream + Sink` where each item sent through the `Sink` results in one or more +/// items being output by the `Stream` (the number is specified at time of sending). With the +/// interface provided by `Pipeline` an easy interface of request to response, hiding the `Stream` +/// and `Sink`. +#[derive(Clone)] +struct Pipeline { + sender: mpsc::Sender>, + push_manager: Arc>, + is_stream_closed: Arc, +} + +impl Debug for Pipeline +where + SinkItem: Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Pipeline").field(&self.sender).finish() + } +} + +pin_project! { + struct PipelineSink { + #[pin] + sink_stream: T, + in_flight: VecDeque, + error: Option, + push_manager: Arc>, + disconnect_notifier: Option>, + is_stream_closed: Arc, + } +} + +impl PipelineSink +where + T: Stream> + 'static, +{ + fn new( + sink_stream: T, + push_manager: Arc>, + disconnect_notifier: Option>, + is_stream_closed: Arc, + ) -> Self + where + T: Sink + Stream> + 'static, + { + PipelineSink { + sink_stream, + in_flight: VecDeque::new(), + error: None, + push_manager, + disconnect_notifier, + is_stream_closed, + } + } + + // Read messages from the stream and send them back to the caller + fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { + loop { + let item = match ready!(self.as_mut().project().sink_stream.poll_next(cx)) { + Some(result) => result, + // The redis response stream is not going to produce any more items so we `Err` + // to break out of the `forward` combinator and stop handling requests + None => { + // this is the right place to notify about the passive TCP disconnect + // In other places we cannot distinguish between the active destruction of MultiplexedConnection and passive disconnect + if let Some(disconnect_notifier) = self.as_mut().project().disconnect_notifier { + disconnect_notifier.notify_disconnect(); + } + self.is_stream_closed.store(true, Ordering::Relaxed); + return Poll::Ready(Err(())); + } + }; + self.as_mut().send_result(item); + } + } + + fn send_result(self: Pin<&mut Self>, result: RedisResult) { + let self_ = self.project(); + let mut skip_value = false; + if let Ok(res) = &result { + if let Value::Push { kind, data: _data } = res { + self_.push_manager.load().try_send_raw(res); + if !kind.has_reply() { + // If it's not true then push kind is converted to reply of a command + skip_value = true; + } + } + } + + let mut entry = match self_.in_flight.pop_front() { + Some(entry) => entry, + None => return, + }; + + if skip_value { + self_.in_flight.push_front(entry); + return; + } + + match &mut entry.response_aggregate { + ResponseAggregate::SingleCommand => { + entry.output.send(result).ok(); + } + ResponseAggregate::Pipeline { + expected_response_count, + current_response_count, + buffer, + first_err, + } => { + match result { + Ok(item) => { + buffer.push(item); + } + Err(err) => { + if first_err.is_none() { + *first_err = Some(err); + } + } + } + + *current_response_count += 1; + if current_response_count < expected_response_count { + // Need to gather more response values + self_.in_flight.push_front(entry); + return; + } + + let response = match first_err.take() { + Some(err) => Err(err), + None => Ok(Value::Array(std::mem::take(buffer))), + }; + + // `Err` means that the receiver was dropped in which case it does not + // care about the output and we can continue by just dropping the value + // and sender + entry.output.send(response).ok(); + } + } + } +} + +impl Sink> for PipelineSink +where + T: Sink + Stream> + 'static, +{ + type Error = (); + + // Retrieve incoming messages and write them to the sink + fn poll_ready( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + ) -> Poll> { + match ready!(self.as_mut().project().sink_stream.poll_ready(cx)) { + Ok(()) => Ok(()).into(), + Err(err) => { + *self.project().error = Some(err); + Ok(()).into() + } + } + } + + fn start_send( + mut self: Pin<&mut Self>, + PipelineMessage { + input, + output, + pipeline_response_count, + }: PipelineMessage, + ) -> Result<(), Self::Error> { + // If there is nothing to receive our output we do not need to send the message as it is + // ambiguous whether the message will be sent anyway. Helps shed some load on the + // connection. + if output.is_closed() { + return Ok(()); + } + + let self_ = self.as_mut().project(); + + if let Some(err) = self_.error.take() { + let _ = output.send(Err(err)); + return Err(()); + } + + match self_.sink_stream.start_send(input) { + Ok(()) => { + let response_aggregate = ResponseAggregate::new(pipeline_response_count); + let entry = InFlight { + output, + response_aggregate, + }; + + self_.in_flight.push_back(entry); + Ok(()) + } + Err(err) => { + let _ = output.send(Err(err)); + Err(()) + } + } + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + ) -> Poll> { + ready!(self + .as_mut() + .project() + .sink_stream + .poll_flush(cx) + .map_err(|err| { + self.as_mut().send_result(Err(err)); + }))?; + self.poll_read(cx) + } + + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + ) -> Poll> { + // No new requests will come in after the first call to `close` but we need to complete any + // in progress requests before closing + if !self.in_flight.is_empty() { + ready!(self.as_mut().poll_flush(cx))?; + } + let this = self.as_mut().project(); + this.sink_stream.poll_close(cx).map_err(|err| { + self.send_result(Err(err)); + }) + } +} + +impl Pipeline +where + SinkItem: Send + 'static, +{ + fn new( + sink_stream: T, + disconnect_notifier: Option>, + ) -> (Self, impl Future) + where + T: Sink + Stream> + 'static, + T: Send + 'static, + T::Item: Send, + T::Error: Send, + T::Error: ::std::fmt::Debug, + { + const BUFFER_SIZE: usize = 50; + let (sender, mut receiver) = mpsc::channel(BUFFER_SIZE); + let push_manager: Arc> = + Arc::new(ArcSwap::new(Arc::new(PushManager::default()))); + let is_stream_closed = Arc::new(AtomicBool::new(false)); + let sink = PipelineSink::new::( + sink_stream, + push_manager.clone(), + disconnect_notifier, + is_stream_closed.clone(), + ); + let f = stream::poll_fn(move |cx| receiver.poll_recv(cx)) + .map(Ok) + .forward(sink) + .map(|_| ()); + ( + Pipeline { + sender, + push_manager, + is_stream_closed, + }, + f, + ) + } + + // `None` means that the stream was out of items causing that poll loop to shut down. + async fn send_single( + &mut self, + item: SinkItem, + timeout: Duration, + ) -> Result> { + self.send_recv(item, None, timeout).await + } + + async fn send_recv( + &mut self, + input: SinkItem, + // If `None`, this is a single request, not a pipeline of multiple requests. + pipeline_response_count: Option, + timeout: Duration, + ) -> Result> { + let (sender, receiver) = oneshot::channel(); + + self.sender + .send(PipelineMessage { + input, + pipeline_response_count, + output: sender, + }) + .await + .map_err(|_| None)?; + match Runtime::locate().timeout(timeout, receiver).await { + Ok(Ok(result)) => result.map_err(Some), + Ok(Err(_)) => { + // The `sender` was dropped which likely means that the stream part + // failed for one reason or another + Err(None) + } + Err(elapsed) => Err(Some(elapsed.into())), + } + } + + /// Sets `PushManager` of Pipeline + async fn set_push_manager(&mut self, push_manager: PushManager) { + self.push_manager.store(Arc::new(push_manager)); + } + + pub fn is_closed(&self) -> bool { + self.is_stream_closed.load(Ordering::Relaxed) + } +} + +/// A connection object which can be cloned, allowing requests to be be sent concurrently +/// on the same underlying connection (tcp/unix socket). +#[derive(Clone)] +pub struct MultiplexedConnection { + pipeline: Pipeline>, + db: i64, + response_timeout: Duration, + protocol: ProtocolVersion, + push_manager: PushManager, +} + +impl Debug for MultiplexedConnection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MultiplexedConnection") + .field("pipeline", &self.pipeline) + .field("db", &self.db) + .finish() + } +} + +impl MultiplexedConnection { + /// Constructs a new `MultiplexedConnection` out of a `AsyncRead + AsyncWrite` object + /// and a `ConnectionInfo` + pub async fn new( + connection_info: &ConnectionInfo, + stream: C, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult<(Self, impl Future)> + where + C: Unpin + AsyncRead + AsyncWrite + Send + 'static, + { + Self::new_with_response_timeout( + connection_info, + stream, + std::time::Duration::MAX, + glide_connection_options, + ) + .await + } + + /// Constructs a new `MultiplexedConnection` out of a `AsyncRead + AsyncWrite` object + /// and a `ConnectionInfo`. The new object will wait on operations for the given `response_timeout`. + pub async fn new_with_response_timeout( + connection_info: &ConnectionInfo, + stream: C, + response_timeout: std::time::Duration, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult<(Self, impl Future)> + where + C: Unpin + AsyncRead + AsyncWrite + Send + 'static, + { + fn boxed( + f: impl Future + Send + 'static, + ) -> Pin + Send>> { + Box::pin(f) + } + + #[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))] + compile_error!("tokio-comp or async-std-comp features required for aio feature"); + + let redis_connection_info = &connection_info.redis; + let codec = ValueCodec::default() + .framed(stream) + .and_then(|msg| async move { msg }); + let (mut pipeline, driver) = + Pipeline::new(codec, glide_connection_options.disconnect_notifier); + let driver = boxed(driver); + let pm = PushManager::default(); + if let Some(sender) = glide_connection_options.push_sender { + pm.replace_sender(sender); + } + + pipeline.set_push_manager(pm.clone()).await; + let mut con = MultiplexedConnection { + pipeline, + db: connection_info.redis.db, + response_timeout, + push_manager: pm, + protocol: redis_connection_info.protocol, + }; + let driver = { + let auth = setup_connection(&connection_info.redis, &mut con); + + futures_util::pin_mut!(auth); + + match futures_util::future::select(auth, driver).await { + futures_util::future::Either::Left((result, driver)) => { + result?; + driver + } + futures_util::future::Either::Right(((), _)) => { + return Err(RedisError::from(( + crate::ErrorKind::IoError, + "Multiplexed connection driver unexpectedly terminated", + ))); + } + } + }; + Ok((con, driver)) + } + + /// Sets the time that the multiplexer will wait for responses on operations before failing. + pub fn set_response_timeout(&mut self, timeout: std::time::Duration) { + self.response_timeout = timeout; + } + + /// Sends an already encoded (packed) command into the TCP socket and + /// reads the single response from it. + pub async fn send_packed_command(&mut self, cmd: &Cmd) -> RedisResult { + let result = self + .pipeline + .send_single(cmd.get_packed_command(), self.response_timeout) + .await + .map_err(|err| { + err.unwrap_or_else(|| RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe))) + }); + if self.protocol != ProtocolVersion::RESP2 { + if let Err(e) = &result { + if e.is_connection_dropped() { + // Notify the PushManager that the connection was lost + self.push_manager.try_send_raw(&Value::Push { + kind: PushKind::Disconnection, + data: vec![], + }); + } + } + } + result + } + + /// Sends multiple already encoded (packed) command into the TCP socket + /// and reads `count` responses from it. This is used to implement + /// pipelining. + pub async fn send_packed_commands( + &mut self, + cmd: &crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisResult> { + let result = self + .pipeline + .send_recv( + cmd.get_packed_pipeline(), + Some(offset + count), + self.response_timeout, + ) + .await + .map_err(|err| { + err.unwrap_or_else(|| RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe))) + }); + + if self.protocol != ProtocolVersion::RESP2 { + if let Err(e) = &result { + if e.is_connection_dropped() { + // Notify the PushManager that the connection was lost + self.push_manager.try_send_raw(&Value::Push { + kind: PushKind::Disconnection, + data: vec![], + }); + } + } + } + let value = result?; + match value { + Value::Array(mut values) => { + values.drain(..offset); + Ok(values) + } + _ => Ok(vec![value]), + } + } + + /// Sets `PushManager` of connection + pub async fn set_push_manager(&mut self, push_manager: PushManager) { + self.push_manager = push_manager.clone(); + self.pipeline.set_push_manager(push_manager).await; + } +} + +impl ConnectionLike for MultiplexedConnection { + fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> { + (async move { self.send_packed_command(cmd).await }).boxed() + } + + fn req_packed_commands<'a>( + &'a mut self, + cmd: &'a crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisFuture<'a, Vec> { + (async move { self.send_packed_commands(cmd, offset, count).await }).boxed() + } + + fn get_db(&self) -> i64 { + self.db + } + + fn is_closed(&self) -> bool { + self.pipeline.is_closed() + } +} +impl MultiplexedConnection { + /// Subscribes to a new channel. + pub async fn subscribe(&mut self, channel_name: String) -> RedisResult<()> { + if self.protocol == ProtocolVersion::RESP2 { + return Err(RedisError::from(( + crate::ErrorKind::InvalidClientConfig, + "RESP3 is required for this command", + ))); + } + let mut cmd = cmd("SUBSCRIBE"); + cmd.arg(channel_name.clone()); + cmd.query_async(self).await?; + Ok(()) + } + + /// Unsubscribes from channel. + pub async fn unsubscribe(&mut self, channel_name: String) -> RedisResult<()> { + if self.protocol == ProtocolVersion::RESP2 { + return Err(RedisError::from(( + crate::ErrorKind::InvalidClientConfig, + "RESP3 is required for this command", + ))); + } + let mut cmd = cmd("UNSUBSCRIBE"); + cmd.arg(channel_name); + cmd.query_async(self).await?; + Ok(()) + } + + /// Subscribes to a new channel with pattern. + pub async fn psubscribe(&mut self, channel_pattern: String) -> RedisResult<()> { + if self.protocol == ProtocolVersion::RESP2 { + return Err(RedisError::from(( + crate::ErrorKind::InvalidClientConfig, + "RESP3 is required for this command", + ))); + } + let mut cmd = cmd("PSUBSCRIBE"); + cmd.arg(channel_pattern.clone()); + cmd.query_async(self).await?; + Ok(()) + } + + /// Unsubscribes from channel pattern. + pub async fn punsubscribe(&mut self, channel_pattern: String) -> RedisResult<()> { + if self.protocol == ProtocolVersion::RESP2 { + return Err(RedisError::from(( + crate::ErrorKind::InvalidClientConfig, + "RESP3 is required for this command", + ))); + } + let mut cmd = cmd("PUNSUBSCRIBE"); + cmd.arg(channel_pattern); + cmd.query_async(self).await?; + Ok(()) + } + + /// Returns `PushManager` of Connection, this method is used to subscribe/unsubscribe from Push types + pub fn get_push_manager(&self) -> PushManager { + self.push_manager.clone() + } +} diff --git a/glide-core/redis-rs/redis/src/aio/runtime.rs b/glide-core/redis-rs/redis/src/aio/runtime.rs new file mode 100644 index 0000000000..5755f62c9f --- /dev/null +++ b/glide-core/redis-rs/redis/src/aio/runtime.rs @@ -0,0 +1,82 @@ +use std::{io, time::Duration}; + +use futures_util::Future; + +#[cfg(feature = "async-std-comp")] +use super::async_std; +#[cfg(feature = "tokio-comp")] +use super::tokio; +use super::RedisRuntime; +use crate::types::RedisError; + +#[derive(Clone, Debug)] +pub(crate) enum Runtime { + #[cfg(feature = "tokio-comp")] + Tokio, + #[cfg(feature = "async-std-comp")] + AsyncStd, +} + +impl Runtime { + pub(crate) fn locate() -> Self { + #[cfg(all(feature = "tokio-comp", not(feature = "async-std-comp")))] + { + Runtime::Tokio + } + + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + { + Runtime::AsyncStd + } + + #[cfg(all(feature = "tokio-comp", feature = "async-std-comp"))] + { + if ::tokio::runtime::Handle::try_current().is_ok() { + Runtime::Tokio + } else { + Runtime::AsyncStd + } + } + + #[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))] + { + compile_error!("tokio-comp or async-std-comp features required for aio feature") + } + } + + #[allow(dead_code)] + pub(super) fn spawn(&self, f: impl Future + Send + 'static) { + match self { + #[cfg(feature = "tokio-comp")] + Runtime::Tokio => tokio::Tokio::spawn(f), + #[cfg(feature = "async-std-comp")] + Runtime::AsyncStd => async_std::AsyncStd::spawn(f), + } + } + + pub(crate) async fn timeout( + &self, + duration: Duration, + future: F, + ) -> Result { + match self { + #[cfg(feature = "tokio-comp")] + Runtime::Tokio => ::tokio::time::timeout(duration, future) + .await + .map_err(|_| Elapsed(())), + #[cfg(feature = "async-std-comp")] + Runtime::AsyncStd => ::async_std::future::timeout(duration, future) + .await + .map_err(|_| Elapsed(())), + } + } +} + +#[derive(Debug)] +pub(crate) struct Elapsed(()); + +impl From for RedisError { + fn from(_: Elapsed) -> Self { + io::Error::from(io::ErrorKind::TimedOut).into() + } +} diff --git a/glide-core/redis-rs/redis/src/aio/tokio.rs b/glide-core/redis-rs/redis/src/aio/tokio.rs new file mode 100644 index 0000000000..3a68c0ebfc --- /dev/null +++ b/glide-core/redis-rs/redis/src/aio/tokio.rs @@ -0,0 +1,204 @@ +use super::{AsyncStream, RedisResult, RedisRuntime, SocketAddr}; +use async_trait::async_trait; +#[allow(unused_imports)] // fixes "Duration" unused when built with non-default feature set +use std::{ + future::Future, + io, + pin::Pin, + task::{self, Poll}, + time::Duration, +}; +#[cfg(unix)] +use tokio::net::UnixStream as UnixStreamTokio; +use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf}, + net::TcpStream as TcpStreamTokio, +}; + +#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] +use native_tls::TlsConnector; + +#[cfg(feature = "tls-rustls")] +use crate::connection::create_rustls_config; +#[cfg(feature = "tls-rustls")] +use std::sync::Arc; +#[cfg(feature = "tls-rustls")] +use tokio_rustls::{client::TlsStream, TlsConnector}; + +#[cfg(all(feature = "tokio-native-tls-comp", not(feature = "tokio-rustls-comp")))] +use tokio_native_tls::TlsStream; + +#[cfg(feature = "tokio-rustls-comp")] +use crate::tls::TlsConnParams; + +#[cfg(all(feature = "tokio-native-tls-comp", not(feature = "tls-rustls")))] +use crate::connection::TlsConnParams; + +#[cfg(unix)] +use super::Path; + +#[inline(always)] +async fn connect_tcp(addr: &SocketAddr) -> io::Result { + let socket = TcpStreamTokio::connect(addr).await?; + #[cfg(feature = "tcp_nodelay")] + socket.set_nodelay(true)?; + #[cfg(feature = "keep-alive")] + { + //For now rely on system defaults + const KEEP_ALIVE: socket2::TcpKeepalive = socket2::TcpKeepalive::new(); + //these are useless error that not going to happen + let std_socket = socket.into_std()?; + let socket2: socket2::Socket = std_socket.into(); + socket2.set_tcp_keepalive(&KEEP_ALIVE)?; + // TCP_USER_TIMEOUT configuration isn't supported across all operation systems + #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] + { + // TODO: Replace this hardcoded timeout with a configurable timeout when https://github.com/redis-rs/redis-rs/issues/1147 is resolved + const DFEAULT_USER_TCP_TIMEOUT: Duration = Duration::from_secs(5); + socket2.set_tcp_user_timeout(Some(DFEAULT_USER_TCP_TIMEOUT))?; + } + TcpStreamTokio::from_std(socket2.into()) + } + + #[cfg(not(feature = "keep-alive"))] + { + Ok(socket) + } +} + +pub(crate) enum Tokio { + /// Represents a Tokio TCP connection. + Tcp(TcpStreamTokio), + /// Represents a Tokio TLS encrypted TCP connection + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] + TcpTls(Box>), + /// Represents a Tokio Unix connection. + #[cfg(unix)] + Unix(UnixStreamTokio), +} + +impl AsyncWrite for Tokio { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + buf: &[u8], + ) -> Poll> { + match &mut *self { + Tokio::Tcp(r) => Pin::new(r).poll_write(cx, buf), + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] + Tokio::TcpTls(r) => Pin::new(r).poll_write(cx, buf), + #[cfg(unix)] + Tokio::Unix(r) => Pin::new(r).poll_write(cx, buf), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { + match &mut *self { + Tokio::Tcp(r) => Pin::new(r).poll_flush(cx), + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] + Tokio::TcpTls(r) => Pin::new(r).poll_flush(cx), + #[cfg(unix)] + Tokio::Unix(r) => Pin::new(r).poll_flush(cx), + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { + match &mut *self { + Tokio::Tcp(r) => Pin::new(r).poll_shutdown(cx), + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] + Tokio::TcpTls(r) => Pin::new(r).poll_shutdown(cx), + #[cfg(unix)] + Tokio::Unix(r) => Pin::new(r).poll_shutdown(cx), + } + } +} + +impl AsyncRead for Tokio { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match &mut *self { + Tokio::Tcp(r) => Pin::new(r).poll_read(cx, buf), + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] + Tokio::TcpTls(r) => Pin::new(r).poll_read(cx, buf), + #[cfg(unix)] + Tokio::Unix(r) => Pin::new(r).poll_read(cx, buf), + } + } +} + +#[async_trait] +impl RedisRuntime for Tokio { + async fn connect_tcp(socket_addr: SocketAddr) -> RedisResult { + Ok(connect_tcp(&socket_addr).await.map(Tokio::Tcp)?) + } + + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + async fn connect_tcp_tls( + hostname: &str, + socket_addr: SocketAddr, + insecure: bool, + _: &Option, + ) -> RedisResult { + let tls_connector: tokio_native_tls::TlsConnector = if insecure { + TlsConnector::builder() + .danger_accept_invalid_certs(true) + .danger_accept_invalid_hostnames(true) + .use_sni(false) + .build()? + } else { + TlsConnector::new()? + } + .into(); + Ok(tls_connector + .connect(hostname, connect_tcp(&socket_addr).await?) + .await + .map(|con| Tokio::TcpTls(Box::new(con)))?) + } + + #[cfg(feature = "tls-rustls")] + async fn connect_tcp_tls( + hostname: &str, + socket_addr: SocketAddr, + insecure: bool, + tls_params: &Option, + ) -> RedisResult { + let config = create_rustls_config(insecure, tls_params.clone())?; + let tls_connector = TlsConnector::from(Arc::new(config)); + + Ok(tls_connector + .connect( + rustls_pki_types::ServerName::try_from(hostname)?.to_owned(), + connect_tcp(&socket_addr).await?, + ) + .await + .map(|con| Tokio::TcpTls(Box::new(con)))?) + } + + #[cfg(unix)] + async fn connect_unix(path: &Path) -> RedisResult { + Ok(UnixStreamTokio::connect(path).await.map(Tokio::Unix)?) + } + + #[cfg(feature = "tokio-comp")] + fn spawn(f: impl Future + Send + 'static) { + tokio::spawn(f); + } + + #[cfg(not(feature = "tokio-comp"))] + fn spawn(_: impl Future + Send + 'static) { + unreachable!() + } + + fn boxed(self) -> Pin> { + match self { + Tokio::Tcp(x) => Box::pin(x), + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] + Tokio::TcpTls(x) => Box::pin(x), + #[cfg(unix)] + Tokio::Unix(x) => Box::pin(x), + } + } +} diff --git a/glide-core/redis-rs/redis/src/client.rs b/glide-core/redis-rs/redis/src/client.rs new file mode 100644 index 0000000000..5e3f144e71 --- /dev/null +++ b/glide-core/redis-rs/redis/src/client.rs @@ -0,0 +1,855 @@ +use std::time::Duration; + +#[cfg(feature = "aio")] +use crate::aio::DisconnectNotifier; + +use crate::{ + connection::{connect, Connection, ConnectionInfo, ConnectionLike, IntoConnectionInfo}, + push_manager::PushInfo, + types::{RedisResult, Value}, +}; +#[cfg(feature = "aio")] +use std::net::IpAddr; +#[cfg(feature = "aio")] +use std::net::SocketAddr; +#[cfg(feature = "aio")] +use std::pin::Pin; +use tokio::sync::mpsc; + +#[cfg(feature = "tls-rustls")] +use crate::tls::{inner_build_with_tls, TlsCertificates}; + +/// The client type. +#[derive(Debug, Clone)] +pub struct Client { + pub(crate) connection_info: ConnectionInfo, +} + +/// The client acts as connector to the redis server. By itself it does not +/// do much other than providing a convenient way to fetch a connection from +/// it. In the future the plan is to provide a connection pool in the client. +/// +/// When opening a client a URL in the following format should be used: +/// +/// ```plain +/// redis://host:port/db +/// ``` +/// +/// Example usage:: +/// +/// ```rust,no_run +/// let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +/// let con = client.get_connection(None).unwrap(); +/// ``` +impl Client { + /// Connects to a redis server and returns a client. This does not + /// actually open a connection yet but it does perform some basic + /// checks on the URL that might make the operation fail. + pub fn open(params: T) -> RedisResult { + Ok(Client { + connection_info: params.into_connection_info()?, + }) + } + + /// Instructs the client to actually connect to redis and returns a + /// connection object. The connection object can be used to send + /// commands to the server. This can fail with a variety of errors + /// (like unreachable host) so it's important that you handle those + /// errors. + pub fn get_connection( + &self, + _push_sender: Option>, + ) -> RedisResult { + connect(&self.connection_info, None) + } + + /// Instructs the client to actually connect to redis with specified + /// timeout and returns a connection object. The connection object + /// can be used to send commands to the server. This can fail with + /// a variety of errors (like unreachable host) so it's important + /// that you handle those errors. + pub fn get_connection_with_timeout(&self, timeout: Duration) -> RedisResult { + connect(&self.connection_info, Some(timeout)) + } + + /// Returns a reference of client connection info object. + pub fn get_connection_info(&self) -> &ConnectionInfo { + &self.connection_info + } +} + +/// Glide-specific connection options +#[derive(Clone, Default)] +pub struct GlideConnectionOptions { + /// Queue for RESP3 notifications + pub push_sender: Option>, + #[cfg(feature = "aio")] + /// Passive disconnect notifier + pub disconnect_notifier: Option>, +} + +/// To enable async support you need to chose one of the supported runtimes and active its +/// corresponding feature: `tokio-comp` or `async-std-comp` +#[cfg(feature = "aio")] +#[cfg_attr(docsrs, doc(cfg(feature = "aio")))] +impl Client { + /// Returns an async connection from the client. + #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] + #[deprecated( + note = "aio::Connection is deprecated. Use client::get_multiplexed_async_connection instead." + )] + #[allow(deprecated)] + pub async fn get_async_connection( + &self, + _push_sender: Option>, + ) -> RedisResult { + let (con, _ip) = match Runtime::locate() { + #[cfg(feature = "tokio-comp")] + Runtime::Tokio => { + self.get_simple_async_connection::(None) + .await? + } + #[cfg(feature = "async-std-comp")] + Runtime::AsyncStd => { + self.get_simple_async_connection::(None) + .await? + } + }; + + crate::aio::Connection::new(&self.connection_info.redis, con).await + } + + /// Returns an async connection from the client. + #[cfg(feature = "tokio-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] + #[deprecated( + note = "aio::Connection is deprecated. Use client::get_multiplexed_tokio_connection instead." + )] + #[allow(deprecated)] + pub async fn get_tokio_connection(&self) -> RedisResult { + use crate::aio::RedisRuntime; + Ok( + crate::aio::connect::(&self.connection_info, None) + .await? + .map(RedisRuntime::boxed), + ) + } + + /// Returns an async connection from the client. + #[cfg(feature = "async-std-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] + #[deprecated( + note = "aio::Connection is deprecated. Use client::get_multiplexed_async_std_connection instead." + )] + #[allow(deprecated)] + pub async fn get_async_std_connection(&self) -> RedisResult { + use crate::aio::RedisRuntime; + Ok( + crate::aio::connect::(&self.connection_info, None) + .await? + .map(RedisRuntime::boxed), + ) + } + + /// Returns an async connection from the client. + #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] + #[cfg_attr( + docsrs, + doc(cfg(any(feature = "tokio-comp", feature = "async-std-comp"))) + )] + pub async fn get_multiplexed_async_connection( + &self, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult { + self.get_multiplexed_async_connection_with_timeouts( + std::time::Duration::MAX, + std::time::Duration::MAX, + glide_connection_options, + ) + .await + } + + /// Returns an async connection from the client. + #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] + #[cfg_attr( + docsrs, + doc(cfg(any(feature = "tokio-comp", feature = "async-std-comp"))) + )] + pub async fn get_multiplexed_async_connection_with_timeouts( + &self, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult { + let result = match Runtime::locate() { + #[cfg(feature = "tokio-comp")] + rt @ Runtime::Tokio => { + rt.timeout( + connection_timeout, + self.get_multiplexed_async_connection_inner::( + response_timeout, + None, + glide_connection_options, + ), + ) + .await + } + #[cfg(feature = "async-std-comp")] + rt @ Runtime::AsyncStd => { + rt.timeout( + connection_timeout, + self.get_multiplexed_async_connection_inner::( + response_timeout, + None, + glide_connection_options, + ), + ) + .await + } + }; + + match result { + Ok(Ok(connection)) => Ok(connection), + Ok(Err(e)) => Err(e), + Err(elapsed) => Err(elapsed.into()), + } + .map(|(conn, _ip)| conn) + } + + /// For TCP connections: returns (async connection, Some(the direct IP address)) + /// For Unix connections, returns (async connection, None) + #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] + #[cfg_attr( + docsrs, + doc(cfg(any(feature = "tokio-comp", feature = "async-std-comp"))) + )] + pub async fn get_multiplexed_async_connection_and_ip( + &self, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult<(crate::aio::MultiplexedConnection, Option)> { + match Runtime::locate() { + #[cfg(feature = "tokio-comp")] + Runtime::Tokio => { + self.get_multiplexed_async_connection_inner::( + Duration::MAX, + None, + glide_connection_options, + ) + .await + } + #[cfg(feature = "async-std-comp")] + Runtime::AsyncStd => { + self.get_multiplexed_async_connection_inner::( + Duration::MAX, + None, + glide_connection_options, + ) + .await + } + } + } + + /// Returns an async multiplexed connection from the client. + /// + /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + #[cfg(feature = "tokio-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] + pub async fn get_multiplexed_tokio_connection_with_response_timeouts( + &self, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult { + let result = Runtime::locate() + .timeout( + connection_timeout, + self.get_multiplexed_async_connection_inner::( + response_timeout, + None, + glide_connection_options, + ), + ) + .await; + + match result { + Ok(Ok((connection, _ip))) => Ok(connection), + Ok(Err(e)) => Err(e), + Err(elapsed) => Err(elapsed.into()), + } + } + + /// Returns an async multiplexed connection from the client. + /// + /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + #[cfg(feature = "tokio-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] + pub async fn get_multiplexed_tokio_connection( + &self, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult { + self.get_multiplexed_tokio_connection_with_response_timeouts( + std::time::Duration::MAX, + std::time::Duration::MAX, + glide_connection_options, + ) + .await + } + + /// Returns an async multiplexed connection from the client. + /// + /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + #[cfg(feature = "async-std-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] + pub async fn get_multiplexed_async_std_connection_with_timeouts( + &self, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult { + let result = Runtime::locate() + .timeout( + connection_timeout, + self.get_multiplexed_async_connection_inner::( + response_timeout, + None, + glide_connection_options, + ), + ) + .await; + + match result { + Ok(Ok((connection, _ip))) => Ok(connection), + Ok(Err(e)) => Err(e), + Err(elapsed) => Err(elapsed.into()), + } + } + + /// Returns an async multiplexed connection from the client. + /// + /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + #[cfg(feature = "async-std-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] + pub async fn get_multiplexed_async_std_connection( + &self, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult { + self.get_multiplexed_async_std_connection_with_timeouts( + std::time::Duration::MAX, + std::time::Duration::MAX, + glide_connection_options, + ) + .await + } + + /// Returns an async multiplexed connection from the client and a future which must be polled + /// to drive any requests submitted to it (see `get_multiplexed_tokio_connection`). + /// + /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + /// The multiplexer will return a timeout error on any request that takes longer then `response_timeout`. + #[cfg(feature = "tokio-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] + pub async fn create_multiplexed_tokio_connection_with_response_timeout( + &self, + response_timeout: std::time::Duration, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult<( + crate::aio::MultiplexedConnection, + impl std::future::Future, + )> { + self.create_multiplexed_async_connection_inner::( + response_timeout, + None, + glide_connection_options, + ) + .await + .map(|(conn, driver, _ip)| (conn, driver)) + } + + /// Returns an async multiplexed connection from the client and a future which must be polled + /// to drive any requests submitted to it (see `get_multiplexed_tokio_connection`). + /// + /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + #[cfg(feature = "tokio-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] + pub async fn create_multiplexed_tokio_connection( + &self, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult<( + crate::aio::MultiplexedConnection, + impl std::future::Future, + )> { + self.create_multiplexed_tokio_connection_with_response_timeout( + std::time::Duration::MAX, + glide_connection_options, + ) + .await + .map(|conn_res| (conn_res.0, conn_res.1)) + } + + /// Returns an async multiplexed connection from the client and a future which must be polled + /// to drive any requests submitted to it (see `get_multiplexed_tokio_connection`). + /// + /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + /// The multiplexer will return a timeout error on any request that takes longer then `response_timeout`. + #[cfg(feature = "async-std-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] + pub async fn create_multiplexed_async_std_connection_with_response_timeout( + &self, + response_timeout: std::time::Duration, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult<( + crate::aio::MultiplexedConnection, + impl std::future::Future, + )> { + self.create_multiplexed_async_connection_inner::( + response_timeout, + None, + glide_connection_options, + ) + .await + .map(|(conn, driver, _ip)| (conn, driver)) + } + + /// Returns an async multiplexed connection from the client and a future which must be polled + /// to drive any requests submitted to it (see `get_multiplexed_tokio_connection`). + /// + /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + #[cfg(feature = "async-std-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] + pub async fn create_multiplexed_async_std_connection( + &self, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult<( + crate::aio::MultiplexedConnection, + impl std::future::Future, + )> { + self.create_multiplexed_async_std_connection_with_response_timeout( + std::time::Duration::MAX, + glide_connection_options, + ) + .await + } + + /// Returns an async [`ConnectionManager`][connection-manager] from the client. + /// + /// The connection manager wraps a + /// [`MultiplexedConnection`][multiplexed-connection]. If a command to that + /// connection fails with a connection error, then a new connection is + /// established in the background and the error is returned to the caller. + /// + /// This means that on connection loss at least one command will fail, but + /// the connection will be re-established automatically if possible. Please + /// refer to the [`ConnectionManager`][connection-manager] docs for + /// detailed reconnecting behavior. + /// + /// A connection manager can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + /// + /// [connection-manager]: aio/struct.ConnectionManager.html + /// [multiplexed-connection]: aio/struct.MultiplexedConnection.html + #[cfg(feature = "connection-manager")] + #[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))] + #[deprecated(note = "use get_connection_manager instead")] + pub async fn get_tokio_connection_manager(&self) -> RedisResult { + crate::aio::ConnectionManager::new(self.clone()).await + } + + /// Returns an async [`ConnectionManager`][connection-manager] from the client. + /// + /// The connection manager wraps a + /// [`MultiplexedConnection`][multiplexed-connection]. If a command to that + /// connection fails with a connection error, then a new connection is + /// established in the background and the error is returned to the caller. + /// + /// This means that on connection loss at least one command will fail, but + /// the connection will be re-established automatically if possible. Please + /// refer to the [`ConnectionManager`][connection-manager] docs for + /// detailed reconnecting behavior. + /// + /// A connection manager can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + /// + /// [connection-manager]: aio/struct.ConnectionManager.html + /// [multiplexed-connection]: aio/struct.MultiplexedConnection.html + #[cfg(feature = "connection-manager")] + #[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))] + pub async fn get_connection_manager(&self) -> RedisResult { + crate::aio::ConnectionManager::new(self.clone()).await + } + + /// Returns an async [`ConnectionManager`][connection-manager] from the client. + /// + /// The connection manager wraps a + /// [`MultiplexedConnection`][multiplexed-connection]. If a command to that + /// connection fails with a connection error, then a new connection is + /// established in the background and the error is returned to the caller. + /// + /// This means that on connection loss at least one command will fail, but + /// the connection will be re-established automatically if possible. Please + /// refer to the [`ConnectionManager`][connection-manager] docs for + /// detailed reconnecting behavior. + /// + /// A connection manager can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + /// + /// [connection-manager]: aio/struct.ConnectionManager.html + /// [multiplexed-connection]: aio/struct.MultiplexedConnection.html + #[cfg(feature = "connection-manager")] + #[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))] + #[deprecated(note = "use get_connection_manager_with_backoff instead")] + pub async fn get_tokio_connection_manager_with_backoff( + &self, + exponent_base: u64, + factor: u64, + number_of_retries: usize, + ) -> RedisResult { + self.get_connection_manager_with_backoff_and_timeouts( + exponent_base, + factor, + number_of_retries, + std::time::Duration::MAX, + std::time::Duration::MAX, + ) + .await + } + + /// Returns an async [`ConnectionManager`][connection-manager] from the client. + /// + /// The connection manager wraps a + /// [`MultiplexedConnection`][multiplexed-connection]. If a command to that + /// connection fails with a connection error, then a new connection is + /// established in the background and the error is returned to the caller. + /// + /// This means that on connection loss at least one command will fail, but + /// the connection will be re-established automatically if possible. Please + /// refer to the [`ConnectionManager`][connection-manager] docs for + /// detailed reconnecting behavior. + /// + /// A connection manager can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + /// + /// [connection-manager]: aio/struct.ConnectionManager.html + /// [multiplexed-connection]: aio/struct.MultiplexedConnection.html + #[cfg(feature = "connection-manager")] + #[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))] + #[deprecated(note = "use get_connection_manager_with_backoff_and_timeouts instead")] + pub async fn get_tokio_connection_manager_with_backoff_and_timeouts( + &self, + exponent_base: u64, + factor: u64, + number_of_retries: usize, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + ) -> RedisResult { + crate::aio::ConnectionManager::new_with_backoff_and_timeouts( + self.clone(), + exponent_base, + factor, + number_of_retries, + response_timeout, + connection_timeout, + ) + .await + } + + /// Returns an async [`ConnectionManager`][connection-manager] from the client. + /// + /// The connection manager wraps a + /// [`MultiplexedConnection`][multiplexed-connection]. If a command to that + /// connection fails with a connection error, then a new connection is + /// established in the background and the error is returned to the caller. + /// + /// This means that on connection loss at least one command will fail, but + /// the connection will be re-established automatically if possible. Please + /// refer to the [`ConnectionManager`][connection-manager] docs for + /// detailed reconnecting behavior. + /// + /// A connection manager can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + /// + /// [connection-manager]: aio/struct.ConnectionManager.html + /// [multiplexed-connection]: aio/struct.MultiplexedConnection.html + #[cfg(feature = "connection-manager")] + #[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))] + pub async fn get_connection_manager_with_backoff_and_timeouts( + &self, + exponent_base: u64, + factor: u64, + number_of_retries: usize, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + ) -> RedisResult { + crate::aio::ConnectionManager::new_with_backoff_and_timeouts( + self.clone(), + exponent_base, + factor, + number_of_retries, + response_timeout, + connection_timeout, + ) + .await + } + + /// Returns an async [`ConnectionManager`][connection-manager] from the client. + /// + /// The connection manager wraps a + /// [`MultiplexedConnection`][multiplexed-connection]. If a command to that + /// connection fails with a connection error, then a new connection is + /// established in the background and the error is returned to the caller. + /// + /// This means that on connection loss at least one command will fail, but + /// the connection will be re-established automatically if possible. Please + /// refer to the [`ConnectionManager`][connection-manager] docs for + /// detailed reconnecting behavior. + /// + /// A connection manager can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + /// + /// [connection-manager]: aio/struct.ConnectionManager.html + /// [multiplexed-connection]: aio/struct.MultiplexedConnection.html + #[cfg(feature = "connection-manager")] + #[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))] + pub async fn get_connection_manager_with_backoff( + &self, + exponent_base: u64, + factor: u64, + number_of_retries: usize, + ) -> RedisResult { + crate::aio::ConnectionManager::new_with_backoff( + self.clone(), + exponent_base, + factor, + number_of_retries, + ) + .await + } + + pub(crate) async fn get_multiplexed_async_connection_inner( + &self, + response_timeout: std::time::Duration, + socket_addr: Option, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult<(crate::aio::MultiplexedConnection, Option)> + where + T: crate::aio::RedisRuntime, + { + let (connection, driver, ip) = self + .create_multiplexed_async_connection_inner::( + response_timeout, + socket_addr, + glide_connection_options, + ) + .await?; + T::spawn(driver); + Ok((connection, ip)) + } + + async fn create_multiplexed_async_connection_inner( + &self, + response_timeout: std::time::Duration, + socket_addr: Option, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult<( + crate::aio::MultiplexedConnection, + impl std::future::Future, + Option, + )> + where + T: crate::aio::RedisRuntime, + { + let (con, ip) = self.get_simple_async_connection::(socket_addr).await?; + crate::aio::MultiplexedConnection::new_with_response_timeout( + &self.connection_info, + con, + response_timeout, + glide_connection_options, + ) + .await + .map(|res| (res.0, res.1, ip)) + } + + async fn get_simple_async_connection( + &self, + socket_addr: Option, + ) -> RedisResult<( + Pin>, + Option, + )> + where + T: crate::aio::RedisRuntime, + { + let (conn, ip) = + crate::aio::connect_simple::(&self.connection_info, socket_addr).await?; + Ok((conn.boxed(), ip)) + } + + #[cfg(feature = "connection-manager")] + pub(crate) fn connection_info(&self) -> &ConnectionInfo { + &self.connection_info + } + + /// Constructs a new `Client` with parameters necessary to create a TLS connection. + /// + /// - `conn_info` - URL using the `rediss://` scheme. + /// - `tls_certs` - `TlsCertificates` structure containing: + /// -- `client_tls` - Optional `ClientTlsConfig` containing byte streams for + /// --- `client_cert` - client's byte stream containing client certificate in PEM format + /// --- `client_key` - client's byte stream containing private key in PEM format + /// -- `root_cert` - Optional byte stream yielding PEM formatted file for root certificates. + /// + /// If `ClientTlsConfig` ( cert+key pair ) is not provided, then client-side authentication is not enabled. + /// If `root_cert` is not provided, then system root certificates are used instead. + /// + /// # Examples + /// + /// ```no_run + /// use std::{fs::File, io::{BufReader, Read}}; + /// + /// use redis::{Client, AsyncCommands as _, TlsCertificates, ClientTlsConfig}; + /// + /// async fn do_redis_code( + /// url: &str, + /// root_cert_file: &str, + /// cert_file: &str, + /// key_file: &str + /// ) -> redis::RedisResult<()> { + /// let root_cert_file = File::open(root_cert_file).expect("cannot open private cert file"); + /// let mut root_cert_vec = Vec::new(); + /// BufReader::new(root_cert_file) + /// .read_to_end(&mut root_cert_vec) + /// .expect("Unable to read ROOT cert file"); + /// + /// let cert_file = File::open(cert_file).expect("cannot open private cert file"); + /// let mut client_cert_vec = Vec::new(); + /// BufReader::new(cert_file) + /// .read_to_end(&mut client_cert_vec) + /// .expect("Unable to read client cert file"); + /// + /// let key_file = File::open(key_file).expect("cannot open private key file"); + /// let mut client_key_vec = Vec::new(); + /// BufReader::new(key_file) + /// .read_to_end(&mut client_key_vec) + /// .expect("Unable to read client key file"); + /// + /// let client = Client::build_with_tls( + /// url, + /// TlsCertificates { + /// client_tls: Some(ClientTlsConfig{ + /// client_cert: client_cert_vec, + /// client_key: client_key_vec, + /// }), + /// root_cert: Some(root_cert_vec), + /// } + /// ) + /// .expect("Unable to build client"); + /// + /// let connection_info = client.get_connection_info(); + /// + /// println!(">>> connection info: {connection_info:?}"); + /// + /// let mut con = client.get_async_connection(None).await?; + /// + /// con.set("key1", b"foo").await?; + /// + /// redis::cmd("SET") + /// .arg(&["key2", "bar"]) + /// .query_async(&mut con) + /// .await?; + /// + /// let result = redis::cmd("MGET") + /// .arg(&["key1", "key2"]) + /// .query_async(&mut con) + /// .await; + /// assert_eq!(result, Ok(("foo".to_string(), b"bar".to_vec()))); + /// println!("Result from MGET: {result:?}"); + /// + /// Ok(()) + /// } + /// ``` + #[cfg(feature = "tls-rustls")] + pub fn build_with_tls( + conn_info: C, + tls_certs: TlsCertificates, + ) -> RedisResult { + let connection_info = conn_info.into_connection_info()?; + + inner_build_with_tls(connection_info, tls_certs) + } + + /// Returns an async receiver for pub-sub messages. + #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] + // TODO - do we want to type-erase pubsub using a trait, to allow us to replace it with a different implementation later? + pub async fn get_async_pubsub(&self) -> RedisResult { + #[allow(deprecated)] + self.get_async_connection(None) + .await + .map(|connection| connection.into_pubsub()) + } + + /// Returns an async receiver for monitor messages. + #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] + // TODO - do we want to type-erase monitor using a trait, to allow us to replace it with a different implementation later? + pub async fn get_async_monitor(&self) -> RedisResult { + #[allow(deprecated)] + self.get_async_connection(None) + .await + .map(|connection| connection.into_monitor()) + } +} + +#[cfg(feature = "aio")] +use crate::aio::Runtime; + +impl ConnectionLike for Client { + fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult { + self.get_connection(None)?.req_packed_command(cmd) + } + + fn req_packed_commands( + &mut self, + cmd: &[u8], + offset: usize, + count: usize, + ) -> RedisResult> { + self.get_connection(None)? + .req_packed_commands(cmd, offset, count) + } + + fn get_db(&self) -> i64 { + self.connection_info.redis.db + } + + fn check_connection(&mut self) -> bool { + if let Ok(mut conn) = self.get_connection(None) { + conn.check_connection() + } else { + false + } + } + + fn is_open(&self) -> bool { + if let Ok(conn) = self.get_connection(None) { + conn.is_open() + } else { + false + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn regression_293_parse_ipv6_with_interface() { + assert!(Client::open(("fe80::cafe:beef%eno1", 6379)).is_ok()); + } +} diff --git a/glide-core/redis-rs/redis/src/cluster.rs b/glide-core/redis-rs/redis/src/cluster.rs new file mode 100644 index 0000000000..f9c76f5161 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster.rs @@ -0,0 +1,1076 @@ +//! This module extends the library to support Redis Cluster. +//! +//! Note that this module does not currently provide pubsub +//! functionality. +//! +//! # Example +//! ```rust,no_run +//! use redis::Commands; +//! use redis::cluster::ClusterClient; +//! +//! let nodes = vec!["redis://127.0.0.1:6379/", "redis://127.0.0.1:6378/", "redis://127.0.0.1:6377/"]; +//! let client = ClusterClient::new(nodes).unwrap(); +//! let mut connection = client.get_connection(None).unwrap(); +//! +//! let _: () = connection.set("test", "test_data").unwrap(); +//! let rv: String = connection.get("test").unwrap(); +//! +//! assert_eq!(rv, "test_data"); +//! ``` +//! +//! # Pipelining +//! ```rust,no_run +//! use redis::Commands; +//! use redis::cluster::{cluster_pipe, ClusterClient}; +//! +//! let nodes = vec!["redis://127.0.0.1:6379/", "redis://127.0.0.1:6378/", "redis://127.0.0.1:6377/"]; +//! let client = ClusterClient::new(nodes).unwrap(); +//! let mut connection = client.get_connection(None).unwrap(); +//! +//! let key = "test"; +//! +//! let _: () = cluster_pipe() +//! .rpush(key, "123").ignore() +//! .ltrim(key, -10, -1).ignore() +//! .expire(key, 60).ignore() +//! .query(&mut connection).unwrap(); +//! ``` +use std::cell::RefCell; +use std::collections::HashSet; +use std::str::FromStr; +use std::thread; +use std::time::Duration; + +use rand::{seq::IteratorRandom, thread_rng}; + +use crate::cluster_pipeline::UNROUTABLE_ERROR; +use crate::cluster_routing::{ + MultipleNodeRoutingInfo, ResponsePolicy, Routable, SingleNodeRoutingInfo, +}; +use crate::cluster_slotmap::SlotMap; +use crate::cluster_topology::parse_and_count_slots; +use crate::cmd::{cmd, Cmd}; +use crate::connection::{ + connect, Connection, ConnectionAddr, ConnectionInfo, ConnectionLike, RedisConnectionInfo, +}; +use crate::parser::parse_redis_value; +use crate::types::{ErrorKind, HashMap, RedisError, RedisResult, Value}; +pub use crate::TlsMode; // Pub for backwards compatibility +use crate::{ + cluster_client::ClusterParams, + cluster_routing::{Redirect, Route, RoutingInfo}, + IntoConnectionInfo, PushInfo, +}; + +pub use crate::cluster_client::{ClusterClient, ClusterClientBuilder}; +pub use crate::cluster_pipeline::{cluster_pipe, ClusterPipeline}; + +use tokio::sync::mpsc; + +#[cfg(feature = "tls-rustls")] +use crate::tls::TlsConnParams; + +#[cfg(not(feature = "tls-rustls"))] +use crate::connection::TlsConnParams; + +#[derive(Clone)] +enum Input<'a> { + Slice { + cmd: &'a [u8], + routable: Value, + }, + Cmd(&'a Cmd), + Commands { + cmd: &'a [u8], + route: SingleNodeRoutingInfo, + offset: usize, + count: usize, + }, +} + +impl<'a> Input<'a> { + fn send(&'a self, connection: &mut impl ConnectionLike) -> RedisResult { + match self { + Input::Slice { cmd, routable: _ } => { + connection.req_packed_command(cmd).map(Output::Single) + } + Input::Cmd(cmd) => connection.req_command(cmd).map(Output::Single), + Input::Commands { + cmd, + route: _, + offset, + count, + } => connection + .req_packed_commands(cmd, *offset, *count) + .map(Output::Multi), + } + } +} + +impl<'a> Routable for Input<'a> { + fn arg_idx(&self, idx: usize) -> Option<&[u8]> { + match self { + Input::Slice { cmd: _, routable } => routable.arg_idx(idx), + Input::Cmd(cmd) => cmd.arg_idx(idx), + Input::Commands { .. } => None, + } + } + + fn position(&self, candidate: &[u8]) -> Option { + match self { + Input::Slice { cmd: _, routable } => routable.position(candidate), + Input::Cmd(cmd) => cmd.position(candidate), + Input::Commands { .. } => None, + } + } +} + +enum Output { + Single(Value), + Multi(Vec), +} + +impl From for Value { + fn from(value: Output) -> Self { + match value { + Output::Single(value) => value, + Output::Multi(values) => Value::Array(values), + } + } +} + +impl From for Vec { + fn from(value: Output) -> Self { + match value { + Output::Single(value) => vec![value], + Output::Multi(values) => values, + } + } +} + +/// Implements the process of connecting to a Redis server +/// and obtaining and configuring a connection handle. +pub trait Connect: Sized { + /// Connect to a node, returning handle for command execution. + fn connect(info: T, timeout: Option) -> RedisResult + where + T: IntoConnectionInfo; + + /// Sends an already encoded (packed) command into the TCP socket and + /// does not read a response. This is useful for commands like + /// `MONITOR` which yield multiple items. This needs to be used with + /// care because it changes the state of the connection. + fn send_packed_command(&mut self, cmd: &[u8]) -> RedisResult<()>; + + /// Sets the write timeout for the connection. + /// + /// If the provided value is `None`, then `send_packed_command` call will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + fn set_write_timeout(&self, dur: Option) -> RedisResult<()>; + + /// Sets the read timeout for the connection. + /// + /// If the provided value is `None`, then `recv_response` call will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + fn set_read_timeout(&self, dur: Option) -> RedisResult<()>; + + /// Fetches a single response from the connection. This is useful + /// if used in combination with `send_packed_command`. + fn recv_response(&mut self) -> RedisResult; +} + +impl Connect for Connection { + fn connect(info: T, timeout: Option) -> RedisResult + where + T: IntoConnectionInfo, + { + connect(&info.into_connection_info()?, timeout) + } + + fn send_packed_command(&mut self, cmd: &[u8]) -> RedisResult<()> { + Self::send_packed_command(self, cmd) + } + + fn set_write_timeout(&self, dur: Option) -> RedisResult<()> { + Self::set_write_timeout(self, dur) + } + + fn set_read_timeout(&self, dur: Option) -> RedisResult<()> { + Self::set_read_timeout(self, dur) + } + + fn recv_response(&mut self) -> RedisResult { + Self::recv_response(self) + } +} + +/// This represents a Redis Cluster connection. It stores the +/// underlying connections maintained for each node in the cluster, as well +/// as common parameters for connecting to nodes and executing commands. +pub struct ClusterConnection { + initial_nodes: Vec, + connections: RefCell>, + slots: RefCell, + auto_reconnect: RefCell, + read_timeout: RefCell>, + write_timeout: RefCell>, + cluster_params: ClusterParams, +} + +impl ClusterConnection +where + C: ConnectionLike + Connect, +{ + pub(crate) fn new( + cluster_params: ClusterParams, + initial_nodes: Vec, + _push_sender: Option>, + ) -> RedisResult { + let connection = Self { + connections: RefCell::new(HashMap::new()), + slots: RefCell::new(SlotMap::new(vec![], cluster_params.read_from_replicas)), + auto_reconnect: RefCell::new(true), + cluster_params, + read_timeout: RefCell::new(None), + write_timeout: RefCell::new(None), + initial_nodes: initial_nodes.to_vec(), + }; + connection.create_initial_connections()?; + + Ok(connection) + } + + /// Set an auto reconnect attribute. + /// Default value is true; + pub fn set_auto_reconnect(&self, value: bool) { + let mut auto_reconnect = self.auto_reconnect.borrow_mut(); + *auto_reconnect = value; + } + + /// Sets the write timeout for the connection. + /// + /// If the provided value is `None`, then `send_packed_command` call will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + pub fn set_write_timeout(&self, dur: Option) -> RedisResult<()> { + // Check if duration is valid before updating local value. + if dur.is_some() && dur.unwrap().is_zero() { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Duration should be None or non-zero.", + ))); + } + + let mut t = self.write_timeout.borrow_mut(); + *t = dur; + let connections = self.connections.borrow(); + for conn in connections.values() { + conn.set_write_timeout(dur)?; + } + Ok(()) + } + + /// Sets the read timeout for the connection. + /// + /// If the provided value is `None`, then `recv_response` call will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + pub fn set_read_timeout(&self, dur: Option) -> RedisResult<()> { + // Check if duration is valid before updating local value. + if dur.is_some() && dur.unwrap().is_zero() { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Duration should be None or non-zero.", + ))); + } + + let mut t = self.read_timeout.borrow_mut(); + *t = dur; + let connections = self.connections.borrow(); + for conn in connections.values() { + conn.set_read_timeout(dur)?; + } + Ok(()) + } + + /// Check that all connections it has are available (`PING` internally). + #[doc(hidden)] + pub fn check_connection(&mut self) -> bool { + ::check_connection(self) + } + + pub(crate) fn execute_pipeline(&mut self, pipe: &ClusterPipeline) -> RedisResult> { + self.send_recv_and_retry_cmds(pipe.commands()) + } + + /// Returns the connection status. + /// + /// The connection is open until any `read_response` call recieved an + /// invalid response from the server (most likely a closed or dropped + /// connection, otherwise a Redis protocol error). When using unix + /// sockets the connection is open until writing a command failed with a + /// `BrokenPipe` error. + fn create_initial_connections(&self) -> RedisResult<()> { + let mut connections = HashMap::with_capacity(self.initial_nodes.len()); + + for info in self.initial_nodes.iter() { + let addr = info.addr.to_string(); + + if let Ok(mut conn) = self.connect(&addr) { + if conn.check_connection() { + connections.insert(addr, conn); + break; + } + } + } + + if connections.is_empty() { + return Err(RedisError::from(( + ErrorKind::IoError, + "It failed to check startup nodes.", + ))); + } + + *self.connections.borrow_mut() = connections; + self.refresh_slots()?; + Ok(()) + } + + // Query a node to discover slot-> master mappings. + fn refresh_slots(&self) -> RedisResult<()> { + let mut slots = self.slots.borrow_mut(); + *slots = self.create_new_slots()?; + + let mut nodes = slots.values().flatten().collect::>(); + nodes.sort_unstable(); + nodes.dedup(); + + let mut connections = self.connections.borrow_mut(); + *connections = nodes + .into_iter() + .filter_map(|addr| { + if connections.contains_key(addr) { + let mut conn = connections.remove(addr).unwrap(); + if conn.check_connection() { + return Some((addr.to_string(), conn)); + } + } + + if let Ok(mut conn) = self.connect(addr) { + if conn.check_connection() { + return Some((addr.to_string(), conn)); + } + } + + None + }) + .collect(); + + Ok(()) + } + + fn create_new_slots(&self) -> RedisResult { + let mut connections = self.connections.borrow_mut(); + let mut rng = thread_rng(); + let len = connections.len(); + let samples = connections.iter_mut().choose_multiple(&mut rng, len); + let mut result = Err(RedisError::from(( + ErrorKind::ResponseError, + "Slot refresh error.", + "didn't get any slots from server".to_string(), + ))); + for (addr, conn) in samples { + let value = conn.req_command(&slot_cmd())?; + let addr = addr.split(':').next().ok_or(RedisError::from(( + ErrorKind::ClientError, + "can't parse node address", + )))?; + match parse_and_count_slots(&value, self.cluster_params.tls, addr).map(|slots_data| { + SlotMap::new(slots_data.1, self.cluster_params.read_from_replicas) + }) { + Ok(new_slots) => { + result = Ok(new_slots); + break; + } + Err(err) => result = Err(err), + } + } + result + } + + fn connect(&self, node: &str) -> RedisResult { + let info = get_connection_info(node, self.cluster_params.clone())?; + + let mut conn = C::connect(info, Some(self.cluster_params.connection_timeout))?; + if self.cluster_params.read_from_replicas + != crate::cluster_slotmap::ReadFromReplicaStrategy::AlwaysFromPrimary + { + // If READONLY is sent to primary nodes, it will have no effect + cmd("READONLY").query(&mut conn)?; + } + conn.set_read_timeout(*self.read_timeout.borrow())?; + conn.set_write_timeout(*self.write_timeout.borrow())?; + Ok(conn) + } + + fn get_connection<'a>( + &self, + connections: &'a mut HashMap, + route: &Route, + ) -> RedisResult<(String, &'a mut C)> { + let slots = self.slots.borrow(); + if let Some(addr) = slots.slot_addr_for_route(route) { + Ok(( + addr.to_string(), + self.get_connection_by_addr(connections, addr)?, + )) + } else { + // try a random node next. This is safe if slots are involved + // as a wrong node would reject the request. + Ok(get_random_connection(connections)) + } + } + + fn get_connection_by_addr<'a>( + &self, + connections: &'a mut HashMap, + addr: &str, + ) -> RedisResult<&'a mut C> { + if connections.contains_key(addr) { + Ok(connections.get_mut(addr).unwrap()) + } else { + // Create new connection. + // TODO: error handling + let conn = self.connect(addr)?; + Ok(connections.entry(addr.to_string()).or_insert(conn)) + } + } + + fn get_addr_for_cmd(&self, cmd: &Cmd) -> RedisResult { + let slots = self.slots.borrow(); + + let addr_for_slot = |route: Route| -> RedisResult { + let slot_addr = slots + .slot_addr_for_route(&route) + .ok_or((ErrorKind::ClusterDown, "Missing slot coverage"))?; + Ok(slot_addr.to_string()) + }; + + match RoutingInfo::for_routable(cmd) { + Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + | Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::RandomPrimary)) => { + Ok(addr_for_slot(Route::new_random_primary())?) + } + Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(route))) => { + Ok(addr_for_slot(route)?) + } + _ => fail!(UNROUTABLE_ERROR), + } + } + + fn map_cmds_to_nodes(&self, cmds: &[Cmd]) -> RedisResult> { + let mut cmd_map: HashMap = HashMap::new(); + + for (idx, cmd) in cmds.iter().enumerate() { + let addr = self.get_addr_for_cmd(cmd)?; + let nc = cmd_map + .entry(addr.clone()) + .or_insert_with(|| NodeCmd::new(addr)); + nc.indexes.push(idx); + cmd.write_packed_command(&mut nc.pipe); + } + + let mut result = Vec::new(); + for (_, v) in cmd_map.drain() { + result.push(v); + } + Ok(result) + } + + fn execute_on_all<'a>( + &'a self, + input: Input, + addresses: HashSet<&'a str>, + connections: &'a mut HashMap, + ) -> Vec> { + addresses + .into_iter() + .map(|addr| { + let connection = self.get_connection_by_addr(connections, addr)?; + match input { + Input::Slice { cmd, routable: _ } => connection.req_packed_command(cmd), + Input::Cmd(cmd) => connection.req_command(cmd), + Input::Commands { + cmd: _, + route: _, + offset: _, + count: _, + } => Err(( + ErrorKind::ClientError, + "req_packed_commands isn't supported with multiple nodes", + ) + .into()), + } + .map(|res| (addr, res)) + }) + .collect() + } + + fn execute_on_all_nodes<'a>( + &'a self, + input: Input, + slots: &'a mut SlotMap, + connections: &'a mut HashMap, + ) -> Vec> { + self.execute_on_all(input, slots.addresses_for_all_nodes(), connections) + } + + fn execute_on_all_primaries<'a>( + &'a self, + input: Input, + slots: &'a mut SlotMap, + connections: &'a mut HashMap, + ) -> Vec> { + self.execute_on_all(input, slots.addresses_for_all_primaries(), connections) + } + + fn execute_multi_slot<'a, 'b>( + &'a self, + input: Input, + slots: &'a mut SlotMap, + connections: &'a mut HashMap, + routes: &'b [(Route, Vec)], + ) -> Vec> + where + 'b: 'a, + { + slots + .addresses_for_multi_slot(routes) + .enumerate() + .map(|(index, addr)| { + let addr = addr.ok_or(RedisError::from(( + ErrorKind::IoError, + "Couldn't find connection", + )))?; + let connection = self.get_connection_by_addr(connections, addr)?; + let (_, indices) = routes.get(index).unwrap(); + let cmd = + crate::cluster_routing::command_for_multi_slot_indices(&input, indices.iter()); + connection.req_command(&cmd).map(|res| (addr, res)) + }) + .collect() + } + + fn execute_on_multiple_nodes( + &self, + input: Input, + routing: MultipleNodeRoutingInfo, + response_policy: Option, + ) -> RedisResult { + let mut connections = self.connections.borrow_mut(); + let mut slots = self.slots.borrow_mut(); + + let results = match &routing { + MultipleNodeRoutingInfo::MultiSlot(routes) => { + self.execute_multi_slot(input, &mut slots, &mut connections, routes) + } + MultipleNodeRoutingInfo::AllMasters => { + self.execute_on_all_primaries(input, &mut slots, &mut connections) + } + MultipleNodeRoutingInfo::AllNodes => { + self.execute_on_all_nodes(input, &mut slots, &mut connections) + } + }; + + match response_policy { + Some(ResponsePolicy::AllSucceeded) => { + for result in results { + result?; + } + + Ok(Value::Okay) + } + Some(ResponsePolicy::OneSucceeded) => { + let mut last_failure = None; + + for result in results { + match result { + Ok((_, val)) => return Ok(val), + Err(err) => last_failure = Some(err), + } + } + + Err(last_failure + .unwrap_or_else(|| (ErrorKind::IoError, "Couldn't find a connection").into())) + } + Some(ResponsePolicy::FirstSucceededNonEmptyOrAllEmpty) => { + // Attempt to return the first result that isn't `Nil` or an error. + // If no such response is found and all servers returned `Nil`, it indicates that all shards are empty, so return `Nil`. + // If we received only errors, return the last received error. + // If we received a mix of errors and `Nil`s, we can't determine if all shards are empty, + // thus we return the last received error instead of `Nil`. + let mut last_failure = None; + let num_of_results = results.len(); + let mut nil_counter = 0; + for result in results { + match result.map(|(_, res)| res) { + Ok(Value::Nil) => nil_counter += 1, + Ok(val) => return Ok(val), + Err(err) => last_failure = Some(err), + } + } + if nil_counter == num_of_results { + Ok(Value::Nil) + } else { + Err(last_failure.unwrap_or_else(|| { + (ErrorKind::IoError, "Couldn't find a connection").into() + })) + } + } + Some(ResponsePolicy::Aggregate(op)) => { + let results = results + .into_iter() + .map(|res| res.map(|(_, val)| val)) + .collect::>>()?; + crate::cluster_routing::aggregate(results, op) + } + Some(ResponsePolicy::AggregateLogical(op)) => { + let results = results + .into_iter() + .map(|res| res.map(|(_, val)| val)) + .collect::>>()?; + crate::cluster_routing::logical_aggregate(results, op) + } + Some(ResponsePolicy::CombineArrays) => { + let results = results + .into_iter() + .map(|res| res.map(|(_, val)| val)) + .collect::>>()?; + match routing { + MultipleNodeRoutingInfo::MultiSlot(vec) => { + crate::cluster_routing::combine_and_sort_array_results( + results, + vec.iter().map(|(_, indices)| indices), + ) + } + _ => crate::cluster_routing::combine_array_results(results), + } + } + Some(ResponsePolicy::CombineMaps) => { + let results = results + .into_iter() + .map(|res| res.map(|(_, val)| val)) + .collect::>>()?; + crate::cluster_routing::combine_map_results(results) + } + Some(ResponsePolicy::Special) | None => { + // This is our assumption - if there's no coherent way to aggregate the responses, we just map each response to the sender, and pass it to the user. + // TODO - once Value::Error is merged, we can use join_all and report separate errors and also pass successes. + let results = results + .into_iter() + .map(|result| { + result.map(|(addr, val)| (Value::BulkString(addr.as_bytes().to_vec()), val)) + }) + .collect::>>()?; + Ok(Value::Map(results)) + } + } + } + + #[allow(clippy::unnecessary_unwrap)] + fn request(&self, input: Input) -> RedisResult { + let route_option = match &input { + Input::Slice { cmd: _, routable } => RoutingInfo::for_routable(routable), + Input::Cmd(cmd) => RoutingInfo::for_routable(*cmd), + Input::Commands { + cmd: _, + route, + offset: _, + count: _, + } => Some(RoutingInfo::SingleNode(route.clone())), + }; + let single_node_routing = match route_option { + Some(RoutingInfo::SingleNode(single_node_routing)) => single_node_routing, + Some(RoutingInfo::MultiNode((multi_node_routing, response_policy))) => { + return self + .execute_on_multiple_nodes(input, multi_node_routing, response_policy) + .map(Output::Single); + } + None => fail!(UNROUTABLE_ERROR), + }; + + let mut retries = 0; + let mut redirected = None::; + + loop { + // Get target address and response. + let (addr, rv) = { + let mut connections = self.connections.borrow_mut(); + let (addr, conn) = if let Some(redirected) = redirected.take() { + let (addr, is_asking) = match redirected { + Redirect::Moved(addr) => (addr, false), + Redirect::Ask(addr) => (addr, true), + }; + let conn = self.get_connection_by_addr(&mut connections, &addr)?; + if is_asking { + // if we are in asking mode we want to feed a single + // ASKING command into the connection before what we + // actually want to execute. + conn.req_packed_command(&b"*1\r\n$6\r\nASKING\r\n"[..])?; + } + (addr.to_string(), conn) + } else { + match &single_node_routing { + SingleNodeRoutingInfo::Random => get_random_connection(&mut connections), + SingleNodeRoutingInfo::SpecificNode(route) => { + self.get_connection(&mut connections, route)? + } + SingleNodeRoutingInfo::RandomPrimary => { + self.get_connection(&mut connections, &Route::new_random_primary())? + } + SingleNodeRoutingInfo::ByAddress { host, port } => { + let address = format!("{host}:{port}"); + let conn = self.get_connection_by_addr(&mut connections, &address)?; + (address, conn) + } + } + }; + (addr, input.send(conn)) + }; + + match rv { + Ok(rv) => return Ok(rv), + Err(err) => { + if retries == self.cluster_params.retry_params.number_of_retries { + return Err(err); + } + retries += 1; + + match err.retry_method() { + crate::types::RetryMethod::AskRedirect => { + redirected = err + .redirect_node() + .map(|(node, _slot)| Redirect::Ask(node.to_string())); + } + crate::types::RetryMethod::MovedRedirect => { + // Refresh slots. + self.refresh_slots()?; + // Request again. + redirected = err + .redirect_node() + .map(|(node, _slot)| Redirect::Moved(node.to_string())); + } + crate::types::RetryMethod::WaitAndRetryOnPrimaryRedirectOnReplica + | crate::types::RetryMethod::WaitAndRetry => { + // Sleep and retry. + let sleep_time = self + .cluster_params + .retry_params + .wait_time_for_retry(retries); + thread::sleep(sleep_time); + } + crate::types::RetryMethod::Reconnect => { + if *self.auto_reconnect.borrow() { + if let Ok(mut conn) = self.connect(&addr) { + if conn.check_connection() { + self.connections.borrow_mut().insert(addr, conn); + } + } + } + } + crate::types::RetryMethod::NoRetry => { + return Err(err); + } + crate::types::RetryMethod::RetryImmediately => {} + } + } + } + } + } + + fn send_recv_and_retry_cmds(&self, cmds: &[Cmd]) -> RedisResult> { + // Vector to hold the results, pre-populated with `Nil` values. This allows the original + // cmd ordering to be re-established by inserting the response directly into the result + // vector (e.g., results[10] = response). + let mut results = vec![Value::Nil; cmds.len()]; + + let to_retry = self + .send_all_commands(cmds) + .and_then(|node_cmds| self.recv_all_commands(&mut results, &node_cmds))?; + + if to_retry.is_empty() { + return Ok(results); + } + + // Refresh the slots to ensure that we have a clean slate for the retry attempts. + self.refresh_slots()?; + + // Given that there are commands that need to be retried, it means something in the cluster + // topology changed. Execute each command seperately to take advantage of the existing + // retry logic that handles these cases. + for retry_idx in to_retry { + let cmd = &cmds[retry_idx]; + results[retry_idx] = self.request(Input::Cmd(cmd))?.into(); + } + Ok(results) + } + + // Build up a pipeline per node, then send it + fn send_all_commands(&self, cmds: &[Cmd]) -> RedisResult> { + let mut connections = self.connections.borrow_mut(); + + let node_cmds = self.map_cmds_to_nodes(cmds)?; + for nc in &node_cmds { + self.get_connection_by_addr(&mut connections, &nc.addr)? + .send_packed_command(&nc.pipe)?; + } + Ok(node_cmds) + } + + // Receive from each node, keeping track of which commands need to be retried. + fn recv_all_commands( + &self, + results: &mut [Value], + node_cmds: &[NodeCmd], + ) -> RedisResult> { + let mut to_retry = Vec::new(); + let mut connections = self.connections.borrow_mut(); + let mut first_err = None; + + for nc in node_cmds { + for cmd_idx in &nc.indexes { + match self + .get_connection_by_addr(&mut connections, &nc.addr)? + .recv_response() + { + Ok(item) => results[*cmd_idx] = item, + Err(err) if err.is_cluster_error() => to_retry.push(*cmd_idx), + Err(err) => first_err = first_err.or(Some(err)), + } + } + } + match first_err { + Some(err) => Err(err), + None => Ok(to_retry), + } + } +} + +const MULTI: &[u8] = "*1\r\n$5\r\nMULTI\r\n".as_bytes(); +impl ConnectionLike for ClusterConnection { + fn supports_pipelining(&self) -> bool { + false + } + + fn req_command(&mut self, cmd: &Cmd) -> RedisResult { + self.request(Input::Cmd(cmd)).map(|res| res.into()) + } + + fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult { + let actual_cmd = if cmd.starts_with(MULTI) { + &cmd[MULTI.len()..] + } else { + cmd + }; + let value = parse_redis_value(actual_cmd)?; + self.request(Input::Slice { + cmd, + routable: value, + }) + .map(|res| res.into()) + } + + fn req_packed_commands( + &mut self, + cmd: &[u8], + offset: usize, + count: usize, + ) -> RedisResult> { + let actual_cmd = if cmd.starts_with(MULTI) { + &cmd[MULTI.len()..] + } else { + cmd + }; + let value = parse_redis_value(actual_cmd)?; + let route = match RoutingInfo::for_routable(&value) { + Some(RoutingInfo::MultiNode(_)) => None, + Some(RoutingInfo::SingleNode(route)) => Some(route), + None => None, + } + .unwrap_or(SingleNodeRoutingInfo::Random); + self.request(Input::Commands { + cmd, + offset, + count, + route, + }) + .map(|res| res.into()) + } + + fn get_db(&self) -> i64 { + 0 + } + + fn is_open(&self) -> bool { + let connections = self.connections.borrow(); + for conn in connections.values() { + if !conn.is_open() { + return false; + } + } + true + } + + fn check_connection(&mut self) -> bool { + let mut connections = self.connections.borrow_mut(); + for conn in connections.values_mut() { + if !conn.check_connection() { + return false; + } + } + true + } +} + +#[derive(Debug)] +struct NodeCmd { + // The original command indexes + indexes: Vec, + pipe: Vec, + addr: String, +} + +impl NodeCmd { + fn new(a: String) -> NodeCmd { + NodeCmd { + indexes: vec![], + pipe: vec![], + addr: a, + } + } +} + +// TODO: This function can panic and should probably +// return an Option instead: +fn get_random_connection( + connections: &mut HashMap, +) -> (String, &mut C) { + let addr = connections + .keys() + .choose(&mut thread_rng()) + .expect("Connections is empty") + .to_string(); + let con = connections.get_mut(&addr).expect("Connections is empty"); + (addr, con) +} + +// The node string passed to this function will always be in the format host:port as it is either: +// - Created by calling ConnectionAddr::to_string (unix connections are not supported in cluster mode) +// - Returned from redis via the ASK/MOVED response +pub(crate) fn get_connection_info( + node: &str, + cluster_params: ClusterParams, +) -> RedisResult { + let invalid_error = || (ErrorKind::InvalidClientConfig, "Invalid node string"); + + let (host, port) = node + .rsplit_once(':') + .and_then(|(host, port)| { + Some(host.trim_start_matches('[').trim_end_matches(']')) + .filter(|h| !h.is_empty()) + .zip(u16::from_str(port).ok()) + }) + .ok_or_else(invalid_error)?; + + Ok(ConnectionInfo { + addr: get_connection_addr( + host.to_string(), + port, + cluster_params.tls, + cluster_params.tls_params, + ), + redis: RedisConnectionInfo { + password: cluster_params.password, + username: cluster_params.username, + client_name: cluster_params.client_name, + protocol: cluster_params.protocol, + db: 0, + pubsub_subscriptions: cluster_params.pubsub_subscriptions, + }, + }) +} + +pub(crate) fn get_connection_addr( + host: String, + port: u16, + tls: Option, + tls_params: Option, +) -> ConnectionAddr { + match tls { + Some(TlsMode::Secure) => ConnectionAddr::TcpTls { + host, + port, + insecure: false, + tls_params, + }, + Some(TlsMode::Insecure) => ConnectionAddr::TcpTls { + host, + port, + insecure: true, + tls_params, + }, + _ => ConnectionAddr::Tcp(host, port), + } +} + +pub(crate) fn slot_cmd() -> Cmd { + let mut cmd = Cmd::new(); + cmd.arg("CLUSTER").arg("SLOTS"); + cmd +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_cluster_node_host_port() { + let cases = vec![ + ( + "127.0.0.1:6379", + ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379u16), + ), + ( + "localhost.localdomain:6379", + ConnectionAddr::Tcp("localhost.localdomain".to_string(), 6379u16), + ), + ( + "dead::cafe:beef:30001", + ConnectionAddr::Tcp("dead::cafe:beef".to_string(), 30001u16), + ), + ( + "[fe80::cafe:beef%en1]:30001", + ConnectionAddr::Tcp("fe80::cafe:beef%en1".to_string(), 30001u16), + ), + ]; + + for (input, expected) in cases { + let res = get_connection_info(input, ClusterParams::default()); + assert_eq!(res.unwrap().addr, expected); + } + + let cases = vec![":0", "[]:6379"]; + for input in cases { + let res = get_connection_info(input, ClusterParams::default()); + assert_eq!( + res.err(), + Some(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Invalid node string", + ))), + ); + } + } +} diff --git a/glide-core/redis-rs/redis/src/cluster_async/LICENSE b/glide-core/redis-rs/redis/src/cluster_async/LICENSE new file mode 100644 index 0000000000..aaa71a1638 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_async/LICENSE @@ -0,0 +1,7 @@ +Copyright 2019 Atsushi Koge, Markus Westerlind + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs b/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs new file mode 100644 index 0000000000..2bfbb8b934 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs @@ -0,0 +1,881 @@ +use crate::cluster_async::ConnectionFuture; +use crate::cluster_routing::{Route, SlotAddr}; +use crate::cluster_slotmap::{ReadFromReplicaStrategy, SlotMap, SlotMapValue}; +use crate::cluster_topology::TopologyHash; +use dashmap::DashMap; +use futures::FutureExt; +use rand::seq::IteratorRandom; +use std::net::IpAddr; + +/// A struct that encapsulates a network connection along with its associated IP address. +#[derive(Clone, Eq, PartialEq, Debug)] +pub struct ConnectionWithIp { + /// The actual connection + pub conn: Connection, + /// The IP associated with the connection + pub ip: Option, +} + +impl ConnectionWithIp +where + Connection: Clone + Send + 'static, +{ + /// Consumes the current instance and returns a new `ConnectionWithIp` + /// where the connection is wrapped in a future. + #[doc(hidden)] + pub fn into_future(self) -> ConnectionWithIp> { + ConnectionWithIp { + conn: async { self.conn }.boxed().shared(), + ip: self.ip, + } + } +} + +impl From<(Connection, Option)> for ConnectionWithIp { + fn from(val: (Connection, Option)) -> Self { + ConnectionWithIp { + conn: val.0, + ip: val.1, + } + } +} + +impl From> for (Connection, Option) { + fn from(val: ConnectionWithIp) -> Self { + (val.conn, val.ip) + } +} + +#[derive(Clone, Eq, PartialEq, Debug)] +pub struct ClusterNode { + pub user_connection: ConnectionWithIp, + pub management_connection: Option>, +} + +impl ClusterNode +where + Connection: Clone, +{ + pub fn new( + user_connection: ConnectionWithIp, + management_connection: Option>, + ) -> Self { + Self { + user_connection, + management_connection, + } + } + + pub(crate) fn get_connection(&self, conn_type: &ConnectionType) -> Connection { + match conn_type { + ConnectionType::User => self.user_connection.conn.clone(), + ConnectionType::PreferManagement => self.management_connection.as_ref().map_or_else( + || self.user_connection.conn.clone(), + |management_conn| management_conn.conn.clone(), + ), + } + } +} + +#[derive(Clone, Eq, PartialEq, Debug)] + +pub(crate) enum ConnectionType { + User, + PreferManagement, +} + +pub(crate) struct ConnectionsMap(pub(crate) DashMap>); + +impl std::fmt::Display for ConnectionsMap { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for item in self.0.iter() { + let (address, node) = (item.key(), item.value()); + match node.user_connection.ip { + Some(ip) => writeln!(f, "{address} - {ip}")?, + None => writeln!(f, "{address}")?, + }; + } + Ok(()) + } +} + +pub(crate) struct ConnectionsContainer { + connection_map: DashMap>, + pub(crate) slot_map: SlotMap, + read_from_replica_strategy: ReadFromReplicaStrategy, + topology_hash: TopologyHash, +} + +impl Default for ConnectionsContainer { + fn default() -> Self { + Self { + connection_map: Default::default(), + slot_map: Default::default(), + read_from_replica_strategy: ReadFromReplicaStrategy::AlwaysFromPrimary, + topology_hash: 0, + } + } +} + +pub(crate) type ConnectionAndAddress = (String, Connection); + +impl ConnectionsContainer +where + Connection: Clone, +{ + pub(crate) fn new( + slot_map: SlotMap, + connection_map: ConnectionsMap, + read_from_replica_strategy: ReadFromReplicaStrategy, + topology_hash: TopologyHash, + ) -> Self { + Self { + connection_map: connection_map.0, + slot_map, + read_from_replica_strategy, + topology_hash, + } + } + + // Extends the current connection map with the provided one + pub(crate) fn extend_connection_map( + &mut self, + other_connection_map: ConnectionsMap, + ) { + self.connection_map.extend(other_connection_map.0); + } + + /// Returns true if the address represents a known primary node. + pub(crate) fn is_primary(&self, address: &String) -> bool { + self.connection_for_address(address).is_some() + && self + .slot_map + .values() + .any(|slot_addrs| slot_addrs.primary.as_str() == address) + } + + fn round_robin_read_from_replica( + &self, + slot_map_value: &SlotMapValue, + ) -> Option> { + let addrs = &slot_map_value.addrs; + let initial_index = slot_map_value + .latest_used_replica + .load(std::sync::atomic::Ordering::Relaxed); + let mut check_count = 0; + loop { + check_count += 1; + + // Looped through all replicas, no connected replica was found. + if check_count > addrs.replicas.len() { + return self.connection_for_address(addrs.primary.as_str()); + } + let index = (initial_index + check_count) % addrs.replicas.len(); + if let Some(connection) = self.connection_for_address(addrs.replicas[index].as_str()) { + let _ = slot_map_value.latest_used_replica.compare_exchange_weak( + initial_index, + index, + std::sync::atomic::Ordering::Relaxed, + std::sync::atomic::Ordering::Relaxed, + ); + return Some(connection); + } + } + } + + fn lookup_route(&self, route: &Route) -> Option> { + let slot_map_value = self.slot_map.slot_value_for_route(route)?; + let addrs = &slot_map_value.addrs; + if addrs.replicas.is_empty() { + return self.connection_for_address(addrs.primary.as_str()); + } + + match route.slot_addr() { + SlotAddr::Master => self.connection_for_address(addrs.primary.as_str()), + SlotAddr::ReplicaOptional => match self.read_from_replica_strategy { + ReadFromReplicaStrategy::AlwaysFromPrimary => { + self.connection_for_address(addrs.primary.as_str()) + } + ReadFromReplicaStrategy::RoundRobin => { + self.round_robin_read_from_replica(slot_map_value) + } + }, + SlotAddr::ReplicaRequired => self.round_robin_read_from_replica(slot_map_value), + } + } + + pub(crate) fn connection_for_route( + &self, + route: &Route, + ) -> Option> { + self.lookup_route(route).or_else(|| { + if route.slot_addr() != SlotAddr::Master { + self.lookup_route(&Route::new(route.slot(), SlotAddr::Master)) + } else { + None + } + }) + } + + pub(crate) fn all_node_connections( + &self, + ) -> impl Iterator> + '_ { + self.connection_map.iter().map(move |item| { + let (node, address) = (item.key(), item.value()); + (node.clone(), address.user_connection.conn.clone()) + }) + } + + pub(crate) fn all_primary_connections( + &self, + ) -> impl Iterator> + '_ { + self.slot_map + .addresses_for_all_primaries() + .into_iter() + .flat_map(|addr| self.connection_for_address(addr)) + } + + pub(crate) fn node_for_address(&self, address: &str) -> Option> { + self.connection_map + .get(address) + .map(|item| item.value().clone()) + } + + pub(crate) fn connection_for_address( + &self, + address: &str, + ) -> Option> { + self.connection_map.get(address).map(|item| { + let (address, conn) = (item.key(), item.value()); + (address.clone(), conn.user_connection.conn.clone()) + }) + } + + pub(crate) fn random_connections( + &self, + amount: usize, + conn_type: ConnectionType, + ) -> impl Iterator> + '_ { + self.connection_map + .iter() + .choose_multiple(&mut rand::thread_rng(), amount) + .into_iter() + .map(move |item| { + let (address, node) = (item.key(), item.value()); + let conn = node.get_connection(&conn_type); + (address.clone(), conn) + }) + } + + pub(crate) fn replace_or_add_connection_for_address( + &self, + address: impl Into, + node: ClusterNode, + ) -> String { + let address = address.into(); + self.connection_map.insert(address.clone(), node); + address + } + + pub(crate) fn remove_node(&self, address: &String) -> Option> { + self.connection_map + .remove(address) + .map(|(_key, value)| value) + } + + pub(crate) fn len(&self) -> usize { + self.connection_map.len() + } + + pub(crate) fn get_current_topology_hash(&self) -> TopologyHash { + self.topology_hash + } + + /// Returns true if the connections container contains no connections. + pub(crate) fn is_empty(&self) -> bool { + self.connection_map.is_empty() + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use crate::cluster_routing::Slot; + + use super::*; + impl ClusterNode + where + Connection: Clone, + { + pub(crate) fn new_only_with_user_conn(user_connection: Connection) -> Self { + let ip = None; + Self { + user_connection: (user_connection, ip).into(), + management_connection: None, + } + } + } + fn remove_nodes(container: &ConnectionsContainer, addresses: &[&str]) { + for address in addresses { + container.remove_node(&(*address).into()); + } + } + + fn remove_all_connections(container: &ConnectionsContainer) { + remove_nodes( + container, + &[ + "primary1", + "primary2", + "primary3", + "replica2-1", + "replica3-1", + "replica3-2", + ], + ); + } + + fn one_of( + connection: Option>, + expected_connections: &[usize], + ) -> bool { + let found = connection.unwrap().1; + expected_connections.contains(&found) + } + fn create_cluster_node( + connection: usize, + use_management_connections: bool, + ) -> ClusterNode { + let ip = None; + ClusterNode::new( + (connection, ip).into(), + if use_management_connections { + Some((connection * 10, ip).into()) + } else { + None + }, + ) + } + + fn create_container_with_strategy( + stragey: ReadFromReplicaStrategy, + use_management_connections: bool, + ) -> ConnectionsContainer { + let slot_map = SlotMap::new( + vec![ + Slot::new(1, 1000, "primary1".to_owned(), Vec::new()), + Slot::new( + 1002, + 2000, + "primary2".to_owned(), + vec!["replica2-1".to_owned()], + ), + Slot::new( + 2001, + 3000, + "primary3".to_owned(), + vec!["replica3-1".to_owned(), "replica3-2".to_owned()], + ), + ], + ReadFromReplicaStrategy::AlwaysFromPrimary, // this argument shouldn't matter, since we overload the RFR strategy. + ); + let connection_map = DashMap::new(); + connection_map.insert( + "primary1".into(), + create_cluster_node(1, use_management_connections), + ); + connection_map.insert( + "primary2".into(), + create_cluster_node(2, use_management_connections), + ); + connection_map.insert( + "primary3".into(), + create_cluster_node(3, use_management_connections), + ); + connection_map.insert( + "replica2-1".into(), + create_cluster_node(21, use_management_connections), + ); + connection_map.insert( + "replica3-1".into(), + create_cluster_node(31, use_management_connections), + ); + connection_map.insert( + "replica3-2".into(), + create_cluster_node(32, use_management_connections), + ); + + ConnectionsContainer { + slot_map, + connection_map, + read_from_replica_strategy: stragey, + topology_hash: 0, + } + } + + fn create_container() -> ConnectionsContainer { + create_container_with_strategy(ReadFromReplicaStrategy::RoundRobin, false) + } + + #[test] + fn get_connection_for_primary_route() { + let container = create_container(); + + assert!(container + .connection_for_route(&Route::new(0, SlotAddr::Master)) + .is_none()); + + assert_eq!( + 1, + container + .connection_for_route(&Route::new(500, SlotAddr::Master)) + .unwrap() + .1 + ); + + assert_eq!( + 1, + container + .connection_for_route(&Route::new(1000, SlotAddr::Master)) + .unwrap() + .1 + ); + + assert!(container + .connection_for_route(&Route::new(1001, SlotAddr::Master)) + .is_none()); + + assert_eq!( + 2, + container + .connection_for_route(&Route::new(1002, SlotAddr::Master)) + .unwrap() + .1 + ); + + assert_eq!( + 2, + container + .connection_for_route(&Route::new(1500, SlotAddr::Master)) + .unwrap() + .1 + ); + + assert_eq!( + 3, + container + .connection_for_route(&Route::new(2001, SlotAddr::Master)) + .unwrap() + .1 + ); + } + + #[test] + fn get_connection_for_replica_route() { + let container = create_container(); + + assert!(container + .connection_for_route(&Route::new(1001, SlotAddr::ReplicaOptional)) + .is_none()); + + assert_eq!( + 21, + container + .connection_for_route(&Route::new(1002, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + + assert_eq!( + 21, + container + .connection_for_route(&Route::new(1500, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + + assert!(one_of( + container.connection_for_route(&Route::new(2001, SlotAddr::ReplicaOptional)), + &[31, 32], + )); + } + + #[test] + fn get_primary_connection_for_replica_route_if_no_replicas_were_added() { + let container = create_container(); + + assert!(container + .connection_for_route(&Route::new(0, SlotAddr::ReplicaOptional)) + .is_none()); + + assert_eq!( + 1, + container + .connection_for_route(&Route::new(500, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + + assert_eq!( + 1, + container + .connection_for_route(&Route::new(1000, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + } + + #[test] + fn get_replica_connection_for_replica_route_if_some_but_not_all_replicas_were_removed() { + let container = create_container(); + container.remove_node(&"replica3-2".into()); + + assert_eq!( + 31, + container + .connection_for_route(&Route::new(2001, SlotAddr::ReplicaRequired)) + .unwrap() + .1 + ); + } + + #[test] + fn get_replica_connection_for_replica_route_if_replica_is_required_even_if_strategy_is_always_from_primary( + ) { + let container = + create_container_with_strategy(ReadFromReplicaStrategy::AlwaysFromPrimary, false); + + assert!(one_of( + container.connection_for_route(&Route::new(2001, SlotAddr::ReplicaRequired)), + &[31, 32], + )); + } + + #[test] + fn get_primary_connection_for_replica_route_if_all_replicas_were_removed() { + let container = create_container(); + remove_nodes(&container, &["replica2-1", "replica3-1", "replica3-2"]); + + assert_eq!( + 2, + container + .connection_for_route(&Route::new(1002, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + + assert_eq!( + 2, + container + .connection_for_route(&Route::new(1500, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + + assert_eq!( + 3, + container + .connection_for_route(&Route::new(2001, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + } + + #[test] + fn get_connection_by_address() { + let container = create_container(); + + assert!(container.connection_for_address("foobar").is_none()); + + assert_eq!(1, container.connection_for_address("primary1").unwrap().1); + assert_eq!(2, container.connection_for_address("primary2").unwrap().1); + assert_eq!(3, container.connection_for_address("primary3").unwrap().1); + assert_eq!( + 21, + container.connection_for_address("replica2-1").unwrap().1 + ); + assert_eq!( + 31, + container.connection_for_address("replica3-1").unwrap().1 + ); + assert_eq!( + 32, + container.connection_for_address("replica3-2").unwrap().1 + ); + } + + #[test] + fn get_connection_by_address_returns_none_if_connection_was_removed() { + let container = create_container(); + container.remove_node(&"primary1".into()); + + assert!(container.connection_for_address("primary1").is_none()); + } + + #[test] + fn get_connection_by_address_returns_added_connection() { + let container = create_container(); + let address = container.replace_or_add_connection_for_address( + "foobar", + ClusterNode::new_only_with_user_conn(4), + ); + + assert_eq!( + (address, 4), + container.connection_for_address("foobar").unwrap() + ); + } + + #[test] + fn get_random_connections_without_repetitions() { + let container = create_container(); + + let random_connections: HashSet<_> = container + .random_connections(3, ConnectionType::User) + .map(|pair| pair.1) + .collect(); + + assert_eq!(random_connections.len(), 3); + assert!(random_connections + .iter() + .all(|connection| [1, 2, 3, 21, 31, 32].contains(connection))); + } + + #[test] + fn get_random_connections_returns_none_if_all_connections_were_removed() { + let container = create_container(); + remove_all_connections(&container); + + assert_eq!( + 0, + container + .random_connections(1, ConnectionType::User) + .count() + ); + } + + #[test] + fn get_random_connections_returns_added_connection() { + let container = create_container(); + remove_all_connections(&container); + let address = container.replace_or_add_connection_for_address( + "foobar", + ClusterNode::new_only_with_user_conn(4), + ); + let random_connections: Vec<_> = container + .random_connections(1, ConnectionType::User) + .collect(); + + assert_eq!(vec![(address, 4)], random_connections); + } + + #[test] + fn get_random_connections_is_bound_by_the_number_of_connections_in_the_map() { + let container = create_container(); + let mut random_connections: Vec<_> = container + .random_connections(1000, ConnectionType::User) + .map(|pair| pair.1) + .collect(); + random_connections.sort(); + + assert_eq!(random_connections, vec![1, 2, 3, 21, 31, 32]); + } + + #[test] + fn get_random_management_connections() { + let container = create_container_with_strategy(ReadFromReplicaStrategy::RoundRobin, true); + let mut random_connections: Vec<_> = container + .random_connections(1000, ConnectionType::PreferManagement) + .map(|pair| pair.1) + .collect(); + random_connections.sort(); + + assert_eq!(random_connections, vec![10, 20, 30, 210, 310, 320]); + } + + #[test] + fn get_all_user_connections() { + let container = create_container(); + let mut connections: Vec<_> = container + .all_node_connections() + .map(|conn| conn.1) + .collect(); + connections.sort(); + + assert_eq!(vec![1, 2, 3, 21, 31, 32], connections); + } + + #[test] + fn get_all_user_connections_returns_added_connection() { + let container = create_container(); + container.replace_or_add_connection_for_address( + "foobar", + ClusterNode::new_only_with_user_conn(4), + ); + + let mut connections: Vec<_> = container + .all_node_connections() + .map(|conn| conn.1) + .collect(); + connections.sort(); + + assert_eq!(vec![1, 2, 3, 4, 21, 31, 32], connections); + } + + #[test] + fn get_all_user_connections_does_not_return_removed_connection() { + let container = create_container(); + container.remove_node(&"primary1".into()); + + let mut connections: Vec<_> = container + .all_node_connections() + .map(|conn| conn.1) + .collect(); + connections.sort(); + + assert_eq!(vec![2, 3, 21, 31, 32], connections); + } + + #[test] + fn get_all_primaries() { + let container = create_container(); + + let mut connections: Vec<_> = container + .all_primary_connections() + .map(|conn| conn.1) + .collect(); + connections.sort(); + + assert_eq!(vec![1, 2, 3], connections); + } + + #[test] + fn get_all_primaries_does_not_return_removed_connection() { + let container = create_container(); + container.remove_node(&"primary1".into()); + + let mut connections: Vec<_> = container + .all_primary_connections() + .map(|conn| conn.1) + .collect(); + connections.sort(); + + assert_eq!(vec![2, 3], connections); + } + + #[test] + fn len_is_adjusted_on_removals_and_additions() { + let container = create_container(); + + assert_eq!(container.len(), 6); + + container.remove_node(&"primary1".into()); + assert_eq!(container.len(), 5); + + container.replace_or_add_connection_for_address( + "foobar", + ClusterNode::new_only_with_user_conn(4), + ); + assert_eq!(container.len(), 6); + } + + #[test] + fn len_is_not_adjusted_on_removals_of_nonexisting_connections_or_additions_of_existing_connections( + ) { + let container = create_container(); + + assert_eq!(container.len(), 6); + + container.remove_node(&"foobar".into()); + assert_eq!(container.len(), 6); + + container.replace_or_add_connection_for_address( + "primary1", + ClusterNode::new_only_with_user_conn(4), + ); + assert_eq!(container.len(), 6); + } + + #[test] + fn remove_node_returns_connection_if_it_exists() { + let container = create_container(); + + let connection = container.remove_node(&"primary1".into()); + assert_eq!(connection, Some(ClusterNode::new_only_with_user_conn(1))); + + let non_connection = container.remove_node(&"foobar".into()); + assert_eq!(non_connection, None); + } + + #[test] + fn test_is_empty() { + let container = create_container(); + + assert!(!container.is_empty()); + container.remove_node(&"primary1".into()); + assert!(!container.is_empty()); + container.remove_node(&"primary2".into()); + container.remove_node(&"primary3".into()); + assert!(!container.is_empty()); + + container.remove_node(&"replica2-1".into()); + container.remove_node(&"replica3-1".into()); + assert!(!container.is_empty()); + + container.remove_node(&"replica3-2".into()); + assert!(container.is_empty()); + } + + #[test] + fn is_primary_returns_true_for_known_primary() { + let container = create_container(); + + assert!(container.is_primary(&"primary1".into())); + } + + #[test] + fn is_primary_returns_false_for_known_replica() { + let container = create_container(); + + assert!(!container.is_primary(&"replica2-1".into())); + } + + #[test] + fn is_primary_returns_false_for_removed_node() { + let container = create_container(); + let address = "primary1".into(); + container.remove_node(&address); + + assert!(!container.is_primary(&address)); + } + + #[test] + fn test_extend_connection_map() { + let mut container = create_container(); + let mut current_addresses: Vec<_> = container + .all_node_connections() + .map(|conn| conn.0) + .collect(); + + let new_node = "new_primary1".to_string(); + // Check that `new_node` not exists in the current + assert!(container.connection_for_address(&new_node).is_none()); + // Create new connection map + let new_connection_map = DashMap::new(); + new_connection_map.insert(new_node.clone(), create_cluster_node(1, false)); + + // Extend the current connection map + container.extend_connection_map(ConnectionsMap(new_connection_map)); + + // Check that the new addresses vector contains both the new node and all previous nodes + let mut new_addresses: Vec<_> = container + .all_node_connections() + .map(|conn| conn.0) + .collect(); + current_addresses.push(new_node); + current_addresses.sort(); + new_addresses.sort(); + assert_eq!(current_addresses, new_addresses); + } +} diff --git a/glide-core/redis-rs/redis/src/cluster_async/connections_logic.rs b/glide-core/redis-rs/redis/src/cluster_async/connections_logic.rs new file mode 100644 index 0000000000..7de2493000 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_async/connections_logic.rs @@ -0,0 +1,481 @@ +use std::net::SocketAddr; + +use super::{ + connections_container::{ClusterNode, ConnectionWithIp}, + Connect, +}; +use crate::{ + aio::{ConnectionLike, DisconnectNotifier, Runtime}, + client::GlideConnectionOptions, + cluster::get_connection_info, + cluster_client::ClusterParams, + ErrorKind, RedisError, RedisResult, +}; + +use futures::prelude::*; +use futures_util::{future::BoxFuture, join}; +use tracing::warn; + +pub(crate) type ConnectionFuture = futures::future::Shared>; +/// Cluster node for async connections +#[doc(hidden)] +pub type AsyncClusterNode = ClusterNode>; + +#[doc(hidden)] +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum RefreshConnectionType { + // Refresh only user connections + OnlyUserConnection, + // Refresh only management connections + OnlyManagementConnection, + // Refresh all connections: both management and user connections. + AllConnections, +} + +fn failed_management_connection( + addr: &str, + user_conn: ConnectionWithIp>, + err: RedisError, +) -> ConnectAndCheckResult +where + C: ConnectionLike + Send + Clone + Sync + Connect + 'static, +{ + warn!( + "Failed to create management connection for node `{:?}`. Error: `{:?}`", + addr, err + ); + ConnectAndCheckResult::ManagementConnectionFailed { + node: AsyncClusterNode::new(user_conn, None), + err, + } +} + +pub(crate) async fn get_or_create_conn( + addr: &str, + node: Option>, + params: &ClusterParams, + conn_type: RefreshConnectionType, + glide_connection_options: GlideConnectionOptions, +) -> RedisResult> +where + C: ConnectionLike + Send + Clone + Sync + Connect + 'static, +{ + if let Some(node) = node { + // We won't check whether the DNS address of this node has changed and now points to a new IP. + // Instead, we depend on managed Redis services to close the connection for refresh if the node has changed. + match check_node_connections(&node, params, conn_type, addr).await { + None => Ok(node), + Some(conn_type) => connect_and_check( + addr, + params.clone(), + None, + conn_type, + Some(node), + glide_connection_options, + ) + .await + .get_node(), + } + } else { + connect_and_check( + addr, + params.clone(), + None, + conn_type, + None, + glide_connection_options, + ) + .await + .get_node() + } +} + +fn create_async_node( + user_conn: ConnectionWithIp, + management_conn: Option>, +) -> AsyncClusterNode +where + C: ConnectionLike + Connect + Send + Sync + 'static + Clone, +{ + AsyncClusterNode::new( + user_conn.into_future(), + management_conn.map(|conn| conn.into_future()), + ) +} + +pub(crate) async fn connect_and_check_all_connections( + addr: &str, + params: ClusterParams, + socket_addr: Option, + glide_connection_options: GlideConnectionOptions, +) -> ConnectAndCheckResult +where + C: ConnectionLike + Connect + Send + Sync + 'static + Clone, +{ + match future::join( + create_connection( + addr, + params.clone(), + socket_addr, + false, + glide_connection_options.clone(), + ), + create_connection( + addr, + params.clone(), + socket_addr, + true, + glide_connection_options, + ), + ) + .await + { + (Ok(conn_1), Ok(conn_2)) => { + // Both connections were successfully established + let mut user_conn: ConnectionWithIp = conn_1; + let mut management_conn: ConnectionWithIp = conn_2; + if let Err(err) = setup_user_connection(&mut user_conn.conn, params).await { + return err.into(); + } + match setup_management_connection(&mut management_conn.conn).await { + Ok(_) => ConnectAndCheckResult::Success(create_async_node( + user_conn, + Some(management_conn), + )), + Err(err) => failed_management_connection(addr, user_conn.into_future(), err), + } + } + (Ok(mut connection), Err(err)) | (Err(err), Ok(mut connection)) => { + // Only a single connection was successfully established. Use it for the user connection + match setup_user_connection(&mut connection.conn, params).await { + Ok(_) => failed_management_connection(addr, connection.into_future(), err), + Err(err) => err.into(), + } + } + (Err(err_1), Err(err_2)) => { + // Neither of the connections succeeded. + RedisError::from(( + ErrorKind::IoError, + "Failed to refresh both connections", + format!( + "Node: {:?} received errors: `{:?}`, `{:?}`", + addr, err_1, err_2 + ), + )) + .into() + } + } +} + +async fn connect_and_check_only_management_conn( + addr: &str, + params: ClusterParams, + socket_addr: Option, + prev_node: AsyncClusterNode, + disconnect_notifier: Option>, +) -> ConnectAndCheckResult +where + C: ConnectionLike + Connect + Send + Sync + 'static + Clone, +{ + match create_connection::( + addr, + params.clone(), + socket_addr, + true, + GlideConnectionOptions { + push_sender: None, + disconnect_notifier, + }, + ) + .await + { + Err(conn_err) => failed_management_connection(addr, prev_node.user_connection, conn_err), + + Ok(mut connection) => { + if let Err(err) = setup_management_connection(&mut connection.conn).await { + return failed_management_connection(addr, prev_node.user_connection, err); + } + + ConnectAndCheckResult::Success(ClusterNode { + user_connection: prev_node.user_connection, + management_connection: Some(connection.into_future()), + }) + } + } +} + +#[doc(hidden)] +#[must_use] +pub enum ConnectAndCheckResult { + // Returns a node that was fully connected according to the request. + Success(AsyncClusterNode), + // Returns a node that failed to create a management connection, but has a working user connection. + ManagementConnectionFailed { + node: AsyncClusterNode, + err: RedisError, + }, + // Request failed completely, could not return a node with any working connection. + Failed(RedisError), +} + +impl ConnectAndCheckResult { + pub fn get_node(self) -> RedisResult> { + match self { + ConnectAndCheckResult::Success(node) => Ok(node), + ConnectAndCheckResult::ManagementConnectionFailed { node, .. } => Ok(node), + ConnectAndCheckResult::Failed(err) => Err(err), + } + } + + pub fn get_error(self) -> Option { + match self { + ConnectAndCheckResult::Success(_) => None, + ConnectAndCheckResult::ManagementConnectionFailed { err, .. } => Some(err), + ConnectAndCheckResult::Failed(err) => Some(err), + } + } +} + +impl From for ConnectAndCheckResult { + fn from(value: RedisError) -> Self { + ConnectAndCheckResult::Failed(value) + } +} + +impl From> for ConnectAndCheckResult { + fn from(value: AsyncClusterNode) -> Self { + ConnectAndCheckResult::Success(value) + } +} + +impl From>> for ConnectAndCheckResult { + fn from(value: RedisResult>) -> Self { + match value { + Ok(value) => value.into(), + Err(err) => err.into(), + } + } +} + +#[doc(hidden)] +pub async fn connect_and_check( + addr: &str, + params: ClusterParams, + socket_addr: Option, + conn_type: RefreshConnectionType, + node: Option>, + glide_connection_options: GlideConnectionOptions, +) -> ConnectAndCheckResult +where + C: ConnectionLike + Connect + Send + Sync + 'static + Clone, +{ + match conn_type { + RefreshConnectionType::OnlyUserConnection => { + let user_conn = match create_and_setup_user_connection( + addr, + params.clone(), + socket_addr, + glide_connection_options, + ) + .await + { + Ok(tuple) => tuple, + Err(err) => return err.into(), + }; + let management_conn = node.and_then(|node| node.management_connection); + AsyncClusterNode::new(user_conn.into_future(), management_conn).into() + } + RefreshConnectionType::OnlyManagementConnection => { + // Refreshing only the management connection requires the node to exist alongside a user connection. Otherwise, refresh all connections. + match node { + Some(node) => { + connect_and_check_only_management_conn( + addr, + params, + socket_addr, + node, + glide_connection_options.disconnect_notifier, + ) + .await + } + None => { + connect_and_check_all_connections( + addr, + params, + socket_addr, + glide_connection_options, + ) + .await + } + } + } + RefreshConnectionType::AllConnections => { + connect_and_check_all_connections(addr, params, socket_addr, glide_connection_options) + .await + } + } +} + +async fn create_and_setup_user_connection( + node: &str, + params: ClusterParams, + socket_addr: Option, + glide_connection_options: GlideConnectionOptions, +) -> RedisResult> +where + C: ConnectionLike + Connect + Send + 'static, +{ + let mut connection: ConnectionWithIp = create_connection( + node, + params.clone(), + socket_addr, + false, + glide_connection_options, + ) + .await?; + setup_user_connection(&mut connection.conn, params).await?; + Ok(connection) +} + +async fn setup_user_connection(conn: &mut C, params: ClusterParams) -> RedisResult<()> +where + C: ConnectionLike + Connect + Send + 'static, +{ + let read_from_replicas = params.read_from_replicas + != crate::cluster_slotmap::ReadFromReplicaStrategy::AlwaysFromPrimary; + let connection_timeout = params.connection_timeout; + check_connection(conn, connection_timeout).await?; + if read_from_replicas { + // If READONLY is sent to primary nodes, it will have no effect + crate::cmd("READONLY").query_async(conn).await?; + } + Ok(()) +} + +#[doc(hidden)] +pub const MANAGEMENT_CONN_NAME: &str = "glide_management_connection"; + +async fn setup_management_connection(conn: &mut C) -> RedisResult<()> +where + C: ConnectionLike + Connect + Send + 'static, +{ + crate::cmd("CLIENT") + .arg(&["SETNAME", MANAGEMENT_CONN_NAME]) + .query_async(conn) + .await?; + Ok(()) +} + +async fn create_connection( + node: &str, + mut params: ClusterParams, + socket_addr: Option, + is_management: bool, + mut glide_connection_options: GlideConnectionOptions, +) -> RedisResult> +where + C: ConnectionLike + Connect + Send + 'static, +{ + let connection_timeout = params.connection_timeout; + let response_timeout = params.response_timeout; + // ignore pubsub subscriptions and push notifications for management connections + if is_management { + params.pubsub_subscriptions = None; + } + let info = get_connection_info(node, params)?; + // management connection does not require notifications or disconnect notifications + if is_management { + glide_connection_options.disconnect_notifier = None; + } + C::connect( + info, + response_timeout, + connection_timeout, + socket_addr, + glide_connection_options, + ) + .await + .map(|conn| conn.into()) +} + +/// The function returns None if the checked connection/s are healthy. Otherwise, it returns the type of the unhealthy connection/s. +#[allow(dead_code)] +#[doc(hidden)] +pub async fn check_node_connections( + node: &AsyncClusterNode, + params: &ClusterParams, + conn_type: RefreshConnectionType, + address: &str, +) -> Option +where + C: ConnectionLike + Send + 'static + Clone, +{ + let timeout = params.connection_timeout; + let (check_mgmt_connection, check_user_connection) = match conn_type { + RefreshConnectionType::OnlyUserConnection => (false, true), + RefreshConnectionType::OnlyManagementConnection => (true, false), + RefreshConnectionType::AllConnections => (true, true), + }; + let check = |conn, timeout, conn_type| async move { + match check_connection(&mut conn.await, timeout).await { + Ok(_) => false, + Err(err) => { + warn!( + "The {} connection for node {} is unhealthy. Error: {:?}", + conn_type, address, err + ); + true + } + } + }; + let (mgmt_failed, user_failed) = join!( + async { + if !check_mgmt_connection { + return false; + } + match node.management_connection.clone() { + Some(connection) => check(connection.conn, timeout, "management").await, + None => { + warn!("The management connection for node {} isn't set", address); + true + } + } + }, + async { + if !check_user_connection { + return false; + } + let conn = node.user_connection.conn.clone(); + check(conn, timeout, "user").await + }, + ); + + match (mgmt_failed, user_failed) { + (true, true) => Some(RefreshConnectionType::AllConnections), + (true, false) => Some(RefreshConnectionType::OnlyManagementConnection), + (false, true) => Some(RefreshConnectionType::OnlyUserConnection), + (false, false) => None, + } +} + +async fn check_connection(conn: &mut C, timeout: std::time::Duration) -> RedisResult<()> +where + C: ConnectionLike + Send + 'static, +{ + Runtime::locate() + .timeout(timeout, crate::cmd("PING").query_async::<_, String>(conn)) + .await??; + Ok(()) +} + +/// Splits a string address into host and port. If the passed address cannot be parsed, None is returned. +/// [addr] should be in the following format: ":". +pub(crate) fn get_host_and_port_from_addr(addr: &str) -> Option<(&str, u16)> { + let parts: Vec<&str> = addr.split(':').collect(); + if parts.len() != 2 { + return None; + } + let host = parts.first().unwrap(); + let port = parts.get(1).unwrap(); + port.parse::().ok().map(|port| (*host, port)) +} diff --git a/glide-core/redis-rs/redis/src/cluster_async/mod.rs b/glide-core/redis-rs/redis/src/cluster_async/mod.rs new file mode 100644 index 0000000000..be7beb79b7 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_async/mod.rs @@ -0,0 +1,2656 @@ +//! This module provides async functionality for Redis Cluster. +//! +//! By default, [`ClusterConnection`] makes use of [`MultiplexedConnection`] and maintains a pool +//! of connections to each node in the cluster. While it generally behaves similarly to +//! the sync cluster module, certain commands do not route identically, due most notably to +//! a current lack of support for routing commands to multiple nodes. +//! +//! Also note that pubsub functionality is not currently provided by this module. +//! +//! # Example +//! ```rust,no_run +//! use redis::cluster::ClusterClient; +//! use redis::AsyncCommands; +//! +//! async fn fetch_an_integer() -> String { +//! let nodes = vec!["redis://127.0.0.1/"]; +//! let client = ClusterClient::new(nodes).unwrap(); +//! let mut connection = client.get_async_connection(None).await.unwrap(); +//! let _: () = connection.set("test", "test_data").await.unwrap(); +//! let rv: String = connection.get("test").await.unwrap(); +//! return rv; +//! } +//! ``` + +mod connections_container; +mod connections_logic; +/// Exposed only for testing. +pub mod testing { + pub use super::connections_container::ConnectionWithIp; + pub use super::connections_logic::*; +} +use crate::{ + client::GlideConnectionOptions, + cluster_routing::{Routable, RoutingInfo}, + cluster_slotmap::SlotMap, + cluster_topology::SLOT_SIZE, + cmd, + commands::cluster_scan::{cluster_scan, ClusterScanArgs, ObjectType, ScanStateRC}, + FromRedisValue, InfoDict, ToRedisArgs, +}; +#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] +use async_std::task::{spawn, JoinHandle}; +use dashmap::DashMap; +#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] +use futures::executor::block_on; +use std::{ + collections::{HashMap, HashSet}, + fmt, io, mem, + net::{IpAddr, SocketAddr}, + pin::Pin, + sync::{ + atomic::{self, AtomicUsize, Ordering}, + Arc, Mutex, + }, + task::{self, Poll}, + time::SystemTime, +}; +#[cfg(feature = "tokio-comp")] +use tokio::task::JoinHandle; + +#[cfg(feature = "tokio-comp")] +use crate::aio::DisconnectNotifier; + +use crate::{ + aio::{get_socket_addrs, ConnectionLike, MultiplexedConnection, Runtime}, + cluster::slot_cmd, + cluster_async::connections_logic::{ + get_host_and_port_from_addr, get_or_create_conn, ConnectionFuture, RefreshConnectionType, + }, + cluster_client::{ClusterParams, RetryParams}, + cluster_routing::{ + self, MultipleNodeRoutingInfo, Redirect, ResponsePolicy, Route, SingleNodeRoutingInfo, + SlotAddr, + }, + cluster_topology::{ + calculate_topology, get_slot, SlotRefreshState, DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES, + DEFAULT_REFRESH_SLOTS_RETRY_INITIAL_INTERVAL, DEFAULT_REFRESH_SLOTS_RETRY_MAX_INTERVAL, + }, + connection::{PubSubSubscriptionInfo, PubSubSubscriptionKind}, + push_manager::PushInfo, + Cmd, ConnectionInfo, ErrorKind, IntoConnectionInfo, RedisError, RedisFuture, RedisResult, + Value, +}; +use futures::stream::{FuturesUnordered, StreamExt}; +use std::time::Duration; + +#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] +use crate::aio::{async_std::AsyncStd, RedisRuntime}; +#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] +use backoff_std_async::future::retry; +#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] +use backoff_std_async::{Error as BackoffError, ExponentialBackoff}; + +#[cfg(feature = "tokio-comp")] +use async_trait::async_trait; +#[cfg(feature = "tokio-comp")] +use backoff_tokio::future::retry; +#[cfg(feature = "tokio-comp")] +use backoff_tokio::{Error as BackoffError, ExponentialBackoff}; +#[cfg(feature = "tokio-comp")] +use tokio::{sync::Notify, time::timeout}; + +use dispose::{Disposable, Dispose}; +use futures::{future::BoxFuture, prelude::*, ready}; +use pin_project_lite::pin_project; +use tokio::sync::{ + mpsc, + oneshot::{self, Receiver}, + RwLock, +}; +use tracing::{debug, info, trace, warn}; + +use self::{ + connections_container::{ConnectionAndAddress, ConnectionType, ConnectionsMap}, + connections_logic::connect_and_check, +}; + +/// This represents an async Redis Cluster connection. It stores the +/// underlying connections maintained for each node in the cluster, as well +/// as common parameters for connecting to nodes and executing commands. +#[derive(Clone)] +pub struct ClusterConnection(mpsc::Sender>); + +impl ClusterConnection +where + C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static, +{ + pub(crate) async fn new( + initial_nodes: &[ConnectionInfo], + cluster_params: ClusterParams, + push_sender: Option>, + ) -> RedisResult> { + ClusterConnInner::new(initial_nodes, cluster_params, push_sender) + .await + .map(|inner| { + let (tx, mut rx) = mpsc::channel::>(100); + let stream = async move { + let _ = stream::poll_fn(move |cx| rx.poll_recv(cx)) + .map(Ok) + .forward(inner) + .await; + }; + #[cfg(feature = "tokio-comp")] + tokio::spawn(stream); + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + AsyncStd::spawn(stream); + + ClusterConnection(tx) + }) + } + + /// Special handling for `SCAN` command, using `cluster_scan`. + /// If you wish to use a match pattern, use [`cluster_scan_with_pattern`]. + /// Perform a `SCAN` command on a Redis cluster, using scan state object in order to handle changes in topology + /// and make sure that all keys that were in the cluster from start to end of the scan are scanned. + /// In order to make sure all keys in the cluster scanned, topology refresh occurs more frequently and may affect performance. + /// + /// # Arguments + /// + /// * `scan_state_rc` - A reference to the scan state, For initiating new scan send [`ScanStateRC::new()`], + /// for each subsequent iteration use the returned [`ScanStateRC`]. + /// * `count` - An optional count of keys requested, + /// the amount returned can vary and not obligated to return exactly count. + /// * `object_type` - An optional [`ObjectType`] enum of requested key redis type. + /// + /// # Returns + /// + /// A [`ScanStateRC`] for the updated state of the scan and the vector of keys that were found in the scan. + /// structure of returned value: + /// `Ok((ScanStateRC, Vec))` + /// + /// When the scan is finished [`ScanStateRC`] will be None, and can be checked by calling `scan_state_wrapper.is_finished()`. + /// + /// # Example + /// ```rust,no_run + /// use redis::cluster::ClusterClient; + /// use redis::{ScanStateRC, FromRedisValue, from_redis_value, Value, ObjectType}; + /// + /// async fn scan_all_cluster() -> Vec { + /// let nodes = vec!["redis://127.0.0.1/"]; + /// let client = ClusterClient::new(nodes).unwrap(); + /// let mut connection = client.get_async_connection(None).await.unwrap(); + /// let mut scan_state_rc = ScanStateRC::new(); + /// let mut keys: Vec = vec![]; + /// loop { + /// let (next_cursor, scan_keys): (ScanStateRC, Vec) = + /// connection.cluster_scan(scan_state_rc, None, None).await.unwrap(); + /// scan_state_rc = next_cursor; + /// let mut scan_keys = scan_keys + /// .into_iter() + /// .map(|v| from_redis_value(&v).unwrap()) + /// .collect::>(); // Change the type of `keys` to `Vec` + /// keys.append(&mut scan_keys); + /// if scan_state_rc.is_finished() { + /// break; + /// } + /// } + /// keys + /// } + /// ``` + pub async fn cluster_scan( + &mut self, + scan_state_rc: ScanStateRC, + count: Option, + object_type: Option, + ) -> RedisResult<(ScanStateRC, Vec)> { + let cluster_scan_args = ClusterScanArgs::new(scan_state_rc, None, count, object_type); + self.route_cluster_scan(cluster_scan_args).await + } + + /// Special handling for `SCAN` command, using `cluster_scan_with_pattern`. + /// It is a special case of [`cluster_scan`], with an additional match pattern. + /// Perform a `SCAN` command on a Redis cluster, using scan state object in order to handle changes in topology + /// and make sure that all keys that were in the cluster from start to end of the scan are scanned. + /// In order to make sure all keys in the cluster scanned, topology refresh occurs more frequently and may affect performance. + /// + /// # Arguments + /// + /// * `scan_state_rc` - A reference to the scan state, For initiating new scan send [`ScanStateRC::new()`], + /// for each subsequent iteration use the returned [`ScanStateRC`]. + /// * `match_pattern` - A match pattern of requested keys. + /// * `count` - An optional count of keys requested, + /// the amount returned can vary and not obligated to return exactly count. + /// * `object_type` - An optional [`ObjectType`] enum of requested key redis type. + /// + /// # Returns + /// + /// A [`ScanStateRC`] for the updated state of the scan and the vector of keys that were found in the scan. + /// structure of returned value: + /// `Ok((ScanStateRC, Vec))` + /// + /// When the scan is finished [`ScanStateRC`] will be None, and can be checked by calling `scan_state_wrapper.is_finished()`. + /// + /// # Example + /// ```rust,no_run + /// use redis::cluster::ClusterClient; + /// use redis::{ScanStateRC, FromRedisValue, from_redis_value, Value, ObjectType}; + /// + /// async fn scan_all_cluster() -> Vec { + /// let nodes = vec!["redis://127.0.0.1/"]; + /// let client = ClusterClient::new(nodes).unwrap(); + /// let mut connection = client.get_async_connection(None).await.unwrap(); + /// let mut scan_state_rc = ScanStateRC::new(); + /// let mut keys: Vec = vec![]; + /// loop { + /// let (next_cursor, scan_keys): (ScanStateRC, Vec) = + /// connection.cluster_scan_with_pattern(scan_state_rc, b"my_key", None, None).await.unwrap(); + /// scan_state_rc = next_cursor; + /// let mut scan_keys = scan_keys + /// .into_iter() + /// .map(|v| from_redis_value(&v).unwrap()) + /// .collect::>(); // Change the type of `keys` to `Vec` + /// keys.append(&mut scan_keys); + /// if scan_state_rc.is_finished() { + /// break; + /// } + /// } + /// keys + /// } + /// ``` + pub async fn cluster_scan_with_pattern( + &mut self, + scan_state_rc: ScanStateRC, + match_pattern: K, + count: Option, + object_type: Option, + ) -> RedisResult<(ScanStateRC, Vec)> { + let cluster_scan_args = ClusterScanArgs::new( + scan_state_rc, + Some(match_pattern.to_redis_args().concat()), + count, + object_type, + ); + self.route_cluster_scan(cluster_scan_args).await + } + + /// Route cluster scan to be handled by internal cluster_scan command + async fn route_cluster_scan( + &mut self, + cluster_scan_args: ClusterScanArgs, + ) -> RedisResult<(ScanStateRC, Vec)> { + let (sender, receiver) = oneshot::channel(); + self.0 + .send(Message { + cmd: CmdArg::ClusterScan { cluster_scan_args }, + sender, + }) + .await + .map_err(|_| { + RedisError::from(io::Error::new( + io::ErrorKind::BrokenPipe, + "redis_cluster: Unable to send command", + )) + })?; + receiver + .await + .unwrap_or_else(|_| { + Err(RedisError::from(io::Error::new( + io::ErrorKind::BrokenPipe, + "redis_cluster: Unable to receive command", + ))) + }) + .map(|response| match response { + Response::ClusterScanResult(new_scan_state_ref, key) => (new_scan_state_ref, key), + Response::Single(_) => unreachable!(), + Response::Multiple(_) => unreachable!(), + }) + } + + /// Send a command to the given `routing`. If `routing` is [None], it will be computed from `cmd`. + pub async fn route_command( + &mut self, + cmd: &Cmd, + routing: cluster_routing::RoutingInfo, + ) -> RedisResult { + trace!("route_command"); + let (sender, receiver) = oneshot::channel(); + self.0 + .send(Message { + cmd: CmdArg::Cmd { + cmd: Arc::new(cmd.clone()), + routing: routing.into(), + }, + sender, + }) + .await + .map_err(|_| { + RedisError::from(io::Error::new( + io::ErrorKind::BrokenPipe, + "redis_cluster: Unable to send command", + )) + })?; + receiver + .await + .unwrap_or_else(|_| { + Err(RedisError::from(io::Error::new( + io::ErrorKind::BrokenPipe, + "redis_cluster: Unable to receive command", + ))) + }) + .map(|response| match response { + Response::Single(value) => value, + Response::Multiple(_) => unreachable!(), + Response::ClusterScanResult(_, _) => unreachable!(), + }) + } + + /// Send commands in `pipeline` to the given `route`. If `route` is [None], it will be computed from `pipeline`. + pub async fn route_pipeline<'a>( + &'a mut self, + pipeline: &'a crate::Pipeline, + offset: usize, + count: usize, + route: SingleNodeRoutingInfo, + ) -> RedisResult> { + let (sender, receiver) = oneshot::channel(); + self.0 + .send(Message { + cmd: CmdArg::Pipeline { + pipeline: Arc::new(pipeline.clone()), + offset, + count, + route: route.into(), + }, + sender, + }) + .await + .map_err(|_| RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe)))?; + + receiver + .await + .unwrap_or_else(|_| Err(RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe)))) + .map(|response| match response { + Response::Multiple(values) => values, + Response::Single(_) => unreachable!(), + Response::ClusterScanResult(_, _) => unreachable!(), + }) + } +} + +#[cfg(feature = "tokio-comp")] +#[derive(Clone)] +struct TokioDisconnectNotifier { + disconnect_notifier: Arc, +} + +#[cfg(feature = "tokio-comp")] +#[async_trait] +impl DisconnectNotifier for TokioDisconnectNotifier { + fn notify_disconnect(&mut self) { + self.disconnect_notifier.notify_one(); + } + + async fn wait_for_disconnect_with_timeout(&self, max_wait: &Duration) { + let _ = timeout(*max_wait, async { + self.disconnect_notifier.notified().await; + }) + .await; + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + +#[cfg(feature = "tokio-comp")] +impl TokioDisconnectNotifier { + fn new() -> TokioDisconnectNotifier { + TokioDisconnectNotifier { + disconnect_notifier: Arc::new(Notify::new()), + } + } +} + +type ConnectionMap = connections_container::ConnectionsMap>; +type ConnectionsContainer = + self::connections_container::ConnectionsContainer>; + +pub(crate) struct InnerCore { + pub(crate) conn_lock: RwLock>, + cluster_params: ClusterParams, + pending_requests: Mutex>>, + slot_refresh_state: SlotRefreshState, + initial_nodes: Vec, + subscriptions_by_address: RwLock>, + unassigned_subscriptions: RwLock, + glide_connection_options: GlideConnectionOptions, +} + +pub(crate) type Core = Arc>; + +impl InnerCore +where + C: ConnectionLike + Connect + Clone + Send + Sync + 'static, +{ + // return address of node for slot + pub(crate) async fn get_address_from_slot( + &self, + slot: u16, + slot_addr: SlotAddr, + ) -> Option { + self.conn_lock + .read() + .await + .slot_map + .get_node_address_for_slot(slot, slot_addr) + } + + // return epoch of node + pub(crate) async fn get_address_epoch(&self, node_address: &str) -> Result { + let command = cmd("CLUSTER").arg("INFO").to_owned(); + let node_conn = self + .conn_lock + .read() + .await + .connection_for_address(node_address) + .ok_or(RedisError::from(( + ErrorKind::ResponseError, + "Failed to parse cluster info", + )))?; + + let cluster_info = node_conn.1.await.req_packed_command(&command).await; + match cluster_info { + Ok(value) => { + let info_dict: Result = + FromRedisValue::from_redis_value(&value); + if let Ok(info_dict) = info_dict { + let epoch = info_dict.get("cluster_my_epoch"); + if let Some(epoch) = epoch { + Ok(epoch) + } else { + Err(RedisError::from(( + ErrorKind::ResponseError, + "Failed to get epoch from cluster info", + ))) + } + } else { + Err(RedisError::from(( + ErrorKind::ResponseError, + "Failed to parse cluster info", + ))) + } + } + Err(redis_error) => Err(redis_error), + } + } + + // return slots of node + pub(crate) async fn get_slots_of_address(&self, node_address: &str) -> Vec { + self.conn_lock + .read() + .await + .slot_map + .get_slots_of_node(node_address) + } +} + +pub(crate) struct ClusterConnInner { + pub(crate) inner: Core, + state: ConnectionState, + #[allow(clippy::complexity)] + in_flight_requests: stream::FuturesUnordered>>>, + refresh_error: Option, + // Handler of the periodic check task. + periodic_checks_handler: Option>, + // Handler of fast connection validation task + connections_validation_handler: Option>, +} + +impl Dispose for ClusterConnInner { + fn dispose(self) { + if let Some(handle) = self.periodic_checks_handler { + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + block_on(handle.cancel()); + #[cfg(feature = "tokio-comp")] + handle.abort() + } + if let Some(handle) = self.connections_validation_handler { + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + block_on(handle.cancel()); + #[cfg(feature = "tokio-comp")] + handle.abort() + } + } +} + +#[derive(Clone)] +pub(crate) enum InternalRoutingInfo { + SingleNode(InternalSingleNodeRouting), + MultiNode((MultipleNodeRoutingInfo, Option)), +} + +#[derive(PartialEq, Clone, Debug)] +/// Represents different policies for refreshing the cluster slots. +pub(crate) enum RefreshPolicy { + /// `Throttable` indicates that the refresh operation can be throttled, + /// meaning it can be delayed or rate-limited if necessary. + Throttable, + /// `NotThrottable` indicates that the refresh operation should not be throttled, + /// meaning it should be executed immediately without any delay or rate-limiting. + NotThrottable, +} + +impl From for InternalRoutingInfo { + fn from(value: cluster_routing::RoutingInfo) -> Self { + match value { + cluster_routing::RoutingInfo::SingleNode(route) => { + InternalRoutingInfo::SingleNode(route.into()) + } + cluster_routing::RoutingInfo::MultiNode(routes) => { + InternalRoutingInfo::MultiNode(routes) + } + } + } +} + +impl From> for InternalRoutingInfo { + fn from(value: InternalSingleNodeRouting) -> Self { + InternalRoutingInfo::SingleNode(value) + } +} + +#[derive(Clone)] +pub(crate) enum InternalSingleNodeRouting { + Random, + SpecificNode(Route), + ByAddress(String), + Connection { + address: String, + conn: ConnectionFuture, + }, + Redirect { + redirect: Redirect, + previous_routing: Box>, + }, +} + +impl Default for InternalSingleNodeRouting { + fn default() -> Self { + Self::Random + } +} + +impl From for InternalSingleNodeRouting { + fn from(value: SingleNodeRoutingInfo) -> Self { + match value { + SingleNodeRoutingInfo::Random => InternalSingleNodeRouting::Random, + SingleNodeRoutingInfo::SpecificNode(route) => { + InternalSingleNodeRouting::SpecificNode(route) + } + SingleNodeRoutingInfo::RandomPrimary => { + InternalSingleNodeRouting::SpecificNode(Route::new_random_primary()) + } + SingleNodeRoutingInfo::ByAddress { host, port } => { + InternalSingleNodeRouting::ByAddress(format!("{host}:{port}")) + } + } + } +} + +#[derive(Clone)] +enum CmdArg { + Cmd { + cmd: Arc, + routing: InternalRoutingInfo, + }, + Pipeline { + pipeline: Arc, + offset: usize, + count: usize, + route: InternalSingleNodeRouting, + }, + ClusterScan { + // struct containing the arguments for the cluster scan command - scan state cursor, match pattern, count and object type. + cluster_scan_args: ClusterScanArgs, + }, +} + +fn route_for_pipeline(pipeline: &crate::Pipeline) -> RedisResult> { + fn route_for_command(cmd: &Cmd) -> Option { + match cluster_routing::RoutingInfo::for_routable(cmd) { + Some(cluster_routing::RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) => None, + Some(cluster_routing::RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(route), + )) => Some(route), + Some(cluster_routing::RoutingInfo::SingleNode( + SingleNodeRoutingInfo::RandomPrimary, + )) => Some(Route::new_random_primary()), + Some(cluster_routing::RoutingInfo::MultiNode(_)) => None, + Some(cluster_routing::RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + .. + })) => None, + None => None, + } + } + + // Find first specific slot and send to it. There's no need to check If later commands + // should be routed to a different slot, since the server will return an error indicating this. + pipeline.cmd_iter().map(route_for_command).try_fold( + None, + |chosen_route, next_cmd_route| match (chosen_route, next_cmd_route) { + (None, _) => Ok(next_cmd_route), + (_, None) => Ok(chosen_route), + (Some(chosen_route), Some(next_cmd_route)) => { + if chosen_route.slot() != next_cmd_route.slot() { + Err((ErrorKind::CrossSlot, "Received crossed slots in pipeline").into()) + } else if chosen_route.slot_addr() == SlotAddr::ReplicaOptional { + Ok(Some(next_cmd_route)) + } else { + Ok(Some(chosen_route)) + } + } + }, + ) +} + +fn boxed_sleep(duration: Duration) -> BoxFuture<'static, ()> { + #[cfg(feature = "tokio-comp")] + return Box::pin(tokio::time::sleep(duration)); + + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + return Box::pin(async_std::task::sleep(duration)); +} + +pub(crate) enum Response { + Single(Value), + ClusterScanResult(ScanStateRC, Vec), + Multiple(Vec), +} + +pub(crate) enum OperationTarget { + Node { address: String }, + FanOut, + NotFound, +} +type OperationResult = Result; + +impl From for OperationTarget { + fn from(address: String) -> Self { + OperationTarget::Node { address } + } +} + +struct Message { + cmd: CmdArg, + sender: oneshot::Sender>, +} + +enum RecoverFuture { + RecoverSlots(BoxFuture<'static, RedisResult<()>>), + Reconnect(BoxFuture<'static, ()>), +} + +enum ConnectionState { + PollComplete, + Recover(RecoverFuture), +} + +impl fmt::Debug for ConnectionState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}", + match self { + ConnectionState::PollComplete => "PollComplete", + ConnectionState::Recover(_) => "Recover", + } + ) + } +} + +#[derive(Clone)] +struct RequestInfo { + cmd: CmdArg, +} + +impl RequestInfo { + fn set_redirect(&mut self, redirect: Option) { + if let Some(redirect) = redirect { + match &mut self.cmd { + CmdArg::Cmd { routing, .. } => match routing { + InternalRoutingInfo::SingleNode(route) => { + let redirect = InternalSingleNodeRouting::Redirect { + redirect, + previous_routing: Box::new(std::mem::take(route)), + } + .into(); + *routing = redirect; + } + InternalRoutingInfo::MultiNode(_) => { + panic!("Cannot redirect multinode requests") + } + }, + CmdArg::Pipeline { route, .. } => { + let redirect = InternalSingleNodeRouting::Redirect { + redirect, + previous_routing: Box::new(std::mem::take(route)), + }; + *route = redirect; + } + // cluster_scan is sent as a normal command internally so we will not reach that point. + CmdArg::ClusterScan { .. } => { + unreachable!() + } + } + } + } + + fn reset_routing(&mut self) { + let fix_route = |route: &mut InternalSingleNodeRouting| { + match route { + InternalSingleNodeRouting::Redirect { + previous_routing, .. + } => { + let previous_routing = std::mem::take(previous_routing.as_mut()); + *route = previous_routing; + } + // If a specific connection is specified, then reconnecting without resetting the routing + // will mean that the request is still routed to the old connection. + InternalSingleNodeRouting::Connection { address, .. } => { + *route = InternalSingleNodeRouting::ByAddress(address.to_string()); + } + _ => {} + } + }; + match &mut self.cmd { + CmdArg::Cmd { routing, .. } => { + if let InternalRoutingInfo::SingleNode(route) = routing { + fix_route(route); + } + } + CmdArg::Pipeline { route, .. } => { + fix_route(route); + } + // cluster_scan is sent as a normal command internally so we will not reach that point. + CmdArg::ClusterScan { .. } => { + unreachable!() + } + } + } +} + +pin_project! { + #[project = RequestStateProj] + enum RequestState { + None, + Future { + #[pin] + future: F, + }, + Sleep { + #[pin] + sleep: BoxFuture<'static, ()>, + }, + } +} + +struct PendingRequest { + retry: u32, + sender: oneshot::Sender>, + info: RequestInfo, +} + +pin_project! { + struct Request { + retry_params: RetryParams, + request: Option>, + #[pin] + future: RequestState>, + } +} + +#[must_use] +enum Next { + Retry { + request: PendingRequest, + }, + RetryBusyLoadingError { + request: PendingRequest, + address: String, + }, + Reconnect { + // if not set, then a reconnect should happen without sending a request afterwards + request: Option>, + target: String, + }, + RefreshSlots { + // if not set, then a slot refresh should happen without sending a request afterwards + request: Option>, + sleep_duration: Option, + }, + ReconnectToInitialNodes { + // if not set, then a reconnect should happen without sending a request afterwards + request: Option>, + }, + Done, +} + +impl Future for Request { + type Output = Next; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll { + let mut this = self.as_mut().project(); + // If the sender is closed, the caller is no longer waiting for the reply, and it is ambiguous + // whether they expect the side-effect of the request to happen or not. + if this.request.is_none() || this.request.as_ref().unwrap().sender.is_closed() { + return Poll::Ready(Next::Done); + } + let future = match this.future.as_mut().project() { + RequestStateProj::Future { future } => future, + RequestStateProj::Sleep { sleep } => { + ready!(sleep.poll(cx)); + return Next::Retry { + request: self.project().request.take().unwrap(), + } + .into(); + } + _ => panic!("Request future must be Some"), + }; + match ready!(future.poll(cx)) { + Ok(item) => { + self.respond(Ok(item)); + Next::Done.into() + } + Err((target, err)) => { + let request = this.request.as_mut().unwrap(); + // TODO - would be nice if we didn't need to repeat this code twice, with & without retries. + if request.retry >= this.retry_params.number_of_retries { + let next = if err.kind() == ErrorKind::AllConnectionsUnavailable { + Next::ReconnectToInitialNodes { request: None }.into() + } else if matches!(err.retry_method(), crate::types::RetryMethod::MovedRedirect) + || matches!(target, OperationTarget::NotFound) + { + Next::RefreshSlots { + request: None, + sleep_duration: None, + } + .into() + } else if matches!(err.retry_method(), crate::types::RetryMethod::Reconnect) { + if let OperationTarget::Node { address } = target { + Next::Reconnect { + request: None, + target: address, + } + .into() + } else { + Next::Done.into() + } + } else { + Next::Done.into() + }; + self.respond(Err(err)); + return next; + } + request.retry = request.retry.saturating_add(1); + + if err.kind() == ErrorKind::AllConnectionsUnavailable { + return Next::ReconnectToInitialNodes { + request: Some(this.request.take().unwrap()), + } + .into(); + } + + let sleep_duration = this.retry_params.wait_time_for_retry(request.retry); + + let address = match target { + OperationTarget::Node { address } => address, + OperationTarget::FanOut => { + trace!("Request error `{}` multi-node request", err); + + // Fanout operation are retried per internal request, and don't need additional retries. + self.respond(Err(err)); + return Next::Done.into(); + } + OperationTarget::NotFound => { + // TODO - this is essentially a repeat of the retirable error. probably can remove duplication. + let mut request = this.request.take().unwrap(); + request.info.reset_routing(); + return Next::RefreshSlots { + request: Some(request), + sleep_duration: Some(sleep_duration), + } + .into(); + } + }; + + warn!("Received request error {} on node {:?}.", err, address); + + match err.retry_method() { + crate::types::RetryMethod::AskRedirect => { + let mut request = this.request.take().unwrap(); + request.info.set_redirect( + err.redirect_node() + .map(|(node, _slot)| Redirect::Ask(node.to_string())), + ); + Next::Retry { request }.into() + } + crate::types::RetryMethod::MovedRedirect => { + let mut request = this.request.take().unwrap(); + request.info.set_redirect( + err.redirect_node() + .map(|(node, _slot)| Redirect::Moved(node.to_string())), + ); + Next::RefreshSlots { + request: Some(request), + sleep_duration: None, + } + .into() + } + crate::types::RetryMethod::WaitAndRetry => { + let sleep_duration = this.retry_params.wait_time_for_retry(request.retry); + // Sleep and retry. + this.future.set(RequestState::Sleep { + sleep: boxed_sleep(sleep_duration), + }); + self.poll(cx) + } + crate::types::RetryMethod::Reconnect => { + let mut request = this.request.take().unwrap(); + // TODO should we reset the redirect here? + request.info.reset_routing(); + warn!("disconnected from {:?}", address); + Next::Reconnect { + request: Some(request), + target: address, + } + .into() + } + crate::types::RetryMethod::WaitAndRetryOnPrimaryRedirectOnReplica => { + Next::RetryBusyLoadingError { + request: this.request.take().unwrap(), + address, + } + .into() + } + crate::types::RetryMethod::RetryImmediately => Next::Retry { + request: this.request.take().unwrap(), + } + .into(), + crate::types::RetryMethod::NoRetry => { + self.respond(Err(err)); + Next::Done.into() + } + } + } + } + } +} + +impl Request { + fn respond(self: Pin<&mut Self>, msg: RedisResult) { + // If `send` errors the receiver has dropped and thus does not care about the message + let _ = self + .project() + .request + .take() + .expect("Result should only be sent once") + .sender + .send(msg); + } +} + +enum ConnectionCheck { + Found((String, ConnectionFuture)), + OnlyAddress(String), + RandomConnection, +} + +impl ClusterConnInner +where + C: ConnectionLike + Connect + Clone + Send + Sync + 'static, +{ + async fn new( + initial_nodes: &[ConnectionInfo], + cluster_params: ClusterParams, + push_sender: Option>, + ) -> RedisResult> { + let disconnect_notifier = { + #[cfg(feature = "tokio-comp")] + { + Some::>(Box::new(TokioDisconnectNotifier::new())) + } + #[cfg(not(feature = "tokio-comp"))] + None + }; + + let glide_connection_options = GlideConnectionOptions { + push_sender, + disconnect_notifier, + }; + + let connections = Self::create_initial_connections( + initial_nodes, + &cluster_params, + glide_connection_options.clone(), + ) + .await?; + + let topology_checks_interval = cluster_params.topology_checks_interval; + let slots_refresh_rate_limiter = cluster_params.slots_refresh_rate_limit; + let inner = Arc::new(InnerCore { + conn_lock: RwLock::new(ConnectionsContainer::new( + Default::default(), + connections, + cluster_params.read_from_replicas, + 0, + )), + cluster_params: cluster_params.clone(), + pending_requests: Mutex::new(Vec::new()), + slot_refresh_state: SlotRefreshState::new(slots_refresh_rate_limiter), + initial_nodes: initial_nodes.to_vec(), + unassigned_subscriptions: RwLock::new( + if let Some(subs) = cluster_params.pubsub_subscriptions { + subs.clone() + } else { + PubSubSubscriptionInfo::new() + }, + ), + subscriptions_by_address: RwLock::new(Default::default()), + glide_connection_options, + }); + let mut connection = ClusterConnInner { + inner, + in_flight_requests: Default::default(), + refresh_error: None, + state: ConnectionState::PollComplete, + periodic_checks_handler: None, + connections_validation_handler: None, + }; + Self::refresh_slots_and_subscriptions_with_retries( + connection.inner.clone(), + &RefreshPolicy::NotThrottable, + ) + .await?; + + if let Some(duration) = topology_checks_interval { + let periodic_task = + ClusterConnInner::periodic_topology_check(connection.inner.clone(), duration); + #[cfg(feature = "tokio-comp")] + { + connection.periodic_checks_handler = Some(tokio::spawn(periodic_task)); + } + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + { + connection.periodic_checks_handler = Some(spawn(periodic_task)); + } + } + + let connections_validation_interval = cluster_params.connections_validation_interval; + if let Some(duration) = connections_validation_interval { + let connections_validation_handler = + ClusterConnInner::connections_validation_task(connection.inner.clone(), duration); + #[cfg(feature = "tokio-comp")] + { + connection.connections_validation_handler = + Some(tokio::spawn(connections_validation_handler)); + } + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + { + connection.connections_validation_handler = + Some(spawn(connections_validation_handler)); + } + } + + Ok(Disposable::new(connection)) + } + + /// Go through each of the initial nodes and attempt to retrieve all IP entries from them. + /// If there's a DNS endpoint that directs to several IP addresses, add all addresses to the initial nodes list. + /// Returns a vector of tuples, each containing a node's address (including the hostname) and its corresponding SocketAddr if retrieved. + pub(crate) async fn try_to_expand_initial_nodes( + initial_nodes: &[ConnectionInfo], + ) -> Vec<(String, Option)> { + stream::iter(initial_nodes) + .fold( + Vec::with_capacity(initial_nodes.len()), + |mut acc, info| async { + let (host, port) = match &info.addr { + crate::ConnectionAddr::Tcp(host, port) => (host, port), + crate::ConnectionAddr::TcpTls { + host, + port, + insecure: _, + tls_params: _, + } => (host, port), + crate::ConnectionAddr::Unix(_) => { + // We don't support multiple addresses for a Unix address. Store the initial node address and continue + acc.push((info.addr.to_string(), None)); + return acc; + } + }; + match get_socket_addrs(host, *port).await { + Ok(socket_addrs) => { + for addr in socket_addrs { + acc.push((info.addr.to_string(), Some(addr))); + } + } + Err(_) => { + // Couldn't find socket addresses, store the initial node address and continue + acc.push((info.addr.to_string(), None)); + } + }; + acc + }, + ) + .await + } + + async fn create_initial_connections( + initial_nodes: &[ConnectionInfo], + params: &ClusterParams, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult> { + let initial_nodes: Vec<(String, Option)> = + Self::try_to_expand_initial_nodes(initial_nodes).await; + let connections = stream::iter(initial_nodes.iter().cloned()) + .map(|(node_addr, socket_addr)| { + let mut params: ClusterParams = params.clone(); + let glide_connection_options = glide_connection_options.clone(); + // set subscriptions to none, they will be applied upon the topology discovery + params.pubsub_subscriptions = None; + + async move { + let result = connect_and_check( + &node_addr, + params, + socket_addr, + RefreshConnectionType::AllConnections, + None, + glide_connection_options, + ) + .await + .get_node(); + let node_address = if let Some(socket_addr) = socket_addr { + socket_addr.to_string() + } else { + node_addr + }; + result.map(|node| (node_address, node)) + } + }) + .buffer_unordered(initial_nodes.len()) + .fold( + ( + ConnectionsMap(DashMap::with_capacity(initial_nodes.len())), + None, + ), + |connections: (ConnectionMap, Option), addr_conn_res| async move { + match addr_conn_res { + Ok((addr, node)) => { + connections.0 .0.insert(addr, node); + (connections.0, None) + } + Err(e) => (connections.0, Some(e.to_string())), + } + }, + ) + .await; + if connections.0 .0.is_empty() { + return Err(RedisError::from(( + ErrorKind::IoError, + "Failed to create initial connections", + connections.1.unwrap_or("".to_string()), + ))); + } + info!("Connected to initial nodes:\n{}", connections.0); + Ok(connections.0) + } + + fn reconnect_to_initial_nodes(&mut self) -> impl Future { + let inner = self.inner.clone(); + async move { + let connection_map = match Self::create_initial_connections( + &inner.initial_nodes, + &inner.cluster_params, + inner.glide_connection_options.clone(), + ) + .await + { + Ok(map) => map, + Err(err) => { + warn!("Can't reconnect to initial nodes: `{err}`"); + return; + } + }; + let mut write_lock = inner.conn_lock.write().await; + write_lock.extend_connection_map(connection_map); + drop(write_lock); + if let Err(err) = Self::refresh_slots_and_subscriptions_with_retries( + inner.clone(), + &RefreshPolicy::Throttable, + ) + .await + { + warn!("Can't refresh slots with initial nodes: `{err}`"); + }; + } + } + + // Validate all existing user connections and try to reconnect if nessesary. + // In addition, as a safety measure, drop nodes that do not have any assigned slots. + // This function serves as a cheap alternative to slot_refresh() and thus can be used much more frequently. + // The function does not discover the topology from the cluster and assumes the cached topology is valid. + // In addition, the validation is done by peeking at the state of the underlying transport w/o overhead of additional commands to server. + async fn validate_all_user_connections(inner: Arc>) { + let mut all_valid_conns = HashMap::new(); + // prep connections and clean out these w/o assigned slots, as we might have established connections to unwanted hosts + let mut nodes_to_delete = Vec::new(); + let connections_container = inner.conn_lock.read().await; + + let all_nodes_with_slots: HashSet = connections_container + .slot_map + .addresses_for_all_nodes() + .iter() + .map(|addr| String::from(*addr)) + .collect(); + + connections_container + .all_node_connections() + .for_each(|(addr, con)| { + if all_nodes_with_slots.contains(&addr) { + all_valid_conns.insert(addr.clone(), con.clone()); + } else { + nodes_to_delete.push(addr.clone()); + } + }); + + for addr in &nodes_to_delete { + connections_container.remove_node(addr); + } + + drop(connections_container); + + // identify nodes with closed connection + let mut addrs_to_refresh = Vec::new(); + for (addr, con_fut) in &all_valid_conns { + let con = con_fut.clone().await; + // connection object might be present despite the transport being closed + if con.is_closed() { + // transport is closed, need to refresh + addrs_to_refresh.push(addr.clone()); + } + } + + // identify missing nodes + addrs_to_refresh.extend( + all_nodes_with_slots + .iter() + .filter(|addr| !all_valid_conns.contains_key(*addr)) + .cloned(), + ); + + if !addrs_to_refresh.is_empty() { + // dont try existing nodes since we know a. it does not exist. b. exist but its connection is closed + Self::refresh_connections( + inner.clone(), + addrs_to_refresh, + RefreshConnectionType::AllConnections, + false, + ) + .await; + } + } + + async fn refresh_connections( + inner: Arc>, + addresses: Vec, + conn_type: RefreshConnectionType, + check_existing_conn: bool, + ) { + info!("Started refreshing connections to {:?}", addresses); + let connections_container = inner.conn_lock.read().await; + let cluster_params = &inner.cluster_params; + let subscriptions_by_address = &inner.subscriptions_by_address; + let glide_connection_optons = &inner.glide_connection_options; + + stream::iter(addresses.into_iter()) + .fold( + &*connections_container, + |connections_container, address| async move { + let node_option = if check_existing_conn { + connections_container.remove_node(&address) + } else { + None + }; + + // override subscriptions for this connection + let mut cluster_params = cluster_params.clone(); + let subs_guard = subscriptions_by_address.read().await; + cluster_params.pubsub_subscriptions = subs_guard.get(&address).cloned(); + drop(subs_guard); + let node = get_or_create_conn( + &address, + node_option, + &cluster_params, + conn_type, + glide_connection_optons.clone(), + ) + .await; + match node { + Ok(node) => { + connections_container + .replace_or_add_connection_for_address(address, node); + } + Err(err) => { + warn!( + "Failed to refresh connection for node {}. Error: `{:?}`", + address, err + ); + } + } + connections_container + }, + ) + .await; + info!("refresh connections completed"); + } + + async fn aggregate_results( + receivers: Vec<(Option, oneshot::Receiver>)>, + routing: &MultipleNodeRoutingInfo, + response_policy: Option, + ) -> RedisResult { + let extract_result = |response| match response { + Response::Single(value) => value, + Response::Multiple(_) => unreachable!(), + Response::ClusterScanResult(_, _) => unreachable!(), + }; + + let convert_result = |res: Result, _>| { + res.map_err(|_| RedisError::from((ErrorKind::ResponseError, "request wasn't handled due to internal failure"))) // this happens only if the result sender is dropped before usage. + .and_then(|res| res.map(extract_result)) + }; + + let get_receiver = |(_, receiver): (_, oneshot::Receiver>)| async { + convert_result(receiver.await) + }; + + // TODO - once Value::Error will be merged, these will need to be updated to handle this new value. + match response_policy { + Some(ResponsePolicy::AllSucceeded) => { + future::try_join_all(receivers.into_iter().map(get_receiver)) + .await + .map(|mut results| results.pop().unwrap()) // unwrap is safe, since at least one function succeeded + } + Some(ResponsePolicy::OneSucceeded) => future::select_ok( + receivers + .into_iter() + .map(|tuple| Box::pin(get_receiver(tuple))), + ) + .await + .map(|(result, _)| result), + Some(ResponsePolicy::FirstSucceededNonEmptyOrAllEmpty) => { + // Attempt to return the first result that isn't `Nil` or an error. + // If no such response is found and all servers returned `Nil`, it indicates that all shards are empty, so return `Nil`. + // If we received only errors, return the last received error. + // If we received a mix of errors and `Nil`s, we can't determine if all shards are empty, + // thus we return the last received error instead of `Nil`. + let num_of_results: usize = receivers.len(); + let mut futures = receivers + .into_iter() + .map(get_receiver) + .collect::>(); + let mut nil_counter = 0; + let mut last_err = None; + while let Some(result) = futures.next().await { + match result { + Ok(Value::Nil) => nil_counter += 1, + Ok(val) => return Ok(val), + Err(e) => last_err = Some(e), + } + } + + if nil_counter == num_of_results { + // All received results are `Nil` + Ok(Value::Nil) + } else { + Err(last_err.unwrap_or_else(|| { + ( + ErrorKind::AllConnectionsUnavailable, + "Couldn't find any connection", + ) + .into() + })) + } + } + Some(ResponsePolicy::Aggregate(op)) => { + future::try_join_all(receivers.into_iter().map(get_receiver)) + .await + .and_then(|results| crate::cluster_routing::aggregate(results, op)) + } + Some(ResponsePolicy::AggregateLogical(op)) => { + future::try_join_all(receivers.into_iter().map(get_receiver)) + .await + .and_then(|results| crate::cluster_routing::logical_aggregate(results, op)) + } + Some(ResponsePolicy::CombineArrays) => { + future::try_join_all(receivers.into_iter().map(get_receiver)) + .await + .and_then(|results| match routing { + MultipleNodeRoutingInfo::MultiSlot(vec) => { + crate::cluster_routing::combine_and_sort_array_results( + results, + vec.iter().map(|(_, indices)| indices), + ) + } + _ => crate::cluster_routing::combine_array_results(results), + }) + } + Some(ResponsePolicy::CombineMaps) => { + future::try_join_all(receivers.into_iter().map(get_receiver)) + .await + .and_then(crate::cluster_routing::combine_map_results) + } + Some(ResponsePolicy::Special) | None => { + // This is our assumption - if there's no coherent way to aggregate the responses, we just map each response to the sender, and pass it to the user. + // TODO - once Value::Error is merged, we can use join_all and report separate errors and also pass successes. + future::try_join_all(receivers.into_iter().map(|(addr, receiver)| async move { + let result = convert_result(receiver.await)?; + // The unwrap here is possible, because if `addr` is None, an error should have been sent on the receiver. + Ok((Value::BulkString(addr.unwrap().as_bytes().to_vec()), result)) + })) + .await + .map(Value::Map) + } + } + } + + // Query a node to discover slot-> master mappings with retries + async fn refresh_slots_and_subscriptions_with_retries( + inner: Arc>, + policy: &RefreshPolicy, + ) -> RedisResult<()> { + let SlotRefreshState { + in_progress, + last_run, + rate_limiter, + } = &inner.slot_refresh_state; + // Ensure only a single slot refresh operation occurs at a time + if in_progress + .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed) + .is_err() + { + return Ok(()); + } + let mut skip_slots_refresh = false; + if *policy == RefreshPolicy::Throttable { + // Check if the current slot refresh is triggered before the wait duration has passed + let last_run_rlock = last_run.read().await; + if let Some(last_run_time) = *last_run_rlock { + let passed_time = SystemTime::now() + .duration_since(last_run_time) + .unwrap_or_else(|err| { + warn!( + "Failed to get the duration since the last slot refresh, received error: {:?}", + err + ); + // Setting the passed time to 0 will force the current refresh to continue and reset the stored last_run timestamp with the current one + Duration::from_secs(0) + }); + let wait_duration = rate_limiter.wait_duration(); + if passed_time <= wait_duration { + debug!("Skipping slot refresh as the wait duration hasn't yet passed. Passed time = {:?}, + Wait duration = {:?}", passed_time, wait_duration); + skip_slots_refresh = true; + } + } + } + + let mut res = Ok(()); + if !skip_slots_refresh { + let retry_strategy = ExponentialBackoff { + initial_interval: DEFAULT_REFRESH_SLOTS_RETRY_INITIAL_INTERVAL, + max_interval: DEFAULT_REFRESH_SLOTS_RETRY_MAX_INTERVAL, + max_elapsed_time: None, + ..Default::default() + }; + let retries_counter = AtomicUsize::new(0); + res = retry(retry_strategy, || { + let curr_retry = retries_counter.fetch_add(1, atomic::Ordering::Relaxed); + Self::refresh_slots(inner.clone(), curr_retry) + }) + .await; + } + in_progress.store(false, Ordering::Relaxed); + + Self::refresh_pubsub_subscriptions(inner).await; + + res + } + + pub(crate) async fn check_topology_and_refresh_if_diff( + inner: Arc>, + policy: &RefreshPolicy, + ) -> bool { + let topology_changed = Self::check_for_topology_diff(inner.clone()).await; + if topology_changed { + let _ = Self::refresh_slots_and_subscriptions_with_retries(inner.clone(), policy).await; + } + topology_changed + } + + async fn periodic_topology_check(inner: Arc>, interval_duration: Duration) { + loop { + let _ = boxed_sleep(interval_duration).await; + let topology_changed = + Self::check_topology_and_refresh_if_diff(inner.clone(), &RefreshPolicy::Throttable) + .await; + if !topology_changed { + // This serves as a safety measure for validating pubsub subsctiptions state in case it has drifted + // while topology stayed the same. + // For example, a failed attempt to refresh a connection which is triggered from refresh_pubsub_subscriptions(), + // might leave a node unconnected indefinitely in case topology is stable and no request are attempted to this node. + Self::refresh_pubsub_subscriptions(inner.clone()).await; + } + } + } + + async fn connections_validation_task(inner: Arc>, interval_duration: Duration) { + loop { + if let Some(disconnect_notifier) = + inner.glide_connection_options.disconnect_notifier.clone() + { + disconnect_notifier + .wait_for_disconnect_with_timeout(&interval_duration) + .await; + } else { + let _ = boxed_sleep(interval_duration).await; + } + + Self::validate_all_user_connections(inner.clone()).await; + } + } + + async fn refresh_pubsub_subscriptions(inner: Arc>) { + if inner.cluster_params.protocol != crate::types::ProtocolVersion::RESP3 { + return; + } + + let mut addrs_to_refresh: HashSet = HashSet::new(); + let mut subs_by_address_guard = inner.subscriptions_by_address.write().await; + let mut unassigned_subs_guard = inner.unassigned_subscriptions.write().await; + let conns_read_guard = inner.conn_lock.read().await; + + // validate active subscriptions location + subs_by_address_guard.retain(|current_address, address_subs| { + address_subs.retain(|kind, channels_patterns| { + channels_patterns.retain(|channel_pattern| { + let new_slot = get_slot(channel_pattern); + let mut valid = false; + if let Some((new_address, _)) = conns_read_guard + .connection_for_route(&Route::new(new_slot, SlotAddr::Master)) + { + if *new_address == *current_address { + valid = true; + } + } + // no new address or new address differ - move to unassigned and store this address for connection reset + if !valid { + // need to drop the original connection for clearing the subscription in the server, avoiding possible double-receivers + if conns_read_guard + .connection_for_address(current_address) + .is_some() + { + addrs_to_refresh.insert(current_address.clone()); + } + + unassigned_subs_guard + .entry(*kind) + .and_modify(|channels_patterns| { + channels_patterns.insert(channel_pattern.clone()); + }) + .or_insert(HashSet::from([channel_pattern.clone()])); + } + valid + }); + !channels_patterns.is_empty() + }); + !address_subs.is_empty() + }); + + // try to assign new addresses + unassigned_subs_guard.retain(|kind: &PubSubSubscriptionKind, channels_patterns| { + channels_patterns.retain(|channel_pattern| { + let new_slot = get_slot(channel_pattern); + if let Some((new_address, _)) = + conns_read_guard.connection_for_route(&Route::new(new_slot, SlotAddr::Master)) + { + // need to drop the new connection so the subscription will be picked up in setup_connection() + addrs_to_refresh.insert(new_address.clone()); + + let e = subs_by_address_guard + .entry(new_address.clone()) + .or_insert(PubSubSubscriptionInfo::new()); + + e.entry(*kind) + .or_insert(HashSet::new()) + .insert(channel_pattern.clone()); + + return false; + } + true + }); + !channels_patterns.is_empty() + }); + + drop(conns_read_guard); + drop(unassigned_subs_guard); + drop(subs_by_address_guard); + + if !addrs_to_refresh.is_empty() { + // immediately trigger connection reestablishment + Self::refresh_connections( + inner.clone(), + addrs_to_refresh.into_iter().collect(), + RefreshConnectionType::AllConnections, + false, + ) + .await; + } + } + + /// Queries log2n nodes (where n represents the number of cluster nodes) to determine whether their + /// topology view differs from the one currently stored in the connection manager. + /// Returns true if change was detected, otherwise false. + async fn check_for_topology_diff(inner: Arc>) -> bool { + let read_guard = inner.conn_lock.read().await; + let num_of_nodes: usize = read_guard.len(); + // TODO: Starting from Rust V1.67, integers has logarithms support. + // When we no longer need to support Rust versions < 1.67, remove fast_math and transition to the ilog2 function. + let num_of_nodes_to_query = + std::cmp::max(fast_math::log2_raw(num_of_nodes as f32) as usize, 1); + let (res, failed_connections) = calculate_topology_from_random_nodes( + &inner, + num_of_nodes_to_query, + &read_guard, + DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES, + ) + .await; + + if let Ok((_, found_topology_hash)) = res { + if read_guard.get_current_topology_hash() != found_topology_hash { + return true; + } + } + drop(read_guard); + + if !failed_connections.is_empty() { + Self::refresh_connections( + inner, + failed_connections, + RefreshConnectionType::OnlyManagementConnection, + true, + ) + .await; + } + + false + } + + async fn refresh_slots( + inner: Arc>, + curr_retry: usize, + ) -> Result<(), BackoffError> { + // Update the slot refresh last run timestamp + let now = SystemTime::now(); + let mut last_run_wlock = inner.slot_refresh_state.last_run.write().await; + *last_run_wlock = Some(now); + drop(last_run_wlock); + Self::refresh_slots_inner(inner, curr_retry) + .await + .map_err(|err| { + if curr_retry > DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES { + BackoffError::Permanent(err) + } else { + BackoffError::from(err) + } + }) + } + + pub(crate) fn check_if_all_slots_covered(slot_map: &SlotMap) -> bool { + let mut slots_covered = 0; + for (end, slots) in slot_map.slots.iter() { + slots_covered += end.saturating_sub(slots.start).saturating_add(1); + } + slots_covered == SLOT_SIZE + } + + // Query a node to discover slot-> master mappings + async fn refresh_slots_inner(inner: Arc>, curr_retry: usize) -> RedisResult<()> { + let read_guard = inner.conn_lock.read().await; + let num_of_nodes = read_guard.len(); + const MAX_REQUESTED_NODES: usize = 10; + let num_of_nodes_to_query = std::cmp::min(num_of_nodes, MAX_REQUESTED_NODES); + let (new_slots, topology_hash) = calculate_topology_from_random_nodes( + &inner, + num_of_nodes_to_query, + &read_guard, + curr_retry, + ) + .await + .0?; + let connections = &*read_guard; + // Create a new connection vector of the found nodes + let mut nodes = new_slots.values().flatten().collect::>(); + nodes.sort_unstable(); + nodes.dedup(); + let nodes_len = nodes.len(); + let addresses_and_connections_iter = stream::iter(nodes) + .fold( + Vec::with_capacity(nodes_len), + |mut addrs_and_conns, addr| async move { + if let Some(node) = connections.node_for_address(addr.as_str()) { + addrs_and_conns.push((addr, Some(node))); + return addrs_and_conns; + } + // If it's a DNS endpoint, it could have been stored in the existing connections vector using the resolved IP address instead of the DNS endpoint's name. + // We shall check if a connection is already exists under the resolved IP name. + let (host, port) = match get_host_and_port_from_addr(addr) { + Some((host, port)) => (host, port), + None => { + addrs_and_conns.push((addr, None)); + return addrs_and_conns; + } + }; + let conn = get_socket_addrs(host, port) + .await + .ok() + .map(|mut socket_addresses| { + socket_addresses + .find_map(|addr| connections.node_for_address(&addr.to_string())) + }) + .unwrap_or(None); + addrs_and_conns.push((addr, conn)); + addrs_and_conns + }, + ) + .await; + let new_connections: ConnectionMap = stream::iter(addresses_and_connections_iter) + .fold( + ConnectionsMap(DashMap::with_capacity(nodes_len)), + |connections, (addr, node)| async { + let mut cluster_params = inner.cluster_params.clone(); + let subs_guard = inner.subscriptions_by_address.read().await; + cluster_params.pubsub_subscriptions = subs_guard.get(addr).cloned(); + drop(subs_guard); + let node = get_or_create_conn( + addr, + node, + &cluster_params, + RefreshConnectionType::AllConnections, + inner.glide_connection_options.clone(), + ) + .await; + if let Ok(node) = node { + connections.0.insert(addr.into(), node); + } + connections + }, + ) + .await; + + drop(read_guard); + info!("refresh_slots found nodes:\n{new_connections}"); + // Replace the current slot map and connection vector with the new ones + let mut write_guard = inner.conn_lock.write().await; + *write_guard = ConnectionsContainer::new( + new_slots, + new_connections, + inner.cluster_params.read_from_replicas, + topology_hash, + ); + Ok(()) + } + + async fn execute_on_multiple_nodes<'a>( + cmd: &'a Arc, + routing: &'a MultipleNodeRoutingInfo, + core: Core, + response_policy: Option, + ) -> OperationResult { + trace!("execute_on_multiple_nodes"); + let connections_container = core.conn_lock.read().await; + if connections_container.is_empty() { + return OperationResult::Err(( + OperationTarget::FanOut, + ( + ErrorKind::AllConnectionsUnavailable, + "No connections found for multi-node operation", + ) + .into(), + )); + } + + // This function maps the connections to senders & receivers of one-shot channels, and the receivers are mapped to `PendingRequest`s. + // This allows us to pass the new `PendingRequest`s to `try_request`, while letting `execute_on_multiple_nodes` wait on the receivers + // for all of the individual requests to complete. + #[allow(clippy::type_complexity)] // The return value is complex, but indentation and linebreaks make it human readable. + fn into_channels( + iterator: impl Iterator< + Item = Option<(Arc, ConnectionAndAddress>)>, + >, + ) -> ( + Vec<(Option, Receiver>)>, + Vec>>, + ) { + iterator + .map(|tuple_opt| { + let (sender, receiver) = oneshot::channel(); + if let Some((cmd, conn, address)) = + tuple_opt.map(|(cmd, (address, conn))| (cmd, conn, address)) + { + ( + (Some(address.clone()), receiver), + Some(PendingRequest { + retry: 0, + sender, + info: RequestInfo { + cmd: CmdArg::Cmd { + cmd, + routing: InternalSingleNodeRouting::Connection { + address, + conn, + } + .into(), + }, + }, + }), + ) + } else { + let _ = sender.send(Err(( + ErrorKind::ConnectionNotFoundForRoute, + "Connection not found", + ) + .into())); + ((None, receiver), None) + } + }) + .unzip() + } + + let (receivers, requests): (Vec<_>, Vec<_>) = match routing { + MultipleNodeRoutingInfo::AllNodes => into_channels( + connections_container + .all_node_connections() + .map(|tuple| Some((cmd.clone(), tuple))), + ), + MultipleNodeRoutingInfo::AllMasters => into_channels( + connections_container + .all_primary_connections() + .map(|tuple| Some((cmd.clone(), tuple))), + ), + MultipleNodeRoutingInfo::MultiSlot(slots) => { + into_channels(slots.iter().map(|(route, indices)| { + connections_container + .connection_for_route(route) + .map(|tuple| { + let new_cmd = crate::cluster_routing::command_for_multi_slot_indices( + cmd.as_ref(), + indices.iter(), + ); + (Arc::new(new_cmd), tuple) + }) + })) + } + }; + + drop(connections_container); + core.pending_requests + .lock() + .unwrap() + .extend(requests.into_iter().flatten()); + + Self::aggregate_results(receivers, routing, response_policy) + .await + .map(Response::Single) + .map_err(|err| (OperationTarget::FanOut, err)) + } + + pub(crate) async fn try_cmd_request( + cmd: Arc, + routing: InternalRoutingInfo, + core: Core, + ) -> OperationResult { + let routing = match routing { + // commands that are sent to multiple nodes are handled here. + InternalRoutingInfo::MultiNode((multi_node_routing, response_policy)) => { + return Self::execute_on_multiple_nodes( + &cmd, + &multi_node_routing, + core, + response_policy, + ) + .await; + } + + InternalRoutingInfo::SingleNode(routing) => routing, + }; + trace!("route request to single node"); + + // if we reached this point, we're sending the command only to single node, and we need to find the + // right connection to the node. + let (address, mut conn) = Self::get_connection(routing, core, Some(cmd.clone())) + .await + .map_err(|err| (OperationTarget::NotFound, err))?; + conn.req_packed_command(&cmd) + .await + .map(Response::Single) + .map_err(|err| (address.into(), err)) + } + + async fn try_pipeline_request( + pipeline: Arc, + offset: usize, + count: usize, + conn: impl Future>, + ) -> OperationResult { + trace!("try_pipeline_request"); + let (address, mut conn) = conn.await.map_err(|err| (OperationTarget::NotFound, err))?; + conn.req_packed_commands(&pipeline, offset, count) + .await + .map(Response::Multiple) + .map_err(|err| (OperationTarget::Node { address }, err)) + } + + async fn try_request(info: RequestInfo, core: Core) -> OperationResult { + match info.cmd { + CmdArg::Cmd { cmd, routing } => Self::try_cmd_request(cmd, routing, core).await, + CmdArg::Pipeline { + pipeline, + offset, + count, + route, + } => { + Self::try_pipeline_request( + pipeline, + offset, + count, + Self::get_connection(route, core, None), + ) + .await + } + CmdArg::ClusterScan { + cluster_scan_args, .. + } => { + let core = core; + let scan_result = cluster_scan(core, cluster_scan_args).await; + match scan_result { + Ok((scan_state_ref, values)) => { + Ok(Response::ClusterScanResult(scan_state_ref, values)) + } + // TODO: After routing issues with sending to random node on not-key based commands are resolved, + // this error should be handled in the same way as other errors and not fan-out. + Err(err) => Err((OperationTarget::FanOut, err)), + } + } + } + } + + async fn get_connection( + routing: InternalSingleNodeRouting, + core: Core, + cmd: Option>, + ) -> RedisResult<(String, C)> { + let read_guard = core.conn_lock.read().await; + let mut asking = false; + + let conn_check = match routing { + InternalSingleNodeRouting::Redirect { + redirect: Redirect::Moved(moved_addr), + .. + } => read_guard + .connection_for_address(moved_addr.as_str()) + .map_or( + ConnectionCheck::OnlyAddress(moved_addr), + ConnectionCheck::Found, + ), + InternalSingleNodeRouting::Redirect { + redirect: Redirect::Ask(ask_addr), + .. + } => { + asking = true; + read_guard.connection_for_address(ask_addr.as_str()).map_or( + ConnectionCheck::OnlyAddress(ask_addr), + ConnectionCheck::Found, + ) + } + InternalSingleNodeRouting::SpecificNode(route) => { + match read_guard.connection_for_route(&route) { + Some((conn, address)) => ConnectionCheck::Found((conn, address)), + None => { + // No connection is found for the given route: + // - For key-based commands, attempt redirection to a random node, + // hopefully to be redirected afterwards by a MOVED error. + // - For non-key-based commands, avoid attempting redirection to a random node + // as it wouldn't result in MOVED hints and can lead to unwanted results + // (e.g., sending management command to a different node than the user asked for); instead, raise the error. + let routable_cmd = cmd.and_then(|cmd| Routable::command(&*cmd)); + if routable_cmd.is_some() + && !RoutingInfo::is_key_routing_command(&routable_cmd.unwrap()) + { + return Err(( + ErrorKind::ConnectionNotFoundForRoute, + "Requested connection not found for route", + format!("{route:?}"), + ) + .into()); + } else { + warn!("No connection found for route `{route:?}`. Attempting redirection to a random node."); + ConnectionCheck::RandomConnection + } + } + } + } + InternalSingleNodeRouting::Random => ConnectionCheck::RandomConnection, + InternalSingleNodeRouting::Connection { address, conn } => { + return Ok((address, conn.await)); + } + InternalSingleNodeRouting::ByAddress(address) => { + if let Some((address, conn)) = read_guard.connection_for_address(&address) { + return Ok((address, conn.await)); + } else { + return Err(( + ErrorKind::ConnectionNotFoundForRoute, + "Requested connection not found", + address, + ) + .into()); + } + } + }; + drop(read_guard); + + let (address, mut conn) = match conn_check { + ConnectionCheck::Found((address, connection)) => (address, connection.await), + ConnectionCheck::OnlyAddress(addr) => { + let mut this_conn_params = core.cluster_params.clone(); + let subs_guard = core.subscriptions_by_address.read().await; + this_conn_params.pubsub_subscriptions = subs_guard.get(addr.as_str()).cloned(); + drop(subs_guard); + match connect_and_check::( + &addr, + this_conn_params, + None, + RefreshConnectionType::AllConnections, + None, + core.glide_connection_options.clone(), + ) + .await + .get_node() + { + Ok(node) => { + let connection_clone = node.user_connection.conn.clone().await; + let connections = core.conn_lock.read().await; + let address = connections.replace_or_add_connection_for_address(addr, node); + drop(connections); + (address, connection_clone) + } + Err(err) => { + return Err(err); + } + } + } + ConnectionCheck::RandomConnection => { + let read_guard = core.conn_lock.read().await; + let (random_address, random_conn_future) = read_guard + .random_connections(1, ConnectionType::User) + .next() + .ok_or(RedisError::from(( + ErrorKind::AllConnectionsUnavailable, + "No random connection found", + )))?; + return Ok((random_address, random_conn_future.await)); + } + }; + + if asking { + let _ = conn.req_packed_command(&crate::cmd::cmd("ASKING")).await; + } + Ok((address, conn)) + } + + fn poll_recover(&mut self, cx: &mut task::Context<'_>) -> Poll> { + let recover_future = match &mut self.state { + ConnectionState::PollComplete => return Poll::Ready(Ok(())), + ConnectionState::Recover(future) => future, + }; + match recover_future { + RecoverFuture::RecoverSlots(ref mut future) => match ready!(future.as_mut().poll(cx)) { + Ok(_) => { + trace!("Recovered!"); + self.state = ConnectionState::PollComplete; + Poll::Ready(Ok(())) + } + Err(err) => { + trace!("Recover slots failed!"); + *future = Box::pin(Self::refresh_slots_and_subscriptions_with_retries( + self.inner.clone(), + &RefreshPolicy::Throttable, + )); + Poll::Ready(Err(err)) + } + }, + RecoverFuture::Reconnect(ref mut future) => { + ready!(future.as_mut().poll(cx)); + trace!("Reconnected connections"); + self.state = ConnectionState::PollComplete; + Poll::Ready(Ok(())) + } + } + } + + async fn handle_loading_error( + core: Core, + info: RequestInfo, + address: String, + retry: u32, + ) -> OperationResult { + let is_primary = core.conn_lock.read().await.is_primary(&address); + + if !is_primary { + // If the connection is a replica, remove the connection and retry. + // The connection will be established again on the next call to refresh slots once the replica is no longer in loading state. + core.conn_lock.read().await.remove_node(&address); + } else { + // If the connection is primary, just sleep and retry + let sleep_duration = core.cluster_params.retry_params.wait_time_for_retry(retry); + boxed_sleep(sleep_duration).await; + } + + Self::try_request(info, core).await + } + + fn poll_complete(&mut self, cx: &mut task::Context<'_>) -> Poll { + let mut poll_flush_action = PollFlushAction::None; + + let mut pending_requests_guard = self.inner.pending_requests.lock().unwrap(); + if !pending_requests_guard.is_empty() { + let mut pending_requests = mem::take(&mut *pending_requests_guard); + for request in pending_requests.drain(..) { + // Drop the request if none is waiting for a response to free up resources for + // requests callers care about (load shedding). It will be ambiguous whether the + // request actually goes through regardless. + if request.sender.is_closed() { + continue; + } + + let future = Self::try_request(request.info.clone(), self.inner.clone()).boxed(); + self.in_flight_requests.push(Box::pin(Request { + retry_params: self.inner.cluster_params.retry_params.clone(), + request: Some(request), + future: RequestState::Future { future }, + })); + } + *pending_requests_guard = pending_requests; + } + drop(pending_requests_guard); + + loop { + let result = match Pin::new(&mut self.in_flight_requests).poll_next(cx) { + Poll::Ready(Some(result)) => result, + Poll::Ready(None) | Poll::Pending => break, + }; + match result { + Next::Done => {} + Next::Retry { request } => { + let future = Self::try_request(request.info.clone(), self.inner.clone()); + self.in_flight_requests.push(Box::pin(Request { + retry_params: self.inner.cluster_params.retry_params.clone(), + request: Some(request), + future: RequestState::Future { + future: Box::pin(future), + }, + })); + } + Next::RetryBusyLoadingError { request, address } => { + // TODO - do we also want to try and reconnect to replica if it is loading? + let future = Self::handle_loading_error( + self.inner.clone(), + request.info.clone(), + address, + request.retry, + ); + self.in_flight_requests.push(Box::pin(Request { + retry_params: self.inner.cluster_params.retry_params.clone(), + request: Some(request), + future: RequestState::Future { + future: Box::pin(future), + }, + })); + } + Next::RefreshSlots { + request, + sleep_duration, + } => { + poll_flush_action = + poll_flush_action.change_state(PollFlushAction::RebuildSlots); + if let Some(request) = request { + let future: RequestState< + Pin + Send>>, + > = match sleep_duration { + Some(sleep_duration) => RequestState::Sleep { + sleep: boxed_sleep(sleep_duration), + }, + None => RequestState::Future { + future: Box::pin(Self::try_request( + request.info.clone(), + self.inner.clone(), + )), + }, + }; + self.in_flight_requests.push(Box::pin(Request { + retry_params: self.inner.cluster_params.retry_params.clone(), + request: Some(request), + future, + })); + } + } + Next::Reconnect { + request, target, .. + } => { + poll_flush_action = + poll_flush_action.change_state(PollFlushAction::Reconnect(vec![target])); + if let Some(request) = request { + self.inner.pending_requests.lock().unwrap().push(request); + } + } + Next::ReconnectToInitialNodes { request } => { + poll_flush_action = poll_flush_action + .change_state(PollFlushAction::ReconnectFromInitialConnections); + if let Some(request) = request { + self.inner.pending_requests.lock().unwrap().push(request); + } + } + } + } + + if matches!(poll_flush_action, PollFlushAction::None) { + if self.in_flight_requests.is_empty() { + Poll::Ready(poll_flush_action) + } else { + Poll::Pending + } + } else { + Poll::Ready(poll_flush_action) + } + } + + fn send_refresh_error(&mut self) { + if self.refresh_error.is_some() { + if let Some(mut request) = Pin::new(&mut self.in_flight_requests) + .iter_pin_mut() + .find(|request| request.request.is_some()) + { + (*request) + .as_mut() + .respond(Err(self.refresh_error.take().unwrap())); + } else if let Some(request) = self.inner.pending_requests.lock().unwrap().pop() { + let _ = request.sender.send(Err(self.refresh_error.take().unwrap())); + } + } + } +} + +enum PollFlushAction { + None, + RebuildSlots, + Reconnect(Vec), + ReconnectFromInitialConnections, +} + +impl PollFlushAction { + fn change_state(self, next_state: PollFlushAction) -> PollFlushAction { + match (self, next_state) { + (PollFlushAction::None, next_state) => next_state, + (next_state, PollFlushAction::None) => next_state, + (PollFlushAction::ReconnectFromInitialConnections, _) + | (_, PollFlushAction::ReconnectFromInitialConnections) => { + PollFlushAction::ReconnectFromInitialConnections + } + + (PollFlushAction::RebuildSlots, _) | (_, PollFlushAction::RebuildSlots) => { + PollFlushAction::RebuildSlots + } + + (PollFlushAction::Reconnect(mut addrs), PollFlushAction::Reconnect(new_addrs)) => { + addrs.extend(new_addrs); + Self::Reconnect(addrs) + } + } + } +} + +impl Sink> for Disposable> +where + C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static, +{ + type Error = (); + + fn poll_ready(self: Pin<&mut Self>, _cx: &mut task::Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, msg: Message) -> Result<(), Self::Error> { + let Message { cmd, sender } = msg; + + let info = RequestInfo { cmd }; + + self.inner + .pending_requests + .lock() + .unwrap() + .push(PendingRequest { + retry: 0, + sender, + info, + }); + Ok(()) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + ) -> Poll> { + trace!("poll_flush: {:?}", self.state); + loop { + self.send_refresh_error(); + + if let Err(err) = ready!(self.as_mut().poll_recover(cx)) { + // We failed to reconnect, while we will try again we will report the + // error if we can to avoid getting trapped in an infinite loop of + // trying to reconnect + self.refresh_error = Some(err); + + // Give other tasks a chance to progress before we try to recover + // again. Since the future may not have registered a wake up we do so + // now so the task is not forgotten + cx.waker().wake_by_ref(); + return Poll::Pending; + } + + match ready!(self.poll_complete(cx)) { + PollFlushAction::None => return Poll::Ready(Ok(())), + PollFlushAction::RebuildSlots => { + self.state = ConnectionState::Recover(RecoverFuture::RecoverSlots(Box::pin( + ClusterConnInner::refresh_slots_and_subscriptions_with_retries( + self.inner.clone(), + &RefreshPolicy::Throttable, + ), + ))); + } + PollFlushAction::Reconnect(addresses) => { + self.state = ConnectionState::Recover(RecoverFuture::Reconnect(Box::pin( + ClusterConnInner::refresh_connections( + self.inner.clone(), + addresses, + RefreshConnectionType::OnlyUserConnection, + true, + ), + ))); + } + PollFlushAction::ReconnectFromInitialConnections => { + self.state = ConnectionState::Recover(RecoverFuture::Reconnect(Box::pin( + self.reconnect_to_initial_nodes(), + ))); + } + } + } + } + + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + ) -> Poll> { + // Try to drive any in flight requests to completion + match self.poll_complete(cx) { + Poll::Ready(PollFlushAction::None) => (), + Poll::Ready(_) => Err(())?, + Poll::Pending => (), + }; + // If we no longer have any requests in flight we are done (skips any reconnection + // attempts) + if self.in_flight_requests.is_empty() { + return Poll::Ready(Ok(())); + } + + self.poll_flush(cx) + } +} + +async fn calculate_topology_from_random_nodes<'a, C>( + inner: &Core, + num_of_nodes_to_query: usize, + read_guard: &tokio::sync::RwLockReadGuard<'a, ConnectionsContainer>, + curr_retry: usize, +) -> ( + RedisResult<( + crate::cluster_slotmap::SlotMap, + crate::cluster_topology::TopologyHash, + )>, + Vec, +) +where + C: ConnectionLike + Connect + Clone + Send + Sync + 'static, +{ + let requested_nodes = + read_guard.random_connections(num_of_nodes_to_query, ConnectionType::PreferManagement); + let topology_join_results = + futures::future::join_all(requested_nodes.map(|(addr, conn)| async move { + let mut conn: C = conn.await; + let res = conn.req_packed_command(&slot_cmd()).await; + (addr, res) + })) + .await; + let failed_addresses = topology_join_results + .iter() + .filter_map(|(address, res)| match res { + Err(err) if err.is_unrecoverable_error() => Some(address.clone()), + _ => None, + }) + .collect(); + let topology_values = topology_join_results.iter().filter_map(|(addr, res)| { + res.as_ref() + .ok() + .and_then(|value| get_host_and_port_from_addr(addr).map(|(host, _)| (host, value))) + }); + ( + calculate_topology( + topology_values, + curr_retry, + inner.cluster_params.tls, + num_of_nodes_to_query, + inner.cluster_params.read_from_replicas, + ), + failed_addresses, + ) +} + +impl ConnectionLike for ClusterConnection +where + C: ConnectionLike + Send + Clone + Unpin + Sync + Connect + 'static, +{ + fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> { + let routing = cluster_routing::RoutingInfo::for_routable(cmd).unwrap_or( + cluster_routing::RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random), + ); + self.route_command(cmd, routing).boxed() + } + + fn req_packed_commands<'a>( + &'a mut self, + pipeline: &'a crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisFuture<'a, Vec> { + async move { + let route = route_for_pipeline(pipeline)?; + self.route_pipeline(pipeline, offset, count, route.into()) + .await + } + .boxed() + } + + fn get_db(&self) -> i64 { + 0 + } + + fn is_closed(&self) -> bool { + false + } +} + +/// Implements the process of connecting to a Redis server +/// and obtaining a connection handle. +pub trait Connect: Sized { + /// Connect to a node. + /// For TCP connections, returning a tuple of handle for command execution and the node's IP address. + /// For UNIX connections, returning a tuple of handle for command execution and None. + fn connect<'a, T>( + info: T, + response_timeout: Duration, + connection_timeout: Duration, + socket_addr: Option, + glide_connection_options: GlideConnectionOptions, + ) -> RedisFuture<'a, (Self, Option)> + where + T: IntoConnectionInfo + Send + 'a; +} + +impl Connect for MultiplexedConnection { + fn connect<'a, T>( + info: T, + response_timeout: Duration, + connection_timeout: Duration, + socket_addr: Option, + glide_connection_options: GlideConnectionOptions, + ) -> RedisFuture<'a, (MultiplexedConnection, Option)> + where + T: IntoConnectionInfo + Send + 'a, + { + async move { + let connection_info = info.into_connection_info()?; + let client = crate::Client::open(connection_info)?; + + match Runtime::locate() { + #[cfg(feature = "tokio-comp")] + rt @ Runtime::Tokio => { + rt.timeout( + connection_timeout, + client.get_multiplexed_async_connection_inner::( + response_timeout, + socket_addr, + glide_connection_options, + ), + ) + .await? + } + #[cfg(feature = "async-std-comp")] + rt @ Runtime::AsyncStd => { + rt.timeout(connection_timeout,client + .get_multiplexed_async_connection_inner::( + response_timeout, + socket_addr, + glide_connection_options, + )) + .await? + } + } + } + .boxed() + } +} + +#[cfg(test)] +mod pipeline_routing_tests { + use super::route_for_pipeline; + use crate::{ + cluster_routing::{Route, SlotAddr}, + cmd, + }; + + #[test] + fn test_first_route_is_found() { + let mut pipeline = crate::Pipeline::new(); + + pipeline + .add_command(cmd("FLUSHALL")) // route to all masters + .get("foo") // route to slot 12182 + .add_command(cmd("EVAL")); // route randomly + + assert_eq!( + route_for_pipeline(&pipeline), + Ok(Some(Route::new(12182, SlotAddr::ReplicaOptional))) + ); + } + + #[test] + fn test_return_none_if_no_route_is_found() { + let mut pipeline = crate::Pipeline::new(); + + pipeline + .add_command(cmd("FLUSHALL")) // route to all masters + .add_command(cmd("EVAL")); // route randomly + + assert_eq!(route_for_pipeline(&pipeline), Ok(None)); + } + + #[test] + fn test_prefer_primary_route_over_replica() { + let mut pipeline = crate::Pipeline::new(); + + pipeline + .get("foo") // route to replica of slot 12182 + .add_command(cmd("FLUSHALL")) // route to all masters + .add_command(cmd("EVAL"))// route randomly + .cmd("CONFIG").arg("GET").arg("timeout") // unkeyed command + .set("foo", "bar"); // route to primary of slot 12182 + + assert_eq!( + route_for_pipeline(&pipeline), + Ok(Some(Route::new(12182, SlotAddr::Master))) + ); + } + + #[test] + fn test_raise_cross_slot_error_on_conflicting_slots() { + let mut pipeline = crate::Pipeline::new(); + + pipeline + .add_command(cmd("FLUSHALL")) // route to all masters + .set("baz", "bar") // route to slot 4813 + .get("foo"); // route to slot 12182 + + assert_eq!( + route_for_pipeline(&pipeline).unwrap_err().kind(), + crate::ErrorKind::CrossSlot + ); + } + + #[test] + fn unkeyed_commands_dont_affect_route() { + let mut pipeline = crate::Pipeline::new(); + + pipeline + .set("{foo}bar", "baz") // route to primary of slot 12182 + .cmd("CONFIG").arg("GET").arg("timeout") // unkeyed command + .set("foo", "bar") // route to primary of slot 12182 + .cmd("DEBUG").arg("PAUSE").arg("100") // unkeyed command + .cmd("ECHO").arg("hello world"); // unkeyed command + + assert_eq!( + route_for_pipeline(&pipeline), + Ok(Some(Route::new(12182, SlotAddr::Master))) + ); + } +} diff --git a/glide-core/redis-rs/redis/src/cluster_client.rs b/glide-core/redis-rs/redis/src/cluster_client.rs new file mode 100644 index 0000000000..5815bede1e --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_client.rs @@ -0,0 +1,752 @@ +use crate::cluster_slotmap::ReadFromReplicaStrategy; +#[cfg(feature = "cluster-async")] +use crate::cluster_topology::{ + DEFAULT_SLOTS_REFRESH_MAX_JITTER_MILLI, DEFAULT_SLOTS_REFRESH_WAIT_DURATION, +}; +use crate::connection::{ConnectionAddr, ConnectionInfo, IntoConnectionInfo}; +use crate::types::{ErrorKind, ProtocolVersion, RedisError, RedisResult}; +use crate::{cluster, cluster::TlsMode}; +use crate::{PubSubSubscriptionInfo, PushInfo}; +use rand::Rng; +#[cfg(feature = "cluster-async")] +use std::ops::Add; +use std::time::Duration; + +#[cfg(feature = "tls-rustls")] +use crate::tls::TlsConnParams; + +#[cfg(not(feature = "tls-rustls"))] +use crate::connection::TlsConnParams; + +#[cfg(feature = "cluster-async")] +use crate::cluster_async; + +#[cfg(feature = "tls-rustls")] +use crate::tls::{retrieve_tls_certificates, TlsCertificates}; + +use tokio::sync::mpsc; + +/// Parameters specific to builder, so that +/// builder parameters may have different types +/// than final ClusterParams +#[derive(Default)] +struct BuilderParams { + password: Option, + username: Option, + read_from_replicas: ReadFromReplicaStrategy, + tls: Option, + #[cfg(feature = "tls-rustls")] + certs: Option, + retries_configuration: RetryParams, + connection_timeout: Option, + #[cfg(feature = "cluster-async")] + topology_checks_interval: Option, + #[cfg(feature = "cluster-async")] + connections_validation_interval: Option, + #[cfg(feature = "cluster-async")] + slots_refresh_rate_limit: SlotsRefreshRateLimit, + client_name: Option, + response_timeout: Option, + protocol: ProtocolVersion, + pubsub_subscriptions: Option, +} + +#[derive(Clone)] +pub(crate) struct RetryParams { + pub(crate) number_of_retries: u32, + max_wait_time: u64, + min_wait_time: u64, + exponent_base: u64, + factor: u64, +} + +impl Default for RetryParams { + fn default() -> Self { + const DEFAULT_RETRIES: u32 = 16; + const DEFAULT_MAX_RETRY_WAIT_TIME: u64 = 655360; + const DEFAULT_MIN_RETRY_WAIT_TIME: u64 = 1280; + const DEFAULT_EXPONENT_BASE: u64 = 2; + const DEFAULT_FACTOR: u64 = 10; + Self { + number_of_retries: DEFAULT_RETRIES, + max_wait_time: DEFAULT_MAX_RETRY_WAIT_TIME, + min_wait_time: DEFAULT_MIN_RETRY_WAIT_TIME, + exponent_base: DEFAULT_EXPONENT_BASE, + factor: DEFAULT_FACTOR, + } + } +} + +impl RetryParams { + pub(crate) fn wait_time_for_retry(&self, retry: u32) -> Duration { + let base_wait = self.exponent_base.pow(retry) * self.factor; + let clamped_wait = base_wait + .min(self.max_wait_time) + .max(self.min_wait_time + 1); + let jittered_wait = rand::thread_rng().gen_range(self.min_wait_time..clamped_wait); + Duration::from_millis(jittered_wait) + } +} + +/// Configuration for rate limiting slot refresh operations in a Redis cluster. +/// +/// This struct defines the interval duration between consecutive slot refresh +/// operations and an additional jitter to introduce randomness in the refresh intervals. +/// +/// # Fields +/// +/// * `interval_duration`: The minimum duration to wait between consecutive slot refresh operations. +/// * `max_jitter_milli`: The maximum jitter in milliseconds to add to the interval duration. +#[cfg(feature = "cluster-async")] +#[derive(Clone, Copy)] +pub(crate) struct SlotsRefreshRateLimit { + pub(crate) interval_duration: Duration, + pub(crate) max_jitter_milli: u64, +} + +#[cfg(feature = "cluster-async")] +impl Default for SlotsRefreshRateLimit { + fn default() -> Self { + Self { + interval_duration: DEFAULT_SLOTS_REFRESH_WAIT_DURATION, + max_jitter_milli: DEFAULT_SLOTS_REFRESH_MAX_JITTER_MILLI, + } + } +} + +#[cfg(feature = "cluster-async")] +impl SlotsRefreshRateLimit { + pub(crate) fn wait_duration(&self) -> Duration { + let duration_jitter = match self.max_jitter_milli { + 0 => Duration::from_millis(0), + _ => Duration::from_millis(rand::thread_rng().gen_range(0..self.max_jitter_milli)), + }; + self.interval_duration.add(duration_jitter) + } +} +/// Redis cluster specific parameters. +#[derive(Default, Clone)] +#[doc(hidden)] +pub struct ClusterParams { + pub(crate) password: Option, + pub(crate) username: Option, + pub(crate) read_from_replicas: ReadFromReplicaStrategy, + /// tls indicates tls behavior of connections. + /// When Some(TlsMode), connections use tls and verify certification depends on TlsMode. + /// When None, connections do not use tls. + pub(crate) tls: Option, + pub(crate) retry_params: RetryParams, + #[cfg(feature = "cluster-async")] + pub(crate) topology_checks_interval: Option, + #[cfg(feature = "cluster-async")] + pub(crate) slots_refresh_rate_limit: SlotsRefreshRateLimit, + #[cfg(feature = "cluster-async")] + pub(crate) connections_validation_interval: Option, + pub(crate) tls_params: Option, + pub(crate) client_name: Option, + pub(crate) connection_timeout: Duration, + pub(crate) response_timeout: Duration, + pub(crate) protocol: ProtocolVersion, + pub(crate) pubsub_subscriptions: Option, +} + +impl ClusterParams { + fn from(value: BuilderParams) -> RedisResult { + #[cfg(not(feature = "tls-rustls"))] + let tls_params = None; + + #[cfg(feature = "tls-rustls")] + let tls_params = { + let retrieved_tls_params = value.certs.clone().map(retrieve_tls_certificates); + + retrieved_tls_params.transpose()? + }; + + Ok(Self { + password: value.password, + username: value.username, + read_from_replicas: value.read_from_replicas, + tls: value.tls, + retry_params: value.retries_configuration, + connection_timeout: value.connection_timeout.unwrap_or(Duration::MAX), + #[cfg(feature = "cluster-async")] + topology_checks_interval: value.topology_checks_interval, + #[cfg(feature = "cluster-async")] + slots_refresh_rate_limit: value.slots_refresh_rate_limit, + #[cfg(feature = "cluster-async")] + connections_validation_interval: value.connections_validation_interval, + tls_params, + client_name: value.client_name, + response_timeout: value.response_timeout.unwrap_or(Duration::MAX), + protocol: value.protocol, + pubsub_subscriptions: value.pubsub_subscriptions, + }) + } +} + +/// Used to configure and build a [`ClusterClient`]. +pub struct ClusterClientBuilder { + initial_nodes: RedisResult>, + builder_params: BuilderParams, +} + +impl ClusterClientBuilder { + /// Creates a new `ClusterClientBuilder` with the provided initial_nodes. + /// + /// This is the same as `ClusterClient::builder(initial_nodes)`. + pub fn new( + initial_nodes: impl IntoIterator, + ) -> ClusterClientBuilder { + ClusterClientBuilder { + initial_nodes: initial_nodes + .into_iter() + .map(|x| x.into_connection_info()) + .collect(), + builder_params: Default::default(), + } + } + + /// Creates a new [`ClusterClient`] from the parameters. + /// + /// This does not create connections to the Redis Cluster, but only performs some basic checks + /// on the initial nodes' URLs and passwords/usernames. + /// + /// When the `tls-rustls` feature is enabled and TLS credentials are provided, they are set for + /// each cluster connection. + /// + /// # Errors + /// + /// Upon failure to parse initial nodes or if the initial nodes have different passwords or + /// usernames, an error is returned. + pub fn build(self) -> RedisResult { + let initial_nodes = self.initial_nodes?; + + let first_node = match initial_nodes.first() { + Some(node) => node, + None => { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Initial nodes can't be empty.", + ))) + } + }; + + let mut cluster_params = ClusterParams::from(self.builder_params)?; + let password = if cluster_params.password.is_none() { + cluster_params + .password + .clone_from(&first_node.redis.password); + &cluster_params.password + } else { + &None + }; + let username = if cluster_params.username.is_none() { + cluster_params + .username + .clone_from(&first_node.redis.username); + &cluster_params.username + } else { + &None + }; + if cluster_params.tls.is_none() { + cluster_params.tls = match first_node.addr { + ConnectionAddr::TcpTls { + host: _, + port: _, + insecure, + tls_params: _, + } => Some(match insecure { + false => TlsMode::Secure, + true => TlsMode::Insecure, + }), + _ => None, + }; + } + + let mut nodes = Vec::with_capacity(initial_nodes.len()); + for mut node in initial_nodes { + if let ConnectionAddr::Unix(_) = node.addr { + return Err(RedisError::from((ErrorKind::InvalidClientConfig, + "This library cannot use unix socket because Redis's cluster command returns only cluster's IP and port."))); + } + + if password.is_some() && node.redis.password != *password { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Cannot use different password among initial nodes.", + ))); + } + + if username.is_some() && node.redis.username != *username { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Cannot use different username among initial nodes.", + ))); + } + + if node.redis.client_name.is_some() + && node.redis.client_name != cluster_params.client_name + { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Cannot use different client_name among initial nodes.", + ))); + } + + node.redis.protocol = cluster_params.protocol; + nodes.push(node); + } + + Ok(ClusterClient { + initial_nodes: nodes, + cluster_params, + }) + } + + /// Sets client name for the new ClusterClient. + pub fn client_name(mut self, client_name: String) -> ClusterClientBuilder { + self.builder_params.client_name = Some(client_name); + self + } + + /// Sets password for the new ClusterClient. + pub fn password(mut self, password: String) -> ClusterClientBuilder { + self.builder_params.password = Some(password); + self + } + + /// Sets username for the new ClusterClient. + pub fn username(mut self, username: String) -> ClusterClientBuilder { + self.builder_params.username = Some(username); + self + } + + /// Sets number of retries for the new ClusterClient. + pub fn retries(mut self, retries: u32) -> ClusterClientBuilder { + self.builder_params.retries_configuration.number_of_retries = retries; + self + } + + /// Sets maximal wait time in millisceonds between retries for the new ClusterClient. + pub fn max_retry_wait(mut self, max_wait: u64) -> ClusterClientBuilder { + self.builder_params.retries_configuration.max_wait_time = max_wait; + self + } + + /// Sets minimal wait time in millisceonds between retries for the new ClusterClient. + pub fn min_retry_wait(mut self, min_wait: u64) -> ClusterClientBuilder { + self.builder_params.retries_configuration.min_wait_time = min_wait; + self + } + + /// Sets the factor and exponent base for the retry wait time. + /// The formula for the wait is rand(min_wait_retry .. min(max_retry_wait , factor * exponent_base ^ retry))ms. + pub fn retry_wait_formula(mut self, factor: u64, exponent_base: u64) -> ClusterClientBuilder { + self.builder_params.retries_configuration.factor = factor; + self.builder_params.retries_configuration.exponent_base = exponent_base; + self + } + + /// Sets TLS mode for the new ClusterClient. + /// + /// It is extracted from the first node of initial_nodes if not set. + #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))] + pub fn tls(mut self, tls: TlsMode) -> ClusterClientBuilder { + self.builder_params.tls = Some(tls); + self + } + + /// Sets raw TLS certificates for the new ClusterClient. + /// + /// When set, enforces the connection must be TLS secured. + /// + /// All certificates must be provided as byte streams loaded from PEM files their consistency is + /// checked during `build()` call. + /// + /// - `certificates` - `TlsCertificates` structure containing: + /// -- `client_tls` - Optional `ClientTlsConfig` containing byte streams for + /// --- `client_cert` - client's byte stream containing client certificate in PEM format + /// --- `client_key` - client's byte stream containing private key in PEM format + /// -- `root_cert` - Optional byte stream yielding PEM formatted file for root certificates. + /// + /// If `ClientTlsConfig` ( cert+key pair ) is not provided, then client-side authentication is not enabled. + /// If `root_cert` is not provided, then system root certificates are used instead. + #[cfg(feature = "tls-rustls")] + pub fn certs(mut self, certificates: TlsCertificates) -> ClusterClientBuilder { + self.builder_params.tls = Some(TlsMode::Secure); + self.builder_params.certs = Some(certificates); + self + } + + /// Enables reading from replicas for all new connections (default is disabled). + /// + /// If enabled, then read queries will go to the replica nodes & write queries will go to the + /// primary nodes. If there are no replica nodes, then all queries will go to the primary nodes. + pub fn read_from_replicas(mut self) -> ClusterClientBuilder { + self.builder_params.read_from_replicas = ReadFromReplicaStrategy::RoundRobin; + self + } + + /// Enables periodic topology checks for this client. + /// + /// If enabled, periodic topology checks will be executed at the configured intervals to examine whether there + /// have been any changes in the cluster's topology. If a change is detected, it will trigger a slot refresh. + /// Unlike slot refreshments, the periodic topology checks only examine a limited number of nodes to query their + /// topology, ensuring that the check remains quick and efficient. + #[cfg(feature = "cluster-async")] + pub fn periodic_topology_checks(mut self, interval: Duration) -> ClusterClientBuilder { + self.builder_params.topology_checks_interval = Some(interval); + self + } + + /// Enables periodic connections checks for this client. + /// If enabled, the conenctions to the cluster nodes will be validated periodicatly, per configured interval. + /// In addition, for tokio runtime, passive disconnections could be detected instantly, + /// triggering reestablishemnt, w/o waiting for the next periodic check. + #[cfg(feature = "cluster-async")] + pub fn periodic_connections_checks(mut self, interval: Duration) -> ClusterClientBuilder { + self.builder_params.connections_validation_interval = Some(interval); + self + } + + /// Sets the rate limit for slot refresh operations in the cluster. + /// + /// This method configures the interval duration between consecutive slot + /// refresh operations and an additional jitter to introduce randomness + /// in the refresh intervals. + /// + /// # Parameters + /// + /// * `interval_duration`: The minimum duration to wait between consecutive slot refresh operations. + /// * `max_jitter_milli`: The maximum jitter in milliseconds to add to the interval duration. + /// + /// # Defaults + /// + /// If not set, the slots refresh rate limit configurations will be set with the default values: + /// ``` + /// #[cfg(feature = "cluster-async")] + /// use redis::cluster_topology::{DEFAULT_SLOTS_REFRESH_MAX_JITTER_MILLI, DEFAULT_SLOTS_REFRESH_WAIT_DURATION}; + /// ``` + /// + /// - `interval_duration`: `DEFAULT_SLOTS_REFRESH_WAIT_DURATION` + /// - `max_jitter_milli`: `DEFAULT_SLOTS_REFRESH_MAX_JITTER_MILLI` + /// + #[cfg(feature = "cluster-async")] + pub fn slots_refresh_rate_limit( + mut self, + interval_duration: Duration, + max_jitter_milli: u64, + ) -> ClusterClientBuilder { + self.builder_params.slots_refresh_rate_limit = SlotsRefreshRateLimit { + interval_duration, + max_jitter_milli, + }; + self + } + + /// Enables timing out on slow connection time. + /// + /// If enabled, the cluster will only wait the given time on each connection attempt to each node. + pub fn connection_timeout(mut self, connection_timeout: Duration) -> ClusterClientBuilder { + self.builder_params.connection_timeout = Some(connection_timeout); + self + } + + /// Enables timing out on slow responses. + /// + /// If enabled, the cluster will only wait the given time to each response from each node. + pub fn response_timeout(mut self, response_timeout: Duration) -> ClusterClientBuilder { + self.builder_params.response_timeout = Some(response_timeout); + self + } + + /// Sets the protocol with which the client should communicate with the server. + pub fn use_protocol(mut self, protocol: ProtocolVersion) -> ClusterClientBuilder { + self.builder_params.protocol = protocol; + self + } + + /// Use `build()`. + #[deprecated(since = "0.22.0", note = "Use build()")] + pub fn open(self) -> RedisResult { + self.build() + } + + /// Use `read_from_replicas()`. + #[deprecated(since = "0.22.0", note = "Use read_from_replicas()")] + pub fn readonly(mut self, read_from_replicas: bool) -> ClusterClientBuilder { + self.builder_params.read_from_replicas = if read_from_replicas { + ReadFromReplicaStrategy::RoundRobin + } else { + ReadFromReplicaStrategy::AlwaysFromPrimary + }; + self + } + + /// Sets the pubsub configuration for the new ClusterClient. + pub fn pubsub_subscriptions( + mut self, + pubsub_subscriptions: PubSubSubscriptionInfo, + ) -> ClusterClientBuilder { + self.builder_params.pubsub_subscriptions = Some(pubsub_subscriptions); + self + } +} + +/// This is a Redis Cluster client. +#[derive(Clone)] +pub struct ClusterClient { + initial_nodes: Vec, + cluster_params: ClusterParams, +} + +impl ClusterClient { + /// Creates a `ClusterClient` with the default parameters. + /// + /// This does not create connections to the Redis Cluster, but only performs some basic checks + /// on the initial nodes' URLs and passwords/usernames. + /// + /// # Errors + /// + /// Upon failure to parse initial nodes or if the initial nodes have different passwords or + /// usernames, an error is returned. + pub fn new( + initial_nodes: impl IntoIterator, + ) -> RedisResult { + Self::builder(initial_nodes).build() + } + + /// Creates a [`ClusterClientBuilder`] with the provided initial_nodes. + pub fn builder( + initial_nodes: impl IntoIterator, + ) -> ClusterClientBuilder { + ClusterClientBuilder::new(initial_nodes) + } + + /// Creates new connections to Redis Cluster nodes and returns a + /// [`cluster::ClusterConnection`]. + /// + /// # Errors + /// + /// An error is returned if there is a failure while creating connections or slots. + pub fn get_connection( + &self, + push_sender: Option>, + ) -> RedisResult { + cluster::ClusterConnection::new( + self.cluster_params.clone(), + self.initial_nodes.clone(), + push_sender, + ) + } + + /// Creates new connections to Redis Cluster nodes and returns a + /// [`cluster_async::ClusterConnection`]. + /// + /// # Errors + /// + /// An error is returned if there is a failure while creating connections or slots. + #[cfg(feature = "cluster-async")] + pub async fn get_async_connection( + &self, + push_sender: Option>, + ) -> RedisResult { + cluster_async::ClusterConnection::new( + &self.initial_nodes, + self.cluster_params.clone(), + push_sender, + ) + .await + } + + #[doc(hidden)] + pub fn get_generic_connection( + &self, + push_sender: Option>, + ) -> RedisResult> + where + C: crate::ConnectionLike + crate::cluster::Connect + Send, + { + cluster::ClusterConnection::new( + self.cluster_params.clone(), + self.initial_nodes.clone(), + push_sender, + ) + } + + #[doc(hidden)] + #[cfg(feature = "cluster-async")] + pub async fn get_async_generic_connection( + &self, + ) -> RedisResult> + where + C: crate::aio::ConnectionLike + + cluster_async::Connect + + Clone + + Send + + Sync + + Unpin + + 'static, + { + cluster_async::ClusterConnection::new( + &self.initial_nodes, + self.cluster_params.clone(), + None, + ) + .await + } + + /// Use `new()`. + #[deprecated(since = "0.22.0", note = "Use new()")] + pub fn open(initial_nodes: Vec) -> RedisResult { + Self::new(initial_nodes) + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "cluster-async")] + use crate::cluster_topology::{ + DEFAULT_SLOTS_REFRESH_MAX_JITTER_MILLI, DEFAULT_SLOTS_REFRESH_WAIT_DURATION, + }; + + use super::{ClusterClient, ClusterClientBuilder, ConnectionInfo, IntoConnectionInfo}; + + fn get_connection_data() -> Vec { + vec![ + "redis://127.0.0.1:6379".into_connection_info().unwrap(), + "redis://127.0.0.1:6378".into_connection_info().unwrap(), + "redis://127.0.0.1:6377".into_connection_info().unwrap(), + ] + } + + fn get_connection_data_with_password() -> Vec { + vec![ + "redis://:password@127.0.0.1:6379" + .into_connection_info() + .unwrap(), + "redis://:password@127.0.0.1:6378" + .into_connection_info() + .unwrap(), + "redis://:password@127.0.0.1:6377" + .into_connection_info() + .unwrap(), + ] + } + + fn get_connection_data_with_username_and_password() -> Vec { + vec![ + "redis://user1:password@127.0.0.1:6379" + .into_connection_info() + .unwrap(), + "redis://user1:password@127.0.0.1:6378" + .into_connection_info() + .unwrap(), + "redis://user1:password@127.0.0.1:6377" + .into_connection_info() + .unwrap(), + ] + } + + #[test] + fn give_no_password() { + let client = ClusterClient::new(get_connection_data()).unwrap(); + assert_eq!(client.cluster_params.password, None); + } + + #[test] + fn give_password_by_initial_nodes() { + let client = ClusterClient::new(get_connection_data_with_password()).unwrap(); + assert_eq!(client.cluster_params.password, Some("password".to_string())); + } + + #[test] + fn give_username_and_password_by_initial_nodes() { + let client = ClusterClient::new(get_connection_data_with_username_and_password()).unwrap(); + assert_eq!(client.cluster_params.password, Some("password".to_string())); + assert_eq!(client.cluster_params.username, Some("user1".to_string())); + } + + #[test] + fn give_different_password_by_initial_nodes() { + let result = ClusterClient::new(vec![ + "redis://:password1@127.0.0.1:6379", + "redis://:password2@127.0.0.1:6378", + "redis://:password3@127.0.0.1:6377", + ]); + assert!(result.is_err()); + } + + #[test] + fn give_different_username_by_initial_nodes() { + let result = ClusterClient::new(vec![ + "redis://user1:password@127.0.0.1:6379", + "redis://user2:password@127.0.0.1:6378", + "redis://user1:password@127.0.0.1:6377", + ]); + assert!(result.is_err()); + } + + #[test] + fn give_username_password_by_method() { + let client = ClusterClientBuilder::new(get_connection_data_with_password()) + .password("pass".to_string()) + .username("user1".to_string()) + .build() + .unwrap(); + assert_eq!(client.cluster_params.password, Some("pass".to_string())); + assert_eq!(client.cluster_params.username, Some("user1".to_string())); + } + + #[test] + fn give_empty_initial_nodes() { + let client = ClusterClient::new(Vec::::new()); + assert!(client.is_err()) + } + + #[cfg(feature = "cluster-async")] + #[test] + fn give_slots_refresh_rate_limit_configurations() { + let interval_dur = std::time::Duration::from_secs(20); + let client = ClusterClientBuilder::new(get_connection_data()) + .slots_refresh_rate_limit(interval_dur, 500) + .build() + .unwrap(); + assert_eq!( + client + .cluster_params + .slots_refresh_rate_limit + .interval_duration, + interval_dur + ); + assert_eq!( + client + .cluster_params + .slots_refresh_rate_limit + .max_jitter_milli, + 500 + ); + } + + #[cfg(feature = "cluster-async")] + #[test] + fn dont_give_slots_refresh_rate_limit_configurations_uses_defaults() { + let client = ClusterClientBuilder::new(get_connection_data()) + .build() + .unwrap(); + assert_eq!( + client + .cluster_params + .slots_refresh_rate_limit + .interval_duration, + DEFAULT_SLOTS_REFRESH_WAIT_DURATION + ); + assert_eq!( + client + .cluster_params + .slots_refresh_rate_limit + .max_jitter_milli, + DEFAULT_SLOTS_REFRESH_MAX_JITTER_MILLI + ); + } +} diff --git a/glide-core/redis-rs/redis/src/cluster_pipeline.rs b/glide-core/redis-rs/redis/src/cluster_pipeline.rs new file mode 100644 index 0000000000..9da1fee781 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_pipeline.rs @@ -0,0 +1,151 @@ +use crate::cluster::ClusterConnection; +use crate::cmd::{cmd, Cmd}; +use crate::types::{ + from_owned_redis_value, ErrorKind, FromRedisValue, HashSet, RedisResult, ToRedisArgs, Value, +}; + +pub(crate) const UNROUTABLE_ERROR: (ErrorKind, &str) = ( + ErrorKind::ClientError, + "This command cannot be safely routed in cluster mode", +); + +fn is_illegal_cmd(cmd: &str) -> bool { + matches!( + cmd, + "BGREWRITEAOF" | "BGSAVE" | "BITOP" | "BRPOPLPUSH" | + // All commands that start with "CLIENT" + "CLIENT" | "CLIENT GETNAME" | "CLIENT KILL" | "CLIENT LIST" | "CLIENT SETNAME" | + // All commands that start with "CONFIG" + "CONFIG" | "CONFIG GET" | "CONFIG RESETSTAT" | "CONFIG REWRITE" | "CONFIG SET" | + "DBSIZE" | + "ECHO" | "EVALSHA" | + "FLUSHALL" | "FLUSHDB" | + "INFO" | + "KEYS" | + "LASTSAVE" | + "MGET" | "MOVE" | "MSET" | "MSETNX" | + "PFMERGE" | "PFCOUNT" | "PING" | "PUBLISH" | + "RANDOMKEY" | "RENAME" | "RENAMENX" | "RPOPLPUSH" | + "SAVE" | "SCAN" | + // All commands that start with "SCRIPT" + "SCRIPT" | "SCRIPT EXISTS" | "SCRIPT FLUSH" | "SCRIPT KILL" | "SCRIPT LOAD" | + "SDIFF" | "SDIFFSTORE" | + // All commands that start with "SENTINEL" + "SENTINEL" | "SENTINEL GET MASTER ADDR BY NAME" | "SENTINEL MASTER" | "SENTINEL MASTERS" | + "SENTINEL MONITOR" | "SENTINEL REMOVE" | "SENTINEL SENTINELS" | "SENTINEL SET" | + "SENTINEL SLAVES" | "SHUTDOWN" | "SINTER" | "SINTERSTORE" | "SLAVEOF" | + // All commands that start with "SLOWLOG" + "SLOWLOG" | "SLOWLOG GET" | "SLOWLOG LEN" | "SLOWLOG RESET" | + "SMOVE" | "SORT" | "SUNION" | "SUNIONSTORE" | + "TIME" + ) +} + +/// Represents a Redis Cluster command pipeline. +#[derive(Clone)] +pub struct ClusterPipeline { + commands: Vec, + ignored_commands: HashSet, +} + +/// A cluster pipeline is almost identical to a normal [Pipeline](crate::pipeline::Pipeline), with two exceptions: +/// * It does not support transactions +/// * The following commands can not be used in a cluster pipeline: +/// ```text +/// BGREWRITEAOF, BGSAVE, BITOP, BRPOPLPUSH +/// CLIENT GETNAME, CLIENT KILL, CLIENT LIST, CLIENT SETNAME, CONFIG GET, +/// CONFIG RESETSTAT, CONFIG REWRITE, CONFIG SET +/// DBSIZE +/// ECHO, EVALSHA +/// FLUSHALL, FLUSHDB +/// INFO +/// KEYS +/// LASTSAVE +/// MGET, MOVE, MSET, MSETNX +/// PFMERGE, PFCOUNT, PING, PUBLISH +/// RANDOMKEY, RENAME, RENAMENX, RPOPLPUSH +/// SAVE, SCAN, SCRIPT EXISTS, SCRIPT FLUSH, SCRIPT KILL, SCRIPT LOAD, SDIFF, SDIFFSTORE, +/// SENTINEL GET MASTER ADDR BY NAME, SENTINEL MASTER, SENTINEL MASTERS, SENTINEL MONITOR, +/// SENTINEL REMOVE, SENTINEL SENTINELS, SENTINEL SET, SENTINEL SLAVES, SHUTDOWN, SINTER, +/// SINTERSTORE, SLAVEOF, SLOWLOG GET, SLOWLOG LEN, SLOWLOG RESET, SMOVE, SORT, SUNION, SUNIONSTORE +/// TIME +/// ``` +impl ClusterPipeline { + /// Create an empty pipeline. + pub fn new() -> ClusterPipeline { + Self::with_capacity(0) + } + + /// Creates an empty pipeline with pre-allocated capacity. + pub fn with_capacity(capacity: usize) -> ClusterPipeline { + ClusterPipeline { + commands: Vec::with_capacity(capacity), + ignored_commands: HashSet::new(), + } + } + + pub(crate) fn commands(&self) -> &Vec { + &self.commands + } + + /// Executes the pipeline and fetches the return values: + /// + /// ```rust,no_run + /// # let nodes = vec!["redis://127.0.0.1:6379/"]; + /// # let client = redis::cluster::ClusterClient::new(nodes).unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// let mut pipe = redis::cluster::cluster_pipe(); + /// let (k1, k2) : (i32, i32) = pipe + /// .cmd("SET").arg("key_1").arg(42).ignore() + /// .cmd("SET").arg("key_2").arg(43).ignore() + /// .cmd("GET").arg("key_1") + /// .cmd("GET").arg("key_2").query(&mut con).unwrap(); + /// ``` + #[inline] + pub fn query(&self, con: &mut ClusterConnection) -> RedisResult { + for cmd in &self.commands { + let cmd_name = std::str::from_utf8(cmd.arg_idx(0).unwrap_or(b"")) + .unwrap_or("") + .trim() + .to_ascii_uppercase(); + + if is_illegal_cmd(&cmd_name) { + fail!(( + UNROUTABLE_ERROR.0, + UNROUTABLE_ERROR.1, + format!("Command '{cmd_name}' can't be executed in a cluster pipeline.") + )) + } + } + + from_owned_redis_value(if self.commands.is_empty() { + Value::Array(vec![]) + } else { + self.make_pipeline_results(con.execute_pipeline(self)?) + }) + } + + /// This is a shortcut to `query()` that does not return a value and + /// will fail the task if the query of the pipeline fails. + /// + /// This is equivalent to a call to query like this: + /// + /// ```rust,no_run + /// # let nodes = vec!["redis://127.0.0.1:6379/"]; + /// # let client = redis::cluster::ClusterClient::new(nodes).unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// let mut pipe = redis::cluster::cluster_pipe(); + /// let _ : () = pipe.cmd("SET").arg("key_1").arg(42).ignore().query(&mut con).unwrap(); + /// ``` + #[inline] + pub fn execute(&self, con: &mut ClusterConnection) { + self.query::<()>(con).unwrap(); + } +} + +/// Shortcut for creating a new cluster pipeline. +pub fn cluster_pipe() -> ClusterPipeline { + ClusterPipeline::new() +} + +implement_pipeline_commands!(ClusterPipeline); diff --git a/glide-core/redis-rs/redis/src/cluster_routing.rs b/glide-core/redis-rs/redis/src/cluster_routing.rs new file mode 100644 index 0000000000..bfe6ae2039 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_routing.rs @@ -0,0 +1,1374 @@ +use rand::Rng; +use std::cmp::min; +use std::collections::HashMap; + +use crate::cluster_topology::get_slot; +use crate::cmd::{Arg, Cmd}; +use crate::types::Value; +use crate::{ErrorKind, RedisResult}; +use std::iter::Once; + +#[derive(Clone)] +pub(crate) enum Redirect { + Moved(String), + Ask(String), +} + +/// Logical bitwise aggregating operators. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum LogicalAggregateOp { + /// Aggregate by bitwise && + And, + // Or, omitted due to dead code warnings. ATM this value isn't constructed anywhere +} + +/// Numerical aggreagting operators. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum AggregateOp { + /// Choose minimal value + Min, + /// Sum all values + Sum, + // Max, omitted due to dead code warnings. ATM this value isn't constructed anywhere +} + +/// Policy defining how to combine multiple responses into one. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum ResponsePolicy { + /// Wait for one request to succeed and return its results. Return error if all requests fail. + OneSucceeded, + /// Returns the first succeeded non-empty result; if all results are empty, returns `Nil`; otherwise, returns the last received error. + FirstSucceededNonEmptyOrAllEmpty, + /// Waits for all requests to succeed, and the returns one of the successes. Returns the error on the first received error. + AllSucceeded, + /// Aggregate success results according to a logical bitwise operator. Return error on any failed request or on a response that doesn't conform to 0 or 1. + AggregateLogical(LogicalAggregateOp), + /// Aggregate success results according to a numeric operator. Return error on any failed request or on a response that isn't an integer. + Aggregate(AggregateOp), + /// Aggregate array responses into a single array. Return error on any failed request or on a response that isn't an array. + CombineArrays, + /// Handling is not defined by the Redis standard. Will receive a special case + Special, + /// Combines multiple map responses into a single map. + CombineMaps, +} + +/// Defines whether a request should be routed to a single node, or multiple ones. +#[derive(Debug, Clone, PartialEq)] +pub enum RoutingInfo { + /// Route to single node + SingleNode(SingleNodeRoutingInfo), + /// Route to multiple nodes + MultiNode((MultipleNodeRoutingInfo, Option)), +} + +/// Defines which single node should receive a request. +#[derive(Debug, Clone, PartialEq)] +pub enum SingleNodeRoutingInfo { + /// Route to any node at random + Random, + /// Route to any *primary* node + RandomPrimary, + /// Route to the node that matches the [Route] + SpecificNode(Route), + /// Route to the node with the given address. + ByAddress { + /// DNS hostname of the node + host: String, + /// port of the node + port: u16, + }, +} + +impl From> for SingleNodeRoutingInfo { + fn from(value: Option) -> Self { + value + .map(SingleNodeRoutingInfo::SpecificNode) + .unwrap_or(SingleNodeRoutingInfo::Random) + } +} + +/// Defines which collection of nodes should receive a request +#[derive(Debug, Clone, PartialEq)] +pub enum MultipleNodeRoutingInfo { + /// Route to all nodes in the clusters + AllNodes, + /// Route to all primaries in the cluster + AllMasters, + /// Instructions for how to split a multi-slot command (e.g. MGET, MSET) into sub-commands. Each tuple is the route for each subcommand, and the indices of the arguments from the original command that should be copied to the subcommand. + MultiSlot(Vec<(Route, Vec)>), +} + +/// Takes a routable and an iterator of indices, which is assued to be created from`MultipleNodeRoutingInfo::MultiSlot`, +/// and returns a command with the arguments matching the indices. +pub fn command_for_multi_slot_indices<'a, 'b>( + original_cmd: &'a impl Routable, + indices: impl Iterator + 'a, +) -> Cmd +where + 'b: 'a, +{ + let mut new_cmd = Cmd::new(); + let command_length = 1; // TODO - the +1 should change if we have multi-slot commands with 2 command words. + new_cmd.arg(original_cmd.arg_idx(0)); + for index in indices { + new_cmd.arg(original_cmd.arg_idx(index + command_length)); + } + new_cmd +} + +/// Aggreagte numeric responses. +pub fn aggregate(values: Vec, op: AggregateOp) -> RedisResult { + let initial_value = match op { + AggregateOp::Min => i64::MAX, + AggregateOp::Sum => 0, + }; + let result = values.into_iter().try_fold(initial_value, |acc, curr| { + let int = match curr { + Value::Int(int) => int, + _ => { + return RedisResult::Err( + ( + ErrorKind::TypeError, + "expected array of integers as response", + ) + .into(), + ); + } + }; + let acc = match op { + AggregateOp::Min => min(acc, int), + AggregateOp::Sum => acc + int, + }; + Ok(acc) + })?; + Ok(Value::Int(result)) +} + +/// Aggreagte numeric responses by a boolean operator. +pub fn logical_aggregate(values: Vec, op: LogicalAggregateOp) -> RedisResult { + let initial_value = match op { + LogicalAggregateOp::And => true, + }; + let results = values.into_iter().try_fold(Vec::new(), |acc, curr| { + let values = match curr { + Value::Array(values) => values, + _ => { + return RedisResult::Err( + ( + ErrorKind::TypeError, + "expected array of integers as response", + ) + .into(), + ); + } + }; + let mut acc = if acc.is_empty() { + vec![initial_value; values.len()] + } else { + acc + }; + for (index, value) in values.into_iter().enumerate() { + let int = match value { + Value::Int(int) => int, + _ => { + return Err(( + ErrorKind::TypeError, + "expected array of integers as response", + ) + .into()); + } + }; + acc[index] = match op { + LogicalAggregateOp::And => acc[index] && (int > 0), + }; + } + Ok(acc) + })?; + Ok(Value::Array( + results + .into_iter() + .map(|result| Value::Int(result as i64)) + .collect(), + )) +} +/// Aggregate array responses into a single map. +pub fn combine_map_results(values: Vec) -> RedisResult { + let mut map: HashMap, i64> = HashMap::new(); + + for value in values { + match value { + Value::Array(elements) => { + let mut iter = elements.into_iter(); + + while let Some(key) = iter.next() { + if let Value::BulkString(key_bytes) = key { + if let Some(Value::Int(value)) = iter.next() { + *map.entry(key_bytes).or_insert(0) += value; + } else { + return Err((ErrorKind::TypeError, "expected integer value").into()); + } + } else { + return Err((ErrorKind::TypeError, "expected string key").into()); + } + } + } + _ => { + return Err((ErrorKind::TypeError, "expected array of values as response").into()); + } + } + } + + let result_vec: Vec<(Value, Value)> = map + .into_iter() + .map(|(k, v)| (Value::BulkString(k), Value::Int(v))) + .collect(); + + Ok(Value::Map(result_vec)) +} + +/// Aggregate array responses into a single array. +pub fn combine_array_results(values: Vec) -> RedisResult { + let mut results = Vec::new(); + + for value in values { + match value { + Value::Array(values) => results.extend(values), + _ => { + return Err((ErrorKind::TypeError, "expected array of values as response").into()); + } + } + } + + Ok(Value::Array(results)) +} + +/// Combines multiple call results in the `values` field, each assume to be an array of results, +/// into a single array. `sorting_order` defines the order of the results in the returned array - +/// for each array of results, `sorting_order` should contain a matching array with the indices of +/// the results in the final array. +pub(crate) fn combine_and_sort_array_results<'a>( + values: Vec, + sorting_order: impl ExactSizeIterator>, +) -> RedisResult { + let mut results = Vec::new(); + results.resize( + values.iter().fold(0, |acc, value| match value { + Value::Array(values) => values.len() + acc, + _ => 0, + }), + Value::Nil, + ); + assert_eq!(values.len(), sorting_order.len()); + + for (key_indices, value) in sorting_order.into_iter().zip(values) { + match value { + Value::Array(values) => { + assert_eq!(values.len(), key_indices.len()); + for (index, value) in key_indices.iter().zip(values) { + results[*index] = value; + } + } + _ => { + return Err((ErrorKind::TypeError, "expected array of values as response").into()); + } + } + } + + Ok(Value::Array(results)) +} + +fn get_route(is_readonly: bool, key: &[u8]) -> Route { + let slot = get_slot(key); + if is_readonly { + Route::new(slot, SlotAddr::ReplicaOptional) + } else { + Route::new(slot, SlotAddr::Master) + } +} + +/// Takes the given `routable` and creates a multi-slot routing info. +/// This is used for commands like MSET & MGET, where if the command's keys +/// are hashed to multiple slots, the command should be split into sub-commands, +/// each targetting a single slot. The results of these sub-commands are then +/// usually reassembled using `combine_and_sort_array_results`. In order to do this, +/// `MultipleNodeRoutingInfo::MultiSlot` contains the routes for each sub-command, and +/// the indices in the final combined result for each result from the sub-command. +/// +/// If all keys are routed to the same slot, there's no need to split the command, +/// so a single node routing info will be returned. +fn multi_shard( + routable: &R, + cmd: &[u8], + first_key_index: usize, + has_values: bool, +) -> Option +where + R: Routable + ?Sized, +{ + let is_readonly = is_readonly_cmd(cmd); + let mut routes = HashMap::new(); + let mut key_index = 0; + while let Some(key) = routable.arg_idx(first_key_index + key_index) { + let route = get_route(is_readonly, key); + let entry = routes.entry(route); + let keys = entry.or_insert(Vec::new()); + keys.push(key_index); + + if has_values { + key_index += 1; + routable.arg_idx(first_key_index + key_index)?; // check that there's a value for the key + keys.push(key_index); + } + key_index += 1; + } + + let mut routes: Vec<(Route, Vec)> = routes.into_iter().collect(); + Some(if routes.len() == 1 { + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(routes.pop().unwrap().0)) + } else { + RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::MultiSlot(routes), + ResponsePolicy::for_command(cmd), + )) + }) +} + +impl ResponsePolicy { + /// Parse the command for the matching response policy. + pub fn for_command(cmd: &[u8]) -> Option { + match cmd { + b"SCRIPT EXISTS" => Some(ResponsePolicy::AggregateLogical(LogicalAggregateOp::And)), + + b"DBSIZE" | b"DEL" | b"EXISTS" | b"SLOWLOG LEN" | b"TOUCH" | b"UNLINK" + | b"LATENCY RESET" | b"PUBSUB NUMPAT" => { + Some(ResponsePolicy::Aggregate(AggregateOp::Sum)) + } + + b"WAIT" => Some(ResponsePolicy::Aggregate(AggregateOp::Min)), + + b"ACL SETUSER" | b"ACL DELUSER" | b"ACL SAVE" | b"CLIENT SETNAME" + | b"CLIENT SETINFO" | b"CONFIG SET" | b"CONFIG RESETSTAT" | b"CONFIG REWRITE" + | b"FLUSHALL" | b"FLUSHDB" | b"FUNCTION DELETE" | b"FUNCTION FLUSH" + | b"FUNCTION LOAD" | b"FUNCTION RESTORE" | b"MEMORY PURGE" | b"MSET" | b"PING" + | b"SCRIPT FLUSH" | b"SCRIPT LOAD" | b"SLOWLOG RESET" | b"UNWATCH" | b"WATCH" => { + Some(ResponsePolicy::AllSucceeded) + } + + b"KEYS" | b"MGET" | b"SLOWLOG GET" | b"PUBSUB CHANNELS" | b"PUBSUB SHARDCHANNELS" => { + Some(ResponsePolicy::CombineArrays) + } + b"PUBSUB NUMSUB" | b"PUBSUB SHARDNUMSUB" => Some(ResponsePolicy::CombineMaps), + + b"FUNCTION KILL" | b"SCRIPT KILL" => Some(ResponsePolicy::OneSucceeded), + + // This isn't based on response_tips, but on the discussion here - https://github.com/redis/redis/issues/12410 + b"RANDOMKEY" => Some(ResponsePolicy::FirstSucceededNonEmptyOrAllEmpty), + + b"LATENCY GRAPH" | b"LATENCY HISTOGRAM" | b"LATENCY HISTORY" | b"LATENCY DOCTOR" + | b"LATENCY LATEST" => Some(ResponsePolicy::Special), + + b"FUNCTION STATS" => Some(ResponsePolicy::Special), + + b"MEMORY MALLOC-STATS" | b"MEMORY DOCTOR" | b"MEMORY STATS" => { + Some(ResponsePolicy::Special) + } + + b"INFO" => Some(ResponsePolicy::Special), + + _ => None, + } + } +} + +enum RouteBy { + AllNodes, + AllPrimaries, + FirstKey, + MultiShardNoValues, + MultiShardWithValues, + Random, + SecondArg, + SecondArgAfterKeyCount, + SecondArgSlot, + StreamsIndex, + ThirdArgAfterKeyCount, + Undefined, +} + +fn base_routing(cmd: &[u8]) -> RouteBy { + match cmd { + b"ACL SETUSER" + | b"ACL DELUSER" + | b"ACL SAVE" + | b"CLIENT SETNAME" + | b"CLIENT SETINFO" + | b"SLOWLOG GET" + | b"SLOWLOG LEN" + | b"SLOWLOG RESET" + | b"CONFIG SET" + | b"CONFIG RESETSTAT" + | b"CONFIG REWRITE" + | b"SCRIPT FLUSH" + | b"SCRIPT LOAD" + | b"LATENCY RESET" + | b"LATENCY GRAPH" + | b"LATENCY HISTOGRAM" + | b"LATENCY HISTORY" + | b"LATENCY DOCTOR" + | b"LATENCY LATEST" + | b"PUBSUB NUMPAT" + | b"PUBSUB CHANNELS" + | b"PUBSUB NUMSUB" + | b"PUBSUB SHARDCHANNELS" + | b"PUBSUB SHARDNUMSUB" + | b"SCRIPT KILL" + | b"FUNCTION KILL" + | b"FUNCTION STATS" => RouteBy::AllNodes, + + b"DBSIZE" + | b"FLUSHALL" + | b"FLUSHDB" + | b"FUNCTION DELETE" + | b"FUNCTION FLUSH" + | b"FUNCTION LOAD" + | b"FUNCTION RESTORE" + | b"INFO" + | b"KEYS" + | b"MEMORY DOCTOR" + | b"MEMORY MALLOC-STATS" + | b"MEMORY PURGE" + | b"MEMORY STATS" + | b"PING" + | b"SCRIPT EXISTS" + | b"UNWATCH" + | b"WAIT" + | b"RANDOMKEY" + | b"WAITAOF" => RouteBy::AllPrimaries, + + b"MGET" | b"DEL" | b"EXISTS" | b"UNLINK" | b"TOUCH" | b"WATCH" => { + RouteBy::MultiShardNoValues + } + b"MSET" => RouteBy::MultiShardWithValues, + + // TODO - special handling - b"SCAN" + b"SCAN" | b"SHUTDOWN" | b"SLAVEOF" | b"REPLICAOF" => RouteBy::Undefined, + + b"BLMPOP" | b"BZMPOP" | b"EVAL" | b"EVALSHA" | b"EVALSHA_RO" | b"EVAL_RO" | b"FCALL" + | b"FCALL_RO" => RouteBy::ThirdArgAfterKeyCount, + + b"BITOP" + | b"MEMORY USAGE" + | b"PFDEBUG" + | b"XGROUP CREATE" + | b"XGROUP CREATECONSUMER" + | b"XGROUP DELCONSUMER" + | b"XGROUP DESTROY" + | b"XGROUP SETID" + | b"XINFO CONSUMERS" + | b"XINFO GROUPS" + | b"XINFO STREAM" + | b"OBJECT ENCODING" + | b"OBJECT FREQ" + | b"OBJECT IDLETIME" + | b"OBJECT REFCOUNT" => RouteBy::SecondArg, + + b"LMPOP" | b"SINTERCARD" | b"ZDIFF" | b"ZINTER" | b"ZINTERCARD" | b"ZMPOP" | b"ZUNION" => { + RouteBy::SecondArgAfterKeyCount + } + + b"XREAD" | b"XREADGROUP" => RouteBy::StreamsIndex, + + // keyless commands with more arguments, whose arguments might be wrongly taken to be keys. + // TODO - double check these, in order to find better ways to route some of them. + b"ACL DRYRUN" + | b"ACL GENPASS" + | b"ACL GETUSER" + | b"ACL HELP" + | b"ACL LIST" + | b"ACL LOG" + | b"ACL USERS" + | b"ACL WHOAMI" + | b"AUTH" + | b"BGSAVE" + | b"CLIENT GETNAME" + | b"CLIENT GETREDIR" + | b"CLIENT ID" + | b"CLIENT INFO" + | b"CLIENT KILL" + | b"CLIENT PAUSE" + | b"CLIENT REPLY" + | b"CLIENT TRACKINGINFO" + | b"CLIENT UNBLOCK" + | b"CLIENT UNPAUSE" + | b"CLUSTER COUNT-FAILURE-REPORTS" + | b"CLUSTER INFO" + | b"CLUSTER KEYSLOT" + | b"CLUSTER MEET" + | b"CLUSTER MYSHARDID" + | b"CLUSTER NODES" + | b"CLUSTER REPLICAS" + | b"CLUSTER RESET" + | b"CLUSTER SET-CONFIG-EPOCH" + | b"CLUSTER SHARDS" + | b"CLUSTER SLOTS" + | b"COMMAND COUNT" + | b"COMMAND GETKEYS" + | b"COMMAND LIST" + | b"COMMAND" + | b"CONFIG GET" + | b"DEBUG" + | b"ECHO" + | b"FUNCTION LIST" + | b"LASTSAVE" + | b"LOLWUT" + | b"MODULE LIST" + | b"MODULE LOAD" + | b"MODULE LOADEX" + | b"MODULE UNLOAD" + | b"READONLY" + | b"READWRITE" + | b"SAVE" + | b"SCRIPT SHOW" + | b"TFCALL" + | b"TFCALLASYNC" + | b"TFUNCTION DELETE" + | b"TFUNCTION LIST" + | b"TFUNCTION LOAD" + | b"TIME" => RouteBy::Random, + + b"CLUSTER ADDSLOTS" + | b"CLUSTER COUNTKEYSINSLOT" + | b"CLUSTER DELSLOTS" + | b"CLUSTER DELSLOTSRANGE" + | b"CLUSTER GETKEYSINSLOT" + | b"CLUSTER SETSLOT" => RouteBy::SecondArgSlot, + + _ => RouteBy::FirstKey, + } +} + +impl RoutingInfo { + /// Returns true if the `cmd` should be routed to all nodes. + pub fn is_all_nodes(cmd: &[u8]) -> bool { + matches!(base_routing(cmd), RouteBy::AllNodes) + } + + /// Returns true if the `cmd` is a key-based command that triggers MOVED errors. + /// A key-based command is one that will be accepted only by the slot owner, + /// while other nodes will respond with a MOVED error redirecting to the relevant primary owner. + pub fn is_key_routing_command(cmd: &[u8]) -> bool { + match base_routing(cmd) { + RouteBy::FirstKey + | RouteBy::SecondArg + | RouteBy::SecondArgAfterKeyCount + | RouteBy::ThirdArgAfterKeyCount + | RouteBy::SecondArgSlot + | RouteBy::StreamsIndex + | RouteBy::MultiShardNoValues + | RouteBy::MultiShardWithValues => { + if matches!(cmd, b"SPUBLISH") { + // SPUBLISH does not return MOVED errors within the slot's shard. This means that even if READONLY wasn't sent to a replica, + // executing SPUBLISH FOO BAR on that replica will succeed. This behavior differs from true key-based commands, + // such as SET FOO BAR, where a non-readonly replica would return a MOVED error if READONLY is off. + // Consequently, SPUBLISH does not meet the requirement of being a command that triggers MOVED errors. + // TODO: remove this when PRIMARY_PREFERRED route for SPUBLISH is added + false + } else { + true + } + } + RouteBy::AllNodes | RouteBy::AllPrimaries | RouteBy::Random | RouteBy::Undefined => { + false + } + } + } + + /// Returns the routing info for `r`. + pub fn for_routable(r: &R) -> Option + where + R: Routable + ?Sized, + { + let cmd = &r.command()?[..]; + match base_routing(cmd) { + RouteBy::AllNodes => Some(RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllNodes, + ResponsePolicy::for_command(cmd), + ))), + + RouteBy::AllPrimaries => Some(RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllMasters, + ResponsePolicy::for_command(cmd), + ))), + + RouteBy::MultiShardWithValues => multi_shard(r, cmd, 1, true), + + RouteBy::MultiShardNoValues => multi_shard(r, cmd, 1, false), + + RouteBy::Random => Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)), + + RouteBy::ThirdArgAfterKeyCount => { + let key_count = r + .arg_idx(2) + .and_then(|x| std::str::from_utf8(x).ok()) + .and_then(|x| x.parse::().ok())?; + if key_count == 0 { + Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + } else { + r.arg_idx(3).map(|key| RoutingInfo::for_key(cmd, key)) + } + } + + RouteBy::SecondArg => r.arg_idx(2).map(|key| RoutingInfo::for_key(cmd, key)), + + RouteBy::SecondArgAfterKeyCount => { + let key_count = r + .arg_idx(1) + .and_then(|x| std::str::from_utf8(x).ok()) + .and_then(|x| x.parse::().ok())?; + if key_count == 0 { + Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + } else { + r.arg_idx(2).map(|key| RoutingInfo::for_key(cmd, key)) + } + } + + RouteBy::StreamsIndex => { + let streams_position = r.position(b"STREAMS")?; + r.arg_idx(streams_position + 1) + .map(|key| RoutingInfo::for_key(cmd, key)) + } + + RouteBy::SecondArgSlot => r + .arg_idx(2) + .and_then(|arg| std::str::from_utf8(arg).ok()) + .and_then(|slot| slot.parse::().ok()) + .map(|slot| { + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + slot, + SlotAddr::Master, + ))) + }), + + RouteBy::FirstKey => match r.arg_idx(1) { + Some(key) => Some(RoutingInfo::for_key(cmd, key)), + None => Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)), + }, + + RouteBy::Undefined => None, + } + } + + fn for_key(cmd: &[u8], key: &[u8]) -> RoutingInfo { + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(get_route( + is_readonly_cmd(cmd), + key, + ))) + } +} + +/// Returns true if the given `routable` represents a readonly command. +pub fn is_readonly(routable: &impl Routable) -> bool { + match routable.command() { + Some(cmd) => is_readonly_cmd(cmd.as_slice()), + None => false, + } +} + +/// Returns `true` if the given `cmd` is a readonly command. +pub fn is_readonly_cmd(cmd: &[u8]) -> bool { + matches!( + cmd, + b"BITCOUNT" + | b"BITFIELD_RO" + | b"BITPOS" + | b"DBSIZE" + | b"DUMP" + | b"EVAL_RO" + | b"EVALSHA_RO" + | b"EXISTS" + | b"EXPIRETIME" + | b"FCALL_RO" + | b"FUNCTION DUMP" + | b"FUNCTION KILL" + | b"FUNCTION LIST" + | b"FUNCTION STATS" + | b"GEODIST" + | b"GEOHASH" + | b"GEOPOS" + | b"GEORADIUSBYMEMBER_RO" + | b"GEORADIUS_RO" + | b"GEOSEARCH" + | b"GET" + | b"GETBIT" + | b"GETRANGE" + | b"HEXISTS" + | b"HGET" + | b"HGETALL" + | b"HKEYS" + | b"HLEN" + | b"HMGET" + | b"HRANDFIELD" + | b"HSCAN" + | b"HSTRLEN" + | b"HVALS" + | b"KEYS" + | b"LCS" + | b"LINDEX" + | b"LLEN" + | b"LOLWUT" + | b"LPOS" + | b"LRANGE" + | b"MEMORY USAGE" + | b"MGET" + | b"OBJECT ENCODING" + | b"OBJECT FREQ" + | b"OBJECT IDLETIME" + | b"OBJECT REFCOUNT" + | b"PEXPIRETIME" + | b"PFCOUNT" + | b"PTTL" + | b"RANDOMKEY" + | b"SCAN" + | b"SCARD" + | b"SCRIPT DEBUG" + | b"SCRIPT EXISTS" + | b"SCRIPT FLUSH" + | b"SCRIPT KILL" + | b"SCRIPT LOAD" + | b"SCRIPT SHOW" + | b"SDIFF" + | b"SINTER" + | b"SINTERCARD" + | b"SISMEMBER" + | b"SMEMBERS" + | b"SMISMEMBER" + | b"SORT_RO" + | b"SRANDMEMBER" + | b"SSCAN" + | b"STRLEN" + | b"SUBSTR" + | b"SUNION" + | b"TOUCH" + | b"TTL" + | b"TYPE" + | b"XINFO CONSUMERS" + | b"XINFO GROUPS" + | b"XINFO STREAM" + | b"XLEN" + | b"XPENDING" + | b"XRANGE" + | b"XREAD" + | b"XREVRANGE" + | b"ZCARD" + | b"ZCOUNT" + | b"ZDIFF" + | b"ZINTER" + | b"ZINTERCARD" + | b"ZLEXCOUNT" + | b"ZMSCORE" + | b"ZRANDMEMBER" + | b"ZRANGE" + | b"ZRANGEBYLEX" + | b"ZRANGEBYSCORE" + | b"ZRANK" + | b"ZREVRANGE" + | b"ZREVRANGEBYLEX" + | b"ZREVRANGEBYSCORE" + | b"ZREVRANK" + | b"ZSCAN" + | b"ZSCORE" + | b"ZUNION" + ) +} + +/// Objects that implement this trait define a request that can be routed by a cluster client to different nodes in the cluster. +pub trait Routable { + /// Convenience function to return ascii uppercase version of the + /// the first argument (i.e., the command). + fn command(&self) -> Option> { + let primary_command = self.arg_idx(0).map(|x| x.to_ascii_uppercase())?; + let mut primary_command = match primary_command.as_slice() { + b"XGROUP" | b"OBJECT" | b"SLOWLOG" | b"FUNCTION" | b"MODULE" | b"COMMAND" + | b"PUBSUB" | b"CONFIG" | b"MEMORY" | b"XINFO" | b"CLIENT" | b"ACL" | b"SCRIPT" + | b"CLUSTER" | b"LATENCY" => primary_command, + _ => { + return Some(primary_command); + } + }; + + Some(match self.arg_idx(1) { + Some(secondary_command) => { + let previous_len = primary_command.len(); + primary_command.reserve(secondary_command.len() + 1); + primary_command.extend(b" "); + primary_command.extend(secondary_command); + let current_len = primary_command.len(); + primary_command[previous_len + 1..current_len].make_ascii_uppercase(); + primary_command + } + None => primary_command, + }) + } + + /// Returns a reference to the data for the argument at `idx`. + fn arg_idx(&self, idx: usize) -> Option<&[u8]>; + + /// Returns index of argument that matches `candidate`, if it exists + fn position(&self, candidate: &[u8]) -> Option; +} + +impl Routable for Cmd { + fn arg_idx(&self, idx: usize) -> Option<&[u8]> { + self.arg_idx(idx) + } + + fn position(&self, candidate: &[u8]) -> Option { + self.args_iter().position(|a| match a { + Arg::Simple(d) => d.eq_ignore_ascii_case(candidate), + _ => false, + }) + } +} + +impl Routable for Value { + fn arg_idx(&self, idx: usize) -> Option<&[u8]> { + match self { + Value::Array(args) => match args.get(idx) { + Some(Value::BulkString(ref data)) => Some(&data[..]), + _ => None, + }, + _ => None, + } + } + + fn position(&self, candidate: &[u8]) -> Option { + match self { + Value::Array(args) => args.iter().position(|a| match a { + Value::BulkString(d) => d.eq_ignore_ascii_case(candidate), + _ => false, + }), + _ => None, + } + } +} + +#[derive(Debug, Hash)] +pub(crate) struct Slot { + pub(crate) start: u16, + pub(crate) end: u16, + pub(crate) master: String, + pub(crate) replicas: Vec, +} + +impl Slot { + pub fn new(s: u16, e: u16, m: String, r: Vec) -> Self { + Self { + start: s, + end: e, + master: m, + replicas: r, + } + } + + pub fn start(&self) -> u16 { + self.start + } + + pub fn end(&self) -> u16 { + self.end + } + + #[allow(dead_code)] // used in tests + pub(crate) fn master(&self) -> &str { + self.master.as_str() + } + + #[allow(dead_code)] // used in tests + pub fn replicas(&self) -> Vec { + self.replicas.clone() + } +} + +/// What type of node should a request be routed to, assuming read from replica is enabled. +#[derive(Eq, PartialEq, Clone, Copy, Debug, Hash)] +pub enum SlotAddr { + /// The request must be routed to primary node + Master, + /// The request may be routed to a replica node. + /// For example, a GET command can be routed either to replica or primary. + ReplicaOptional, + /// The request must be routed to replica node, if one exists. + /// For example, by user requested routing. + ReplicaRequired, +} + +/// This is just a simplified version of [`Slot`], +/// which stores only the master and [optional] replica +/// to avoid the need to choose a replica each time +/// a command is executed +#[derive(Debug, Eq, PartialEq)] +pub(crate) struct SlotAddrs { + pub(crate) primary: String, + pub(crate) replicas: Vec, +} + +impl SlotAddrs { + pub(crate) fn new(primary: String, replicas: Vec) -> Self { + Self { primary, replicas } + } + + pub(crate) fn from_slot(slot: Slot) -> Self { + SlotAddrs::new(slot.master, slot.replicas) + } +} + +impl<'a> IntoIterator for &'a SlotAddrs { + type Item = &'a String; + type IntoIter = std::iter::Chain, std::slice::Iter<'a, String>>; + + fn into_iter(self) -> Self::IntoIter { + std::iter::once(&self.primary).chain(self.replicas.iter()) + } +} + +/// Defines the slot and the [`SlotAddr`] to which +/// a command should be sent +#[derive(Eq, PartialEq, Clone, Copy, Debug, Hash)] +pub struct Route(u16, SlotAddr); + +impl Route { + /// Returns a new Route. + pub fn new(slot: u16, slot_addr: SlotAddr) -> Self { + Self(slot, slot_addr) + } + + /// Returns the slot number of the route. + pub fn slot(&self) -> u16 { + self.0 + } + + /// Returns the slot address of the route. + pub fn slot_addr(&self) -> SlotAddr { + self.1 + } + + /// Returns a new Route for a random primary node + pub fn new_random_primary() -> Self { + Self::new(random_slot(), SlotAddr::Master) + } +} + +/// Choose a random slot from `0..SLOT_SIZE` (excluding) +fn random_slot() -> u16 { + let mut rng = rand::thread_rng(); + rng.gen_range(0..crate::cluster_topology::SLOT_SIZE) +} + +#[cfg(test)] +mod tests { + use super::{ + command_for_multi_slot_indices, AggregateOp, MultipleNodeRoutingInfo, ResponsePolicy, + Route, RoutingInfo, SingleNodeRoutingInfo, SlotAddr, + }; + use crate::{cluster_topology::slot, cmd, parser::parse_redis_value, Value}; + use core::panic; + + #[test] + fn test_routing_info_mixed_capatalization() { + let mut upper = cmd("XREAD"); + upper.arg("STREAMS").arg("foo").arg(0); + + let mut lower = cmd("xread"); + lower.arg("streams").arg("foo").arg(0); + + assert_eq!( + RoutingInfo::for_routable(&upper).unwrap(), + RoutingInfo::for_routable(&lower).unwrap() + ); + + let mut mixed = cmd("xReAd"); + mixed.arg("StReAmS").arg("foo").arg(0); + + assert_eq!( + RoutingInfo::for_routable(&lower).unwrap(), + RoutingInfo::for_routable(&mixed).unwrap() + ); + } + + #[test] + fn test_routing_info() { + let mut test_cmds = vec![]; + + // RoutingInfo::AllMasters + let mut test_cmd = cmd("FLUSHALL"); + test_cmd.arg(""); + test_cmds.push(test_cmd); + + // RoutingInfo::AllNodes + test_cmd = cmd("ECHO"); + test_cmd.arg(""); + test_cmds.push(test_cmd); + + // Routing key is 2nd arg ("42") + test_cmd = cmd("SET"); + test_cmd.arg("42"); + test_cmds.push(test_cmd); + + // Routing key is 3rd arg ("FOOBAR") + test_cmd = cmd("XINFO"); + test_cmd.arg("GROUPS").arg("FOOBAR"); + test_cmds.push(test_cmd); + + // Routing key is 3rd or 4th arg (3rd = "0" == RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + test_cmd = cmd("EVAL"); + test_cmd.arg("FOO").arg("0").arg("BAR"); + test_cmds.push(test_cmd); + + // Routing key is 3rd or 4th arg (3rd != "0" == RoutingInfo::Slot) + test_cmd = cmd("EVAL"); + test_cmd.arg("FOO").arg("4").arg("BAR"); + test_cmds.push(test_cmd); + + // Routing key position is variable, 3rd arg + test_cmd = cmd("XREAD"); + test_cmd.arg("STREAMS").arg("4"); + test_cmds.push(test_cmd); + + // Routing key position is variable, 4th arg + test_cmd = cmd("XREAD"); + test_cmd.arg("FOO").arg("STREAMS").arg("4"); + test_cmds.push(test_cmd); + + for cmd in test_cmds { + let value = parse_redis_value(&cmd.get_packed_command()).unwrap(); + assert_eq!( + RoutingInfo::for_routable(&value).unwrap(), + RoutingInfo::for_routable(&cmd).unwrap(), + ); + } + + // Assert expected RoutingInfo explicitly: + + for cmd in [cmd("FLUSHALL"), cmd("FLUSHDB"), cmd("PING")] { + assert_eq!( + RoutingInfo::for_routable(&cmd), + Some(RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllMasters, + Some(ResponsePolicy::AllSucceeded) + ))) + ); + } + + assert_eq!( + RoutingInfo::for_routable(&cmd("DBSIZE")), + Some(RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllMasters, + Some(ResponsePolicy::Aggregate(AggregateOp::Sum)) + ))) + ); + + assert_eq!( + RoutingInfo::for_routable(&cmd("SCRIPT KILL")), + Some(RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllNodes, + Some(ResponsePolicy::OneSucceeded) + ))) + ); + + assert_eq!( + RoutingInfo::for_routable(&cmd("INFO")), + Some(RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllMasters, + Some(ResponsePolicy::Special) + ))) + ); + + assert_eq!( + RoutingInfo::for_routable(&cmd("KEYS")), + Some(RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllMasters, + Some(ResponsePolicy::CombineArrays) + ))) + ); + + for cmd in vec![ + cmd("SCAN"), + cmd("SHUTDOWN"), + cmd("SLAVEOF"), + cmd("REPLICAOF"), + ] { + assert_eq!( + RoutingInfo::for_routable(&cmd), + None, + "{}", + std::str::from_utf8(cmd.arg_idx(0).unwrap()).unwrap() + ); + } + + for cmd in [ + cmd("EVAL").arg(r#"redis.call("PING");"#).arg(0), + cmd("EVALSHA").arg(r#"redis.call("PING");"#).arg(0), + ] { + assert_eq!( + RoutingInfo::for_routable(cmd), + Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + ); + } + + // While FCALL with N keys is expected to be routed to a specific node + assert_eq!( + RoutingInfo::for_routable(cmd("FCALL").arg("foo").arg(1).arg("mykey")), + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new(slot(b"mykey"), SlotAddr::Master)) + )) + ); + + for (cmd, expected) in [ + ( + cmd("EVAL") + .arg(r#"redis.call("GET, KEYS[1]");"#) + .arg(1) + .arg("foo"), + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new(slot(b"foo"), SlotAddr::Master)), + )), + ), + ( + cmd("XGROUP") + .arg("CREATE") + .arg("mystream") + .arg("workers") + .arg("$") + .arg("MKSTREAM"), + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new( + slot(b"mystream"), + SlotAddr::Master, + )), + )), + ), + ( + cmd("XINFO").arg("GROUPS").arg("foo"), + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new( + slot(b"foo"), + SlotAddr::ReplicaOptional, + )), + )), + ), + ( + cmd("XREADGROUP") + .arg("GROUP") + .arg("wkrs") + .arg("consmrs") + .arg("STREAMS") + .arg("mystream"), + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new( + slot(b"mystream"), + SlotAddr::Master, + )), + )), + ), + ( + cmd("XREAD") + .arg("COUNT") + .arg("2") + .arg("STREAMS") + .arg("mystream") + .arg("writers") + .arg("0-0") + .arg("0-0"), + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new( + slot(b"mystream"), + SlotAddr::ReplicaOptional, + )), + )), + ), + ] { + assert_eq!( + RoutingInfo::for_routable(cmd), + expected, + "{}", + std::str::from_utf8(cmd.arg_idx(0).unwrap()).unwrap() + ); + } + } + + #[test] + fn test_slot_for_packed_cmd() { + assert!(matches!(RoutingInfo::for_routable(&parse_redis_value(&[ + 42, 50, 13, 10, 36, 54, 13, 10, 69, 88, 73, 83, 84, 83, 13, 10, 36, 49, 54, 13, 10, + 244, 93, 23, 40, 126, 127, 253, 33, 89, 47, 185, 204, 171, 249, 96, 139, 13, 10 + ]).unwrap()), Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route(slot, SlotAddr::ReplicaOptional)))) if slot == 964)); + + assert!(matches!(RoutingInfo::for_routable(&parse_redis_value(&[ + 42, 54, 13, 10, 36, 51, 13, 10, 83, 69, 84, 13, 10, 36, 49, 54, 13, 10, 36, 241, + 197, 111, 180, 254, 5, 175, 143, 146, 171, 39, 172, 23, 164, 145, 13, 10, 36, 52, + 13, 10, 116, 114, 117, 101, 13, 10, 36, 50, 13, 10, 78, 88, 13, 10, 36, 50, 13, 10, + 80, 88, 13, 10, 36, 55, 13, 10, 49, 56, 48, 48, 48, 48, 48, 13, 10 + ]).unwrap()), Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route(slot, SlotAddr::Master)))) if slot == 8352)); + + assert!(matches!(RoutingInfo::for_routable(&parse_redis_value(&[ + 42, 54, 13, 10, 36, 51, 13, 10, 83, 69, 84, 13, 10, 36, 49, 54, 13, 10, 169, 233, + 247, 59, 50, 247, 100, 232, 123, 140, 2, 101, 125, 221, 66, 170, 13, 10, 36, 52, + 13, 10, 116, 114, 117, 101, 13, 10, 36, 50, 13, 10, 78, 88, 13, 10, 36, 50, 13, 10, + 80, 88, 13, 10, 36, 55, 13, 10, 49, 56, 48, 48, 48, 48, 48, 13, 10 + ]).unwrap()), Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route(slot, SlotAddr::Master)))) if slot == 5210)); + } + + #[test] + fn test_multi_shard() { + let mut cmd = cmd("DEL"); + cmd.arg("foo").arg("bar").arg("baz").arg("{bar}vaz"); + let routing = RoutingInfo::for_routable(&cmd); + let mut expected = std::collections::HashMap::new(); + expected.insert(Route(4813, SlotAddr::Master), vec![2]); + expected.insert(Route(5061, SlotAddr::Master), vec![1, 3]); + expected.insert(Route(12182, SlotAddr::Master), vec![0]); + + assert!( + matches!(routing.clone(), Some(RoutingInfo::MultiNode((MultipleNodeRoutingInfo::MultiSlot(vec), Some(ResponsePolicy::Aggregate(AggregateOp::Sum))))) if { + let routes = vec.clone().into_iter().collect(); + expected == routes + }), + "{routing:?}" + ); + + let mut cmd = crate::cmd("MGET"); + cmd.arg("foo").arg("bar").arg("baz").arg("{bar}vaz"); + let routing = RoutingInfo::for_routable(&cmd); + let mut expected = std::collections::HashMap::new(); + expected.insert(Route(4813, SlotAddr::ReplicaOptional), vec![2]); + expected.insert(Route(5061, SlotAddr::ReplicaOptional), vec![1, 3]); + expected.insert(Route(12182, SlotAddr::ReplicaOptional), vec![0]); + + assert!( + matches!(routing.clone(), Some(RoutingInfo::MultiNode((MultipleNodeRoutingInfo::MultiSlot(vec), Some(ResponsePolicy::CombineArrays)))) if { + let routes = vec.clone().into_iter().collect(); + expected ==routes + }), + "{routing:?}" + ); + } + + #[test] + fn test_command_creation_for_multi_shard() { + let mut original_cmd = cmd("DEL"); + original_cmd + .arg("foo") + .arg("bar") + .arg("baz") + .arg("{bar}vaz"); + let routing = RoutingInfo::for_routable(&original_cmd); + let expected = [vec![0], vec![1, 3], vec![2]]; + + let mut indices: Vec<_> = match routing { + Some(RoutingInfo::MultiNode((MultipleNodeRoutingInfo::MultiSlot(vec), _))) => { + vec.into_iter().map(|(_, indices)| indices).collect() + } + _ => panic!("unexpected routing: {routing:?}"), + }; + indices.sort_by(|prev, next| prev.iter().next().unwrap().cmp(next.iter().next().unwrap())); // sorting because the `for_routable` doesn't return values in a consistent order between runs. + + for (index, indices) in indices.into_iter().enumerate() { + let cmd = command_for_multi_slot_indices(&original_cmd, indices.iter()); + let expected_indices = &expected[index]; + assert_eq!(original_cmd.arg_idx(0), cmd.arg_idx(0)); + for (index, target_index) in expected_indices.iter().enumerate() { + let target_index = target_index + 1; + assert_eq!(original_cmd.arg_idx(target_index), cmd.arg_idx(index + 1)); + } + } + } + + #[test] + fn test_combine_multi_shard_to_single_node_when_all_keys_are_in_same_slot() { + let mut cmd = cmd("DEL"); + cmd.arg("foo").arg("{foo}bar").arg("{foo}baz"); + let routing = RoutingInfo::for_routable(&cmd); + + assert!( + matches!( + routing, + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route(12182, SlotAddr::Master)) + )) + ), + "{routing:?}" + ); + } + + #[test] + fn test_combining_results_into_single_array() { + let res1 = Value::Array(vec![Value::Nil, Value::Okay]); + let res2 = Value::Array(vec![ + Value::BulkString("1".as_bytes().to_vec()), + Value::BulkString("4".as_bytes().to_vec()), + ]); + let res3 = Value::Array(vec![Value::SimpleString("2".to_string()), Value::Int(3)]); + let results = super::combine_and_sort_array_results( + vec![res1, res2, res3], + [vec![0, 5], vec![1, 4], vec![2, 3]].iter(), + ); + + assert_eq!( + results.unwrap(), + Value::Array(vec![ + Value::Nil, + Value::BulkString("1".as_bytes().to_vec()), + Value::SimpleString("2".to_string()), + Value::Int(3), + Value::BulkString("4".as_bytes().to_vec()), + Value::Okay, + ]) + ); + } + + #[test] + fn test_combine_map_results() { + let input = vec![]; + let result = super::combine_map_results(input).unwrap(); + assert_eq!(result, Value::Map(vec![])); + + let input = vec![ + Value::Array(vec![ + Value::BulkString(b"key1".to_vec()), + Value::Int(5), + Value::BulkString(b"key2".to_vec()), + Value::Int(10), + ]), + Value::Array(vec![ + Value::BulkString(b"key1".to_vec()), + Value::Int(3), + Value::BulkString(b"key3".to_vec()), + Value::Int(15), + ]), + ]; + let result = super::combine_map_results(input).unwrap(); + let mut expected = vec![ + (Value::BulkString(b"key1".to_vec()), Value::Int(8)), + (Value::BulkString(b"key2".to_vec()), Value::Int(10)), + (Value::BulkString(b"key3".to_vec()), Value::Int(15)), + ]; + expected.sort_unstable_by(|a, b| match (&a.0, &b.0) { + (Value::BulkString(a_bytes), Value::BulkString(b_bytes)) => a_bytes.cmp(b_bytes), + _ => std::cmp::Ordering::Equal, + }); + let mut result_vec = match result { + Value::Map(v) => v, + _ => panic!("Expected Map"), + }; + result_vec.sort_unstable_by(|a, b| match (&a.0, &b.0) { + (Value::BulkString(a_bytes), Value::BulkString(b_bytes)) => a_bytes.cmp(b_bytes), + _ => std::cmp::Ordering::Equal, + }); + assert_eq!(result_vec, expected); + + let input = vec![Value::Int(5)]; + let result = super::combine_map_results(input); + assert!(result.is_err()); + } +} diff --git a/glide-core/redis-rs/redis/src/cluster_slotmap.rs b/glide-core/redis-rs/redis/src/cluster_slotmap.rs new file mode 100644 index 0000000000..7f1f70af98 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_slotmap.rs @@ -0,0 +1,435 @@ +use std::{ + collections::{BTreeMap, HashSet}, + fmt::Display, + sync::atomic::AtomicUsize, +}; + +use crate::cluster_routing::{Route, Slot, SlotAddr, SlotAddrs}; + +#[derive(Debug)] +pub(crate) struct SlotMapValue { + pub(crate) start: u16, + pub(crate) addrs: SlotAddrs, + pub(crate) latest_used_replica: AtomicUsize, +} + +impl SlotMapValue { + fn from_slot(slot: Slot) -> Self { + Self { + start: slot.start(), + addrs: SlotAddrs::from_slot(slot), + latest_used_replica: AtomicUsize::new(0), + } + } +} + +#[derive(Debug, Default, Clone, PartialEq, Copy)] +pub(crate) enum ReadFromReplicaStrategy { + #[default] + AlwaysFromPrimary, + RoundRobin, +} + +#[derive(Debug, Default)] +pub(crate) struct SlotMap { + pub(crate) slots: BTreeMap, + read_from_replica: ReadFromReplicaStrategy, +} + +fn get_address_from_slot( + slot: &SlotMapValue, + read_from_replica: ReadFromReplicaStrategy, + slot_addr: SlotAddr, +) -> &str { + if slot_addr == SlotAddr::Master || slot.addrs.replicas.is_empty() { + return slot.addrs.primary.as_str(); + } + match read_from_replica { + ReadFromReplicaStrategy::AlwaysFromPrimary => slot.addrs.primary.as_str(), + ReadFromReplicaStrategy::RoundRobin => { + let index = slot + .latest_used_replica + .fetch_add(1, std::sync::atomic::Ordering::Relaxed) + % slot.addrs.replicas.len(); + slot.addrs.replicas[index].as_str() + } + } +} + +impl SlotMap { + pub(crate) fn new(slots: Vec, read_from_replica: ReadFromReplicaStrategy) -> Self { + let mut this = Self { + slots: BTreeMap::new(), + read_from_replica, + }; + this.slots.extend( + slots + .into_iter() + .map(|slot| (slot.end(), SlotMapValue::from_slot(slot))), + ); + this + } + + pub fn slot_value_for_route(&self, route: &Route) -> Option<&SlotMapValue> { + let slot = route.slot(); + self.slots + .range(slot..) + .next() + .and_then(|(end, slot_value)| { + if slot <= *end && slot_value.start <= slot { + Some(slot_value) + } else { + None + } + }) + } + + pub fn slot_addr_for_route(&self, route: &Route) -> Option<&str> { + self.slot_value_for_route(route).map(|slot_value| { + get_address_from_slot(slot_value, self.read_from_replica, route.slot_addr()) + }) + } + + pub fn values(&self) -> impl Iterator { + self.slots.values().map(|slot_value| &slot_value.addrs) + } + + fn all_unique_addresses(&self, only_primaries: bool) -> HashSet<&str> { + let mut addresses = HashSet::new(); + for slot in self.values() { + addresses.insert(slot.primary.as_str()); + if !only_primaries { + addresses.extend(slot.replicas.iter().map(|str| str.as_str())); + } + } + + addresses + } + + pub fn addresses_for_all_primaries(&self) -> HashSet<&str> { + self.all_unique_addresses(true) + } + + pub fn addresses_for_all_nodes(&self) -> HashSet<&str> { + self.all_unique_addresses(false) + } + + pub fn addresses_for_multi_slot<'a, 'b>( + &'a self, + routes: &'b [(Route, Vec)], + ) -> impl Iterator> + 'a + where + 'b: 'a, + { + routes + .iter() + .map(|(route, _)| self.slot_addr_for_route(route)) + } + + // Returns the slots that are assigned to the given address. + pub(crate) fn get_slots_of_node(&self, node_address: &str) -> Vec { + let node_address = node_address.to_string(); + self.slots + .iter() + .filter_map(|(end, slot_value)| { + if slot_value.addrs.primary == node_address + || slot_value.addrs.replicas.contains(&node_address) + { + Some(slot_value.start..(*end + 1)) + } else { + None + } + }) + .flatten() + .collect() + } + + pub(crate) fn get_node_address_for_slot( + &self, + slot: u16, + slot_addr: SlotAddr, + ) -> Option { + self.slots.range(slot..).next().and_then(|(_, slot_value)| { + if slot_value.start <= slot { + Some( + get_address_from_slot(slot_value, self.read_from_replica, slot_addr) + .to_string(), + ) + } else { + None + } + }) + } +} + +impl Display for SlotMap { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "Strategy: {:?}. Slot mapping:", self.read_from_replica)?; + for (end, slot_map_value) in self.slots.iter() { + writeln!( + f, + "({}-{}): primary: {}, replicas: {:?}", + slot_map_value.start, + end, + slot_map_value.addrs.primary, + slot_map_value.addrs.replicas + )?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_slot_map_retrieve_routes() { + let slot_map = SlotMap::new( + vec![ + Slot::new( + 1, + 1000, + "node1:6379".to_owned(), + vec!["replica1:6379".to_owned()], + ), + Slot::new( + 1002, + 2000, + "node2:6379".to_owned(), + vec!["replica2:6379".to_owned()], + ), + ], + ReadFromReplicaStrategy::AlwaysFromPrimary, + ); + + assert!(slot_map + .slot_addr_for_route(&Route::new(0, SlotAddr::Master)) + .is_none()); + assert_eq!( + "node1:6379", + slot_map + .slot_addr_for_route(&Route::new(1, SlotAddr::Master)) + .unwrap() + ); + assert_eq!( + "node1:6379", + slot_map + .slot_addr_for_route(&Route::new(500, SlotAddr::Master)) + .unwrap() + ); + assert_eq!( + "node1:6379", + slot_map + .slot_addr_for_route(&Route::new(1000, SlotAddr::Master)) + .unwrap() + ); + assert!(slot_map + .slot_addr_for_route(&Route::new(1001, SlotAddr::Master)) + .is_none()); + + assert_eq!( + "node2:6379", + slot_map + .slot_addr_for_route(&Route::new(1002, SlotAddr::Master)) + .unwrap() + ); + assert_eq!( + "node2:6379", + slot_map + .slot_addr_for_route(&Route::new(1500, SlotAddr::Master)) + .unwrap() + ); + assert_eq!( + "node2:6379", + slot_map + .slot_addr_for_route(&Route::new(2000, SlotAddr::Master)) + .unwrap() + ); + assert!(slot_map + .slot_addr_for_route(&Route::new(2001, SlotAddr::Master)) + .is_none()); + } + + fn get_slot_map(read_from_replica: ReadFromReplicaStrategy) -> SlotMap { + SlotMap::new( + vec![ + Slot::new( + 1, + 1000, + "node1:6379".to_owned(), + vec!["replica1:6379".to_owned()], + ), + Slot::new( + 1002, + 2000, + "node2:6379".to_owned(), + vec!["replica2:6379".to_owned(), "replica3:6379".to_owned()], + ), + Slot::new( + 2001, + 3000, + "node3:6379".to_owned(), + vec![ + "replica4:6379".to_owned(), + "replica5:6379".to_owned(), + "replica6:6379".to_owned(), + ], + ), + Slot::new( + 3001, + 4000, + "node2:6379".to_owned(), + vec!["replica2:6379".to_owned(), "replica3:6379".to_owned()], + ), + ], + read_from_replica, + ) + } + + #[test] + fn test_slot_map_get_all_primaries() { + let slot_map = get_slot_map(ReadFromReplicaStrategy::AlwaysFromPrimary); + let addresses = slot_map.addresses_for_all_primaries(); + assert_eq!( + addresses, + HashSet::from_iter(["node1:6379", "node2:6379", "node3:6379"]) + ); + } + + #[test] + fn test_slot_map_get_all_nodes() { + let slot_map = get_slot_map(ReadFromReplicaStrategy::AlwaysFromPrimary); + let addresses = slot_map.addresses_for_all_nodes(); + assert_eq!( + addresses, + HashSet::from_iter([ + "node1:6379", + "node2:6379", + "node3:6379", + "replica1:6379", + "replica2:6379", + "replica3:6379", + "replica4:6379", + "replica5:6379", + "replica6:6379" + ]) + ); + } + + #[test] + fn test_slot_map_get_multi_node() { + let slot_map = get_slot_map(ReadFromReplicaStrategy::RoundRobin); + let routes = vec![ + (Route::new(1, SlotAddr::Master), vec![]), + (Route::new(2001, SlotAddr::ReplicaOptional), vec![]), + ]; + let addresses = slot_map + .addresses_for_multi_slot(&routes) + .collect::>(); + assert!(addresses.contains(&Some("node1:6379"))); + assert!( + addresses.contains(&Some("replica4:6379")) + || addresses.contains(&Some("replica5:6379")) + || addresses.contains(&Some("replica6:6379")) + ); + } + + /// This test is needed in order to verify that if the MultiSlot route finds the same node for more than a single route, + /// that node's address will appear multiple times, in the same order. + #[test] + fn test_slot_map_get_repeating_addresses_when_the_same_node_is_found_in_multi_slot() { + let slot_map = get_slot_map(ReadFromReplicaStrategy::RoundRobin); + let routes = vec![ + (Route::new(1, SlotAddr::ReplicaOptional), vec![]), + (Route::new(2001, SlotAddr::Master), vec![]), + (Route::new(2, SlotAddr::ReplicaOptional), vec![]), + (Route::new(2002, SlotAddr::Master), vec![]), + (Route::new(3, SlotAddr::ReplicaOptional), vec![]), + (Route::new(2003, SlotAddr::Master), vec![]), + ]; + let addresses = slot_map + .addresses_for_multi_slot(&routes) + .collect::>(); + assert_eq!( + addresses, + vec![ + Some("replica1:6379"), + Some("node3:6379"), + Some("replica1:6379"), + Some("node3:6379"), + Some("replica1:6379"), + Some("node3:6379") + ] + ); + } + + #[test] + fn test_slot_map_get_none_when_slot_is_missing_from_multi_slot() { + let slot_map = get_slot_map(ReadFromReplicaStrategy::RoundRobin); + let routes = vec![ + (Route::new(1, SlotAddr::ReplicaOptional), vec![]), + (Route::new(5000, SlotAddr::Master), vec![]), + (Route::new(6000, SlotAddr::ReplicaOptional), vec![]), + (Route::new(2002, SlotAddr::Master), vec![]), + ]; + let addresses = slot_map + .addresses_for_multi_slot(&routes) + .collect::>(); + assert_eq!( + addresses, + vec![Some("replica1:6379"), None, None, Some("node3:6379")] + ); + } + + #[test] + fn test_slot_map_rotate_read_replicas() { + let slot_map = get_slot_map(ReadFromReplicaStrategy::RoundRobin); + let route = Route::new(2001, SlotAddr::ReplicaOptional); + let mut addresses = vec![ + slot_map.slot_addr_for_route(&route).unwrap(), + slot_map.slot_addr_for_route(&route).unwrap(), + slot_map.slot_addr_for_route(&route).unwrap(), + ]; + addresses.sort(); + assert_eq!( + addresses, + vec!["replica4:6379", "replica5:6379", "replica6:6379"] + ); + } + + #[test] + fn test_get_slots_of_node() { + let slot_map = get_slot_map(ReadFromReplicaStrategy::AlwaysFromPrimary); + assert_eq!( + slot_map.get_slots_of_node("node1:6379"), + (1..1001).collect::>() + ); + assert_eq!( + slot_map.get_slots_of_node("node2:6379"), + vec![1002..2001, 3001..4001] + .into_iter() + .flatten() + .collect::>() + ); + assert_eq!( + slot_map.get_slots_of_node("replica3:6379"), + vec![1002..2001, 3001..4001] + .into_iter() + .flatten() + .collect::>() + ); + assert_eq!( + slot_map.get_slots_of_node("replica4:6379"), + (2001..3001).collect::>() + ); + assert_eq!( + slot_map.get_slots_of_node("replica5:6379"), + (2001..3001).collect::>() + ); + assert_eq!( + slot_map.get_slots_of_node("replica6:6379"), + (2001..3001).collect::>() + ); + } +} diff --git a/glide-core/redis-rs/redis/src/cluster_topology.rs b/glide-core/redis-rs/redis/src/cluster_topology.rs new file mode 100644 index 0000000000..a2ce9ea078 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_topology.rs @@ -0,0 +1,645 @@ +//! This module provides the functionality to refresh and calculate the cluster topology for Redis Cluster. + +use crate::cluster::get_connection_addr; +#[cfg(feature = "cluster-async")] +use crate::cluster_client::SlotsRefreshRateLimit; +use crate::cluster_routing::Slot; +use crate::cluster_slotmap::{ReadFromReplicaStrategy, SlotMap}; +use crate::{cluster::TlsMode, ErrorKind, RedisError, RedisResult, Value}; +#[cfg(all(feature = "cluster-async", not(feature = "tokio-comp")))] +use async_std::sync::RwLock; +use derivative::Derivative; +use std::collections::{hash_map::DefaultHasher, HashMap}; +use std::hash::{Hash, Hasher}; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; +use std::time::{Duration, SystemTime}; +#[cfg(all(feature = "cluster-async", feature = "tokio-comp"))] +use tokio::sync::RwLock; +use tracing::info; + +// Exponential backoff constants for retrying a slot refresh +/// The default number of refresh topology retries in the same call +pub const DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES: usize = 3; +/// The default maximum interval between two retries of the same call for topology refresh +pub const DEFAULT_REFRESH_SLOTS_RETRY_MAX_INTERVAL: Duration = Duration::from_secs(1); +/// The default initial interval for retrying topology refresh +pub const DEFAULT_REFRESH_SLOTS_RETRY_INITIAL_INTERVAL: Duration = Duration::from_millis(500); + +// Constants for the intervals between two independent consecutive refresh slots calls +/// The default wait duration between two consecutive refresh slots calls +#[cfg(feature = "cluster-async")] +pub const DEFAULT_SLOTS_REFRESH_WAIT_DURATION: Duration = Duration::from_secs(15); +/// The default maximum jitter duration to add to the refresh slots wait duration +#[cfg(feature = "cluster-async")] +pub const DEFAULT_SLOTS_REFRESH_MAX_JITTER_MILLI: u64 = 15 * 1000; // 15 seconds + +pub(crate) const SLOT_SIZE: u16 = 16384; +pub(crate) type TopologyHash = u64; + +/// Represents the state of slot refresh operations. +#[cfg(feature = "cluster-async")] +pub(crate) struct SlotRefreshState { + /// Indicates if a slot refresh is currently in progress + pub(crate) in_progress: AtomicBool, + /// The last slot refresh run timestamp + pub(crate) last_run: Arc>>, + pub(crate) rate_limiter: SlotsRefreshRateLimit, +} + +#[cfg(feature = "cluster-async")] +impl SlotRefreshState { + pub(crate) fn new(rate_limiter: SlotsRefreshRateLimit) -> Self { + Self { + in_progress: AtomicBool::new(false), + last_run: Arc::new(RwLock::new(None)), + rate_limiter, + } + } +} + +#[derive(Derivative)] +#[derivative(PartialEq, Eq)] +#[derive(Debug)] +pub(crate) struct TopologyView { + pub(crate) hash_value: TopologyHash, + #[derivative(PartialEq = "ignore")] + pub(crate) nodes_count: u16, + #[derivative(PartialEq = "ignore")] + slots_and_count: (u16, Vec), +} + +pub(crate) fn slot(key: &[u8]) -> u16 { + crc16::State::::calculate(key) % SLOT_SIZE +} + +fn get_hashtag(key: &[u8]) -> Option<&[u8]> { + let open = key.iter().position(|v| *v == b'{'); + let open = match open { + Some(open) => open, + None => return None, + }; + + let close = key[open..].iter().position(|v| *v == b'}'); + let close = match close { + Some(close) => close, + None => return None, + }; + + let rv = &key[open + 1..open + close]; + if rv.is_empty() { + None + } else { + Some(rv) + } +} + +/// Returns the slot that matches `key`. +pub fn get_slot(key: &[u8]) -> u16 { + let key = match get_hashtag(key) { + Some(tag) => tag, + None => key, + }; + + slot(key) +} + +// Parse slot data from raw redis value. +pub(crate) fn parse_and_count_slots( + raw_slot_resp: &Value, + tls: Option, + // The DNS address of the node from which `raw_slot_resp` was received. + addr_of_answering_node: &str, +) -> RedisResult<(u16, Vec)> { + // Parse response. + let mut slots = Vec::with_capacity(2); + let mut count = 0; + + if let Value::Array(items) = raw_slot_resp { + let mut iter = items.iter(); + while let Some(Value::Array(item)) = iter.next() { + if item.len() < 3 { + continue; + } + + let start = if let Value::Int(start) = item[0] { + start as u16 + } else { + continue; + }; + + let end = if let Value::Int(end) = item[1] { + end as u16 + } else { + continue; + }; + + let mut nodes: Vec = item + .iter() + .skip(2) + .filter_map(|node| { + if let Value::Array(node) = node { + if node.len() < 2 { + return None; + } + // According to the CLUSTER SLOTS documentation: + // If the received hostname is an empty string or NULL, clients should utilize the hostname of the responding node. + // However, if the received hostname is "?", it should be regarded as an indication of an unknown node. + let hostname = if let Value::BulkString(ref ip) = node[0] { + let hostname = String::from_utf8_lossy(ip); + if hostname.is_empty() { + addr_of_answering_node.into() + } else if hostname == "?" { + return None; + } else { + hostname + } + } else if let Value::Nil = node[0] { + addr_of_answering_node.into() + } else { + return None; + }; + if hostname.is_empty() { + return None; + } + + let port = if let Value::Int(port) = node[1] { + port as u16 + } else { + return None; + }; + Some( + get_connection_addr(hostname.into_owned(), port, tls, None).to_string(), + ) + } else { + None + } + }) + .collect(); + + if nodes.is_empty() { + continue; + } + count += end - start; + + let mut replicas = nodes.split_off(1); + // we sort the replicas, because different nodes in a cluster might return the same slot view + // with different order of the replicas, which might cause the views to be considered evaluated as not equal. + replicas.sort_unstable(); + slots.push(Slot::new(start, end, nodes.pop().unwrap(), replicas)); + } + } + if slots.is_empty() { + return Err(RedisError::from(( + ErrorKind::ResponseError, + "Error parsing slots: No healthy node found", + format!("Raw slot map response: {:?}", raw_slot_resp), + ))); + } + + Ok((count, slots)) +} + +fn calculate_hash(t: &T) -> u64 { + let mut s = DefaultHasher::new(); + t.hash(&mut s); + s.finish() +} + +pub(crate) fn calculate_topology<'a>( + topology_views: impl Iterator, + curr_retry: usize, + tls_mode: Option, + num_of_queried_nodes: usize, + read_from_replica: ReadFromReplicaStrategy, +) -> RedisResult<(SlotMap, TopologyHash)> { + let mut hash_view_map = HashMap::new(); + for (host, view) in topology_views { + if let Ok(slots_and_count) = parse_and_count_slots(view, tls_mode, host) { + let hash_value = calculate_hash(&slots_and_count); + let topology_entry = hash_view_map.entry(hash_value).or_insert(TopologyView { + hash_value, + nodes_count: 0, + slots_and_count, + }); + topology_entry.nodes_count += 1; + } + } + let mut non_unique_max_node_count = false; + let mut vec_iter = hash_view_map.into_values(); + let mut most_frequent_topology = match vec_iter.next() { + Some(view) => view, + None => { + return Err(RedisError::from(( + ErrorKind::ResponseError, + "No topology views found", + ))); + } + }; + // Find the most frequent topology view + for curr_view in vec_iter { + match most_frequent_topology + .nodes_count + .cmp(&curr_view.nodes_count) + { + std::cmp::Ordering::Less => { + most_frequent_topology = curr_view; + non_unique_max_node_count = false; + } + std::cmp::Ordering::Greater => continue, + std::cmp::Ordering::Equal => { + non_unique_max_node_count = true; + let seen_slot_count = most_frequent_topology.slots_and_count.0; + + // We choose as the greater view the one with higher slot coverage. + if let std::cmp::Ordering::Less = seen_slot_count.cmp(&curr_view.slots_and_count.0) + { + most_frequent_topology = curr_view; + } + } + } + } + + let parse_and_built_result = |most_frequent_topology: TopologyView| { + info!( + "calculate_topology found topology map:\n{:?}", + most_frequent_topology + ); + let slots_data = most_frequent_topology.slots_and_count.1; + Ok(( + SlotMap::new(slots_data, read_from_replica), + most_frequent_topology.hash_value, + )) + }; + + if non_unique_max_node_count { + // More than a single most frequent view was found + // If we reached the last retry, or if we it's a 2-nodes cluster, we'll return a view with the highest slot coverage, and that is one of most agreed on views. + if curr_retry >= DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES || num_of_queried_nodes < 3 { + return parse_and_built_result(most_frequent_topology); + } + return Err(RedisError::from(( + ErrorKind::ResponseError, + "Slot refresh error: Failed to obtain a majority in topology views", + ))); + } + + // The rate of agreement of the topology view is determined by assessing the number of nodes that share this view out of the total number queried + let agreement_rate = most_frequent_topology.nodes_count as f32 / num_of_queried_nodes as f32; + const MIN_AGREEMENT_RATE: f32 = 0.2; + if agreement_rate >= MIN_AGREEMENT_RATE { + parse_and_built_result(most_frequent_topology) + } else { + Err(RedisError::from(( + ErrorKind::ResponseError, + "Slot refresh error: The accuracy of the topology view is too low", + ))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::cluster_routing::SlotAddrs; + + #[test] + fn test_get_hashtag() { + assert_eq!(get_hashtag(&b"foo{bar}baz"[..]), Some(&b"bar"[..])); + assert_eq!(get_hashtag(&b"foo{}{baz}"[..]), None); + assert_eq!(get_hashtag(&b"foo{{bar}}zap"[..]), Some(&b"{bar"[..])); + } + + fn slot_value_with_replicas(start: u16, end: u16, nodes: Vec<(&str, u16)>) -> Value { + let mut node_values: Vec = nodes + .iter() + .map(|(host, port)| { + Value::Array(vec![ + Value::BulkString(host.as_bytes().to_vec()), + Value::Int(*port as i64), + ]) + }) + .collect(); + let mut slot_vec = vec![Value::Int(start as i64), Value::Int(end as i64)]; + slot_vec.append(&mut node_values); + Value::Array(slot_vec) + } + + fn slot_value(start: u16, end: u16, node: &str, port: u16) -> Value { + slot_value_with_replicas(start, end, vec![(node, port)]) + } + + #[test] + fn parse_slots_with_different_replicas_order_returns_the_same_view() { + let view1 = Value::Array(vec![ + slot_value_with_replicas( + 0, + 4000, + vec![ + ("primary1", 6379), + ("replica1_1", 6379), + ("replica1_2", 6379), + ("replica1_3", 6379), + ], + ), + slot_value_with_replicas( + 4001, + 8000, + vec![ + ("primary2", 6379), + ("replica2_1", 6379), + ("replica2_2", 6379), + ("replica2_3", 6379), + ], + ), + slot_value_with_replicas( + 8001, + 16383, + vec![ + ("primary3", 6379), + ("replica3_1", 6379), + ("replica3_2", 6379), + ("replica3_3", 6379), + ], + ), + ]); + + let view2 = Value::Array(vec![ + slot_value_with_replicas( + 0, + 4000, + vec![ + ("primary1", 6379), + ("replica1_1", 6379), + ("replica1_3", 6379), + ("replica1_2", 6379), + ], + ), + slot_value_with_replicas( + 4001, + 8000, + vec![ + ("primary2", 6379), + ("replica2_2", 6379), + ("replica2_3", 6379), + ("replica2_1", 6379), + ], + ), + slot_value_with_replicas( + 8001, + 16383, + vec![ + ("primary3", 6379), + ("replica3_3", 6379), + ("replica3_1", 6379), + ("replica3_2", 6379), + ], + ), + ]); + + let res1 = parse_and_count_slots(&view1, None, "foo").unwrap(); + let res2 = parse_and_count_slots(&view2, None, "foo").unwrap(); + assert_eq!(calculate_hash(&res1), calculate_hash(&res2)); + assert_eq!(res1.0, res2.0); + assert_eq!(res1.1.len(), res2.1.len()); + let check = res1 + .1 + .into_iter() + .zip(res2.1) + .all(|(first, second)| first.replicas() == second.replicas()); + assert!(check); + } + + #[test] + fn parse_slots_returns_slots_with_host_name_if_missing() { + let view = Value::Array(vec![slot_value(0, 4000, "", 6379)]); + + let (slot_count, slots) = parse_and_count_slots(&view, None, "node").unwrap(); + assert_eq!(slot_count, 4000); + assert_eq!(slots[0].master(), "node:6379"); + } + + #[test] + fn should_parse_and_hash_regardless_of_missing_host_name_and_replicas_order() { + let view1 = Value::Array(vec![ + slot_value(0, 4000, "", 6379), + slot_value(4001, 8000, "node2", 6380), + slot_value_with_replicas( + 8001, + 16383, + vec![ + ("node3", 6379), + ("replica3_1", 6379), + ("replica3_2", 6379), + ("replica3_3", 6379), + ], + ), + ]); + + let view2 = Value::Array(vec![ + slot_value(0, 4000, "node1", 6379), + slot_value(4001, 8000, "node2", 6380), + slot_value_with_replicas( + 8001, + 16383, + vec![ + ("", 6379), + ("replica3_3", 6379), + ("replica3_2", 6379), + ("replica3_1", 6379), + ], + ), + ]); + + let res1 = parse_and_count_slots(&view1, None, "node1").unwrap(); + let res2 = parse_and_count_slots(&view2, None, "node3").unwrap(); + + assert_eq!(calculate_hash(&res1), calculate_hash(&res2)); + assert_eq!(res1.0, res2.0); + assert_eq!(res1.1.len(), res2.1.len()); + let equality_check = + res1.1.iter().zip(&res2.1).all(|(first, second)| { + first.start() == second.start() && first.end() == second.end() + }); + assert!(equality_check); + let replicas_check = res1 + .1 + .iter() + .zip(res2.1) + .all(|(first, second)| first.replicas() == second.replicas()); + assert!(replicas_check); + } + + enum ViewType { + SingleNodeViewFullCoverage, + SingleNodeViewMissingSlots, + TwoNodesViewFullCoverage, + TwoNodesViewMissingSlots, + } + fn get_view(view_type: &ViewType) -> (&str, Value) { + match view_type { + ViewType::SingleNodeViewFullCoverage => ( + "first", + Value::Array(vec![slot_value(0, 16383, "node1", 6379)]), + ), + ViewType::SingleNodeViewMissingSlots => ( + "second", + Value::Array(vec![slot_value(0, 4000, "node1", 6379)]), + ), + ViewType::TwoNodesViewFullCoverage => ( + "third", + Value::Array(vec![ + slot_value(0, 4000, "node1", 6379), + slot_value(4001, 16383, "node2", 6380), + ]), + ), + ViewType::TwoNodesViewMissingSlots => ( + "fourth", + Value::Array(vec![ + slot_value(0, 3000, "node3", 6381), + slot_value(4001, 16383, "node4", 6382), + ]), + ), + } + } + + fn get_node_addr(name: &str, port: u16) -> SlotAddrs { + SlotAddrs::new(format!("{name}:{port}"), Vec::new()) + } + + #[test] + fn test_topology_calculator_4_nodes_queried_has_a_majority_success() { + // 4 nodes queried (1 error): Has a majority, single_node_view should be chosen + let queried_nodes: usize = 4; + let topology_results = vec![ + get_view(&ViewType::SingleNodeViewFullCoverage), + get_view(&ViewType::SingleNodeViewFullCoverage), + get_view(&ViewType::TwoNodesViewFullCoverage), + ]; + + let (topology_view, _) = calculate_topology( + topology_results.iter().map(|(addr, value)| (*addr, value)), + 1, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ) + .unwrap(); + let res: Vec<_> = topology_view.values().collect(); + let node_1 = get_node_addr("node1", 6379); + let expected: Vec<&SlotAddrs> = vec![&node_1]; + assert_eq!(res, expected); + } + + #[test] + fn test_topology_calculator_3_nodes_queried_no_majority_has_more_retries_raise_error() { + // 3 nodes queried: No majority, should return an error + let queried_nodes = 3; + let topology_results = vec![ + get_view(&ViewType::SingleNodeViewFullCoverage), + get_view(&ViewType::TwoNodesViewFullCoverage), + get_view(&ViewType::TwoNodesViewMissingSlots), + ]; + let topology_view = calculate_topology( + topology_results.iter().map(|(addr, value)| (*addr, value)), + 1, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ); + assert!(topology_view.is_err()); + } + + #[test] + fn test_topology_calculator_3_nodes_queried_no_majority_last_retry_success() { + // 3 nodes queried:: No majority, last retry, should get the view that has a full slot coverage + let queried_nodes = 3; + let topology_results = vec![ + get_view(&ViewType::SingleNodeViewMissingSlots), + get_view(&ViewType::TwoNodesViewFullCoverage), + get_view(&ViewType::TwoNodesViewMissingSlots), + ]; + let (topology_view, _) = calculate_topology( + topology_results.iter().map(|(addr, value)| (*addr, value)), + 3, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ) + .unwrap(); + let res: Vec<_> = topology_view.values().collect(); + let node_1 = get_node_addr("node1", 6379); + let node_2 = get_node_addr("node2", 6380); + let expected: Vec<&SlotAddrs> = vec![&node_1, &node_2]; + assert_eq!(res, expected); + } + + #[test] + fn test_topology_calculator_2_nodes_queried_no_majority_return_full_slot_coverage_view() { + // 2 nodes queried: No majority, should get the view that has a full slot coverage + let queried_nodes = 2; + let topology_results = [ + get_view(&ViewType::TwoNodesViewFullCoverage), + get_view(&ViewType::TwoNodesViewMissingSlots), + ]; + let (topology_view, _) = calculate_topology( + topology_results.iter().map(|(addr, value)| (*addr, value)), + 1, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ) + .unwrap(); + let res: Vec<_> = topology_view.values().collect(); + let node_1 = get_node_addr("node1", 6379); + let node_2 = get_node_addr("node2", 6380); + let expected: Vec<&SlotAddrs> = vec![&node_1, &node_2]; + assert_eq!(res, expected); + } + + #[test] + fn test_topology_calculator_2_nodes_queried_no_majority_no_full_coverage_prefer_fuller_coverage( + ) { + // 2 nodes queried: No majority, no full slot coverage, should return error + let queried_nodes = 2; + let topology_results = [ + get_view(&ViewType::SingleNodeViewMissingSlots), + get_view(&ViewType::TwoNodesViewMissingSlots), + ]; + let (topology_view, _) = calculate_topology( + topology_results.iter().map(|(addr, value)| (*addr, value)), + 1, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ) + .unwrap(); + let res: Vec<_> = topology_view.values().collect(); + let node_1 = get_node_addr("node3", 6381); + let node_2 = get_node_addr("node4", 6382); + let expected: Vec<&SlotAddrs> = vec![&node_1, &node_2]; + assert_eq!(res, expected); + } + + #[test] + fn test_topology_calculator_3_nodes_queried_no_full_coverage_prefer_majority() { + // 2 nodes queried: No majority, no full slot coverage, should return error + let queried_nodes = 2; + let topology_results = vec![ + get_view(&ViewType::SingleNodeViewMissingSlots), + get_view(&ViewType::TwoNodesViewMissingSlots), + get_view(&ViewType::SingleNodeViewMissingSlots), + ]; + let (topology_view, _) = calculate_topology( + topology_results.iter().map(|(addr, value)| (*addr, value)), + 1, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ) + .unwrap(); + let res: Vec<_> = topology_view.values().collect(); + let node_1 = get_node_addr("node1", 6379); + let expected: Vec<&SlotAddrs> = vec![&node_1]; + assert_eq!(res, expected); + } +} diff --git a/glide-core/redis-rs/redis/src/cmd.rs b/glide-core/redis-rs/redis/src/cmd.rs new file mode 100644 index 0000000000..979bc7987b --- /dev/null +++ b/glide-core/redis-rs/redis/src/cmd.rs @@ -0,0 +1,663 @@ +#[cfg(feature = "aio")] +use futures_util::{ + future::BoxFuture, + task::{Context, Poll}, + Stream, StreamExt, +}; +#[cfg(feature = "aio")] +use std::pin::Pin; +use std::{fmt, io}; + +use crate::connection::ConnectionLike; +use crate::pipeline::Pipeline; +use crate::types::{from_owned_redis_value, FromRedisValue, RedisResult, RedisWrite, ToRedisArgs}; + +/// An argument to a redis command +#[derive(Clone)] +pub enum Arg { + /// A normal argument + Simple(D), + /// A cursor argument created from `cursor_arg()` + Cursor, +} + +/// Represents redis commands. +#[derive(Clone)] +pub struct Cmd { + data: Vec, + // Arg::Simple contains the offset that marks the end of the argument + args: Vec>, + cursor: Option, + // If it's true command's response won't be read from socket. Useful for Pub/Sub. + no_response: bool, +} + +/// Represents a redis iterator. +pub struct Iter<'a, T: FromRedisValue> { + batch: std::vec::IntoIter, + cursor: u64, + con: &'a mut (dyn ConnectionLike + 'a), + cmd: Cmd, +} + +impl<'a, T: FromRedisValue> Iterator for Iter<'a, T> { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + // we need to do this in a loop until we produce at least one item + // or we find the actual end of the iteration. This is necessary + // because with filtering an iterator it is possible that a whole + // chunk is not matching the pattern and thus yielding empty results. + loop { + if let Some(v) = self.batch.next() { + return Some(v); + }; + if self.cursor == 0 { + return None; + } + + let pcmd = self.cmd.get_packed_command_with_cursor(self.cursor)?; + let rv = self.con.req_packed_command(&pcmd).ok()?; + let (cur, batch): (u64, Vec) = from_owned_redis_value(rv).ok()?; + + self.cursor = cur; + self.batch = batch.into_iter(); + } + } +} + +#[cfg(feature = "aio")] +use crate::aio::ConnectionLike as AsyncConnection; + +/// The inner future of AsyncIter +#[cfg(feature = "aio")] +struct AsyncIterInner<'a, T: FromRedisValue + 'a> { + batch: std::vec::IntoIter, + con: &'a mut (dyn AsyncConnection + Send + 'a), + cmd: Cmd, +} + +/// Represents the state of AsyncIter +#[cfg(feature = "aio")] +enum IterOrFuture<'a, T: FromRedisValue + 'a> { + Iter(AsyncIterInner<'a, T>), + Future(BoxFuture<'a, (AsyncIterInner<'a, T>, Option)>), + Empty, +} + +/// Represents a redis iterator that can be used with async connections. +#[cfg(feature = "aio")] +pub struct AsyncIter<'a, T: FromRedisValue + 'a> { + inner: IterOrFuture<'a, T>, +} + +#[cfg(feature = "aio")] +impl<'a, T: FromRedisValue + 'a> AsyncIterInner<'a, T> { + #[inline] + pub async fn next_item(&mut self) -> Option { + // we need to do this in a loop until we produce at least one item + // or we find the actual end of the iteration. This is necessary + // because with filtering an iterator it is possible that a whole + // chunk is not matching the pattern and thus yielding empty results. + loop { + if let Some(v) = self.batch.next() { + return Some(v); + }; + if let Some(cursor) = self.cmd.cursor { + if cursor == 0 { + return None; + } + } else { + return None; + } + + let rv = self.con.req_packed_command(&self.cmd).await.ok()?; + let (cur, batch): (u64, Vec) = from_owned_redis_value(rv).ok()?; + + self.cmd.cursor = Some(cur); + self.batch = batch.into_iter(); + } + } +} + +#[cfg(feature = "aio")] +impl<'a, T: FromRedisValue + 'a + Unpin + Send> AsyncIter<'a, T> { + /// ```rust,no_run + /// # use redis::AsyncCommands; + /// # async fn scan_set() -> redis::RedisResult<()> { + /// # let client = redis::Client::open("redis://127.0.0.1/")?; + /// # let mut con = client.get_async_connection(None).await?; + /// con.sadd("my_set", 42i32).await?; + /// con.sadd("my_set", 43i32).await?; + /// let mut iter: redis::AsyncIter = con.sscan("my_set").await?; + /// while let Some(element) = iter.next_item().await { + /// assert!(element == 42 || element == 43); + /// } + /// # Ok(()) + /// # } + /// ``` + #[inline] + pub async fn next_item(&mut self) -> Option { + StreamExt::next(self).await + } +} + +#[cfg(feature = "aio")] +impl<'a, T: FromRedisValue + Unpin + Send + 'a> Stream for AsyncIter<'a, T> { + type Item = T; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + let inner = std::mem::replace(&mut this.inner, IterOrFuture::Empty); + match inner { + IterOrFuture::Iter(mut iter) => { + let fut = async move { + let next_item = iter.next_item().await; + (iter, next_item) + }; + this.inner = IterOrFuture::Future(Box::pin(fut)); + Pin::new(this).poll_next(cx) + } + IterOrFuture::Future(mut fut) => match fut.as_mut().poll(cx) { + Poll::Pending => { + this.inner = IterOrFuture::Future(fut); + Poll::Pending + } + Poll::Ready((iter, value)) => { + this.inner = IterOrFuture::Iter(iter); + Poll::Ready(value) + } + }, + IterOrFuture::Empty => unreachable!(), + } + } +} + +fn countdigits(mut v: usize) -> usize { + let mut result = 1; + loop { + if v < 10 { + return result; + } + if v < 100 { + return result + 1; + } + if v < 1000 { + return result + 2; + } + if v < 10000 { + return result + 3; + } + + v /= 10000; + result += 4; + } +} + +#[inline] +fn bulklen(len: usize) -> usize { + 1 + countdigits(len) + 2 + len + 2 +} + +fn args_len<'a, I>(args: I, cursor: u64) -> usize +where + I: IntoIterator> + ExactSizeIterator, +{ + let mut totlen = 1 + countdigits(args.len()) + 2; + for item in args { + totlen += bulklen(match item { + Arg::Cursor => countdigits(cursor as usize), + Arg::Simple(val) => val.len(), + }); + } + totlen +} + +pub(crate) fn cmd_len(cmd: &Cmd) -> usize { + args_len(cmd.args_iter(), cmd.cursor.unwrap_or(0)) +} + +fn encode_command<'a, I>(args: I, cursor: u64) -> Vec +where + I: IntoIterator> + Clone + ExactSizeIterator, +{ + let mut cmd = Vec::new(); + write_command_to_vec(&mut cmd, args, cursor); + cmd +} + +fn write_command_to_vec<'a, I>(cmd: &mut Vec, args: I, cursor: u64) +where + I: IntoIterator> + Clone + ExactSizeIterator, +{ + let totlen = args_len(args.clone(), cursor); + + cmd.reserve(totlen); + + write_command(cmd, args, cursor).unwrap() +} + +fn write_command<'a, I>(cmd: &mut (impl ?Sized + io::Write), args: I, cursor: u64) -> io::Result<()> +where + I: IntoIterator> + Clone + ExactSizeIterator, +{ + let mut buf = ::itoa::Buffer::new(); + + cmd.write_all(b"*")?; + let s = buf.format(args.len()); + cmd.write_all(s.as_bytes())?; + cmd.write_all(b"\r\n")?; + + let mut cursor_bytes = itoa::Buffer::new(); + for item in args { + let bytes = match item { + Arg::Cursor => cursor_bytes.format(cursor).as_bytes(), + Arg::Simple(val) => val, + }; + + cmd.write_all(b"$")?; + let s = buf.format(bytes.len()); + cmd.write_all(s.as_bytes())?; + cmd.write_all(b"\r\n")?; + + cmd.write_all(bytes)?; + cmd.write_all(b"\r\n")?; + } + Ok(()) +} + +impl RedisWrite for Cmd { + fn write_arg(&mut self, arg: &[u8]) { + self.data.extend_from_slice(arg); + self.args.push(Arg::Simple(self.data.len())); + } + + fn write_arg_fmt(&mut self, arg: impl fmt::Display) { + use std::io::Write; + write!(self.data, "{arg}").unwrap(); + self.args.push(Arg::Simple(self.data.len())); + } +} + +impl Default for Cmd { + fn default() -> Cmd { + Cmd::new() + } +} + +/// A command acts as a builder interface to creating encoded redis +/// requests. This allows you to easiy assemble a packed command +/// by chaining arguments together. +/// +/// Basic example: +/// +/// ```rust +/// redis::Cmd::new().arg("SET").arg("my_key").arg(42); +/// ``` +/// +/// There is also a helper function called `cmd` which makes it a +/// tiny bit shorter: +/// +/// ```rust +/// redis::cmd("SET").arg("my_key").arg(42); +/// ``` +/// +/// Because Rust currently does not have an ideal system +/// for lifetimes of temporaries, sometimes you need to hold on to +/// the initially generated command: +/// +/// ```rust,no_run +/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +/// # let mut con = client.get_connection(None).unwrap(); +/// let mut cmd = redis::cmd("SMEMBERS"); +/// let mut iter : redis::Iter = cmd.arg("my_set").clone().iter(&mut con).unwrap(); +/// ``` +impl Cmd { + /// Creates a new empty command. + pub fn new() -> Cmd { + Cmd { + data: vec![], + args: vec![], + cursor: None, + no_response: false, + } + } + + /// Creates a new empty command, with at least the requested capcity. + pub fn with_capacity(arg_count: usize, size_of_data: usize) -> Cmd { + Cmd { + data: Vec::with_capacity(size_of_data), + args: Vec::with_capacity(arg_count), + cursor: None, + no_response: false, + } + } + + /// Get the capacities for the internal buffers. + #[cfg(test)] + #[allow(dead_code)] + pub(crate) fn capacity(&self) -> (usize, usize) { + (self.args.capacity(), self.data.capacity()) + } + + /// Appends an argument to the command. The argument passed must + /// be a type that implements `ToRedisArgs`. Most primitive types as + /// well as vectors of primitive types implement it. + /// + /// For instance all of the following are valid: + /// + /// ```rust,no_run + /// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// redis::cmd("SET").arg(&["my_key", "my_value"]); + /// redis::cmd("SET").arg("my_key").arg(42); + /// redis::cmd("SET").arg("my_key").arg(b"my_value"); + /// ``` + #[inline] + pub fn arg(&mut self, arg: T) -> &mut Cmd { + arg.write_redis_args(self); + self + } + + /// Works similar to `arg` but adds a cursor argument. This is always + /// an integer and also flips the command implementation to support a + /// different mode for the iterators where the iterator will ask for + /// another batch of items when the local data is exhausted. + /// + /// ```rust,no_run + /// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// let mut cmd = redis::cmd("SSCAN"); + /// let mut iter : redis::Iter = + /// cmd.arg("my_set").cursor_arg(0).clone().iter(&mut con).unwrap(); + /// for x in iter { + /// // do something with the item + /// } + /// ``` + #[inline] + pub fn cursor_arg(&mut self, cursor: u64) -> &mut Cmd { + assert!(!self.in_scan_mode()); + self.cursor = Some(cursor); + self.args.push(Arg::Cursor); + self + } + + /// Returns the packed command as a byte vector. + #[inline] + pub fn get_packed_command(&self) -> Vec { + let mut cmd = Vec::new(); + self.write_packed_command(&mut cmd); + cmd + } + + pub(crate) fn write_packed_command(&self, cmd: &mut Vec) { + write_command_to_vec(cmd, self.args_iter(), self.cursor.unwrap_or(0)) + } + + pub(crate) fn write_packed_command_preallocated(&self, cmd: &mut Vec) { + write_command(cmd, self.args_iter(), self.cursor.unwrap_or(0)).unwrap() + } + + /// Like `get_packed_command` but replaces the cursor with the + /// provided value. If the command is not in scan mode, `None` + /// is returned. + #[inline] + fn get_packed_command_with_cursor(&self, cursor: u64) -> Option> { + if !self.in_scan_mode() { + None + } else { + Some(encode_command(self.args_iter(), cursor)) + } + } + + /// Returns true if the command is in scan mode. + #[inline] + pub fn in_scan_mode(&self) -> bool { + self.cursor.is_some() + } + + /// Sends the command as query to the connection and converts the + /// result to the target redis value. This is the general way how + /// you can retrieve data. + #[inline] + pub fn query(&self, con: &mut dyn ConnectionLike) -> RedisResult { + match con.req_command(self) { + Ok(val) => from_owned_redis_value(val), + Err(e) => Err(e), + } + } + + /// Async version of `query`. + #[inline] + #[cfg(feature = "aio")] + pub async fn query_async(&self, con: &mut C) -> RedisResult + where + C: crate::aio::ConnectionLike, + { + let val = con.req_packed_command(self).await?; + from_owned_redis_value(val) + } + + /// Similar to `query()` but returns an iterator over the items of the + /// bulk result or iterator. In normal mode this is not in any way more + /// efficient than just querying into a `Vec` as it's internally + /// implemented as buffering into a vector. This however is useful when + /// `cursor_arg` was used in which case the iterator will query for more + /// items until the server side cursor is exhausted. + /// + /// This is useful for commands such as `SSCAN`, `SCAN` and others. + /// + /// One speciality of this function is that it will check if the response + /// looks like a cursor or not and always just looks at the payload. + /// This way you can use the function the same for responses in the + /// format of `KEYS` (just a list) as well as `SSCAN` (which returns a + /// tuple of cursor and list). + #[inline] + pub fn iter(self, con: &mut dyn ConnectionLike) -> RedisResult> { + let rv = con.req_command(&self)?; + + let (cursor, batch) = if rv.looks_like_cursor() { + from_owned_redis_value::<(u64, Vec)>(rv)? + } else { + (0, from_owned_redis_value(rv)?) + }; + + Ok(Iter { + batch: batch.into_iter(), + cursor, + con, + cmd: self, + }) + } + + /// Similar to `iter()` but returns an AsyncIter over the items of the + /// bulk result or iterator. A [futures::Stream](https://docs.rs/futures/0.3.3/futures/stream/trait.Stream.html) + /// is implemented on AsyncIter. In normal mode this is not in any way more + /// efficient than just querying into a `Vec` as it's internally + /// implemented as buffering into a vector. This however is useful when + /// `cursor_arg` was used in which case the stream will query for more + /// items until the server side cursor is exhausted. + /// + /// This is useful for commands such as `SSCAN`, `SCAN` and others in async contexts. + /// + /// One speciality of this function is that it will check if the response + /// looks like a cursor or not and always just looks at the payload. + /// This way you can use the function the same for responses in the + /// format of `KEYS` (just a list) as well as `SSCAN` (which returns a + /// tuple of cursor and list). + #[cfg(feature = "aio")] + #[inline] + pub async fn iter_async<'a, T: FromRedisValue + 'a>( + mut self, + con: &'a mut (dyn AsyncConnection + Send), + ) -> RedisResult> { + let rv = con.req_packed_command(&self).await?; + + let (cursor, batch) = if rv.looks_like_cursor() { + from_owned_redis_value::<(u64, Vec)>(rv)? + } else { + (0, from_owned_redis_value(rv)?) + }; + if cursor == 0 { + self.cursor = None; + } else { + self.cursor = Some(cursor); + } + + Ok(AsyncIter { + inner: IterOrFuture::Iter(AsyncIterInner { + batch: batch.into_iter(), + con, + cmd: self, + }), + }) + } + + /// This is a shortcut to `query()` that does not return a value and + /// will fail the task if the query fails because of an error. This is + /// mainly useful in examples and for simple commands like setting + /// keys. + /// + /// This is equivalent to a call of query like this: + /// + /// ```rust,no_run + /// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// let _ : () = redis::cmd("PING").query(&mut con).unwrap(); + /// ``` + #[inline] + pub fn execute(&self, con: &mut dyn ConnectionLike) { + self.query::<()>(con).unwrap(); + } + + /// Returns an iterator over the arguments in this command (including the command name itself) + pub fn args_iter(&self) -> impl Clone + ExactSizeIterator> { + let mut prev = 0; + self.args.iter().map(move |arg| match *arg { + Arg::Simple(i) => { + let arg = Arg::Simple(&self.data[prev..i]); + prev = i; + arg + } + + Arg::Cursor => Arg::Cursor, + }) + } + + // Get a reference to the argument at `idx` + #[cfg(feature = "cluster")] + pub(crate) fn arg_idx(&self, idx: usize) -> Option<&[u8]> { + if idx >= self.args.len() { + return None; + } + + let start = if idx == 0 { + 0 + } else { + match self.args[idx - 1] { + Arg::Simple(n) => n, + _ => 0, + } + }; + let end = match self.args[idx] { + Arg::Simple(n) => n, + _ => 0, + }; + if start == 0 && end == 0 { + return None; + } + Some(&self.data[start..end]) + } + + /// Client won't read and wait for results. Currently only used for Pub/Sub commands in RESP3. + #[inline] + pub fn set_no_response(&mut self, nr: bool) -> &mut Cmd { + self.no_response = nr; + self + } + + /// Check whether command's result will be waited for. + #[inline] + pub fn is_no_response(&self) -> bool { + self.no_response + } +} + +impl fmt::Debug for Cmd { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let res = self + .args_iter() + .map(|arg| { + let bytes = match arg { + Arg::Cursor => b"", + Arg::Simple(val) => val, + }; + std::str::from_utf8(bytes).unwrap_or_default() + }) + .collect::>(); + f.debug_struct("Cmd").field("args", &res).finish() + } +} + +/// Shortcut function to creating a command with a single argument. +/// +/// The first argument of a redis command is always the name of the command +/// which needs to be a string. This is the recommended way to start a +/// command pipe. +/// +/// ```rust +/// redis::cmd("PING"); +/// ``` +pub fn cmd(name: &str) -> Cmd { + let mut rv = Cmd::new(); + rv.arg(name); + rv +} + +/// Packs a bunch of commands into a request. This is generally a quite +/// useless function as this functionality is nicely wrapped through the +/// `Cmd` object, but in some cases it can be useful. The return value +/// of this can then be send to the low level `ConnectionLike` methods. +/// +/// Example: +/// +/// ```rust +/// # use redis::ToRedisArgs; +/// let mut args = vec![]; +/// args.extend("SET".to_redis_args()); +/// args.extend("my_key".to_redis_args()); +/// args.extend(42.to_redis_args()); +/// let cmd = redis::pack_command(&args); +/// assert_eq!(cmd, b"*3\r\n$3\r\nSET\r\n$6\r\nmy_key\r\n$2\r\n42\r\n".to_vec()); +/// ``` +pub fn pack_command(args: &[Vec]) -> Vec { + encode_command(args.iter().map(|x| Arg::Simple(&x[..])), 0) +} + +/// Shortcut for creating a new pipeline. +pub fn pipe() -> Pipeline { + Pipeline::new() +} + +#[cfg(test)] +#[cfg(feature = "cluster")] +mod tests { + use super::Cmd; + + #[test] + fn test_cmd_arg_idx() { + let mut c = Cmd::new(); + assert_eq!(c.arg_idx(0), None); + + c.arg("SET"); + assert_eq!(c.arg_idx(0), Some(&b"SET"[..])); + assert_eq!(c.arg_idx(1), None); + + c.arg("foo").arg("42"); + assert_eq!(c.arg_idx(1), Some(&b"foo"[..])); + assert_eq!(c.arg_idx(2), Some(&b"42"[..])); + assert_eq!(c.arg_idx(3), None); + assert_eq!(c.arg_idx(4), None); + } +} diff --git a/glide-core/redis-rs/redis/src/commands/cluster_scan.rs b/glide-core/redis-rs/redis/src/commands/cluster_scan.rs new file mode 100644 index 0000000000..97f10577ac --- /dev/null +++ b/glide-core/redis-rs/redis/src/commands/cluster_scan.rs @@ -0,0 +1,720 @@ +use crate::aio::ConnectionLike; +use crate::cluster_async::{ + ClusterConnInner, Connect, Core, InternalRoutingInfo, InternalSingleNodeRouting, RefreshPolicy, + Response, +}; +use crate::cluster_routing::SlotAddr; +use crate::cluster_topology::SLOT_SIZE; +use crate::{cmd, from_redis_value, Cmd, ErrorKind, RedisError, RedisResult, Value}; +use async_trait::async_trait; +use std::sync::Arc; +use strum_macros::Display; + +/// This module contains the implementation of scanning operations in a Redis cluster. +/// +/// The [`ClusterScanArgs`] struct represents the arguments for a cluster scan operation, +/// including the scan state reference, match pattern, count, and object type. +/// +/// The [[`ScanStateRC`]] struct is a wrapper for managing the state of a scan operation in a cluster. +/// It holds a reference to the scan state and provides methods for accessing the state. +/// +/// The [[`ClusterInScan`]] trait defines the methods for interacting with a Redis cluster during scanning, +/// including retrieving address information, refreshing slot mapping, and routing commands to specific address. +/// +/// The [[`ScanState`]] struct represents the state of a scan operation in a Redis cluster. +/// It holds information about the current scan state, including the cursor position, scanned slots map, +/// address being scanned, and address's epoch. + +const BITS_PER_U64: usize = u64::BITS as usize; +const NUM_OF_SLOTS: usize = SLOT_SIZE as usize; +const BITS_ARRAY_SIZE: usize = NUM_OF_SLOTS / BITS_PER_U64; +const END_OF_SCAN: u16 = NUM_OF_SLOTS as u16 + 1; +type SlotsBitsArray = [u64; BITS_ARRAY_SIZE]; + +#[derive(Clone)] +pub(crate) struct ClusterScanArgs { + pub(crate) scan_state_cursor: ScanStateRC, + match_pattern: Option>, + count: Option, + object_type: Option, +} + +#[derive(Debug, Clone, Display)] +/// Represents the type of an object in Redis. +pub enum ObjectType { + /// Represents a string object in Redis. + String, + /// Represents a list object in Redis. + List, + /// Represents a set object in Redis. + Set, + /// Represents a sorted set object in Redis. + ZSet, + /// Represents a hash object in Redis. + Hash, + /// Represents a stream object in Redis. + Stream, +} + +impl ClusterScanArgs { + pub(crate) fn new( + scan_state_cursor: ScanStateRC, + match_pattern: Option>, + count: Option, + object_type: Option, + ) -> Self { + Self { + scan_state_cursor, + match_pattern, + count, + object_type, + } + } +} + +#[derive(PartialEq, Debug, Clone, Default)] +pub enum ScanStateStage { + #[default] + Initiating, + InProgress, + Finished, +} + +#[derive(Debug, Clone, Default)] +/// A wrapper struct for managing the state of a scan operation in a cluster. +/// It holds a reference to the scan state and provides methods for accessing the state. +/// The `status` field indicates the status of the scan operation. +pub struct ScanStateRC { + scan_state_rc: Arc>, + status: ScanStateStage, +} + +impl ScanStateRC { + /// Creates a new instance of [`ScanStateRC`] from a given [`ScanState`]. + fn from_scan_state(scan_state: ScanState) -> Self { + Self { + scan_state_rc: Arc::new(Some(scan_state)), + status: ScanStateStage::InProgress, + } + } + + /// Creates a new instance of [`ScanStateRC`]. + /// + /// This method initializes the [`ScanStateRC`] with a reference to a [`ScanState`] that is initially set to `None`. + /// An empty ScanState is equivalent to a 0 cursor. + pub fn new() -> Self { + Self { + scan_state_rc: Arc::new(None), + status: ScanStateStage::Initiating, + } + } + /// create a new instance of [`ScanStateRC`] with finished state and empty scan state. + fn create_finished() -> Self { + Self { + scan_state_rc: Arc::new(None), + status: ScanStateStage::Finished, + } + } + /// Returns `true` if the scan state is finished. + pub fn is_finished(&self) -> bool { + self.status == ScanStateStage::Finished + } + + /// Returns a clone of the scan state, if it exist. + pub(crate) fn get_state_from_wrapper(&self) -> Option { + if self.status == ScanStateStage::Initiating || self.status == ScanStateStage::Finished { + None + } else { + self.scan_state_rc.as_ref().clone() + } + } +} + +/// This trait defines the methods for interacting with a Redis cluster during scanning. +#[async_trait] +pub(crate) trait ClusterInScan { + /// Retrieves the address associated with a given slot in the cluster. + async fn get_address_by_slot(&self, slot: u16) -> RedisResult; + + /// Retrieves the epoch of a given address in the cluster. + /// The epoch represents the version of the address, which is updated when a failover occurs or slots migrate in. + async fn get_address_epoch(&self, address: &str) -> Result; + + /// Retrieves the slots assigned to a given address in the cluster. + async fn get_slots_of_address(&self, address: &str) -> Vec; + + /// Routes a Redis command to a specific address in the cluster. + async fn route_command(&self, cmd: Cmd, address: &str) -> RedisResult; + + /// Check if all slots are covered by the cluster + async fn are_all_slots_covered(&self) -> bool; + + /// Check if the topology of the cluster has changed and refresh the slots if needed + async fn refresh_if_topology_changed(&self); +} + +/// Represents the state of a scan operation in a Redis cluster. +/// +/// This struct holds information about the current scan state, including the cursor position, +/// the scanned slots map, the address being scanned, and the address's epoch. +#[derive(PartialEq, Debug, Clone)] +pub(crate) struct ScanState { + // the real cursor in the scan operation + cursor: u64, + // a map of the slots that have been scanned + scanned_slots_map: SlotsBitsArray, + // the address that is being scanned currently, based on the next slot set to 0 in the scanned_slots_map, and the address that "owns" the slot + // in the SlotMap + pub(crate) address_in_scan: String, + // epoch represent the version of the address, when a failover happens or slots migrate in the epoch will be updated to +1 + address_epoch: u64, + // the status of the scan operation + scan_status: ScanStateStage, +} + +impl ScanState { + /// Create a new instance of ScanState. + /// + /// # Arguments + /// + /// * `cursor` - The cursor position. + /// * `scanned_slots_map` - The scanned slots map. + /// * `address_in_scan` - The address being scanned. + /// * `address_epoch` - The epoch of the address being scanned. + /// * `scan_status` - The status of the scan operation. + /// + /// # Returns + /// + /// A new instance of ScanState. + pub fn new( + cursor: u64, + scanned_slots_map: SlotsBitsArray, + address_in_scan: String, + address_epoch: u64, + scan_status: ScanStateStage, + ) -> Self { + Self { + cursor, + scanned_slots_map, + address_in_scan, + address_epoch, + scan_status, + } + } + + fn create_finished_state() -> Self { + Self { + cursor: 0, + scanned_slots_map: [0; BITS_ARRAY_SIZE], + address_in_scan: String::new(), + address_epoch: 0, + scan_status: ScanStateStage::Finished, + } + } + + /// Initialize a new scan operation. + /// This method creates a new scan state with the cursor set to 0, the scanned slots map initialized to 0, + /// and the address set to the address associated with slot 0. + /// The address epoch is set to the epoch of the address. + /// If the address epoch cannot be retrieved, the method returns an error. + async fn initiate_scan(connection: &C) -> RedisResult { + let new_scanned_slots_map: SlotsBitsArray = [0; BITS_ARRAY_SIZE]; + let new_cursor = 0; + let address = connection.get_address_by_slot(0).await?; + let address_epoch = connection.get_address_epoch(&address).await.unwrap_or(0); + Ok(ScanState::new( + new_cursor, + new_scanned_slots_map, + address, + address_epoch, + ScanStateStage::InProgress, + )) + } + + /// Get the next slot to be scanned based on the scanned slots map. + /// If all slots have been scanned, the method returns [`END_OF_SCAN`]. + fn get_next_slot(&self, scanned_slots_map: &SlotsBitsArray) -> Option { + let all_slots_scanned = scanned_slots_map.iter().all(|&word| word == u64::MAX); + if all_slots_scanned { + return Some(END_OF_SCAN); + } + for (i, slot) in scanned_slots_map.iter().enumerate() { + let mut mask = 1; + for j in 0..BITS_PER_U64 { + if (slot & mask) == 0 { + return Some((i * BITS_PER_U64 + j) as u16); + } + mask <<= 1; + } + } + None + } + + /// Update the scan state without updating the scanned slots map. + /// This method is used when the address epoch has changed, and we can't determine which slots are new. + /// In this case, we skip updating the scanned slots map and only update the address and cursor. + async fn creating_state_without_slot_changes( + &self, + connection: &C, + ) -> RedisResult { + let next_slot = self.get_next_slot(&self.scanned_slots_map).unwrap_or(0); + let new_address = if next_slot == END_OF_SCAN { + return Ok(ScanState::create_finished_state()); + } else { + connection.get_address_by_slot(next_slot).await + }; + match new_address { + Ok(address) => { + let new_epoch = connection.get_address_epoch(&address).await.unwrap_or(0); + Ok(ScanState::new( + 0, + self.scanned_slots_map, + address, + new_epoch, + ScanStateStage::InProgress, + )) + } + Err(err) => Err(err), + } + } + + /// Update the scan state and get the next address to scan. + /// This method is called when the cursor reaches 0, indicating that the current address has been scanned. + /// This method updates the scan state based on the scanned slots map and retrieves the next address to scan. + /// If the address epoch has changed, the method skips updating the scanned slots map and only updates the address and cursor. + /// If the address epoch has not changed, the method updates the scanned slots map with the slots owned by the address. + /// The method returns the new scan state with the updated cursor, scanned slots map, address, and epoch. + async fn create_updated_scan_state_for_completed_address( + &mut self, + connection: &C, + ) -> RedisResult { + let _ = connection.refresh_if_topology_changed().await; + let mut scanned_slots_map = self.scanned_slots_map; + // If the address epoch changed it mean that some slots in the address are new, so we cant know which slots been there from the beginning and which are new, or out and in later. + // In this case we will skip updating the scanned_slots_map and will just update the address and the cursor + let new_address_epoch = connection + .get_address_epoch(&self.address_in_scan) + .await + .unwrap_or(0); + if new_address_epoch != self.address_epoch { + return self.creating_state_without_slot_changes(connection).await; + } + // If epoch wasn't changed, the slots owned by the address after the refresh are all valid as slots that been scanned + // So we will update the scanned_slots_map with the slots owned by the address + let slots_scanned = connection.get_slots_of_address(&self.address_in_scan).await; + for slot in slots_scanned { + let slot_index = slot as usize / BITS_PER_U64; + let slot_bit = slot as usize % BITS_PER_U64; + scanned_slots_map[slot_index] |= 1 << slot_bit; + } + // Get the next address to scan and its param base on the next slot set to 0 in the scanned_slots_map + let next_slot = self.get_next_slot(&scanned_slots_map).unwrap_or(0); + let new_address = if next_slot == END_OF_SCAN { + return Ok(ScanState::create_finished_state()); + } else { + connection.get_address_by_slot(next_slot).await + }; + match new_address { + Ok(new_address) => { + let new_epoch = connection + .get_address_epoch(&new_address) + .await + .unwrap_or(0); + let new_cursor = 0; + Ok(ScanState::new( + new_cursor, + scanned_slots_map, + new_address, + new_epoch, + ScanStateStage::InProgress, + )) + } + Err(err) => Err(err), + } + } +} + +// Implement the [`ClusterInScan`] trait for [`InnerCore`] of async cluster connection. +#[async_trait] +impl ClusterInScan for Core +where + C: ConnectionLike + Connect + Clone + Send + Sync + 'static, +{ + async fn get_address_by_slot(&self, slot: u16) -> RedisResult { + let address = self + .get_address_from_slot(slot, SlotAddr::ReplicaRequired) + .await; + match address { + Some(addr) => Ok(addr), + None => { + if self.are_all_slots_covered().await { + Err(RedisError::from(( + ErrorKind::IoError, + "Failed to get connection to the node cover the slot, please check the cluster configuration ", + ))) + } else { + Err(RedisError::from(( + ErrorKind::NotAllSlotsCovered, + "All slots are not covered by the cluster, please check the cluster configuration ", + ))) + } + } + } + } + + async fn get_address_epoch(&self, address: &str) -> Result { + self.as_ref().get_address_epoch(address).await + } + async fn get_slots_of_address(&self, address: &str) -> Vec { + self.as_ref().get_slots_of_address(address).await + } + async fn route_command(&self, cmd: Cmd, address: &str) -> RedisResult { + let routing = InternalRoutingInfo::SingleNode(InternalSingleNodeRouting::ByAddress( + address.to_string(), + )); + let core = self.to_owned(); + let response = ClusterConnInner::::try_cmd_request(Arc::new(cmd), routing, core) + .await + .map_err(|err| err.1)?; + match response { + Response::Single(value) => Ok(value), + _ => Err(RedisError::from(( + ErrorKind::ClientError, + "Expected single response, got unexpected response", + ))), + } + } + async fn are_all_slots_covered(&self) -> bool { + ClusterConnInner::::check_if_all_slots_covered(&self.conn_lock.read().await.slot_map) + } + async fn refresh_if_topology_changed(&self) { + ClusterConnInner::check_topology_and_refresh_if_diff( + self.to_owned(), + // The cluster SCAN implementation must refresh the slots when a topology change is found + // to ensure the scan logic is correct. + &RefreshPolicy::NotThrottable, + ) + .await; + } +} + +/// Perform a cluster scan operation. +/// This function performs a scan operation in a Redis cluster using the given [`ClusterInScan`] connection. +/// It scans the cluster for keys based on the given `ClusterScanArgs` arguments. +/// The function returns a tuple containing the new scan state cursor and the keys found in the scan operation. +/// If the scan operation fails, an error is returned. +/// +/// # Arguments +/// * `core` - The connection to the Redis cluster. +/// * `cluster_scan_args` - The arguments for the cluster scan operation. +/// +/// # Returns +/// A tuple containing the new scan state cursor and the keys found in the scan operation. +/// If the scan operation fails, an error is returned. +pub(crate) async fn cluster_scan( + core: C, + cluster_scan_args: ClusterScanArgs, +) -> RedisResult<(ScanStateRC, Vec)> +where + C: ClusterInScan, +{ + let ClusterScanArgs { + scan_state_cursor, + match_pattern, + count, + object_type, + } = cluster_scan_args; + // If scan_state is None, meaning we start a new scan + let scan_state = match scan_state_cursor.get_state_from_wrapper() { + Some(state) => state, + None => match ScanState::initiate_scan(&core).await { + Ok(state) => state, + Err(err) => { + return Err(err); + } + }, + }; + // Send the actual scan command to the address in the scan_state + let scan_result = send_scan( + &scan_state, + &core, + match_pattern.clone(), + count, + object_type.clone(), + ) + .await; + let ((new_cursor, new_keys), mut scan_state): ((u64, Vec), ScanState) = match scan_result + { + Ok(scan_result) => (from_redis_value(&scan_result)?, scan_state.clone()), + Err(err) => match err.kind() { + // If the scan command failed to route to the address because the address is not found in the cluster or + // the connection to the address cant be reached from different reasons, we will check we want to check if + // the problem is problem that we can recover from like failover or scale down or some network issue + // that we can retry the scan command to an address that own the next slot we are at. + ErrorKind::IoError + | ErrorKind::AllConnectionsUnavailable + | ErrorKind::ConnectionNotFoundForRoute => { + let retry = + retry_scan(&scan_state, &core, match_pattern, count, object_type).await?; + (from_redis_value(&retry.0?)?, retry.1) + } + _ => return Err(err), + }, + }; + + // If the cursor is 0, meaning we finished scanning the address + // we will update the scan state to get the next address to scan + if new_cursor == 0 { + scan_state = scan_state + .create_updated_scan_state_for_completed_address(&core) + .await?; + } + + // If the address is empty, meaning we finished scanning all the address + if scan_state.scan_status == ScanStateStage::Finished { + return Ok((ScanStateRC::create_finished(), new_keys)); + } + + scan_state = ScanState::new( + new_cursor, + scan_state.scanned_slots_map, + scan_state.address_in_scan, + scan_state.address_epoch, + ScanStateStage::InProgress, + ); + Ok((ScanStateRC::from_scan_state(scan_state), new_keys)) +} + +// Send the scan command to the address in the scan_state +async fn send_scan( + scan_state: &ScanState, + core: &C, + match_pattern: Option>, + count: Option, + object_type: Option, +) -> RedisResult +where + C: ClusterInScan, +{ + let mut scan_command = cmd("SCAN"); + scan_command.arg(scan_state.cursor); + if let Some(match_pattern) = match_pattern { + scan_command.arg("MATCH").arg(match_pattern); + } + if let Some(count) = count { + scan_command.arg("COUNT").arg(count); + } + if let Some(object_type) = object_type { + scan_command.arg("TYPE").arg(object_type.to_string()); + } + + core.route_command(scan_command, &scan_state.address_in_scan) + .await +} + +// If the scan command failed to route to the address we will check we will first refresh the slots, we will check if all slots are covered by cluster, +// and if so we will try to get a new address to scan for handling case of failover. +// if all slots are not covered by the cluster we will return an error indicating that the cluster is not well configured. +// if all slots are covered by cluster but we failed to get a new address to scan we will return an error indicating that we failed to get a new address to scan. +// if we got a new address to scan but the scan command failed to route to the address we will return an error indicating that we failed to route the command. +async fn retry_scan( + scan_state: &ScanState, + core: &C, + match_pattern: Option>, + count: Option, + object_type: Option, +) -> RedisResult<(RedisResult, ScanState)> +where + C: ClusterInScan, +{ + // TODO: This mechanism of refreshing on failure to route to address should be part of the routing mechanism + // After the routing mechanism is updated to handle this case, this refresh in the case bellow should be removed + core.refresh_if_topology_changed().await; + if !core.are_all_slots_covered().await { + return Err(RedisError::from(( + ErrorKind::NotAllSlotsCovered, + "Not all slots are covered by the cluster, please check the cluster configuration", + ))); + } + // If for some reason we failed to reach the address we don't know if its a scale down or a failover. + // Since it might be scale down we cant just keep going with the current state we the same cursor as we are at + // the same point in the new address, so we need to get the new address own the next slot that haven't been scanned + // and start from the beginning of this address. + let next_slot = scan_state + .get_next_slot(&scan_state.scanned_slots_map) + .unwrap_or(0); + let address = core.get_address_by_slot(next_slot).await?; + + let new_epoch = core.get_address_epoch(&address).await.unwrap_or(0); + let scan_state = &ScanState::new( + 0, + scan_state.scanned_slots_map, + address, + new_epoch, + ScanStateStage::InProgress, + ); + let res = ( + send_scan(scan_state, core, match_pattern, count, object_type).await, + scan_state.clone(), + ); + Ok(res) +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_creation_of_empty_scan_wrapper() { + let scan_state_wrapper = ScanStateRC::new(); + assert!(scan_state_wrapper.status == ScanStateStage::Initiating); + } + + #[test] + fn test_creation_of_scan_state_wrapper_from() { + let scan_state = ScanState { + cursor: 0, + scanned_slots_map: [0; BITS_ARRAY_SIZE], + address_in_scan: String::from("address1"), + address_epoch: 1, + scan_status: ScanStateStage::InProgress, + }; + + let scan_state_wrapper = ScanStateRC::from_scan_state(scan_state); + assert!(!scan_state_wrapper.is_finished()); + } + + #[test] + // Test the get_next_slot method + fn test_scan_state_get_next_slot() { + let scanned_slots_map: SlotsBitsArray = [0; BITS_ARRAY_SIZE]; + let scan_state = ScanState { + cursor: 0, + scanned_slots_map, + address_in_scan: String::from("address1"), + address_epoch: 1, + scan_status: ScanStateStage::InProgress, + }; + let next_slot = scan_state.get_next_slot(&scanned_slots_map); + assert_eq!(next_slot, Some(0)); + // Set the first slot to 1 + let mut scanned_slots_map: SlotsBitsArray = [0; BITS_ARRAY_SIZE]; + scanned_slots_map[0] = 1; + let scan_state = ScanState { + cursor: 0, + scanned_slots_map, + address_in_scan: String::from("address1"), + address_epoch: 1, + scan_status: ScanStateStage::InProgress, + }; + let next_slot = scan_state.get_next_slot(&scanned_slots_map); + assert_eq!(next_slot, Some(1)); + } + // Create a mock connection + struct MockConnection; + #[async_trait] + impl ClusterInScan for MockConnection { + async fn refresh_if_topology_changed(&self) {} + async fn get_address_by_slot(&self, _slot: u16) -> RedisResult { + Ok("mock_address".to_string()) + } + async fn get_address_epoch(&self, _address: &str) -> Result { + Ok(0) + } + async fn get_slots_of_address(&self, address: &str) -> Vec { + if address == "mock_address" { + vec![3, 4, 5] + } else { + vec![0, 1, 2] + } + } + async fn route_command(&self, _: Cmd, _: &str) -> RedisResult { + unimplemented!() + } + async fn are_all_slots_covered(&self) -> bool { + true + } + } + // Test the initiate_scan function + #[tokio::test] + async fn test_initiate_scan() { + let connection = MockConnection; + let scan_state = ScanState::initiate_scan(&connection).await.unwrap(); + + // Assert that the scan state is initialized correctly + assert_eq!(scan_state.cursor, 0); + assert_eq!(scan_state.scanned_slots_map, [0; BITS_ARRAY_SIZE]); + assert_eq!(scan_state.address_in_scan, "mock_address"); + assert_eq!(scan_state.address_epoch, 0); + } + + // Test the get_next_slot function + #[test] + fn test_get_next_slot() { + let scan_state = ScanState { + cursor: 0, + scanned_slots_map: [0; BITS_ARRAY_SIZE], + address_in_scan: "".to_string(), + address_epoch: 0, + scan_status: ScanStateStage::InProgress, + }; + // Test when all first bits of each u6 are set to 1, the next slots should be 1 + let scanned_slots_map: SlotsBitsArray = [1; BITS_ARRAY_SIZE]; + let next_slot = scan_state.get_next_slot(&scanned_slots_map); + assert_eq!(next_slot, Some(1)); + + // Test when all slots are scanned, the next slot should be 0 + let scanned_slots_map: SlotsBitsArray = [u64::MAX; BITS_ARRAY_SIZE]; + let next_slot = scan_state.get_next_slot(&scanned_slots_map); + assert_eq!(next_slot, Some(16385)); + + // Test when first, second, fourth, sixth and eighth slots scanned, the next slot should be 2 + let mut scanned_slots_map: SlotsBitsArray = [0; BITS_ARRAY_SIZE]; + scanned_slots_map[0] = 171; // 10101011 + let next_slot = scan_state.get_next_slot(&scanned_slots_map); + assert_eq!(next_slot, Some(2)); + } + + // Test the update_scan_state_and_get_next_address function + #[tokio::test] + async fn test_update_scan_state_and_get_next_address() { + let connection = MockConnection; + let scan_state = ScanState::initiate_scan(&connection).await; + let updated_scan_state = scan_state + .unwrap() + .create_updated_scan_state_for_completed_address(&connection) + .await + .unwrap(); + + // cursor should be reset to 0 + assert_eq!(updated_scan_state.cursor, 0); + + // address_in_scan should be updated to the new address + assert_eq!(updated_scan_state.address_in_scan, "mock_address"); + + // address_epoch should be updated to the new address epoch + assert_eq!(updated_scan_state.address_epoch, 0); + } + + #[tokio::test] + async fn test_update_scan_state_without_updating_scanned_map() { + let connection = MockConnection; + let scan_state = ScanState::new( + 0, + [0; BITS_ARRAY_SIZE], + "address".to_string(), + 0, + ScanStateStage::InProgress, + ); + let scanned_slots_map = scan_state.scanned_slots_map; + let updated_scan_state = scan_state + .creating_state_without_slot_changes(&connection) + .await + .unwrap(); + assert_eq!(updated_scan_state.scanned_slots_map, scanned_slots_map); + assert_eq!(updated_scan_state.cursor, 0); + assert_eq!(updated_scan_state.address_in_scan, "mock_address"); + assert_eq!(updated_scan_state.address_epoch, 0); + } +} diff --git a/glide-core/redis-rs/redis/src/commands/json.rs b/glide-core/redis-rs/redis/src/commands/json.rs new file mode 100644 index 0000000000..d63f70c86f --- /dev/null +++ b/glide-core/redis-rs/redis/src/commands/json.rs @@ -0,0 +1,390 @@ +use crate::cmd::{cmd, Cmd}; +use crate::connection::ConnectionLike; +use crate::pipeline::Pipeline; +use crate::types::{FromRedisValue, RedisResult, ToRedisArgs}; +use crate::RedisError; + +#[cfg(feature = "cluster")] +use crate::commands::ClusterPipeline; + +use serde::ser::Serialize; + +macro_rules! implement_json_commands { + ( + $lifetime: lifetime + $( + $(#[$attr:meta])+ + fn $name:ident<$($tyargs:ident : $ty:ident),*>( + $($argname:ident: $argty:ty),*) $body:block + )* + ) => ( + + /// Implements RedisJSON commands for connection like objects. This + /// allows you to send commands straight to a connection or client. It + /// is also implemented for redis results of clients which makes for + /// very convenient access in some basic cases. + /// + /// This allows you to use nicer syntax for some common operations. + /// For instance this code: + /// + /// ```rust,no_run + /// use redis::JsonCommands; + /// use serde_json::json; + /// # fn do_something() -> redis::RedisResult<()> { + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_connection(None)?; + /// redis::cmd("JSON.SET").arg("my_key").arg("$").arg(&json!({"item": 42i32}).to_string()).execute(&mut con); + /// assert_eq!(redis::cmd("JSON.GET").arg("my_key").arg("$").query(&mut con), Ok(String::from(r#"[{"item":42}]"#))); + /// # Ok(()) } + /// ``` + /// + /// Will become this: + /// + /// ```rust,no_run + /// use redis::JsonCommands; + /// use serde_json::json; + /// # fn do_something() -> redis::RedisResult<()> { + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_connection(None)?; + /// con.json_set("my_key", "$", &json!({"item": 42i32}).to_string())?; + /// assert_eq!(con.json_get("my_key", "$"), Ok(String::from(r#"[{"item":42}]"#))); + /// assert_eq!(con.json_get("my_key", "$.item"), Ok(String::from(r#"[42]"#))); + /// # Ok(()) } + /// ``` + /// + /// With RedisJSON commands, you have to note that all results will be wrapped + /// in square brackets (or empty brackets if not found). If you want to deserialize it + /// with e.g. `serde_json` you have to use `Vec` for your output type instead of `T`. + pub trait JsonCommands : ConnectionLike + Sized { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + fn $name<$lifetime, $($tyargs: $ty, )* RV: FromRedisValue>( + &mut self $(, $argname: $argty)*) -> RedisResult + { Cmd::$name($($argname),*)?.query(self) } + )* + } + + impl Cmd { + $( + $(#[$attr])* + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + pub fn $name<$lifetime, $($tyargs: $ty),*>($($argname: $argty),*) -> RedisResult { + $body + } + )* + } + + /// Implements RedisJSON commands over asynchronous connections. This + /// allows you to send commands straight to a connection or client. + /// + /// This allows you to use nicer syntax for some common operations. + /// For instance this code: + /// + /// ```rust,no_run + /// use redis::JsonAsyncCommands; + /// use serde_json::json; + /// # async fn do_something() -> redis::RedisResult<()> { + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_async_connection(None).await?; + /// redis::cmd("JSON.SET").arg("my_key").arg("$").arg(&json!({"item": 42i32}).to_string()).query_async(&mut con).await?; + /// assert_eq!(redis::cmd("JSON.GET").arg("my_key").arg("$").query_async(&mut con).await, Ok(String::from(r#"[{"item":42}]"#))); + /// # Ok(()) } + /// ``` + /// + /// Will become this: + /// + /// ```rust,no_run + /// use redis::JsonAsyncCommands; + /// use serde_json::json; + /// # async fn do_something() -> redis::RedisResult<()> { + /// use redis::Commands; + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_async_connection(None).await?; + /// con.json_set("my_key", "$", &json!({"item": 42i32}).to_string()).await?; + /// assert_eq!(con.json_get("my_key", "$").await, Ok(String::from(r#"[{"item":42}]"#))); + /// assert_eq!(con.json_get("my_key", "$.item").await, Ok(String::from(r#"[42]"#))); + /// # Ok(()) } + /// ``` + /// + /// With RedisJSON commands, you have to note that all results will be wrapped + /// in square brackets (or empty brackets if not found). If you want to deserialize it + /// with e.g. `serde_json` you have to use `Vec` for your output type instead of `T`. + /// + #[cfg(feature = "aio")] + pub trait JsonAsyncCommands : crate::aio::ConnectionLike + Send + Sized { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + fn $name<$lifetime, $($tyargs: $ty + Send + Sync + $lifetime,)* RV>( + & $lifetime mut self + $(, $argname: $argty)* + ) -> $crate::types::RedisFuture<'a, RV> + where + RV: FromRedisValue, + { + Box::pin(async move { + $body?.query_async(self).await + }) + } + )* + } + + /// Implements RedisJSON commands for pipelines. Unlike the regular + /// commands trait, this returns the pipeline rather than a result + /// directly. Other than that it works the same however. + impl Pipeline { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + pub fn $name<$lifetime, $($tyargs: $ty),*>( + &mut self $(, $argname: $argty)* + ) -> RedisResult<&mut Self> { + self.add_command($body?); + Ok(self) + } + )* + } + + /// Implements RedisJSON commands for cluster pipelines. Unlike the regular + /// commands trait, this returns the cluster pipeline rather than a result + /// directly. Other than that it works the same however. + #[cfg(feature = "cluster")] + impl ClusterPipeline { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + pub fn $name<$lifetime, $($tyargs: $ty),*>( + &mut self $(, $argname: $argty)* + ) -> RedisResult<&mut Self> { + self.add_command($body?); + Ok(self) + } + )* + } + + ) +} + +implement_json_commands! { + 'a + + /// Append the JSON `value` to the array at `path` after the last element in it. + fn json_arr_append(key: K, path: P, value: &'a V) { + let mut cmd = cmd("JSON.ARRAPPEND"); + + cmd.arg(key) + .arg(path) + .arg(serde_json::to_string(value)?); + + Ok::<_, RedisError>(cmd) + } + + /// Index array at `path`, returns first occurance of `value` + fn json_arr_index(key: K, path: P, value: &'a V) { + let mut cmd = cmd("JSON.ARRINDEX"); + + cmd.arg(key) + .arg(path) + .arg(serde_json::to_string(value)?); + + Ok::<_, RedisError>(cmd) + } + + /// Same as `json_arr_index` except takes a `start` and a `stop` value, setting these to `0` will mean + /// they make no effect on the query + /// + /// The default values for `start` and `stop` are `0`, so pass those in if you want them to take no effect + fn json_arr_index_ss(key: K, path: P, value: &'a V, start: &'a isize, stop: &'a isize) { + let mut cmd = cmd("JSON.ARRINDEX"); + + cmd.arg(key) + .arg(path) + .arg(serde_json::to_string(value)?) + .arg(start) + .arg(stop); + + Ok::<_, RedisError>(cmd) + } + + /// Inserts the JSON `value` in the array at `path` before the `index` (shifts to the right). + /// + /// `index` must be withing the array's range. + fn json_arr_insert(key: K, path: P, index: i64, value: &'a V) { + let mut cmd = cmd("JSON.ARRINSERT"); + + cmd.arg(key) + .arg(path) + .arg(index) + .arg(serde_json::to_string(value)?); + + Ok::<_, RedisError>(cmd) + + } + + /// Reports the length of the JSON Array at `path` in `key`. + fn json_arr_len(key: K, path: P) { + let mut cmd = cmd("JSON.ARRLEN"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Removes and returns an element from the `index` in the array. + /// + /// `index` defaults to `-1` (the end of the array). + fn json_arr_pop(key: K, path: P, index: i64) { + let mut cmd = cmd("JSON.ARRPOP"); + + cmd.arg(key) + .arg(path) + .arg(index); + + Ok::<_, RedisError>(cmd) + } + + /// Trims an array so that it contains only the specified inclusive range of elements. + /// + /// This command is extremely forgiving and using it with out-of-range indexes will not produce an error. + /// There are a few differences between how RedisJSON v2.0 and legacy versions handle out-of-range indexes. + fn json_arr_trim(key: K, path: P, start: i64, stop: i64) { + let mut cmd = cmd("JSON.ARRTRIM"); + + cmd.arg(key) + .arg(path) + .arg(start) + .arg(stop); + + Ok::<_, RedisError>(cmd) + } + + /// Clears container values (Arrays/Objects), and sets numeric values to 0. + fn json_clear(key: K, path: P) { + let mut cmd = cmd("JSON.CLEAR"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Deletes a value at `path`. + fn json_del(key: K, path: P) { + let mut cmd = cmd("JSON.DEL"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Gets JSON Value(s) at `path`. + /// + /// Runs `JSON.GET` if key is singular, `JSON.MGET` if there are multiple keys. + /// + /// With RedisJSON commands, you have to note that all results will be wrapped + /// in square brackets (or empty brackets if not found). If you want to deserialize it + /// with e.g. `serde_json` you have to use `Vec` for your output type instead of `T`. + fn json_get(key: K, path: P) { + let mut cmd = cmd(if key.is_single_arg() { "JSON.GET" } else { "JSON.MGET" }); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Increments the number value stored at `path` by `number`. + fn json_num_incr_by(key: K, path: P, value: i64) { + let mut cmd = cmd("JSON.NUMINCRBY"); + + cmd.arg(key) + .arg(path) + .arg(value); + + Ok::<_, RedisError>(cmd) + } + + /// Returns the keys in the object that's referenced by `path`. + fn json_obj_keys(key: K, path: P) { + let mut cmd = cmd("JSON.OBJKEYS"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Reports the number of keys in the JSON Object at `path` in `key`. + fn json_obj_len(key: K, path: P) { + let mut cmd = cmd("JSON.OBJLEN"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Sets the JSON Value at `path` in `key`. + fn json_set(key: K, path: P, value: &'a V) { + let mut cmd = cmd("JSON.SET"); + + cmd.arg(key) + .arg(path) + .arg(serde_json::to_string(value)?); + + Ok::<_, RedisError>(cmd) + } + + /// Appends the `json-string` values to the string at `path`. + fn json_str_append(key: K, path: P, value: V) { + let mut cmd = cmd("JSON.STRAPPEND"); + + cmd.arg(key) + .arg(path) + .arg(value); + + Ok::<_, RedisError>(cmd) + } + + /// Reports the length of the JSON String at `path` in `key`. + fn json_str_len(key: K, path: P) { + let mut cmd = cmd("JSON.STRLEN"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Toggle a `boolean` value stored at `path`. + fn json_toggle(key: K, path: P) { + let mut cmd = cmd("JSON.TOGGLE"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Reports the type of JSON value at `path`. + fn json_type(key: K, path: P) { + let mut cmd = cmd("JSON.TYPE"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } +} + +impl JsonCommands for T where T: ConnectionLike {} + +#[cfg(feature = "aio")] +impl JsonAsyncCommands for T where T: crate::aio::ConnectionLike + Send + Sized {} diff --git a/glide-core/redis-rs/redis/src/commands/macros.rs b/glide-core/redis-rs/redis/src/commands/macros.rs new file mode 100644 index 0000000000..9e7d4373c0 --- /dev/null +++ b/glide-core/redis-rs/redis/src/commands/macros.rs @@ -0,0 +1,275 @@ +macro_rules! implement_commands { + ( + $lifetime: lifetime + $( + $(#[$attr:meta])+ + fn $name:ident<$($tyargs:ident : $ty:ident),*>( + $($argname:ident: $argty:ty),*) $body:block + )* + ) => + ( + /// Implements common redis commands for connection like objects. This + /// allows you to send commands straight to a connection or client. It + /// is also implemented for redis results of clients which makes for + /// very convenient access in some basic cases. + /// + /// This allows you to use nicer syntax for some common operations. + /// For instance this code: + /// + /// ```rust,no_run + /// # fn do_something() -> redis::RedisResult<()> { + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_connection(None)?; + /// redis::cmd("SET").arg("my_key").arg(42).execute(&mut con); + /// assert_eq!(redis::cmd("GET").arg("my_key").query(&mut con), Ok(42)); + /// # Ok(()) } + /// ``` + /// + /// Will become this: + /// + /// ```rust,no_run + /// # fn do_something() -> redis::RedisResult<()> { + /// use redis::Commands; + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_connection(None)?; + /// con.set("my_key", 42)?; + /// assert_eq!(con.get("my_key"), Ok(42)); + /// # Ok(()) } + /// ``` + pub trait Commands : ConnectionLike+Sized { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + fn $name<$lifetime, $($tyargs: $ty, )* RV: FromRedisValue>( + &mut self $(, $argname: $argty)*) -> RedisResult + { Cmd::$name($($argname),*).query(self) } + )* + + /// Incrementally iterate the keys space. + #[inline] + fn scan(&mut self) -> RedisResult> { + let mut c = cmd("SCAN"); + c.cursor_arg(0); + c.iter(self) + } + + /// Incrementally iterate the keys space for keys matching a pattern. + #[inline] + fn scan_match(&mut self, pattern: P) -> RedisResult> { + let mut c = cmd("SCAN"); + c.cursor_arg(0).arg("MATCH").arg(pattern); + c.iter(self) + } + + /// Incrementally iterate hash fields and associated values. + #[inline] + fn hscan(&mut self, key: K) -> RedisResult> { + let mut c = cmd("HSCAN"); + c.arg(key).cursor_arg(0); + c.iter(self) + } + + /// Incrementally iterate hash fields and associated values for + /// field names matching a pattern. + #[inline] + fn hscan_match + (&mut self, key: K, pattern: P) -> RedisResult> { + let mut c = cmd("HSCAN"); + c.arg(key).cursor_arg(0).arg("MATCH").arg(pattern); + c.iter(self) + } + + /// Incrementally iterate set elements. + #[inline] + fn sscan(&mut self, key: K) -> RedisResult> { + let mut c = cmd("SSCAN"); + c.arg(key).cursor_arg(0); + c.iter(self) + } + + /// Incrementally iterate set elements for elements matching a pattern. + #[inline] + fn sscan_match + (&mut self, key: K, pattern: P) -> RedisResult> { + let mut c = cmd("SSCAN"); + c.arg(key).cursor_arg(0).arg("MATCH").arg(pattern); + c.iter(self) + } + + /// Incrementally iterate sorted set elements. + #[inline] + fn zscan(&mut self, key: K) -> RedisResult> { + let mut c = cmd("ZSCAN"); + c.arg(key).cursor_arg(0); + c.iter(self) + } + + /// Incrementally iterate sorted set elements for elements matching a pattern. + #[inline] + fn zscan_match + (&mut self, key: K, pattern: P) -> RedisResult> { + let mut c = cmd("ZSCAN"); + c.arg(key).cursor_arg(0).arg("MATCH").arg(pattern); + c.iter(self) + } + } + + impl Cmd { + $( + $(#[$attr])* + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + pub fn $name<$lifetime, $($tyargs: $ty),*>($($argname: $argty),*) -> Self { + ::std::mem::replace($body, Cmd::new()) + } + )* + } + + /// Implements common redis commands over asynchronous connections. This + /// allows you to send commands straight to a connection or client. + /// + /// This allows you to use nicer syntax for some common operations. + /// For instance this code: + /// + /// ```rust,no_run + /// use redis::AsyncCommands; + /// # async fn do_something() -> redis::RedisResult<()> { + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_async_connection(None).await?; + /// redis::cmd("SET").arg("my_key").arg(42i32).query_async(&mut con).await?; + /// assert_eq!(redis::cmd("GET").arg("my_key").query_async(&mut con).await, Ok(42i32)); + /// # Ok(()) } + /// ``` + /// + /// Will become this: + /// + /// ```rust,no_run + /// use redis::AsyncCommands; + /// # async fn do_something() -> redis::RedisResult<()> { + /// use redis::Commands; + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_async_connection(None).await?; + /// con.set("my_key", 42i32).await?; + /// assert_eq!(con.get("my_key").await, Ok(42i32)); + /// # Ok(()) } + /// ``` + #[cfg(feature = "aio")] + pub trait AsyncCommands : crate::aio::ConnectionLike + Send + Sized { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + fn $name<$lifetime, $($tyargs: $ty + Send + Sync + $lifetime,)* RV>( + & $lifetime mut self + $(, $argname: $argty)* + ) -> crate::types::RedisFuture<'a, RV> + where + RV: FromRedisValue, + { + Box::pin(async move { ($body).query_async(self).await }) + } + )* + + /// Incrementally iterate the keys space. + #[inline] + fn scan(&mut self) -> crate::types::RedisFuture> { + let mut c = cmd("SCAN"); + c.cursor_arg(0); + Box::pin(async move { c.iter_async(self).await }) + } + + /// Incrementally iterate set elements for elements matching a pattern. + #[inline] + fn scan_match(&mut self, pattern: P) -> crate::types::RedisFuture> { + let mut c = cmd("SCAN"); + c.cursor_arg(0).arg("MATCH").arg(pattern); + Box::pin(async move { c.iter_async(self).await }) + } + + /// Incrementally iterate hash fields and associated values. + #[inline] + fn hscan(&mut self, key: K) -> crate::types::RedisFuture> { + let mut c = cmd("HSCAN"); + c.arg(key).cursor_arg(0); + Box::pin(async move {c.iter_async(self).await }) + } + + /// Incrementally iterate hash fields and associated values for + /// field names matching a pattern. + #[inline] + fn hscan_match + (&mut self, key: K, pattern: P) -> crate::types::RedisFuture> { + let mut c = cmd("HSCAN"); + c.arg(key).cursor_arg(0).arg("MATCH").arg(pattern); + Box::pin(async move {c.iter_async(self).await }) + } + + /// Incrementally iterate set elements. + #[inline] + fn sscan(&mut self, key: K) -> crate::types::RedisFuture> { + let mut c = cmd("SSCAN"); + c.arg(key).cursor_arg(0); + Box::pin(async move {c.iter_async(self).await }) + } + + /// Incrementally iterate set elements for elements matching a pattern. + #[inline] + fn sscan_match + (&mut self, key: K, pattern: P) -> crate::types::RedisFuture> { + let mut c = cmd("SSCAN"); + c.arg(key).cursor_arg(0).arg("MATCH").arg(pattern); + Box::pin(async move {c.iter_async(self).await }) + } + + /// Incrementally iterate sorted set elements. + #[inline] + fn zscan(&mut self, key: K) -> crate::types::RedisFuture> { + let mut c = cmd("ZSCAN"); + c.arg(key).cursor_arg(0); + Box::pin(async move {c.iter_async(self).await }) + } + + /// Incrementally iterate sorted set elements for elements matching a pattern. + #[inline] + fn zscan_match + (&mut self, key: K, pattern: P) -> crate::types::RedisFuture> { + let mut c = cmd("ZSCAN"); + c.arg(key).cursor_arg(0).arg("MATCH").arg(pattern); + Box::pin(async move {c.iter_async(self).await }) + } + } + + /// Implements common redis commands for pipelines. Unlike the regular + /// commands trait, this returns the pipeline rather than a result + /// directly. Other than that it works the same however. + impl Pipeline { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + pub fn $name<$lifetime, $($tyargs: $ty),*>( + &mut self $(, $argname: $argty)* + ) -> &mut Self { + self.add_command(::std::mem::replace($body, Cmd::new())) + } + )* + } + + // Implements common redis commands for cluster pipelines. Unlike the regular + // commands trait, this returns the cluster pipeline rather than a result + // directly. Other than that it works the same however. + #[cfg(feature = "cluster")] + impl ClusterPipeline { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + pub fn $name<$lifetime, $($tyargs: $ty),*>( + &mut self $(, $argname: $argty)* + ) -> &mut Self { + self.add_command(::std::mem::replace($body, Cmd::new())) + } + )* + } + ) +} diff --git a/glide-core/redis-rs/redis/src/commands/mod.rs b/glide-core/redis-rs/redis/src/commands/mod.rs new file mode 100644 index 0000000000..d5c937fa70 --- /dev/null +++ b/glide-core/redis-rs/redis/src/commands/mod.rs @@ -0,0 +1,2190 @@ +use crate::cmd::{cmd, Cmd, Iter}; +use crate::connection::{Connection, ConnectionLike, Msg}; +use crate::pipeline::Pipeline; +use crate::types::{ + ExistenceCheck, Expiry, FromRedisValue, NumericBehavior, RedisResult, RedisWrite, SetExpiry, + ToRedisArgs, +}; + +#[macro_use] +mod macros; + +#[cfg(feature = "json")] +#[cfg_attr(docsrs, doc(cfg(feature = "json")))] +mod json; + +#[cfg(feature = "cluster-async")] +pub use cluster_scan::ScanStateRC; + +#[cfg(feature = "cluster-async")] +pub(crate) mod cluster_scan; + +#[cfg(feature = "cluster-async")] +pub use cluster_scan::ObjectType; + +#[cfg(feature = "json")] +pub use json::JsonCommands; + +#[cfg(all(feature = "json", feature = "aio"))] +pub use json::JsonAsyncCommands; + +#[cfg(feature = "cluster")] +use crate::cluster_pipeline::ClusterPipeline; + +#[cfg(feature = "geospatial")] +use crate::geo; + +#[cfg(feature = "streams")] +use crate::streams; + +#[cfg(feature = "acl")] +use crate::acl; +use crate::RedisConnectionInfo; + +implement_commands! { + 'a + // most common operations + + /// Get the value of a key. If key is a vec this becomes an `MGET`. + fn get(key: K) { + cmd(if key.is_single_arg() { "GET" } else { "MGET" }).arg(key) + } + + /// Get values of keys + fn mget(key: K){ + cmd("MGET").arg(key) + } + + /// Gets all keys matching pattern + fn keys(key: K) { + cmd("KEYS").arg(key) + } + + /// Set the string value of a key. + fn set(key: K, value: V) { + cmd("SET").arg(key).arg(value) + } + + /// Set the string value of a key with options. + fn set_options(key: K, value: V, options: SetOptions) { + cmd("SET").arg(key).arg(value).arg(options) + } + + /// Sets multiple keys to their values. + #[allow(deprecated)] + #[deprecated(since = "0.22.4", note = "Renamed to mset() to reflect Redis name")] + fn set_multiple(items: &'a [(K, V)]) { + cmd("MSET").arg(items) + } + + /// Sets multiple keys to their values. + fn mset(items: &'a [(K, V)]) { + cmd("MSET").arg(items) + } + + /// Set the value and expiration of a key. + fn set_ex(key: K, value: V, seconds: u64) { + cmd("SETEX").arg(key).arg(seconds).arg(value) + } + + /// Set the value and expiration in milliseconds of a key. + fn pset_ex(key: K, value: V, milliseconds: u64) { + cmd("PSETEX").arg(key).arg(milliseconds).arg(value) + } + + /// Set the value of a key, only if the key does not exist + fn set_nx(key: K, value: V) { + cmd("SETNX").arg(key).arg(value) + } + + /// Sets multiple keys to their values failing if at least one already exists. + fn mset_nx(items: &'a [(K, V)]) { + cmd("MSETNX").arg(items) + } + + /// Set the string value of a key and return its old value. + fn getset(key: K, value: V) { + cmd("GETSET").arg(key).arg(value) + } + + /// Get a range of bytes/substring from the value of a key. Negative values provide an offset from the end of the value. + fn getrange(key: K, from: isize, to: isize) { + cmd("GETRANGE").arg(key).arg(from).arg(to) + } + + /// Overwrite the part of the value stored in key at the specified offset. + fn setrange(key: K, offset: isize, value: V) { + cmd("SETRANGE").arg(key).arg(offset).arg(value) + } + + /// Delete one or more keys. + fn del(key: K) { + cmd("DEL").arg(key) + } + + /// Determine if a key exists. + fn exists(key: K) { + cmd("EXISTS").arg(key) + } + + /// Determine the type of a key. + fn key_type(key: K) { + cmd("TYPE").arg(key) + } + + /// Set a key's time to live in seconds. + fn expire(key: K, seconds: i64) { + cmd("EXPIRE").arg(key).arg(seconds) + } + + /// Set the expiration for a key as a UNIX timestamp. + fn expire_at(key: K, ts: i64) { + cmd("EXPIREAT").arg(key).arg(ts) + } + + /// Set a key's time to live in milliseconds. + fn pexpire(key: K, ms: i64) { + cmd("PEXPIRE").arg(key).arg(ms) + } + + /// Set the expiration for a key as a UNIX timestamp in milliseconds. + fn pexpire_at(key: K, ts: i64) { + cmd("PEXPIREAT").arg(key).arg(ts) + } + + /// Remove the expiration from a key. + fn persist(key: K) { + cmd("PERSIST").arg(key) + } + + /// Get the expiration time of a key. + fn ttl(key: K) { + cmd("TTL").arg(key) + } + + /// Get the expiration time of a key in milliseconds. + fn pttl(key: K) { + cmd("PTTL").arg(key) + } + + /// Get the value of a key and set expiration + fn get_ex(key: K, expire_at: Expiry) { + let (option, time_arg) = match expire_at { + Expiry::EX(sec) => ("EX", Some(sec)), + Expiry::PX(ms) => ("PX", Some(ms)), + Expiry::EXAT(timestamp_sec) => ("EXAT", Some(timestamp_sec)), + Expiry::PXAT(timestamp_ms) => ("PXAT", Some(timestamp_ms)), + Expiry::PERSIST => ("PERSIST", None), + }; + + cmd("GETEX").arg(key).arg(option).arg(time_arg) + } + + /// Get the value of a key and delete it + fn get_del(key: K) { + cmd("GETDEL").arg(key) + } + + /// Rename a key. + fn rename(key: K, new_key: N) { + cmd("RENAME").arg(key).arg(new_key) + } + + /// Rename a key, only if the new key does not exist. + fn rename_nx(key: K, new_key: N) { + cmd("RENAMENX").arg(key).arg(new_key) + } + + /// Unlink one or more keys. + fn unlink(key: K) { + cmd("UNLINK").arg(key) + } + + // common string operations + + /// Append a value to a key. + fn append(key: K, value: V) { + cmd("APPEND").arg(key).arg(value) + } + + /// Increment the numeric value of a key by the given amount. This + /// issues a `INCRBY` or `INCRBYFLOAT` depending on the type. + fn incr(key: K, delta: V) { + cmd(if delta.describe_numeric_behavior() == NumericBehavior::NumberIsFloat { + "INCRBYFLOAT" + } else { + "INCRBY" + }).arg(key).arg(delta) + } + + /// Decrement the numeric value of a key by the given amount. + fn decr(key: K, delta: V) { + cmd("DECRBY").arg(key).arg(delta) + } + + /// Sets or clears the bit at offset in the string value stored at key. + fn setbit(key: K, offset: usize, value: bool) { + cmd("SETBIT").arg(key).arg(offset).arg(i32::from(value)) + } + + /// Returns the bit value at offset in the string value stored at key. + fn getbit(key: K, offset: usize) { + cmd("GETBIT").arg(key).arg(offset) + } + + /// Count set bits in a string. + fn bitcount(key: K) { + cmd("BITCOUNT").arg(key) + } + + /// Count set bits in a string in a range. + fn bitcount_range(key: K, start: usize, end: usize) { + cmd("BITCOUNT").arg(key).arg(start).arg(end) + } + + /// Perform a bitwise AND between multiple keys (containing string values) + /// and store the result in the destination key. + fn bit_and(dstkey: D, srckeys: S) { + cmd("BITOP").arg("AND").arg(dstkey).arg(srckeys) + } + + /// Perform a bitwise OR between multiple keys (containing string values) + /// and store the result in the destination key. + fn bit_or(dstkey: D, srckeys: S) { + cmd("BITOP").arg("OR").arg(dstkey).arg(srckeys) + } + + /// Perform a bitwise XOR between multiple keys (containing string values) + /// and store the result in the destination key. + fn bit_xor(dstkey: D, srckeys: S) { + cmd("BITOP").arg("XOR").arg(dstkey).arg(srckeys) + } + + /// Perform a bitwise NOT of the key (containing string values) + /// and store the result in the destination key. + fn bit_not(dstkey: D, srckey: S) { + cmd("BITOP").arg("NOT").arg(dstkey).arg(srckey) + } + + /// Get the length of the value stored in a key. + fn strlen(key: K) { + cmd("STRLEN").arg(key) + } + + // hash operations + + /// Gets a single (or multiple) fields from a hash. + fn hget(key: K, field: F) { + cmd(if field.is_single_arg() { "HGET" } else { "HMGET" }).arg(key).arg(field) + } + + /// Deletes a single (or multiple) fields from a hash. + fn hdel(key: K, field: F) { + cmd("HDEL").arg(key).arg(field) + } + + /// Sets a single field in a hash. + fn hset(key: K, field: F, value: V) { + cmd("HSET").arg(key).arg(field).arg(value) + } + + /// Sets a single field in a hash if it does not exist. + fn hset_nx(key: K, field: F, value: V) { + cmd("HSETNX").arg(key).arg(field).arg(value) + } + + /// Sets a multiple fields in a hash. + fn hset_multiple(key: K, items: &'a [(F, V)]) { + cmd("HMSET").arg(key).arg(items) + } + + /// Increments a value. + fn hincr(key: K, field: F, delta: D) { + cmd(if delta.describe_numeric_behavior() == NumericBehavior::NumberIsFloat { + "HINCRBYFLOAT" + } else { + "HINCRBY" + }).arg(key).arg(field).arg(delta) + } + + /// Checks if a field in a hash exists. + fn hexists(key: K, field: F) { + cmd("HEXISTS").arg(key).arg(field) + } + + /// Gets all the keys in a hash. + fn hkeys(key: K) { + cmd("HKEYS").arg(key) + } + + /// Gets all the values in a hash. + fn hvals(key: K) { + cmd("HVALS").arg(key) + } + + /// Gets all the fields and values in a hash. + fn hgetall(key: K) { + cmd("HGETALL").arg(key) + } + + /// Gets the length of a hash. + fn hlen(key: K) { + cmd("HLEN").arg(key) + } + + // list operations + + /// Pop an element from a list, push it to another list + /// and return it; or block until one is available + fn blmove(srckey: S, dstkey: D, src_dir: Direction, dst_dir: Direction, timeout: f64) { + cmd("BLMOVE").arg(srckey).arg(dstkey).arg(src_dir).arg(dst_dir).arg(timeout) + } + + /// Pops `count` elements from the first non-empty list key from the list of + /// provided key names; or blocks until one is available. + fn blmpop(timeout: f64, numkeys: usize, key: K, dir: Direction, count: usize){ + cmd("BLMPOP").arg(timeout).arg(numkeys).arg(key).arg(dir).arg("COUNT").arg(count) + } + + /// Remove and get the first element in a list, or block until one is available. + fn blpop(key: K, timeout: f64) { + cmd("BLPOP").arg(key).arg(timeout) + } + + /// Remove and get the last element in a list, or block until one is available. + fn brpop(key: K, timeout: f64) { + cmd("BRPOP").arg(key).arg(timeout) + } + + /// Pop a value from a list, push it to another list and return it; + /// or block until one is available. + fn brpoplpush(srckey: S, dstkey: D, timeout: f64) { + cmd("BRPOPLPUSH").arg(srckey).arg(dstkey).arg(timeout) + } + + /// Get an element from a list by its index. + fn lindex(key: K, index: isize) { + cmd("LINDEX").arg(key).arg(index) + } + + /// Insert an element before another element in a list. + fn linsert_before( + key: K, pivot: P, value: V) { + cmd("LINSERT").arg(key).arg("BEFORE").arg(pivot).arg(value) + } + + /// Insert an element after another element in a list. + fn linsert_after( + key: K, pivot: P, value: V) { + cmd("LINSERT").arg(key).arg("AFTER").arg(pivot).arg(value) + } + + /// Returns the length of the list stored at key. + fn llen(key: K) { + cmd("LLEN").arg(key) + } + + /// Pop an element a list, push it to another list and return it + fn lmove(srckey: S, dstkey: D, src_dir: Direction, dst_dir: Direction) { + cmd("LMOVE").arg(srckey).arg(dstkey).arg(src_dir).arg(dst_dir) + } + + /// Pops `count` elements from the first non-empty list key from the list of + /// provided key names. + fn lmpop( numkeys: usize, key: K, dir: Direction, count: usize) { + cmd("LMPOP").arg(numkeys).arg(key).arg(dir).arg("COUNT").arg(count) + } + + /// Removes and returns the up to `count` first elements of the list stored at key. + /// + /// If `count` is not specified, then defaults to first element. + fn lpop(key: K, count: Option) { + cmd("LPOP").arg(key).arg(count) + } + + /// Returns the index of the first matching value of the list stored at key. + fn lpos(key: K, value: V, options: LposOptions) { + cmd("LPOS").arg(key).arg(value).arg(options) + } + + /// Insert all the specified values at the head of the list stored at key. + fn lpush(key: K, value: V) { + cmd("LPUSH").arg(key).arg(value) + } + + /// Inserts a value at the head of the list stored at key, only if key + /// already exists and holds a list. + fn lpush_exists(key: K, value: V) { + cmd("LPUSHX").arg(key).arg(value) + } + + /// Returns the specified elements of the list stored at key. + fn lrange(key: K, start: isize, stop: isize) { + cmd("LRANGE").arg(key).arg(start).arg(stop) + } + + /// Removes the first count occurrences of elements equal to value + /// from the list stored at key. + fn lrem(key: K, count: isize, value: V) { + cmd("LREM").arg(key).arg(count).arg(value) + } + + /// Trim an existing list so that it will contain only the specified + /// range of elements specified. + fn ltrim(key: K, start: isize, stop: isize) { + cmd("LTRIM").arg(key).arg(start).arg(stop) + } + + /// Sets the list element at index to value + fn lset(key: K, index: isize, value: V) { + cmd("LSET").arg(key).arg(index).arg(value) + } + + /// Removes and returns the up to `count` last elements of the list stored at key + /// + /// If `count` is not specified, then defaults to last element. + fn rpop(key: K, count: Option) { + cmd("RPOP").arg(key).arg(count) + } + + /// Pop a value from a list, push it to another list and return it. + fn rpoplpush(key: K, dstkey: D) { + cmd("RPOPLPUSH").arg(key).arg(dstkey) + } + + /// Insert all the specified values at the tail of the list stored at key. + fn rpush(key: K, value: V) { + cmd("RPUSH").arg(key).arg(value) + } + + /// Inserts value at the tail of the list stored at key, only if key + /// already exists and holds a list. + fn rpush_exists(key: K, value: V) { + cmd("RPUSHX").arg(key).arg(value) + } + + // set commands + + /// Add one or more members to a set. + fn sadd(key: K, member: M) { + cmd("SADD").arg(key).arg(member) + } + + /// Get the number of members in a set. + fn scard(key: K) { + cmd("SCARD").arg(key) + } + + /// Subtract multiple sets. + fn sdiff(keys: K) { + cmd("SDIFF").arg(keys) + } + + /// Subtract multiple sets and store the resulting set in a key. + fn sdiffstore(dstkey: D, keys: K) { + cmd("SDIFFSTORE").arg(dstkey).arg(keys) + } + + /// Intersect multiple sets. + fn sinter(keys: K) { + cmd("SINTER").arg(keys) + } + + /// Intersect multiple sets and store the resulting set in a key. + fn sinterstore(dstkey: D, keys: K) { + cmd("SINTERSTORE").arg(dstkey).arg(keys) + } + + /// Determine if a given value is a member of a set. + fn sismember(key: K, member: M) { + cmd("SISMEMBER").arg(key).arg(member) + } + + /// Determine if given values are members of a set. + fn smismember(key: K, members: M) { + cmd("SMISMEMBER").arg(key).arg(members) + } + + /// Get all the members in a set. + fn smembers(key: K) { + cmd("SMEMBERS").arg(key) + } + + /// Move a member from one set to another. + fn smove(srckey: S, dstkey: D, member: M) { + cmd("SMOVE").arg(srckey).arg(dstkey).arg(member) + } + + /// Remove and return a random member from a set. + fn spop(key: K) { + cmd("SPOP").arg(key) + } + + /// Get one random member from a set. + fn srandmember(key: K) { + cmd("SRANDMEMBER").arg(key) + } + + /// Get multiple random members from a set. + fn srandmember_multiple(key: K, count: usize) { + cmd("SRANDMEMBER").arg(key).arg(count) + } + + /// Remove one or more members from a set. + fn srem(key: K, member: M) { + cmd("SREM").arg(key).arg(member) + } + + /// Add multiple sets. + fn sunion(keys: K) { + cmd("SUNION").arg(keys) + } + + /// Add multiple sets and store the resulting set in a key. + fn sunionstore(dstkey: D, keys: K) { + cmd("SUNIONSTORE").arg(dstkey).arg(keys) + } + + // sorted set commands + + /// Add one member to a sorted set, or update its score if it already exists. + fn zadd(key: K, member: M, score: S) { + cmd("ZADD").arg(key).arg(score).arg(member) + } + + /// Add multiple members to a sorted set, or update its score if it already exists. + fn zadd_multiple(key: K, items: &'a [(S, M)]) { + cmd("ZADD").arg(key).arg(items) + } + + /// Get the number of members in a sorted set. + fn zcard(key: K) { + cmd("ZCARD").arg(key) + } + + /// Count the members in a sorted set with scores within the given values. + fn zcount(key: K, min: M, max: MM) { + cmd("ZCOUNT").arg(key).arg(min).arg(max) + } + + /// Increments the member in a sorted set at key by delta. + /// If the member does not exist, it is added with delta as its score. + fn zincr(key: K, member: M, delta: D) { + cmd("ZINCRBY").arg(key).arg(delta).arg(member) + } + + /// Intersect multiple sorted sets and store the resulting sorted set in + /// a new key using SUM as aggregation function. + fn zinterstore(dstkey: D, keys: &'a [K]) { + cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys) + } + + /// Intersect multiple sorted sets and store the resulting sorted set in + /// a new key using MIN as aggregation function. + fn zinterstore_min(dstkey: D, keys: &'a [K]) { + cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MIN") + } + + /// Intersect multiple sorted sets and store the resulting sorted set in + /// a new key using MAX as aggregation function. + fn zinterstore_max(dstkey: D, keys: &'a [K]) { + cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MAX") + } + + /// [`Commands::zinterstore`], but with the ability to specify a + /// multiplication factor for each sorted set by pairing one with each key + /// in a tuple. + fn zinterstore_weights(dstkey: D, keys: &'a [(K, W)]) { + let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight):&(K, W)| -> (&K, &W) {(key, weight)}).unzip(); + cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("WEIGHTS").arg(weights) + } + + /// [`Commands::zinterstore_min`], but with the ability to specify a + /// multiplication factor for each sorted set by pairing one with each key + /// in a tuple. + fn zinterstore_min_weights(dstkey: D, keys: &'a [(K, W)]) { + let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight):&(K, W)| -> (&K, &W) {(key, weight)}).unzip(); + cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MIN").arg("WEIGHTS").arg(weights) + } + + /// [`Commands::zinterstore_max`], but with the ability to specify a + /// multiplication factor for each sorted set by pairing one with each key + /// in a tuple. + fn zinterstore_max_weights(dstkey: D, keys: &'a [(K, W)]) { + let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight):&(K, W)| -> (&K, &W) {(key, weight)}).unzip(); + cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MAX").arg("WEIGHTS").arg(weights) + } + + /// Count the number of members in a sorted set between a given lexicographical range. + fn zlexcount(key: K, min: M, max: MM) { + cmd("ZLEXCOUNT").arg(key).arg(min).arg(max) + } + + /// Removes and returns the member with the highest score in a sorted set. + /// Blocks until a member is available otherwise. + fn bzpopmax(key: K, timeout: f64) { + cmd("BZPOPMAX").arg(key).arg(timeout) + } + + /// Removes and returns up to count members with the highest scores in a sorted set + fn zpopmax(key: K, count: isize) { + cmd("ZPOPMAX").arg(key).arg(count) + } + + /// Removes and returns the member with the lowest score in a sorted set. + /// Blocks until a member is available otherwise. + fn bzpopmin(key: K, timeout: f64) { + cmd("BZPOPMIN").arg(key).arg(timeout) + } + + /// Removes and returns up to count members with the lowest scores in a sorted set + fn zpopmin(key: K, count: isize) { + cmd("ZPOPMIN").arg(key).arg(count) + } + + /// Removes and returns up to count members with the highest scores, + /// from the first non-empty sorted set in the provided list of key names. + /// Blocks until a member is available otherwise. + fn bzmpop_max(timeout: f64, keys: &'a [K], count: isize) { + cmd("BZMPOP").arg(timeout).arg(keys.len()).arg(keys).arg("MAX").arg("COUNT").arg(count) + } + + /// Removes and returns up to count members with the highest scores, + /// from the first non-empty sorted set in the provided list of key names. + fn zmpop_max(keys: &'a [K], count: isize) { + cmd("ZMPOP").arg(keys.len()).arg(keys).arg("MAX").arg("COUNT").arg(count) + } + + /// Removes and returns up to count members with the lowest scores, + /// from the first non-empty sorted set in the provided list of key names. + /// Blocks until a member is available otherwise. + fn bzmpop_min(timeout: f64, keys: &'a [K], count: isize) { + cmd("BZMPOP").arg(timeout).arg(keys.len()).arg(keys).arg("MIN").arg("COUNT").arg(count) + } + + /// Removes and returns up to count members with the lowest scores, + /// from the first non-empty sorted set in the provided list of key names. + fn zmpop_min(keys: &'a [K], count: isize) { + cmd("ZMPOP").arg(keys.len()).arg(keys).arg("MIN").arg("COUNT").arg(count) + } + + /// Return up to count random members in a sorted set (or 1 if `count == None`) + fn zrandmember(key: K, count: Option) { + cmd("ZRANDMEMBER").arg(key).arg(count) + } + + /// Return up to count random members in a sorted set with scores + fn zrandmember_withscores(key: K, count: isize) { + cmd("ZRANDMEMBER").arg(key).arg(count).arg("WITHSCORES") + } + + /// Return a range of members in a sorted set, by index + fn zrange(key: K, start: isize, stop: isize) { + cmd("ZRANGE").arg(key).arg(start).arg(stop) + } + + /// Return a range of members in a sorted set, by index with scores. + fn zrange_withscores(key: K, start: isize, stop: isize) { + cmd("ZRANGE").arg(key).arg(start).arg(stop).arg("WITHSCORES") + } + + /// Return a range of members in a sorted set, by lexicographical range. + fn zrangebylex(key: K, min: M, max: MM) { + cmd("ZRANGEBYLEX").arg(key).arg(min).arg(max) + } + + /// Return a range of members in a sorted set, by lexicographical + /// range with offset and limit. + fn zrangebylex_limit( + key: K, min: M, max: MM, offset: isize, count: isize) { + cmd("ZRANGEBYLEX").arg(key).arg(min).arg(max).arg("LIMIT").arg(offset).arg(count) + } + + /// Return a range of members in a sorted set, by lexicographical range. + fn zrevrangebylex(key: K, max: MM, min: M) { + cmd("ZREVRANGEBYLEX").arg(key).arg(max).arg(min) + } + + /// Return a range of members in a sorted set, by lexicographical + /// range with offset and limit. + fn zrevrangebylex_limit( + key: K, max: MM, min: M, offset: isize, count: isize) { + cmd("ZREVRANGEBYLEX").arg(key).arg(max).arg(min).arg("LIMIT").arg(offset).arg(count) + } + + /// Return a range of members in a sorted set, by score. + fn zrangebyscore(key: K, min: M, max: MM) { + cmd("ZRANGEBYSCORE").arg(key).arg(min).arg(max) + } + + /// Return a range of members in a sorted set, by score with scores. + fn zrangebyscore_withscores(key: K, min: M, max: MM) { + cmd("ZRANGEBYSCORE").arg(key).arg(min).arg(max).arg("WITHSCORES") + } + + /// Return a range of members in a sorted set, by score with limit. + fn zrangebyscore_limit + (key: K, min: M, max: MM, offset: isize, count: isize) { + cmd("ZRANGEBYSCORE").arg(key).arg(min).arg(max).arg("LIMIT").arg(offset).arg(count) + } + + /// Return a range of members in a sorted set, by score with limit with scores. + fn zrangebyscore_limit_withscores + (key: K, min: M, max: MM, offset: isize, count: isize) { + cmd("ZRANGEBYSCORE").arg(key).arg(min).arg(max).arg("WITHSCORES") + .arg("LIMIT").arg(offset).arg(count) + } + + /// Determine the index of a member in a sorted set. + fn zrank(key: K, member: M) { + cmd("ZRANK").arg(key).arg(member) + } + + /// Remove one or more members from a sorted set. + fn zrem(key: K, members: M) { + cmd("ZREM").arg(key).arg(members) + } + + /// Remove all members in a sorted set between the given lexicographical range. + fn zrembylex(key: K, min: M, max: MM) { + cmd("ZREMRANGEBYLEX").arg(key).arg(min).arg(max) + } + + /// Remove all members in a sorted set within the given indexes. + fn zremrangebyrank(key: K, start: isize, stop: isize) { + cmd("ZREMRANGEBYRANK").arg(key).arg(start).arg(stop) + } + + /// Remove all members in a sorted set within the given scores. + fn zrembyscore(key: K, min: M, max: MM) { + cmd("ZREMRANGEBYSCORE").arg(key).arg(min).arg(max) + } + + /// Return a range of members in a sorted set, by index, with scores + /// ordered from high to low. + fn zrevrange(key: K, start: isize, stop: isize) { + cmd("ZREVRANGE").arg(key).arg(start).arg(stop) + } + + /// Return a range of members in a sorted set, by index, with scores + /// ordered from high to low. + fn zrevrange_withscores(key: K, start: isize, stop: isize) { + cmd("ZREVRANGE").arg(key).arg(start).arg(stop).arg("WITHSCORES") + } + + /// Return a range of members in a sorted set, by score. + fn zrevrangebyscore(key: K, max: MM, min: M) { + cmd("ZREVRANGEBYSCORE").arg(key).arg(max).arg(min) + } + + /// Return a range of members in a sorted set, by score with scores. + fn zrevrangebyscore_withscores(key: K, max: MM, min: M) { + cmd("ZREVRANGEBYSCORE").arg(key).arg(max).arg(min).arg("WITHSCORES") + } + + /// Return a range of members in a sorted set, by score with limit. + fn zrevrangebyscore_limit + (key: K, max: MM, min: M, offset: isize, count: isize) { + cmd("ZREVRANGEBYSCORE").arg(key).arg(max).arg(min).arg("LIMIT").arg(offset).arg(count) + } + + /// Return a range of members in a sorted set, by score with limit with scores. + fn zrevrangebyscore_limit_withscores + (key: K, max: MM, min: M, offset: isize, count: isize) { + cmd("ZREVRANGEBYSCORE").arg(key).arg(max).arg(min).arg("WITHSCORES") + .arg("LIMIT").arg(offset).arg(count) + } + + /// Determine the index of a member in a sorted set, with scores ordered from high to low. + fn zrevrank(key: K, member: M) { + cmd("ZREVRANK").arg(key).arg(member) + } + + /// Get the score associated with the given member in a sorted set. + fn zscore(key: K, member: M) { + cmd("ZSCORE").arg(key).arg(member) + } + + /// Get the scores associated with multiple members in a sorted set. + fn zscore_multiple(key: K, members: &'a [M]) { + cmd("ZMSCORE").arg(key).arg(members) + } + + /// Unions multiple sorted sets and store the resulting sorted set in + /// a new key using SUM as aggregation function. + fn zunionstore(dstkey: D, keys: &'a [K]) { + cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys) + } + + /// Unions multiple sorted sets and store the resulting sorted set in + /// a new key using MIN as aggregation function. + fn zunionstore_min(dstkey: D, keys: &'a [K]) { + cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MIN") + } + + /// Unions multiple sorted sets and store the resulting sorted set in + /// a new key using MAX as aggregation function. + fn zunionstore_max(dstkey: D, keys: &'a [K]) { + cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MAX") + } + + /// [`Commands::zunionstore`], but with the ability to specify a + /// multiplication factor for each sorted set by pairing one with each key + /// in a tuple. + fn zunionstore_weights(dstkey: D, keys: &'a [(K, W)]) { + let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight):&(K, W)| -> (&K, &W) {(key, weight)}).unzip(); + cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("WEIGHTS").arg(weights) + } + + /// [`Commands::zunionstore_min`], but with the ability to specify a + /// multiplication factor for each sorted set by pairing one with each key + /// in a tuple. + fn zunionstore_min_weights(dstkey: D, keys: &'a [(K, W)]) { + let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight):&(K, W)| -> (&K, &W) {(key, weight)}).unzip(); + cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MIN").arg("WEIGHTS").arg(weights) + } + + /// [`Commands::zunionstore_max`], but with the ability to specify a + /// multiplication factor for each sorted set by pairing one with each key + /// in a tuple. + fn zunionstore_max_weights(dstkey: D, keys: &'a [(K, W)]) { + let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight):&(K, W)| -> (&K, &W) {(key, weight)}).unzip(); + cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MAX").arg("WEIGHTS").arg(weights) + } + + // hyperloglog commands + + /// Adds the specified elements to the specified HyperLogLog. + fn pfadd(key: K, element: E) { + cmd("PFADD").arg(key).arg(element) + } + + /// Return the approximated cardinality of the set(s) observed by the + /// HyperLogLog at key(s). + fn pfcount(key: K) { + cmd("PFCOUNT").arg(key) + } + + /// Merge N different HyperLogLogs into a single one. + fn pfmerge(dstkey: D, srckeys: S) { + cmd("PFMERGE").arg(dstkey).arg(srckeys) + } + + /// Posts a message to the given channel. + fn publish(channel: K, message: E) { + cmd("PUBLISH").arg(channel).arg(message) + } + + // Object commands + + /// Returns the encoding of a key. + fn object_encoding(key: K) { + cmd("OBJECT").arg("ENCODING").arg(key) + } + + /// Returns the time in seconds since the last access of a key. + fn object_idletime(key: K) { + cmd("OBJECT").arg("IDLETIME").arg(key) + } + + /// Returns the logarithmic access frequency counter of a key. + fn object_freq(key: K) { + cmd("OBJECT").arg("FREQ").arg(key) + } + + /// Returns the reference count of a key. + fn object_refcount(key: K) { + cmd("OBJECT").arg("REFCOUNT").arg(key) + } + + // ACL commands + + /// When Redis is configured to use an ACL file (with the aclfile + /// configuration option), this command will reload the ACLs from the file, + /// replacing all the current ACL rules with the ones defined in the file. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_load<>() { + cmd("ACL").arg("LOAD") + } + + /// When Redis is configured to use an ACL file (with the aclfile + /// configuration option), this command will save the currently defined + /// ACLs from the server memory to the ACL file. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_save<>() { + cmd("ACL").arg("SAVE") + } + + /// Shows the currently active ACL rules in the Redis server. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_list<>() { + cmd("ACL").arg("LIST") + } + + /// Shows a list of all the usernames of the currently configured users in + /// the Redis ACL system. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_users<>() { + cmd("ACL").arg("USERS") + } + + /// Returns all the rules defined for an existing ACL user. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_getuser(username: K) { + cmd("ACL").arg("GETUSER").arg(username) + } + + /// Creates an ACL user without any privilege. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_setuser(username: K) { + cmd("ACL").arg("SETUSER").arg(username) + } + + /// Creates an ACL user with the specified rules or modify the rules of + /// an existing user. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_setuser_rules(username: K, rules: &'a [acl::Rule]) { + cmd("ACL").arg("SETUSER").arg(username).arg(rules) + } + + /// Delete all the specified ACL users and terminate all the connections + /// that are authenticated with such users. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_deluser(usernames: &'a [K]) { + cmd("ACL").arg("DELUSER").arg(usernames) + } + + /// Shows the available ACL categories. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_cat<>() { + cmd("ACL").arg("CAT") + } + + /// Shows all the Redis commands in the specified category. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_cat_categoryname(categoryname: K) { + cmd("ACL").arg("CAT").arg(categoryname) + } + + /// Generates a 256-bits password starting from /dev/urandom if available. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_genpass<>() { + cmd("ACL").arg("GENPASS") + } + + /// Generates a 1-to-1024-bits password starting from /dev/urandom if available. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_genpass_bits<>(bits: isize) { + cmd("ACL").arg("GENPASS").arg(bits) + } + + /// Returns the username the current connection is authenticated with. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_whoami<>() { + cmd("ACL").arg("WHOAMI") + } + + /// Shows a list of recent ACL security events + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_log<>(count: isize) { + cmd("ACL").arg("LOG").arg(count) + + } + + /// Clears the ACL log. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_log_reset<>() { + cmd("ACL").arg("LOG").arg("RESET") + } + + /// Returns a helpful text describing the different subcommands. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_help<>() { + cmd("ACL").arg("HELP") + } + + // + // geospatial commands + // + + /// Adds the specified geospatial items to the specified key. + /// + /// Every member has to be written as a tuple of `(longitude, latitude, + /// member_name)`. It can be a single tuple, or a vector of tuples. + /// + /// `longitude, latitude` can be set using [`redis::geo::Coord`][1]. + /// + /// [1]: ./geo/struct.Coord.html + /// + /// Returns the number of elements added to the sorted set, not including + /// elements already existing for which the score was updated. + /// + /// # Example + /// + /// ```rust,no_run + /// use redis::{Commands, Connection, RedisResult}; + /// use redis::geo::Coord; + /// + /// fn add_point(con: &mut Connection) -> RedisResult { + /// con.geo_add("my_gis", (Coord::lon_lat(13.361389, 38.115556), "Palermo")) + /// } + /// + /// fn add_point_with_tuples(con: &mut Connection) -> RedisResult { + /// con.geo_add("my_gis", ("13.361389", "38.115556", "Palermo")) + /// } + /// + /// fn add_many_points(con: &mut Connection) -> RedisResult { + /// con.geo_add("my_gis", &[ + /// ("13.361389", "38.115556", "Palermo"), + /// ("15.087269", "37.502669", "Catania") + /// ]) + /// } + /// ``` + #[cfg(feature = "geospatial")] + #[cfg_attr(docsrs, doc(cfg(feature = "geospatial")))] + fn geo_add(key: K, members: M) { + cmd("GEOADD").arg(key).arg(members) + } + + /// Return the distance between two members in the geospatial index + /// represented by the sorted set. + /// + /// If one or both the members are missing, the command returns NULL, so + /// it may be convenient to parse its response as either `Option` or + /// `Option`. + /// + /// # Example + /// + /// ```rust,no_run + /// use redis::{Commands, RedisResult}; + /// use redis::geo::Unit; + /// + /// fn get_dists(con: &mut redis::Connection) { + /// let x: RedisResult = con.geo_dist( + /// "my_gis", + /// "Palermo", + /// "Catania", + /// Unit::Kilometers + /// ); + /// // x is Ok(166.2742) + /// + /// let x: RedisResult> = con.geo_dist( + /// "my_gis", + /// "Palermo", + /// "Atlantis", + /// Unit::Meters + /// ); + /// // x is Ok(None) + /// } + /// ``` + #[cfg(feature = "geospatial")] + #[cfg_attr(docsrs, doc(cfg(feature = "geospatial")))] + fn geo_dist( + key: K, + member1: M1, + member2: M2, + unit: geo::Unit + ) { + cmd("GEODIST") + .arg(key) + .arg(member1) + .arg(member2) + .arg(unit) + } + + /// Return valid [Geohash][1] strings representing the position of one or + /// more members of the geospatial index represented by the sorted set at + /// key. + /// + /// [1]: https://en.wikipedia.org/wiki/Geohash + /// + /// # Example + /// + /// ```rust,no_run + /// use redis::{Commands, RedisResult}; + /// + /// fn get_hash(con: &mut redis::Connection) { + /// let x: RedisResult> = con.geo_hash("my_gis", "Palermo"); + /// // x is vec!["sqc8b49rny0"] + /// + /// let x: RedisResult> = con.geo_hash("my_gis", &["Palermo", "Catania"]); + /// // x is vec!["sqc8b49rny0", "sqdtr74hyu0"] + /// } + /// ``` + #[cfg(feature = "geospatial")] + #[cfg_attr(docsrs, doc(cfg(feature = "geospatial")))] + fn geo_hash(key: K, members: M) { + cmd("GEOHASH").arg(key).arg(members) + } + + /// Return the positions of all the specified members of the geospatial + /// index represented by the sorted set at key. + /// + /// Every position is a pair of `(longitude, latitude)`. [`redis::geo::Coord`][1] + /// can be used to convert these value in a struct. + /// + /// [1]: ./geo/struct.Coord.html + /// + /// # Example + /// + /// ```rust,no_run + /// use redis::{Commands, RedisResult}; + /// use redis::geo::Coord; + /// + /// fn get_position(con: &mut redis::Connection) { + /// let x: RedisResult>> = con.geo_pos("my_gis", &["Palermo", "Catania"]); + /// // x is [ [ 13.361389, 38.115556 ], [ 15.087269, 37.502669 ] ]; + /// + /// let x: Vec> = con.geo_pos("my_gis", "Palermo").unwrap(); + /// // x[0].longitude is 13.361389 + /// // x[0].latitude is 38.115556 + /// } + /// ``` + #[cfg(feature = "geospatial")] + #[cfg_attr(docsrs, doc(cfg(feature = "geospatial")))] + fn geo_pos(key: K, members: M) { + cmd("GEOPOS").arg(key).arg(members) + } + + /// Return the members of a sorted set populated with geospatial information + /// using [`geo_add`](#method.geo_add), which are within the borders of the area + /// specified with the center location and the maximum distance from the center + /// (the radius). + /// + /// Every item in the result can be read with [`redis::geo::RadiusSearchResult`][1], + /// which support the multiple formats returned by `GEORADIUS`. + /// + /// [1]: ./geo/struct.RadiusSearchResult.html + /// + /// ```rust,no_run + /// use redis::{Commands, RedisResult}; + /// use redis::geo::{RadiusOptions, RadiusSearchResult, RadiusOrder, Unit}; + /// + /// fn radius(con: &mut redis::Connection) -> Vec { + /// let opts = RadiusOptions::default().with_dist().order(RadiusOrder::Asc); + /// con.geo_radius("my_gis", 15.90, 37.21, 51.39, Unit::Kilometers, opts).unwrap() + /// } + /// ``` + #[cfg(feature = "geospatial")] + #[cfg_attr(docsrs, doc(cfg(feature = "geospatial")))] + fn geo_radius( + key: K, + longitude: f64, + latitude: f64, + radius: f64, + unit: geo::Unit, + options: geo::RadiusOptions + ) { + cmd("GEORADIUS") + .arg(key) + .arg(longitude) + .arg(latitude) + .arg(radius) + .arg(unit) + .arg(options) + } + + /// Retrieve members selected by distance with the center of `member`. The + /// member itself is always contained in the results. + #[cfg(feature = "geospatial")] + #[cfg_attr(docsrs, doc(cfg(feature = "geospatial")))] + fn geo_radius_by_member( + key: K, + member: M, + radius: f64, + unit: geo::Unit, + options: geo::RadiusOptions + ) { + cmd("GEORADIUSBYMEMBER") + .arg(key) + .arg(member) + .arg(radius) + .arg(unit) + .arg(options) + } + + // + // streams commands + // + + /// Ack pending stream messages checked out by a consumer. + /// + /// ```text + /// XACK ... + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xack( + key: K, + group: G, + ids: &'a [I]) { + cmd("XACK") + .arg(key) + .arg(group) + .arg(ids) + } + + + /// Add a stream message by `key`. Use `*` as the `id` for the current timestamp. + /// + /// ```text + /// XADD key [field value] [field value] ... + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xadd( + key: K, + id: ID, + items: &'a [(F, V)] + ) { + cmd("XADD").arg(key).arg(id).arg(items) + } + + + /// BTreeMap variant for adding a stream message by `key`. + /// Use `*` as the `id` for the current timestamp. + /// + /// ```text + /// XADD key [rust BTreeMap] ... + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xadd_map( + key: K, + id: ID, + map: BTM + ) { + cmd("XADD").arg(key).arg(id).arg(map) + } + + /// Add a stream message while capping the stream at a maxlength. + /// + /// ```text + /// XADD key [MAXLEN [~|=] ] [field value] [field value] ... + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xadd_maxlen< + K: ToRedisArgs, + ID: ToRedisArgs, + F: ToRedisArgs, + V: ToRedisArgs + >( + key: K, + maxlen: streams::StreamMaxlen, + id: ID, + items: &'a [(F, V)] + ) { + cmd("XADD") + .arg(key) + .arg(maxlen) + .arg(id) + .arg(items) + } + + + /// BTreeMap variant for adding a stream message while capping the stream at a maxlength. + /// + /// ```text + /// XADD key [MAXLEN [~|=] ] [rust BTreeMap] ... + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xadd_maxlen_map( + key: K, + maxlen: streams::StreamMaxlen, + id: ID, + map: BTM + ) { + cmd("XADD") + .arg(key) + .arg(maxlen) + .arg(id) + .arg(map) + } + + + + /// Claim pending, unacked messages, after some period of time, + /// currently checked out by another consumer. + /// + /// This method only accepts the must-have arguments for claiming messages. + /// If optional arguments are required, see `xclaim_options` below. + /// + /// ```text + /// XCLAIM [ ] + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xclaim( + key: K, + group: G, + consumer: C, + min_idle_time: MIT, + ids: &'a [ID] + ) { + cmd("XCLAIM") + .arg(key) + .arg(group) + .arg(consumer) + .arg(min_idle_time) + .arg(ids) + } + + /// This is the optional arguments version for claiming unacked, pending messages + /// currently checked out by another consumer. + /// + /// ```no_run + /// use redis::{Connection,Commands,RedisResult}; + /// use redis::streams::{StreamClaimOptions,StreamClaimReply}; + /// let client = redis::Client::open("redis://127.0.0.1/0").unwrap(); + /// let mut con = client.get_connection(None).unwrap(); + /// + /// // Claim all pending messages for key "k1", + /// // from group "g1", checked out by consumer "c1" + /// // for 10ms with RETRYCOUNT 2 and FORCE + /// + /// let opts = StreamClaimOptions::default() + /// .with_force() + /// .retry(2); + /// let results: RedisResult = + /// con.xclaim_options("k1", "g1", "c1", 10, &["0"], opts); + /// + /// // All optional arguments return a `Result` with one exception: + /// // Passing JUSTID returns only the message `id` and omits the HashMap for each message. + /// + /// let opts = StreamClaimOptions::default() + /// .with_justid(); + /// let results: RedisResult> = + /// con.xclaim_options("k1", "g1", "c1", 10, &["0"], opts); + /// ``` + /// + /// ```text + /// XCLAIM + /// [IDLE ] [TIME ] [RETRYCOUNT ] + /// [FORCE] [JUSTID] + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xclaim_options< + K: ToRedisArgs, + G: ToRedisArgs, + C: ToRedisArgs, + MIT: ToRedisArgs, + ID: ToRedisArgs + >( + key: K, + group: G, + consumer: C, + min_idle_time: MIT, + ids: &'a [ID], + options: streams::StreamClaimOptions + ) { + cmd("XCLAIM") + .arg(key) + .arg(group) + .arg(consumer) + .arg(min_idle_time) + .arg(ids) + .arg(options) + } + + + /// Deletes a list of `id`s for a given stream `key`. + /// + /// ```text + /// XDEL [ ... ] + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xdel( + key: K, + ids: &'a [ID] + ) { + cmd("XDEL").arg(key).arg(ids) + } + + + /// This command is used for creating a consumer `group`. It expects the stream key + /// to already exist. Otherwise, use `xgroup_create_mkstream` if it doesn't. + /// The `id` is the starting message id all consumers should read from. Use `$` If you want + /// all consumers to read from the last message added to stream. + /// + /// ```text + /// XGROUP CREATE + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xgroup_create( + key: K, + group: G, + id: ID + ) { + cmd("XGROUP") + .arg("CREATE") + .arg(key) + .arg(group) + .arg(id) + } + + + /// This is the alternate version for creating a consumer `group` + /// which makes the stream if it doesn't exist. + /// + /// ```text + /// XGROUP CREATE [MKSTREAM] + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xgroup_create_mkstream< + K: ToRedisArgs, + G: ToRedisArgs, + ID: ToRedisArgs + >( + key: K, + group: G, + id: ID + ) { + cmd("XGROUP") + .arg("CREATE") + .arg(key) + .arg(group) + .arg(id) + .arg("MKSTREAM") + } + + + /// Alter which `id` you want consumers to begin reading from an existing + /// consumer `group`. + /// + /// ```text + /// XGROUP SETID + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xgroup_setid( + key: K, + group: G, + id: ID + ) { + cmd("XGROUP") + .arg("SETID") + .arg(key) + .arg(group) + .arg(id) + } + + + /// Destroy an existing consumer `group` for a given stream `key` + /// + /// ```text + /// XGROUP SETID + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xgroup_destroy( + key: K, + group: G + ) { + cmd("XGROUP").arg("DESTROY").arg(key).arg(group) + } + + /// This deletes a `consumer` from an existing consumer `group` + /// for given stream `key. + /// + /// ```text + /// XGROUP DELCONSUMER + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xgroup_delconsumer( + key: K, + group: G, + consumer: C + ) { + cmd("XGROUP") + .arg("DELCONSUMER") + .arg(key) + .arg(group) + .arg(consumer) + } + + + /// This returns all info details about + /// which consumers have read messages for given consumer `group`. + /// Take note of the StreamInfoConsumersReply return type. + /// + /// *It's possible this return value might not contain new fields + /// added by Redis in future versions.* + /// + /// ```text + /// XINFO CONSUMERS + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xinfo_consumers( + key: K, + group: G + ) { + cmd("XINFO") + .arg("CONSUMERS") + .arg(key) + .arg(group) + } + + + /// Returns all consumer `group`s created for a given stream `key`. + /// Take note of the StreamInfoGroupsReply return type. + /// + /// *It's possible this return value might not contain new fields + /// added by Redis in future versions.* + /// + /// ```text + /// XINFO GROUPS + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xinfo_groups(key: K) { + cmd("XINFO").arg("GROUPS").arg(key) + } + + + /// Returns info about high-level stream details + /// (first & last message `id`, length, number of groups, etc.) + /// Take note of the StreamInfoStreamReply return type. + /// + /// *It's possible this return value might not contain new fields + /// added by Redis in future versions.* + /// + /// ```text + /// XINFO STREAM + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xinfo_stream(key: K) { + cmd("XINFO").arg("STREAM").arg(key) + } + + /// Returns the number of messages for a given stream `key`. + /// + /// ```text + /// XLEN + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xlen(key: K) { + cmd("XLEN").arg(key) + } + + + /// This is a basic version of making XPENDING command calls which only + /// passes a stream `key` and consumer `group` and it + /// returns details about which consumers have pending messages + /// that haven't been acked. + /// + /// You can use this method along with + /// `xclaim` or `xclaim_options` for determining which messages + /// need to be retried. + /// + /// Take note of the StreamPendingReply return type. + /// + /// ```text + /// XPENDING [ []] + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xpending( + key: K, + group: G + ) { + cmd("XPENDING").arg(key).arg(group) + } + + + /// This XPENDING version returns a list of all messages over the range. + /// You can use this for paginating pending messages (but without the message HashMap). + /// + /// Start and end follow the same rules `xrange` args. Set start to `-` + /// and end to `+` for the entire stream. + /// + /// Take note of the StreamPendingCountReply return type. + /// + /// ```text + /// XPENDING + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xpending_count< + K: ToRedisArgs, + G: ToRedisArgs, + S: ToRedisArgs, + E: ToRedisArgs, + C: ToRedisArgs + >( + key: K, + group: G, + start: S, + end: E, + count: C + ) { + cmd("XPENDING") + .arg(key) + .arg(group) + .arg(start) + .arg(end) + .arg(count) + } + + + /// An alternate version of `xpending_count` which filters by `consumer` name. + /// + /// Start and end follow the same rules `xrange` args. Set start to `-` + /// and end to `+` for the entire stream. + /// + /// Take note of the StreamPendingCountReply return type. + /// + /// ```text + /// XPENDING + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xpending_consumer_count< + K: ToRedisArgs, + G: ToRedisArgs, + S: ToRedisArgs, + E: ToRedisArgs, + C: ToRedisArgs, + CN: ToRedisArgs + >( + key: K, + group: G, + start: S, + end: E, + count: C, + consumer: CN + ) { + cmd("XPENDING") + .arg(key) + .arg(group) + .arg(start) + .arg(end) + .arg(count) + .arg(consumer) + } + + /// Returns a range of messages in a given stream `key`. + /// + /// Set `start` to `-` to begin at the first message. + /// Set `end` to `+` to end the most recent message. + /// You can pass message `id` to both `start` and `end`. + /// + /// Take note of the StreamRangeReply return type. + /// + /// ```text + /// XRANGE key start end + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xrange( + key: K, + start: S, + end: E + ) { + cmd("XRANGE").arg(key).arg(start).arg(end) + } + + + /// A helper method for automatically returning all messages in a stream by `key`. + /// **Use with caution!** + /// + /// ```text + /// XRANGE key - + + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xrange_all(key: K) { + cmd("XRANGE").arg(key).arg("-").arg("+") + } + + + /// A method for paginating a stream by `key`. + /// + /// ```text + /// XRANGE key start end [COUNT ] + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xrange_count( + key: K, + start: S, + end: E, + count: C + ) { + cmd("XRANGE") + .arg(key) + .arg(start) + .arg(end) + .arg("COUNT") + .arg(count) + } + + + /// Read a list of `id`s for each stream `key`. + /// This is the basic form of reading streams. + /// For more advanced control, like blocking, limiting, or reading by consumer `group`, + /// see `xread_options`. + /// + /// ```text + /// XREAD STREAMS key_1 key_2 ... key_N ID_1 ID_2 ... ID_N + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xread( + keys: &'a [K], + ids: &'a [ID] + ) { + cmd("XREAD").arg("STREAMS").arg(keys).arg(ids) + } + + /// This method handles setting optional arguments for + /// `XREAD` or `XREADGROUP` Redis commands. + /// ```no_run + /// use redis::{Connection,RedisResult,Commands}; + /// use redis::streams::{StreamReadOptions,StreamReadReply}; + /// let client = redis::Client::open("redis://127.0.0.1/0").unwrap(); + /// let mut con = client.get_connection(None).unwrap(); + /// + /// // Read 10 messages from the start of the stream, + /// // without registering as a consumer group. + /// + /// let opts = StreamReadOptions::default() + /// .count(10); + /// let results: RedisResult = + /// con.xread_options(&["k1"], &["0"], &opts); + /// + /// // Read all undelivered messages for a given + /// // consumer group. Be advised: the consumer group must already + /// // exist before making this call. Also note: we're passing + /// // '>' as the id here, which means all undelivered messages. + /// + /// let opts = StreamReadOptions::default() + /// .group("group-1", "consumer-1"); + /// let results: RedisResult = + /// con.xread_options(&["k1"], &[">"], &opts); + /// ``` + /// + /// ```text + /// XREAD [BLOCK ] [COUNT ] + /// STREAMS key_1 key_2 ... key_N + /// ID_1 ID_2 ... ID_N + /// + /// XREADGROUP [GROUP group-name consumer-name] [BLOCK ] [COUNT ] [NOACK] + /// STREAMS key_1 key_2 ... key_N + /// ID_1 ID_2 ... ID_N + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xread_options( + keys: &'a [K], + ids: &'a [ID], + options: &'a streams::StreamReadOptions + ) { + cmd(if options.read_only() { + "XREAD" + } else { + "XREADGROUP" + }) + .arg(options) + .arg("STREAMS") + .arg(keys) + .arg(ids) + } + + /// This is the reverse version of `xrange`. + /// The same rules apply for `start` and `end` here. + /// + /// ```text + /// XREVRANGE key end start + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xrevrange( + key: K, + end: E, + start: S + ) { + cmd("XREVRANGE").arg(key).arg(end).arg(start) + } + + /// This is the reverse version of `xrange_all`. + /// The same rules apply for `start` and `end` here. + /// + /// ```text + /// XREVRANGE key + - + /// ``` + fn xrevrange_all(key: K) { + cmd("XREVRANGE").arg(key).arg("+").arg("-") + } + + /// This is the reverse version of `xrange_count`. + /// The same rules apply for `start` and `end` here. + /// + /// ```text + /// XREVRANGE key end start [COUNT ] + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xrevrange_count( + key: K, + end: E, + start: S, + count: C + ) { + cmd("XREVRANGE") + .arg(key) + .arg(end) + .arg(start) + .arg("COUNT") + .arg(count) + } + + + /// Trim a stream `key` to a MAXLEN count. + /// + /// ```text + /// XTRIM MAXLEN [~|=] (Same as XADD MAXLEN option) + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xtrim( + key: K, + maxlen: streams::StreamMaxlen + ) { + cmd("XTRIM").arg(key).arg(maxlen) + } +} + +/// Allows pubsub callbacks to stop receiving messages. +/// +/// Arbitrary data may be returned from `Break`. +pub enum ControlFlow { + /// Continues. + Continue, + /// Breaks with a value. + Break(U), +} + +/// The PubSub trait allows subscribing to one or more channels +/// and receiving a callback whenever a message arrives. +/// +/// Each method handles subscribing to the list of keys, waiting for +/// messages, and unsubscribing from the same list of channels once +/// a ControlFlow::Break is encountered. +/// +/// Once (p)subscribe returns Ok(U), the connection is again safe to use +/// for calling other methods. +/// +/// # Examples +/// +/// ```rust,no_run +/// # fn do_something() -> redis::RedisResult<()> { +/// use redis::{PubSubCommands, ControlFlow}; +/// let client = redis::Client::open("redis://127.0.0.1/")?; +/// let mut con = client.get_connection(None)?; +/// let mut count = 0; +/// con.subscribe(&["foo"], |msg| { +/// // do something with message +/// assert_eq!(msg.get_channel(), Ok(String::from("foo"))); +/// +/// // increment messages seen counter +/// count += 1; +/// match count { +/// // stop after receiving 10 messages +/// 10 => ControlFlow::Break(()), +/// _ => ControlFlow::Continue, +/// } +/// })?; +/// # Ok(()) } +/// ``` +// TODO In the future, it would be nice to implement Try such that `?` will work +// within the closure. +pub trait PubSubCommands: Sized { + /// Subscribe to a list of channels using SUBSCRIBE and run the provided + /// closure for each message received. + /// + /// For every `Msg` passed to the provided closure, either + /// `ControlFlow::Break` or `ControlFlow::Continue` must be returned. This + /// method will not return until `ControlFlow::Break` is observed. + fn subscribe(&mut self, _: C, _: F) -> RedisResult + where + F: FnMut(Msg) -> ControlFlow, + C: ToRedisArgs; + + /// Subscribe to a list of channels using PSUBSCRIBE and run the provided + /// closure for each message received. + /// + /// For every `Msg` passed to the provided closure, either + /// `ControlFlow::Break` or `ControlFlow::Continue` must be returned. This + /// method will not return until `ControlFlow::Break` is observed. + fn psubscribe(&mut self, _: P, _: F) -> RedisResult + where + F: FnMut(Msg) -> ControlFlow, + P: ToRedisArgs; +} + +impl Commands for T where T: ConnectionLike {} + +#[cfg(feature = "aio")] +impl AsyncCommands for T where T: crate::aio::ConnectionLike + Send + Sized {} + +impl PubSubCommands for Connection { + fn subscribe(&mut self, channels: C, mut func: F) -> RedisResult + where + F: FnMut(Msg) -> ControlFlow, + C: ToRedisArgs, + { + let mut pubsub = self.as_pubsub(); + pubsub.subscribe(channels)?; + + loop { + let msg = pubsub.get_message()?; + match func(msg) { + ControlFlow::Continue => continue, + ControlFlow::Break(value) => return Ok(value), + } + } + } + + fn psubscribe(&mut self, patterns: P, mut func: F) -> RedisResult + where + F: FnMut(Msg) -> ControlFlow, + P: ToRedisArgs, + { + let mut pubsub = self.as_pubsub(); + pubsub.psubscribe(patterns)?; + + loop { + let msg = pubsub.get_message()?; + match func(msg) { + ControlFlow::Continue => continue, + ControlFlow::Break(value) => return Ok(value), + } + } + } +} + +/// Options for the [LPOS](https://redis.io/commands/lpos) command +/// +/// # Example +/// +/// ```rust,no_run +/// use redis::{Commands, RedisResult, LposOptions}; +/// fn fetch_list_position( +/// con: &mut redis::Connection, +/// key: &str, +/// value: &str, +/// count: usize, +/// rank: isize, +/// maxlen: usize, +/// ) -> RedisResult> { +/// let opts = LposOptions::default() +/// .count(count) +/// .rank(rank) +/// .maxlen(maxlen); +/// con.lpos(key, value, opts) +/// } +/// ``` +#[derive(Default)] +pub struct LposOptions { + count: Option, + maxlen: Option, + rank: Option, +} + +impl LposOptions { + /// Limit the results to the first N matching items. + pub fn count(mut self, n: usize) -> Self { + self.count = Some(n); + self + } + + /// Return the value of N from the matching items. + pub fn rank(mut self, n: isize) -> Self { + self.rank = Some(n); + self + } + + /// Limit the search to N items in the list. + pub fn maxlen(mut self, n: usize) -> Self { + self.maxlen = Some(n); + self + } +} + +impl ToRedisArgs for LposOptions { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + if let Some(n) = self.count { + out.write_arg(b"COUNT"); + out.write_arg_fmt(n); + } + + if let Some(n) = self.rank { + out.write_arg(b"RANK"); + out.write_arg_fmt(n); + } + + if let Some(n) = self.maxlen { + out.write_arg(b"MAXLEN"); + out.write_arg_fmt(n); + } + } + + fn is_single_arg(&self) -> bool { + false + } +} + +/// Enum for the LEFT | RIGHT args used by some commands +pub enum Direction { + /// Targets the first element (head) of the list + Left, + /// Targets the last element (tail) of the list + Right, +} + +impl ToRedisArgs for Direction { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + let s: &[u8] = match self { + Direction::Left => b"LEFT", + Direction::Right => b"RIGHT", + }; + out.write_arg(s); + } +} + +/// Options for the [SET](https://redis.io/commands/set) command +/// +/// # Example +/// ```rust,no_run +/// use redis::{Commands, RedisResult, SetOptions, SetExpiry, ExistenceCheck}; +/// fn set_key_value( +/// con: &mut redis::Connection, +/// key: &str, +/// value: &str, +/// ) -> RedisResult> { +/// let opts = SetOptions::default() +/// .conditional_set(ExistenceCheck::NX) +/// .get(true) +/// .with_expiration(SetExpiry::EX(60)); +/// con.set_options(key, value, opts) +/// } +/// ``` +#[derive(Clone, Copy, Default)] +pub struct SetOptions { + conditional_set: Option, + get: bool, + expiration: Option, +} + +impl SetOptions { + /// Set the existence check for the SET command + pub fn conditional_set(mut self, existence_check: ExistenceCheck) -> Self { + self.conditional_set = Some(existence_check); + self + } + + /// Set the GET option for the SET command + pub fn get(mut self, get: bool) -> Self { + self.get = get; + self + } + + /// Set the expiration for the SET command + pub fn with_expiration(mut self, expiration: SetExpiry) -> Self { + self.expiration = Some(expiration); + self + } +} + +impl ToRedisArgs for SetOptions { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + if let Some(ref conditional_set) = self.conditional_set { + match conditional_set { + ExistenceCheck::NX => { + out.write_arg(b"NX"); + } + ExistenceCheck::XX => { + out.write_arg(b"XX"); + } + } + } + if self.get { + out.write_arg(b"GET"); + } + if let Some(ref expiration) = self.expiration { + match expiration { + SetExpiry::EX(secs) => { + out.write_arg(b"EX"); + out.write_arg(format!("{}", secs).as_bytes()); + } + SetExpiry::PX(millis) => { + out.write_arg(b"PX"); + out.write_arg(format!("{}", millis).as_bytes()); + } + SetExpiry::EXAT(unix_time) => { + out.write_arg(b"EXAT"); + out.write_arg(format!("{}", unix_time).as_bytes()); + } + SetExpiry::PXAT(unix_time) => { + out.write_arg(b"PXAT"); + out.write_arg(format!("{}", unix_time).as_bytes()); + } + SetExpiry::KEEPTTL => { + out.write_arg(b"KEEPTTL"); + } + } + } + } +} + +/// Creates HELLO command for RESP3 with RedisConnectionInfo +pub fn resp3_hello(connection_info: &RedisConnectionInfo) -> Cmd { + let mut hello_cmd = cmd("HELLO"); + hello_cmd.arg("3"); + if connection_info.password.is_some() { + let username: &str = match connection_info.username.as_ref() { + None => "default", + Some(username) => username, + }; + hello_cmd + .arg("AUTH") + .arg(username) + .arg(connection_info.password.as_ref().unwrap()); + } + hello_cmd +} diff --git a/glide-core/redis-rs/redis/src/connection.rs b/glide-core/redis-rs/redis/src/connection.rs new file mode 100644 index 0000000000..f75b9df494 --- /dev/null +++ b/glide-core/redis-rs/redis/src/connection.rs @@ -0,0 +1,1997 @@ +use std::collections::{HashSet, VecDeque}; +use std::fmt; +use std::io::{self, Write}; +use std::net::{self, SocketAddr, TcpStream, ToSocketAddrs}; +use std::ops::DerefMut; +use std::path::PathBuf; +use std::str::{from_utf8, FromStr}; +use std::time::Duration; + +use crate::cmd::{cmd, pipe, Cmd}; +use crate::parser::Parser; +use crate::pipeline::Pipeline; +use crate::types::{ + from_redis_value, ErrorKind, FromRedisValue, HashMap, PushKind, RedisError, RedisResult, + ToRedisArgs, Value, +}; +use crate::{from_owned_redis_value, ProtocolVersion}; + +#[cfg(unix)] +use std::os::unix::net::UnixStream; +use std::vec::IntoIter; + +use crate::commands::resp3_hello; +#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] +use native_tls::{TlsConnector, TlsStream}; + +#[cfg(feature = "tls-rustls")] +use rustls::{RootCertStore, StreamOwned}; +#[cfg(feature = "tls-rustls")] +use std::sync::Arc; + +use crate::push_manager::PushManager; +use crate::PushInfo; + +#[cfg(all( + feature = "tls-rustls", + not(feature = "tls-native-tls"), + not(feature = "tls-rustls-webpki-roots") +))] +use rustls_native_certs::load_native_certs; + +#[cfg(feature = "tls-rustls")] +use crate::tls::TlsConnParams; + +// Non-exhaustive to prevent construction outside this crate +#[cfg(not(feature = "tls-rustls"))] +#[derive(Clone, Debug)] +#[non_exhaustive] +pub struct TlsConnParams; + +static DEFAULT_PORT: u16 = 6379; + +#[inline(always)] +fn connect_tcp(addr: (&str, u16)) -> io::Result { + let socket = TcpStream::connect(addr)?; + #[cfg(feature = "tcp_nodelay")] + socket.set_nodelay(true)?; + #[cfg(feature = "keep-alive")] + { + //For now rely on system defaults + const KEEP_ALIVE: socket2::TcpKeepalive = socket2::TcpKeepalive::new(); + //these are useless error that not going to happen + let socket2: socket2::Socket = socket.into(); + socket2.set_tcp_keepalive(&KEEP_ALIVE)?; + Ok(socket2.into()) + } + #[cfg(not(feature = "keep-alive"))] + { + Ok(socket) + } +} + +#[inline(always)] +fn connect_tcp_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result { + let socket = TcpStream::connect_timeout(addr, timeout)?; + #[cfg(feature = "tcp_nodelay")] + socket.set_nodelay(true)?; + #[cfg(feature = "keep-alive")] + { + //For now rely on system defaults + const KEEP_ALIVE: socket2::TcpKeepalive = socket2::TcpKeepalive::new(); + //these are useless error that not going to happen + let socket2: socket2::Socket = socket.into(); + socket2.set_tcp_keepalive(&KEEP_ALIVE)?; + Ok(socket2.into()) + } + #[cfg(not(feature = "keep-alive"))] + { + Ok(socket) + } +} + +/// This function takes a redis URL string and parses it into a URL +/// as used by rust-url. This is necessary as the default parser does +/// not understand how redis URLs function. +pub fn parse_redis_url(input: &str) -> Option { + match url::Url::parse(input) { + Ok(result) => match result.scheme() { + "redis" | "rediss" | "redis+unix" | "unix" => Some(result), + _ => None, + }, + Err(_) => None, + } +} + +/// TlsMode indicates use or do not use verification of certification. +/// Check [ConnectionAddr](ConnectionAddr::TcpTls::insecure) for more. +#[derive(Clone, Copy)] +pub enum TlsMode { + /// Secure verify certification. + Secure, + /// Insecure do not verify certification. + Insecure, +} + +/// Defines the connection address. +/// +/// Not all connection addresses are supported on all platforms. For instance +/// to connect to a unix socket you need to run this on an operating system +/// that supports them. +#[derive(Clone, Debug)] +pub enum ConnectionAddr { + /// Format for this is `(host, port)`. + Tcp(String, u16), + /// Format for this is `(host, port)`. + TcpTls { + /// Hostname + host: String, + /// Port + port: u16, + /// Disable hostname verification when connecting. + /// + /// # Warning + /// + /// You should think very carefully before you use this method. If hostname + /// verification is not used, any valid certificate for any site will be + /// trusted for use from any other. This introduces a significant + /// vulnerability to man-in-the-middle attacks. + insecure: bool, + + /// TLS certificates and client key. + tls_params: Option, + }, + /// Format for this is the path to the unix socket. + Unix(PathBuf), +} + +impl PartialEq for ConnectionAddr { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (ConnectionAddr::Tcp(host1, port1), ConnectionAddr::Tcp(host2, port2)) => { + host1 == host2 && port1 == port2 + } + ( + ConnectionAddr::TcpTls { + host: host1, + port: port1, + insecure: insecure1, + tls_params: _, + }, + ConnectionAddr::TcpTls { + host: host2, + port: port2, + insecure: insecure2, + tls_params: _, + }, + ) => port1 == port2 && host1 == host2 && insecure1 == insecure2, + (ConnectionAddr::Unix(path1), ConnectionAddr::Unix(path2)) => path1 == path2, + _ => false, + } + } +} + +impl Eq for ConnectionAddr {} + +impl ConnectionAddr { + /// Checks if this address is supported. + /// + /// Because not all platforms support all connection addresses this is a + /// quick way to figure out if a connection method is supported. Currently + /// this only affects unix connections which are only supported on unix + /// platforms and on older versions of rust also require an explicit feature + /// to be enabled. + pub fn is_supported(&self) -> bool { + match *self { + ConnectionAddr::Tcp(_, _) => true, + ConnectionAddr::TcpTls { .. } => { + cfg!(any(feature = "tls-native-tls", feature = "tls-rustls")) + } + ConnectionAddr::Unix(_) => cfg!(unix), + } + } +} + +impl fmt::Display for ConnectionAddr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // Cluster::get_connection_info depends on the return value from this function + match *self { + ConnectionAddr::Tcp(ref host, port) => write!(f, "{host}:{port}"), + ConnectionAddr::TcpTls { ref host, port, .. } => write!(f, "{host}:{port}"), + ConnectionAddr::Unix(ref path) => write!(f, "{}", path.display()), + } + } +} + +/// Holds the connection information that redis should use for connecting. +#[derive(Clone, Debug)] +pub struct ConnectionInfo { + /// A connection address for where to connect to. + pub addr: ConnectionAddr, + + /// A boxed connection address for where to connect to. + pub redis: RedisConnectionInfo, +} + +/// Types of pubsub subscriptions +/// See for more details +#[derive(Clone, Debug, PartialEq, Eq, Hash, Copy)] +pub enum PubSubSubscriptionKind { + /// Exact channel name. + /// Receives messages which are published to a specific channel using PUBLISH command. + Exact = 0, + /// Pattern-based channel name. + /// Receives messages which are published to channels matched by glob pattern using PUBLISH command. + Pattern = 1, + /// Sharded pubsub mode. + /// Receives messages which are published to a specific channel using SPUBLISH command. + Sharded = 2, +} + +impl From for usize { + fn from(val: PubSubSubscriptionKind) -> Self { + val as usize + } +} + +/// Type for pubsub channels/patterns +pub type PubSubChannelOrPattern = Vec; + +/// Type for pubsub channels/patterns +pub type PubSubSubscriptionInfo = HashMap>; + +/// Redis specific/connection independent information used to establish a connection to redis. +#[derive(Clone, Debug, Default)] +pub struct RedisConnectionInfo { + /// The database number to use. This is usually `0`. + pub db: i64, + /// Optionally a username that should be used for connection. + pub username: Option, + /// Optionally a password that should be used for connection. + pub password: Option, + /// Version of the protocol to use. + pub protocol: ProtocolVersion, + /// Optionally a client name that should be used for connection + pub client_name: Option, + /// Optionally a pubsub subscriptions that should be used for connection + pub pubsub_subscriptions: Option, +} + +impl FromStr for ConnectionInfo { + type Err = RedisError; + + fn from_str(s: &str) -> Result { + s.into_connection_info() + } +} + +/// Converts an object into a connection info struct. This allows the +/// constructor of the client to accept connection information in a +/// range of different formats. +pub trait IntoConnectionInfo { + /// Converts the object into a connection info object. + fn into_connection_info(self) -> RedisResult; +} + +impl IntoConnectionInfo for ConnectionInfo { + fn into_connection_info(self) -> RedisResult { + Ok(self) + } +} + +/// URL format: `{redis|rediss}://[][:@][:port][/]` +/// +/// - Basic: `redis://127.0.0.1:6379` +/// - Username & Password: `redis://user:password@127.0.0.1:6379` +/// - Password only: `redis://:password@127.0.0.1:6379` +/// - Specifying DB: `redis://127.0.0.1:6379/0` +/// - Enabling TLS: `rediss://127.0.0.1:6379` +/// - Enabling Insecure TLS: `rediss://127.0.0.1:6379/#insecure` +impl<'a> IntoConnectionInfo for &'a str { + fn into_connection_info(self) -> RedisResult { + match parse_redis_url(self) { + Some(u) => u.into_connection_info(), + None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")), + } + } +} + +impl IntoConnectionInfo for (T, u16) +where + T: Into, +{ + fn into_connection_info(self) -> RedisResult { + Ok(ConnectionInfo { + addr: ConnectionAddr::Tcp(self.0.into(), self.1), + redis: RedisConnectionInfo::default(), + }) + } +} + +/// URL format: `{redis|rediss}://[][:@][:port][/]` +/// +/// - Basic: `redis://127.0.0.1:6379` +/// - Username & Password: `redis://user:password@127.0.0.1:6379` +/// - Password only: `redis://:password@127.0.0.1:6379` +/// - Specifying DB: `redis://127.0.0.1:6379/0` +/// - Enabling TLS: `rediss://127.0.0.1:6379` +/// - Enabling Insecure TLS: `rediss://127.0.0.1:6379/#insecure` +impl IntoConnectionInfo for String { + fn into_connection_info(self) -> RedisResult { + match parse_redis_url(&self) { + Some(u) => u.into_connection_info(), + None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")), + } + } +} + +fn url_to_tcp_connection_info(url: url::Url) -> RedisResult { + let host = match url.host() { + Some(host) => { + // Here we manually match host's enum arms and call their to_string(). + // Because url.host().to_string() will add `[` and `]` for ipv6: + // https://docs.rs/url/latest/src/url/host.rs.html#170 + // And these brackets will break host.parse::() when + // `client.open()` - `ActualConnection::new()` - `addr.to_socket_addrs()`: + // https://doc.rust-lang.org/src/std/net/addr.rs.html#963 + // https://doc.rust-lang.org/src/std/net/parser.rs.html#158 + // IpAddr string with brackets can ONLY parse to SocketAddrV6: + // https://doc.rust-lang.org/src/std/net/parser.rs.html#255 + // But if we call Ipv6Addr.to_string directly, it follows rfc5952 without brackets: + // https://doc.rust-lang.org/src/std/net/ip.rs.html#1755 + match host { + url::Host::Domain(path) => path.to_string(), + url::Host::Ipv4(v4) => v4.to_string(), + url::Host::Ipv6(v6) => v6.to_string(), + } + } + None => fail!((ErrorKind::InvalidClientConfig, "Missing hostname")), + }; + let port = url.port().unwrap_or(DEFAULT_PORT); + let addr = if url.scheme() == "rediss" { + #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))] + { + match url.fragment() { + Some("insecure") => ConnectionAddr::TcpTls { + host, + port, + insecure: true, + tls_params: None, + }, + Some(_) => fail!(( + ErrorKind::InvalidClientConfig, + "only #insecure is supported as URL fragment" + )), + _ => ConnectionAddr::TcpTls { + host, + port, + insecure: false, + tls_params: None, + }, + } + } + + #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))] + fail!(( + ErrorKind::InvalidClientConfig, + "can't connect with TLS, the feature is not enabled" + )); + } else { + ConnectionAddr::Tcp(host, port) + }; + let query: HashMap<_, _> = url.query_pairs().collect(); + Ok(ConnectionInfo { + addr, + redis: RedisConnectionInfo { + db: match url.path().trim_matches('/') { + "" => 0, + path => path.parse::().map_err(|_| -> RedisError { + (ErrorKind::InvalidClientConfig, "Invalid database number").into() + })?, + }, + username: if url.username().is_empty() { + None + } else { + match percent_encoding::percent_decode(url.username().as_bytes()).decode_utf8() { + Ok(decoded) => Some(decoded.into_owned()), + Err(_) => fail!(( + ErrorKind::InvalidClientConfig, + "Username is not valid UTF-8 string" + )), + } + }, + password: match url.password() { + Some(pw) => match percent_encoding::percent_decode(pw.as_bytes()).decode_utf8() { + Ok(decoded) => Some(decoded.into_owned()), + Err(_) => fail!(( + ErrorKind::InvalidClientConfig, + "Password is not valid UTF-8 string" + )), + }, + None => None, + }, + protocol: match query.get("resp3") { + Some(v) => { + if v == "true" { + ProtocolVersion::RESP3 + } else { + ProtocolVersion::RESP2 + } + } + _ => ProtocolVersion::RESP2, + }, + client_name: None, + pubsub_subscriptions: None, + }, + }) +} + +#[cfg(unix)] +fn url_to_unix_connection_info(url: url::Url) -> RedisResult { + let query: HashMap<_, _> = url.query_pairs().collect(); + Ok(ConnectionInfo { + addr: ConnectionAddr::Unix(url.to_file_path().map_err(|_| -> RedisError { + (ErrorKind::InvalidClientConfig, "Missing path").into() + })?), + redis: RedisConnectionInfo { + db: match query.get("db") { + Some(db) => db.parse::().map_err(|_| -> RedisError { + (ErrorKind::InvalidClientConfig, "Invalid database number").into() + })?, + + None => 0, + }, + username: query.get("user").map(|username| username.to_string()), + password: query.get("pass").map(|password| password.to_string()), + protocol: match query.get("resp3") { + Some(v) => { + if v == "true" { + ProtocolVersion::RESP3 + } else { + ProtocolVersion::RESP2 + } + } + _ => ProtocolVersion::RESP2, + }, + client_name: None, + pubsub_subscriptions: None, + }, + }) +} + +#[cfg(not(unix))] +fn url_to_unix_connection_info(_: url::Url) -> RedisResult { + fail!(( + ErrorKind::InvalidClientConfig, + "Unix sockets are not available on this platform." + )); +} + +impl IntoConnectionInfo for url::Url { + fn into_connection_info(self) -> RedisResult { + match self.scheme() { + "redis" | "rediss" => url_to_tcp_connection_info(self), + "unix" | "redis+unix" => url_to_unix_connection_info(self), + _ => fail!(( + ErrorKind::InvalidClientConfig, + "URL provided is not a redis URL" + )), + } + } +} + +struct TcpConnection { + reader: TcpStream, + open: bool, +} + +#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] +struct TcpNativeTlsConnection { + reader: TlsStream, + open: bool, +} + +#[cfg(feature = "tls-rustls")] +struct TcpRustlsConnection { + reader: StreamOwned, + open: bool, +} + +#[cfg(unix)] +struct UnixConnection { + sock: UnixStream, + open: bool, +} + +enum ActualConnection { + Tcp(TcpConnection), + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + TcpNativeTls(Box), + #[cfg(feature = "tls-rustls")] + TcpRustls(Box), + #[cfg(unix)] + Unix(UnixConnection), +} + +#[cfg(feature = "tls-rustls-insecure")] +struct NoCertificateVerification { + supported: rustls::crypto::WebPkiSupportedAlgorithms, +} + +#[cfg(feature = "tls-rustls-insecure")] +impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification { + fn verify_server_cert( + &self, + _end_entity: &rustls_pki_types::CertificateDer<'_>, + _intermediates: &[rustls_pki_types::CertificateDer<'_>], + _server_name: &rustls_pki_types::ServerName<'_>, + _ocsp_response: &[u8], + _now: rustls_pki_types::UnixTime, + ) -> Result { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &rustls_pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &rustls_pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + self.supported.supported_schemes() + } +} + +#[cfg(feature = "tls-rustls-insecure")] +impl fmt::Debug for NoCertificateVerification { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("NoCertificateVerification").finish() + } +} + +/// Represents a stateful redis TCP connection. +pub struct Connection { + con: ActualConnection, + parser: Parser, + db: i64, + + /// Flag indicating whether the connection was left in the PubSub state after dropping `PubSub`. + /// + /// This flag is checked when attempting to send a command, and if it's raised, we attempt to + /// exit the pubsub state before executing the new request. + pubsub: bool, + + // Field indicating which protocol to use for server communications. + protocol: ProtocolVersion, + + /// `PushManager` instance for the connection. + /// This is used to manage Push messages in RESP3 mode. + push_manager: PushManager, +} + +/// Represents a pubsub connection. +pub struct PubSub<'a> { + con: &'a mut Connection, + waiting_messages: VecDeque, +} + +/// Represents a pubsub message. +#[derive(Debug)] +pub struct Msg { + payload: Value, + channel: Value, + pattern: Option, +} + +impl ActualConnection { + pub fn new(addr: &ConnectionAddr, timeout: Option) -> RedisResult { + Ok(match *addr { + ConnectionAddr::Tcp(ref host, ref port) => { + let addr = (host.as_str(), *port); + let tcp = match timeout { + None => connect_tcp(addr)?, + Some(timeout) => { + let mut tcp = None; + let mut last_error = None; + for addr in addr.to_socket_addrs()? { + match connect_tcp_timeout(&addr, timeout) { + Ok(l) => { + tcp = Some(l); + break; + } + Err(e) => { + last_error = Some(e); + } + }; + } + match (tcp, last_error) { + (Some(tcp), _) => tcp, + (None, Some(e)) => { + fail!(e); + } + (None, None) => { + fail!(( + ErrorKind::InvalidClientConfig, + "could not resolve to any addresses" + )); + } + } + } + }; + ActualConnection::Tcp(TcpConnection { + reader: tcp, + open: true, + }) + } + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ConnectionAddr::TcpTls { + ref host, + port, + insecure, + .. + } => { + let tls_connector = if insecure { + TlsConnector::builder() + .danger_accept_invalid_certs(true) + .danger_accept_invalid_hostnames(true) + .use_sni(false) + .build()? + } else { + TlsConnector::new()? + }; + let addr = (host.as_str(), port); + let tls = match timeout { + None => { + let tcp = connect_tcp(addr)?; + match tls_connector.connect(host, tcp) { + Ok(res) => res, + Err(e) => { + fail!((ErrorKind::IoError, "SSL Handshake error", e.to_string())); + } + } + } + Some(timeout) => { + let mut tcp = None; + let mut last_error = None; + for addr in (host.as_str(), port).to_socket_addrs()? { + match connect_tcp_timeout(&addr, timeout) { + Ok(l) => { + tcp = Some(l); + break; + } + Err(e) => { + last_error = Some(e); + } + }; + } + match (tcp, last_error) { + (Some(tcp), _) => tls_connector.connect(host, tcp).unwrap(), + (None, Some(e)) => { + fail!(e); + } + (None, None) => { + fail!(( + ErrorKind::InvalidClientConfig, + "could not resolve to any addresses" + )); + } + } + } + }; + ActualConnection::TcpNativeTls(Box::new(TcpNativeTlsConnection { + reader: tls, + open: true, + })) + } + #[cfg(feature = "tls-rustls")] + ConnectionAddr::TcpTls { + ref host, + port, + insecure, + ref tls_params, + } => { + let host: &str = host; + let config = create_rustls_config(insecure, tls_params.clone())?; + let conn = rustls::ClientConnection::new( + Arc::new(config), + rustls_pki_types::ServerName::try_from(host)?.to_owned(), + )?; + let reader = match timeout { + None => { + let tcp = connect_tcp((host, port))?; + StreamOwned::new(conn, tcp) + } + Some(timeout) => { + let mut tcp = None; + let mut last_error = None; + for addr in (host, port).to_socket_addrs()? { + match connect_tcp_timeout(&addr, timeout) { + Ok(l) => { + tcp = Some(l); + break; + } + Err(e) => { + last_error = Some(e); + } + }; + } + match (tcp, last_error) { + (Some(tcp), _) => StreamOwned::new(conn, tcp), + (None, Some(e)) => { + fail!(e); + } + (None, None) => { + fail!(( + ErrorKind::InvalidClientConfig, + "could not resolve to any addresses" + )); + } + } + } + }; + + ActualConnection::TcpRustls(Box::new(TcpRustlsConnection { reader, open: true })) + } + #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))] + ConnectionAddr::TcpTls { .. } => { + fail!(( + ErrorKind::InvalidClientConfig, + "Cannot connect to TCP with TLS without the tls feature" + )); + } + #[cfg(unix)] + ConnectionAddr::Unix(ref path) => ActualConnection::Unix(UnixConnection { + sock: UnixStream::connect(path)?, + open: true, + }), + #[cfg(not(unix))] + ConnectionAddr::Unix(ref _path) => { + fail!(( + ErrorKind::InvalidClientConfig, + "Cannot connect to unix sockets \ + on this platform" + )); + } + }) + } + + pub fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult { + match *self { + ActualConnection::Tcp(ref mut connection) => { + let res = connection.reader.write_all(bytes).map_err(RedisError::from); + match res { + Err(e) => { + if e.is_unrecoverable_error() { + connection.open = false; + } + Err(e) + } + Ok(_) => Ok(Value::Okay), + } + } + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref mut connection) => { + let res = connection.reader.write_all(bytes).map_err(RedisError::from); + match res { + Err(e) => { + if e.is_unrecoverable_error() { + connection.open = false; + } + Err(e) + } + Ok(_) => Ok(Value::Okay), + } + } + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref mut connection) => { + let res = connection.reader.write_all(bytes).map_err(RedisError::from); + match res { + Err(e) => { + if e.is_unrecoverable_error() { + connection.open = false; + } + Err(e) + } + Ok(_) => Ok(Value::Okay), + } + } + #[cfg(unix)] + ActualConnection::Unix(ref mut connection) => { + let result = connection.sock.write_all(bytes).map_err(RedisError::from); + match result { + Err(e) => { + if e.is_unrecoverable_error() { + connection.open = false; + } + Err(e) + } + Ok(_) => Ok(Value::Okay), + } + } + } + } + + pub fn set_write_timeout(&self, dur: Option) -> RedisResult<()> { + match *self { + ActualConnection::Tcp(TcpConnection { ref reader, .. }) => { + reader.set_write_timeout(dur)?; + } + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref boxed_tls_connection) => { + let reader = &(boxed_tls_connection.reader); + reader.get_ref().set_write_timeout(dur)?; + } + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref boxed_tls_connection) => { + let reader = &(boxed_tls_connection.reader); + reader.get_ref().set_write_timeout(dur)?; + } + #[cfg(unix)] + ActualConnection::Unix(UnixConnection { ref sock, .. }) => { + sock.set_write_timeout(dur)?; + } + } + Ok(()) + } + + pub fn set_read_timeout(&self, dur: Option) -> RedisResult<()> { + match *self { + ActualConnection::Tcp(TcpConnection { ref reader, .. }) => { + reader.set_read_timeout(dur)?; + } + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref boxed_tls_connection) => { + let reader = &(boxed_tls_connection.reader); + reader.get_ref().set_read_timeout(dur)?; + } + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref boxed_tls_connection) => { + let reader = &(boxed_tls_connection.reader); + reader.get_ref().set_read_timeout(dur)?; + } + #[cfg(unix)] + ActualConnection::Unix(UnixConnection { ref sock, .. }) => { + sock.set_read_timeout(dur)?; + } + } + Ok(()) + } + + pub fn is_open(&self) -> bool { + match *self { + ActualConnection::Tcp(TcpConnection { open, .. }) => open, + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref boxed_tls_connection) => boxed_tls_connection.open, + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref boxed_tls_connection) => boxed_tls_connection.open, + #[cfg(unix)] + ActualConnection::Unix(UnixConnection { open, .. }) => open, + } + } +} + +#[cfg(feature = "tls-rustls")] +pub(crate) fn create_rustls_config( + insecure: bool, + tls_params: Option, +) -> RedisResult { + use crate::tls::ClientTlsParams; + + #[allow(unused_mut)] + let mut root_store = RootCertStore::empty(); + #[cfg(feature = "tls-rustls-webpki-roots")] + root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + #[cfg(all( + feature = "tls-rustls", + not(feature = "tls-native-tls"), + not(feature = "tls-rustls-webpki-roots") + ))] + for cert in load_native_certs()? { + root_store.add(cert)?; + } + + let config = rustls::ClientConfig::builder(); + let config = if let Some(tls_params) = tls_params { + let config_builder = + config.with_root_certificates(tls_params.root_cert_store.unwrap_or(root_store)); + + if let Some(ClientTlsParams { + client_cert_chain: client_cert, + client_key, + }) = tls_params.client_tls_params + { + config_builder + .with_client_auth_cert(client_cert, client_key) + .map_err(|err| { + RedisError::from(( + ErrorKind::InvalidClientConfig, + "Unable to build client with TLS parameters provided.", + err.to_string(), + )) + })? + } else { + config_builder.with_no_client_auth() + } + } else { + config + .with_root_certificates(root_store) + .with_no_client_auth() + }; + + match (insecure, cfg!(feature = "tls-rustls-insecure")) { + #[cfg(feature = "tls-rustls-insecure")] + (true, true) => { + let mut config = config; + config.enable_sni = false; + // nosemgrep + config + .dangerous() + .set_certificate_verifier(Arc::new(NoCertificateVerification { + supported: rustls::crypto::ring::default_provider() + .signature_verification_algorithms, + })); + + Ok(config) + } + (true, false) => { + fail!(( + ErrorKind::InvalidClientConfig, + "Cannot create insecure client without tls-rustls-insecure feature" + )); + } + _ => Ok(config), + } +} + +fn connect_auth(con: &mut Connection, connection_info: &RedisConnectionInfo) -> RedisResult<()> { + let mut command = cmd("AUTH"); + if let Some(username) = &connection_info.username { + command.arg(username); + } + let password = connection_info.password.as_ref().unwrap(); + let err = match command.arg(password).query::(con) { + Ok(Value::Okay) => return Ok(()), + Ok(_) => { + fail!(( + ErrorKind::ResponseError, + "Redis server refused to authenticate, returns Ok() != Value::Okay" + )); + } + Err(e) => e, + }; + let err_msg = err.detail().ok_or(( + ErrorKind::AuthenticationFailed, + "Password authentication failed", + ))?; + if !err_msg.contains("wrong number of arguments for 'auth' command") { + fail!(( + ErrorKind::AuthenticationFailed, + "Password authentication failed", + )); + } + + // fallback to AUTH version <= 5 + let mut command = cmd("AUTH"); + match command.arg(password).query::(con) { + Ok(Value::Okay) => Ok(()), + _ => fail!(( + ErrorKind::AuthenticationFailed, + "Password authentication failed", + )), + } +} + +pub fn connect( + connection_info: &ConnectionInfo, + timeout: Option, +) -> RedisResult { + let con = ActualConnection::new(&connection_info.addr, timeout)?; + setup_connection(con, &connection_info.redis) +} + +#[cfg(not(feature = "disable-client-setinfo"))] +pub(crate) fn client_set_info_pipeline() -> Pipeline { + let mut pipeline = crate::pipe(); + pipeline + .cmd("CLIENT") + .arg("SETINFO") + .arg("LIB-NAME") + .arg(std::env!("GLIDE_NAME")) + .ignore(); + pipeline + .cmd("CLIENT") + .arg("SETINFO") + .arg("LIB-VER") + .arg(std::env!("GLIDE_VERSION")) + .ignore(); + pipeline +} + +fn setup_connection( + con: ActualConnection, + connection_info: &RedisConnectionInfo, +) -> RedisResult { + let mut rv = Connection { + con, + parser: Parser::new(), + db: connection_info.db, + pubsub: false, + protocol: connection_info.protocol, + push_manager: PushManager::new(), + }; + + if connection_info.protocol != ProtocolVersion::RESP2 { + let hello_cmd = resp3_hello(connection_info); + let val: RedisResult = hello_cmd.query(&mut rv); + if let Err(err) = val { + return Err(get_resp3_hello_command_error(err)); + } + } else if connection_info.password.is_some() { + connect_auth(&mut rv, connection_info)?; + } + if connection_info.db != 0 { + match cmd("SELECT") + .arg(connection_info.db) + .query::(&mut rv) + { + Ok(Value::Okay) => {} + _ => fail!(( + ErrorKind::ResponseError, + "Redis server refused to switch database" + )), + } + } + + if connection_info.client_name.is_some() { + match cmd("CLIENT") + .arg("SETNAME") + .arg(connection_info.client_name.as_ref().unwrap()) + .query::(&mut rv) + { + Ok(Value::Okay) => {} + _ => fail!(( + ErrorKind::ResponseError, + "Redis server refused to set client name" + )), + } + } + + // result is ignored, as per the command's instructions. + // https://redis.io/commands/client-setinfo/ + #[cfg(not(feature = "disable-client-setinfo"))] + let _: RedisResult<()> = client_set_info_pipeline().query(&mut rv); + + Ok(rv) +} + +/// Implements the "stateless" part of the connection interface that is used by the +/// different objects in redis-rs. Primarily it obviously applies to `Connection` +/// object but also some other objects implement the interface (for instance +/// whole clients or certain redis results). +/// +/// Generally clients and connections (as well as redis results of those) implement +/// this trait. Actual connections provide more functionality which can be used +/// to implement things like `PubSub` but they also can modify the intrinsic +/// state of the TCP connection. This is not possible with `ConnectionLike` +/// implementors because that functionality is not exposed. +pub trait ConnectionLike { + /// Sends an already encoded (packed) command into the TCP socket and + /// reads the single response from it. + fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult; + + /// Sends multiple already encoded (packed) command into the TCP socket + /// and reads `count` responses from it. This is used to implement + /// pipelining. + /// Important - this function is meant for internal usage, since it's + /// easy to pass incorrect `offset` & `count` parameters, which might + /// cause the connection to enter an erroneous state. Users shouldn't + /// call it, instead using the Pipeline::query function. + #[doc(hidden)] + fn req_packed_commands( + &mut self, + cmd: &[u8], + offset: usize, + count: usize, + ) -> RedisResult>; + + /// Sends a [Cmd] into the TCP socket and reads a single response from it. + fn req_command(&mut self, cmd: &Cmd) -> RedisResult { + let pcmd = cmd.get_packed_command(); + self.req_packed_command(&pcmd) + } + + /// Returns the database this connection is bound to. Note that this + /// information might be unreliable because it's initially cached and + /// also might be incorrect if the connection like object is not + /// actually connected. + fn get_db(&self) -> i64; + + /// Does this connection support pipelining? + #[doc(hidden)] + fn supports_pipelining(&self) -> bool { + true + } + + /// Check that all connections it has are available (`PING` internally). + fn check_connection(&mut self) -> bool; + + /// Returns the connection status. + /// + /// The connection is open until any `read_response` call recieved an + /// invalid response from the server (most likely a closed or dropped + /// connection, otherwise a Redis protocol error). When using unix + /// sockets the connection is open until writing a command failed with a + /// `BrokenPipe` error. + fn is_open(&self) -> bool; +} + +/// A connection is an object that represents a single redis connection. It +/// provides basic support for sending encoded commands into a redis connection +/// and to read a response from it. It's bound to a single database and can +/// only be created from the client. +/// +/// You generally do not much with this object other than passing it to +/// `Cmd` objects. +impl Connection { + /// Sends an already encoded (packed) command into the TCP socket and + /// does not read a response. This is useful for commands like + /// `MONITOR` which yield multiple items. This needs to be used with + /// care because it changes the state of the connection. + pub fn send_packed_command(&mut self, cmd: &[u8]) -> RedisResult<()> { + self.send_bytes(cmd)?; + Ok(()) + } + + /// Fetches a single response from the connection. This is useful + /// if used in combination with `send_packed_command`. + pub fn recv_response(&mut self) -> RedisResult { + self.read_response() + } + + /// Sets the write timeout for the connection. + /// + /// If the provided value is `None`, then `send_packed_command` call will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + pub fn set_write_timeout(&self, dur: Option) -> RedisResult<()> { + self.con.set_write_timeout(dur) + } + + /// Sets the read timeout for the connection. + /// + /// If the provided value is `None`, then `recv_response` call will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + pub fn set_read_timeout(&self, dur: Option) -> RedisResult<()> { + self.con.set_read_timeout(dur) + } + + /// Creates a [`PubSub`] instance for this connection. + pub fn as_pubsub(&mut self) -> PubSub<'_> { + // NOTE: The pubsub flag is intentionally not raised at this time since + // running commands within the pubsub state should not try and exit from + // the pubsub state. + PubSub::new(self) + } + + fn exit_pubsub(&mut self) -> RedisResult<()> { + let res = self.clear_active_subscriptions(); + if res.is_ok() { + self.pubsub = false; + } else { + // Raise the pubsub flag to indicate the connection is "stuck" in that state. + self.pubsub = true; + } + + res + } + + /// Get the inner connection out of a PubSub + /// + /// Any active subscriptions are unsubscribed. In the event of an error, the connection is + /// dropped. + fn clear_active_subscriptions(&mut self) -> RedisResult<()> { + // Responses to unsubscribe commands return in a 3-tuple with values + // ("unsubscribe" or "punsubscribe", name of subscription removed, count of remaining subs). + // The "count of remaining subs" includes both pattern subscriptions and non pattern + // subscriptions. Thus, to accurately drain all unsubscribe messages received from the + // server, both commands need to be executed at once. + { + // Prepare both unsubscribe commands + let unsubscribe = cmd("UNSUBSCRIBE").get_packed_command(); + let punsubscribe = cmd("PUNSUBSCRIBE").get_packed_command(); + + // Execute commands + self.send_bytes(&unsubscribe)?; + self.send_bytes(&punsubscribe)?; + } + + // Receive responses + // + // There will be at minimum two responses - 1 for each of punsubscribe and unsubscribe + // commands. There may be more responses if there are active subscriptions. In this case, + // messages are received until the _subscription count_ in the responses reach zero. + let mut received_unsub = false; + let mut received_punsub = false; + if self.protocol != ProtocolVersion::RESP2 { + while let Value::Push { kind, data } = from_owned_redis_value(self.recv_response()?)? { + if data.len() >= 2 { + if let Value::Int(num) = data[1] { + if resp3_is_pub_sub_state_cleared( + &mut received_unsub, + &mut received_punsub, + &kind, + num as isize, + ) { + break; + } + } + } + } + } else { + loop { + let res: (Vec, (), isize) = from_owned_redis_value(self.recv_response()?)?; + if resp2_is_pub_sub_state_cleared( + &mut received_unsub, + &mut received_punsub, + &res.0, + res.2, + ) { + break; + } + } + } + + // Finally, the connection is back in its normal state since all subscriptions were + // cancelled *and* all unsubscribe messages were received. + Ok(()) + } + + /// Fetches a single response from the connection. + fn read_response(&mut self) -> RedisResult { + let result = match self.con { + ActualConnection::Tcp(TcpConnection { ref mut reader, .. }) => { + let result = self.parser.parse_value(reader); + self.push_manager.try_send(&result); + result + } + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref mut boxed_tls_connection) => { + let reader = &mut boxed_tls_connection.reader; + let result = self.parser.parse_value(reader); + self.push_manager.try_send(&result); + result + } + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref mut boxed_tls_connection) => { + let reader = &mut boxed_tls_connection.reader; + let result = self.parser.parse_value(reader); + self.push_manager.try_send(&result); + result + } + #[cfg(unix)] + ActualConnection::Unix(UnixConnection { ref mut sock, .. }) => { + let result = self.parser.parse_value(sock); + self.push_manager.try_send(&result); + result + } + }; + // shutdown connection on protocol error + if let Err(e) = &result { + let shutdown = match e.as_io_error() { + Some(e) => e.kind() == io::ErrorKind::UnexpectedEof, + None => false, + }; + if shutdown { + // Notify the PushManager that the connection was lost + self.push_manager.try_send_raw(&Value::Push { + kind: PushKind::Disconnection, + data: vec![], + }); + match self.con { + ActualConnection::Tcp(ref mut connection) => { + let _ = connection.reader.shutdown(net::Shutdown::Both); + connection.open = false; + } + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref mut connection) => { + let _ = connection.reader.shutdown(); + connection.open = false; + } + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref mut connection) => { + let _ = connection.reader.get_mut().shutdown(net::Shutdown::Both); + connection.open = false; + } + #[cfg(unix)] + ActualConnection::Unix(ref mut connection) => { + let _ = connection.sock.shutdown(net::Shutdown::Both); + connection.open = false; + } + } + } + } + result + } + + /// Returns `PushManager` of Connection, this method is used to subscribe/unsubscribe from Push types + pub fn get_push_manager(&self) -> PushManager { + self.push_manager.clone() + } + + fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult { + let result = self.con.send_bytes(bytes); + if self.protocol != ProtocolVersion::RESP2 { + if let Err(e) = &result { + if e.is_connection_dropped() { + // Notify the PushManager that the connection was lost + self.push_manager.try_send_raw(&Value::Push { + kind: PushKind::Disconnection, + data: vec![], + }); + } + } + } + result + } +} + +impl ConnectionLike for Connection { + /// Sends a [Cmd] into the TCP socket and reads a single response from it. + fn req_command(&mut self, cmd: &Cmd) -> RedisResult { + let pcmd = cmd.get_packed_command(); + if self.pubsub { + self.exit_pubsub()?; + } + + self.send_bytes(&pcmd)?; + if cmd.is_no_response() { + return Ok(Value::Nil); + } + loop { + match self.read_response()? { + Value::Push { + kind: _kind, + data: _data, + } => continue, + val => return Ok(val), + } + } + } + fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult { + if self.pubsub { + self.exit_pubsub()?; + } + + self.send_bytes(cmd)?; + loop { + match self.read_response()? { + Value::Push { + kind: _kind, + data: _data, + } => continue, + val => return Ok(val), + } + } + } + + fn req_packed_commands( + &mut self, + cmd: &[u8], + offset: usize, + count: usize, + ) -> RedisResult> { + if self.pubsub { + self.exit_pubsub()?; + } + self.send_bytes(cmd)?; + let mut rv = vec![]; + let mut first_err = None; + let mut count = count; + let mut idx = 0; + while idx < (offset + count) { + // When processing a transaction, some responses may be errors. + // We need to keep processing the rest of the responses in that case, + // so bailing early with `?` would not be correct. + // See: https://github.com/redis-rs/redis-rs/issues/436 + let response = self.read_response(); + match response { + Ok(item) => { + // RESP3 can insert push data between command replies + if let Value::Push { + kind: _kind, + data: _data, + } = item + { + // if that is the case we have to extend the loop and handle push data + count += 1; + } else if idx >= offset { + rv.push(item); + } + } + Err(err) => { + if first_err.is_none() { + first_err = Some(err); + } + } + } + idx += 1; + } + + first_err.map_or(Ok(rv), Err) + } + + fn get_db(&self) -> i64 { + self.db + } + + fn check_connection(&mut self) -> bool { + cmd("PING").query::(self).is_ok() + } + + fn is_open(&self) -> bool { + self.con.is_open() + } +} + +impl ConnectionLike for T +where + C: ConnectionLike, + T: DerefMut, +{ + fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult { + self.deref_mut().req_packed_command(cmd) + } + + fn req_packed_commands( + &mut self, + cmd: &[u8], + offset: usize, + count: usize, + ) -> RedisResult> { + self.deref_mut().req_packed_commands(cmd, offset, count) + } + + fn req_command(&mut self, cmd: &Cmd) -> RedisResult { + self.deref_mut().req_command(cmd) + } + + fn get_db(&self) -> i64 { + self.deref().get_db() + } + + fn supports_pipelining(&self) -> bool { + self.deref().supports_pipelining() + } + + fn check_connection(&mut self) -> bool { + self.deref_mut().check_connection() + } + + fn is_open(&self) -> bool { + self.deref().is_open() + } +} + +/// The pubsub object provides convenient access to the redis pubsub +/// system. Once created you can subscribe and unsubscribe from channels +/// and listen in on messages. +/// +/// Example: +/// +/// ```rust,no_run +/// # fn do_something() -> redis::RedisResult<()> { +/// let client = redis::Client::open("redis://127.0.0.1/")?; +/// let mut con = client.get_connection(None)?; +/// let mut pubsub = con.as_pubsub(); +/// pubsub.subscribe("channel_1")?; +/// pubsub.subscribe("channel_2")?; +/// +/// loop { +/// let msg = pubsub.get_message()?; +/// let payload : String = msg.get_payload()?; +/// println!("channel '{}': {}", msg.get_channel_name(), payload); +/// } +/// # } +/// ``` +impl<'a> PubSub<'a> { + fn new(con: &'a mut Connection) -> Self { + Self { + con, + waiting_messages: VecDeque::new(), + } + } + + fn cache_messages_until_received_response(&mut self, cmd: &mut Cmd) -> RedisResult<()> { + if self.con.protocol != ProtocolVersion::RESP2 { + cmd.set_no_response(true); + } + let mut response = cmd.query(self.con)?; + loop { + if let Some(msg) = Msg::from_value(&response) { + self.waiting_messages.push_back(msg); + } else { + return Ok(()); + } + response = self.con.recv_response()?; + } + } + + /// Subscribes to a new channel. + pub fn subscribe(&mut self, channel: T) -> RedisResult<()> { + self.cache_messages_until_received_response(cmd("SUBSCRIBE").arg(channel)) + } + + /// Subscribes to a new channel with a pattern. + pub fn psubscribe(&mut self, pchannel: T) -> RedisResult<()> { + self.cache_messages_until_received_response(cmd("PSUBSCRIBE").arg(pchannel)) + } + + /// Unsubscribes from a channel. + pub fn unsubscribe(&mut self, channel: T) -> RedisResult<()> { + self.cache_messages_until_received_response(cmd("UNSUBSCRIBE").arg(channel)) + } + + /// Unsubscribes from a channel with a pattern. + pub fn punsubscribe(&mut self, pchannel: T) -> RedisResult<()> { + self.cache_messages_until_received_response(cmd("PUNSUBSCRIBE").arg(pchannel)) + } + + /// Fetches the next message from the pubsub connection. Blocks until + /// a message becomes available. This currently does not provide a + /// wait not to block :( + /// + /// The message itself is still generic and can be converted into an + /// appropriate type through the helper methods on it. + pub fn get_message(&mut self) -> RedisResult { + if let Some(msg) = self.waiting_messages.pop_front() { + return Ok(msg); + } + loop { + if let Some(msg) = Msg::from_value(&self.con.recv_response()?) { + return Ok(msg); + } else { + continue; + } + } + } + + /// Sets the read timeout for the connection. + /// + /// If the provided value is `None`, then `get_message` call will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + pub fn set_read_timeout(&self, dur: Option) -> RedisResult<()> { + self.con.set_read_timeout(dur) + } +} + +impl<'a> Drop for PubSub<'a> { + fn drop(&mut self) { + let _ = self.con.exit_pubsub(); + } +} + +/// This holds the data that comes from listening to a pubsub +/// connection. It only contains actual message data. +impl Msg { + /// Tries to convert provided [`Value`] into [`Msg`]. + #[allow(clippy::unnecessary_to_owned)] + pub fn from_value(value: &Value) -> Option { + let mut pattern = None; + let payload; + let channel; + + if let Value::Push { kind, data } = value { + let mut iter: IntoIter = data.to_vec().into_iter(); + if kind == &PushKind::Message || kind == &PushKind::SMessage { + channel = iter.next()?; + payload = iter.next()?; + } else if kind == &PushKind::PMessage { + pattern = Some(iter.next()?); + channel = iter.next()?; + payload = iter.next()?; + } else { + return None; + } + } else { + let raw_msg: Vec = from_redis_value(value).ok()?; + let mut iter = raw_msg.into_iter(); + let msg_type: String = from_owned_redis_value(iter.next()?).ok()?; + if msg_type == "message" { + channel = iter.next()?; + payload = iter.next()?; + } else if msg_type == "pmessage" { + pattern = Some(iter.next()?); + channel = iter.next()?; + payload = iter.next()?; + } else { + return None; + } + }; + Some(Msg { + payload, + channel, + pattern, + }) + } + + /// Tries to convert provided [`PushInfo`] into [`Msg`]. + pub fn from_push_info(push_info: &PushInfo) -> Option { + let mut pattern = None; + let payload; + let channel; + + let mut iter = push_info.data.iter().cloned(); + if push_info.kind == PushKind::Message || push_info.kind == PushKind::SMessage { + channel = iter.next()?; + payload = iter.next()?; + } else if push_info.kind == PushKind::PMessage { + pattern = Some(iter.next()?); + channel = iter.next()?; + payload = iter.next()?; + } else { + return None; + } + + Some(Msg { + payload, + channel, + pattern, + }) + } + + /// Returns the channel this message came on. + pub fn get_channel(&self) -> RedisResult { + from_redis_value(&self.channel) + } + + /// Convenience method to get a string version of the channel. Unless + /// your channel contains non utf-8 bytes you can always use this + /// method. If the channel is not a valid string (which really should + /// not happen) then the return value is `"?"`. + pub fn get_channel_name(&self) -> &str { + match self.channel { + Value::BulkString(ref bytes) => from_utf8(bytes).unwrap_or("?"), + _ => "?", + } + } + + /// Returns the message's payload in a specific format. + pub fn get_payload(&self) -> RedisResult { + from_redis_value(&self.payload) + } + + /// Returns the bytes that are the message's payload. This can be used + /// as an alternative to the `get_payload` function if you are interested + /// in the raw bytes in it. + pub fn get_payload_bytes(&self) -> &[u8] { + match self.payload { + Value::BulkString(ref bytes) => bytes, + _ => b"", + } + } + + /// Returns true if the message was constructed from a pattern + /// subscription. + #[allow(clippy::wrong_self_convention)] + pub fn from_pattern(&self) -> bool { + self.pattern.is_some() + } + + /// If the message was constructed from a message pattern this can be + /// used to find out which one. It's recommended to match against + /// an `Option` so that you do not need to use `from_pattern` + /// to figure out if a pattern was set. + pub fn get_pattern(&self) -> RedisResult { + match self.pattern { + None => from_redis_value(&Value::Nil), + Some(ref x) => from_redis_value(x), + } + } +} + +/// This function simplifies transaction management slightly. What it +/// does is automatically watching keys and then going into a transaction +/// loop util it succeeds. Once it goes through the results are +/// returned. +/// +/// To use the transaction two pieces of information are needed: a list +/// of all the keys that need to be watched for modifications and a +/// closure with the code that should be execute in the context of the +/// transaction. The closure is invoked with a fresh pipeline in atomic +/// mode. To use the transaction the function needs to return the result +/// from querying the pipeline with the connection. +/// +/// The end result of the transaction is then available as the return +/// value from the function call. +/// +/// Example: +/// +/// ```rust,no_run +/// use redis::Commands; +/// # fn do_something() -> redis::RedisResult<()> { +/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +/// # let mut con = client.get_connection(None).unwrap(); +/// let key = "the_key"; +/// let (new_val,) : (isize,) = redis::transaction(&mut con, &[key], |con, pipe| { +/// let old_val : isize = con.get(key)?; +/// pipe +/// .set(key, old_val + 1).ignore() +/// .get(key).query(con) +/// })?; +/// println!("The incremented number is: {}", new_val); +/// # Ok(()) } +/// ``` +pub fn transaction< + C: ConnectionLike, + K: ToRedisArgs, + T, + F: FnMut(&mut C, &mut Pipeline) -> RedisResult>, +>( + con: &mut C, + keys: &[K], + func: F, +) -> RedisResult { + let mut func = func; + loop { + cmd("WATCH").arg(keys).query::<()>(con)?; + let mut p = pipe(); + let response: Option = func(con, p.atomic())?; + match response { + None => { + continue; + } + Some(response) => { + // make sure no watch is left in the connection, even if + // someone forgot to use the pipeline. + cmd("UNWATCH").query::<()>(con)?; + return Ok(response); + } + } + } +} +//TODO: for both clearing logic support sharded channels. + +/// Common logic for clearing subscriptions in RESP2 async/sync +pub fn resp2_is_pub_sub_state_cleared( + received_unsub: &mut bool, + received_punsub: &mut bool, + kind: &[u8], + num: isize, +) -> bool { + match kind.first() { + Some(&b'u') => *received_unsub = true, + Some(&b'p') => *received_punsub = true, + _ => (), + }; + *received_unsub && *received_punsub && num == 0 +} + +/// Common logic for clearing subscriptions in RESP3 async/sync +pub fn resp3_is_pub_sub_state_cleared( + received_unsub: &mut bool, + received_punsub: &mut bool, + kind: &PushKind, + num: isize, +) -> bool { + match kind { + PushKind::Unsubscribe => *received_unsub = true, + PushKind::PUnsubscribe => *received_punsub = true, + _ => (), + }; + *received_unsub && *received_punsub && num == 0 +} + +/// Common logic for checking real cause of hello3 command error +pub fn get_resp3_hello_command_error(err: RedisError) -> RedisError { + if let Some(detail) = err.detail() { + if detail.starts_with("unknown command `HELLO`") { + return ( + ErrorKind::RESP3NotSupported, + "Redis Server doesn't support HELLO command therefore resp3 cannot be used", + ) + .into(); + } + } + err +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_redis_url() { + let cases = vec![ + ("redis://127.0.0.1", true), + ("redis://[::1]", true), + ("redis+unix:///run/redis.sock", true), + ("unix:///run/redis.sock", true), + ("http://127.0.0.1", false), + ("tcp://127.0.0.1", false), + ]; + for (url, expected) in cases.into_iter() { + let res = parse_redis_url(url); + assert_eq!( + res.is_some(), + expected, + "Parsed result of `{url}` is not expected", + ); + } + } + + #[test] + fn test_url_to_tcp_connection_info() { + let cases = vec![ + ( + url::Url::parse("redis://127.0.0.1").unwrap(), + ConnectionInfo { + addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379), + redis: Default::default(), + }, + ), + ( + url::Url::parse("redis://[::1]").unwrap(), + ConnectionInfo { + addr: ConnectionAddr::Tcp("::1".to_string(), 6379), + redis: Default::default(), + }, + ), + ( + url::Url::parse("redis://%25johndoe%25:%23%40%3C%3E%24@example.com/2").unwrap(), + ConnectionInfo { + addr: ConnectionAddr::Tcp("example.com".to_string(), 6379), + redis: RedisConnectionInfo { + db: 2, + username: Some("%johndoe%".to_string()), + password: Some("#@<>$".to_string()), + ..Default::default() + }, + }, + ), + ]; + for (url, expected) in cases.into_iter() { + let res = url_to_tcp_connection_info(url.clone()).unwrap(); + assert_eq!(res.addr, expected.addr, "addr of {url} is not expected"); + assert_eq!( + res.redis.db, expected.redis.db, + "db of {url} is not expected", + ); + assert_eq!( + res.redis.username, expected.redis.username, + "username of {url} is not expected", + ); + assert_eq!( + res.redis.password, expected.redis.password, + "password of {url} is not expected", + ); + } + } + + #[test] + fn test_url_to_tcp_connection_info_failed() { + let cases = vec![ + (url::Url::parse("redis://").unwrap(), "Missing hostname"), + ( + url::Url::parse("redis://127.0.0.1/db").unwrap(), + "Invalid database number", + ), + ( + url::Url::parse("redis://C3%B0@127.0.0.1").unwrap(), + "Username is not valid UTF-8 string", + ), + ( + url::Url::parse("redis://:C3%B0@127.0.0.1").unwrap(), + "Password is not valid UTF-8 string", + ), + ]; + for (url, expected) in cases.into_iter() { + let res = url_to_tcp_connection_info(url).unwrap_err(); + assert_eq!( + res.kind(), + crate::ErrorKind::InvalidClientConfig, + "{}", + &res, + ); + #[allow(deprecated)] + let desc = std::error::Error::description(&res); + assert_eq!(desc, expected, "{}", &res); + assert_eq!(res.detail(), None, "{}", &res); + } + } + + #[test] + #[cfg(unix)] + fn test_url_to_unix_connection_info() { + let cases = vec![ + ( + url::Url::parse("unix:///var/run/redis.sock").unwrap(), + ConnectionInfo { + addr: ConnectionAddr::Unix("/var/run/redis.sock".into()), + redis: RedisConnectionInfo { + db: 0, + username: None, + password: None, + protocol: ProtocolVersion::RESP2, + client_name: None, + pubsub_subscriptions: None, + }, + }, + ), + ( + url::Url::parse("redis+unix:///var/run/redis.sock?db=1").unwrap(), + ConnectionInfo { + addr: ConnectionAddr::Unix("/var/run/redis.sock".into()), + redis: RedisConnectionInfo { + db: 1, + ..Default::default() + }, + }, + ), + ( + url::Url::parse( + "unix:///example.sock?user=%25johndoe%25&pass=%23%40%3C%3E%24&db=2", + ) + .unwrap(), + ConnectionInfo { + addr: ConnectionAddr::Unix("/example.sock".into()), + redis: RedisConnectionInfo { + db: 2, + username: Some("%johndoe%".to_string()), + password: Some("#@<>$".to_string()), + ..Default::default() + }, + }, + ), + ( + url::Url::parse( + "redis+unix:///example.sock?pass=%26%3F%3D+%2A%2B&db=2&user=%25johndoe%25", + ) + .unwrap(), + ConnectionInfo { + addr: ConnectionAddr::Unix("/example.sock".into()), + redis: RedisConnectionInfo { + db: 2, + username: Some("%johndoe%".to_string()), + password: Some("&?= *+".to_string()), + ..Default::default() + }, + }, + ), + ]; + for (url, expected) in cases.into_iter() { + assert_eq!( + ConnectionAddr::Unix(url.to_file_path().unwrap()), + expected.addr, + "addr of {url} is not expected", + ); + let res = url_to_unix_connection_info(url.clone()).unwrap(); + assert_eq!(res.addr, expected.addr, "addr of {url} is not expected"); + assert_eq!( + res.redis.db, expected.redis.db, + "db of {url} is not expected", + ); + assert_eq!( + res.redis.username, expected.redis.username, + "username of {url} is not expected", + ); + assert_eq!( + res.redis.password, expected.redis.password, + "password of {url} is not expected", + ); + } + } +} diff --git a/glide-core/redis-rs/redis/src/geo.rs b/glide-core/redis-rs/redis/src/geo.rs new file mode 100644 index 0000000000..6195264a7c --- /dev/null +++ b/glide-core/redis-rs/redis/src/geo.rs @@ -0,0 +1,361 @@ +//! Defines types to use with the geospatial commands. + +use super::{ErrorKind, RedisResult}; +use crate::types::{FromRedisValue, RedisWrite, ToRedisArgs, Value}; + +macro_rules! invalid_type_error { + ($v:expr, $det:expr) => {{ + fail!(( + ErrorKind::TypeError, + "Response was of incompatible type", + format!("{:?} (response was {:?})", $det, $v) + )); + }}; +} + +/// Units used by [`geo_dist`][1] and [`geo_radius`][2]. +/// +/// [1]: ../trait.Commands.html#method.geo_dist +/// [2]: ../trait.Commands.html#method.geo_radius +pub enum Unit { + /// Represents meters. + Meters, + /// Represents kilometers. + Kilometers, + /// Represents miles. + Miles, + /// Represents feed. + Feet, +} + +impl ToRedisArgs for Unit { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + let unit = match *self { + Unit::Meters => "m", + Unit::Kilometers => "km", + Unit::Miles => "mi", + Unit::Feet => "ft", + }; + out.write_arg(unit.as_bytes()); + } +} + +/// A coordinate (longitude, latitude). Can be used with [`geo_pos`][1] +/// to parse response from Redis. +/// +/// [1]: ../trait.Commands.html#method.geo_pos +/// +/// `T` is the type of the every value. +/// +/// * You may want to use either `f64` or `f32` if you want to perform mathematical operations. +/// * To keep the raw value from Redis, use `String`. +#[allow(clippy::derive_partial_eq_without_eq)] // allow f32/f64 here, which don't implement Eq +#[derive(Debug, PartialEq)] +pub struct Coord { + /// Longitude + pub longitude: T, + /// Latitude + pub latitude: T, +} + +impl Coord { + /// Create a new Coord with the (longitude, latitude) + pub fn lon_lat(longitude: T, latitude: T) -> Coord { + Coord { + longitude, + latitude, + } + } +} + +impl FromRedisValue for Coord { + fn from_redis_value(v: &Value) -> RedisResult { + let values: Vec = FromRedisValue::from_redis_value(v)?; + let mut values = values.into_iter(); + let (longitude, latitude) = match (values.next(), values.next(), values.next()) { + (Some(longitude), Some(latitude), None) => (longitude, latitude), + _ => invalid_type_error!(v, "Expect a pair of numbers"), + }; + Ok(Coord { + longitude, + latitude, + }) + } +} + +impl ToRedisArgs for Coord { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + ToRedisArgs::write_redis_args(&self.longitude, out); + ToRedisArgs::write_redis_args(&self.latitude, out); + } + + fn is_single_arg(&self) -> bool { + false + } +} + +/// Options to sort results from [GEORADIUS][1] and [GEORADIUSBYMEMBER][2] commands +/// +/// [1]: https://redis.io/commands/georadius +/// [2]: https://redis.io/commands/georadiusbymember +#[derive(Default)] +pub enum RadiusOrder { + /// Don't sort the results + #[default] + Unsorted, + + /// Sort returned items from the nearest to the farthest, relative to the center. + Asc, + + /// Sort returned items from the farthest to the nearest, relative to the center. + Desc, +} + +/// Options for the [GEORADIUS][1] and [GEORADIUSBYMEMBER][2] commands +/// +/// [1]: https://redis.io/commands/georadius +/// [2]: https://redis.io/commands/georadiusbymember +/// +/// # Example +/// +/// ```rust,no_run +/// use redis::{Commands, RedisResult}; +/// use redis::geo::{RadiusSearchResult, RadiusOptions, RadiusOrder, Unit}; +/// fn nearest_in_radius( +/// con: &mut redis::Connection, +/// key: &str, +/// longitude: f64, +/// latitude: f64, +/// meters: f64, +/// limit: usize, +/// ) -> RedisResult> { +/// let opts = RadiusOptions::default() +/// .order(RadiusOrder::Asc) +/// .limit(limit); +/// con.geo_radius(key, longitude, latitude, meters, Unit::Meters, opts) +/// } +/// ``` +#[derive(Default)] +pub struct RadiusOptions { + with_coord: bool, + with_dist: bool, + count: Option, + order: RadiusOrder, + store: Option>>, + store_dist: Option>>, +} + +impl RadiusOptions { + /// Limit the results to the first N matching items. + pub fn limit(mut self, n: usize) -> Self { + self.count = Some(n); + self + } + + /// Return the distance of the returned items from the specified center. + /// The distance is returned in the same unit as the unit specified as the + /// radius argument of the command. + pub fn with_dist(mut self) -> Self { + self.with_dist = true; + self + } + + /// Return the `longitude, latitude` coordinates of the matching items. + pub fn with_coord(mut self) -> Self { + self.with_coord = true; + self + } + + /// Sort the returned items + pub fn order(mut self, o: RadiusOrder) -> Self { + self.order = o; + self + } + + /// Store the results in a sorted set at `key`, instead of returning them. + /// + /// This feature can't be used with any `with_*` method. + pub fn store(mut self, key: K) -> Self { + self.store = Some(ToRedisArgs::to_redis_args(&key)); + self + } + + /// Store the results in a sorted set at `key`, with the distance from the + /// center as its score. This feature can't be used with any `with_*` method. + pub fn store_dist(mut self, key: K) -> Self { + self.store_dist = Some(ToRedisArgs::to_redis_args(&key)); + self + } +} + +impl ToRedisArgs for RadiusOptions { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + if self.with_coord { + out.write_arg(b"WITHCOORD"); + } + + if self.with_dist { + out.write_arg(b"WITHDIST"); + } + + if let Some(n) = self.count { + out.write_arg(b"COUNT"); + out.write_arg_fmt(n); + } + + match self.order { + RadiusOrder::Asc => out.write_arg(b"ASC"), + RadiusOrder::Desc => out.write_arg(b"DESC"), + _ => (), + }; + + if let Some(ref store) = self.store { + out.write_arg(b"STORE"); + for i in store { + out.write_arg(i); + } + } + + if let Some(ref store_dist) = self.store_dist { + out.write_arg(b"STOREDIST"); + for i in store_dist { + out.write_arg(i); + } + } + } + + fn is_single_arg(&self) -> bool { + false + } +} + +/// Contain an item returned by [`geo_radius`][1] and [`geo_radius_by_member`][2]. +/// +/// [1]: ../trait.Commands.html#method.geo_radius +/// [2]: ../trait.Commands.html#method.geo_radius_by_member +pub struct RadiusSearchResult { + /// The name that was found. + pub name: String, + /// The coordinate if available. + pub coord: Option>, + /// The distance if available. + pub dist: Option, +} + +impl FromRedisValue for RadiusSearchResult { + fn from_redis_value(v: &Value) -> RedisResult { + // If we receive only the member name, it will be a plain string + if let Ok(name) = FromRedisValue::from_redis_value(v) { + return Ok(RadiusSearchResult { + name, + coord: None, + dist: None, + }); + } + + // Try to parse the result from multitple values + if let Value::Array(ref items) = *v { + if let Some(result) = RadiusSearchResult::parse_multi_values(items) { + return Ok(result); + } + } + + invalid_type_error!(v, "Response type not RadiusSearchResult compatible."); + } +} + +impl RadiusSearchResult { + fn parse_multi_values(items: &[Value]) -> Option { + let mut iter = items.iter(); + + // First item is always the member name + let name: String = match iter.next().map(FromRedisValue::from_redis_value) { + Some(Ok(n)) => n, + _ => return None, + }; + + let mut next = iter.next(); + + // Next element, if present, will be the distance. + let dist = match next.map(FromRedisValue::from_redis_value) { + Some(Ok(c)) => { + next = iter.next(); + Some(c) + } + _ => None, + }; + + // Finally, if present, the last item will be the coordinates + + let coord = match next.map(FromRedisValue::from_redis_value) { + Some(Ok(c)) => Some(c), + _ => None, + }; + + Some(RadiusSearchResult { name, coord, dist }) + } +} + +#[cfg(test)] +mod tests { + use super::{Coord, RadiusOptions, RadiusOrder}; + use crate::types::ToRedisArgs; + use std::str; + + macro_rules! assert_args { + ($value:expr, $($args:expr),+) => { + let args = $value.to_redis_args(); + let strings: Vec<_> = args.iter() + .map(|a| str::from_utf8(a.as_ref()).unwrap()) + .collect(); + assert_eq!(strings, vec![$($args),+]); + } + } + + #[test] + fn test_coord_to_args() { + let member = ("Palermo", Coord::lon_lat("13.361389", "38.115556")); + assert_args!(&member, "Palermo", "13.361389", "38.115556"); + } + + #[test] + fn test_radius_options() { + // Without options, should not generate any argument + let empty = RadiusOptions::default(); + assert_eq!(ToRedisArgs::to_redis_args(&empty).len(), 0); + + // Some combinations with WITH* options + let opts = RadiusOptions::default; + + assert_args!(opts().with_coord().with_dist(), "WITHCOORD", "WITHDIST"); + + assert_args!(opts().limit(50), "COUNT", "50"); + + assert_args!(opts().limit(50).store("x"), "COUNT", "50", "STORE", "x"); + + assert_args!( + opts().limit(100).store_dist("y"), + "COUNT", + "100", + "STOREDIST", + "y" + ); + + assert_args!( + opts().order(RadiusOrder::Asc).limit(10).with_dist(), + "WITHDIST", + "COUNT", + "10", + "ASC" + ); + } +} diff --git a/glide-core/redis-rs/redis/src/lib.rs b/glide-core/redis-rs/redis/src/lib.rs new file mode 100644 index 0000000000..4f138c2bb6 --- /dev/null +++ b/glide-core/redis-rs/redis/src/lib.rs @@ -0,0 +1,506 @@ +//! redis-rs is a Rust implementation of a Redis client library. It exposes +//! a general purpose interface to Redis and also provides specific helpers for +//! commonly used functionality. +//! +//! The crate is called `redis` and you can depend on it via cargo: +//! +//! ```ini +//! [dependencies.redis] +//! version = "*" +//! ``` +//! +//! If you want to use the git version: +//! +//! ```ini +//! [dependencies.redis] +//! git = "https://github.com/redis-rs/redis-rs.git" +//! ``` +//! +//! # Basic Operation +//! +//! redis-rs exposes two API levels: a low- and a high-level part. +//! The high-level part does not expose all the functionality of redis and +//! might take some liberties in how it speaks the protocol. The low-level +//! part of the API allows you to express any request on the redis level. +//! You can fluently switch between both API levels at any point. +//! +//! ## Connection Handling +//! +//! For connecting to redis you can use a client object which then can produce +//! actual connections. Connections and clients as well as results of +//! connections and clients are considered `ConnectionLike` objects and +//! can be used anywhere a request is made. +//! +//! The full canonical way to get a connection is to create a client and +//! to ask for a connection from it: +//! +//! ```rust,no_run +//! extern crate redis; +//! +//! fn do_something() -> redis::RedisResult<()> { +//! let client = redis::Client::open("redis://127.0.0.1/")?; +//! let mut con = client.get_connection(None)?; +//! +//! /* do something here */ +//! +//! Ok(()) +//! } +//! ``` +//! +//! ## Optional Features +//! +//! There are a few features defined that can enable additional functionality +//! if so desired. Some of them are turned on by default. +//! +//! * `acl`: enables acl support (enabled by default) +//! * `aio`: enables async IO support (enabled by default) +//! * `geospatial`: enables geospatial support (enabled by default) +//! * `script`: enables script support (enabled by default) +//! * `r2d2`: enables r2d2 connection pool support (optional) +//! * `ahash`: enables ahash map/set support & uses ahash internally (+7-10% performance) (optional) +//! * `cluster`: enables redis cluster support (optional) +//! * `cluster-async`: enables async redis cluster support (optional) +//! * `tokio-comp`: enables support for tokio (optional) +//! * `connection-manager`: enables support for automatic reconnection (optional) +//! * `keep-alive`: enables keep-alive option on socket by means of `socket2` crate (optional) +//! +//! ## Connection Parameters +//! +//! redis-rs knows different ways to define where a connection should +//! go. The parameter to `Client::open` needs to implement the +//! `IntoConnectionInfo` trait of which there are three implementations: +//! +//! * string slices in `redis://` URL format. +//! * URL objects from the redis-url crate. +//! * `ConnectionInfo` objects. +//! +//! The URL format is `redis://[][:@][:port][/]` +//! +//! If Unix socket support is available you can use a unix URL in this format: +//! +//! `redis+unix:///[?db=[&pass=][&user=]]` +//! +//! For compatibility with some other redis libraries, the "unix" scheme +//! is also supported: +//! +//! `unix:///[?db=][&pass=][&user=]]` +//! +//! ## Executing Low-Level Commands +//! +//! To execute low-level commands you can use the `cmd` function which allows +//! you to build redis requests. Once you have configured a command object +//! to your liking you can send a query into any `ConnectionLike` object: +//! +//! ```rust,no_run +//! fn do_something(con: &mut redis::Connection) -> redis::RedisResult<()> { +//! let _ : () = redis::cmd("SET").arg("my_key").arg(42).query(con)?; +//! Ok(()) +//! } +//! ``` +//! +//! Upon querying the return value is a result object. If you do not care +//! about the actual return value (other than that it is not a failure) +//! you can always type annotate it to the unit type `()`. +//! +//! Note that commands with a sub-command (like "MEMORY USAGE", "ACL WHOAMI", +//! "LATENCY HISTORY", etc) must specify the sub-command as a separate `arg`: +//! +//! ```rust,no_run +//! fn do_something(con: &mut redis::Connection) -> redis::RedisResult { +//! // This will result in a server error: "unknown command `MEMORY USAGE`" +//! // because "USAGE" is technically a sub-command of "MEMORY". +//! redis::cmd("MEMORY USAGE").arg("my_key").query(con)?; +//! +//! // However, this will work as you'd expect +//! redis::cmd("MEMORY").arg("USAGE").arg("my_key").query(con) +//! } +//! ``` +//! +//! ## Executing High-Level Commands +//! +//! The high-level interface is similar. For it to become available you +//! need to use the `Commands` trait in which case all `ConnectionLike` +//! objects the library provides will also have high-level methods which +//! make working with the protocol easier: +//! +//! ```rust,no_run +//! extern crate redis; +//! use redis::Commands; +//! +//! fn do_something(con: &mut redis::Connection) -> redis::RedisResult<()> { +//! let _ : () = con.set("my_key", 42)?; +//! Ok(()) +//! } +//! ``` +//! +//! Note that high-level commands are work in progress and many are still +//! missing! +//! +//! ## Type Conversions +//! +//! Because redis inherently is mostly type-less and the protocol is not +//! exactly friendly to developers, this library provides flexible support +//! for casting values to the intended results. This is driven through the `FromRedisValue` and `ToRedisArgs` traits. +//! +//! The `arg` method of the command will accept a wide range of types through +//! the `ToRedisArgs` trait and the `query` method of a command can convert the +//! value to what you expect the function to return through the `FromRedisValue` +//! trait. This is quite flexible and allows vectors, tuples, hashsets, hashmaps +//! as well as optional values: +//! +//! ```rust,no_run +//! # use redis::Commands; +//! # use std::collections::{HashMap, HashSet}; +//! # fn do_something() -> redis::RedisResult<()> { +//! # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +//! # let mut con = client.get_connection(None).unwrap(); +//! let count : i32 = con.get("my_counter")?; +//! let count = con.get("my_counter").unwrap_or(0i32); +//! let k : Option = con.get("missing_key")?; +//! let name : String = con.get("my_name")?; +//! let bin : Vec = con.get("my_binary")?; +//! let map : HashMap = con.hgetall("my_hash")?; +//! let keys : Vec = con.hkeys("my_hash")?; +//! let mems : HashSet = con.smembers("my_set")?; +//! let (k1, k2) : (String, String) = con.get(&["k1", "k2"])?; +//! # Ok(()) +//! # } +//! ``` +//! +//! # Iteration Protocol +//! +//! In addition to sending a single query, iterators are also supported. When +//! used with regular bulk responses they don't give you much over querying and +//! converting into a vector (both use a vector internally) but they can also +//! be used with `SCAN` like commands in which case iteration will send more +//! queries until the cursor is exhausted: +//! +//! ```rust,ignore +//! # fn do_something() -> redis::RedisResult<()> { +//! # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +//! # let mut con = client.get_connection(None).unwrap(); +//! let mut iter : redis::Iter = redis::cmd("SSCAN").arg("my_set") +//! .cursor_arg(0).clone().iter(&mut con)?; +//! for x in iter { +//! // do something with the item +//! } +//! # Ok(()) } +//! ``` +//! +//! As you can see the cursor argument needs to be defined with `cursor_arg` +//! instead of `arg` so that the library knows which argument needs updating +//! as the query is run for more items. +//! +//! # Pipelining +//! +//! In addition to simple queries you can also send command pipelines. This +//! is provided through the `pipe` function. It works very similar to sending +//! individual commands but you can send more than one in one go. This also +//! allows you to ignore individual results so that matching on the end result +//! is easier: +//! +//! ```rust,no_run +//! # fn do_something() -> redis::RedisResult<()> { +//! # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +//! # let mut con = client.get_connection(None).unwrap(); +//! let (k1, k2) : (i32, i32) = redis::pipe() +//! .cmd("SET").arg("key_1").arg(42).ignore() +//! .cmd("SET").arg("key_2").arg(43).ignore() +//! .cmd("GET").arg("key_1") +//! .cmd("GET").arg("key_2").query(&mut con)?; +//! # Ok(()) } +//! ``` +//! +//! If you want the pipeline to be wrapped in a `MULTI`/`EXEC` block you can +//! easily do that by switching the pipeline into `atomic` mode. From the +//! caller's point of view nothing changes, the pipeline itself will take +//! care of the rest for you: +//! +//! ```rust,no_run +//! # fn do_something() -> redis::RedisResult<()> { +//! # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +//! # let mut con = client.get_connection(None).unwrap(); +//! let (k1, k2) : (i32, i32) = redis::pipe() +//! .atomic() +//! .cmd("SET").arg("key_1").arg(42).ignore() +//! .cmd("SET").arg("key_2").arg(43).ignore() +//! .cmd("GET").arg("key_1") +//! .cmd("GET").arg("key_2").query(&mut con)?; +//! # Ok(()) } +//! ``` +//! +//! You can also use high-level commands on pipelines: +//! +//! ```rust,no_run +//! # fn do_something() -> redis::RedisResult<()> { +//! # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +//! # let mut con = client.get_connection(None).unwrap(); +//! let (k1, k2) : (i32, i32) = redis::pipe() +//! .atomic() +//! .set("key_1", 42).ignore() +//! .set("key_2", 43).ignore() +//! .get("key_1") +//! .get("key_2").query(&mut con)?; +//! # Ok(()) } +//! ``` +//! +//! # Transactions +//! +//! Transactions are available through atomic pipelines. In order to use +//! them in a more simple way you can use the `transaction` function of a +//! connection: +//! +//! ```rust,no_run +//! # fn do_something() -> redis::RedisResult<()> { +//! use redis::Commands; +//! # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +//! # let mut con = client.get_connection(None).unwrap(); +//! let key = "the_key"; +//! let (new_val,) : (isize,) = redis::transaction(&mut con, &[key], |con, pipe| { +//! let old_val : isize = con.get(key)?; +//! pipe +//! .set(key, old_val + 1).ignore() +//! .get(key).query(con) +//! })?; +//! println!("The incremented number is: {}", new_val); +//! # Ok(()) } +//! ``` +//! +//! For more information see the `transaction` function. +//! +//! # PubSub +//! +//! Pubsub is currently work in progress but provided through the `PubSub` +//! connection object. Due to the fact that Rust does not have support +//! for async IO in libnative yet, the API does not provide a way to +//! read messages with any form of timeout yet. +//! +//! Example usage: +//! +//! ```rust,no_run +//! # fn do_something() -> redis::RedisResult<()> { +//! let client = redis::Client::open("redis://127.0.0.1/")?; +//! let mut con = client.get_connection(None)?; +//! let mut pubsub = con.as_pubsub(); +//! pubsub.subscribe("channel_1")?; +//! pubsub.subscribe("channel_2")?; +//! +//! loop { +//! let msg = pubsub.get_message()?; +//! let payload : String = msg.get_payload()?; +//! println!("channel '{}': {}", msg.get_channel_name(), payload); +//! } +//! # } +//! ``` +//! +#![cfg_attr( + feature = "script", + doc = r##" +# Scripts + +Lua scripts are supported through the `Script` type in a convenient +way (it does not support pipelining currently). It will automatically +load the script if it does not exist and invoke it. + +Example: + +```rust,no_run +# fn do_something() -> redis::RedisResult<()> { +# let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +# let mut con = client.get_connection(None).unwrap(); +let script = redis::Script::new(r" + return tonumber(ARGV[1]) + tonumber(ARGV[2]); +"); +let result : isize = script.arg(1).arg(2).invoke(&mut con)?; +assert_eq!(result, 3); +# Ok(()) } +``` +"## +)] +//! +#![cfg_attr( + feature = "aio", + doc = r##" +# Async + +In addition to the synchronous interface that's been explained above there also exists an +asynchronous interface based on [`futures`][] and [`tokio`][]. + +This interface exists under the `aio` (async io) module (which requires that the `aio` feature +is enabled) and largely mirrors the synchronous with a few concessions to make it fit the +constraints of `futures`. + +```rust,no_run +use futures::prelude::*; +use redis::AsyncCommands; + +# #[tokio::main] +# async fn main() -> redis::RedisResult<()> { +let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +let mut con = client.get_async_connection(None).await?; + +con.set("key1", b"foo").await?; + +redis::cmd("SET").arg(&["key2", "bar"]).query_async(&mut con).await?; + +let result = redis::cmd("MGET") + .arg(&["key1", "key2"]) + .query_async(&mut con) + .await; +assert_eq!(result, Ok(("foo".to_string(), b"bar".to_vec()))); +# Ok(()) } +``` +"## +)] +//! +//! [`futures`]:https://crates.io/crates/futures +//! [`tokio`]:https://tokio.rs + +#![deny(non_camel_case_types)] +#![warn(missing_docs)] +#![cfg_attr(docsrs, warn(rustdoc::broken_intra_doc_links))] +#![cfg_attr(docsrs, feature(doc_cfg))] +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] + +// public api +pub use crate::client::Client; +pub use crate::client::GlideConnectionOptions; +pub use crate::cmd::{cmd, pack_command, pipe, Arg, Cmd, Iter}; +pub use crate::commands::{ + Commands, ControlFlow, Direction, LposOptions, PubSubCommands, SetOptions, +}; +pub use crate::connection::{ + parse_redis_url, transaction, Connection, ConnectionAddr, ConnectionInfo, ConnectionLike, + IntoConnectionInfo, Msg, PubSub, PubSubChannelOrPattern, PubSubSubscriptionInfo, + PubSubSubscriptionKind, RedisConnectionInfo, TlsMode, +}; +pub use crate::parser::{parse_redis_value, Parser}; +pub use crate::pipeline::Pipeline; +pub use push_manager::{PushInfo, PushManager}; + +#[cfg(feature = "script")] +#[cfg_attr(docsrs, doc(cfg(feature = "script")))] +pub use crate::script::{Script, ScriptInvocation}; + +// preserve grouping and order +#[rustfmt::skip] +pub use crate::types::{ + // utility functions + from_redis_value, + from_owned_redis_value, + + // error kinds + ErrorKind, + + // conversion traits + FromRedisValue, + + // utility types + InfoDict, + NumericBehavior, + Expiry, + SetExpiry, + ExistenceCheck, + + // error and result types + RedisError, + RedisResult, + RedisWrite, + ToRedisArgs, + + // low level values + Value, + PushKind, + VerbatimFormat, + ProtocolVersion +}; + +#[cfg(feature = "aio")] +#[cfg_attr(docsrs, doc(cfg(feature = "aio")))] +pub use crate::{ + cmd::AsyncIter, commands::AsyncCommands, parser::parse_redis_value_async, types::RedisFuture, +}; + +mod macros; +mod pipeline; + +#[cfg(feature = "acl")] +#[cfg_attr(docsrs, doc(cfg(feature = "acl")))] +pub mod acl; + +#[cfg(feature = "aio")] +#[cfg_attr(docsrs, doc(cfg(feature = "aio")))] +pub mod aio; + +#[cfg(feature = "json")] +pub use crate::commands::JsonCommands; + +#[cfg(all(feature = "json", feature = "aio"))] +pub use crate::commands::JsonAsyncCommands; + +#[cfg(feature = "geospatial")] +#[cfg_attr(docsrs, doc(cfg(feature = "geospatial")))] +pub mod geo; + +#[cfg(feature = "cluster")] +#[cfg_attr(docsrs, doc(cfg(feature = "cluster")))] +pub mod cluster; + +#[cfg(feature = "cluster")] +#[cfg_attr(docsrs, doc(cfg(feature = "cluster")))] +mod cluster_slotmap; + +#[cfg(feature = "cluster-async")] +pub use crate::commands::ScanStateRC; + +#[cfg(feature = "cluster-async")] +pub use crate::commands::ObjectType; + +#[cfg(feature = "cluster")] +mod cluster_client; + +/// for testing purposes +pub mod testing { + #[cfg(feature = "cluster")] + pub use crate::cluster_client::ClusterParams; +} + +#[cfg(feature = "cluster")] +mod cluster_pipeline; + +/// Routing information for cluster commands. +#[cfg(feature = "cluster")] +pub mod cluster_routing; + +#[cfg(feature = "cluster")] +#[cfg_attr(docsrs, doc(cfg(feature = "cluster")))] +pub mod cluster_topology; + +#[cfg(feature = "r2d2")] +#[cfg_attr(docsrs, doc(cfg(feature = "r2d2")))] +mod r2d2; + +#[cfg(feature = "streams")] +#[cfg_attr(docsrs, doc(cfg(feature = "streams")))] +pub mod streams; + +#[cfg(feature = "cluster-async")] +pub mod cluster_async; + +#[cfg(feature = "sentinel")] +pub mod sentinel; + +#[cfg(feature = "tls-rustls")] +mod tls; + +#[cfg(feature = "tls-rustls")] +pub use crate::tls::{ClientTlsConfig, TlsCertificates}; + +mod client; +mod cmd; +mod commands; +mod connection; +mod parser; +mod push_manager; +mod script; +mod types; diff --git a/glide-core/redis-rs/redis/src/macros.rs b/glide-core/redis-rs/redis/src/macros.rs new file mode 100644 index 0000000000..b8886cc759 --- /dev/null +++ b/glide-core/redis-rs/redis/src/macros.rs @@ -0,0 +1,7 @@ +#![macro_use] + +macro_rules! fail { + ($expr:expr) => { + return Err(::std::convert::From::from($expr)) + }; +} diff --git a/glide-core/redis-rs/redis/src/parser.rs b/glide-core/redis-rs/redis/src/parser.rs new file mode 100644 index 0000000000..96e0bcd8f1 --- /dev/null +++ b/glide-core/redis-rs/redis/src/parser.rs @@ -0,0 +1,658 @@ +use std::{ + io::{self, Read}, + str, +}; + +use crate::types::{ + ErrorKind, InternalValue, PushKind, RedisError, RedisResult, ServerError, ServerErrorKind, + Value, VerbatimFormat, +}; + +use combine::{ + any, + error::StreamError, + opaque, + parser::{ + byte::{crlf, take_until_bytes}, + combinator::{any_send_sync_partial_state, AnySendSyncPartialState}, + range::{recognize, take}, + }, + stream::{PointerOffset, RangeStream, StreamErrorFor}, + ParseError, Parser as _, +}; +use num_bigint::BigInt; + +const MAX_RECURSE_DEPTH: usize = 100; + +fn err_parser(line: &str) -> ServerError { + let mut pieces = line.splitn(2, ' '); + let kind = match pieces.next().unwrap() { + "ERR" => ServerErrorKind::ResponseError, + "EXECABORT" => ServerErrorKind::ExecAbortError, + "LOADING" => ServerErrorKind::BusyLoadingError, + "NOSCRIPT" => ServerErrorKind::NoScriptError, + "MOVED" => ServerErrorKind::Moved, + "ASK" => ServerErrorKind::Ask, + "TRYAGAIN" => ServerErrorKind::TryAgain, + "CLUSTERDOWN" => ServerErrorKind::ClusterDown, + "CROSSSLOT" => ServerErrorKind::CrossSlot, + "MASTERDOWN" => ServerErrorKind::MasterDown, + "READONLY" => ServerErrorKind::ReadOnly, + "NOTBUSY" => ServerErrorKind::NotBusy, + code => { + return ServerError::ExtensionError { + code: code.to_string(), + detail: pieces.next().map(|str| str.to_string()), + } + } + }; + let detail = pieces.next().map(|str| str.to_string()); + ServerError::KnownError { kind, detail } +} + +pub fn get_push_kind(kind: String) -> PushKind { + match kind.as_str() { + "invalidate" => PushKind::Invalidate, + "message" => PushKind::Message, + "pmessage" => PushKind::PMessage, + "smessage" => PushKind::SMessage, + "unsubscribe" => PushKind::Unsubscribe, + "punsubscribe" => PushKind::PUnsubscribe, + "sunsubscribe" => PushKind::SUnsubscribe, + "subscribe" => PushKind::Subscribe, + "psubscribe" => PushKind::PSubscribe, + "ssubscribe" => PushKind::SSubscribe, + _ => PushKind::Other(kind), + } +} + +fn value<'a, I>( + count: Option, +) -> impl combine::Parser +where + I: RangeStream, + I::Error: combine::ParseError, +{ + let count = count.unwrap_or(1); + + opaque!(any_send_sync_partial_state( + any() + .then_partial(move |&mut b| { + if b == b'*' && count > MAX_RECURSE_DEPTH { + combine::unexpected_any("Maximum recursion depth exceeded").left() + } else { + combine::value(b).right() + } + }) + .then_partial(move |&mut b| { + let line = || { + recognize(take_until_bytes(&b"\r\n"[..]).with(take(2).map(|_| ()))).and_then( + |line: &[u8]| { + str::from_utf8(&line[..line.len() - 2]) + .map_err(StreamErrorFor::::other) + }, + ) + }; + + let simple_string = || { + line().map(|line| { + if line == "OK" { + InternalValue::Okay + } else { + InternalValue::SimpleString(line.into()) + } + }) + }; + + let int = || { + line().and_then(|line| { + line.trim().parse::().map_err(|_| { + StreamErrorFor::::message_static_message( + "Expected integer, got garbage", + ) + }) + }) + }; + + let bulk_string = || { + int().then_partial(move |size| { + if *size < 0 { + combine::produce(|| InternalValue::Nil).left() + } else { + take(*size as usize) + .map(|bs: &[u8]| InternalValue::BulkString(bs.to_vec())) + .skip(crlf()) + .right() + } + }) + }; + let blob = || { + int().then_partial(move |size| { + take(*size as usize) + .map(|bs: &[u8]| String::from_utf8_lossy(bs).to_string()) + .skip(crlf()) + }) + }; + + let array = || { + int().then_partial(move |&mut length| { + if length < 0 { + combine::produce(|| InternalValue::Nil).left() + } else { + let length = length as usize; + combine::count_min_max(length, length, value(Some(count + 1))) + .map(InternalValue::Array) + .right() + } + }) + }; + + let error = || line().map(err_parser); + let map = || { + int().then_partial(move |&mut kv_length| { + let length = kv_length as usize * 2; + combine::count_min_max(length, length, value(Some(count + 1))).map( + move |result: Vec| { + let mut it = result.into_iter(); + let mut x = vec![]; + for _ in 0..kv_length { + if let (Some(k), Some(v)) = (it.next(), it.next()) { + x.push((k, v)) + } + } + InternalValue::Map(x) + }, + ) + }) + }; + let attribute = || { + int().then_partial(move |&mut kv_length| { + // + 1 is for data! + let length = kv_length as usize * 2 + 1; + combine::count_min_max(length, length, value(Some(count + 1))).map( + move |result: Vec| { + let mut it = result.into_iter(); + let mut attributes = vec![]; + for _ in 0..kv_length { + if let (Some(k), Some(v)) = (it.next(), it.next()) { + attributes.push((k, v)) + } + } + InternalValue::Attribute { + data: Box::new(it.next().unwrap()), + attributes, + } + }, + ) + }) + }; + let set = || { + int().then_partial(move |&mut length| { + if length < 0 { + combine::produce(|| InternalValue::Nil).left() + } else { + let length = length as usize; + combine::count_min_max(length, length, value(Some(count + 1))) + .map(InternalValue::Set) + .right() + } + }) + }; + let push = || { + int().then_partial(move |&mut length| { + if length <= 0 { + combine::produce(|| InternalValue::Push { + kind: PushKind::Other("".to_string()), + data: vec![], + }) + .left() + } else { + let length = length as usize; + combine::count_min_max(length, length, value(Some(count + 1))) + .and_then(|result: Vec| { + let mut it = result.into_iter(); + let first = it.next().unwrap_or(InternalValue::Nil); + if let InternalValue::BulkString(kind) = first { + let push_kind = String::from_utf8(kind) + .map_err(StreamErrorFor::::other)?; + Ok(InternalValue::Push { + kind: get_push_kind(push_kind), + data: it.collect(), + }) + } else if let InternalValue::SimpleString(kind) = first { + Ok(InternalValue::Push { + kind: get_push_kind(kind), + data: it.collect(), + }) + } else { + Err(StreamErrorFor::::message_static_message( + "parse error when decoding push", + )) + } + }) + .right() + } + }) + }; + let null = || line().map(|_| InternalValue::Nil); + let double = || { + line().and_then(|line| { + line.trim() + .parse::() + .map_err(StreamErrorFor::::other) + }) + }; + let boolean = || { + line().and_then(|line: &str| match line { + "t" => Ok(true), + "f" => Ok(false), + _ => Err(StreamErrorFor::::message_static_message( + "Expected boolean, got garbage", + )), + }) + }; + let blob_error = || blob().map(|line| err_parser(&line)); + let verbatim = || { + blob().and_then(|line| { + if let Some((format, text)) = line.split_once(':') { + let format = match format { + "txt" => VerbatimFormat::Text, + "mkd" => VerbatimFormat::Markdown, + x => VerbatimFormat::Unknown(x.to_string()), + }; + Ok(InternalValue::VerbatimString { + format, + text: text.to_string(), + }) + } else { + Err(StreamErrorFor::::message_static_message( + "parse error when decoding verbatim string", + )) + } + }) + }; + let big_number = || { + line().and_then(|line| { + BigInt::parse_bytes(line.as_bytes(), 10).ok_or_else(|| { + StreamErrorFor::::message_static_message( + "Expected bigint, got garbage", + ) + }) + }) + }; + combine::dispatch!(b; + b'+' => simple_string(), + b':' => int().map(InternalValue::Int), + b'$' => bulk_string(), + b'*' => array(), + b'%' => map(), + b'|' => attribute(), + b'~' => set(), + b'-' => error().map(InternalValue::ServerError), + b'_' => null(), + b',' => double().map(InternalValue::Double), + b'#' => boolean().map(InternalValue::Boolean), + b'!' => blob_error().map(InternalValue::ServerError), + b'=' => verbatim(), + b'(' => big_number().map(InternalValue::BigNumber), + b'>' => push(), + b => combine::unexpected_any(combine::error::Token(b)) + ) + }) + )) +} + +#[cfg(feature = "aio")] +mod aio_support { + use super::*; + + use bytes::{Buf, BytesMut}; + use tokio::io::AsyncRead; + use tokio_util::codec::{Decoder, Encoder}; + + #[derive(Default)] + pub struct ValueCodec { + state: AnySendSyncPartialState, + } + + impl ValueCodec { + fn decode_stream( + &mut self, + bytes: &mut BytesMut, + eof: bool, + ) -> RedisResult>> { + let (opt, removed_len) = { + let buffer = &bytes[..]; + let mut stream = + combine::easy::Stream(combine::stream::MaybePartialStream(buffer, !eof)); + match combine::stream::decode_tokio(value(None), &mut stream, &mut self.state) { + Ok(x) => x, + Err(err) => { + let err = err + .map_position(|pos| pos.translate_position(buffer)) + .map_range(|range| format!("{range:?}")) + .to_string(); + return Err(RedisError::from(( + ErrorKind::ParseError, + "parse error", + err, + ))); + } + } + }; + + bytes.advance(removed_len); + match opt { + Some(result) => Ok(Some(result.try_into())), + None => Ok(None), + } + } + } + + impl Encoder> for ValueCodec { + type Error = RedisError; + fn encode(&mut self, item: Vec, dst: &mut BytesMut) -> Result<(), Self::Error> { + dst.extend_from_slice(item.as_ref()); + Ok(()) + } + } + + impl Decoder for ValueCodec { + type Item = RedisResult; + type Error = RedisError; + + fn decode(&mut self, bytes: &mut BytesMut) -> Result, Self::Error> { + self.decode_stream(bytes, false) + } + + fn decode_eof(&mut self, bytes: &mut BytesMut) -> Result, Self::Error> { + self.decode_stream(bytes, true) + } + } + + /// Parses a redis value asynchronously. + pub async fn parse_redis_value_async( + decoder: &mut combine::stream::Decoder>, + read: &mut R, + ) -> RedisResult + where + R: AsyncRead + std::marker::Unpin, + { + let result = combine::decode_tokio!(*decoder, *read, value(None), |input, _| { + combine::stream::easy::Stream::from(input) + }); + match result { + Err(err) => Err(match err { + combine::stream::decoder::Error::Io { error, .. } => error.into(), + combine::stream::decoder::Error::Parse(err) => { + if err.is_unexpected_end_of_input() { + RedisError::from(io::Error::from(io::ErrorKind::UnexpectedEof)) + } else { + let err = err + .map_range(|range| format!("{range:?}")) + .map_position(|pos| pos.translate_position(decoder.buffer())) + .to_string(); + RedisError::from((ErrorKind::ParseError, "parse error", err)) + } + } + }), + Ok(result) => result.try_into(), + } + } +} + +#[cfg(feature = "aio")] +#[cfg_attr(docsrs, doc(cfg(feature = "aio")))] +pub use self::aio_support::*; + +/// The internal redis response parser. +pub struct Parser { + decoder: combine::stream::decoder::Decoder>, +} + +impl Default for Parser { + fn default() -> Self { + Parser::new() + } +} + +/// The parser can be used to parse redis responses into values. Generally +/// you normally do not use this directly as it's already done for you by +/// the client but in some more complex situations it might be useful to be +/// able to parse the redis responses. +impl Parser { + /// Creates a new parser that parses the data behind the reader. More + /// than one value can be behind the reader in which case the parser can + /// be invoked multiple times. In other words: the stream does not have + /// to be terminated. + pub fn new() -> Parser { + Parser { + decoder: combine::stream::decoder::Decoder::new(), + } + } + + // public api + + /// Parses synchronously into a single value from the reader. + pub fn parse_value(&mut self, mut reader: T) -> RedisResult { + let mut decoder = &mut self.decoder; + let result = combine::decode!(decoder, reader, value(None), |input, _| { + combine::stream::easy::Stream::from(input) + }); + match result { + Err(err) => Err(match err { + combine::stream::decoder::Error::Io { error, .. } => error.into(), + combine::stream::decoder::Error::Parse(err) => { + if err.is_unexpected_end_of_input() { + RedisError::from(io::Error::from(io::ErrorKind::UnexpectedEof)) + } else { + let err = err + .map_range(|range| format!("{range:?}")) + .map_position(|pos| pos.translate_position(decoder.buffer())) + .to_string(); + RedisError::from((ErrorKind::ParseError, "parse error", err)) + } + } + }), + Ok(result) => result.try_into(), + } + } +} + +/// Parses bytes into a redis value. +/// +/// This is the most straightforward way to parse something into a low +/// level redis value instead of having to use a whole parser. +pub fn parse_redis_value(bytes: &[u8]) -> RedisResult { + let mut parser = Parser::new(); + parser.parse_value(bytes) +} + +#[cfg(test)] +mod tests { + use crate::types::make_extension_error; + + use super::*; + + #[cfg(feature = "aio")] + #[test] + fn decode_eof_returns_none_at_eof() { + use tokio_util::codec::Decoder; + let mut codec = ValueCodec::default(); + + let mut bytes = bytes::BytesMut::from(&b"+GET 123\r\n"[..]); + assert_eq!( + codec.decode_eof(&mut bytes), + Ok(Some(Ok(parse_redis_value(b"+GET 123\r\n").unwrap()))) + ); + assert_eq!(codec.decode_eof(&mut bytes), Ok(None)); + assert_eq!(codec.decode_eof(&mut bytes), Ok(None)); + } + + #[cfg(feature = "aio")] + #[test] + fn decode_eof_returns_error_inside_array_and_can_parse_more_inputs() { + use tokio_util::codec::Decoder; + let mut codec = ValueCodec::default(); + + let mut bytes = + bytes::BytesMut::from(b"*3\r\n+OK\r\n-LOADING server is loading\r\n+OK\r\n".as_slice()); + let result = codec.decode_eof(&mut bytes).unwrap().unwrap(); + + assert_eq!( + result, + Err(RedisError::from(( + ErrorKind::BusyLoadingError, + "An error was signalled by the server", + "server is loading".to_string() + ))) + ); + + let mut bytes = bytes::BytesMut::from(b"+OK\r\n".as_slice()); + let result = codec.decode_eof(&mut bytes).unwrap().unwrap(); + + assert_eq!(result, Ok(Value::Okay)); + } + + #[test] + fn parse_nested_error_and_handle_more_inputs() { + // from https://redis.io/docs/interact/transactions/ - + // "EXEC returned two-element bulk string reply where one is an OK code and the other an error reply. It's up to the client library to find a sensible way to provide the error to the user." + + let bytes = b"*3\r\n+OK\r\n-LOADING server is loading\r\n+OK\r\n"; + let result = parse_redis_value(bytes); + + assert_eq!( + result, + Err(RedisError::from(( + ErrorKind::BusyLoadingError, + "An error was signalled by the server", + "server is loading".to_string() + ))) + ); + + let result = parse_redis_value(b"+OK\r\n").unwrap(); + + assert_eq!(result, Value::Okay); + } + + #[test] + fn decode_resp3_double() { + let val = parse_redis_value(b",1.23\r\n").unwrap(); + assert_eq!(val, Value::Double(1.23)); + let val = parse_redis_value(b",nan\r\n").unwrap(); + if let Value::Double(val) = val { + assert!(val.is_sign_positive()); + assert!(val.is_nan()); + } else { + panic!("expected double"); + } + // -nan is supported prior to redis 7.2 + let val = parse_redis_value(b",-nan\r\n").unwrap(); + if let Value::Double(val) = val { + assert!(val.is_sign_negative()); + assert!(val.is_nan()); + } else { + panic!("expected double"); + } + //Allow doubles in scientific E notation + let val = parse_redis_value(b",2.67923e+8\r\n").unwrap(); + assert_eq!(val, Value::Double(267923000.0)); + let val = parse_redis_value(b",2.67923E+8\r\n").unwrap(); + assert_eq!(val, Value::Double(267923000.0)); + let val = parse_redis_value(b",-2.67923E+8\r\n").unwrap(); + assert_eq!(val, Value::Double(-267923000.0)); + let val = parse_redis_value(b",2.1E-2\r\n").unwrap(); + assert_eq!(val, Value::Double(0.021)); + + let val = parse_redis_value(b",-inf\r\n").unwrap(); + assert_eq!(val, Value::Double(-f64::INFINITY)); + let val = parse_redis_value(b",inf\r\n").unwrap(); + assert_eq!(val, Value::Double(f64::INFINITY)); + } + + #[test] + fn decode_resp3_map() { + let val = parse_redis_value(b"%2\r\n+first\r\n:1\r\n+second\r\n:2\r\n").unwrap(); + let mut v = val.as_map_iter().unwrap(); + assert_eq!( + (&Value::SimpleString("first".to_string()), &Value::Int(1)), + v.next().unwrap() + ); + assert_eq!( + (&Value::SimpleString("second".to_string()), &Value::Int(2)), + v.next().unwrap() + ); + } + + #[test] + fn decode_resp3_boolean() { + let val = parse_redis_value(b"#t\r\n").unwrap(); + assert_eq!(val, Value::Boolean(true)); + let val = parse_redis_value(b"#f\r\n").unwrap(); + assert_eq!(val, Value::Boolean(false)); + let val = parse_redis_value(b"#x\r\n"); + assert!(val.is_err()); + let val = parse_redis_value(b"#\r\n"); + assert!(val.is_err()); + } + + #[test] + fn decode_resp3_blob_error() { + let val = parse_redis_value(b"!21\r\nSYNTAX invalid syntax\r\n"); + assert_eq!( + val.err(), + Some(make_extension_error( + "SYNTAX".to_string(), + Some("invalid syntax".to_string()) + )) + ) + } + + #[test] + fn decode_resp3_big_number() { + let val = parse_redis_value(b"(3492890328409238509324850943850943825024385\r\n").unwrap(); + assert_eq!( + val, + Value::BigNumber( + BigInt::parse_bytes(b"3492890328409238509324850943850943825024385", 10).unwrap() + ) + ); + } + + #[test] + fn decode_resp3_set() { + let val = parse_redis_value(b"~5\r\n+orange\r\n+apple\r\n#t\r\n:100\r\n:999\r\n").unwrap(); + let v = val.as_sequence().unwrap(); + assert_eq!(Value::SimpleString("orange".to_string()), v[0]); + assert_eq!(Value::SimpleString("apple".to_string()), v[1]); + assert_eq!(Value::Boolean(true), v[2]); + assert_eq!(Value::Int(100), v[3]); + assert_eq!(Value::Int(999), v[4]); + } + + #[test] + fn decode_resp3_push() { + let val = parse_redis_value(b">3\r\n+message\r\n+somechannel\r\n+this is the message\r\n") + .unwrap(); + if let Value::Push { ref kind, ref data } = val { + assert_eq!(&PushKind::Message, kind); + assert_eq!(Value::SimpleString("somechannel".to_string()), data[0]); + assert_eq!( + Value::SimpleString("this is the message".to_string()), + data[1] + ); + } else { + panic!("Expected Value::Push") + } + } + + #[test] + fn test_max_recursion_depth() { + let bytes = b"*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n"; + match parse_redis_value(bytes) { + Ok(_) => panic!("Expected Err"), + Err(e) => assert!(matches!(e.kind(), ErrorKind::ParseError)), + } + } +} diff --git a/glide-core/redis-rs/redis/src/pipeline.rs b/glide-core/redis-rs/redis/src/pipeline.rs new file mode 100644 index 0000000000..babb57a1ff --- /dev/null +++ b/glide-core/redis-rs/redis/src/pipeline.rs @@ -0,0 +1,324 @@ +#![macro_use] + +use crate::cmd::{cmd, cmd_len, Cmd}; +use crate::connection::ConnectionLike; +use crate::types::{ + from_owned_redis_value, ErrorKind, FromRedisValue, HashSet, RedisResult, ToRedisArgs, Value, +}; + +/// Represents a redis command pipeline. +#[derive(Clone)] +pub struct Pipeline { + commands: Vec, + transaction_mode: bool, + ignored_commands: HashSet, +} + +/// A pipeline allows you to send multiple commands in one go to the +/// redis server. API wise it's very similar to just using a command +/// but it allows multiple commands to be chained and some features such +/// as iteration are not available. +/// +/// Basic example: +/// +/// ```rust,no_run +/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +/// # let mut con = client.get_connection(None).unwrap(); +/// let ((k1, k2),) : ((i32, i32),) = redis::pipe() +/// .cmd("SET").arg("key_1").arg(42).ignore() +/// .cmd("SET").arg("key_2").arg(43).ignore() +/// .cmd("MGET").arg(&["key_1", "key_2"]).query(&mut con).unwrap(); +/// ``` +/// +/// As you can see with `cmd` you can start a new command. By default +/// each command produces a value but for some you can ignore them by +/// calling `ignore` on the command. That way it will be skipped in the +/// return value which is useful for `SET` commands and others, which +/// do not have a useful return value. +impl Pipeline { + /// Creates an empty pipeline. For consistency with the `cmd` + /// api a `pipe` function is provided as alias. + pub fn new() -> Pipeline { + Self::with_capacity(0) + } + + /// Creates an empty pipeline with pre-allocated capacity. + pub fn with_capacity(capacity: usize) -> Pipeline { + Pipeline { + commands: Vec::with_capacity(capacity), + transaction_mode: false, + ignored_commands: HashSet::new(), + } + } + + /// This enables atomic mode. In atomic mode the whole pipeline is + /// enclosed in `MULTI`/`EXEC`. From the user's point of view nothing + /// changes however. This is easier than using `MULTI`/`EXEC` yourself + /// as the format does not change. + /// + /// ```rust,no_run + /// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// let (k1, k2) : (i32, i32) = redis::pipe() + /// .atomic() + /// .cmd("GET").arg("key_1") + /// .cmd("GET").arg("key_2").query(&mut con).unwrap(); + /// ``` + #[inline] + pub fn atomic(&mut self) -> &mut Pipeline { + self.transaction_mode = true; + self + } + + /// Returns the encoded pipeline commands. + pub fn get_packed_pipeline(&self) -> Vec { + encode_pipeline(&self.commands, self.transaction_mode) + } + + #[cfg(feature = "aio")] + pub(crate) fn write_packed_pipeline(&self, out: &mut Vec) { + write_pipeline(out, &self.commands, self.transaction_mode) + } + + fn execute_pipelined(&self, con: &mut dyn ConnectionLike) -> RedisResult { + Ok(self.make_pipeline_results(con.req_packed_commands( + &encode_pipeline(&self.commands, false), + 0, + self.commands.len(), + )?)) + } + + fn execute_transaction(&self, con: &mut dyn ConnectionLike) -> RedisResult { + let mut resp = con.req_packed_commands( + &encode_pipeline(&self.commands, true), + self.commands.len() + 1, + 1, + )?; + match resp.pop() { + Some(Value::Nil) => Ok(Value::Nil), + Some(Value::Array(items)) => Ok(self.make_pipeline_results(items)), + _ => fail!(( + ErrorKind::ResponseError, + "Invalid response when parsing multi response" + )), + } + } + + /// Executes the pipeline and fetches the return values. Since most + /// pipelines return different types it's recommended to use tuple + /// matching to process the results: + /// + /// ```rust,no_run + /// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// let (k1, k2) : (i32, i32) = redis::pipe() + /// .cmd("SET").arg("key_1").arg(42).ignore() + /// .cmd("SET").arg("key_2").arg(43).ignore() + /// .cmd("GET").arg("key_1") + /// .cmd("GET").arg("key_2").query(&mut con).unwrap(); + /// ``` + /// + /// NOTE: A Pipeline object may be reused after `query()` with all the commands as were inserted + /// to them. In order to clear a Pipeline object with minimal memory released/allocated, + /// it is necessary to call the `clear()` before inserting new commands. + #[inline] + pub fn query(&self, con: &mut dyn ConnectionLike) -> RedisResult { + if !con.supports_pipelining() { + fail!(( + ErrorKind::ResponseError, + "This connection does not support pipelining." + )); + } + from_owned_redis_value(if self.commands.is_empty() { + Value::Array(vec![]) + } else if self.transaction_mode { + self.execute_transaction(con)? + } else { + self.execute_pipelined(con)? + }) + } + + #[cfg(feature = "aio")] + async fn execute_pipelined_async(&self, con: &mut C) -> RedisResult + where + C: crate::aio::ConnectionLike, + { + let value = con + .req_packed_commands(self, 0, self.commands.len()) + .await?; + Ok(self.make_pipeline_results(value)) + } + + #[cfg(feature = "aio")] + async fn execute_transaction_async(&self, con: &mut C) -> RedisResult + where + C: crate::aio::ConnectionLike, + { + let mut resp = con + .req_packed_commands(self, self.commands.len() + 1, 1) + .await?; + match resp.pop() { + Some(Value::Nil) => Ok(Value::Nil), + Some(Value::Array(items)) => Ok(self.make_pipeline_results(items)), + _ => Err(( + ErrorKind::ResponseError, + "Invalid response when parsing multi response", + ) + .into()), + } + } + + /// Async version of `query`. + #[inline] + #[cfg(feature = "aio")] + pub async fn query_async(&self, con: &mut C) -> RedisResult + where + C: crate::aio::ConnectionLike, + { + let v = if self.commands.is_empty() { + return from_owned_redis_value(Value::Array(vec![])); + } else if self.transaction_mode { + self.execute_transaction_async(con).await? + } else { + self.execute_pipelined_async(con).await? + }; + from_owned_redis_value(v) + } + + /// This is a shortcut to `query()` that does not return a value and + /// will fail the task if the query of the pipeline fails. + /// + /// This is equivalent to a call of query like this: + /// + /// ```rust,no_run + /// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// let _ : () = redis::pipe().cmd("PING").query(&mut con).unwrap(); + /// ``` + /// + /// NOTE: A Pipeline object may be reused after `query()` with all the commands as were inserted + /// to them. In order to clear a Pipeline object with minimal memory released/allocated, + /// it is necessary to call the `clear()` before inserting new commands. + #[inline] + pub fn execute(&self, con: &mut dyn ConnectionLike) { + self.query::<()>(con).unwrap(); + } +} + +fn encode_pipeline(cmds: &[Cmd], atomic: bool) -> Vec { + let mut rv = vec![]; + write_pipeline(&mut rv, cmds, atomic); + rv +} + +fn write_pipeline(rv: &mut Vec, cmds: &[Cmd], atomic: bool) { + let cmds_len = cmds.iter().map(cmd_len).sum(); + + if atomic { + let multi = cmd("MULTI"); + let exec = cmd("EXEC"); + rv.reserve(cmd_len(&multi) + cmd_len(&exec) + cmds_len); + + multi.write_packed_command_preallocated(rv); + for cmd in cmds { + cmd.write_packed_command_preallocated(rv); + } + exec.write_packed_command_preallocated(rv); + } else { + rv.reserve(cmds_len); + + for cmd in cmds { + cmd.write_packed_command_preallocated(rv); + } + } +} + +// Macro to implement shared methods between Pipeline and ClusterPipeline +macro_rules! implement_pipeline_commands { + ($struct_name:ident) => { + impl $struct_name { + /// Adds a command to the cluster pipeline. + #[inline] + pub fn add_command(&mut self, cmd: Cmd) -> &mut Self { + self.commands.push(cmd); + self + } + + /// Starts a new command. Functions such as `arg` then become + /// available to add more arguments to that command. + #[inline] + pub fn cmd(&mut self, name: &str) -> &mut Self { + self.add_command(cmd(name)) + } + + /// Returns an iterator over all the commands currently in this pipeline + pub fn cmd_iter(&self) -> impl Iterator { + self.commands.iter() + } + + /// Instructs the pipeline to ignore the return value of this command. + /// It will still be ensured that it is not an error, but any successful + /// result is just thrown away. This makes result processing through + /// tuples much easier because you do not need to handle all the items + /// you do not care about. + #[inline] + pub fn ignore(&mut self) -> &mut Self { + match self.commands.len() { + 0 => true, + x => self.ignored_commands.insert(x - 1), + }; + self + } + + /// Adds an argument to the last started command. This works similar + /// to the `arg` method of the `Cmd` object. + /// + /// Note that this function fails the task if executed on an empty pipeline. + #[inline] + pub fn arg(&mut self, arg: T) -> &mut Self { + { + let cmd = self.get_last_command(); + cmd.arg(arg); + } + self + } + + /// Clear a pipeline object's internal data structure. + /// + /// This allows reusing a pipeline object as a clear object while performing a minimal + /// amount of memory released/reallocated. + #[inline] + pub fn clear(&mut self) { + self.commands.clear(); + self.ignored_commands.clear(); + } + + #[inline] + fn get_last_command(&mut self) -> &mut Cmd { + let idx = match self.commands.len() { + 0 => panic!("No command on stack"), + x => x - 1, + }; + &mut self.commands[idx] + } + + fn make_pipeline_results(&self, resp: Vec) -> Value { + let mut rv = Vec::with_capacity(resp.len() - self.ignored_commands.len()); + for (idx, result) in resp.into_iter().enumerate() { + if !self.ignored_commands.contains(&idx) { + rv.push(result); + } + } + Value::Array(rv) + } + } + + impl Default for $struct_name { + fn default() -> Self { + Self::new() + } + } + }; +} + +implement_pipeline_commands!(Pipeline); diff --git a/glide-core/redis-rs/redis/src/push_manager.rs b/glide-core/redis-rs/redis/src/push_manager.rs new file mode 100644 index 0000000000..8a22e06a57 --- /dev/null +++ b/glide-core/redis-rs/redis/src/push_manager.rs @@ -0,0 +1,234 @@ +use crate::{PushKind, RedisResult, Value}; +use arc_swap::ArcSwap; +use std::sync::Arc; +use tokio::sync::mpsc; + +/// Holds information about received Push data +#[derive(Debug, Clone)] +pub struct PushInfo { + /// Push Kind + pub kind: PushKind, + /// Data from push message + pub data: Vec, +} + +/// Manages Push messages for single tokio channel +#[derive(Clone, Default)] +pub struct PushManager { + sender: Arc>>>, +} +impl PushManager { + /// It checks if value's type is Push + /// then invokes `try_send_raw` method + pub(crate) fn try_send(&self, value: &RedisResult) { + if let Ok(value) = &value { + self.try_send_raw(value); + } + } + + /// It checks if value's type is Push and there is a provided sender + /// then creates PushInfo and invokes `send` method of sender + pub(crate) fn try_send_raw(&self, value: &Value) { + if let Value::Push { kind, data } = value { + let guard = self.sender.load(); + if let Some(sender) = guard.as_ref() { + let push_info = PushInfo { + kind: kind.clone(), + data: data.clone(), + }; + if sender.send(push_info).is_err() { + self.sender.compare_and_swap(guard, Arc::new(None)); + } + } + } + } + /// Replace mpsc channel of `PushManager` with provided sender. + pub fn replace_sender(&self, sender: mpsc::UnboundedSender) { + self.sender.store(Arc::new(Some(sender))); + } + + /// Creates new `PushManager` + pub fn new() -> Self { + PushManager { + sender: Arc::from(ArcSwap::from(Arc::new(None))), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_send_and_receive_push_info() { + let push_manager = PushManager::new(); + let (tx, mut rx) = mpsc::unbounded_channel(); + push_manager.replace_sender(tx); + + let value = Ok(Value::Push { + kind: PushKind::Message, + data: vec![Value::BulkString("hello".to_string().into_bytes())], + }); + + push_manager.try_send(&value); + + let push_info = rx.try_recv().unwrap(); + assert_eq!(push_info.kind, PushKind::Message); + assert_eq!( + push_info.data, + vec![Value::BulkString("hello".to_string().into_bytes())] + ); + } + #[test] + fn test_push_manager_receiver_dropped() { + let push_manager = PushManager::new(); + let (tx, rx) = mpsc::unbounded_channel(); + push_manager.replace_sender(tx); + + let value = Ok(Value::Push { + kind: PushKind::Message, + data: vec![Value::BulkString("hello".to_string().into_bytes())], + }); + + drop(rx); + + push_manager.try_send(&value); + push_manager.try_send(&value); + push_manager.try_send(&value); + } + #[test] + fn test_push_manager_without_sender() { + let push_manager = PushManager::new(); + + push_manager.try_send(&Ok(Value::Push { + kind: PushKind::Message, + data: vec![Value::BulkString("hello".to_string().into_bytes())], + })); // nothing happens! + + let (tx, mut rx) = mpsc::unbounded_channel(); + push_manager.replace_sender(tx); + push_manager.try_send(&Ok(Value::Push { + kind: PushKind::Message, + data: vec![Value::BulkString("hello2".to_string().into_bytes())], + })); + + assert_eq!( + rx.try_recv().unwrap().data, + vec![Value::BulkString("hello2".to_string().into_bytes())] + ); + } + #[test] + fn test_push_manager_multiple_channels_and_messages() { + let push_manager = PushManager::new(); + let (tx1, mut rx1) = mpsc::unbounded_channel(); + let (tx2, mut rx2) = mpsc::unbounded_channel(); + push_manager.replace_sender(tx1); + + let value1 = Ok(Value::Push { + kind: PushKind::Message, + data: vec![Value::Int(1)], + }); + + let value2 = Ok(Value::Push { + kind: PushKind::Message, + data: vec![Value::Int(2)], + }); + + push_manager.try_send(&value1); + push_manager.try_send(&value2); + + assert_eq!(rx1.try_recv().unwrap().data, vec![Value::Int(1)]); + assert_eq!(rx1.try_recv().unwrap().data, vec![Value::Int(2)]); + + push_manager.replace_sender(tx2); + // make sure rx1 is disconnected after replacing tx1 with tx2. + assert_eq!( + rx1.try_recv().err().unwrap(), + mpsc::error::TryRecvError::Disconnected + ); + + push_manager.try_send(&value1); + push_manager.try_send(&value2); + + assert_eq!(rx2.try_recv().unwrap().data, vec![Value::Int(1)]); + assert_eq!(rx2.try_recv().unwrap().data, vec![Value::Int(2)]); + } + + #[tokio::test] + async fn test_push_manager_multi_threaded() { + // In this test we create 4 channels and send 1000 message, it switchs channels for each message we sent. + // Then we check if all messages are received and sum of messages are equal to expected sum. + // We also check if all channels are used. + let push_manager = PushManager::new(); + let (tx1, mut rx1) = mpsc::unbounded_channel(); + let (tx2, mut rx2) = mpsc::unbounded_channel(); + let (tx3, mut rx3) = mpsc::unbounded_channel(); + let (tx4, mut rx4) = mpsc::unbounded_channel(); + + let mut handles = vec![]; + let txs = [tx1, tx2, tx3, tx4]; + let mut expected_sum = 0; + for i in 0..1000 { + expected_sum += i; + let push_manager_clone = push_manager.clone(); + let new_tx = txs[(i % 4) as usize].clone(); + let value = Ok(Value::Push { + kind: PushKind::Message, + data: vec![Value::Int(i)], + }); + let handle = tokio::spawn(async move { + push_manager_clone.replace_sender(new_tx); + push_manager_clone.try_send(&value); + }); + handles.push(handle); + } + + for handle in handles { + handle.await.unwrap(); + } + + let mut count1 = 0; + let mut count2 = 0; + let mut count3 = 0; + let mut count4 = 0; + let mut received_sum = 0; + while let Ok(push_info) = rx1.try_recv() { + assert_eq!(push_info.kind, PushKind::Message); + if let Value::Int(i) = push_info.data[0] { + received_sum += i; + } + count1 += 1; + } + while let Ok(push_info) = rx2.try_recv() { + assert_eq!(push_info.kind, PushKind::Message); + if let Value::Int(i) = push_info.data[0] { + received_sum += i; + } + count2 += 1; + } + + while let Ok(push_info) = rx3.try_recv() { + assert_eq!(push_info.kind, PushKind::Message); + if let Value::Int(i) = push_info.data[0] { + received_sum += i; + } + count3 += 1; + } + + while let Ok(push_info) = rx4.try_recv() { + assert_eq!(push_info.kind, PushKind::Message); + if let Value::Int(i) = push_info.data[0] { + received_sum += i; + } + count4 += 1; + } + + assert_ne!(count1, 0); + assert_ne!(count2, 0); + assert_ne!(count3, 0); + assert_ne!(count4, 0); + + assert_eq!(count1 + count2 + count3 + count4, 1000); + assert_eq!(received_sum, expected_sum); + } +} diff --git a/glide-core/redis-rs/redis/src/r2d2.rs b/glide-core/redis-rs/redis/src/r2d2.rs new file mode 100644 index 0000000000..e34d2c7bb9 --- /dev/null +++ b/glide-core/redis-rs/redis/src/r2d2.rs @@ -0,0 +1,36 @@ +use std::io; + +use crate::{ConnectionLike, RedisError}; + +macro_rules! impl_manage_connection { + ($client:ty, $connection:ty) => { + impl r2d2::ManageConnection for $client { + type Connection = $connection; + type Error = RedisError; + + fn connect(&self) -> Result { + self.get_connection(None) + } + + fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { + if conn.check_connection() { + Ok(()) + } else { + Err(RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe))) + } + } + + fn has_broken(&self, conn: &mut Self::Connection) -> bool { + !conn.is_open() + } + } + }; +} + +impl_manage_connection!(crate::Client, crate::Connection); + +#[cfg(feature = "cluster")] +impl_manage_connection!( + crate::cluster::ClusterClient, + crate::cluster::ClusterConnection +); diff --git a/glide-core/redis-rs/redis/src/script.rs b/glide-core/redis-rs/redis/src/script.rs new file mode 100644 index 0000000000..c62d2344ae --- /dev/null +++ b/glide-core/redis-rs/redis/src/script.rs @@ -0,0 +1,255 @@ +#![cfg(feature = "script")] +use sha1_smol::Sha1; + +use crate::cmd::cmd; +use crate::connection::ConnectionLike; +use crate::types::{ErrorKind, FromRedisValue, RedisResult, ToRedisArgs}; +use crate::Cmd; + +/// Represents a lua script. +#[derive(Debug, Clone)] +pub struct Script { + code: String, + hash: String, +} + +/// The script object represents a lua script that can be executed on the +/// redis server. The object itself takes care of automatic uploading and +/// execution. The script object itself can be shared and is immutable. +/// +/// Example: +/// +/// ```rust,no_run +/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +/// # let mut con = client.get_connection(None).unwrap(); +/// let script = redis::Script::new(r" +/// return tonumber(ARGV[1]) + tonumber(ARGV[2]); +/// "); +/// let result = script.arg(1).arg(2).invoke(&mut con); +/// assert_eq!(result, Ok(3)); +/// ``` +impl Script { + /// Creates a new script object. + pub fn new(code: &str) -> Script { + let mut hash = Sha1::new(); + hash.update(code.as_bytes()); + Script { + code: code.to_string(), + hash: hash.digest().to_string(), + } + } + + /// Returns the script's SHA1 hash in hexadecimal format. + pub fn get_hash(&self) -> &str { + &self.hash + } + + /// Creates a script invocation object with a key filled in. + #[inline] + pub fn key(&self, key: T) -> ScriptInvocation<'_> { + ScriptInvocation { + script: self, + args: vec![], + keys: key.to_redis_args(), + } + } + + /// Creates a script invocation object with an argument filled in. + #[inline] + pub fn arg(&self, arg: T) -> ScriptInvocation<'_> { + ScriptInvocation { + script: self, + args: arg.to_redis_args(), + keys: vec![], + } + } + + /// Returns an empty script invocation object. This is primarily useful + /// for programmatically adding arguments and keys because the type will + /// not change. Normally you can use `arg` and `key` directly. + #[inline] + pub fn prepare_invoke(&self) -> ScriptInvocation<'_> { + ScriptInvocation { + script: self, + args: vec![], + keys: vec![], + } + } + + /// Invokes the script directly without arguments. + #[inline] + pub fn invoke(&self, con: &mut dyn ConnectionLike) -> RedisResult { + ScriptInvocation { + script: self, + args: vec![], + keys: vec![], + } + .invoke(con) + } + + /// Asynchronously invokes the script without arguments. + #[inline] + #[cfg(feature = "aio")] + pub async fn invoke_async(&self, con: &mut C) -> RedisResult + where + C: crate::aio::ConnectionLike, + T: FromRedisValue, + { + ScriptInvocation { + script: self, + args: vec![], + keys: vec![], + } + .invoke_async(con) + .await + } +} + +/// Represents a prepared script call. +pub struct ScriptInvocation<'a> { + script: &'a Script, + args: Vec>, + keys: Vec>, +} + +/// This type collects keys and other arguments for the script so that it +/// can be then invoked. While the `Script` type itself holds the script, +/// the `ScriptInvocation` holds the arguments that should be invoked until +/// it's sent to the server. +impl<'a> ScriptInvocation<'a> { + /// Adds a regular argument to the invocation. This ends up as `ARGV[i]` + /// in the script. + #[inline] + pub fn arg<'b, T: ToRedisArgs>(&'b mut self, arg: T) -> &'b mut ScriptInvocation<'a> + where + 'a: 'b, + { + arg.write_redis_args(&mut self.args); + self + } + + /// Adds a key argument to the invocation. This ends up as `KEYS[i]` + /// in the script. + #[inline] + pub fn key<'b, T: ToRedisArgs>(&'b mut self, key: T) -> &'b mut ScriptInvocation<'a> + where + 'a: 'b, + { + key.write_redis_args(&mut self.keys); + self + } + + /// Invokes the script and returns the result. + #[inline] + pub fn invoke(&self, con: &mut dyn ConnectionLike) -> RedisResult { + let eval_cmd = self.eval_cmd(); + match eval_cmd.query(con) { + Ok(val) => Ok(val), + Err(err) => { + if err.kind() == ErrorKind::NoScriptError { + self.load_cmd().query(con)?; + eval_cmd.query(con) + } else { + Err(err) + } + } + } + } + + /// Asynchronously invokes the script and returns the result. + #[inline] + #[cfg(feature = "aio")] + pub async fn invoke_async(&self, con: &mut C) -> RedisResult + where + C: crate::aio::ConnectionLike, + T: FromRedisValue, + { + let eval_cmd = self.eval_cmd(); + match eval_cmd.query_async(con).await { + Ok(val) => { + // Return the value from the script evaluation + Ok(val) + } + Err(err) => { + // Load the script into Redis if the script hash wasn't there already + if err.kind() == ErrorKind::NoScriptError { + self.load_cmd().query_async(con).await?; + eval_cmd.query_async(con).await + } else { + Err(err) + } + } + } + } + + /// Loads the script and returns the SHA1 of it. + #[inline] + pub fn load(&self, con: &mut dyn ConnectionLike) -> RedisResult { + let hash: String = self.load_cmd().query(con)?; + + debug_assert_eq!(hash, self.script.hash); + + Ok(hash) + } + + /// Asynchronously loads the script and returns the SHA1 of it. + #[inline] + #[cfg(feature = "aio")] + pub async fn load_async(&self, con: &mut C) -> RedisResult + where + C: crate::aio::ConnectionLike, + { + let hash: String = self.load_cmd().query_async(con).await?; + + debug_assert_eq!(hash, self.script.hash); + + Ok(hash) + } + + fn load_cmd(&self) -> Cmd { + let mut cmd = cmd("SCRIPT"); + cmd.arg("LOAD").arg(self.script.code.as_bytes()); + cmd + } + + fn estimate_buflen(&self) -> usize { + self + .keys + .iter() + .chain(self.args.iter()) + .fold(0, |acc, e| acc + e.len()) + + 7 /* "EVALSHA".len() */ + + self.script.hash.len() + + 4 /* Slots reserved for the length of keys. */ + } + + fn eval_cmd(&self) -> Cmd { + let args_len = 3 + self.keys.len() + self.args.len(); + let mut cmd = Cmd::with_capacity(args_len, self.estimate_buflen()); + cmd.arg("EVALSHA") + .arg(self.script.hash.as_bytes()) + .arg(self.keys.len()) + .arg(&*self.keys) + .arg(&*self.args); + cmd + } +} + +#[cfg(test)] +mod tests { + use super::Script; + + #[test] + fn script_eval_should_work() { + let script = Script::new("return KEYS[1]"); + let invocation = script.key("dummy"); + let estimated_buflen = invocation.estimate_buflen(); + let cmd = invocation.eval_cmd(); + assert!(estimated_buflen >= cmd.capacity().1); + let expected = "*4\r\n$7\r\nEVALSHA\r\n$40\r\n4a2267357833227dd98abdedb8cf24b15a986445\r\n$1\r\n1\r\n$5\r\ndummy\r\n"; + assert_eq!( + expected, + std::str::from_utf8(cmd.get_packed_command().as_slice()).unwrap() + ); + } +} diff --git a/glide-core/redis-rs/redis/src/sentinel.rs b/glide-core/redis-rs/redis/src/sentinel.rs new file mode 100644 index 0000000000..ac6aac65cc --- /dev/null +++ b/glide-core/redis-rs/redis/src/sentinel.rs @@ -0,0 +1,778 @@ +//! Defines a Sentinel type that connects to Redis sentinels and creates clients to +//! master or replica nodes. +//! +//! # Example +//! ```rust,no_run +//! use redis::Commands; +//! use redis::sentinel::Sentinel; +//! +//! let nodes = vec!["redis://127.0.0.1:6379/", "redis://127.0.0.1:6378/", "redis://127.0.0.1:6377/"]; +//! let mut sentinel = Sentinel::build(nodes).unwrap(); +//! let mut master = sentinel.master_for("master_name", None).unwrap().get_connection(None).unwrap(); +//! let mut replica = sentinel.replica_for("master_name", None).unwrap().get_connection(None).unwrap(); +//! +//! let _: () = master.set("test", "test_data").unwrap(); +//! let rv: String = replica.get("test").unwrap(); +//! +//! assert_eq!(rv, "test_data"); +//! ``` +//! +//! There is also a SentinelClient which acts like a regular Client, providing the +//! `get_connection` and `get_async_connection` methods, internally using the Sentinel +//! type to create clients on demand for the desired node type (Master or Replica). +//! +//! # Example +//! ```rust,no_run +//! use redis::Commands; +//! use redis::sentinel::{ SentinelServerType, SentinelClient }; +//! +//! let nodes = vec!["redis://127.0.0.1:6379/", "redis://127.0.0.1:6378/", "redis://127.0.0.1:6377/"]; +//! let mut master_client = SentinelClient::build(nodes.clone(), String::from("master_name"), None, SentinelServerType::Master).unwrap(); +//! let mut replica_client = SentinelClient::build(nodes, String::from("master_name"), None, SentinelServerType::Replica).unwrap(); +//! let mut master_conn = master_client.get_connection().unwrap(); +//! let mut replica_conn = replica_client.get_connection().unwrap(); +//! +//! let _: () = master_conn.set("test", "test_data").unwrap(); +//! let rv: String = replica_conn.get("test").unwrap(); +//! +//! assert_eq!(rv, "test_data"); +//! ``` +//! +//! If the sentinel's nodes are using TLS or require authentication, a full +//! SentinelNodeConnectionInfo struct may be used instead of just the master's name: +//! +//! # Example +//! ```rust,no_run +//! use redis::{ Commands, RedisConnectionInfo }; +//! use redis::sentinel::{ Sentinel, SentinelNodeConnectionInfo }; +//! +//! let nodes = vec!["redis://127.0.0.1:6379/", "redis://127.0.0.1:6378/", "redis://127.0.0.1:6377/"]; +//! let mut sentinel = Sentinel::build(nodes).unwrap(); +//! +//! let mut master_with_auth = sentinel +//! .master_for( +//! "master_name", +//! Some(&SentinelNodeConnectionInfo { +//! tls_mode: None, +//! redis_connection_info: Some(RedisConnectionInfo { +//! db: 1, +//! username: Some(String::from("foo")), +//! password: Some(String::from("bar")), +//! ..Default::default() +//! }), +//! }), +//! ) +//! .unwrap() +//! .get_connection(None) +//! .unwrap(); +//! +//! let mut replica_with_tls = sentinel +//! .master_for( +//! "master_name", +//! Some(&SentinelNodeConnectionInfo { +//! tls_mode: Some(redis::TlsMode::Secure), +//! redis_connection_info: None, +//! }), +//! ) +//! .unwrap() +//! .get_connection(None) +//! .unwrap(); +//! ``` +//! +//! # Example +//! ```rust,no_run +//! use redis::{ Commands, RedisConnectionInfo }; +//! use redis::sentinel::{ SentinelServerType, SentinelClient, SentinelNodeConnectionInfo }; +//! +//! let nodes = vec!["redis://127.0.0.1:6379/", "redis://127.0.0.1:6378/", "redis://127.0.0.1:6377/"]; +//! let mut master_client = SentinelClient::build( +//! nodes, +//! String::from("master1"), +//! Some(SentinelNodeConnectionInfo { +//! tls_mode: Some(redis::TlsMode::Insecure), +//! redis_connection_info: Some(RedisConnectionInfo { +//! username: Some(String::from("user")), +//! password: Some(String::from("pass")), +//! ..Default::default() +//! }), +//! }), +//! redis::sentinel::SentinelServerType::Master, +//! ) +//! .unwrap(); +//! ``` +//! + +use std::{collections::HashMap, num::NonZeroUsize}; + +#[cfg(feature = "aio")] +use futures_util::StreamExt; +use rand::Rng; + +#[cfg(feature = "aio")] +use crate::aio::MultiplexedConnection as AsyncConnection; + +use crate::{ + client::GlideConnectionOptions, connection::ConnectionInfo, types::RedisResult, Client, Cmd, + Connection, ErrorKind, FromRedisValue, IntoConnectionInfo, RedisConnectionInfo, TlsMode, Value, +}; + +/// The Sentinel type, serves as a special purpose client which builds other clients on +/// demand. +pub struct Sentinel { + sentinels_connection_info: Vec, + connections_cache: Vec>, + #[cfg(feature = "aio")] + async_connections_cache: Vec>, + replica_start_index: usize, +} + +/// Holds the connection information that a sentinel should use when connecting to the +/// servers (masters and replicas) belonging to it. +#[derive(Clone, Default)] +pub struct SentinelNodeConnectionInfo { + /// The TLS mode of the connection, or None if we do not want to connect using TLS + /// (just a plain TCP connection). + pub tls_mode: Option, + + /// The Redis specific/connection independent information to be used. + pub redis_connection_info: Option, +} + +impl SentinelNodeConnectionInfo { + fn create_connection_info(&self, ip: String, port: u16) -> ConnectionInfo { + let addr = match self.tls_mode { + None => crate::ConnectionAddr::Tcp(ip, port), + Some(TlsMode::Secure) => crate::ConnectionAddr::TcpTls { + host: ip, + port, + insecure: false, + tls_params: None, + }, + Some(TlsMode::Insecure) => crate::ConnectionAddr::TcpTls { + host: ip, + port, + insecure: true, + tls_params: None, + }, + }; + + ConnectionInfo { + addr, + redis: self.redis_connection_info.clone().unwrap_or_default(), + } + } +} + +impl Default for &SentinelNodeConnectionInfo { + fn default() -> Self { + static DEFAULT_VALUE: SentinelNodeConnectionInfo = SentinelNodeConnectionInfo { + tls_mode: None, + redis_connection_info: None, + }; + &DEFAULT_VALUE + } +} + +fn sentinel_masters_cmd() -> crate::Cmd { + let mut cmd = crate::cmd("SENTINEL"); + cmd.arg("MASTERS"); + cmd +} + +fn sentinel_replicas_cmd(master_name: &str) -> crate::Cmd { + let mut cmd = crate::cmd("SENTINEL"); + cmd.arg("SLAVES"); // For compatibility with older redis versions + cmd.arg(master_name); + cmd +} + +fn is_master_valid(master_info: &HashMap, service_name: &str) -> bool { + master_info.get("name").map(|s| s.as_str()) == Some(service_name) + && master_info.contains_key("ip") + && master_info.contains_key("port") + && master_info.get("flags").map_or(false, |flags| { + flags.contains("master") && !flags.contains("s_down") && !flags.contains("o_down") + }) + && master_info["port"].parse::().is_ok() +} + +fn is_replica_valid(replica_info: &HashMap) -> bool { + replica_info.contains_key("ip") + && replica_info.contains_key("port") + && replica_info.get("flags").map_or(false, |flags| { + !flags.contains("s_down") && !flags.contains("o_down") + }) + && replica_info["port"].parse::().is_ok() +} + +/// Generates a random value in the 0..max range. +fn random_replica_index(max: NonZeroUsize) -> usize { + rand::thread_rng().gen_range(0..max.into()) +} + +fn try_connect_to_first_replica( + addresses: &[ConnectionInfo], + start_index: Option, +) -> Result { + if addresses.is_empty() { + fail!(( + ErrorKind::NoValidReplicasFoundBySentinel, + "No valid replica found in sentinel for given name", + )); + } + + let start_index = start_index.unwrap_or(0); + + let mut last_err = None; + for i in 0..addresses.len() { + let index = (i + start_index) % addresses.len(); + match Client::open(addresses[index].clone()) { + Ok(client) => return Ok(client), + Err(err) => last_err = Some(err), + } + } + + // We can unwrap here because we know there is at least one error, since there is at + // least one address, so we'll either return a client for it or store an error in + // last_err. + Err(last_err.expect("There should be an error because there is should be at least one address")) +} + +fn valid_addrs<'a>( + servers_info: Vec>, + validate: impl Fn(&HashMap) -> bool + 'a, +) -> impl Iterator { + servers_info + .into_iter() + .filter(move |info| validate(info)) + .map(|mut info| { + // We can unwrap here because we already checked everything + let ip = info.remove("ip").unwrap(); + let port = info["port"].parse::().unwrap(); + (ip, port) + }) +} + +fn check_role_result(result: &RedisResult>, target_role: &str) -> bool { + if let Ok(values) = result { + if !values.is_empty() { + if let Ok(role) = String::from_redis_value(&values[0]) { + return role.to_ascii_lowercase() == target_role; + } + } + } + false +} + +fn check_role(connection_info: &ConnectionInfo, target_role: &str) -> bool { + if let Ok(client) = Client::open(connection_info.clone()) { + if let Ok(mut conn) = client.get_connection(None) { + let result: RedisResult> = crate::cmd("ROLE").query(&mut conn); + return check_role_result(&result, target_role); + } + } + false +} + +/// Searches for a valid master with the given name in the list of masters returned by +/// a sentinel. A valid master is one which has a role of "master" (checked by running +/// the `ROLE` command and by seeing if its flags contains the "master" flag) and which +/// does not have the flags s_down or o_down set to it (these flags are returned by the +/// `SENTINEL MASTERS` command, and we expect the `masters` parameter to be the result of +/// that command). +fn find_valid_master( + masters: Vec>, + service_name: &str, + node_connection_info: &SentinelNodeConnectionInfo, +) -> RedisResult { + for (ip, port) in valid_addrs(masters, |m| is_master_valid(m, service_name)) { + let connection_info = node_connection_info.create_connection_info(ip, port); + if check_role(&connection_info, "master") { + return Ok(connection_info); + } + } + + fail!(( + ErrorKind::MasterNameNotFoundBySentinel, + "Master with given name not found in sentinel", + )) +} + +#[cfg(feature = "aio")] +async fn async_check_role(connection_info: &ConnectionInfo, target_role: &str) -> bool { + if let Ok(client) = Client::open(connection_info.clone()) { + if let Ok(mut conn) = client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await + { + let result: RedisResult> = crate::cmd("ROLE").query_async(&mut conn).await; + return check_role_result(&result, target_role); + } + } + false +} + +/// Async version of [find_valid_master]. +#[cfg(feature = "aio")] +async fn async_find_valid_master( + masters: Vec>, + service_name: &str, + node_connection_info: &SentinelNodeConnectionInfo, +) -> RedisResult { + for (ip, port) in valid_addrs(masters, |m| is_master_valid(m, service_name)) { + let connection_info = node_connection_info.create_connection_info(ip, port); + if async_check_role(&connection_info, "master").await { + return Ok(connection_info); + } + } + + fail!(( + ErrorKind::MasterNameNotFoundBySentinel, + "Master with given name not found in sentinel", + )) +} + +fn get_valid_replicas_addresses( + replicas: Vec>, + node_connection_info: &SentinelNodeConnectionInfo, +) -> Vec { + valid_addrs(replicas, is_replica_valid) + .map(|(ip, port)| node_connection_info.create_connection_info(ip, port)) + .filter(|connection_info| check_role(connection_info, "slave")) + .collect() +} + +#[cfg(feature = "aio")] +async fn async_get_valid_replicas_addresses<'a>( + replicas: Vec>, + node_connection_info: &SentinelNodeConnectionInfo, +) -> Vec { + async fn is_replica_role_valid(connection_info: ConnectionInfo) -> Option { + if async_check_role(&connection_info, "slave").await { + Some(connection_info) + } else { + None + } + } + + futures_util::stream::iter(valid_addrs(replicas, is_replica_valid)) + .map(|(ip, port)| node_connection_info.create_connection_info(ip, port)) + .filter_map(is_replica_role_valid) + .collect() + .await +} + +#[cfg(feature = "aio")] +async fn async_reconnect( + connection: &mut Option, + connection_info: &ConnectionInfo, +) -> RedisResult<()> { + let sentinel_client = Client::open(connection_info.clone())?; + let new_connection = sentinel_client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + connection.replace(new_connection); + Ok(()) +} + +#[cfg(feature = "aio")] +async fn async_try_single_sentinel( + cmd: Cmd, + connection_info: &ConnectionInfo, + cached_connection: &mut Option, +) -> RedisResult { + if cached_connection.is_none() { + async_reconnect(cached_connection, connection_info).await?; + } + + let result = cmd.query_async(cached_connection.as_mut().unwrap()).await; + + if let Err(err) = result { + if err.is_unrecoverable_error() || err.is_io_error() { + async_reconnect(cached_connection, connection_info).await?; + cmd.query_async(cached_connection.as_mut().unwrap()).await + } else { + Err(err) + } + } else { + result + } +} + +fn reconnect( + connection: &mut Option, + connection_info: &ConnectionInfo, +) -> RedisResult<()> { + let sentinel_client = Client::open(connection_info.clone())?; + let new_connection = sentinel_client.get_connection(None)?; + connection.replace(new_connection); + Ok(()) +} + +fn try_single_sentinel( + cmd: Cmd, + connection_info: &ConnectionInfo, + cached_connection: &mut Option, +) -> RedisResult { + if cached_connection.is_none() { + reconnect(cached_connection, connection_info)?; + } + + let result = cmd.query(cached_connection.as_mut().unwrap()); + + if let Err(err) = result { + if err.is_unrecoverable_error() || err.is_io_error() { + reconnect(cached_connection, connection_info)?; + cmd.query(cached_connection.as_mut().unwrap()) + } else { + Err(err) + } + } else { + result + } +} + +// non-async methods +impl Sentinel { + /// Creates a Sentinel client performing some basic + /// checks on the URLs that might make the operation fail. + pub fn build(params: Vec) -> RedisResult { + if params.is_empty() { + fail!(( + ErrorKind::EmptySentinelList, + "At least one sentinel is required", + )) + } + + let sentinels_connection_info = params + .into_iter() + .map(|p| p.into_connection_info()) + .collect::>>()?; + + let mut connections_cache = vec![]; + connections_cache.resize_with(sentinels_connection_info.len(), Default::default); + + #[cfg(feature = "aio")] + { + let mut async_connections_cache = vec![]; + async_connections_cache.resize_with(sentinels_connection_info.len(), Default::default); + + Ok(Sentinel { + sentinels_connection_info, + connections_cache, + async_connections_cache, + replica_start_index: random_replica_index(NonZeroUsize::new(1000000).unwrap()), + }) + } + + #[cfg(not(feature = "aio"))] + { + Ok(Sentinel { + sentinels_connection_info, + connections_cache, + replica_start_index: random_replica_index(NonZeroUsize::new(1000000).unwrap()), + }) + } + } + + /// Try to execute the given command in each sentinel, returning the result of the + /// first one that executes without errors. If all return errors, we return the + /// error of the last attempt. + /// + /// For each sentinel, we first check if there is a cached connection, and if not + /// we attempt to connect to it (skipping that sentinel if there is an error during + /// the connection). Then, we attempt to execute the given command with the cached + /// connection. If there is an error indicating that the connection is invalid, we + /// reconnect and try one more time in the new connection. + /// + fn try_all_sentinels(&mut self, cmd: Cmd) -> RedisResult { + let mut last_err = None; + for (connection_info, cached_connection) in self + .sentinels_connection_info + .iter() + .zip(self.connections_cache.iter_mut()) + { + match try_single_sentinel(cmd.clone(), connection_info, cached_connection) { + Ok(result) => { + return Ok(result); + } + Err(err) => { + last_err = Some(err); + } + } + } + + // We can unwrap here because we know there is at least one connection info. + Err(last_err.expect("There should be at least one connection info")) + } + + /// Get a list of all masters (using the command SENTINEL MASTERS) from the + /// sentinels. + fn get_sentinel_masters(&mut self) -> RedisResult>> { + self.try_all_sentinels(sentinel_masters_cmd()) + } + + fn get_sentinel_replicas( + &mut self, + service_name: &str, + ) -> RedisResult>> { + self.try_all_sentinels(sentinel_replicas_cmd(service_name)) + } + + fn find_master_address( + &mut self, + service_name: &str, + node_connection_info: &SentinelNodeConnectionInfo, + ) -> RedisResult { + let masters = self.get_sentinel_masters()?; + find_valid_master(masters, service_name, node_connection_info) + } + + fn find_valid_replica_addresses( + &mut self, + service_name: &str, + node_connection_info: &SentinelNodeConnectionInfo, + ) -> RedisResult> { + let replicas = self.get_sentinel_replicas(service_name)?; + Ok(get_valid_replicas_addresses(replicas, node_connection_info)) + } + + /// Determines the masters address for the given name, and returns a client for that + /// master. + pub fn master_for( + &mut self, + service_name: &str, + node_connection_info: Option<&SentinelNodeConnectionInfo>, + ) -> RedisResult { + let connection_info = + self.find_master_address(service_name, node_connection_info.unwrap_or_default())?; + Client::open(connection_info) + } + + /// Connects to a randomly chosen replica of the given master name. + pub fn replica_for( + &mut self, + service_name: &str, + node_connection_info: Option<&SentinelNodeConnectionInfo>, + ) -> RedisResult { + let addresses = self + .find_valid_replica_addresses(service_name, node_connection_info.unwrap_or_default())?; + let start_index = NonZeroUsize::new(addresses.len()).map(random_replica_index); + try_connect_to_first_replica(&addresses, start_index) + } + + /// Attempts to connect to a different replica of the given master name each time. + /// There is no guarantee that we'll actually be connecting to a different replica + /// in the next call, but in a static set of replicas (no replicas added or + /// removed), on average we'll choose each replica the same number of times. + pub fn replica_rotate_for( + &mut self, + service_name: &str, + node_connection_info: Option<&SentinelNodeConnectionInfo>, + ) -> RedisResult { + let addresses = self + .find_valid_replica_addresses(service_name, node_connection_info.unwrap_or_default())?; + if !addresses.is_empty() { + self.replica_start_index = (self.replica_start_index + 1) % addresses.len(); + } + try_connect_to_first_replica(&addresses, Some(self.replica_start_index)) + } +} + +// Async versions of the public methods above, along with async versions of private +// methods required for the public methods. +#[cfg(feature = "aio")] +impl Sentinel { + async fn async_try_all_sentinels(&mut self, cmd: Cmd) -> RedisResult { + let mut last_err = None; + for (connection_info, cached_connection) in self + .sentinels_connection_info + .iter() + .zip(self.async_connections_cache.iter_mut()) + { + match async_try_single_sentinel(cmd.clone(), connection_info, cached_connection).await { + Ok(result) => { + return Ok(result); + } + Err(err) => { + last_err = Some(err); + } + } + } + + // We can unwrap here because we know there is at least one connection info. + Err(last_err.expect("There should be at least one connection info")) + } + + async fn async_get_sentinel_masters(&mut self) -> RedisResult>> { + self.async_try_all_sentinels(sentinel_masters_cmd()).await + } + + async fn async_get_sentinel_replicas<'a>( + &mut self, + service_name: &'a str, + ) -> RedisResult>> { + self.async_try_all_sentinels(sentinel_replicas_cmd(service_name)) + .await + } + + async fn async_find_master_address<'a>( + &mut self, + service_name: &str, + node_connection_info: &SentinelNodeConnectionInfo, + ) -> RedisResult { + let masters = self.async_get_sentinel_masters().await?; + async_find_valid_master(masters, service_name, node_connection_info).await + } + + async fn async_find_valid_replica_addresses<'a>( + &mut self, + service_name: &str, + node_connection_info: &SentinelNodeConnectionInfo, + ) -> RedisResult> { + let replicas = self.async_get_sentinel_replicas(service_name).await?; + Ok(async_get_valid_replicas_addresses(replicas, node_connection_info).await) + } + + /// Determines the masters address for the given name, and returns a client for that + /// master. + pub async fn async_master_for( + &mut self, + service_name: &str, + node_connection_info: Option<&SentinelNodeConnectionInfo>, + ) -> RedisResult { + let address = self + .async_find_master_address(service_name, node_connection_info.unwrap_or_default()) + .await?; + Client::open(address) + } + + /// Connects to a randomly chosen replica of the given master name. + pub async fn async_replica_for( + &mut self, + service_name: &str, + node_connection_info: Option<&SentinelNodeConnectionInfo>, + ) -> RedisResult { + let addresses = self + .async_find_valid_replica_addresses( + service_name, + node_connection_info.unwrap_or_default(), + ) + .await?; + let start_index = NonZeroUsize::new(addresses.len()).map(random_replica_index); + try_connect_to_first_replica(&addresses, start_index) + } + + /// Attempts to connect to a different replica of the given master name each time. + /// There is no guarantee that we'll actually be connecting to a different replica + /// in the next call, but in a static set of replicas (no replicas added or + /// removed), on average we'll choose each replica the same number of times. + pub async fn async_replica_rotate_for<'a>( + &mut self, + service_name: &str, + node_connection_info: Option<&SentinelNodeConnectionInfo>, + ) -> RedisResult { + let addresses = self + .async_find_valid_replica_addresses( + service_name, + node_connection_info.unwrap_or_default(), + ) + .await?; + if !addresses.is_empty() { + self.replica_start_index = (self.replica_start_index + 1) % addresses.len(); + } + try_connect_to_first_replica(&addresses, Some(self.replica_start_index)) + } +} + +/// Enum defining the server types from a sentinel's point of view. +#[derive(Debug, Clone)] +pub enum SentinelServerType { + /// Master connections only + Master, + /// Replica connections only + Replica, +} + +/// An alternative to the Client type which creates connections from clients created +/// on-demand based on information fetched from the sentinels. Uses the Sentinel type +/// internally. This is basic an utility to help make it easier to use sentinels but +/// with an interface similar to the client (`get_connection` and +/// `get_async_connection`). The type of server (master or replica) and name of the +/// desired master are specified when constructing an instance, so it will always +/// return connections to the same target (for example, always to the master with name +/// "mymaster123", or always to replicas of the master "another-master-abc"). +pub struct SentinelClient { + sentinel: Sentinel, + service_name: String, + node_connection_info: SentinelNodeConnectionInfo, + server_type: SentinelServerType, +} + +impl SentinelClient { + /// Creates a SentinelClient performing some basic checks on the URLs that might + /// result in an error. + pub fn build( + params: Vec, + service_name: String, + node_connection_info: Option, + server_type: SentinelServerType, + ) -> RedisResult { + Ok(SentinelClient { + sentinel: Sentinel::build(params)?, + service_name, + node_connection_info: node_connection_info.unwrap_or_default(), + server_type, + }) + } + + fn get_client(&mut self) -> RedisResult { + match self.server_type { + SentinelServerType::Master => self + .sentinel + .master_for(self.service_name.as_str(), Some(&self.node_connection_info)), + SentinelServerType::Replica => self + .sentinel + .replica_for(self.service_name.as_str(), Some(&self.node_connection_info)), + } + } + + /// Creates a new connection to the desired type of server (based on the + /// service/master name, and the server type). We use a Sentinel to create a client + /// for the target type of server, and then create a connection using that client. + pub fn get_connection(&mut self) -> RedisResult { + let client = self.get_client()?; + client.get_connection(None) + } +} + +/// To enable async support you need to chose one of the supported runtimes and active its +/// corresponding feature: `tokio-comp` or `async-std-comp` +#[cfg(feature = "aio")] +#[cfg_attr(docsrs, doc(cfg(feature = "aio")))] +impl SentinelClient { + async fn async_get_client(&mut self) -> RedisResult { + match self.server_type { + SentinelServerType::Master => { + self.sentinel + .async_master_for(self.service_name.as_str(), Some(&self.node_connection_info)) + .await + } + SentinelServerType::Replica => { + self.sentinel + .async_replica_for(self.service_name.as_str(), Some(&self.node_connection_info)) + .await + } + } + } + + /// Returns an async connection from the client, using the same logic from + /// `SentinelClient::get_connection`. + #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] + pub async fn get_async_connection(&mut self) -> RedisResult { + let client = self.async_get_client().await?; + client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await + } +} diff --git a/glide-core/redis-rs/redis/src/streams.rs b/glide-core/redis-rs/redis/src/streams.rs new file mode 100644 index 0000000000..62505d6d75 --- /dev/null +++ b/glide-core/redis-rs/redis/src/streams.rs @@ -0,0 +1,670 @@ +//! Defines types to use with the streams commands. + +use crate::{ + from_redis_value, types::HashMap, FromRedisValue, RedisResult, RedisWrite, ToRedisArgs, Value, +}; + +use std::io::{Error, ErrorKind}; + +// Stream Maxlen Enum + +/// Utility enum for passing `MAXLEN [= or ~] [COUNT]` +/// arguments into `StreamCommands`. +/// The enum value represents the count. +#[derive(PartialEq, Eq, Clone, Debug, Copy)] +pub enum StreamMaxlen { + /// Match an exact count + Equals(usize), + /// Match an approximate count + Approx(usize), +} + +impl ToRedisArgs for StreamMaxlen { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + let (ch, val) = match *self { + StreamMaxlen::Equals(v) => ("=", v), + StreamMaxlen::Approx(v) => ("~", v), + }; + out.write_arg(b"MAXLEN"); + out.write_arg(ch.as_bytes()); + val.write_redis_args(out); + } +} + +/// Builder options for [`xclaim_options`] command. +/// +/// [`xclaim_options`]: ../trait.Commands.html#method.xclaim_options +/// +#[derive(Default, Debug)] +pub struct StreamClaimOptions { + /// Set `IDLE ` cmd arg. + idle: Option, + /// Set `TIME ` cmd arg. + time: Option, + /// Set `RETRYCOUNT ` cmd arg. + retry: Option, + /// Set `FORCE` cmd arg. + force: bool, + /// Set `JUSTID` cmd arg. Be advised: the response + /// type changes with this option. + justid: bool, +} + +impl StreamClaimOptions { + /// Set `IDLE ` cmd arg. + pub fn idle(mut self, ms: usize) -> Self { + self.idle = Some(ms); + self + } + + /// Set `TIME ` cmd arg. + pub fn time(mut self, ms_time: usize) -> Self { + self.time = Some(ms_time); + self + } + + /// Set `RETRYCOUNT ` cmd arg. + pub fn retry(mut self, count: usize) -> Self { + self.retry = Some(count); + self + } + + /// Set `FORCE` cmd arg to true. + pub fn with_force(mut self) -> Self { + self.force = true; + self + } + + /// Set `JUSTID` cmd arg to true. Be advised: the response + /// type changes with this option. + pub fn with_justid(mut self) -> Self { + self.justid = true; + self + } +} + +impl ToRedisArgs for StreamClaimOptions { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + if let Some(ref ms) = self.idle { + out.write_arg(b"IDLE"); + out.write_arg(format!("{ms}").as_bytes()); + } + if let Some(ref ms_time) = self.time { + out.write_arg(b"TIME"); + out.write_arg(format!("{ms_time}").as_bytes()); + } + if let Some(ref count) = self.retry { + out.write_arg(b"RETRYCOUNT"); + out.write_arg(format!("{count}").as_bytes()); + } + if self.force { + out.write_arg(b"FORCE"); + } + if self.justid { + out.write_arg(b"JUSTID"); + } + } +} + +/// Argument to `StreamReadOptions` +/// Represents the Redis `GROUP ` cmd arg. +/// This option will toggle the cmd from `XREAD` to `XREADGROUP` +type SRGroup = Option<(Vec>, Vec>)>; +/// Builder options for [`xread_options`] command. +/// +/// [`xread_options`]: ../trait.Commands.html#method.xread_options +/// +#[derive(Default, Debug)] +pub struct StreamReadOptions { + /// Set the `BLOCK ` cmd arg. + block: Option, + /// Set the `COUNT ` cmd arg. + count: Option, + /// Set the `NOACK` cmd arg. + noack: Option, + /// Set the `GROUP ` cmd arg. + /// This option will toggle the cmd from XREAD to XREADGROUP. + group: SRGroup, +} + +impl StreamReadOptions { + /// Indicates whether the command is participating in a group + /// and generating ACKs + pub fn read_only(&self) -> bool { + self.group.is_none() + } + + /// Sets the command so that it avoids adding the message + /// to the PEL in cases where reliability is not a requirement + /// and the occasional message loss is acceptable. + pub fn noack(mut self) -> Self { + self.noack = Some(true); + self + } + + /// Sets the block time in milliseconds. + pub fn block(mut self, ms: usize) -> Self { + self.block = Some(ms); + self + } + + /// Sets the maximum number of elements to return per stream. + pub fn count(mut self, n: usize) -> Self { + self.count = Some(n); + self + } + + /// Sets the name of a consumer group associated to the stream. + pub fn group( + mut self, + group_name: GN, + consumer_name: CN, + ) -> Self { + self.group = Some(( + ToRedisArgs::to_redis_args(&group_name), + ToRedisArgs::to_redis_args(&consumer_name), + )); + self + } +} + +impl ToRedisArgs for StreamReadOptions { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + if let Some(ref group) = self.group { + out.write_arg(b"GROUP"); + for i in &group.0 { + out.write_arg(i); + } + for i in &group.1 { + out.write_arg(i); + } + } + + if let Some(ref ms) = self.block { + out.write_arg(b"BLOCK"); + out.write_arg(format!("{ms}").as_bytes()); + } + + if let Some(ref n) = self.count { + out.write_arg(b"COUNT"); + out.write_arg(format!("{n}").as_bytes()); + } + + if self.group.is_some() { + // noack is only available w/ xreadgroup + if self.noack == Some(true) { + out.write_arg(b"NOACK"); + } + } + } +} + +/// Reply type used with [`xread`] or [`xread_options`] commands. +/// +/// [`xread`]: ../trait.Commands.html#method.xread +/// [`xread_options`]: ../trait.Commands.html#method.xread_options +/// +#[derive(Default, Debug, Clone)] +pub struct StreamReadReply { + /// Complex data structure containing a payload for each key in this array + pub keys: Vec, +} + +/// Reply type used with [`xrange`], [`xrange_count`], [`xrange_all`], [`xrevrange`], [`xrevrange_count`], [`xrevrange_all`] commands. +/// +/// Represents stream entries matching a given range of `id`'s. +/// +/// [`xrange`]: ../trait.Commands.html#method.xrange +/// [`xrange_count`]: ../trait.Commands.html#method.xrange_count +/// [`xrange_all`]: ../trait.Commands.html#method.xrange_all +/// [`xrevrange`]: ../trait.Commands.html#method.xrevrange +/// [`xrevrange_count`]: ../trait.Commands.html#method.xrevrange_count +/// [`xrevrange_all`]: ../trait.Commands.html#method.xrevrange_all +/// +#[derive(Default, Debug, Clone)] +pub struct StreamRangeReply { + /// Complex data structure containing a payload for each ID in this array + pub ids: Vec, +} + +/// Reply type used with [`xclaim`] command. +/// +/// Represents that ownership of the specified messages was changed. +/// +/// [`xclaim`]: ../trait.Commands.html#method.xclaim +/// +#[derive(Default, Debug, Clone)] +pub struct StreamClaimReply { + /// Complex data structure containing a payload for each ID in this array + pub ids: Vec, +} + +/// Reply type used with [`xpending`] command. +/// +/// Data returned here were fetched from the stream without +/// having been acknowledged. +/// +/// [`xpending`]: ../trait.Commands.html#method.xpending +/// +#[derive(Debug, Clone, Default)] +pub enum StreamPendingReply { + /// The stream is empty. + #[default] + Empty, + /// Data with payload exists in the stream. + Data(StreamPendingData), +} + +impl StreamPendingReply { + /// Returns how many records are in the reply. + pub fn count(&self) -> usize { + match self { + StreamPendingReply::Empty => 0, + StreamPendingReply::Data(x) => x.count, + } + } +} + +/// Inner reply type when an [`xpending`] command has data. +/// +/// [`xpending`]: ../trait.Commands.html#method.xpending +#[derive(Default, Debug, Clone)] +pub struct StreamPendingData { + /// Limit on the number of messages to return per call. + pub count: usize, + /// ID for the first pending record. + pub start_id: String, + /// ID for the final pending record. + pub end_id: String, + /// Every consumer in the consumer group with at + /// least one pending message, + /// and the number of pending messages it has. + pub consumers: Vec, +} + +/// Reply type used with [`xpending_count`] and +/// [`xpending_consumer_count`] commands. +/// +/// Data returned here have been fetched from the stream without +/// any acknowledgement. +/// +/// [`xpending_count`]: ../trait.Commands.html#method.xpending_count +/// [`xpending_consumer_count`]: ../trait.Commands.html#method.xpending_consumer_count +/// +#[derive(Default, Debug, Clone)] +pub struct StreamPendingCountReply { + /// An array of structs containing information about + /// message IDs yet to be acknowledged by various consumers, + /// time since last ack, and total number of acks by that consumer. + pub ids: Vec, +} + +/// Reply type used with [`xinfo_stream`] command, containing +/// general information about the stream stored at the specified key. +/// +/// The very first and last IDs in the stream are shown, +/// in order to give some sense about what is the stream content. +/// +/// [`xinfo_stream`]: ../trait.Commands.html#method.xinfo_stream +/// +#[derive(Default, Debug, Clone)] +pub struct StreamInfoStreamReply { + /// The last generated ID that may not be the same as the last + /// entry ID in case some entry was deleted. + pub last_generated_id: String, + /// Details about the radix tree representing the stream mostly + /// useful for optimization and debugging tasks. + pub radix_tree_keys: usize, + /// The number of consumer groups associated with the stream. + pub groups: usize, + /// Number of elements of the stream. + pub length: usize, + /// The very first entry in the stream. + pub first_entry: StreamId, + /// The very last entry in the stream. + pub last_entry: StreamId, +} + +/// Reply type used with [`xinfo_consumer`] command, an array of every +/// consumer in a specific consumer group. +/// +/// [`xinfo_consumer`]: ../trait.Commands.html#method.xinfo_consumer +/// +#[derive(Default, Debug, Clone)] +pub struct StreamInfoConsumersReply { + /// An array of every consumer in a specific consumer group. + pub consumers: Vec, +} + +/// Reply type used with [`xinfo_groups`] command. +/// +/// This output represents all the consumer groups associated with +/// the stream. +/// +/// [`xinfo_groups`]: ../trait.Commands.html#method.xinfo_groups +/// +#[derive(Default, Debug, Clone)] +pub struct StreamInfoGroupsReply { + /// All the consumer groups associated with the stream. + pub groups: Vec, +} + +/// A consumer parsed from [`xinfo_consumers`] command. +/// +/// [`xinfo_consumers`]: ../trait.Commands.html#method.xinfo_consumers +/// +#[derive(Default, Debug, Clone)] +pub struct StreamInfoConsumer { + /// Name of the consumer group. + pub name: String, + /// Number of pending messages for this specific consumer. + pub pending: usize, + /// This consumer's idle time in milliseconds. + pub idle: usize, +} + +/// A group parsed from [`xinfo_groups`] command. +/// +/// [`xinfo_groups`]: ../trait.Commands.html#method.xinfo_groups +/// +#[derive(Default, Debug, Clone)] +pub struct StreamInfoGroup { + /// The group name. + pub name: String, + /// Number of consumers known in the group. + pub consumers: usize, + /// Number of pending messages (delivered but not yet acknowledged) in the group. + pub pending: usize, + /// Last ID delivered to this group. + pub last_delivered_id: String, +} + +/// Represents a pending message parsed from [`xpending`] methods. +/// +/// [`xpending`]: ../trait.Commands.html#method.xpending +#[derive(Default, Debug, Clone)] +pub struct StreamPendingId { + /// The ID of the message. + pub id: String, + /// The name of the consumer that fetched the message and has + /// still to acknowledge it. We call it the current owner + /// of the message. + pub consumer: String, + /// The number of milliseconds that elapsed since the + /// last time this message was delivered to this consumer. + pub last_delivered_ms: usize, + /// The number of times this message was delivered. + pub times_delivered: usize, +} + +/// Represents a stream `key` and its `id`'s parsed from `xread` methods. +#[derive(Default, Debug, Clone)] +pub struct StreamKey { + /// The stream `key`. + pub key: String, + /// The parsed stream `id`'s. + pub ids: Vec, +} + +/// Represents a stream `id` and its field/values as a `HashMap` +#[derive(Default, Debug, Clone)] +pub struct StreamId { + /// The stream `id` (entry ID) of this particular message. + pub id: String, + /// All fields in this message, associated with their respective values. + pub map: HashMap, +} + +impl StreamId { + /// Converts a `Value::Array` into a `StreamId`. + fn from_array_value(v: &Value) -> RedisResult { + let mut stream_id = StreamId::default(); + if let Value::Array(ref values) = *v { + if let Some(v) = values.first() { + stream_id.id = from_redis_value(v)?; + } + if let Some(v) = values.get(1) { + stream_id.map = from_redis_value(v)?; + } + } + + Ok(stream_id) + } + + /// Fetches value of a given field and converts it to the specified + /// type. + pub fn get(&self, key: &str) -> Option { + match self.map.get(key) { + Some(x) => from_redis_value(x).ok(), + None => None, + } + } + + /// Does the message contain a particular field? + pub fn contains_key(&self, key: &str) -> bool { + self.map.contains_key(key) + } + + /// Returns how many field/value pairs exist in this message. + pub fn len(&self) -> usize { + self.map.len() + } + + /// Returns true if there are no field/value pairs in this message. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +type SRRows = Vec>>>>; +impl FromRedisValue for StreamReadReply { + fn from_redis_value(v: &Value) -> RedisResult { + let rows: SRRows = from_redis_value(v)?; + let keys = rows + .into_iter() + .flat_map(|row| { + row.into_iter().map(|(key, entry)| { + let ids = entry + .into_iter() + .flat_map(|id_row| id_row.into_iter().map(|(id, map)| StreamId { id, map })) + .collect(); + StreamKey { key, ids } + }) + }) + .collect(); + Ok(StreamReadReply { keys }) + } +} + +impl FromRedisValue for StreamRangeReply { + fn from_redis_value(v: &Value) -> RedisResult { + let rows: Vec>> = from_redis_value(v)?; + let ids: Vec = rows + .into_iter() + .flat_map(|row| row.into_iter().map(|(id, map)| StreamId { id, map })) + .collect(); + Ok(StreamRangeReply { ids }) + } +} + +impl FromRedisValue for StreamClaimReply { + fn from_redis_value(v: &Value) -> RedisResult { + let rows: Vec>> = from_redis_value(v)?; + let ids: Vec = rows + .into_iter() + .flat_map(|row| row.into_iter().map(|(id, map)| StreamId { id, map })) + .collect(); + Ok(StreamClaimReply { ids }) + } +} + +type SPRInner = ( + usize, + Option, + Option, + Vec>, +); +impl FromRedisValue for StreamPendingReply { + fn from_redis_value(v: &Value) -> RedisResult { + let (count, start, end, consumer_data): SPRInner = from_redis_value(v)?; + + if count == 0 { + Ok(StreamPendingReply::Empty) + } else { + let mut result = StreamPendingData::default(); + + let start_id = start.ok_or_else(|| { + Error::new( + ErrorKind::Other, + "IllegalState: Non-zero pending expects start id", + ) + })?; + + let end_id = end.ok_or_else(|| { + Error::new( + ErrorKind::Other, + "IllegalState: Non-zero pending expects end id", + ) + })?; + + result.count = count; + result.start_id = start_id; + result.end_id = end_id; + + result.consumers = consumer_data + .into_iter() + .flatten() + .map(|(name, pending)| StreamInfoConsumer { + name, + pending: pending.parse().unwrap_or_default(), + ..Default::default() + }) + .collect(); + + Ok(StreamPendingReply::Data(result)) + } + } +} + +impl FromRedisValue for StreamPendingCountReply { + fn from_redis_value(v: &Value) -> RedisResult { + let mut reply = StreamPendingCountReply::default(); + match v { + Value::Array(outer_tuple) => { + for outer in outer_tuple { + match outer { + Value::Array(inner_tuple) => match &inner_tuple[..] { + [Value::BulkString(id_bytes), Value::BulkString(consumer_bytes), Value::Int(last_delivered_ms_u64), Value::Int(times_delivered_u64)] => + { + let id = String::from_utf8(id_bytes.to_vec())?; + let consumer = String::from_utf8(consumer_bytes.to_vec())?; + let last_delivered_ms = *last_delivered_ms_u64 as usize; + let times_delivered = *times_delivered_u64 as usize; + reply.ids.push(StreamPendingId { + id, + consumer, + last_delivered_ms, + times_delivered, + }); + } + _ => fail!(( + crate::types::ErrorKind::TypeError, + "Cannot parse redis data (3)" + )), + }, + _ => fail!(( + crate::types::ErrorKind::TypeError, + "Cannot parse redis data (2)" + )), + } + } + } + _ => fail!(( + crate::types::ErrorKind::TypeError, + "Cannot parse redis data (1)" + )), + }; + Ok(reply) + } +} + +impl FromRedisValue for StreamInfoStreamReply { + fn from_redis_value(v: &Value) -> RedisResult { + let map: HashMap = from_redis_value(v)?; + let mut reply = StreamInfoStreamReply::default(); + if let Some(v) = &map.get("last-generated-id") { + reply.last_generated_id = from_redis_value(v)?; + } + if let Some(v) = &map.get("radix-tree-nodes") { + reply.radix_tree_keys = from_redis_value(v)?; + } + if let Some(v) = &map.get("groups") { + reply.groups = from_redis_value(v)?; + } + if let Some(v) = &map.get("length") { + reply.length = from_redis_value(v)?; + } + if let Some(v) = &map.get("first-entry") { + reply.first_entry = StreamId::from_array_value(v)?; + } + if let Some(v) = &map.get("last-entry") { + reply.last_entry = StreamId::from_array_value(v)?; + } + Ok(reply) + } +} + +impl FromRedisValue for StreamInfoConsumersReply { + fn from_redis_value(v: &Value) -> RedisResult { + let consumers: Vec> = from_redis_value(v)?; + let mut reply = StreamInfoConsumersReply::default(); + for map in consumers { + let mut c = StreamInfoConsumer::default(); + if let Some(v) = &map.get("name") { + c.name = from_redis_value(v)?; + } + if let Some(v) = &map.get("pending") { + c.pending = from_redis_value(v)?; + } + if let Some(v) = &map.get("idle") { + c.idle = from_redis_value(v)?; + } + reply.consumers.push(c); + } + + Ok(reply) + } +} + +impl FromRedisValue for StreamInfoGroupsReply { + fn from_redis_value(v: &Value) -> RedisResult { + let groups: Vec> = from_redis_value(v)?; + let mut reply = StreamInfoGroupsReply::default(); + for map in groups { + let mut g = StreamInfoGroup::default(); + if let Some(v) = &map.get("name") { + g.name = from_redis_value(v)?; + } + if let Some(v) = &map.get("pending") { + g.pending = from_redis_value(v)?; + } + if let Some(v) = &map.get("consumers") { + g.consumers = from_redis_value(v)?; + } + if let Some(v) = &map.get("last-delivered-id") { + g.last_delivered_id = from_redis_value(v)?; + } + reply.groups.push(g); + } + Ok(reply) + } +} diff --git a/glide-core/redis-rs/redis/src/tls.rs b/glide-core/redis-rs/redis/src/tls.rs new file mode 100644 index 0000000000..6886efb836 --- /dev/null +++ b/glide-core/redis-rs/redis/src/tls.rs @@ -0,0 +1,142 @@ +use std::io::{BufRead, Error, ErrorKind as IOErrorKind}; + +use rustls::RootCertStore; +use rustls_pki_types::{CertificateDer, PrivateKeyDer}; + +use crate::{Client, ConnectionAddr, ConnectionInfo, ErrorKind, RedisError, RedisResult}; + +/// Structure to hold mTLS client _certificate_ and _key_ binaries in PEM format +/// +#[derive(Clone)] +pub struct ClientTlsConfig { + /// client certificate byte stream in PEM format + pub client_cert: Vec, + /// client key byte stream in PEM format + pub client_key: Vec, +} + +/// Structure to hold TLS certificates +/// - `client_tls`: binaries of clientkey and certificate within a `ClientTlsConfig` structure if mTLS is used +/// - `root_cert`: binary CA certificate in PEM format if CA is not in local truststore +/// +#[derive(Clone)] +pub struct TlsCertificates { + /// 'ClientTlsConfig' containing client certificate and key if mTLS is to be used + pub client_tls: Option, + /// root certificate byte stream in PEM format if the local truststore is *not* to be used + pub root_cert: Option>, +} + +pub(crate) fn inner_build_with_tls( + mut connection_info: ConnectionInfo, + certificates: TlsCertificates, +) -> RedisResult { + let tls_params = retrieve_tls_certificates(certificates)?; + + connection_info.addr = if let ConnectionAddr::TcpTls { + host, + port, + insecure, + .. + } = connection_info.addr + { + ConnectionAddr::TcpTls { + host, + port, + insecure, + tls_params: Some(tls_params), + } + } else { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Constructing a TLS client requires a URL with the `rediss://` scheme", + ))); + }; + + Ok(Client { connection_info }) +} + +pub(crate) fn retrieve_tls_certificates( + certificates: TlsCertificates, +) -> RedisResult { + let TlsCertificates { + client_tls, + root_cert, + } = certificates; + + let client_tls_params = if let Some(ClientTlsConfig { + client_cert, + client_key, + }) = client_tls + { + let buf = &mut client_cert.as_slice() as &mut dyn BufRead; + let certs = rustls_pemfile::certs(buf); + let client_cert_chain = certs.collect::, _>>()?; + + let client_key = + rustls_pemfile::private_key(&mut client_key.as_slice() as &mut dyn BufRead)? + .ok_or_else(|| { + Error::new( + IOErrorKind::Other, + "Unable to extract private key from PEM file", + ) + })?; + + Some(ClientTlsParams { + client_cert_chain, + client_key, + }) + } else { + None + }; + + let root_cert_store = if let Some(root_cert) = root_cert { + let buf = &mut root_cert.as_slice() as &mut dyn BufRead; + let certs = rustls_pemfile::certs(buf); + let mut root_cert_store = RootCertStore::empty(); + for result in certs { + if root_cert_store.add(result?.to_owned()).is_err() { + return Err( + Error::new(IOErrorKind::Other, "Unable to parse TLS trust anchors").into(), + ); + } + } + + Some(root_cert_store) + } else { + None + }; + + Ok(TlsConnParams { + client_tls_params, + root_cert_store, + }) +} + +#[derive(Debug)] +pub struct ClientTlsParams { + pub(crate) client_cert_chain: Vec>, + pub(crate) client_key: PrivateKeyDer<'static>, +} + +/// [`PrivateKeyDer`] does not implement `Clone` so we need to implement it manually. +impl Clone for ClientTlsParams { + fn clone(&self) -> Self { + use PrivateKeyDer::*; + Self { + client_cert_chain: self.client_cert_chain.clone(), + client_key: match &self.client_key { + Pkcs1(key) => Pkcs1(key.secret_pkcs1_der().to_vec().into()), + Pkcs8(key) => Pkcs8(key.secret_pkcs8_der().to_vec().into()), + Sec1(key) => Sec1(key.secret_sec1_der().to_vec().into()), + _ => unreachable!(), + }, + } + } +} + +#[derive(Debug, Clone)] +pub struct TlsConnParams { + pub(crate) client_tls_params: Option, + pub(crate) root_cert_store: Option, +} diff --git a/glide-core/redis-rs/redis/src/types.rs b/glide-core/redis-rs/redis/src/types.rs new file mode 100644 index 0000000000..a024f16a7d --- /dev/null +++ b/glide-core/redis-rs/redis/src/types.rs @@ -0,0 +1,2460 @@ +use std::collections::{BTreeMap, BTreeSet}; +use std::default::Default; +use std::error; +use std::ffi::{CString, NulError}; +use std::fmt; +use std::hash::{BuildHasher, Hash}; +use std::io; +use std::str::{from_utf8, Utf8Error}; +use std::string::FromUtf8Error; + +#[cfg(feature = "ahash")] +pub(crate) use ahash::{AHashMap as HashMap, AHashSet as HashSet}; +use num_bigint::BigInt; +#[cfg(not(feature = "ahash"))] +pub(crate) use std::collections::{HashMap, HashSet}; +use std::ops::Deref; + +macro_rules! invalid_type_error { + ($v:expr, $det:expr) => {{ + fail!(invalid_type_error_inner!($v, $det)) + }}; +} + +macro_rules! invalid_type_error_inner { + ($v:expr, $det:expr) => { + RedisError::from(( + ErrorKind::TypeError, + "Response was of incompatible type", + format!("{:?} (response was {:?})", $det, $v), + )) + }; +} + +/// Helper enum that is used to define expiry time +pub enum Expiry { + /// EX seconds -- Set the specified expire time, in seconds. + EX(usize), + /// PX milliseconds -- Set the specified expire time, in milliseconds. + PX(usize), + /// EXAT timestamp-seconds -- Set the specified Unix time at which the key will expire, in seconds. + EXAT(usize), + /// PXAT timestamp-milliseconds -- Set the specified Unix time at which the key will expire, in milliseconds. + PXAT(usize), + /// PERSIST -- Remove the time to live associated with the key. + PERSIST, +} + +/// Helper enum that is used to define expiry time for SET command +#[derive(Clone, Copy)] +pub enum SetExpiry { + /// EX seconds -- Set the specified expire time, in seconds. + EX(usize), + /// PX milliseconds -- Set the specified expire time, in milliseconds. + PX(usize), + /// EXAT timestamp-seconds -- Set the specified Unix time at which the key will expire, in seconds. + EXAT(usize), + /// PXAT timestamp-milliseconds -- Set the specified Unix time at which the key will expire, in milliseconds. + PXAT(usize), + /// KEEPTTL -- Retain the time to live associated with the key. + KEEPTTL, +} + +/// Helper enum that is used to define existence checks +#[derive(Clone, Copy)] +pub enum ExistenceCheck { + /// NX -- Only set the key if it does not already exist. + NX, + /// XX -- Only set the key if it already exists. + XX, +} + +/// Helper enum that is used in some situations to describe +/// the behavior of arguments in a numeric context. +#[derive(PartialEq, Eq, Clone, Debug, Copy)] +pub enum NumericBehavior { + /// This argument is not numeric. + NonNumeric, + /// This argument is an integer. + NumberIsInteger, + /// This argument is a floating point value. + NumberIsFloat, +} + +/// An enum of all error kinds. +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +#[non_exhaustive] +pub enum ErrorKind { + /// The server generated an invalid response. + ResponseError, + /// The parser failed to parse the server response. + ParseError, + /// The authentication with the server failed. + AuthenticationFailed, + /// Operation failed because of a type mismatch. + TypeError, + /// A script execution was aborted. + ExecAbortError, + /// The server cannot response because it's loading a dump. + BusyLoadingError, + /// A script that was requested does not actually exist. + NoScriptError, + /// An error that was caused because the parameter to the + /// client were wrong. + InvalidClientConfig, + /// Raised if a key moved to a different node. + Moved, + /// Raised if a key moved to a different node but we need to ask. + Ask, + /// Raised if a request needs to be retried. + TryAgain, + /// Raised if a redis cluster is down. + ClusterDown, + /// A request spans multiple slots + CrossSlot, + /// A cluster master is unavailable. + MasterDown, + /// This kind is returned if the redis error is one that is + /// not native to the system. This is usually the case if + /// the cause is another error. + IoError, + /// An error raised that was identified on the client before execution. + ClientError, + /// An extension error. This is an error created by the server + /// that is not directly understood by the library. + ExtensionError, + /// Attempt to write to a read-only server + ReadOnly, + /// Requested name not found among masters returned by the sentinels + MasterNameNotFoundBySentinel, + /// No valid replicas found in the sentinels, for a given master name + NoValidReplicasFoundBySentinel, + /// At least one sentinel connection info is required + EmptySentinelList, + /// Attempted to kill a script/function while they werent' executing + NotBusy, + /// Used when no valid node connections remain in the cluster connection + AllConnectionsUnavailable, + /// Used when a connection is not found for the specified route. + ConnectionNotFoundForRoute, + + #[cfg(feature = "json")] + /// Error Serializing a struct to JSON form + Serialize, + + /// Redis Servers prior to v6.0.0 doesn't support RESP3. + /// Try disabling resp3 option + RESP3NotSupported, + + /// Not all slots are covered by the cluster + NotAllSlotsCovered, +} + +#[derive(PartialEq, Debug)] +pub(crate) enum ServerErrorKind { + ResponseError, + ExecAbortError, + BusyLoadingError, + NoScriptError, + Moved, + Ask, + TryAgain, + ClusterDown, + CrossSlot, + MasterDown, + ReadOnly, + NotBusy, +} + +#[derive(PartialEq, Debug)] +pub(crate) enum ServerError { + ExtensionError { + code: String, + detail: Option, + }, + KnownError { + kind: ServerErrorKind, + detail: Option, + }, +} + +impl From for RedisError { + fn from(value: ServerError) -> Self { + // TODO - Consider changing RedisError to explicitly represent whether an error came from the server or not. Today it is only implied. + match value { + ServerError::ExtensionError { code, detail } => make_extension_error(code, detail), + ServerError::KnownError { kind, detail } => { + let desc = "An error was signalled by the server"; + let kind = match kind { + ServerErrorKind::ResponseError => ErrorKind::ResponseError, + ServerErrorKind::ExecAbortError => ErrorKind::ExecAbortError, + ServerErrorKind::BusyLoadingError => ErrorKind::BusyLoadingError, + ServerErrorKind::NoScriptError => ErrorKind::NoScriptError, + ServerErrorKind::Moved => ErrorKind::Moved, + ServerErrorKind::Ask => ErrorKind::Ask, + ServerErrorKind::TryAgain => ErrorKind::TryAgain, + ServerErrorKind::ClusterDown => ErrorKind::ClusterDown, + ServerErrorKind::CrossSlot => ErrorKind::CrossSlot, + ServerErrorKind::MasterDown => ErrorKind::MasterDown, + ServerErrorKind::ReadOnly => ErrorKind::ReadOnly, + ServerErrorKind::NotBusy => ErrorKind::NotBusy, + }; + match detail { + Some(detail) => RedisError::from((kind, desc, detail)), + None => RedisError::from((kind, desc)), + } + } + } + } +} + +/// Internal low-level redis value enum. +#[derive(PartialEq, Debug)] +pub(crate) enum InternalValue { + /// A nil response from the server. + Nil, + /// An integer response. Note that there are a few situations + /// in which redis actually returns a string for an integer which + /// is why this library generally treats integers and strings + /// the same for all numeric responses. + Int(i64), + /// An arbitrary binary data, usually represents a binary-safe string. + BulkString(Vec), + /// A response containing an array with more data. This is generally used by redis + /// to express nested structures. + Array(Vec), + /// A simple string response, without line breaks and not binary safe. + SimpleString(String), + /// A status response which represents the string "OK". + Okay, + /// Unordered key,value list from the server. Use `as_map_iter` function. + Map(Vec<(InternalValue, InternalValue)>), + /// Attribute value from the server. Client will give data instead of whole Attribute type. + Attribute { + /// Data that attributes belong to. + data: Box, + /// Key,Value list of attributes. + attributes: Vec<(InternalValue, InternalValue)>, + }, + /// Unordered set value from the server. + Set(Vec), + /// A floating number response from the server. + Double(f64), + /// A boolean response from the server. + Boolean(bool), + /// First String is format and other is the string + VerbatimString { + /// Text's format type + format: VerbatimFormat, + /// Remaining string check format before using! + text: String, + }, + /// Very large number that out of the range of the signed 64 bit numbers + BigNumber(BigInt), + /// Push data from the server. + Push { + /// Push Kind + kind: PushKind, + /// Remaining data from push message + data: Vec, + }, + ServerError(ServerError), +} + +impl InternalValue { + pub(crate) fn try_into(self) -> RedisResult { + match self { + InternalValue::Nil => Ok(Value::Nil), + InternalValue::Int(val) => Ok(Value::Int(val)), + InternalValue::BulkString(val) => Ok(Value::BulkString(val)), + InternalValue::Array(val) => Ok(Value::Array(Self::try_into_vec(val)?)), + InternalValue::SimpleString(val) => Ok(Value::SimpleString(val)), + InternalValue::Okay => Ok(Value::Okay), + InternalValue::Map(map) => Ok(Value::Map(Self::try_into_map(map)?)), + InternalValue::Attribute { data, attributes } => { + let data = Box::new((*data).try_into()?); + let attributes = Self::try_into_map(attributes)?; + Ok(Value::Attribute { data, attributes }) + } + InternalValue::Set(set) => Ok(Value::Set(Self::try_into_vec(set)?)), + InternalValue::Double(double) => Ok(Value::Double(double)), + InternalValue::Boolean(boolean) => Ok(Value::Boolean(boolean)), + InternalValue::VerbatimString { format, text } => { + Ok(Value::VerbatimString { format, text }) + } + InternalValue::BigNumber(number) => Ok(Value::BigNumber(number)), + InternalValue::Push { kind, data } => Ok(Value::Push { + kind, + data: Self::try_into_vec(data)?, + }), + + InternalValue::ServerError(err) => Err(err.into()), + } + } + + fn try_into_vec(vec: Vec) -> RedisResult> { + vec.into_iter() + .map(InternalValue::try_into) + .collect::>>() + } + + fn try_into_map(map: Vec<(InternalValue, InternalValue)>) -> RedisResult> { + let mut vec = Vec::with_capacity(map.len()); + for (key, value) in map.into_iter() { + vec.push((key.try_into()?, value.try_into()?)); + } + Ok(vec) + } +} + +/// Internal low-level redis value enum. +#[derive(PartialEq, Clone)] +pub enum Value { + /// A nil response from the server. + Nil, + /// An integer response. Note that there are a few situations + /// in which redis actually returns a string for an integer which + /// is why this library generally treats integers and strings + /// the same for all numeric responses. + Int(i64), + /// An arbitrary binary data, usually represents a binary-safe string. + BulkString(Vec), + /// A response containing an array with more data. This is generally used by redis + /// to express nested structures. + Array(Vec), + /// A simple string response, without line breaks and not binary safe. + SimpleString(String), + /// A status response which represents the string "OK". + Okay, + /// Unordered key,value list from the server. Use `as_map_iter` function. + Map(Vec<(Value, Value)>), + /// Attribute value from the server. Client will give data instead of whole Attribute type. + Attribute { + /// Data that attributes belong to. + data: Box, + /// Key,Value list of attributes. + attributes: Vec<(Value, Value)>, + }, + /// Unordered set value from the server. + Set(Vec), + /// A floating number response from the server. + Double(f64), + /// A boolean response from the server. + Boolean(bool), + /// First String is format and other is the string + VerbatimString { + /// Text's format type + format: VerbatimFormat, + /// Remaining string check format before using! + text: String, + }, + /// Very large number that out of the range of the signed 64 bit numbers + BigNumber(BigInt), + /// Push data from the server. + Push { + /// Push Kind + kind: PushKind, + /// Remaining data from push message + data: Vec, + }, +} + +/// `VerbatimString`'s format types defined by spec +#[derive(PartialEq, Clone, Debug)] +pub enum VerbatimFormat { + /// Unknown type to catch future formats. + Unknown(String), + /// `mkd` format + Markdown, + /// `txt` format + Text, +} + +/// `Push` type's currently known kinds. +#[derive(PartialEq, Clone, Debug)] +pub enum PushKind { + /// `Disconnection` is sent from the **library** when connection is closed. + Disconnection, + /// Other kind to catch future kinds. + Other(String), + /// `invalidate` is received when a key is changed/deleted. + Invalidate, + /// `message` is received when pubsub message published by another client. + Message, + /// `pmessage` is received when pubsub message published by another client and client subscribed to topic via pattern. + PMessage, + /// `smessage` is received when pubsub message published by another client and client subscribed to it with sharding. + SMessage, + /// `unsubscribe` is received when client unsubscribed from a channel. + Unsubscribe, + /// `punsubscribe` is received when client unsubscribed from a pattern. + PUnsubscribe, + /// `sunsubscribe` is received when client unsubscribed from a shard channel. + SUnsubscribe, + /// `subscribe` is received when client subscribed to a channel. + Subscribe, + /// `psubscribe` is received when client subscribed to a pattern. + PSubscribe, + /// `ssubscribe` is received when client subscribed to a shard channel. + SSubscribe, +} + +impl PushKind { + #[cfg(feature = "aio")] + pub(crate) fn has_reply(&self) -> bool { + matches!( + self, + &PushKind::Unsubscribe + | &PushKind::PUnsubscribe + | &PushKind::SUnsubscribe + | &PushKind::Subscribe + | &PushKind::PSubscribe + | &PushKind::SSubscribe + ) + } +} + +impl fmt::Display for VerbatimFormat { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + VerbatimFormat::Markdown => write!(f, "mkd"), + VerbatimFormat::Unknown(val) => write!(f, "{val}"), + VerbatimFormat::Text => write!(f, "txt"), + } + } +} + +impl fmt::Display for PushKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + PushKind::Other(kind) => write!(f, "{}", kind), + PushKind::Invalidate => write!(f, "invalidate"), + PushKind::Message => write!(f, "message"), + PushKind::PMessage => write!(f, "pmessage"), + PushKind::SMessage => write!(f, "smessage"), + PushKind::Unsubscribe => write!(f, "unsubscribe"), + PushKind::PUnsubscribe => write!(f, "punsubscribe"), + PushKind::SUnsubscribe => write!(f, "sunsubscribe"), + PushKind::Subscribe => write!(f, "subscribe"), + PushKind::PSubscribe => write!(f, "psubscribe"), + PushKind::SSubscribe => write!(f, "ssubscribe"), + PushKind::Disconnection => write!(f, "disconnection"), + } + } +} + +pub enum MapIter<'a> { + Array(std::slice::Iter<'a, Value>), + Map(std::slice::Iter<'a, (Value, Value)>), +} + +impl<'a> Iterator for MapIter<'a> { + type Item = (&'a Value, &'a Value); + + fn next(&mut self) -> Option { + match self { + MapIter::Array(iter) => Some((iter.next()?, iter.next()?)), + MapIter::Map(iter) => { + let (k, v) = iter.next()?; + Some((k, v)) + } + } + } + + fn size_hint(&self) -> (usize, Option) { + match self { + MapIter::Array(iter) => iter.size_hint(), + MapIter::Map(iter) => iter.size_hint(), + } + } +} + +pub enum OwnedMapIter { + Array(std::vec::IntoIter), + Map(std::vec::IntoIter<(Value, Value)>), +} + +impl Iterator for OwnedMapIter { + type Item = (Value, Value); + + fn next(&mut self) -> Option { + match self { + OwnedMapIter::Array(iter) => Some((iter.next()?, iter.next()?)), + OwnedMapIter::Map(iter) => iter.next(), + } + } + + fn size_hint(&self) -> (usize, Option) { + match self { + OwnedMapIter::Array(iter) => { + let (low, high) = iter.size_hint(); + (low / 2, high.map(|h| h / 2)) + } + OwnedMapIter::Map(iter) => iter.size_hint(), + } + } +} + +/// Values are generally not used directly unless you are using the +/// more low level functionality in the library. For the most part +/// this is hidden with the help of the `FromRedisValue` trait. +/// +/// While on the redis protocol there is an error type this is already +/// separated at an early point so the value only holds the remaining +/// types. +impl Value { + /// Checks if the return value looks like it fulfils the cursor + /// protocol. That means the result is an array item of length + /// two with the first one being a cursor and the second an + /// array response. + pub fn looks_like_cursor(&self) -> bool { + match *self { + Value::Array(ref items) => { + if items.len() != 2 { + return false; + } + matches!(items[0], Value::BulkString(_)) && matches!(items[1], Value::Array(_)) + } + _ => false, + } + } + + /// Returns an `&[Value]` if `self` is compatible with a sequence type + pub fn as_sequence(&self) -> Option<&[Value]> { + match self { + Value::Array(items) => Some(&items[..]), + Value::Set(items) => Some(&items[..]), + Value::Nil => Some(&[]), + _ => None, + } + } + + /// Returns a `Vec` if `self` is compatible with a sequence type, + /// otherwise returns `Err(self)`. + pub fn into_sequence(self) -> Result, Value> { + match self { + Value::Array(items) => Ok(items), + Value::Set(items) => Ok(items), + Value::Nil => Ok(vec![]), + _ => Err(self), + } + } + + /// Returns an iterator of `(&Value, &Value)` if `self` is compatible with a map type + pub fn as_map_iter(&self) -> Option> { + match self { + Value::Array(items) => { + if items.len() % 2 == 0 { + Some(MapIter::Array(items.iter())) + } else { + None + } + } + Value::Map(items) => Some(MapIter::Map(items.iter())), + _ => None, + } + } + + /// Returns an iterator of `(Value, Value)` if `self` is compatible with a map type. + /// If not, returns `Err(self)`. + pub fn into_map_iter(self) -> Result { + match self { + Value::Array(items) => { + if items.len() % 2 == 0 { + Ok(OwnedMapIter::Array(items.into_iter())) + } else { + Err(Value::Array(items)) + } + } + Value::Map(items) => Ok(OwnedMapIter::Map(items.into_iter())), + _ => Err(self), + } + } +} + +impl fmt::Debug for Value { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Value::Nil => write!(fmt, "nil"), + Value::Int(val) => write!(fmt, "int({val:?})"), + Value::BulkString(ref val) => match from_utf8(val) { + Ok(x) => write!(fmt, "bulk-string('{x:?}')"), + Err(_) => write!(fmt, "binary-data({val:?})"), + }, + Value::Array(ref values) => write!(fmt, "array({values:?})"), + Value::Push { ref kind, ref data } => write!(fmt, "push({kind:?}, {data:?})"), + Value::Okay => write!(fmt, "ok"), + Value::SimpleString(ref s) => write!(fmt, "simple-string({s:?})"), + Value::Map(ref values) => write!(fmt, "map({values:?})"), + Value::Attribute { + ref data, + attributes: _, + } => write!(fmt, "attribute({data:?})"), + Value::Set(ref values) => write!(fmt, "set({values:?})"), + Value::Double(ref d) => write!(fmt, "double({d:?})"), + Value::Boolean(ref b) => write!(fmt, "boolean({b:?})"), + Value::VerbatimString { + ref format, + ref text, + } => { + write!(fmt, "verbatim-string({:?},{:?})", format, text) + } + Value::BigNumber(ref m) => write!(fmt, "big-number({:?})", m), + } + } +} + +/// Represents a redis error. For the most part you should be using +/// the Error trait to interact with this rather than the actual +/// struct. +pub struct RedisError { + repr: ErrorRepr, +} + +#[cfg(feature = "json")] +impl From for RedisError { + fn from(serde_err: serde_json::Error) -> RedisError { + RedisError::from(( + ErrorKind::Serialize, + "Serialization Error", + format!("{serde_err}"), + )) + } +} + +#[derive(Debug)] +enum ErrorRepr { + WithDescription(ErrorKind, &'static str), + WithDescriptionAndDetail(ErrorKind, &'static str, String), + ExtensionError(String, String), + IoError(io::Error), +} + +impl PartialEq for RedisError { + fn eq(&self, other: &RedisError) -> bool { + match (&self.repr, &other.repr) { + (&ErrorRepr::WithDescription(kind_a, _), &ErrorRepr::WithDescription(kind_b, _)) => { + kind_a == kind_b + } + ( + &ErrorRepr::WithDescriptionAndDetail(kind_a, _, _), + &ErrorRepr::WithDescriptionAndDetail(kind_b, _, _), + ) => kind_a == kind_b, + (ErrorRepr::ExtensionError(a, _), ErrorRepr::ExtensionError(b, _)) => *a == *b, + _ => false, + } + } +} + +impl From for RedisError { + fn from(err: io::Error) -> RedisError { + RedisError { + repr: ErrorRepr::IoError(err), + } + } +} + +impl From for RedisError { + fn from(_: Utf8Error) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescription(ErrorKind::TypeError, "Invalid UTF-8"), + } + } +} + +impl From for RedisError { + fn from(err: NulError) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescriptionAndDetail( + ErrorKind::TypeError, + "Value contains interior nul terminator", + err.to_string(), + ), + } + } +} + +#[cfg(feature = "tls-native-tls")] +impl From for RedisError { + fn from(err: native_tls::Error) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescriptionAndDetail( + ErrorKind::IoError, + "TLS error", + err.to_string(), + ), + } + } +} + +#[cfg(feature = "tls-rustls")] +impl From for RedisError { + fn from(err: rustls::Error) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescriptionAndDetail( + ErrorKind::IoError, + "TLS error", + err.to_string(), + ), + } + } +} + +#[cfg(feature = "tls-rustls")] +impl From for RedisError { + fn from(err: rustls_pki_types::InvalidDnsNameError) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescriptionAndDetail( + ErrorKind::IoError, + "TLS Error", + err.to_string(), + ), + } + } +} + +#[cfg(feature = "uuid")] +impl From for RedisError { + fn from(err: uuid::Error) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescriptionAndDetail( + ErrorKind::TypeError, + "Value is not a valid UUID", + err.to_string(), + ), + } + } +} + +impl From for RedisError { + fn from(_: FromUtf8Error) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescription(ErrorKind::TypeError, "Cannot convert from UTF-8"), + } + } +} + +impl From<(ErrorKind, &'static str)> for RedisError { + fn from((kind, desc): (ErrorKind, &'static str)) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescription(kind, desc), + } + } +} + +impl From<(ErrorKind, &'static str, String)> for RedisError { + fn from((kind, desc, detail): (ErrorKind, &'static str, String)) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescriptionAndDetail(kind, desc, detail), + } + } +} + +impl error::Error for RedisError { + #[allow(deprecated)] + fn description(&self) -> &str { + match self.repr { + ErrorRepr::WithDescription(_, desc) => desc, + ErrorRepr::WithDescriptionAndDetail(_, desc, _) => desc, + ErrorRepr::ExtensionError(_, _) => "extension error", + ErrorRepr::IoError(ref err) => err.description(), + } + } + + fn cause(&self) -> Option<&dyn error::Error> { + match self.repr { + ErrorRepr::IoError(ref err) => Some(err as &dyn error::Error), + _ => None, + } + } +} + +impl fmt::Display for RedisError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + match self.repr { + ErrorRepr::WithDescription(kind, desc) => { + desc.fmt(f)?; + f.write_str("- ")?; + fmt::Debug::fmt(&kind, f) + } + ErrorRepr::WithDescriptionAndDetail(kind, desc, ref detail) => { + desc.fmt(f)?; + f.write_str(" - ")?; + fmt::Debug::fmt(&kind, f)?; + f.write_str(": ")?; + detail.fmt(f) + } + ErrorRepr::ExtensionError(ref code, ref detail) => { + code.fmt(f)?; + f.write_str(": ")?; + detail.fmt(f) + } + ErrorRepr::IoError(ref err) => err.fmt(f), + } + } +} + +impl fmt::Debug for RedisError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + fmt::Display::fmt(self, f) + } +} + +pub(crate) enum RetryMethod { + Reconnect, + NoRetry, + RetryImmediately, + WaitAndRetry, + AskRedirect, + MovedRedirect, + WaitAndRetryOnPrimaryRedirectOnReplica, +} + +/// Indicates a general failure in the library. +impl RedisError { + /// Returns the kind of the error. + pub fn kind(&self) -> ErrorKind { + match self.repr { + ErrorRepr::WithDescription(kind, _) + | ErrorRepr::WithDescriptionAndDetail(kind, _, _) => kind, + ErrorRepr::ExtensionError(_, _) => ErrorKind::ExtensionError, + ErrorRepr::IoError(_) => ErrorKind::IoError, + } + } + + /// Returns the error detail. + pub fn detail(&self) -> Option<&str> { + match self.repr { + ErrorRepr::WithDescriptionAndDetail(_, _, ref detail) + | ErrorRepr::ExtensionError(_, ref detail) => Some(detail.as_str()), + _ => None, + } + } + + /// Returns the raw error code if available. + pub fn code(&self) -> Option<&str> { + match self.kind() { + ErrorKind::ResponseError => Some("ERR"), + ErrorKind::ExecAbortError => Some("EXECABORT"), + ErrorKind::BusyLoadingError => Some("LOADING"), + ErrorKind::NoScriptError => Some("NOSCRIPT"), + ErrorKind::Moved => Some("MOVED"), + ErrorKind::Ask => Some("ASK"), + ErrorKind::TryAgain => Some("TRYAGAIN"), + ErrorKind::ClusterDown => Some("CLUSTERDOWN"), + ErrorKind::CrossSlot => Some("CROSSSLOT"), + ErrorKind::MasterDown => Some("MASTERDOWN"), + ErrorKind::ReadOnly => Some("READONLY"), + ErrorKind::NotBusy => Some("NOTBUSY"), + _ => match self.repr { + ErrorRepr::ExtensionError(ref code, _) => Some(code), + _ => None, + }, + } + } + + /// Returns the name of the error category for display purposes. + pub fn category(&self) -> &str { + match self.kind() { + ErrorKind::ResponseError => "response error", + ErrorKind::AuthenticationFailed => "authentication failed", + ErrorKind::TypeError => "type error", + ErrorKind::ExecAbortError => "script execution aborted", + ErrorKind::BusyLoadingError => "busy loading", + ErrorKind::NoScriptError => "no script", + ErrorKind::InvalidClientConfig => "invalid client config", + ErrorKind::Moved => "key moved", + ErrorKind::Ask => "key moved (ask)", + ErrorKind::TryAgain => "try again", + ErrorKind::ClusterDown => "cluster down", + ErrorKind::CrossSlot => "cross-slot", + ErrorKind::MasterDown => "master down", + ErrorKind::IoError => "I/O error", + ErrorKind::ExtensionError => "extension error", + ErrorKind::ClientError => "client error", + ErrorKind::ReadOnly => "read-only", + ErrorKind::MasterNameNotFoundBySentinel => "master name not found by sentinel", + ErrorKind::NoValidReplicasFoundBySentinel => "no valid replicas found by sentinel", + ErrorKind::EmptySentinelList => "empty sentinel list", + ErrorKind::NotBusy => "not busy", + ErrorKind::AllConnectionsUnavailable => "no valid connections remain in the cluster", + ErrorKind::ConnectionNotFoundForRoute => "No connection found for the requested route", + #[cfg(feature = "json")] + ErrorKind::Serialize => "serializing", + ErrorKind::RESP3NotSupported => "resp3 is not supported by server", + ErrorKind::ParseError => "parse error", + ErrorKind::NotAllSlotsCovered => "not all slots are covered", + } + } + + /// Indicates that this failure is an IO failure. + pub fn is_io_error(&self) -> bool { + self.as_io_error().is_some() + } + + pub(crate) fn as_io_error(&self) -> Option<&io::Error> { + match &self.repr { + ErrorRepr::IoError(e) => Some(e), + _ => None, + } + } + + /// Indicates that this is a cluster error. + pub fn is_cluster_error(&self) -> bool { + matches!( + self.kind(), + ErrorKind::Moved | ErrorKind::Ask | ErrorKind::TryAgain | ErrorKind::ClusterDown + ) + } + + /// Returns true if this error indicates that the connection was + /// refused. You should generally not rely much on this function + /// unless you are writing unit tests that want to detect if a + /// local server is available. + pub fn is_connection_refusal(&self) -> bool { + match self.repr { + ErrorRepr::IoError(ref err) => { + #[allow(clippy::match_like_matches_macro)] + match err.kind() { + io::ErrorKind::ConnectionRefused => true, + // if we connect to a unix socket and the file does not + // exist yet, then we want to treat this as if it was a + // connection refusal. + io::ErrorKind::NotFound => cfg!(unix), + _ => false, + } + } + _ => false, + } + } + + /// Returns true if error was caused by I/O time out. + /// Note that this may not be accurate depending on platform. + pub fn is_timeout(&self) -> bool { + match self.repr { + ErrorRepr::IoError(ref err) => matches!( + err.kind(), + io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock + ), + _ => false, + } + } + + /// Returns true if error was caused by a dropped connection. + pub fn is_connection_dropped(&self) -> bool { + match self.repr { + ErrorRepr::IoError(ref err) => matches!( + err.kind(), + io::ErrorKind::BrokenPipe + | io::ErrorKind::ConnectionReset + | io::ErrorKind::UnexpectedEof + ), + _ => false, + } + } + + /// Returns true if the error is likely to not be recoverable, and the connection must be replaced. + pub fn is_unrecoverable_error(&self) -> bool { + match self.retry_method() { + RetryMethod::Reconnect => true, + + RetryMethod::NoRetry => false, + RetryMethod::RetryImmediately => false, + RetryMethod::WaitAndRetry => false, + RetryMethod::AskRedirect => false, + RetryMethod::MovedRedirect => false, + RetryMethod::WaitAndRetryOnPrimaryRedirectOnReplica => false, + } + } + + /// Returns the node the error refers to. + /// + /// This returns `(addr, slot_id)`. + pub fn redirect_node(&self) -> Option<(&str, u16)> { + match self.kind() { + ErrorKind::Ask | ErrorKind::Moved => (), + _ => return None, + } + let mut iter = self.detail()?.split_ascii_whitespace(); + let slot_id: u16 = iter.next()?.parse().ok()?; + let addr = iter.next()?; + Some((addr, slot_id)) + } + + /// Returns the extension error code. + /// + /// This method should not be used because every time the redis library + /// adds support for a new error code it would disappear form this method. + /// `code()` always returns the code. + #[deprecated(note = "use code() instead")] + pub fn extension_error_code(&self) -> Option<&str> { + match self.repr { + ErrorRepr::ExtensionError(ref code, _) => Some(code), + _ => None, + } + } + + /// Clone the `RedisError`, throwing away non-cloneable parts of an `IoError`. + /// + /// Deriving `Clone` is not possible because the wrapped `io::Error` is not + /// cloneable. + /// + /// The `ioerror_description` parameter will be prepended to the message in + /// case an `IoError` is found. + #[cfg(feature = "connection-manager")] // Used to avoid "unused method" warning + pub(crate) fn clone_mostly(&self, ioerror_description: &'static str) -> Self { + let repr = match self.repr { + ErrorRepr::WithDescription(kind, desc) => ErrorRepr::WithDescription(kind, desc), + ErrorRepr::WithDescriptionAndDetail(kind, desc, ref detail) => { + ErrorRepr::WithDescriptionAndDetail(kind, desc, detail.clone()) + } + ErrorRepr::ExtensionError(ref code, ref detail) => { + ErrorRepr::ExtensionError(code.clone(), detail.clone()) + } + ErrorRepr::IoError(ref e) => ErrorRepr::IoError(io::Error::new( + e.kind(), + format!("{ioerror_description}: {e}"), + )), + }; + Self { repr } + } + + pub(crate) fn retry_method(&self) -> RetryMethod { + match self.kind() { + ErrorKind::Moved => RetryMethod::MovedRedirect, + ErrorKind::Ask => RetryMethod::AskRedirect, + + ErrorKind::TryAgain => RetryMethod::WaitAndRetry, + ErrorKind::MasterDown => RetryMethod::WaitAndRetry, + ErrorKind::ClusterDown => RetryMethod::WaitAndRetry, + ErrorKind::MasterNameNotFoundBySentinel => RetryMethod::WaitAndRetry, + ErrorKind::NoValidReplicasFoundBySentinel => RetryMethod::WaitAndRetry, + + ErrorKind::BusyLoadingError => RetryMethod::WaitAndRetryOnPrimaryRedirectOnReplica, + + ErrorKind::ResponseError => RetryMethod::NoRetry, + ErrorKind::ReadOnly => RetryMethod::NoRetry, + ErrorKind::ExtensionError => RetryMethod::NoRetry, + ErrorKind::ExecAbortError => RetryMethod::NoRetry, + ErrorKind::TypeError => RetryMethod::NoRetry, + ErrorKind::NoScriptError => RetryMethod::NoRetry, + ErrorKind::InvalidClientConfig => RetryMethod::NoRetry, + ErrorKind::CrossSlot => RetryMethod::NoRetry, + ErrorKind::ClientError => RetryMethod::NoRetry, + ErrorKind::EmptySentinelList => RetryMethod::NoRetry, + ErrorKind::NotBusy => RetryMethod::NoRetry, + #[cfg(feature = "json")] + ErrorKind::Serialize => RetryMethod::NoRetry, + ErrorKind::RESP3NotSupported => RetryMethod::NoRetry, + + ErrorKind::ParseError => RetryMethod::Reconnect, + ErrorKind::AuthenticationFailed => RetryMethod::Reconnect, + ErrorKind::AllConnectionsUnavailable => RetryMethod::Reconnect, + ErrorKind::ConnectionNotFoundForRoute => RetryMethod::Reconnect, + + ErrorKind::IoError => match &self.repr { + ErrorRepr::IoError(err) => match err.kind() { + io::ErrorKind::ConnectionRefused => RetryMethod::Reconnect, + io::ErrorKind::NotFound => RetryMethod::Reconnect, + io::ErrorKind::ConnectionReset => RetryMethod::Reconnect, + io::ErrorKind::ConnectionAborted => RetryMethod::Reconnect, + io::ErrorKind::NotConnected => RetryMethod::Reconnect, + io::ErrorKind::BrokenPipe => RetryMethod::Reconnect, + io::ErrorKind::UnexpectedEof => RetryMethod::Reconnect, + + io::ErrorKind::PermissionDenied => RetryMethod::NoRetry, + io::ErrorKind::Unsupported => RetryMethod::NoRetry, + + _ => RetryMethod::RetryImmediately, + }, + _ => RetryMethod::RetryImmediately, + }, + ErrorKind::NotAllSlotsCovered => RetryMethod::NoRetry, + } + } +} + +pub fn make_extension_error(code: String, detail: Option) -> RedisError { + RedisError { + repr: ErrorRepr::ExtensionError( + code, + match detail { + Some(x) => x, + None => "Unknown extension error encountered".to_string(), + }, + ), + } +} + +/// Library generic result type. +pub type RedisResult = Result; + +/// Library generic future type. +#[cfg(feature = "aio")] +pub type RedisFuture<'a, T> = futures_util::future::BoxFuture<'a, RedisResult>; + +/// An info dictionary type. +#[derive(Debug, Clone)] +pub struct InfoDict { + map: HashMap, +} + +/// This type provides convenient access to key/value data returned by +/// the "INFO" command. It acts like a regular mapping but also has +/// a convenience method `get` which can return data in the appropriate +/// type. +/// +/// For instance this can be used to query the server for the role it's +/// in (master, slave) etc: +/// +/// ```rust,no_run +/// # fn do_something() -> redis::RedisResult<()> { +/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +/// # let mut con = client.get_connection(None).unwrap(); +/// let info : redis::InfoDict = redis::cmd("INFO").query(&mut con)?; +/// let role : Option = info.get("role"); +/// # Ok(()) } +/// ``` +impl InfoDict { + /// Creates a new info dictionary from a string in the response of + /// the INFO command. Each line is a key, value pair with the + /// key and value separated by a colon (`:`). Lines starting with a + /// hash (`#`) are ignored. + pub fn new(kvpairs: &str) -> InfoDict { + let mut map = HashMap::new(); + for line in kvpairs.lines() { + if line.is_empty() || line.starts_with('#') { + continue; + } + let mut p = line.splitn(2, ':'); + let (k, v) = match (p.next(), p.next()) { + (Some(k), Some(v)) => (k.to_string(), v.to_string()), + _ => continue, + }; + map.insert(k, Value::SimpleString(v)); + } + InfoDict { map } + } + + /// Fetches a value by key and converts it into the given type. + /// Typical types are `String`, `bool` and integer types. + pub fn get(&self, key: &str) -> Option { + match self.find(&key) { + Some(x) => from_redis_value(x).ok(), + None => None, + } + } + + /// Looks up a key in the info dict. + pub fn find(&self, key: &&str) -> Option<&Value> { + self.map.get(*key) + } + + /// Checks if a key is contained in the info dicf. + pub fn contains_key(&self, key: &&str) -> bool { + self.find(key).is_some() + } + + /// Returns the size of the info dict. + pub fn len(&self) -> usize { + self.map.len() + } + + /// Checks if the dict is empty. + pub fn is_empty(&self) -> bool { + self.map.is_empty() + } +} + +impl Deref for InfoDict { + type Target = HashMap; + + fn deref(&self) -> &Self::Target { + &self.map + } +} + +/// Abstraction trait for redis command abstractions. +pub trait RedisWrite { + /// Accepts a serialized redis command. + fn write_arg(&mut self, arg: &[u8]); + + /// Accepts a serialized redis command. + fn write_arg_fmt(&mut self, arg: impl fmt::Display) { + self.write_arg(arg.to_string().as_bytes()) + } +} + +impl RedisWrite for Vec> { + fn write_arg(&mut self, arg: &[u8]) { + self.push(arg.to_owned()); + } + + fn write_arg_fmt(&mut self, arg: impl fmt::Display) { + self.push(arg.to_string().into_bytes()) + } +} + +/// Used to convert a value into one or multiple redis argument +/// strings. Most values will produce exactly one item but in +/// some cases it might make sense to produce more than one. +pub trait ToRedisArgs: Sized { + /// This converts the value into a vector of bytes. Each item + /// is a single argument. Most items generate a vector of a + /// single item. + /// + /// The exception to this rule currently are vectors of items. + fn to_redis_args(&self) -> Vec> { + let mut out = Vec::new(); + self.write_redis_args(&mut out); + out + } + + /// This writes the value into a vector of bytes. Each item + /// is a single argument. Most items generate a single item. + /// + /// The exception to this rule currently are vectors of items. + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite; + + /// Returns an information about the contained value with regards + /// to it's numeric behavior in a redis context. This is used in + /// some high level concepts to switch between different implementations + /// of redis functions (for instance `INCR` vs `INCRBYFLOAT`). + fn describe_numeric_behavior(&self) -> NumericBehavior { + NumericBehavior::NonNumeric + } + + /// Returns an indiciation if the value contained is exactly one + /// argument. It returns false if it's zero or more than one. This + /// is used in some high level functions to intelligently switch + /// between `GET` and `MGET` variants. + fn is_single_arg(&self) -> bool { + true + } + + /// This only exists internally as a workaround for the lack of + /// specialization. + #[doc(hidden)] + fn write_args_from_slice(items: &[Self], out: &mut W) + where + W: ?Sized + RedisWrite, + { + Self::make_arg_iter_ref(items.iter(), out) + } + + /// This only exists internally as a workaround for the lack of + /// specialization. + #[doc(hidden)] + fn make_arg_iter_ref<'a, I, W>(items: I, out: &mut W) + where + W: ?Sized + RedisWrite, + I: Iterator, + Self: 'a, + { + for item in items { + item.write_redis_args(out); + } + } + + #[doc(hidden)] + fn is_single_vec_arg(items: &[Self]) -> bool { + items.len() == 1 && items[0].is_single_arg() + } +} + +macro_rules! itoa_based_to_redis_impl { + ($t:ty, $numeric:expr) => { + impl ToRedisArgs for $t { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + let mut buf = ::itoa::Buffer::new(); + let s = buf.format(*self); + out.write_arg(s.as_bytes()) + } + + fn describe_numeric_behavior(&self) -> NumericBehavior { + $numeric + } + } + }; +} + +macro_rules! non_zero_itoa_based_to_redis_impl { + ($t:ty, $numeric:expr) => { + impl ToRedisArgs for $t { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + let mut buf = ::itoa::Buffer::new(); + let s = buf.format(self.get()); + out.write_arg(s.as_bytes()) + } + + fn describe_numeric_behavior(&self) -> NumericBehavior { + $numeric + } + } + }; +} + +macro_rules! ryu_based_to_redis_impl { + ($t:ty, $numeric:expr) => { + impl ToRedisArgs for $t { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + let mut buf = ::ryu::Buffer::new(); + let s = buf.format(*self); + out.write_arg(s.as_bytes()) + } + + fn describe_numeric_behavior(&self) -> NumericBehavior { + $numeric + } + } + }; +} + +impl ToRedisArgs for u8 { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + let mut buf = ::itoa::Buffer::new(); + let s = buf.format(*self); + out.write_arg(s.as_bytes()) + } + + fn write_args_from_slice(items: &[u8], out: &mut W) + where + W: ?Sized + RedisWrite, + { + out.write_arg(items); + } + + fn is_single_vec_arg(_items: &[u8]) -> bool { + true + } +} + +itoa_based_to_redis_impl!(i8, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(i16, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(u16, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(i32, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(u32, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(i64, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(u64, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(isize, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(usize, NumericBehavior::NumberIsInteger); + +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroU8, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroI8, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroU16, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroI16, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroU32, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroI32, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroU64, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroI64, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroUsize, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroIsize, NumericBehavior::NumberIsInteger); + +ryu_based_to_redis_impl!(f32, NumericBehavior::NumberIsFloat); +ryu_based_to_redis_impl!(f64, NumericBehavior::NumberIsFloat); + +#[cfg(any( + feature = "rust_decimal", + feature = "bigdecimal", + feature = "num-bigint" +))] +macro_rules! bignum_to_redis_impl { + ($t:ty) => { + impl ToRedisArgs for $t { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + out.write_arg(&self.to_string().into_bytes()) + } + } + }; +} + +#[cfg(feature = "rust_decimal")] +bignum_to_redis_impl!(rust_decimal::Decimal); +#[cfg(feature = "bigdecimal")] +bignum_to_redis_impl!(bigdecimal::BigDecimal); +#[cfg(feature = "num-bigint")] +bignum_to_redis_impl!(num_bigint::BigInt); +#[cfg(feature = "num-bigint")] +bignum_to_redis_impl!(num_bigint::BigUint); + +impl ToRedisArgs for bool { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + out.write_arg(if *self { b"1" } else { b"0" }) + } +} + +impl ToRedisArgs for String { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + out.write_arg(self.as_bytes()) + } +} + +impl<'a> ToRedisArgs for &'a str { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + out.write_arg(self.as_bytes()) + } +} + +impl ToRedisArgs for Vec { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + ToRedisArgs::write_args_from_slice(self, out) + } + + fn is_single_arg(&self) -> bool { + ToRedisArgs::is_single_vec_arg(&self[..]) + } +} + +impl<'a, T: ToRedisArgs> ToRedisArgs for &'a [T] { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + ToRedisArgs::write_args_from_slice(self, out) + } + + fn is_single_arg(&self) -> bool { + ToRedisArgs::is_single_vec_arg(self) + } +} + +impl ToRedisArgs for Option { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + if let Some(ref x) = *self { + x.write_redis_args(out); + } + } + + fn describe_numeric_behavior(&self) -> NumericBehavior { + match *self { + Some(ref x) => x.describe_numeric_behavior(), + None => NumericBehavior::NonNumeric, + } + } + + fn is_single_arg(&self) -> bool { + match *self { + Some(ref x) => x.is_single_arg(), + None => false, + } + } +} + +impl ToRedisArgs for &T { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + (*self).write_redis_args(out) + } + + fn is_single_arg(&self) -> bool { + (*self).is_single_arg() + } +} + +/// @note: Redis cannot store empty sets so the application has to +/// check whether the set is empty and if so, not attempt to use that +/// result +impl ToRedisArgs + for std::collections::HashSet +{ + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + ToRedisArgs::make_arg_iter_ref(self.iter(), out) + } + + fn is_single_arg(&self) -> bool { + self.len() <= 1 + } +} + +/// @note: Redis cannot store empty sets so the application has to +/// check whether the set is empty and if so, not attempt to use that +/// result +#[cfg(feature = "ahash")] +impl ToRedisArgs for ahash::AHashSet { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + ToRedisArgs::make_arg_iter_ref(self.iter(), out) + } + + fn is_single_arg(&self) -> bool { + self.len() <= 1 + } +} + +/// @note: Redis cannot store empty sets so the application has to +/// check whether the set is empty and if so, not attempt to use that +/// result +impl ToRedisArgs for BTreeSet { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + ToRedisArgs::make_arg_iter_ref(self.iter(), out) + } + + fn is_single_arg(&self) -> bool { + self.len() <= 1 + } +} + +/// this flattens BTreeMap into something that goes well with HMSET +/// @note: Redis cannot store empty sets so the application has to +/// check whether the set is empty and if so, not attempt to use that +/// result +impl ToRedisArgs for BTreeMap { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + for (key, value) in self { + // otherwise things like HMSET will simply NOT work + assert!(key.is_single_arg() && value.is_single_arg()); + + key.write_redis_args(out); + value.write_redis_args(out); + } + } + + fn is_single_arg(&self) -> bool { + self.len() <= 1 + } +} + +impl ToRedisArgs + for std::collections::HashMap +{ + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + for (key, value) in self { + assert!(key.is_single_arg() && value.is_single_arg()); + + key.write_redis_args(out); + value.write_redis_args(out); + } + } + + fn is_single_arg(&self) -> bool { + self.len() <= 1 + } +} + +macro_rules! to_redis_args_for_tuple { + () => (); + ($($name:ident,)+) => ( + #[doc(hidden)] + impl<$($name: ToRedisArgs),*> ToRedisArgs for ($($name,)*) { + // we have local variables named T1 as dummies and those + // variables are unused. + #[allow(non_snake_case, unused_variables)] + fn write_redis_args(&self, out: &mut W) where W: ?Sized + RedisWrite { + let ($(ref $name,)*) = *self; + $($name.write_redis_args(out);)* + } + + #[allow(non_snake_case, unused_variables)] + fn is_single_arg(&self) -> bool { + let mut n = 0u32; + $(let $name = (); n += 1;)* + n == 1 + } + } + to_redis_args_for_tuple_peel!($($name,)*); + ) +} + +/// This chips of the leading one and recurses for the rest. So if the first +/// iteration was T1, T2, T3 it will recurse to T2, T3. It stops for tuples +/// of size 1 (does not implement down to unit). +macro_rules! to_redis_args_for_tuple_peel { + ($name:ident, $($other:ident,)*) => (to_redis_args_for_tuple!($($other,)*);) +} + +to_redis_args_for_tuple! { T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, } + +impl ToRedisArgs for &[T; N] { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + ToRedisArgs::write_args_from_slice(self.as_slice(), out) + } + + fn is_single_arg(&self) -> bool { + ToRedisArgs::is_single_vec_arg(self.as_slice()) + } +} + +fn vec_to_array(items: Vec, original_value: &Value) -> RedisResult<[T; N]> { + match items.try_into() { + Ok(array) => Ok(array), + Err(items) => { + let msg = format!( + "Response has wrong dimension, expected {N}, got {}", + items.len() + ); + invalid_type_error!(original_value, msg) + } + } +} + +impl FromRedisValue for [T; N] { + fn from_redis_value(value: &Value) -> RedisResult<[T; N]> { + match *value { + Value::BulkString(ref bytes) => match FromRedisValue::from_byte_vec(bytes) { + Some(items) => vec_to_array(items, value), + None => { + let msg = format!( + "Conversion to Array[{}; {N}] failed", + std::any::type_name::() + ); + invalid_type_error!(value, msg) + } + }, + Value::Array(ref items) => { + let items = FromRedisValue::from_redis_values(items)?; + vec_to_array(items, value) + } + Value::Nil => vec_to_array(vec![], value), + _ => invalid_type_error!(value, "Response type not array compatible"), + } + } +} + +/// This trait is used to convert a redis value into a more appropriate +/// type. While a redis `Value` can represent any response that comes +/// back from the redis server, usually you want to map this into something +/// that works better in rust. For instance you might want to convert the +/// return value into a `String` or an integer. +/// +/// This trait is well supported throughout the library and you can +/// implement it for your own types if you want. +/// +/// In addition to what you can see from the docs, this is also implemented +/// for tuples up to size 12 and for `Vec`. +pub trait FromRedisValue: Sized { + /// Given a redis `Value` this attempts to convert it into the given + /// destination type. If that fails because it's not compatible an + /// appropriate error is generated. + fn from_redis_value(v: &Value) -> RedisResult; + + /// Given a redis `Value` this attempts to convert it into the given + /// destination type. If that fails because it's not compatible an + /// appropriate error is generated. + fn from_owned_redis_value(v: Value) -> RedisResult { + // By default, fall back to `from_redis_value`. + // This function only needs to be implemented if it can benefit + // from taking `v` by value. + Self::from_redis_value(&v) + } + + /// Similar to `from_redis_value` but constructs a vector of objects + /// from another vector of values. This primarily exists internally + /// to customize the behavior for vectors of tuples. + fn from_redis_values(items: &[Value]) -> RedisResult> { + items.iter().map(FromRedisValue::from_redis_value).collect() + } + + /// The same as `from_redis_values`, but takes a `Vec` instead + /// of a `&[Value]`. + fn from_owned_redis_values(items: Vec) -> RedisResult> { + items + .into_iter() + .map(FromRedisValue::from_owned_redis_value) + .collect() + } + + /// Convert bytes to a single element vector. + fn from_byte_vec(_vec: &[u8]) -> Option> { + Self::from_owned_redis_value(Value::BulkString(_vec.into())) + .map(|rv| vec![rv]) + .ok() + } + + /// Convert bytes to a single element vector. + fn from_owned_byte_vec(_vec: Vec) -> RedisResult> { + Self::from_owned_redis_value(Value::BulkString(_vec)).map(|rv| vec![rv]) + } +} + +fn get_inner_value(v: &Value) -> &Value { + if let Value::Attribute { + data, + attributes: _, + } = v + { + data.as_ref() + } else { + v + } +} + +fn get_owned_inner_value(v: Value) -> Value { + if let Value::Attribute { + data, + attributes: _, + } = v + { + *data + } else { + v + } +} + +macro_rules! from_redis_value_for_num_internal { + ($t:ty, $v:expr) => {{ + let v = if let Value::Attribute { + data, + attributes: _, + } = $v + { + data + } else { + $v + }; + match *v { + Value::Int(val) => Ok(val as $t), + Value::SimpleString(ref s) => match s.parse::<$t>() { + Ok(rv) => Ok(rv), + Err(_) => invalid_type_error!(v, "Could not convert from string."), + }, + Value::BulkString(ref bytes) => match from_utf8(bytes)?.parse::<$t>() { + Ok(rv) => Ok(rv), + Err(_) => invalid_type_error!(v, "Could not convert from string."), + }, + Value::Double(val) => Ok(val as $t), + _ => invalid_type_error!(v, "Response type not convertible to numeric."), + } + }}; +} + +macro_rules! from_redis_value_for_num { + ($t:ty) => { + impl FromRedisValue for $t { + fn from_redis_value(v: &Value) -> RedisResult<$t> { + from_redis_value_for_num_internal!($t, v) + } + } + }; +} + +impl FromRedisValue for u8 { + fn from_redis_value(v: &Value) -> RedisResult { + from_redis_value_for_num_internal!(u8, v) + } + + // this hack allows us to specialize Vec to work with binary data. + fn from_byte_vec(vec: &[u8]) -> Option> { + Some(vec.to_vec()) + } + fn from_owned_byte_vec(vec: Vec) -> RedisResult> { + Ok(vec) + } +} + +from_redis_value_for_num!(i8); +from_redis_value_for_num!(i16); +from_redis_value_for_num!(u16); +from_redis_value_for_num!(i32); +from_redis_value_for_num!(u32); +from_redis_value_for_num!(i64); +from_redis_value_for_num!(u64); +from_redis_value_for_num!(i128); +from_redis_value_for_num!(u128); +from_redis_value_for_num!(f32); +from_redis_value_for_num!(f64); +from_redis_value_for_num!(isize); +from_redis_value_for_num!(usize); + +#[cfg(any( + feature = "rust_decimal", + feature = "bigdecimal", + feature = "num-bigint" +))] +macro_rules! from_redis_value_for_bignum_internal { + ($t:ty, $v:expr) => {{ + let v = $v; + match *v { + Value::Int(val) => <$t>::try_from(val) + .map_err(|_| invalid_type_error_inner!(v, "Could not convert from integer.")), + Value::SimpleString(ref s) => match s.parse::<$t>() { + Ok(rv) => Ok(rv), + Err(_) => invalid_type_error!(v, "Could not convert from string."), + }, + Value::BulkString(ref bytes) => match from_utf8(bytes)?.parse::<$t>() { + Ok(rv) => Ok(rv), + Err(_) => invalid_type_error!(v, "Could not convert from string."), + }, + _ => invalid_type_error!(v, "Response type not convertible to numeric."), + } + }}; +} + +#[cfg(any( + feature = "rust_decimal", + feature = "bigdecimal", + feature = "num-bigint" +))] +macro_rules! from_redis_value_for_bignum { + ($t:ty) => { + impl FromRedisValue for $t { + fn from_redis_value(v: &Value) -> RedisResult<$t> { + from_redis_value_for_bignum_internal!($t, v) + } + } + }; +} + +#[cfg(feature = "rust_decimal")] +from_redis_value_for_bignum!(rust_decimal::Decimal); +#[cfg(feature = "bigdecimal")] +from_redis_value_for_bignum!(bigdecimal::BigDecimal); +#[cfg(feature = "num-bigint")] +from_redis_value_for_bignum!(num_bigint::BigInt); +#[cfg(feature = "num-bigint")] +from_redis_value_for_bignum!(num_bigint::BigUint); + +impl FromRedisValue for bool { + fn from_redis_value(v: &Value) -> RedisResult { + let v = get_inner_value(v); + match *v { + Value::Nil => Ok(false), + Value::Int(val) => Ok(val != 0), + Value::SimpleString(ref s) => { + if &s[..] == "1" { + Ok(true) + } else if &s[..] == "0" { + Ok(false) + } else { + invalid_type_error!(v, "Response status not valid boolean"); + } + } + Value::BulkString(ref bytes) => { + if bytes == b"1" { + Ok(true) + } else if bytes == b"0" { + Ok(false) + } else { + invalid_type_error!(v, "Response type not bool compatible."); + } + } + Value::Boolean(b) => Ok(b), + Value::Okay => Ok(true), + _ => invalid_type_error!(v, "Response type not bool compatible."), + } + } +} + +impl FromRedisValue for CString { + fn from_redis_value(v: &Value) -> RedisResult { + let v = get_inner_value(v); + match *v { + Value::BulkString(ref bytes) => Ok(CString::new(bytes.as_slice())?), + Value::Okay => Ok(CString::new("OK")?), + Value::SimpleString(ref val) => Ok(CString::new(val.as_bytes())?), + _ => invalid_type_error!(v, "Response type not CString compatible."), + } + } + fn from_owned_redis_value(v: Value) -> RedisResult { + let v = get_owned_inner_value(v); + match v { + Value::BulkString(bytes) => Ok(CString::new(bytes)?), + Value::Okay => Ok(CString::new("OK")?), + Value::SimpleString(val) => Ok(CString::new(val)?), + _ => invalid_type_error!(v, "Response type not CString compatible."), + } + } +} + +impl FromRedisValue for String { + fn from_redis_value(v: &Value) -> RedisResult { + let v = get_inner_value(v); + match *v { + Value::BulkString(ref bytes) => Ok(from_utf8(bytes)?.to_string()), + Value::Okay => Ok("OK".to_string()), + Value::SimpleString(ref val) => Ok(val.to_string()), + Value::VerbatimString { + format: _, + ref text, + } => Ok(text.to_string()), + Value::Double(ref val) => Ok(val.to_string()), + Value::Int(val) => Ok(val.to_string()), + _ => invalid_type_error!(v, "Response type not string compatible."), + } + } + + fn from_owned_redis_value(v: Value) -> RedisResult { + let v = get_owned_inner_value(v); + match v { + Value::BulkString(bytes) => Ok(String::from_utf8(bytes)?), + Value::Okay => Ok("OK".to_string()), + Value::SimpleString(val) => Ok(val), + Value::VerbatimString { format: _, text } => Ok(text), + Value::Double(val) => Ok(val.to_string()), + Value::Int(val) => Ok(val.to_string()), + _ => invalid_type_error!(v, "Response type not string compatible."), + } + } +} + +/// Implement `FromRedisValue` for `$Type` (which should use the generic parameter `$T`). +/// +/// The implementation parses the value into a vec, and then passes the value through `$convert`. +/// If `$convert` is ommited, it defaults to `Into::into`. +macro_rules! from_vec_from_redis_value { + (<$T:ident> $Type:ty) => { + from_vec_from_redis_value!(<$T> $Type; Into::into); + }; + + (<$T:ident> $Type:ty; $convert:expr) => { + impl<$T: FromRedisValue> FromRedisValue for $Type { + fn from_redis_value(v: &Value) -> RedisResult<$Type> { + match v { + // All binary data except u8 will try to parse into a single element vector. + // u8 has its own implementation of from_byte_vec. + Value::BulkString(bytes) => match FromRedisValue::from_byte_vec(bytes) { + Some(x) => Ok($convert(x)), + None => invalid_type_error!( + v, + format!("Conversion to {} failed.", std::any::type_name::<$Type>()) + ), + }, + Value::Array(items) => FromRedisValue::from_redis_values(items).map($convert), + Value::Set(ref items) => FromRedisValue::from_redis_values(items).map($convert), + Value::Map(ref items) => { + let mut n: Vec = vec![]; + for item in items { + match FromRedisValue::from_redis_value(&Value::Map(vec![item.clone()])) { + Ok(v) => { + n.push(v); + } + Err(e) => { + return Err(e); + } + } + } + Ok($convert(n)) + } + Value::Nil => Ok($convert(Vec::new())), + _ => invalid_type_error!(v, "Response type not vector compatible."), + } + } + fn from_owned_redis_value(v: Value) -> RedisResult<$Type> { + match v { + // Binary data is parsed into a single-element vector, except + // for the element type `u8`, which directly consumes the entire + // array of bytes. + Value::BulkString(bytes) => FromRedisValue::from_owned_byte_vec(bytes).map($convert), + Value::Array(items) => FromRedisValue::from_owned_redis_values(items).map($convert), + Value::Set(items) => FromRedisValue::from_owned_redis_values(items).map($convert), + Value::Map(items) => { + let mut n: Vec = vec![]; + for item in items { + match FromRedisValue::from_owned_redis_value(Value::Map(vec![item])) { + Ok(v) => { + n.push(v); + } + Err(e) => { + return Err(e); + } + } + } + Ok($convert(n)) + } + Value::Nil => Ok($convert(Vec::new())), + _ => invalid_type_error!(v, "Response type not vector compatible."), + } + } + } + }; +} + +from_vec_from_redis_value!( Vec); +from_vec_from_redis_value!( std::sync::Arc<[T]>); +from_vec_from_redis_value!( Box<[T]>; Vec::into_boxed_slice); + +impl FromRedisValue + for std::collections::HashMap +{ + fn from_redis_value(v: &Value) -> RedisResult> { + let v = get_inner_value(v); + match *v { + Value::Nil => Ok(Default::default()), + _ => v + .as_map_iter() + .ok_or_else(|| { + invalid_type_error_inner!(v, "Response type not hashmap compatible") + })? + .map(|(k, v)| Ok((from_redis_value(k)?, from_redis_value(v)?))) + .collect(), + } + } + fn from_owned_redis_value(v: Value) -> RedisResult> { + let v = get_owned_inner_value(v); + match v { + Value::Nil => Ok(Default::default()), + _ => v + .into_map_iter() + .map_err(|v| invalid_type_error_inner!(v, "Response type not hashmap compatible"))? + .map(|(k, v)| Ok((from_owned_redis_value(k)?, from_owned_redis_value(v)?))) + .collect(), + } + } +} + +#[cfg(feature = "ahash")] +impl FromRedisValue for ahash::AHashMap { + fn from_redis_value(v: &Value) -> RedisResult> { + let v = get_inner_value(v); + match *v { + Value::Nil => Ok(ahash::AHashMap::with_hasher(Default::default())), + _ => v + .as_map_iter() + .ok_or_else(|| { + invalid_type_error_inner!(v, "Response type not hashmap compatible") + })? + .map(|(k, v)| Ok((from_redis_value(k)?, from_redis_value(v)?))) + .collect(), + } + } + fn from_owned_redis_value(v: Value) -> RedisResult> { + let v = get_owned_inner_value(v); + match v { + Value::Nil => Ok(ahash::AHashMap::with_hasher(Default::default())), + _ => v + .into_map_iter() + .map_err(|v| invalid_type_error_inner!(v, "Response type not hashmap compatible"))? + .map(|(k, v)| Ok((from_owned_redis_value(k)?, from_owned_redis_value(v)?))) + .collect(), + } + } +} + +impl FromRedisValue for BTreeMap +where + K: Ord, +{ + fn from_redis_value(v: &Value) -> RedisResult> { + let v = get_inner_value(v); + v.as_map_iter() + .ok_or_else(|| invalid_type_error_inner!(v, "Response type not btreemap compatible"))? + .map(|(k, v)| Ok((from_redis_value(k)?, from_redis_value(v)?))) + .collect() + } + fn from_owned_redis_value(v: Value) -> RedisResult> { + let v = get_owned_inner_value(v); + v.into_map_iter() + .map_err(|v| invalid_type_error_inner!(v, "Response type not btreemap compatible"))? + .map(|(k, v)| Ok((from_owned_redis_value(k)?, from_owned_redis_value(v)?))) + .collect() + } +} + +impl FromRedisValue + for std::collections::HashSet +{ + fn from_redis_value(v: &Value) -> RedisResult> { + let v = get_inner_value(v); + let items = v + .as_sequence() + .ok_or_else(|| invalid_type_error_inner!(v, "Response type not hashset compatible"))?; + items.iter().map(|item| from_redis_value(item)).collect() + } + fn from_owned_redis_value(v: Value) -> RedisResult> { + let v = get_owned_inner_value(v); + let items = v + .into_sequence() + .map_err(|v| invalid_type_error_inner!(v, "Response type not hashset compatible"))?; + items + .into_iter() + .map(|item| from_owned_redis_value(item)) + .collect() + } +} + +#[cfg(feature = "ahash")] +impl FromRedisValue for ahash::AHashSet { + fn from_redis_value(v: &Value) -> RedisResult> { + let v = get_inner_value(v); + let items = v + .as_sequence() + .ok_or_else(|| invalid_type_error_inner!(v, "Response type not hashset compatible"))?; + items.iter().map(|item| from_redis_value(item)).collect() + } + + fn from_owned_redis_value(v: Value) -> RedisResult> { + let v = get_owned_inner_value(v); + let items = v + .into_sequence() + .map_err(|v| invalid_type_error_inner!(v, "Response type not hashset compatible"))?; + items + .into_iter() + .map(|item| from_owned_redis_value(item)) + .collect() + } +} + +impl FromRedisValue for BTreeSet +where + T: Ord, +{ + fn from_redis_value(v: &Value) -> RedisResult> { + let v = get_inner_value(v); + let items = v + .as_sequence() + .ok_or_else(|| invalid_type_error_inner!(v, "Response type not btreeset compatible"))?; + items.iter().map(|item| from_redis_value(item)).collect() + } + fn from_owned_redis_value(v: Value) -> RedisResult> { + let v = get_owned_inner_value(v); + let items = v + .into_sequence() + .map_err(|v| invalid_type_error_inner!(v, "Response type not btreeset compatible"))?; + items + .into_iter() + .map(|item| from_owned_redis_value(item)) + .collect() + } +} + +impl FromRedisValue for Value { + fn from_redis_value(v: &Value) -> RedisResult { + Ok(v.clone()) + } + fn from_owned_redis_value(v: Value) -> RedisResult { + Ok(v) + } +} + +impl FromRedisValue for () { + fn from_redis_value(_v: &Value) -> RedisResult<()> { + Ok(()) + } +} + +macro_rules! from_redis_value_for_tuple { + () => (); + ($($name:ident,)+) => ( + #[doc(hidden)] + impl<$($name: FromRedisValue),*> FromRedisValue for ($($name,)*) { + // we have local variables named T1 as dummies and those + // variables are unused. + #[allow(non_snake_case, unused_variables)] + fn from_redis_value(v: &Value) -> RedisResult<($($name,)*)> { + let v = get_inner_value(v); + match *v { + Value::Array(ref items) => { + // hacky way to count the tuple size + let mut n = 0; + $(let $name = (); n += 1;)* + if items.len() != n { + invalid_type_error!(v, "Array response of wrong dimension") + } + + // this is pretty ugly too. The { i += 1; i - 1} is rust's + // postfix increment :) + let mut i = 0; + Ok(($({let $name = (); from_redis_value( + &items[{ i += 1; i - 1 }])?},)*)) + } + + Value::Map(ref items) => { + // hacky way to count the tuple size + let mut n = 0; + $(let $name = (); n += 1;)* + if n != 2 { + invalid_type_error!(v, "Map response of wrong dimension") + } + + let mut flatten_items = vec![]; + for (k,v) in items { + flatten_items.push(k); + flatten_items.push(v); + } + + // this is pretty ugly too. The { i += 1; i - 1} is rust's + // postfix increment :) + let mut i = 0; + Ok(($({let $name = (); from_redis_value( + &flatten_items[{ i += 1; i - 1 }])?},)*)) + } + + _ => invalid_type_error!(v, "Not a Array response") + } + } + + // we have local variables named T1 as dummies and those + // variables are unused. + #[allow(non_snake_case, unused_variables)] + fn from_owned_redis_value(v: Value) -> RedisResult<($($name,)*)> { + let v = get_owned_inner_value(v); + match v { + Value::Array(mut items) => { + // hacky way to count the tuple size + let mut n = 0; + $(let $name = (); n += 1;)* + if items.len() != n { + invalid_type_error!(Value::Array(items), "Array response of wrong dimension") + } + + // this is pretty ugly too. The { i += 1; i - 1} is rust's + // postfix increment :) + let mut i = 0; + Ok(($({let $name = (); from_owned_redis_value( + ::std::mem::replace(&mut items[{ i += 1; i - 1 }], Value::Nil) + )?},)*)) + } + + Value::Map(items) => { + // hacky way to count the tuple size + let mut n = 0; + $(let $name = (); n += 1;)* + if n != 2 { + invalid_type_error!(Value::Map(items), "Map response of wrong dimension") + } + + let mut flatten_items = vec![]; + for (k,v) in items { + flatten_items.push(k); + flatten_items.push(v); + } + + // this is pretty ugly too. The { i += 1; i - 1} is rust's + // postfix increment :) + let mut i = 0; + Ok(($({let $name = (); from_redis_value( + &flatten_items[{ i += 1; i - 1 }])?},)*)) + } + + _ => invalid_type_error!(v, "Not a Array response") + } + } + + #[allow(non_snake_case, unused_variables)] + fn from_redis_values(items: &[Value]) -> RedisResult> { + // hacky way to count the tuple size + let mut n = 0; + $(let $name = (); n += 1;)* + let mut rv = vec![]; + if items.len() == 0 { + return Ok(rv) + } + //It's uglier then before! + for item in items { + match item { + Value::Array(ch) => { + if let [$($name),*] = &ch[..] { + rv.push(($(from_redis_value(&$name)?),*),) + } else { + unreachable!() + }; + }, + _ => {}, + + } + } + if !rv.is_empty(){ + return Ok(rv); + } + + if let [$($name),*] = items{ + rv.push(($(from_redis_value($name)?),*),); + return Ok(rv); + } + for chunk in items.chunks_exact(n) { + match chunk { + [$($name),*] => rv.push(($(from_redis_value($name)?),*),), + _ => {}, + } + } + Ok(rv) + } + + #[allow(non_snake_case, unused_variables)] + fn from_owned_redis_values(mut items: Vec) -> RedisResult> { + // hacky way to count the tuple size + let mut n = 0; + $(let $name = (); n += 1;)* + + let mut rv = vec![]; + if items.len() == 0 { + return Ok(rv) + } + //It's uglier then before! + for item in items.iter() { + match item { + Value::Array(ch) => { + // TODO - this copies when we could've used the owned value. need to find out how to do this. + if let [$($name),*] = &ch[..] { + rv.push(($(from_redis_value($name)?),*),) + } else { + unreachable!() + }; + }, + _ => {}, + } + } + if !rv.is_empty(){ + return Ok(rv); + } + + let mut rv = Vec::with_capacity(items.len() / n); + if items.len() == 0 { + return Ok(rv) + } + for chunk in items.chunks_mut(n) { + match chunk { + // Take each element out of the chunk with `std::mem::replace`, leaving a `Value::Nil` + // in its place. This allows each `Value` to be parsed without being copied. + // Since `items` is consume by this function and not used later, this replacement + // is not observable to the rest of the code. + [$($name),*] => rv.push(($(from_owned_redis_value(std::mem::replace($name, Value::Nil))?),*),), + _ => unreachable!(), + } + } + Ok(rv) + } + } + from_redis_value_for_tuple_peel!($($name,)*); + ) +} + +/// This chips of the leading one and recurses for the rest. So if the first +/// iteration was T1, T2, T3 it will recurse to T2, T3. It stops for tuples +/// of size 1 (does not implement down to unit). +macro_rules! from_redis_value_for_tuple_peel { + ($name:ident, $($other:ident,)*) => (from_redis_value_for_tuple!($($other,)*);) +} + +from_redis_value_for_tuple! { T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, } + +impl FromRedisValue for InfoDict { + fn from_redis_value(v: &Value) -> RedisResult { + let v = get_inner_value(v); + let s: String = from_redis_value(v)?; + Ok(InfoDict::new(&s)) + } + fn from_owned_redis_value(v: Value) -> RedisResult { + let v = get_owned_inner_value(v); + let s: String = from_owned_redis_value(v)?; + Ok(InfoDict::new(&s)) + } +} + +impl FromRedisValue for Option { + fn from_redis_value(v: &Value) -> RedisResult> { + let v = get_inner_value(v); + if *v == Value::Nil { + return Ok(None); + } + Ok(Some(from_redis_value(v)?)) + } + fn from_owned_redis_value(v: Value) -> RedisResult> { + let v = get_owned_inner_value(v); + if v == Value::Nil { + return Ok(None); + } + Ok(Some(from_owned_redis_value(v)?)) + } +} + +#[cfg(feature = "bytes")] +impl FromRedisValue for bytes::Bytes { + fn from_redis_value(v: &Value) -> RedisResult { + let v = get_inner_value(v); + match v { + Value::BulkString(bytes_vec) => Ok(bytes::Bytes::copy_from_slice(bytes_vec.as_ref())), + _ => invalid_type_error!(v, "Not a bulk string"), + } + } + fn from_owned_redis_value(v: Value) -> RedisResult { + let v = get_owned_inner_value(v); + match v { + Value::BulkString(bytes_vec) => Ok(bytes_vec.into()), + _ => invalid_type_error!(v, "Not a bulk string"), + } + } +} + +#[cfg(feature = "uuid")] +impl FromRedisValue for uuid::Uuid { + fn from_redis_value(v: &Value) -> RedisResult { + match *v { + Value::BulkString(ref bytes) => Ok(uuid::Uuid::from_slice(bytes)?), + _ => invalid_type_error!(v, "Response type not uuid compatible."), + } + } +} + +#[cfg(feature = "uuid")] +impl ToRedisArgs for uuid::Uuid { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + out.write_arg(self.as_bytes()); + } +} + +/// A shortcut function to invoke `FromRedisValue::from_redis_value` +/// to make the API slightly nicer. +pub fn from_redis_value(v: &Value) -> RedisResult { + FromRedisValue::from_redis_value(v) +} + +/// A shortcut function to invoke `FromRedisValue::from_owned_redis_value` +/// to make the API slightly nicer. +pub fn from_owned_redis_value(v: Value) -> RedisResult { + FromRedisValue::from_owned_redis_value(v) +} + +/// Enum representing the communication protocol with the server. This enum represents the types +/// of data that the server can send to the client, and the capabilities that the client can use. +#[derive(Clone, Eq, PartialEq, Default, Debug, Copy)] +pub enum ProtocolVersion { + /// + #[default] + RESP2, + /// + RESP3, +} diff --git a/glide-core/redis-rs/redis/tests/parser.rs b/glide-core/redis-rs/redis/tests/parser.rs new file mode 100644 index 0000000000..c4083f44bd --- /dev/null +++ b/glide-core/redis-rs/redis/tests/parser.rs @@ -0,0 +1,195 @@ +use std::{io, pin::Pin}; + +use redis::Value; +use { + futures::{ + ready, + task::{self, Poll}, + }, + partial_io::{quickcheck_types::GenWouldBlock, quickcheck_types::PartialWithErrors, PartialOp}, + quickcheck::{quickcheck, Gen}, + tokio::io::{AsyncRead, ReadBuf}, +}; + +mod support; +use crate::support::{block_on_all, encode_value}; + +#[derive(Clone, Debug)] +struct ArbitraryValue(Value); + +impl ::quickcheck::Arbitrary for ArbitraryValue { + fn arbitrary(g: &mut Gen) -> Self { + let size = g.size(); + ArbitraryValue(arbitrary_value(g, size)) + } + + fn shrink(&self) -> Box> { + match self.0 { + Value::Nil | Value::Okay => Box::new(None.into_iter()), + Value::Int(i) => Box::new(i.shrink().map(Value::Int).map(ArbitraryValue)), + Value::BulkString(ref xs) => { + Box::new(xs.shrink().map(Value::BulkString).map(ArbitraryValue)) + } + Value::Array(ref xs) | Value::Set(ref xs) => { + let ys = xs + .iter() + .map(|x| ArbitraryValue(x.clone())) + .collect::>(); + Box::new( + ys.shrink() + .map(|xs| xs.into_iter().map(|x| x.0).collect()) + .map(Value::Array) + .map(ArbitraryValue), + ) + } + Value::Map(ref _xs) => Box::new(vec![ArbitraryValue(Value::Map(vec![]))].into_iter()), + Value::Attribute { + ref data, + ref attributes, + } => Box::new( + vec![ArbitraryValue(Value::Attribute { + data: data.clone(), + attributes: attributes.clone(), + })] + .into_iter(), + ), + Value::Push { ref kind, ref data } => { + let mut ys = data + .iter() + .map(|x| ArbitraryValue(x.clone())) + .collect::>(); + ys.insert(0, ArbitraryValue(Value::SimpleString(kind.to_string()))); + Box::new( + ys.shrink() + .map(|xs| xs.into_iter().map(|x| x.0).collect()) + .map(Value::Array) + .map(ArbitraryValue), + ) + } + Value::SimpleString(ref status) => { + Box::new(status.shrink().map(Value::SimpleString).map(ArbitraryValue)) + } + Value::Double(i) => Box::new(i.shrink().map(Value::Double).map(ArbitraryValue)), + Value::Boolean(i) => Box::new(i.shrink().map(Value::Boolean).map(ArbitraryValue)), + Value::BigNumber(ref i) => { + Box::new(vec![ArbitraryValue(Value::BigNumber(i.clone()))].into_iter()) + } + Value::VerbatimString { + ref format, + ref text, + } => Box::new( + vec![ArbitraryValue(Value::VerbatimString { + format: format.clone(), + text: text.clone(), + })] + .into_iter(), + ), + } + } +} + +fn arbitrary_value(g: &mut Gen, recursive_size: usize) -> Value { + use quickcheck::Arbitrary; + if recursive_size == 0 { + Value::Nil + } else { + match u8::arbitrary(g) % 6 { + 0 => Value::Nil, + 1 => Value::Int(Arbitrary::arbitrary(g)), + 2 => Value::BulkString(Arbitrary::arbitrary(g)), + 3 => { + let size = { + let s = g.size(); + usize::arbitrary(g) % s + }; + Value::Array( + (0..size) + .map(|_| arbitrary_value(g, recursive_size / size)) + .collect(), + ) + } + 4 => { + let size = { + let s = g.size(); + usize::arbitrary(g) % s + }; + + let mut string = String::with_capacity(size); + for _ in 0..size { + let c = char::arbitrary(g); + if c.is_ascii_alphabetic() { + string.push(c); + } + } + + if string == "OK" { + Value::Okay + } else { + Value::SimpleString(string) + } + } + 5 => Value::Okay, + _ => unreachable!(), + } + } +} + +struct PartialAsyncRead { + inner: R, + ops: Box + Send>, +} + +impl AsyncRead for PartialAsyncRead +where + R: AsyncRead + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.ops.next() { + Some(PartialOp::Limited(n)) => { + let len = std::cmp::min(n, buf.remaining()); + buf.initialize_unfilled(); + let mut sub_buf = buf.take(len); + ready!(Pin::new(&mut self.inner).poll_read(cx, &mut sub_buf))?; + let filled = sub_buf.filled().len(); + buf.advance(filled); + Poll::Ready(Ok(())) + } + Some(PartialOp::Err(err)) => { + if err == io::ErrorKind::WouldBlock { + cx.waker().wake_by_ref(); + Poll::Pending + } else { + Err(io::Error::new( + err, + "error during read, generated by partial-io", + )) + .into() + } + } + Some(PartialOp::Unlimited) | None => Pin::new(&mut self.inner).poll_read(cx, buf), + } + } +} + +quickcheck! { + fn partial_io_parse(input: ArbitraryValue, seq: PartialWithErrors) -> () { + + let mut encoded_input = Vec::new(); + encode_value(&input.0, &mut encoded_input).unwrap(); + + let mut reader = &encoded_input[..]; + let mut partial_reader = PartialAsyncRead { inner: &mut reader, ops: Box::new(seq.into_iter()) }; + let mut decoder = combine::stream::Decoder::new(); + + let result = block_on_all(redis::parse_redis_value_async(&mut decoder, &mut partial_reader)); + assert!(result.as_ref().is_ok(), "{}", result.unwrap_err()); + assert_eq!( + result.unwrap(), + input.0, + ); + } +} diff --git a/glide-core/redis-rs/redis/tests/support/cluster.rs b/glide-core/redis-rs/redis/tests/support/cluster.rs new file mode 100644 index 0000000000..991331cfca --- /dev/null +++ b/glide-core/redis-rs/redis/tests/support/cluster.rs @@ -0,0 +1,792 @@ +#![cfg(feature = "cluster")] +#![allow(dead_code)] + +use std::convert::identity; +use std::env; +use std::process; +use std::thread::sleep; +use std::time::Duration; + +use redis::cluster_routing::RoutingInfo; +use redis::cluster_routing::SingleNodeRoutingInfo; +use redis::from_redis_value; + +#[cfg(feature = "cluster-async")] +use redis::aio::ConnectionLike; +#[cfg(feature = "cluster-async")] +use redis::cluster_async::Connect; +use redis::ConnectionInfo; +use redis::ProtocolVersion; +use redis::PushInfo; +use redis::RedisResult; +use redis::Value; +use tempfile::TempDir; + +use crate::support::{build_keys_and_certs_for_tls, Module}; + +#[cfg(feature = "tls-rustls")] +use super::{build_single_client, load_certs_from_file}; + +use super::use_protocol; +use super::RedisServer; +use super::TlsFilePaths; +use tokio::sync::mpsc; + +const LOCALHOST: &str = "127.0.0.1"; + +enum ClusterType { + Tcp, + TcpTls, +} + +impl ClusterType { + fn get_intended() -> ClusterType { + match env::var("REDISRS_SERVER_TYPE") + .ok() + .as_ref() + .map(|x| &x[..]) + { + Some("tcp") => ClusterType::Tcp, + Some("tcp+tls") => ClusterType::TcpTls, + Some(val) => { + panic!("Unknown server type {val:?}"); + } + None => ClusterType::Tcp, + } + } + + fn build_addr(port: u16) -> redis::ConnectionAddr { + match ClusterType::get_intended() { + ClusterType::Tcp => redis::ConnectionAddr::Tcp("127.0.0.1".into(), port), + ClusterType::TcpTls => redis::ConnectionAddr::TcpTls { + host: "127.0.0.1".into(), + port, + insecure: true, + tls_params: None, + }, + } + } +} + +fn port_in_use(addr: &str) -> bool { + let socket_addr: std::net::SocketAddr = addr.parse().expect("Invalid address"); + let socket = socket2::Socket::new( + socket2::Domain::for_address(socket_addr), + socket2::Type::STREAM, + None, + ) + .expect("Failed to create socket"); + + socket.connect(&socket_addr.into()).is_ok() +} + +pub struct RedisCluster { + pub servers: Vec, + pub folders: Vec, + pub tls_paths: Option, +} + +impl RedisCluster { + pub fn username() -> &'static str { + "hello" + } + + pub fn password() -> &'static str { + "world" + } + + pub fn client_name() -> &'static str { + "test_cluster_client" + } + + pub fn new(nodes: u16, replicas: u16) -> RedisCluster { + RedisCluster::with_modules(nodes, replicas, &[], false) + } + + #[cfg(feature = "tls-rustls")] + pub fn new_with_mtls(nodes: u16, replicas: u16) -> RedisCluster { + RedisCluster::with_modules(nodes, replicas, &[], true) + } + + pub fn with_modules( + nodes: u16, + replicas: u16, + modules: &[Module], + mtls_enabled: bool, + ) -> RedisCluster { + let mut servers = vec![]; + let mut folders = vec![]; + let mut addrs = vec![]; + let start_port = 7000; + let mut tls_paths = None; + + let mut is_tls = false; + + if let ClusterType::TcpTls = ClusterType::get_intended() { + // Create a shared set of keys in cluster mode + let tempdir = tempfile::Builder::new() + .prefix("redis") + .tempdir() + .expect("failed to create tempdir"); + let files = build_keys_and_certs_for_tls(&tempdir); + folders.push(tempdir); + tls_paths = Some(files); + is_tls = true; + } + + let max_attempts = 5; + + for node in 0..nodes { + let port = start_port + node; + + servers.push(RedisServer::new_with_addr_tls_modules_and_spawner( + ClusterType::build_addr(port), + None, + tls_paths.clone(), + mtls_enabled, + modules, + |cmd| { + let tempdir = tempfile::Builder::new() + .prefix("redis") + .tempdir() + .expect("failed to create tempdir"); + let acl_path = tempdir.path().join("users.acl"); + let acl_content = format!( + "user {} on allcommands allkeys >{}", + Self::username(), + Self::password() + ); + std::fs::write(&acl_path, acl_content).expect("failed to write acl file"); + cmd.arg("--cluster-enabled") + .arg("yes") + .arg("--cluster-config-file") + .arg(tempdir.path().join("nodes.conf")) + .arg("--cluster-node-timeout") + .arg("5000") + .arg("--appendonly") + .arg("yes") + .arg("--aclfile") + .arg(&acl_path); + if is_tls { + cmd.arg("--tls-cluster").arg("yes"); + if replicas > 0 { + cmd.arg("--tls-replication").arg("yes"); + } + } + let addr = format!("127.0.0.1:{port}"); + cmd.current_dir(tempdir.path()); + folders.push(tempdir); + addrs.push(addr.clone()); + + let mut cur_attempts = 0; + loop { + let mut process = cmd.spawn().unwrap(); + sleep(Duration::from_millis(100)); + + match process.try_wait() { + Ok(Some(status)) => { + let err = + format!("redis server creation failed with status {status:?}"); + if cur_attempts == max_attempts { + panic!("{err}"); + } + eprintln!("Retrying: {err}"); + cur_attempts += 1; + } + Ok(None) => { + let max_attempts = 20; + let mut cur_attempts = 0; + loop { + if cur_attempts == max_attempts { + panic!("redis server creation failed: Port {port} closed") + } + if port_in_use(&addr) { + return process; + } + eprintln!("Waiting for redis process to initialize"); + sleep(Duration::from_millis(50)); + cur_attempts += 1; + } + } + Err(e) => { + panic!("Unexpected error in redis server creation {e}"); + } + } + } + }, + )); + } + + let mut cmd = process::Command::new("redis-cli"); + cmd.stdout(process::Stdio::null()) + .arg("--cluster") + .arg("create") + .args(&addrs); + if replicas > 0 { + cmd.arg("--cluster-replicas").arg(replicas.to_string()); + } + cmd.arg("--cluster-yes"); + + if is_tls { + if mtls_enabled { + if let Some(TlsFilePaths { + redis_crt, + redis_key, + ca_crt, + }) = &tls_paths + { + cmd.arg("--cert"); + cmd.arg(redis_crt); + cmd.arg("--key"); + cmd.arg(redis_key); + cmd.arg("--cacert"); + cmd.arg(ca_crt); + cmd.arg("--tls"); + } + } else { + cmd.arg("--tls").arg("--insecure"); + } + } + + let mut cur_attempts = 0; + loop { + let output = cmd.output().unwrap(); + if output.status.success() { + break; + } else { + let err = format!("Cluster creation failed: {output:?}"); + if cur_attempts == max_attempts { + panic!("{err}"); + } + eprintln!("Retrying: {err}"); + sleep(Duration::from_millis(50)); + cur_attempts += 1; + } + } + + let cluster = RedisCluster { + servers, + folders, + tls_paths, + }; + if replicas > 0 { + cluster.wait_for_replicas(replicas, mtls_enabled); + } + + wait_for_status_ok(&cluster); + cluster + } + + // parameter `_mtls_enabled` can only be used if `feature = tls-rustls` is active + #[allow(dead_code)] + fn wait_for_replicas(&self, replicas: u16, _mtls_enabled: bool) { + 'server: for server in &self.servers { + let conn_info = server.connection_info(); + eprintln!( + "waiting until {:?} knows required number of replicas", + conn_info.addr + ); + + #[cfg(feature = "tls-rustls")] + let client = + build_single_client(server.connection_info(), &self.tls_paths, _mtls_enabled) + .unwrap(); + #[cfg(not(feature = "tls-rustls"))] + let client = redis::Client::open(server.connection_info()).unwrap(); + + let mut con = client.get_connection(None).unwrap(); + + // retry 500 times + for _ in 1..500 { + let value = redis::cmd("CLUSTER").arg("SLOTS").query(&mut con).unwrap(); + let slots: Vec> = redis::from_owned_redis_value(value).unwrap(); + + // all slots should have following items: + // [start slot range, end slot range, master's IP, replica1's IP, replica2's IP,... ] + if slots.iter().all(|slot| slot.len() >= 3 + replicas as usize) { + continue 'server; + } + + sleep(Duration::from_millis(100)); + } + + panic!("failed to create enough replicas"); + } + } + + pub fn stop(&mut self) { + for server in &mut self.servers { + server.stop(); + } + } + + pub fn iter_servers(&self) -> impl Iterator { + self.servers.iter() + } +} + +fn wait_for_status_ok(cluster: &RedisCluster) { + 'server: for server in &cluster.servers { + let log_file = RedisServer::log_file(&server.tempdir); + + for _ in 1..500 { + let contents = + std::fs::read_to_string(&log_file).expect("Should have been able to read the file"); + + if contents.contains("Cluster state changed: ok") { + continue 'server; + } + sleep(Duration::from_millis(20)); + } + panic!("failed to reach state change: OK"); + } +} + +impl Drop for RedisCluster { + fn drop(&mut self) { + self.stop() + } +} + +pub struct TestClusterContext { + pub cluster: RedisCluster, + pub client: redis::cluster::ClusterClient, + pub mtls_enabled: bool, + pub nodes: Vec, + pub protocol: ProtocolVersion, +} + +impl TestClusterContext { + pub fn new(nodes: u16, replicas: u16) -> TestClusterContext { + Self::new_with_cluster_client_builder(nodes, replicas, identity, false) + } + + #[cfg(feature = "tls-rustls")] + pub fn new_with_mtls(nodes: u16, replicas: u16) -> TestClusterContext { + Self::new_with_cluster_client_builder(nodes, replicas, identity, true) + } + + pub fn new_with_cluster_client_builder( + nodes: u16, + replicas: u16, + initializer: F, + mtls_enabled: bool, + ) -> TestClusterContext + where + F: FnOnce(redis::cluster::ClusterClientBuilder) -> redis::cluster::ClusterClientBuilder, + { + let cluster = RedisCluster::new(nodes, replicas); + let initial_nodes: Vec = cluster + .iter_servers() + .map(RedisServer::connection_info) + .collect(); + let mut builder = redis::cluster::ClusterClientBuilder::new(initial_nodes.clone()) + .use_protocol(use_protocol()); + + #[cfg(feature = "tls-rustls")] + if mtls_enabled { + if let Some(tls_file_paths) = &cluster.tls_paths { + builder = builder.certs(load_certs_from_file(tls_file_paths)); + } + } + + builder = initializer(builder); + + let client = builder.build().unwrap(); + + TestClusterContext { + cluster, + client, + mtls_enabled, + nodes: initial_nodes, + protocol: use_protocol(), + } + } + + pub fn connection(&self) -> redis::cluster::ClusterConnection { + self.client.get_connection(None).unwrap() + } + + #[cfg(feature = "cluster-async")] + pub async fn async_connection( + &self, + push_sender: Option>, + ) -> redis::cluster_async::ClusterConnection { + self.client.get_async_connection(push_sender).await.unwrap() + } + + #[cfg(feature = "cluster-async")] + pub async fn async_generic_connection< + C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static, + >( + &self, + ) -> redis::cluster_async::ClusterConnection { + self.client + .get_async_generic_connection::() + .await + .unwrap() + } + + pub fn wait_for_cluster_up(&self) { + let mut con = self.connection(); + let mut c = redis::cmd("CLUSTER"); + c.arg("INFO"); + + for _ in 0..100 { + let r: String = c.query::(&mut con).unwrap(); + if r.starts_with("cluster_state:ok") { + return; + } + + sleep(Duration::from_millis(25)); + } + + panic!("failed waiting for cluster to be ready"); + } + + pub fn disable_default_user(&self) { + for server in &self.cluster.servers { + #[cfg(feature = "tls-rustls")] + let client = build_single_client( + server.connection_info(), + &self.cluster.tls_paths, + self.mtls_enabled, + ) + .unwrap(); + #[cfg(not(feature = "tls-rustls"))] + let client = redis::Client::open(server.connection_info()).unwrap(); + + let mut con = client.get_connection(None).unwrap(); + let _: () = redis::cmd("ACL") + .arg("SETUSER") + .arg("default") + .arg("off") + .query(&mut con) + .unwrap(); + + // subsequent unauthenticated command should fail: + if let Ok(mut con) = client.get_connection(None) { + assert!(redis::cmd("PING").query::<()>(&mut con).is_err()); + } + } + } + + pub fn get_version(&self) -> super::Version { + let mut conn = self.connection(); + super::get_version(&mut conn) + } + + pub fn get_node_ids(&self) -> Vec { + let mut conn = self.connection(); + let nodes: Vec = redis::cmd("CLUSTER") + .arg("NODES") + .query::(&mut conn) + .unwrap() + .split('\n') + .map(|s| s.to_string()) + .collect(); + let node_ids: Vec = nodes + .iter() + .map(|node| node.split(' ').next().unwrap().to_string()) + .collect(); + node_ids + .iter() + .filter(|id| !id.is_empty()) + .cloned() + .collect() + } + + // Migrate half the slots from one node to another + pub async fn migrate_slots_from_node_to_another( + &self, + slot_distribution: Vec<(String, String, String, Vec>)>, + ) { + let slots_ranges_of_node_id = slot_distribution[0].3.clone(); + + let mut conn = self.async_connection(None).await; + + let from = slot_distribution[0].clone(); + let target = slot_distribution[1].clone(); + + let from_node_id = from.0.clone(); + let target_node_id = target.0.clone(); + + let from_route = RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: from.1.clone(), + port: from.2.clone().parse::().unwrap(), + }); + let target_route = RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: target.1.clone(), + port: target.2.clone().parse::().unwrap(), + }); + + // Migrate the slots + for range in slots_ranges_of_node_id { + let mut slots_of_nodes: std::ops::Range = range[0]..range[1]; + let number_of_slots = range[1] - range[0] + 1; + // Migrate half the slots + for _i in 0..(number_of_slots as f64 / 2.0).floor() as usize { + let slot = slots_of_nodes.next().unwrap(); + // Set the nodes to MIGRATING and IMPORTING + let mut set_cmd = redis::cmd("CLUSTER"); + set_cmd + .arg("SETSLOT") + .arg(slot) + .arg("IMPORTING") + .arg(from_node_id.clone()); + let result: RedisResult = + conn.route_command(&set_cmd, target_route.clone()).await; + match result { + Ok(_) => {} + Err(err) => { + println!( + "Failed to set slot {} to IMPORTING with error {}", + slot, err + ); + } + } + let mut set_cmd = redis::cmd("CLUSTER"); + set_cmd + .arg("SETSLOT") + .arg(slot) + .arg("MIGRATING") + .arg(target_node_id.clone()); + let result: RedisResult = + conn.route_command(&set_cmd, from_route.clone()).await; + match result { + Ok(_) => {} + Err(err) => { + println!( + "Failed to set slot {} to MIGRATING with error {}", + slot, err + ); + } + } + // Get a key from the slot + let mut get_key_cmd = redis::cmd("CLUSTER"); + get_key_cmd.arg("GETKEYSINSLOT").arg(slot).arg(1); + let result: RedisResult = + conn.route_command(&get_key_cmd, from_route.clone()).await; + let vec_string_result: Vec = match result { + Ok(val) => { + let val: Vec = from_redis_value(&val).unwrap(); + val + } + Err(err) => { + println!("Failed to get keys in slot {}: {:?}", slot, err); + continue; + } + }; + if vec_string_result.is_empty() { + continue; + } + let key = vec_string_result[0].clone(); + // Migrate the key, which will make the whole slot to move + let mut migrate_cmd = redis::cmd("MIGRATE"); + migrate_cmd + .arg(target.1.clone()) + .arg(target.2.clone()) + .arg(key.clone()) + .arg(0) + .arg(5000); + let result: RedisResult = + conn.route_command(&migrate_cmd, from_route.clone()).await; + + match result { + Ok(Value::Okay) => {} + Ok(Value::SimpleString(str)) => { + if str != "NOKEY" { + println!( + "Failed to migrate key {} to target node with status {}", + key, str + ); + } else { + println!("Key {} does not exist", key); + } + } + Ok(_) => {} + Err(err) => { + println!( + "Failed to migrate key {} to target node with error {}", + key, err + ); + } + } + // Tell the source and target nodes to propagate the slot change to the cluster + let mut setslot_cmd = redis::cmd("CLUSTER"); + setslot_cmd + .arg("SETSLOT") + .arg(slot) + .arg("NODE") + .arg(target_node_id.clone()); + let result: RedisResult = + conn.route_command(&setslot_cmd, target_route.clone()).await; + match result { + Ok(_) => {} + Err(err) => { + println!( + "Failed to set slot {} to target NODE with error {}", + slot, err + ); + } + }; + self.wait_for_connection_is_ready(&from_route) + .await + .unwrap(); + self.wait_for_connection_is_ready(&target_route) + .await + .unwrap(); + self.wait_for_cluster_up(); + } + } + } + + // Return the slots distribution of the cluster as a vector of tuples + // where the first element is the node id, seconed is host, third is port and the last element is a vector of slots ranges + pub fn get_slots_ranges_distribution( + &self, + cluster_nodes: &str, + ) -> Vec<(String, String, String, Vec>)> { + let nodes_string: Vec = cluster_nodes + .split('\n') + .map(|s| s.to_string()) + .filter(|s| !s.is_empty()) + .collect(); + let mut nodes: Vec> = vec![]; + for node in nodes_string { + let node_vec: Vec = node.split(' ').map(|s| s.to_string()).collect(); + if node_vec.last().unwrap() == "connected" || node_vec.last().unwrap() == "disconnected" + { + continue; + } else { + nodes.push(node_vec); + } + } + let mut slot_distribution = vec![]; + for node in &nodes { + let mut slots_ranges: Vec> = vec![]; + let mut slots_ranges_vec: Vec = vec![]; + let node_id = node[0].clone(); + let host_and_port: Vec = node[1].split(':').map(|s| s.to_string()).collect(); + let host = host_and_port[0].clone(); + let port = host_and_port[1].split('@').next().unwrap().to_string(); + let slots = node[8..].to_vec(); + for slot in slots { + if slot.contains("->") || slot.contains("<-") { + continue; + } + if slot.contains('-') { + let range: Vec = + slot.split('-').map(|s| s.parse::().unwrap()).collect(); + slots_ranges_vec.push(range[0]); + slots_ranges_vec.push(range[1]); + slots_ranges.push(slots_ranges_vec.clone()); + slots_ranges_vec.clear(); + } else { + let slot: u16 = slot.parse::().unwrap(); + slots_ranges_vec.push(slot); + slots_ranges_vec.push(slot); + slots_ranges.push(slots_ranges_vec.clone()); + slots_ranges_vec.clear(); + } + } + let parsed_node: (String, String, String, Vec>) = + (node_id, host, port, slots_ranges); + slot_distribution.push(parsed_node); + } + slot_distribution + } + + pub async fn get_masters(&self, cluster_nodes: &str) -> Vec> { + let mut masters = vec![]; + for line in cluster_nodes.lines() { + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() < 3 { + continue; + } + if parts[2] == "master" || parts[2] == "myself,master" { + let id = parts[0]; + let host_and_port = parts[1].split(':'); + let host = host_and_port.clone().next().unwrap(); + let port = host_and_port + .clone() + .last() + .unwrap() + .split('@') + .next() + .unwrap(); + masters.push(vec![id.to_string(), host.to_string(), port.to_string()]); + } + } + masters + } + + pub async fn get_replicas(&self, cluster_nodes: &str) -> Vec> { + let mut replicas = vec![]; + for line in cluster_nodes.lines() { + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() < 3 { + continue; + } + if parts[2] == "slave" || parts[2] == "myself,slave" { + let id = parts[0]; + let host_and_port = parts[1].split(':'); + let host = host_and_port.clone().next().unwrap(); + let port = host_and_port + .clone() + .last() + .unwrap() + .split('@') + .next() + .unwrap(); + replicas.push(vec![id.to_string(), host.to_string(), port.to_string()]); + } + } + replicas + } + + pub async fn get_cluster_nodes(&self) -> String { + let mut conn = self.async_connection(None).await; + let mut cmd = redis::cmd("CLUSTER"); + cmd.arg("NODES"); + let res: RedisResult = conn + .route_command(&cmd, RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + .await; + let res: String = from_redis_value(&res.unwrap()).unwrap(); + res + } + + pub async fn wait_for_fail_to_finish(&self, route: &RoutingInfo) -> RedisResult<()> { + for _ in 0..500 { + let mut conn = self.async_connection(None).await; + let cmd = redis::cmd("PING"); + let res: RedisResult = conn.route_command(&cmd, route.clone()).await; + if res.is_err() { + return Ok(()); + } + sleep(Duration::from_millis(50)); + } + Err(redis::RedisError::from(( + redis::ErrorKind::IoError, + "Failed to get connection", + ))) + } + + pub async fn wait_for_connection_is_ready(&self, route: &RoutingInfo) -> RedisResult<()> { + let mut i = 1; + while i < 1000 { + let mut conn = self.async_connection(None).await; + let cmd = redis::cmd("PING"); + let res: RedisResult = conn.route_command(&cmd, route.clone()).await; + if res.is_ok() { + return Ok(()); + } + sleep(Duration::from_millis(i * 10)); + i += 10; + } + Err(redis::RedisError::from(( + redis::ErrorKind::IoError, + "Failed to get connection", + ))) + } +} diff --git a/glide-core/redis-rs/redis/tests/support/mock_cluster.rs b/glide-core/redis-rs/redis/tests/support/mock_cluster.rs new file mode 100644 index 0000000000..ce91988cef --- /dev/null +++ b/glide-core/redis-rs/redis/tests/support/mock_cluster.rs @@ -0,0 +1,487 @@ +use redis::{ + cluster::{self, ClusterClient, ClusterClientBuilder}, + ErrorKind, FromRedisValue, GlideConnectionOptions, RedisError, +}; + +use std::{ + collections::HashMap, + net::{IpAddr, Ipv4Addr, SocketAddr}, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, RwLock, + }, + time::Duration, +}; + +use { + once_cell::sync::Lazy, + redis::{IntoConnectionInfo, RedisResult, Value}, +}; + +#[cfg(feature = "cluster-async")] +use redis::{aio, cluster_async, RedisFuture}; + +#[cfg(feature = "cluster-async")] +use futures::future; + +#[cfg(feature = "cluster-async")] +use tokio::runtime::Runtime; + +type Handler = Arc Result<(), RedisResult> + Send + Sync>; + +pub struct MockConnectionBehavior { + pub id: String, + pub handler: Handler, + pub connection_id_provider: AtomicUsize, + pub returned_ip_type: ConnectionIPReturnType, + pub return_connection_err: ShouldReturnConnectionError, +} + +impl MockConnectionBehavior { + fn new(id: &str, handler: Handler) -> Self { + Self { + id: id.to_string(), + handler, + connection_id_provider: AtomicUsize::new(0), + returned_ip_type: ConnectionIPReturnType::default(), + return_connection_err: ShouldReturnConnectionError::default(), + } + } + + #[must_use] + pub fn register_new(id: &str, handler: Handler) -> RemoveHandler { + get_behaviors().insert(id.to_string(), Self::new(id, handler)); + RemoveHandler(vec![id.to_string()]) + } + + fn get_handler(&self) -> Handler { + self.handler.clone() + } +} + +pub fn modify_mock_connection_behavior(name: &str, func: impl FnOnce(&mut MockConnectionBehavior)) { + func( + get_behaviors() + .get_mut(name) + .expect("Handler `{name}` was not installed"), + ); +} + +pub fn get_mock_connection_handler(name: &str) -> Handler { + MOCK_CONN_BEHAVIORS + .read() + .unwrap() + .get(name) + .expect("Handler `{name}` was not installed") + .get_handler() +} + +pub fn get_mock_connection(name: &str, id: usize) -> MockConnection { + get_mock_connection_with_port(name, id, 6379) +} + +pub fn get_mock_connection_with_port(name: &str, id: usize, port: u16) -> MockConnection { + MockConnection { + id, + handler: get_mock_connection_handler(name), + port, + } +} + +static MOCK_CONN_BEHAVIORS: Lazy>> = + Lazy::new(Default::default); + +fn get_behaviors() -> std::sync::RwLockWriteGuard<'static, HashMap> +{ + MOCK_CONN_BEHAVIORS.write().unwrap() +} + +#[derive(Default)] +pub enum ConnectionIPReturnType { + /// New connections' IP will be returned as None + #[default] + None, + /// Creates connections with the specified IP + Specified(IpAddr), + /// Each new connection will be created with a different IP based on the passed atomic integer + Different(AtomicUsize), +} + +#[derive(Default)] +pub enum ShouldReturnConnectionError { + /// Don't return a connection error + #[default] + No, + /// Always return a connection error + Yes, + /// Return connection error when the internal index is an odd number + OnOddIdx(AtomicUsize), +} + +#[derive(Clone)] +pub struct MockConnection { + pub id: usize, + pub handler: Handler, + pub port: u16, +} + +#[cfg(feature = "cluster-async")] +impl cluster_async::Connect for MockConnection { + fn connect<'a, T>( + info: T, + _response_timeout: Duration, + _connection_timeout: Duration, + _socket_addr: Option, + _glide_connection_options: GlideConnectionOptions, + ) -> RedisFuture<'a, (Self, Option)> + where + T: IntoConnectionInfo + Send + 'a, + { + let info = info.into_connection_info().unwrap(); + + let (name, port) = match &info.addr { + redis::ConnectionAddr::Tcp(addr, port) => (addr, *port), + _ => unreachable!(), + }; + let binding = MOCK_CONN_BEHAVIORS.read().unwrap(); + let conn_utils = binding + .get(name) + .unwrap_or_else(|| panic!("MockConnectionUtils for `{name}` were not installed")); + let conn_err = Box::pin(future::err(RedisError::from(std::io::Error::new( + std::io::ErrorKind::ConnectionReset, + "mock-io-error", + )))); + match &conn_utils.return_connection_err { + ShouldReturnConnectionError::No => {} + ShouldReturnConnectionError::Yes => return conn_err, + ShouldReturnConnectionError::OnOddIdx(curr_idx) => { + if curr_idx.fetch_add(1, Ordering::SeqCst) % 2 != 0 { + // raise an error on each odd number + return conn_err; + } + } + } + + let ip = match &conn_utils.returned_ip_type { + ConnectionIPReturnType::Specified(ip) => Some(*ip), + ConnectionIPReturnType::Different(ip_getter) => { + let first_ip_num = ip_getter.fetch_add(1, Ordering::SeqCst) as u8; + Some(IpAddr::V4(Ipv4Addr::new(first_ip_num, 0, 0, 0))) + } + ConnectionIPReturnType::None => None, + }; + + Box::pin(future::ok(( + MockConnection { + id: conn_utils + .connection_id_provider + .fetch_add(1, Ordering::SeqCst), + handler: conn_utils.get_handler(), + port, + }, + ip, + ))) + } +} + +impl cluster::Connect for MockConnection { + fn connect<'a, T>(info: T, _timeout: Option) -> RedisResult + where + T: IntoConnectionInfo, + { + let info = info.into_connection_info().unwrap(); + + let (name, port) = match &info.addr { + redis::ConnectionAddr::Tcp(addr, port) => (addr, *port), + _ => unreachable!(), + }; + let binding = MOCK_CONN_BEHAVIORS.read().unwrap(); + let conn_utils = binding + .get(name) + .unwrap_or_else(|| panic!("MockConnectionUtils for `{name}` were not installed")); + Ok(MockConnection { + id: conn_utils + .connection_id_provider + .fetch_add(1, Ordering::SeqCst), + handler: conn_utils.get_handler(), + port, + }) + } + + fn send_packed_command(&mut self, _cmd: &[u8]) -> RedisResult<()> { + Ok(()) + } + + fn set_write_timeout(&self, _dur: Option) -> RedisResult<()> { + Ok(()) + } + + fn set_read_timeout(&self, _dur: Option) -> RedisResult<()> { + Ok(()) + } + + fn recv_response(&mut self) -> RedisResult { + Ok(Value::Nil) + } +} + +pub fn contains_slice(xs: &[u8], ys: &[u8]) -> bool { + for i in 0..xs.len() { + if xs[i..].starts_with(ys) { + return true; + } + } + false +} + +pub fn respond_startup(name: &str, cmd: &[u8]) -> Result<(), RedisResult> { + if contains_slice(cmd, b"PING") || contains_slice(cmd, b"SETNAME") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![Value::Array(vec![ + Value::Int(0), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6379), + ]), + ])]))) + } else if contains_slice(cmd, b"READONLY") { + Err(Ok(Value::SimpleString("OK".into()))) + } else { + Ok(()) + } +} + +#[derive(Clone, Debug)] +pub struct MockSlotRange { + pub primary_port: u16, + pub replica_ports: Vec, + pub slot_range: std::ops::Range, +} + +pub fn respond_startup_with_replica(name: &str, cmd: &[u8]) -> Result<(), RedisResult> { + respond_startup_with_replica_using_config(name, cmd, None) +} + +pub fn respond_startup_two_nodes(name: &str, cmd: &[u8]) -> Result<(), RedisResult> { + respond_startup_with_config(name, cmd, None, false) +} + +pub fn create_topology_from_config(name: &str, slots_config: Vec) -> Value { + let slots_vec = slots_config + .into_iter() + .map(|slot_config| { + let mut config = vec![ + Value::Int(slot_config.slot_range.start as i64), + Value::Int(slot_config.slot_range.end as i64), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(slot_config.primary_port as i64), + ]), + ]; + config.extend(slot_config.replica_ports.into_iter().map(|replica_port| { + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(replica_port as i64), + ]) + })); + Value::Array(config) + }) + .collect(); + Value::Array(slots_vec) +} + +pub fn respond_startup_with_replica_using_config( + name: &str, + cmd: &[u8], + slots_config: Option>, +) -> Result<(), RedisResult> { + respond_startup_with_config(name, cmd, slots_config, true) +} + +/// If the configuration isn't provided, a configuration with two primary nodes, with or without replicas, will be used. +pub fn respond_startup_with_config( + name: &str, + cmd: &[u8], + slots_config: Option>, + with_replicas: bool, +) -> Result<(), RedisResult> { + let slots_config = slots_config.unwrap_or(if with_replicas { + vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380], + slot_range: (0..8191), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: vec![6382], + slot_range: (8192..16383), + }, + ] + } else { + vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![], + slot_range: (0..8191), + }, + MockSlotRange { + primary_port: 6380, + replica_ports: vec![], + slot_range: (8192..16383), + }, + ] + }); + if contains_slice(cmd, b"PING") || contains_slice(cmd, b"SETNAME") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + let slots = create_topology_from_config(name, slots_config); + Err(Ok(slots)) + } else if contains_slice(cmd, b"READONLY") { + Err(Ok(Value::SimpleString("OK".into()))) + } else { + Ok(()) + } +} + +#[cfg(feature = "cluster-async")] +impl aio::ConnectionLike for MockConnection { + fn req_packed_command<'a>(&'a mut self, cmd: &'a redis::Cmd) -> RedisFuture<'a, Value> { + Box::pin(future::ready( + (self.handler)(&cmd.get_packed_command(), self.port) + .expect_err("Handler did not specify a response"), + )) + } + + fn req_packed_commands<'a>( + &'a mut self, + _pipeline: &'a redis::Pipeline, + _offset: usize, + _count: usize, + ) -> RedisFuture<'a, Vec> { + Box::pin(future::ok(vec![])) + } + + fn get_db(&self) -> i64 { + 0 + } + + fn is_closed(&self) -> bool { + false + } +} + +impl redis::ConnectionLike for MockConnection { + fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult { + (self.handler)(cmd, self.port).expect_err("Handler did not specify a response") + } + + fn req_packed_commands( + &mut self, + cmd: &[u8], + offset: usize, + _count: usize, + ) -> RedisResult> { + let res = (self.handler)(cmd, self.port).expect_err("Handler did not specify a response"); + match res { + Err(err) => Err(err), + Ok(res) => { + if let Value::Array(results) = res { + match results.into_iter().nth(offset) { + Some(Value::Array(res)) => Ok(res), + _ => Err((ErrorKind::ResponseError, "non-array response").into()), + } + } else { + Err(( + ErrorKind::ResponseError, + "non-array response", + String::from_owned_redis_value(res).unwrap(), + ) + .into()) + } + } + } + } + + fn get_db(&self) -> i64 { + 0 + } + + fn check_connection(&mut self) -> bool { + true + } + + fn is_open(&self) -> bool { + true + } +} + +pub struct MockEnv { + #[cfg(feature = "cluster-async")] + pub runtime: Runtime, + pub client: redis::cluster::ClusterClient, + pub connection: redis::cluster::ClusterConnection, + #[cfg(feature = "cluster-async")] + pub async_connection: redis::cluster_async::ClusterConnection, + #[allow(unused)] + pub handler: RemoveHandler, +} + +pub struct RemoveHandler(Vec); + +impl Drop for RemoveHandler { + fn drop(&mut self) { + for id in &self.0 { + get_behaviors().remove(id); + } + } +} + +impl MockEnv { + pub fn new( + id: &str, + handler: impl Fn(&[u8], u16) -> Result<(), RedisResult> + Send + Sync + 'static, + ) -> Self { + Self::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{id}")]), + id, + handler, + ) + } + + pub fn with_client_builder( + client_builder: ClusterClientBuilder, + id: &str, + handler: impl Fn(&[u8], u16) -> Result<(), RedisResult> + Send + Sync + 'static, + ) -> Self { + #[cfg(feature = "cluster-async")] + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .unwrap(); + + let id = id.to_string(); + let handler = MockConnectionBehavior::register_new( + &id, + Arc::new(move |cmd, port| handler(cmd, port)), + ); + let client = client_builder.build().unwrap(); + let connection = client.get_generic_connection(None).unwrap(); + #[cfg(feature = "cluster-async")] + let async_connection = runtime + .block_on(client.get_async_generic_connection()) + .unwrap(); + MockEnv { + #[cfg(feature = "cluster-async")] + runtime, + client, + connection, + #[cfg(feature = "cluster-async")] + async_connection, + handler, + } + } +} diff --git a/glide-core/redis-rs/redis/tests/support/mod.rs b/glide-core/redis-rs/redis/tests/support/mod.rs new file mode 100644 index 0000000000..335cd045de --- /dev/null +++ b/glide-core/redis-rs/redis/tests/support/mod.rs @@ -0,0 +1,887 @@ +#![allow(dead_code)] + +use std::path::Path; +use std::{ + env, fs, io, net::SocketAddr, net::TcpListener, path::PathBuf, process, thread::sleep, + time::Duration, +}; +#[cfg(feature = "tls-rustls")] +use std::{ + fs::File, + io::{BufReader, Read}, +}; + +#[cfg(feature = "aio")] +use futures::Future; +use redis::{ConnectionAddr, InfoDict, Pipeline, ProtocolVersion, RedisConnectionInfo, Value}; + +#[cfg(feature = "tls-rustls")] +use redis::{ClientTlsConfig, TlsCertificates}; + +use socket2::{Domain, Socket, Type}; +use tempfile::TempDir; + +#[cfg(feature = "aio")] +use redis::GlideConnectionOptions; + +pub fn use_protocol() -> ProtocolVersion { + if env::var("PROTOCOL").unwrap_or_default() == "RESP3" { + ProtocolVersion::RESP3 + } else { + ProtocolVersion::RESP2 + } +} + +pub fn current_thread_runtime() -> tokio::runtime::Runtime { + let mut builder = tokio::runtime::Builder::new_current_thread(); + + #[cfg(feature = "aio")] + builder.enable_io(); + + builder.enable_time(); + + builder.build().unwrap() +} + +#[cfg(feature = "aio")] +pub fn block_on_all(f: F) -> F::Output +where + F: Future>, +{ + use std::panic; + use std::sync::atomic::{AtomicBool, Ordering}; + + static CHECK: AtomicBool = AtomicBool::new(false); + + // TODO - this solution is purely single threaded, and won't work on multiple threads at the same time. + // This is needed because Tokio's Runtime silently ignores panics - https://users.rust-lang.org/t/tokio-runtime-what-happens-when-a-thread-panics/95819 + // Once Tokio stabilizes the `unhandled_panic` field on the runtime builder, it should be used instead. + panic::set_hook(Box::new(|panic| { + println!("Panic: {panic}"); + CHECK.store(true, Ordering::Relaxed); + })); + + // This continuously query the flag, in order to abort ASAP after a panic. + let check_future = futures_util::FutureExt::fuse(async { + loop { + if CHECK.load(Ordering::Relaxed) { + return Err((redis::ErrorKind::IoError, "panic was caught").into()); + } + futures_time::task::sleep(futures_time::time::Duration::from_millis(1)).await; + } + }); + let f = futures_util::FutureExt::fuse(f); + futures::pin_mut!(f, check_future); + + let res = current_thread_runtime().block_on(async { + futures::select! {res = f => res, err = check_future => err} + }); + + let _ = panic::take_hook(); + if CHECK.swap(false, Ordering::Relaxed) { + panic!("Internal thread panicked"); + } + + res +} + +#[cfg(feature = "async-std-comp")] +pub fn block_on_all_using_async_std(f: F) -> F::Output +where + F: Future, +{ + async_std::task::block_on(f) +} + +#[cfg(any(feature = "cluster", feature = "cluster-async"))] +mod cluster; + +#[cfg(any(feature = "cluster", feature = "cluster-async"))] +mod mock_cluster; + +mod util; +#[allow(unused_imports)] +pub use self::util::*; + +#[cfg(any(feature = "cluster", feature = "cluster-async"))] +#[allow(unused_imports)] +pub use self::cluster::*; + +#[cfg(any(feature = "cluster", feature = "cluster-async"))] +#[allow(unused_imports)] +pub use self::mock_cluster::*; + +#[cfg(feature = "sentinel")] +mod sentinel; + +#[cfg(feature = "sentinel")] +#[allow(unused_imports)] +pub use self::sentinel::*; + +#[derive(PartialEq)] +enum ServerType { + Tcp { tls: bool }, + Unix, +} + +pub enum Module { + Json, +} + +pub struct RedisServer { + pub process: process::Child, + pub(crate) tempdir: tempfile::TempDir, + pub(crate) addr: redis::ConnectionAddr, + pub(crate) tls_paths: Option, +} + +impl ServerType { + fn get_intended() -> ServerType { + match env::var("REDISRS_SERVER_TYPE") + .ok() + .as_ref() + .map(|x| &x[..]) + { + Some("tcp") => ServerType::Tcp { tls: false }, + Some("tcp+tls") => ServerType::Tcp { tls: true }, + Some("unix") => ServerType::Unix, + Some(val) => { + panic!("Unknown server type {val:?}"); + } + None => ServerType::Tcp { tls: false }, + } + } +} + +impl RedisServer { + pub fn new() -> RedisServer { + RedisServer::with_modules(&[], false) + } + + #[cfg(feature = "tls-rustls")] + pub fn new_with_mtls() -> RedisServer { + RedisServer::with_modules(&[], true) + } + + pub fn get_addr(port: u16) -> ConnectionAddr { + let server_type = ServerType::get_intended(); + match server_type { + ServerType::Tcp { tls } => { + if tls { + redis::ConnectionAddr::TcpTls { + host: "127.0.0.1".to_string(), + port, + insecure: true, + tls_params: None, + } + } else { + redis::ConnectionAddr::Tcp("127.0.0.1".to_string(), port) + } + } + ServerType::Unix => { + let (a, b) = rand::random::<(u64, u64)>(); + let path = format!("/tmp/redis-rs-test-{a}-{b}.sock"); + redis::ConnectionAddr::Unix(PathBuf::from(&path)) + } + } + } + + pub fn with_modules(modules: &[Module], mtls_enabled: bool) -> RedisServer { + // this is technically a race but we can't do better with + // the tools that redis gives us :( + let redis_port = get_random_available_port(); + let addr = RedisServer::get_addr(redis_port); + + RedisServer::new_with_addr_tls_modules_and_spawner( + addr, + None, + None, + mtls_enabled, + modules, + |cmd| { + cmd.spawn() + .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}")) + }, + ) + } + + pub fn new_with_addr_and_modules( + addr: redis::ConnectionAddr, + modules: &[Module], + mtls_enabled: bool, + ) -> RedisServer { + RedisServer::new_with_addr_tls_modules_and_spawner( + addr, + None, + None, + mtls_enabled, + modules, + |cmd| { + cmd.spawn() + .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}")) + }, + ) + } + + pub fn new_with_addr_tls_modules_and_spawner< + F: FnOnce(&mut process::Command) -> process::Child, + >( + addr: redis::ConnectionAddr, + config_file: Option<&Path>, + tls_paths: Option, + mtls_enabled: bool, + modules: &[Module], + spawner: F, + ) -> RedisServer { + let mut redis_cmd = process::Command::new("redis-server"); + + if let Some(config_path) = config_file { + redis_cmd.arg(config_path); + } + + // Load Redis Modules + for module in modules { + match module { + Module::Json => { + redis_cmd + .arg("--loadmodule") + .arg(env::var("REDIS_RS_REDIS_JSON_PATH").expect( + "Unable to find path to RedisJSON at REDIS_RS_REDIS_JSON_PATH, is it set?", + )); + } + }; + } + + redis_cmd + .stdout(process::Stdio::null()) + .stderr(process::Stdio::null()); + let tempdir = tempfile::Builder::new() + .prefix("redis") + .tempdir() + .expect("failed to create tempdir"); + redis_cmd.arg("--logfile").arg(Self::log_file(&tempdir)); + match addr { + redis::ConnectionAddr::Tcp(ref bind, server_port) => { + redis_cmd + .arg("--port") + .arg(server_port.to_string()) + .arg("--bind") + .arg(bind); + + RedisServer { + process: spawner(&mut redis_cmd), + tempdir, + addr, + tls_paths: None, + } + } + redis::ConnectionAddr::TcpTls { ref host, port, .. } => { + let tls_paths = tls_paths.unwrap_or_else(|| build_keys_and_certs_for_tls(&tempdir)); + + let auth_client = if mtls_enabled { "yes" } else { "no" }; + + // prepare redis with TLS + redis_cmd + .arg("--tls-port") + .arg(port.to_string()) + .arg("--port") + .arg("0") + .arg("--tls-cert-file") + .arg(&tls_paths.redis_crt) + .arg("--tls-key-file") + .arg(&tls_paths.redis_key) + .arg("--tls-ca-cert-file") + .arg(&tls_paths.ca_crt) + .arg("--tls-auth-clients") + .arg(auth_client) + .arg("--bind") + .arg(host); + + // Insecure only disabled if `mtls` is enabled + let insecure = !mtls_enabled; + + let addr = redis::ConnectionAddr::TcpTls { + host: host.clone(), + port, + insecure, + tls_params: None, + }; + + RedisServer { + process: spawner(&mut redis_cmd), + tempdir, + addr, + tls_paths: Some(tls_paths), + } + } + redis::ConnectionAddr::Unix(ref path) => { + redis_cmd + .arg("--port") + .arg("0") + .arg("--unixsocket") + .arg(path); + RedisServer { + process: spawner(&mut redis_cmd), + tempdir, + addr, + tls_paths: None, + } + } + } + } + + pub fn client_addr(&self) -> &redis::ConnectionAddr { + &self.addr + } + + pub fn connection_info(&self) -> redis::ConnectionInfo { + redis::ConnectionInfo { + addr: self.client_addr().clone(), + redis: RedisConnectionInfo { + protocol: use_protocol(), + ..Default::default() + }, + } + } + + pub fn stop(&mut self) { + let _ = self.process.kill(); + let _ = self.process.wait(); + if let redis::ConnectionAddr::Unix(ref path) = *self.client_addr() { + fs::remove_file(path).ok(); + } + } + + pub fn log_file(tempdir: &TempDir) -> PathBuf { + tempdir.path().join("redis.log") + } +} + +/// Finds a random open port available for listening at, by spawning a TCP server with +/// port "zero" (which prompts the OS to just use any available port). Between calling +/// this function and trying to bind to this port, the port may be given to another +/// process, so this must be used with care (since here we only use it for tests, it's +/// mostly okay). +pub fn get_random_available_port() -> u16 { + let addr = &"127.0.0.1:0".parse::().unwrap().into(); + let socket = Socket::new(Domain::IPV4, Type::STREAM, None).unwrap(); + socket.set_reuse_address(true).unwrap(); + socket.bind(addr).unwrap(); + socket.listen(1).unwrap(); + let listener = TcpListener::from(socket); + listener.local_addr().unwrap().port() +} + +impl Drop for RedisServer { + fn drop(&mut self) { + self.stop() + } +} + +pub struct TestContext { + pub server: RedisServer, + pub client: redis::Client, + pub protocol: ProtocolVersion, +} + +pub(crate) fn is_tls_enabled() -> bool { + cfg!(all(feature = "tls-rustls", not(feature = "tls-native-tls"))) +} + +impl TestContext { + pub fn new() -> TestContext { + TestContext::with_modules(&[], false) + } + + #[cfg(feature = "tls-rustls")] + pub fn new_with_mtls() -> TestContext { + Self::with_modules(&[], true) + } + + fn connect_with_retries(client: &redis::Client) { + let mut con; + + let millisecond = Duration::from_millis(1); + let mut retries = 0; + loop { + match client.get_connection(None) { + Err(err) => { + if err.is_connection_refusal() { + sleep(millisecond); + retries += 1; + if retries > 100000 { + panic!("Tried to connect too many times, last error: {err}"); + } + } else { + panic!("Could not connect: {err}"); + } + } + Ok(x) => { + con = x; + break; + } + } + } + redis::cmd("FLUSHDB").execute(&mut con); + } + + pub fn with_tls(tls_files: TlsFilePaths, mtls_enabled: bool) -> TestContext { + let redis_port = get_random_available_port(); + let addr: ConnectionAddr = RedisServer::get_addr(redis_port); + + let server = RedisServer::new_with_addr_tls_modules_and_spawner( + addr, + None, + Some(tls_files), + mtls_enabled, + &[], + |cmd| { + cmd.spawn() + .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}")) + }, + ); + + #[cfg(feature = "tls-rustls")] + let client = + build_single_client(server.connection_info(), &server.tls_paths, mtls_enabled).unwrap(); + #[cfg(not(feature = "tls-rustls"))] + let client = redis::Client::open(server.connection_info()).unwrap(); + + Self::connect_with_retries(&client); + + TestContext { + server, + client, + protocol: use_protocol(), + } + } + + pub fn with_modules(modules: &[Module], mtls_enabled: bool) -> TestContext { + let server = RedisServer::with_modules(modules, mtls_enabled); + + #[cfg(feature = "tls-rustls")] + let client = + build_single_client(server.connection_info(), &server.tls_paths, mtls_enabled).unwrap(); + #[cfg(not(feature = "tls-rustls"))] + let client = redis::Client::open(server.connection_info()).unwrap(); + + Self::connect_with_retries(&client); + + TestContext { + server, + client, + protocol: use_protocol(), + } + } + + pub fn with_client_name(clientname: &str) -> TestContext { + let server = RedisServer::with_modules(&[], false); + let con_info = redis::ConnectionInfo { + addr: server.client_addr().clone(), + redis: redis::RedisConnectionInfo { + client_name: Some(clientname.to_string()), + ..Default::default() + }, + }; + + #[cfg(feature = "tls-rustls")] + let client = build_single_client(con_info, &server.tls_paths, false).unwrap(); + #[cfg(not(feature = "tls-rustls"))] + let client = redis::Client::open(con_info).unwrap(); + + Self::connect_with_retries(&client); + + TestContext { + server, + client, + protocol: use_protocol(), + } + } + + pub fn connection(&self) -> redis::Connection { + self.client.get_connection(None).unwrap() + } + + #[cfg(feature = "aio")] + pub async fn async_connection(&self) -> redis::RedisResult { + self.client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await + } + + #[cfg(feature = "aio")] + pub async fn async_pubsub(&self) -> redis::RedisResult { + self.client.get_async_pubsub().await + } + + #[cfg(feature = "async-std-comp")] + pub async fn async_connection_async_std( + &self, + ) -> redis::RedisResult { + self.client + .get_multiplexed_async_std_connection(GlideConnectionOptions::default()) + .await + } + + pub fn stop_server(&mut self) { + self.server.stop(); + } + + #[cfg(feature = "tokio-comp")] + pub async fn multiplexed_async_connection( + &self, + ) -> redis::RedisResult { + self.multiplexed_async_connection_tokio().await + } + + #[cfg(feature = "tokio-comp")] + pub async fn multiplexed_async_connection_tokio( + &self, + ) -> redis::RedisResult { + self.client + .get_multiplexed_tokio_connection(GlideConnectionOptions::default()) + .await + } + + #[cfg(feature = "async-std-comp")] + pub async fn multiplexed_async_connection_async_std( + &self, + ) -> redis::RedisResult { + self.client + .get_multiplexed_async_std_connection(GlideConnectionOptions::default()) + .await + } + + pub fn get_version(&self) -> Version { + let mut conn = self.connection(); + get_version(&mut conn) + } +} + +fn encode_iter(values: &[Value], writer: &mut W, prefix: &str) -> io::Result<()> +where + W: io::Write, +{ + write!(writer, "{}{}\r\n", prefix, values.len())?; + for val in values.iter() { + encode_value(val, writer)?; + } + Ok(()) +} +fn encode_map(values: &[(Value, Value)], writer: &mut W, prefix: &str) -> io::Result<()> +where + W: io::Write, +{ + write!(writer, "{}{}\r\n", prefix, values.len())?; + for (k, v) in values.iter() { + encode_value(k, writer)?; + encode_value(v, writer)?; + } + Ok(()) +} +pub fn encode_value(value: &Value, writer: &mut W) -> io::Result<()> +where + W: io::Write, +{ + #![allow(clippy::write_with_newline)] + match *value { + Value::Nil => write!(writer, "$-1\r\n"), + Value::Int(val) => write!(writer, ":{val}\r\n"), + Value::BulkString(ref val) => { + write!(writer, "${}\r\n", val.len())?; + writer.write_all(val)?; + writer.write_all(b"\r\n") + } + Value::Array(ref values) => encode_iter(values, writer, "*"), + Value::Okay => write!(writer, "+OK\r\n"), + Value::SimpleString(ref s) => write!(writer, "+{s}\r\n"), + Value::Map(ref values) => encode_map(values, writer, "%"), + Value::Attribute { + ref data, + ref attributes, + } => { + encode_map(attributes, writer, "|")?; + encode_value(data, writer)?; + Ok(()) + } + Value::Set(ref values) => encode_iter(values, writer, "~"), + Value::Double(val) => write!(writer, ",{}\r\n", val), + Value::Boolean(v) => { + if v { + write!(writer, "#t\r\n") + } else { + write!(writer, "#f\r\n") + } + } + Value::VerbatimString { + ref format, + ref text, + } => { + // format is always 3 bytes + write!(writer, "={}\r\n{}:{}\r\n", 3 + text.len(), format, text) + } + Value::BigNumber(ref val) => write!(writer, "({}\r\n", val), + Value::Push { ref kind, ref data } => { + write!(writer, ">{}\r\n+{kind}\r\n", data.len() + 1)?; + for val in data.iter() { + encode_value(val, writer)?; + } + Ok(()) + } + } +} + +#[derive(Clone, Debug)] +pub struct TlsFilePaths { + pub(crate) redis_crt: PathBuf, + pub(crate) redis_key: PathBuf, + pub(crate) ca_crt: PathBuf, +} + +pub fn build_keys_and_certs_for_tls(tempdir: &TempDir) -> TlsFilePaths { + // Based on shell script in redis's server tests + // https://github.com/redis/redis/blob/8c291b97b95f2e011977b522acf77ead23e26f55/utils/gen-test-certs.sh + let ca_crt = tempdir.path().join("ca.crt"); + let ca_key = tempdir.path().join("ca.key"); + let ca_serial = tempdir.path().join("ca.txt"); + let redis_crt = tempdir.path().join("redis.crt"); + let redis_key = tempdir.path().join("redis.key"); + let ext_file = tempdir.path().join("openssl.cnf"); + + fn make_key>(name: S, size: usize) { + process::Command::new("openssl") + .arg("genrsa") + .arg("-out") + .arg(name) + .arg(format!("{size}")) + .stdout(process::Stdio::null()) + .stderr(process::Stdio::null()) + .spawn() + .expect("failed to spawn openssl") + .wait() + .expect("failed to create key"); + } + + // Build CA Key + make_key(&ca_key, 4096); + + // Build redis key + make_key(&redis_key, 2048); + + // Build CA Cert + process::Command::new("openssl") + .arg("req") + .arg("-x509") + .arg("-new") + .arg("-nodes") + .arg("-sha256") + .arg("-key") + .arg(&ca_key) + .arg("-days") + .arg("3650") + .arg("-subj") + .arg("/O=Redis Test/CN=Certificate Authority") + .arg("-out") + .arg(&ca_crt) + .stdout(process::Stdio::null()) + .stderr(process::Stdio::null()) + .spawn() + .expect("failed to spawn openssl") + .wait() + .expect("failed to create CA cert"); + + // Build x509v3 extensions file + fs::write( + &ext_file, + b"keyUsage = digitalSignature, keyEncipherment\n\ + subjectAltName = @alt_names\n\ + [alt_names]\n\ + IP.1 = 127.0.0.1\n", + ) + .expect("failed to create x509v3 extensions file"); + + // Read redis key + let mut key_cmd = process::Command::new("openssl") + .arg("req") + .arg("-new") + .arg("-sha256") + .arg("-subj") + .arg("/O=Redis Test/CN=Generic-cert") + .arg("-key") + .arg(&redis_key) + .stdout(process::Stdio::piped()) + .stderr(process::Stdio::null()) + .spawn() + .expect("failed to spawn openssl"); + + // build redis cert + process::Command::new("openssl") + .arg("x509") + .arg("-req") + .arg("-sha256") + .arg("-CA") + .arg(&ca_crt) + .arg("-CAkey") + .arg(&ca_key) + .arg("-CAserial") + .arg(&ca_serial) + .arg("-CAcreateserial") + .arg("-days") + .arg("365") + .arg("-extfile") + .arg(&ext_file) + .arg("-out") + .arg(&redis_crt) + .stdin(key_cmd.stdout.take().expect("should have stdout")) + .stdout(process::Stdio::null()) + .stderr(process::Stdio::null()) + .spawn() + .expect("failed to spawn openssl") + .wait() + .expect("failed to create redis cert"); + + key_cmd.wait().expect("failed to create redis key"); + + TlsFilePaths { + redis_crt, + redis_key, + ca_crt, + } +} + +pub type Version = (u16, u16, u16); + +fn get_version(conn: &mut impl redis::ConnectionLike) -> Version { + let info: InfoDict = redis::Cmd::new().arg("INFO").query(conn).unwrap(); + let version: String = info.get("redis_version").unwrap(); + let versions: Vec = version + .split('.') + .map(|version| version.parse::().unwrap()) + .collect(); + assert_eq!(versions.len(), 3); + (versions[0], versions[1], versions[2]) +} + +pub fn is_major_version(expected_version: u16, version: Version) -> bool { + expected_version <= version.0 +} + +pub fn is_version(expected_major_minor: (u16, u16), version: Version) -> bool { + expected_major_minor.0 < version.0 + || (expected_major_minor.0 == version.0 && expected_major_minor.1 <= version.1) +} + +#[cfg(feature = "tls-rustls")] +fn load_certs_from_file(tls_file_paths: &TlsFilePaths) -> TlsCertificates { + let ca_file = File::open(&tls_file_paths.ca_crt).expect("Cannot open CA cert file"); + let mut root_cert_vec = Vec::new(); + BufReader::new(ca_file) + .read_to_end(&mut root_cert_vec) + .expect("Unable to read CA cert file"); + + let cert_file = File::open(&tls_file_paths.redis_crt).expect("cannot open private cert file"); + let mut client_cert_vec = Vec::new(); + BufReader::new(cert_file) + .read_to_end(&mut client_cert_vec) + .expect("Unable to read client cert file"); + + let key_file = File::open(&tls_file_paths.redis_key).expect("Cannot open private key file"); + let mut client_key_vec = Vec::new(); + BufReader::new(key_file) + .read_to_end(&mut client_key_vec) + .expect("Unable to read client key file"); + + TlsCertificates { + client_tls: Some(ClientTlsConfig { + client_cert: client_cert_vec, + client_key: client_key_vec, + }), + root_cert: Some(root_cert_vec), + } +} + +#[cfg(feature = "tls-rustls")] +pub(crate) fn build_single_client( + connection_info: T, + tls_file_params: &Option, + mtls_enabled: bool, +) -> redis::RedisResult { + if mtls_enabled && tls_file_params.is_some() { + redis::Client::build_with_tls( + connection_info, + load_certs_from_file( + tls_file_params + .as_ref() + .expect("Expected certificates when `tls-rustls` feature is enabled"), + ), + ) + } else { + redis::Client::open(connection_info) + } +} + +#[cfg(feature = "tls-rustls")] +pub(crate) mod mtls_test { + use super::*; + use redis::{cluster::ClusterClient, ConnectionInfo, RedisError}; + + fn clean_node_info(nodes: &[ConnectionInfo]) -> Vec { + let nodes = nodes + .iter() + .map(|node| match node { + ConnectionInfo { + addr: redis::ConnectionAddr::TcpTls { host, port, .. }, + redis, + } => ConnectionInfo { + addr: redis::ConnectionAddr::TcpTls { + host: host.to_owned(), + port: *port, + insecure: false, + tls_params: None, + }, + redis: redis.clone(), + }, + _ => node.clone(), + }) + .collect(); + nodes + } + + pub(crate) fn create_cluster_client_from_cluster( + cluster: &TestClusterContext, + mtls_enabled: bool, + ) -> Result { + let server = cluster + .cluster + .servers + .first() + .expect("Expected at least 1 server"); + let tls_paths = server.tls_paths.as_ref(); + let nodes = clean_node_info(&cluster.nodes); + let builder = redis::cluster::ClusterClientBuilder::new(nodes); + if let Some(tls_paths) = tls_paths { + // server-side TLS available + if mtls_enabled { + builder.certs(load_certs_from_file(tls_paths)) + } else { + builder + } + } else { + // server-side TLS NOT available + builder + } + .build() + } +} + +pub fn build_simple_pipeline_for_invalidation() -> Pipeline { + let mut pipe = redis::pipe(); + pipe.cmd("GET") + .arg("key_1") + .ignore() + .cmd("SET") + .arg("key_1") + .arg(42) + .ignore(); + pipe +} diff --git a/glide-core/redis-rs/redis/tests/support/sentinel.rs b/glide-core/redis-rs/redis/tests/support/sentinel.rs new file mode 100644 index 0000000000..d34d3dc88b --- /dev/null +++ b/glide-core/redis-rs/redis/tests/support/sentinel.rs @@ -0,0 +1,404 @@ +use std::fs::File; +use std::io::Write; +use std::thread::sleep; +use std::time::Duration; + +use redis::sentinel::SentinelNodeConnectionInfo; +use redis::Client; +use redis::ConnectionAddr; +use redis::ConnectionInfo; +use redis::FromRedisValue; +use redis::RedisResult; +use redis::TlsMode; +use tempfile::TempDir; + +use crate::support::build_single_client; + +use super::build_keys_and_certs_for_tls; +use super::get_random_available_port; +use super::Module; +use super::RedisServer; +use super::TlsFilePaths; + +const LOCALHOST: &str = "127.0.0.1"; +const MTLS_NOT_ENABLED: bool = false; + +pub struct RedisSentinelCluster { + pub servers: Vec, + pub sentinel_servers: Vec, + pub folders: Vec, +} + +fn get_addr(port: u16) -> ConnectionAddr { + let addr = RedisServer::get_addr(port); + if let ConnectionAddr::Unix(_) = addr { + ConnectionAddr::Tcp(String::from("127.0.0.1"), port) + } else { + addr + } +} + +fn spawn_master_server( + port: u16, + dir: &TempDir, + tlspaths: &TlsFilePaths, + modules: &[Module], +) -> RedisServer { + RedisServer::new_with_addr_tls_modules_and_spawner( + get_addr(port), + None, + Some(tlspaths.clone()), + MTLS_NOT_ENABLED, + modules, + |cmd| { + // Minimize startup delay + cmd.arg("--repl-diskless-sync-delay").arg("0"); + cmd.arg("--appendonly").arg("yes"); + if let ConnectionAddr::TcpTls { .. } = get_addr(port) { + cmd.arg("--tls-replication").arg("yes"); + } + cmd.current_dir(dir.path()); + cmd.spawn().unwrap() + }, + ) +} + +fn spawn_replica_server( + port: u16, + master_port: u16, + dir: &TempDir, + tlspaths: &TlsFilePaths, + modules: &[Module], +) -> RedisServer { + let config_file_path = dir.path().join("redis_config.conf"); + File::create(&config_file_path).unwrap(); + + RedisServer::new_with_addr_tls_modules_and_spawner( + get_addr(port), + Some(&config_file_path), + Some(tlspaths.clone()), + MTLS_NOT_ENABLED, + modules, + |cmd| { + cmd.arg("--replicaof") + .arg("127.0.0.1") + .arg(master_port.to_string()); + if let ConnectionAddr::TcpTls { .. } = get_addr(port) { + cmd.arg("--tls-replication").arg("yes"); + } + cmd.arg("--appendonly").arg("yes"); + cmd.current_dir(dir.path()); + cmd.spawn().unwrap() + }, + ) +} + +fn spawn_sentinel_server( + port: u16, + master_ports: &[u16], + dir: &TempDir, + tlspaths: &TlsFilePaths, + modules: &[Module], +) -> RedisServer { + let config_file_path = dir.path().join("redis_config.conf"); + let mut file = File::create(&config_file_path).unwrap(); + for (i, master_port) in master_ports.iter().enumerate() { + file.write_all( + format!("sentinel monitor master{} 127.0.0.1 {} 1\n", i, master_port).as_bytes(), + ) + .unwrap(); + } + file.flush().unwrap(); + + RedisServer::new_with_addr_tls_modules_and_spawner( + get_addr(port), + Some(&config_file_path), + Some(tlspaths.clone()), + MTLS_NOT_ENABLED, + modules, + |cmd| { + cmd.arg("--sentinel"); + cmd.arg("--appendonly").arg("yes"); + if let ConnectionAddr::TcpTls { .. } = get_addr(port) { + cmd.arg("--tls-replication").arg("yes"); + } + cmd.current_dir(dir.path()); + cmd.spawn().unwrap() + }, + ) +} + +fn wait_for_master_server( + mut get_client_fn: impl FnMut() -> RedisResult, +) -> Result<(), ()> { + let rolecmd = redis::cmd("ROLE"); + for _ in 0..100 { + let master_client = get_client_fn(); + match master_client { + Ok(client) => match client.get_connection(None) { + Ok(mut conn) => { + let r: Vec = rolecmd.query(&mut conn).unwrap(); + let role = String::from_redis_value(r.first().unwrap()).unwrap(); + if role.starts_with("master") { + return Ok(()); + } else { + println!("failed check for master role - current role: {r:?}") + } + } + Err(err) => { + println!("failed to get master connection: {:?}", err) + } + }, + Err(err) => { + println!("failed to get master client: {:?}", err) + } + } + + sleep(Duration::from_millis(25)); + } + + Err(()) +} + +fn wait_for_replica(mut get_client_fn: impl FnMut() -> RedisResult) -> Result<(), ()> { + let rolecmd = redis::cmd("ROLE"); + for _ in 0..200 { + let replica_client = get_client_fn(); + match replica_client { + Ok(client) => match client.get_connection(None) { + Ok(mut conn) => { + let r: Vec = rolecmd.query(&mut conn).unwrap(); + let role = String::from_redis_value(r.first().unwrap()).unwrap(); + let state = String::from_redis_value(r.get(3).unwrap()).unwrap(); + if role.starts_with("slave") && state == "connected" { + return Ok(()); + } else { + println!("failed check for replica role - current role: {:?}", r) + } + } + Err(err) => { + println!("failed to get replica connection: {:?}", err) + } + }, + Err(err) => { + println!("failed to get replica client: {:?}", err) + } + } + + sleep(Duration::from_millis(25)); + } + + Err(()) +} + +fn wait_for_replicas_to_sync(servers: &[RedisServer], masters: u16) { + let cluster_size = servers.len() / (masters as usize); + let clusters = servers.len() / cluster_size; + let replicas = cluster_size - 1; + + for cluster_index in 0..clusters { + let master_addr = servers[cluster_index * cluster_size].connection_info(); + let tls_paths = &servers.first().unwrap().tls_paths; + let r = wait_for_master_server(|| { + Ok(build_single_client(master_addr.clone(), tls_paths, MTLS_NOT_ENABLED).unwrap()) + }); + if r.is_err() { + panic!("failed waiting for master to be ready"); + } + + for replica_index in 0..replicas { + let replica_addr = + servers[(cluster_index * cluster_size) + 1 + replica_index].connection_info(); + let r = wait_for_replica(|| { + Ok(build_single_client(replica_addr.clone(), tls_paths, MTLS_NOT_ENABLED).unwrap()) + }); + if r.is_err() { + panic!("failed waiting for replica to be ready and in sync"); + } + } + } +} + +impl RedisSentinelCluster { + pub fn new(masters: u16, replicas_per_master: u16, sentinels: u16) -> RedisSentinelCluster { + RedisSentinelCluster::with_modules(masters, replicas_per_master, sentinels, &[]) + } + + pub fn with_modules( + masters: u16, + replicas_per_master: u16, + sentinels: u16, + modules: &[Module], + ) -> RedisSentinelCluster { + let mut servers = vec![]; + let mut folders = vec![]; + let mut master_ports = vec![]; + + let tempdir = tempfile::Builder::new() + .prefix("redistls") + .tempdir() + .expect("failed to create tempdir"); + let tlspaths = build_keys_and_certs_for_tls(&tempdir); + folders.push(tempdir); + + let required_number_of_sockets = masters * (replicas_per_master + 1) + sentinels; + let mut available_ports = std::collections::HashSet::new(); + while available_ports.len() < required_number_of_sockets as usize { + available_ports.insert(get_random_available_port()); + } + let mut available_ports: Vec<_> = available_ports.into_iter().collect(); + + for _ in 0..masters { + let port = available_ports.pop().unwrap(); + let tempdir = tempfile::Builder::new() + .prefix("redis") + .tempdir() + .expect("failed to create tempdir"); + servers.push(spawn_master_server(port, &tempdir, &tlspaths, modules)); + folders.push(tempdir); + master_ports.push(port); + + for _ in 0..replicas_per_master { + let replica_port = available_ports.pop().unwrap(); + let tempdir = tempfile::Builder::new() + .prefix("redis") + .tempdir() + .expect("failed to create tempdir"); + servers.push(spawn_replica_server( + replica_port, + port, + &tempdir, + &tlspaths, + modules, + )); + folders.push(tempdir); + } + } + + // Wait for replicas to sync so that the sentinels discover them on the first try + wait_for_replicas_to_sync(&servers, masters); + + let mut sentinel_servers = vec![]; + for _ in 0..sentinels { + let port = available_ports.pop().unwrap(); + let tempdir = tempfile::Builder::new() + .prefix("redis") + .tempdir() + .expect("failed to create tempdir"); + + sentinel_servers.push(spawn_sentinel_server( + port, + &master_ports, + &tempdir, + &tlspaths, + modules, + )); + folders.push(tempdir); + } + + RedisSentinelCluster { + servers, + sentinel_servers, + folders, + } + } + + pub fn stop(&mut self) { + for server in &mut self.servers { + server.stop(); + } + for server in &mut self.sentinel_servers { + server.stop(); + } + } + + pub fn iter_sentinel_servers(&self) -> impl Iterator { + self.sentinel_servers.iter() + } +} + +impl Drop for RedisSentinelCluster { + fn drop(&mut self) { + self.stop() + } +} + +pub struct TestSentinelContext { + pub cluster: RedisSentinelCluster, + pub sentinel: redis::sentinel::Sentinel, + pub sentinels_connection_info: Vec, + mtls_enabled: bool, // for future tests +} + +impl TestSentinelContext { + pub fn new(nodes: u16, replicas: u16, sentinels: u16) -> TestSentinelContext { + Self::new_with_cluster_client_builder(nodes, replicas, sentinels) + } + + pub fn new_with_cluster_client_builder( + nodes: u16, + replicas: u16, + sentinels: u16, + ) -> TestSentinelContext { + let cluster = RedisSentinelCluster::new(nodes, replicas, sentinels); + let initial_nodes: Vec = cluster + .iter_sentinel_servers() + .map(RedisServer::connection_info) + .collect(); + let sentinel = redis::sentinel::Sentinel::build(initial_nodes.clone()); + let sentinel = sentinel.unwrap(); + + let mut context = TestSentinelContext { + cluster, + sentinel, + sentinels_connection_info: initial_nodes, + mtls_enabled: MTLS_NOT_ENABLED, + }; + context.wait_for_cluster_up(); + context + } + + pub fn sentinel(&self) -> &redis::sentinel::Sentinel { + &self.sentinel + } + + pub fn sentinel_mut(&mut self) -> &mut redis::sentinel::Sentinel { + &mut self.sentinel + } + + pub fn sentinels_connection_info(&self) -> &Vec { + &self.sentinels_connection_info + } + + pub fn sentinel_node_connection_info(&self) -> SentinelNodeConnectionInfo { + SentinelNodeConnectionInfo { + tls_mode: if let ConnectionAddr::TcpTls { insecure, .. } = + self.cluster.servers[0].client_addr() + { + if *insecure { + Some(TlsMode::Insecure) + } else { + Some(TlsMode::Secure) + } + } else { + None + }, + redis_connection_info: None, + } + } + + pub fn wait_for_cluster_up(&mut self) { + let node_conn_info = self.sentinel_node_connection_info(); + let con = self.sentinel_mut(); + + let r = wait_for_master_server(|| con.master_for("master1", Some(&node_conn_info))); + if r.is_err() { + panic!("failed waiting for sentinel master1 to be ready"); + } + + let r = wait_for_replica(|| con.replica_for("master1", Some(&node_conn_info))); + if r.is_err() { + panic!("failed waiting for sentinel master1 replica to be ready"); + } + } +} diff --git a/glide-core/redis-rs/redis/tests/support/util.rs b/glide-core/redis-rs/redis/tests/support/util.rs new file mode 100644 index 0000000000..8026b83fb5 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/support/util.rs @@ -0,0 +1,23 @@ +use std::collections::HashMap; + +#[macro_export] +macro_rules! assert_args { + ($value:expr, $($args:expr),+) => { + let args = $value.to_redis_args(); + let strings: Vec<_> = args.iter() + .map(|a| std::str::from_utf8(a.as_ref()).unwrap()) + .collect(); + assert_eq!(strings, vec![$($args),+]); + } +} + +pub fn parse_client_info(client_info: &str) -> HashMap { + let mut res = HashMap::new(); + + for line in client_info.split(' ') { + let this_attr: Vec<&str> = line.split('=').collect(); + res.insert(this_attr[0].to_string(), this_attr[1].to_string()); + } + + res +} diff --git a/glide-core/redis-rs/redis/tests/test_acl.rs b/glide-core/redis-rs/redis/tests/test_acl.rs new file mode 100644 index 0000000000..093774f3bc --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_acl.rs @@ -0,0 +1,156 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +#![cfg(feature = "acl")] + +use std::collections::HashSet; + +use redis::acl::{AclInfo, Rule}; +use redis::{Commands, Value}; + +mod support; +use crate::support::*; + +#[test] +fn test_acl_whoami() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + assert_eq!(con.acl_whoami(), Ok("default".to_owned())); +} + +#[test] +fn test_acl_help() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let res: Vec = con.acl_help().expect("Got help manual"); + assert!(!res.is_empty()); +} + +//TODO: do we need this test? +#[test] +#[ignore] +fn test_acl_getsetdel_users() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + assert_eq!( + con.acl_list(), + Ok(vec!["user default on nopass ~* +@all".to_owned()]) + ); + assert_eq!(con.acl_users(), Ok(vec!["default".to_owned()])); + // bob + assert_eq!(con.acl_setuser("bob"), Ok(())); + assert_eq!( + con.acl_users(), + Ok(vec!["bob".to_owned(), "default".to_owned()]) + ); + + // ACL SETUSER bob on ~redis:* +set + assert_eq!( + con.acl_setuser_rules( + "bob", + &[ + Rule::On, + Rule::AddHashedPass( + "c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2".to_owned() + ), + Rule::Pattern("redis:*".to_owned()), + Rule::AddCommand("set".to_owned()) + ], + ), + Ok(()) + ); + let acl_info: AclInfo = con.acl_getuser("bob").expect("Got user"); + assert_eq!( + acl_info, + AclInfo { + flags: vec![Rule::On], + passwords: vec![Rule::AddHashedPass( + "c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2".to_owned() + )], + commands: vec![ + Rule::RemoveCategory("all".to_owned()), + Rule::AddCommand("set".to_owned()) + ], + keys: vec![Rule::Pattern("redis:*".to_owned())], + } + ); + assert_eq!( + con.acl_list(), + Ok(vec![ + "user bob on #c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2 ~redis:* -@all +set".to_owned(), + "user default on nopass ~* +@all".to_owned(), + ]) + ); + + // ACL SETUSER eve + assert_eq!(con.acl_setuser("eve"), Ok(())); + assert_eq!( + con.acl_users(), + Ok(vec![ + "bob".to_owned(), + "default".to_owned(), + "eve".to_owned() + ]) + ); + assert_eq!(con.acl_deluser(&["bob", "eve"]), Ok(2)); + assert_eq!(con.acl_users(), Ok(vec!["default".to_owned()])); +} + +#[test] +fn test_acl_cat() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let res: HashSet = con.acl_cat().expect("Got categories"); + let expects = vec![ + "keyspace", + "read", + "write", + "set", + "sortedset", + "list", + "hash", + "string", + "bitmap", + "hyperloglog", + "geo", + "stream", + "pubsub", + "admin", + "fast", + "slow", + "blocking", + "dangerous", + "connection", + "transaction", + "scripting", + ]; + for cat in expects.iter() { + assert!(res.contains(*cat), "Category `{cat}` does not exist"); + } + + let expects = ["pfmerge", "pfcount", "pfselftest", "pfadd"]; + let res: HashSet = con + .acl_cat_categoryname("hyperloglog") + .expect("Got commands of a category"); + for cmd in expects.iter() { + assert!(res.contains(*cmd), "Command `{cmd}` does not exist"); + } +} + +#[test] +fn test_acl_genpass() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let pass: String = con.acl_genpass().expect("Got password"); + assert_eq!(pass.len(), 64); + + let pass: String = con.acl_genpass_bits(1024).expect("Got password"); + assert_eq!(pass.len(), 256); +} + +#[test] +fn test_acl_log() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let logs: Vec = con.acl_log(1).expect("Got logs"); + assert_eq!(logs.len(), 0); + assert_eq!(con.acl_log_reset(), Ok(())); +} diff --git a/glide-core/redis-rs/redis/tests/test_async.rs b/glide-core/redis-rs/redis/tests/test_async.rs new file mode 100644 index 0000000000..d16f1e0694 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_async.rs @@ -0,0 +1,1132 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] + +mod support; + +#[cfg(test)] +mod basic_async { + use std::collections::HashMap; + + use futures::{prelude::*, StreamExt}; + use redis::{ + aio::{ConnectionLike, MultiplexedConnection}, + cmd, pipe, AsyncCommands, ErrorKind, GlideConnectionOptions, PushInfo, PushKind, + RedisResult, Value, + }; + use tokio::sync::mpsc::error::TryRecvError; + + use crate::support::*; + + #[test] + fn test_args() { + let ctx = TestContext::new(); + let connect = ctx.async_connection(); + + block_on_all(connect.and_then(|mut con| async move { + redis::cmd("SET") + .arg("key1") + .arg(b"foo") + .query_async(&mut con) + .await?; + redis::cmd("SET") + .arg(&["key2", "bar"]) + .query_async(&mut con) + .await?; + let result = redis::cmd("MGET") + .arg(&["key1", "key2"]) + .query_async(&mut con) + .await; + assert_eq!(result, Ok(("foo".to_string(), b"bar".to_vec()))); + result + })) + .unwrap(); + } + + #[test] + fn test_nice_hash_api() { + let ctx = TestContext::new(); + + block_on_all(async move { + let mut connection = ctx.async_connection().await.unwrap(); + + assert_eq!( + connection + .hset_multiple("my_hash", &[("f1", 1), ("f2", 2), ("f3", 4), ("f4", 8)]) + .await, + Ok(()) + ); + + let hm: HashMap = connection.hgetall("my_hash").await.unwrap(); + assert_eq!(hm.len(), 4); + assert_eq!(hm.get("f1"), Some(&1)); + assert_eq!(hm.get("f2"), Some(&2)); + assert_eq!(hm.get("f3"), Some(&4)); + assert_eq!(hm.get("f4"), Some(&8)); + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_nice_hash_api_in_pipe() { + let ctx = TestContext::new(); + + block_on_all(async move { + let mut connection = ctx.async_connection().await.unwrap(); + + assert_eq!( + connection + .hset_multiple("my_hash", &[("f1", 1), ("f2", 2), ("f3", 4), ("f4", 8)]) + .await, + Ok(()) + ); + + let mut pipe = redis::pipe(); + pipe.cmd("HGETALL").arg("my_hash"); + let mut vec: Vec> = + pipe.query_async(&mut connection).await.unwrap(); + assert_eq!(vec.len(), 1); + let hash = vec.pop().unwrap(); + assert_eq!(hash.len(), 4); + assert_eq!(hash.get("f1"), Some(&1)); + assert_eq!(hash.get("f2"), Some(&2)); + assert_eq!(hash.get("f3"), Some(&4)); + assert_eq!(hash.get("f4"), Some(&8)); + + Ok(()) + }) + .unwrap(); + } + + #[test] + fn dont_panic_on_closed_multiplexed_connection() { + let ctx = TestContext::new(); + let client = ctx.client.clone(); + let connect = client.get_multiplexed_async_connection(GlideConnectionOptions::default()); + drop(ctx); + + block_on_all(async move { + connect + .and_then(|con| async move { + let cmd = move || { + let mut con = con.clone(); + async move { + redis::cmd("SET") + .arg("key1") + .arg(b"foo") + .query_async(&mut con) + .await + } + }; + let result: RedisResult<()> = cmd().await; + assert_eq!( + result.as_ref().unwrap_err().kind(), + redis::ErrorKind::IoError, + "{}", + result.as_ref().unwrap_err() + ); + cmd().await + }) + .map(|result| { + assert_eq!( + result.as_ref().unwrap_err().kind(), + redis::ErrorKind::IoError, + "{}", + result.as_ref().unwrap_err() + ); + }) + .await; + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_pipeline_transaction() { + let ctx = TestContext::new(); + block_on_all(async move { + let mut con = ctx.async_connection().await?; + let mut pipe = redis::pipe(); + pipe.atomic() + .cmd("SET") + .arg("key_1") + .arg(42) + .ignore() + .cmd("SET") + .arg("key_2") + .arg(43) + .ignore() + .cmd("MGET") + .arg(&["key_1", "key_2"]); + pipe.query_async(&mut con) + .map_ok(|((k1, k2),): ((i32, i32),)| { + assert_eq!(k1, 42); + assert_eq!(k2, 43); + }) + .await + }) + .unwrap(); + } + + #[test] + fn test_client_tracking_doesnt_block_execution() { + //It checks if the library distinguish a push-type message from the others and continues its normal operation. + let ctx = TestContext::new(); + block_on_all(async move { + let mut con = ctx.async_connection().await.unwrap(); + let mut pipe = redis::pipe(); + pipe.cmd("CLIENT") + .arg("TRACKING") + .arg("ON") + .ignore() + .cmd("GET") + .arg("key_1") + .ignore() + .cmd("SET") + .arg("key_1") + .arg(42) + .ignore(); + let _: RedisResult<()> = pipe.query_async(&mut con).await; + let num: i32 = con.get("key_1").await.unwrap(); + assert_eq!(num, 42); + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_pipeline_transaction_with_errors() { + use redis::RedisError; + let ctx = TestContext::new(); + + block_on_all(async move { + let mut con = ctx.async_connection().await?; + con.set::<_, _, ()>("x", 42).await.unwrap(); + + // Make Redis a replica of a nonexistent master, thereby making it read-only. + redis::cmd("slaveof") + .arg("1.1.1.1") + .arg("1") + .query_async::<_, ()>(&mut con) + .await + .unwrap(); + + // Ensure that a write command fails with a READONLY error + let err: RedisResult<()> = redis::pipe() + .atomic() + .set("x", 142) + .ignore() + .get("x") + .query_async(&mut con) + .await; + + assert_eq!(err.unwrap_err().kind(), ErrorKind::ReadOnly); + + let x: i32 = con.get("x").await.unwrap(); + assert_eq!(x, 42); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + fn test_cmd( + con: &MultiplexedConnection, + i: i32, + ) -> impl Future> + Send { + let mut con = con.clone(); + async move { + let key = format!("key{i}"); + let key_2 = key.clone(); + let key2 = format!("key{i}_2"); + let key2_2 = key2.clone(); + + let foo_val = format!("foo{i}"); + + redis::cmd("SET") + .arg(&key[..]) + .arg(foo_val.as_bytes()) + .query_async(&mut con) + .await?; + redis::cmd("SET") + .arg(&[&key2, "bar"]) + .query_async(&mut con) + .await?; + redis::cmd("MGET") + .arg(&[&key_2, &key2_2]) + .query_async(&mut con) + .map(|result| { + assert_eq!(Ok((foo_val, b"bar".to_vec())), result); + Ok(()) + }) + .await + } + } + + fn test_error(con: &MultiplexedConnection) -> impl Future> { + let mut con = con.clone(); + async move { + redis::cmd("SET") + .query_async(&mut con) + .map(|result| match result { + Ok(()) => panic!("Expected redis to return an error"), + Err(_) => Ok(()), + }) + .await + } + } + + #[test] + fn test_pipe_over_multiplexed_connection() { + let ctx = TestContext::new(); + block_on_all(async move { + let mut con = ctx.multiplexed_async_connection().await?; + let mut pipe = pipe(); + pipe.zrange("zset", 0, 0); + pipe.zrange("zset", 0, 0); + let frames = con.send_packed_commands(&pipe, 0, 2).await?; + assert_eq!(frames.len(), 2); + assert!(matches!(frames[0], redis::Value::Array(_))); + assert!(matches!(frames[1], redis::Value::Array(_))); + RedisResult::Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_args_multiplexed_connection() { + let ctx = TestContext::new(); + block_on_all(async move { + ctx.multiplexed_async_connection() + .and_then(|con| { + let cmds = (0..100).map(move |i| test_cmd(&con, i)); + future::try_join_all(cmds).map_ok(|results| { + assert_eq!(results.len(), 100); + }) + }) + .map_err(|err| panic!("{}", err)) + .await + }) + .unwrap(); + } + + #[test] + fn test_args_with_errors_multiplexed_connection() { + let ctx = TestContext::new(); + block_on_all(async move { + ctx.multiplexed_async_connection() + .and_then(|con| { + let cmds = (0..100).map(move |i| { + let con = con.clone(); + async move { + if i % 2 == 0 { + test_cmd(&con, i).await + } else { + test_error(&con).await + } + } + }); + future::try_join_all(cmds).map_ok(|results| { + assert_eq!(results.len(), 100); + }) + }) + .map_err(|err| panic!("{}", err)) + .await + }) + .unwrap(); + } + + #[test] + fn test_transaction_multiplexed_connection() { + let ctx = TestContext::new(); + block_on_all(async move { + ctx.multiplexed_async_connection() + .and_then(|con| { + let cmds = (0..100).map(move |i| { + let mut con = con.clone(); + async move { + let foo_val = i; + let bar_val = format!("bar{i}"); + + let mut pipe = redis::pipe(); + pipe.atomic() + .cmd("SET") + .arg("key") + .arg(foo_val) + .ignore() + .cmd("SET") + .arg(&["key2", &bar_val[..]]) + .ignore() + .cmd("MGET") + .arg(&["key", "key2"]); + + pipe.query_async(&mut con) + .map(move |result| { + assert_eq!(Ok(((foo_val, bar_val.into_bytes()),)), result); + result + }) + .await + } + }); + future::try_join_all(cmds) + }) + .map_ok(|results| { + assert_eq!(results.len(), 100); + }) + .map_err(|err| panic!("{}", err)) + .await + }) + .unwrap(); + } + + fn test_async_scanning(batch_size: usize) { + let ctx = TestContext::new(); + block_on_all(async move { + ctx.multiplexed_async_connection() + .and_then(|mut con| { + async move { + let mut unseen = std::collections::HashSet::new(); + + for x in 0..batch_size { + redis::cmd("SADD") + .arg("foo") + .arg(x) + .query_async(&mut con) + .await?; + unseen.insert(x); + } + + let mut iter = redis::cmd("SSCAN") + .arg("foo") + .cursor_arg(0) + .clone() + .iter_async(&mut con) + .await + .unwrap(); + + while let Some(x) = iter.next_item().await { + // type inference limitations + let x: usize = x; + // if this assertion fails, too many items were returned by the iterator. + assert!(unseen.remove(&x)); + } + + assert_eq!(unseen.len(), 0); + Ok(()) + } + }) + .map_err(|err| panic!("{}", err)) + .await + }) + .unwrap(); + } + + #[test] + fn test_async_scanning_big_batch() { + test_async_scanning(1000) + } + + #[test] + fn test_async_scanning_small_batch() { + test_async_scanning(2) + } + + #[test] + fn test_response_timeout_multiplexed_connection() { + let ctx = TestContext::new(); + block_on_all(async move { + let mut connection = ctx.multiplexed_async_connection().await.unwrap(); + connection.set_response_timeout(std::time::Duration::from_millis(1)); + let mut cmd = redis::Cmd::new(); + cmd.arg("BLPOP").arg("foo").arg(0); // 0 timeout blocks indefinitely + let result = connection.req_packed_command(&cmd).await; + assert!(result.is_err()); + assert!(result.unwrap_err().is_timeout()); + Ok(()) + }) + .unwrap(); + } + + #[test] + #[cfg(feature = "script")] + fn test_script() { + use redis::RedisError; + + // Note this test runs both scripts twice to test when they have already been loaded + // into Redis and when they need to be loaded in + let script1 = redis::Script::new("return redis.call('SET', KEYS[1], ARGV[1])"); + let script2 = redis::Script::new("return redis.call('GET', KEYS[1])"); + let script3 = redis::Script::new("return redis.call('KEYS', '*')"); + + let ctx = TestContext::new(); + + block_on_all(async move { + let mut con = ctx.multiplexed_async_connection().await?; + script1 + .key("key1") + .arg("foo") + .invoke_async(&mut con) + .await?; + let val: String = script2.key("key1").invoke_async(&mut con).await?; + assert_eq!(val, "foo"); + let keys: Vec = script3.invoke_async(&mut con).await?; + assert_eq!(keys, ["key1"]); + script1 + .key("key1") + .arg("bar") + .invoke_async(&mut con) + .await?; + let val: String = script2.key("key1").invoke_async(&mut con).await?; + assert_eq!(val, "bar"); + let keys: Vec = script3.invoke_async(&mut con).await?; + assert_eq!(keys, ["key1"]); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[cfg(feature = "script")] + fn test_script_load() { + let ctx = TestContext::new(); + let script = redis::Script::new("return 'Hello World'"); + + block_on_all(async move { + let mut con = ctx.multiplexed_async_connection().await.unwrap(); + + let hash = script.prepare_invoke().load_async(&mut con).await.unwrap(); + assert_eq!(hash, script.get_hash().to_string()); + Ok(()) + }) + .unwrap(); + } + + #[test] + #[cfg(feature = "script")] + fn test_script_returning_complex_type() { + let ctx = TestContext::new(); + block_on_all(async { + let mut con = ctx.multiplexed_async_connection().await?; + redis::Script::new("return {1, ARGV[1], true}") + .arg("hello") + .invoke_async(&mut con) + .map_ok(|(i, s, b): (i32, String, bool)| { + assert_eq!(i, 1); + assert_eq!(s, "hello"); + assert!(b); + }) + .await + }) + .unwrap(); + } + + // Allowing `nth(0)` for similarity with the following `nth(1)`. + // Allowing `let ()` as `query_async` requries the type it converts the result to. + #[allow(clippy::let_unit_value, clippy::iter_nth_zero)] + #[tokio::test] + async fn io_error_on_kill_issue_320() { + let ctx = TestContext::new(); + + let mut conn_to_kill = ctx.async_connection().await.unwrap(); + cmd("CLIENT") + .arg("SETNAME") + .arg("to-kill") + .query_async::<_, ()>(&mut conn_to_kill) + .await + .unwrap(); + + let client_list: String = cmd("CLIENT") + .arg("LIST") + .query_async(&mut conn_to_kill) + .await + .unwrap(); + + eprintln!("{client_list}"); + let client_to_kill = client_list + .split('\n') + .find(|line| line.contains("to-kill")) + .expect("line") + .split(' ') + .nth(0) + .expect("id") + .split('=') + .nth(1) + .expect("id value"); + + let mut killer_conn = ctx.async_connection().await.unwrap(); + let () = cmd("CLIENT") + .arg("KILL") + .arg("ID") + .arg(client_to_kill) + .query_async(&mut killer_conn) + .await + .unwrap(); + let mut killed_client = conn_to_kill; + + let err = loop { + match killed_client.get::<_, Option>("a").await { + // We are racing against the server being shutdown so try until we a get an io error + Ok(_) => tokio::time::sleep(std::time::Duration::from_millis(50)).await, + Err(err) => break err, + } + }; + assert_eq!(err.kind(), ErrorKind::IoError); // Shouldn't this be IoError? + } + + #[tokio::test] + async fn invalid_password_issue_343() { + let ctx = TestContext::new(); + let coninfo = redis::ConnectionInfo { + addr: ctx.server.client_addr().clone(), + redis: redis::RedisConnectionInfo { + password: Some("asdcasc".to_string()), + ..Default::default() + }, + }; + let client = redis::Client::open(coninfo).unwrap(); + + let err = client + .get_multiplexed_tokio_connection(GlideConnectionOptions::default()) + .await + .err() + .unwrap(); + assert_eq!( + err.kind(), + ErrorKind::AuthenticationFailed, + "Unexpected error: {err}", + ); + } + + // Test issue of Stream trait blocking if we try to iterate more than 10 items + // https://github.com/mitsuhiko/redis-rs/issues/537 and https://github.com/mitsuhiko/redis-rs/issues/583 + #[tokio::test] + async fn test_issue_stream_blocks() { + let ctx = TestContext::new(); + let mut con = ctx.multiplexed_async_connection().await.unwrap(); + for i in 0..20usize { + let _: () = con.append(format!("test/{i}"), i).await.unwrap(); + } + let values = con.scan_match::<&str, String>("test/*").await.unwrap(); + tokio::time::timeout(std::time::Duration::from_millis(100), async move { + let values: Vec<_> = values.collect().await; + assert_eq!(values.len(), 20); + }) + .await + .unwrap(); + } + + // Test issue of AsyncCommands::scan returning the wrong number of keys + // https://github.com/redis-rs/redis-rs/issues/759 + #[tokio::test] + async fn test_issue_async_commands_scan_broken() { + let ctx = TestContext::new(); + let mut con = ctx.async_connection().await.unwrap(); + let mut keys: Vec = (0..100).map(|k| format!("async-key{k}")).collect(); + keys.sort(); + for key in &keys { + let _: () = con.set(key, b"foo").await.unwrap(); + } + + let iter: redis::AsyncIter = con.scan().await.unwrap(); + let mut keys_from_redis: Vec<_> = iter.collect().await; + keys_from_redis.sort(); + assert_eq!(keys, keys_from_redis); + assert_eq!(keys.len(), 100); + } + + mod pub_sub { + use std::time::Duration; + + use redis::ProtocolVersion; + + use super::*; + + #[test] + fn pub_sub_subscription() { + use redis::RedisError; + + let ctx = TestContext::new(); + block_on_all(async move { + let mut pubsub_conn = ctx.async_pubsub().await?; + pubsub_conn.subscribe("phonewave").await?; + let mut pubsub_stream = pubsub_conn.on_message(); + let mut publish_conn = ctx.async_connection().await?; + publish_conn.publish("phonewave", "banana").await?; + + let msg_payload: String = pubsub_stream.next().await.unwrap().get_payload()?; + assert_eq!("banana".to_string(), msg_payload); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn pub_sub_unsubscription() { + use redis::RedisError; + + const SUBSCRIPTION_KEY: &str = "phonewave-pub-sub-unsubscription"; + + let ctx = TestContext::new(); + block_on_all(async move { + let mut pubsub_conn = ctx.async_pubsub().await?; + pubsub_conn.subscribe(SUBSCRIPTION_KEY).await?; + pubsub_conn.unsubscribe(SUBSCRIPTION_KEY).await?; + + let mut conn = ctx.async_connection().await?; + let subscriptions_counts: HashMap = redis::cmd("PUBSUB") + .arg("NUMSUB") + .arg(SUBSCRIPTION_KEY) + .query_async(&mut conn) + .await?; + let subscription_count = *subscriptions_counts.get(SUBSCRIPTION_KEY).unwrap(); + assert_eq!(subscription_count, 0); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn automatic_unsubscription() { + use redis::RedisError; + + const SUBSCRIPTION_KEY: &str = "phonewave-automatic-unsubscription"; + + let ctx = TestContext::new(); + block_on_all(async move { + let mut pubsub_conn = ctx.async_pubsub().await?; + pubsub_conn.subscribe(SUBSCRIPTION_KEY).await?; + drop(pubsub_conn); + + let mut conn = ctx.async_connection().await?; + let mut subscription_count = 1; + // Allow for the unsubscription to occur within 5 seconds + for _ in 0..100 { + let subscriptions_counts: HashMap = redis::cmd("PUBSUB") + .arg("NUMSUB") + .arg(SUBSCRIPTION_KEY) + .query_async(&mut conn) + .await?; + subscription_count = *subscriptions_counts.get(SUBSCRIPTION_KEY).unwrap(); + if subscription_count == 0 { + break; + } + + std::thread::sleep(Duration::from_millis(50)); + } + assert_eq!(subscription_count, 0); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn pub_sub_conn_reuse() { + use redis::RedisError; + + let ctx = TestContext::new(); + block_on_all(async move { + let mut pubsub_conn = ctx.async_pubsub().await?; + pubsub_conn.subscribe("phonewave").await?; + pubsub_conn.psubscribe("*").await?; + + #[allow(deprecated)] + let mut conn = pubsub_conn.into_connection().await; + redis::cmd("SET") + .arg("foo") + .arg("bar") + .query_async(&mut conn) + .await?; + + let res: String = redis::cmd("GET").arg("foo").query_async(&mut conn).await?; + assert_eq!(&res, "bar"); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn pipe_errors_do_not_affect_subsequent_commands() { + use redis::RedisError; + + let ctx = TestContext::new(); + block_on_all(async move { + let mut conn = ctx.multiplexed_async_connection().await?; + + conn.lpush::<&str, &str, ()>("key", "value").await?; + + let res: Result<(String, usize), redis::RedisError> = redis::pipe() + .get("key") // WRONGTYPE + .llen("key") + .query_async(&mut conn) + .await; + + assert!(res.is_err()); + + let list: Vec = conn.lrange("key", 0, -1).await?; + + assert_eq!(list, vec!["value".to_owned()]); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn pub_sub_multiple() { + use redis::RedisError; + + let ctx = TestContext::new(); + if ctx.protocol == ProtocolVersion::RESP2 { + return; + } + block_on_all(async move { + let mut conn = ctx.multiplexed_async_connection().await?; + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + let pub_count = 10; + let channel_name = "phonewave".to_string(); + conn.get_push_manager().replace_sender(tx.clone()); + conn.subscribe(channel_name.clone()).await?; + rx.recv().await.unwrap(); //PASS SUBSCRIBE + + let mut publish_conn = ctx.async_connection().await?; + for i in 0..pub_count { + publish_conn + .publish(channel_name.clone(), format!("banana {i}")) + .await?; + } + for _ in 0..pub_count { + rx.recv().await.unwrap(); + } + assert!(rx.try_recv().is_err()); + + { + //Lets test if unsubscribing from individual channel subscription works + publish_conn + .publish(channel_name.clone(), "banana!") + .await?; + rx.recv().await.unwrap(); + } + { + //Giving none for channel id should unsubscribe all subscriptions from that channel and send unsubcribe command to server. + conn.unsubscribe(channel_name.clone()).await?; + rx.recv().await.unwrap(); //PASS UNSUBSCRIBE + publish_conn + .publish(channel_name.clone(), "banana!") + .await?; + //Let's wait for 100ms to make sure there is nothing in channel. + tokio::time::sleep(Duration::from_millis(100)).await; + assert!(rx.try_recv().is_err()); + } + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn push_manager_active_context() { + use redis::RedisError; + + let ctx = TestContext::new(); + if ctx.protocol == ProtocolVersion::RESP2 { + return; + } + block_on_all(async move { + let mut sub_conn = ctx.multiplexed_async_connection().await?; + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + let channel_name = "test_channel".to_string(); + sub_conn.get_push_manager().replace_sender(tx.clone()); + sub_conn.subscribe(channel_name.clone()).await?; + + let rcv_msg = rx.recv().await.unwrap(); + println!("Received PushInfo: {:?}", rcv_msg); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn push_manager_disconnection() { + use redis::RedisError; + + let ctx = TestContext::new(); + if ctx.protocol == ProtocolVersion::RESP2 { + return; + } + block_on_all(async move { + let mut conn = ctx.multiplexed_async_connection().await?; + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + conn.get_push_manager().replace_sender(tx.clone()); + + conn.set("A", "1").await?; + assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty); + drop(ctx); + let x: RedisResult<()> = conn.set("A", "1").await; + assert!(x.is_err()); + assert_eq!(rx.recv().await.unwrap().kind, PushKind::Disconnection); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + } + + #[test] + fn test_async_basic_pipe_with_parsing_error() { + // Tests a specific case involving repeated errors in transactions. + let ctx = TestContext::new(); + + block_on_all(async move { + let mut conn = ctx.multiplexed_async_connection().await?; + + // create a transaction where 2 errors are returned. + // we call EVALSHA twice with no loaded script, thus triggering 2 errors. + redis::pipe() + .atomic() + .cmd("EVALSHA") + .arg("foobar") + .arg(0) + .cmd("EVALSHA") + .arg("foobar") + .arg(0) + .query_async::<_, ((), ())>(&mut conn) + .await + .expect_err("should return an error"); + + assert!( + // Arbitrary Redis command that should not return an error. + redis::cmd("SMEMBERS") + .arg("nonexistent_key") + .query_async::<_, Vec>(&mut conn) + .await + .is_ok(), + "Failed transaction should not interfere with future calls." + ); + + Ok::<_, redis::RedisError>(()) + }) + .unwrap() + } + + #[cfg(feature = "connection-manager")] + async fn wait_for_server_to_become_ready(client: redis::Client) { + let millisecond = std::time::Duration::from_millis(1); + let mut retries = 0; + loop { + match client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await + { + Err(err) => { + if err.is_connection_refusal() { + tokio::time::sleep(millisecond).await; + retries += 1; + if retries > 100000 { + panic!("Tried to connect too many times, last error: {err}"); + } + } else { + panic!("Could not connect: {err}"); + } + } + Ok(mut con) => { + let _: RedisResult<()> = redis::cmd("FLUSHDB").query_async(&mut con).await; + break; + } + } + } + } + + #[test] + #[cfg(feature = "connection-manager")] + fn test_connection_manager_reconnect_after_delay() { + use redis::ProtocolVersion; + + let tempdir = tempfile::Builder::new() + .prefix("redis") + .tempdir() + .expect("failed to create tempdir"); + let tls_files = build_keys_and_certs_for_tls(&tempdir); + + let ctx = TestContext::with_tls(tls_files.clone(), false); + block_on_all(async move { + let mut manager = redis::aio::ConnectionManager::new(ctx.client.clone()) + .await + .unwrap(); + let server = ctx.server; + let addr = server.client_addr().clone(); + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + manager.get_push_manager().replace_sender(tx.clone()); + drop(server); + + let _result: RedisResult = manager.set("foo", "bar").await; // one call is ignored because it's required to trigger the connection manager's reconnect. + if ctx.protocol != ProtocolVersion::RESP2 { + assert_eq!(rx.recv().await.unwrap().kind, PushKind::Disconnection); + } + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let _new_server = RedisServer::new_with_addr_and_modules(addr.clone(), &[], false); + wait_for_server_to_become_ready(ctx.client.clone()).await; + + let result: redis::Value = manager.set("foo", "bar").await.unwrap(); + assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty); + assert_eq!(result, redis::Value::Okay); + Ok(()) + }) + .unwrap(); + } + + #[cfg(feature = "tls-rustls")] + mod mtls_test { + use super::*; + + #[test] + fn test_should_connect_mtls() { + let ctx = TestContext::new_with_mtls(); + + let client = + build_single_client(ctx.server.connection_info(), &ctx.server.tls_paths, true) + .unwrap(); + let connect = + client.get_multiplexed_async_connection(GlideConnectionOptions::default()); + block_on_all(connect.and_then(|mut con| async move { + redis::cmd("SET") + .arg("key1") + .arg(b"foo") + .query_async(&mut con) + .await?; + let result = redis::cmd("GET").arg(&["key1"]).query_async(&mut con).await; + assert_eq!(result, Ok("foo".to_string())); + result + })) + .unwrap(); + } + + #[test] + fn test_should_not_connect_if_tls_active() { + let ctx = TestContext::new_with_mtls(); + + let client = + build_single_client(ctx.server.connection_info(), &ctx.server.tls_paths, false) + .unwrap(); + let connect = + client.get_multiplexed_async_connection(GlideConnectionOptions::default()); + let result = block_on_all(connect.and_then(|mut con| async move { + redis::cmd("SET") + .arg("key1") + .arg(b"foo") + .query_async(&mut con) + .await?; + let result = redis::cmd("GET").arg(&["key1"]).query_async(&mut con).await; + assert_eq!(result, Ok("foo".to_string())); + result + })); + + // depends on server type set (REDISRS_SERVER_TYPE) + match ctx.server.connection_info() { + redis::ConnectionInfo { + addr: redis::ConnectionAddr::TcpTls { .. }, + .. + } => { + if result.is_ok() { + panic!("Must NOT be able to connect without client credentials if server accepts TLS"); + } + } + _ => { + if result.is_err() { + panic!("Must be able to connect without client credentials if server does NOT accept TLS"); + } + } + } + } + } + + #[test] + fn test_set_client_name_by_config() { + const CLIENT_NAME: &str = "TEST_CLIENT_NAME"; + use redis::RedisError; + let ctx = TestContext::with_client_name(CLIENT_NAME); + + block_on_all(async move { + let mut con = ctx.async_connection().await?; + + let client_info: String = redis::cmd("CLIENT") + .arg("INFO") + .query_async(&mut con) + .await + .unwrap(); + + let client_attrs = parse_client_info(&client_info); + + assert!( + client_attrs.contains_key("name"), + "Could not detect the 'name' attribute in CLIENT INFO output" + ); + + assert_eq!( + client_attrs["name"], CLIENT_NAME, + "Incorrect client name, expecting: {}, got {}", + CLIENT_NAME, client_attrs["name"] + ); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[cfg(feature = "connection-manager")] + fn test_push_manager_cm() { + use redis::ProtocolVersion; + + let ctx = TestContext::new(); + if ctx.protocol == ProtocolVersion::RESP2 { + return; + } + + block_on_all(async move { + let mut manager = redis::aio::ConnectionManager::new(ctx.client.clone()) + .await + .unwrap(); + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + manager.get_push_manager().replace_sender(tx.clone()); + manager + .send_packed_command(cmd("CLIENT").arg("TRACKING").arg("ON")) + .await + .unwrap(); + let pipe = build_simple_pipeline_for_invalidation(); + let _: RedisResult<()> = pipe.query_async(&mut manager).await; + let _: i32 = manager.get("key_1").await.unwrap(); + let PushInfo { kind, data } = rx.try_recv().unwrap(); + assert_eq!( + ( + PushKind::Invalidate, + vec![Value::Array(vec![Value::BulkString( + "key_1".as_bytes().to_vec() + )])] + ), + (kind, data) + ); + let (new_tx, mut new_rx) = tokio::sync::mpsc::unbounded_channel(); + manager.get_push_manager().replace_sender(new_tx); + drop(rx); + let _: RedisResult<()> = pipe.query_async(&mut manager).await; + let _: i32 = manager.get("key_1").await.unwrap(); + let PushInfo { kind, data } = new_rx.try_recv().unwrap(); + assert_eq!( + ( + PushKind::Invalidate, + vec![Value::Array(vec![Value::BulkString( + "key_1".as_bytes().to_vec() + )])] + ), + (kind, data) + ); + assert_eq!(TryRecvError::Empty, new_rx.try_recv().err().unwrap()); + Ok(()) + }) + .unwrap(); + } +} diff --git a/glide-core/redis-rs/redis/tests/test_async_async_std.rs b/glide-core/redis-rs/redis/tests/test_async_async_std.rs new file mode 100644 index 0000000000..656d1979f6 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_async_async_std.rs @@ -0,0 +1,328 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +use futures::prelude::*; + +use crate::support::*; + +use redis::{aio::MultiplexedConnection, GlideConnectionOptions, RedisResult}; + +mod support; + +#[test] +fn test_args() { + let ctx = TestContext::new(); + let connect = ctx.async_connection_async_std(); + + block_on_all_using_async_std(connect.and_then(|mut con| async move { + redis::cmd("SET") + .arg("key1") + .arg(b"foo") + .query_async(&mut con) + .await?; + redis::cmd("SET") + .arg(&["key2", "bar"]) + .query_async(&mut con) + .await?; + let result = redis::cmd("MGET") + .arg(&["key1", "key2"]) + .query_async(&mut con) + .await; + assert_eq!(result, Ok(("foo".to_string(), b"bar".to_vec()))); + result + })) + .unwrap(); +} + +#[test] +fn test_args_async_std() { + let ctx = TestContext::new(); + let connect = ctx.async_connection_async_std(); + + block_on_all_using_async_std(connect.and_then(|mut con| async move { + redis::cmd("SET") + .arg("key1") + .arg(b"foo") + .query_async(&mut con) + .await?; + redis::cmd("SET") + .arg(&["key2", "bar"]) + .query_async(&mut con) + .await?; + let result = redis::cmd("MGET") + .arg(&["key1", "key2"]) + .query_async(&mut con) + .await; + assert_eq!(result, Ok(("foo".to_string(), b"bar".to_vec()))); + result + })) + .unwrap(); +} + +#[test] +fn dont_panic_on_closed_multiplexed_connection() { + let ctx = TestContext::new(); + let client = ctx.client.clone(); + let connect = client.get_multiplexed_async_std_connection(GlideConnectionOptions::default()); + drop(ctx); + + block_on_all_using_async_std(async move { + connect + .and_then(|con| async move { + let cmd = move || { + let mut con = con.clone(); + async move { + redis::cmd("SET") + .arg("key1") + .arg(b"foo") + .query_async(&mut con) + .await + } + }; + let result: RedisResult<()> = cmd().await; + assert_eq!( + result.as_ref().unwrap_err().kind(), + redis::ErrorKind::IoError, + "{}", + result.as_ref().unwrap_err() + ); + cmd().await + }) + .map(|result| { + assert_eq!( + result.as_ref().unwrap_err().kind(), + redis::ErrorKind::IoError, + "{}", + result.as_ref().unwrap_err() + ); + }) + .await + }); +} + +#[test] +fn test_pipeline_transaction() { + let ctx = TestContext::new(); + block_on_all_using_async_std(async move { + let mut con = ctx.async_connection_async_std().await?; + let mut pipe = redis::pipe(); + pipe.atomic() + .cmd("SET") + .arg("key_1") + .arg(42) + .ignore() + .cmd("SET") + .arg("key_2") + .arg(43) + .ignore() + .cmd("MGET") + .arg(&["key_1", "key_2"]); + pipe.query_async(&mut con) + .map_ok(|((k1, k2),): ((i32, i32),)| { + assert_eq!(k1, 42); + assert_eq!(k2, 43); + }) + .await + }) + .unwrap(); +} + +fn test_cmd(con: &MultiplexedConnection, i: i32) -> impl Future> + Send { + let mut con = con.clone(); + async move { + let key = format!("key{i}"); + let key_2 = key.clone(); + let key2 = format!("key{i}_2"); + let key2_2 = key2.clone(); + + let foo_val = format!("foo{i}"); + + redis::cmd("SET") + .arg(&key[..]) + .arg(foo_val.as_bytes()) + .query_async(&mut con) + .await?; + redis::cmd("SET") + .arg(&[&key2, "bar"]) + .query_async(&mut con) + .await?; + redis::cmd("MGET") + .arg(&[&key_2, &key2_2]) + .query_async(&mut con) + .map(|result| { + assert_eq!(Ok((foo_val, b"bar".to_vec())), result); + Ok(()) + }) + .await + } +} + +fn test_error(con: &MultiplexedConnection) -> impl Future> { + let mut con = con.clone(); + async move { + redis::cmd("SET") + .query_async(&mut con) + .map(|result| match result { + Ok(()) => panic!("Expected redis to return an error"), + Err(_) => Ok(()), + }) + .await + } +} + +#[test] +fn test_args_multiplexed_connection() { + let ctx = TestContext::new(); + block_on_all_using_async_std(async move { + ctx.multiplexed_async_connection_async_std() + .and_then(|con| { + let cmds = (0..100).map(move |i| test_cmd(&con, i)); + future::try_join_all(cmds).map_ok(|results| { + assert_eq!(results.len(), 100); + }) + }) + .map_err(|err| panic!("{}", err)) + .await + }) + .unwrap(); +} + +#[test] +fn test_args_with_errors_multiplexed_connection() { + let ctx = TestContext::new(); + block_on_all_using_async_std(async move { + ctx.multiplexed_async_connection_async_std() + .and_then(|con| { + let cmds = (0..100).map(move |i| { + let con = con.clone(); + async move { + if i % 2 == 0 { + test_cmd(&con, i).await + } else { + test_error(&con).await + } + } + }); + future::try_join_all(cmds).map_ok(|results| { + assert_eq!(results.len(), 100); + }) + }) + .map_err(|err| panic!("{}", err)) + .await + }) + .unwrap(); +} + +#[test] +fn test_transaction_multiplexed_connection() { + let ctx = TestContext::new(); + block_on_all_using_async_std(async move { + ctx.multiplexed_async_connection_async_std() + .and_then(|con| { + let cmds = (0..100).map(move |i| { + let mut con = con.clone(); + async move { + let foo_val = i; + let bar_val = format!("bar{i}"); + + let mut pipe = redis::pipe(); + pipe.atomic() + .cmd("SET") + .arg("key") + .arg(foo_val) + .ignore() + .cmd("SET") + .arg(&["key2", &bar_val[..]]) + .ignore() + .cmd("MGET") + .arg(&["key", "key2"]); + + pipe.query_async(&mut con) + .map(move |result| { + assert_eq!(Ok(((foo_val, bar_val.into_bytes()),)), result); + result + }) + .await + } + }); + future::try_join_all(cmds) + }) + .map_ok(|results| { + assert_eq!(results.len(), 100); + }) + .map_err(|err| panic!("{}", err)) + .await + }) + .unwrap(); +} + +#[test] +#[cfg(feature = "script")] +fn test_script() { + use redis::RedisError; + + // Note this test runs both scripts twice to test when they have already been loaded + // into Redis and when they need to be loaded in + let script1 = redis::Script::new("return redis.call('SET', KEYS[1], ARGV[1])"); + let script2 = redis::Script::new("return redis.call('GET', KEYS[1])"); + let script3 = redis::Script::new("return redis.call('KEYS', '*')"); + + let ctx = TestContext::new(); + + block_on_all_using_async_std(async move { + let mut con = ctx.multiplexed_async_connection_async_std().await?; + script1 + .key("key1") + .arg("foo") + .invoke_async(&mut con) + .await?; + let val: String = script2.key("key1").invoke_async(&mut con).await?; + assert_eq!(val, "foo"); + let keys: Vec = script3.invoke_async(&mut con).await?; + assert_eq!(keys, ["key1"]); + script1 + .key("key1") + .arg("bar") + .invoke_async(&mut con) + .await?; + let val: String = script2.key("key1").invoke_async(&mut con).await?; + assert_eq!(val, "bar"); + let keys: Vec = script3.invoke_async(&mut con).await?; + assert_eq!(keys, ["key1"]); + Ok::<_, RedisError>(()) + }) + .unwrap(); +} + +#[test] +#[cfg(feature = "script")] +fn test_script_load() { + let ctx = TestContext::new(); + let script = redis::Script::new("return 'Hello World'"); + + block_on_all(async move { + let mut con = ctx.multiplexed_async_connection_async_std().await.unwrap(); + + let hash = script.prepare_invoke().load_async(&mut con).await.unwrap(); + assert_eq!(hash, script.get_hash().to_string()); + Ok(()) + }) + .unwrap(); +} + +#[test] +#[cfg(feature = "script")] +fn test_script_returning_complex_type() { + let ctx = TestContext::new(); + block_on_all_using_async_std(async { + let mut con = ctx.multiplexed_async_connection_async_std().await?; + redis::Script::new("return {1, ARGV[1], true}") + .arg("hello") + .invoke_async(&mut con) + .map_ok(|(i, s, b): (i32, String, bool)| { + assert_eq!(i, 1); + assert_eq!(s, "hello"); + assert!(b); + }) + .await + }) + .unwrap(); +} diff --git a/glide-core/redis-rs/redis/tests/test_async_cluster_connections_logic.rs b/glide-core/redis-rs/redis/tests/test_async_cluster_connections_logic.rs new file mode 100644 index 0000000000..0230d1de17 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_async_cluster_connections_logic.rs @@ -0,0 +1,563 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +#![cfg(feature = "cluster-async")] +mod support; + +use redis::{ + cluster_async::testing::{AsyncClusterNode, RefreshConnectionType}, + testing::ClusterParams, + ErrorKind, GlideConnectionOptions, +}; +use std::net::{IpAddr, Ipv4Addr}; +use std::sync::Arc; +use support::{ + get_mock_connection, get_mock_connection_with_port, modify_mock_connection_behavior, + respond_startup, ConnectionIPReturnType, MockConnection, MockConnectionBehavior, +}; + +mod test_connect_and_check { + use std::sync::atomic::AtomicUsize; + + use super::*; + use crate::support::{get_mock_connection_handler, ShouldReturnConnectionError}; + use redis::cluster_async::testing::{ + connect_and_check, ConnectAndCheckResult, ConnectionWithIp, + }; + + fn assert_partial_result( + result: ConnectAndCheckResult, + ) -> (AsyncClusterNode, redis::RedisError) { + match result { + ConnectAndCheckResult::ManagementConnectionFailed { node, err } => (node, err), + ConnectAndCheckResult::Success(_) => { + panic!("Expected partial result, got full success") + } + ConnectAndCheckResult::Failed(_) => panic!("Expected partial result, got a failure"), + } + } + + fn assert_full_success( + result: ConnectAndCheckResult, + ) -> AsyncClusterNode { + match result { + ConnectAndCheckResult::Success(node) => node, + ConnectAndCheckResult::ManagementConnectionFailed { .. } => { + panic!("Expected full success, got partial success") + } + ConnectAndCheckResult::Failed(_) => panic!("Expected partial result, got a failure"), + } + } + + #[tokio::test] + async fn test_connect_and_check_connect_successfully() { + // Test that upon refreshing all connections, if both connections were successful, + // the returned node contains both user and management connection + let name = "test_connect_and_check_connect_successfully"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + + let ip = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)); + modify_mock_connection_behavior(name, |behavior| { + behavior.returned_ip_type = ConnectionIPReturnType::Specified(ip) + }); + + let result = connect_and_check::( + &format!("{name}:6379"), + ClusterParams::default(), + None, + RefreshConnectionType::AllConnections, + None, + GlideConnectionOptions::default(), + ) + .await; + let node = assert_full_success(result); + assert!(node.management_connection.is_some()); + assert_eq!(node.user_connection.ip, Some(ip)); + assert_eq!(node.management_connection.unwrap().ip, Some(ip)); + } + + #[tokio::test] + async fn test_connect_and_check_all_connections_one_connection_err_returns_only_user_conn() { + // Test that upon refreshing all connections, if only one of the new connections fail, + // the other successful connection will be used as the user connection, as a partial success. + let name = "all_connections_one_connection_err"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + modify_mock_connection_behavior(name, |behavior| { + // The second connection will fail + behavior.return_connection_err = + ShouldReturnConnectionError::OnOddIdx(AtomicUsize::new(0)) + }); + + let params = ClusterParams::default(); + + let result = connect_and_check::( + &format!("{name}:6379"), + params.clone(), + None, + RefreshConnectionType::AllConnections, + None, + GlideConnectionOptions::default(), + ) + .await; + let (node, _) = assert_partial_result(result); + assert!(node.management_connection.is_none()); + + modify_mock_connection_behavior(name, |behavior| { + // The first connection will fail + behavior.return_connection_err = + ShouldReturnConnectionError::OnOddIdx(AtomicUsize::new(1)); + }); + + let result = connect_and_check::( + &format!("{name}:6379"), + params, + None, + RefreshConnectionType::AllConnections, + None, + GlideConnectionOptions::default(), + ) + .await; + let (node, _) = assert_partial_result(result); + assert!(node.management_connection.is_none()); + } + + #[tokio::test] + async fn test_connect_and_check_all_connections_different_ip_returns_both_connections() { + // Test that node's connections (e.g. user and management) can have different IPs for the same DNS endpoint. + // It is relevant for cases where the DNS entry holds multiple IPs that routes to the same node, for example with load balancers. + // The test verifies that upon refreshing all connections, if the IPs of the new connections differ, + // the function uses all connections. + let name = "all_connections_different_ip"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + modify_mock_connection_behavior(name, |behavior| { + behavior.returned_ip_type = ConnectionIPReturnType::Different(AtomicUsize::new(0)); + }); + + // The first connection will have 0.0.0.0 IP, the second 1.0.0.0 + let result = connect_and_check::( + &format!("{name}:6379"), + ClusterParams::default(), + None, + RefreshConnectionType::AllConnections, + None, + GlideConnectionOptions::default(), + ) + .await; + let node = assert_full_success(result); + assert!(node.management_connection.is_some()); + assert_eq!( + node.user_connection.ip, + Some(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))) + ); + assert_eq!( + node.management_connection.unwrap().ip, + Some(IpAddr::V4(Ipv4Addr::new(1, 0, 0, 0))) + ); + } + + #[tokio::test] + async fn test_connect_and_check_all_connections_both_conn_error_returns_err() { + // Test that when trying to refresh all connections and both connections fail, the function returns with an error + let name = "both_conn_error_returns_err"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + modify_mock_connection_behavior(name, |behavior| { + behavior.return_connection_err = ShouldReturnConnectionError::Yes + }); + + let result = connect_and_check::( + &format!("{name}:6379"), + ClusterParams::default(), + None, + RefreshConnectionType::AllConnections, + None, + GlideConnectionOptions::default(), + ) + .await; + let err = result.get_error().unwrap(); + assert!( + err.to_string() + .contains("Failed to refresh both connections") + && err.kind() == ErrorKind::IoError + ); + } + + #[tokio::test] + async fn test_connect_and_check_only_management_same_ip() { + // Test that when we refresh only the management connection and the new connection returned with the same IP as the user's, + // the returned node contains a new management connection and the user connection remains unchanged + let name = "only_management_same_ip"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + + let ip = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)); + modify_mock_connection_behavior(name, |behavior| { + behavior.returned_ip_type = ConnectionIPReturnType::Specified(ip) + }); + + let user_conn_id: usize = 1000; + let user_conn = MockConnection { + id: user_conn_id, + handler: get_mock_connection_handler(name), + port: 6379, + }; + let node = AsyncClusterNode::new( + ConnectionWithIp { + conn: user_conn, + ip: Some(ip), + } + .into_future(), + None, + ); + + let result = connect_and_check::( + &format!("{name}:6379"), + ClusterParams::default(), + None, + RefreshConnectionType::OnlyManagementConnection, + Some(node), + GlideConnectionOptions::default(), + ) + .await; + let node = assert_full_success(result); + assert!(node.management_connection.is_some()); + // Confirm that the user connection remains unchanged + assert_eq!(node.user_connection.conn.await.id, user_conn_id); + } + + #[tokio::test] + async fn test_connect_and_check_only_management_connection_err() { + // Test that when we try the refresh only the management connection and it fails, we receive a partial success with the same node. + let name = "only_management_connection_err"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + modify_mock_connection_behavior(name, |behavior| { + behavior.return_connection_err = ShouldReturnConnectionError::Yes; + }); + + let user_conn_id: usize = 1000; + let user_conn = MockConnection { + id: user_conn_id, + handler: get_mock_connection_handler(name), + port: 6379, + }; + let prev_ip = Some(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1))); + let node = AsyncClusterNode::new( + ConnectionWithIp { + conn: user_conn, + ip: prev_ip, + } + .into_future(), + None, + ); + + let result = connect_and_check::( + &format!("{name}:6379"), + ClusterParams::default(), + None, + RefreshConnectionType::OnlyManagementConnection, + Some(node), + GlideConnectionOptions::default(), + ) + .await; + let (node, _) = assert_partial_result(result); + assert!(node.management_connection.is_none()); + // Confirm that the user connection was changed + assert_eq!(node.user_connection.conn.await.id, user_conn_id); + assert_eq!(node.user_connection.ip, prev_ip); + } + + #[tokio::test] + async fn test_connect_and_check_only_user_connection_same_ip() { + // Test that upon refreshing only the user connection, if the newly created connection share the same IP as the existing management connection, + // the managament connection remains unchanged + let name = "only_user_connection_same_ip"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + + let prev_ip = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)); + modify_mock_connection_behavior(name, |behavior| { + behavior.returned_ip_type = ConnectionIPReturnType::Specified(prev_ip); + }); + let old_user_conn_id: usize = 1000; + let management_conn_id: usize = 2000; + let old_user_conn = MockConnection { + id: old_user_conn_id, + handler: get_mock_connection_handler(name), + port: 6379, + }; + let management_conn = MockConnection { + id: management_conn_id, + handler: get_mock_connection_handler(name), + port: 6379, + }; + + let node = AsyncClusterNode::new( + ConnectionWithIp { + conn: old_user_conn, + ip: Some(prev_ip), + } + .into_future(), + Some( + ConnectionWithIp { + conn: management_conn, + ip: Some(prev_ip), + } + .into_future(), + ), + ); + + let result = connect_and_check::( + &format!("{name}:6379"), + ClusterParams::default(), + None, + RefreshConnectionType::OnlyUserConnection, + Some(node), + GlideConnectionOptions::default(), + ) + .await; + let node = assert_full_success(result); + // Confirm that a new user connection was created + assert_ne!(node.user_connection.conn.await.id, old_user_conn_id); + // Confirm that the management connection remains unchanged + assert_eq!( + node.management_connection.unwrap().conn.await.id, + management_conn_id + ); + } +} + +mod test_check_node_connections { + + use super::*; + use redis::cluster_async::testing::{check_node_connections, ConnectionWithIp}; + fn create_node_with_all_connections(name: &str) -> AsyncClusterNode { + let ip = None; + AsyncClusterNode::new( + ConnectionWithIp { + conn: get_mock_connection_with_port(name, 1, 6380), + ip, + } + .into_future(), + Some( + ConnectionWithIp { + conn: get_mock_connection_with_port(name, 2, 6381), + ip, + } + .into_future(), + ), + ) + } + + #[tokio::test] + async fn test_check_node_connections_find_no_problem() { + // Test that upon when checking both connections, if both connections are healthy no issue is returned. + let name = "test_check_node_connections_find_no_problem"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + let node = create_node_with_all_connections(name); + let response = check_node_connections::( + &node, + &ClusterParams::default(), + RefreshConnectionType::AllConnections, + name, + ) + .await; + assert_eq!(response, None); + } + + #[tokio::test] + async fn test_check_node_connections_find_management_connection_issue() { + // Test that upon checking both connections, if management connection isn't responding to pings, `OnlyManagementConnection` will be returned. + let name = "test_check_node_connections_find_management_connection_issue"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, port| { + if port == 6381 { + return Err(Err((ErrorKind::ClientError, "some error").into())); + } + respond_startup(name, cmd)?; + Ok(()) + }), + ); + + let node = create_node_with_all_connections(name); + let response = check_node_connections::( + &node, + &ClusterParams::default(), + RefreshConnectionType::AllConnections, + name, + ) + .await; + assert_eq!( + response, + Some(RefreshConnectionType::OnlyManagementConnection) + ); + } + + #[tokio::test] + async fn test_check_node_connections_find_missing_management_connection() { + // Test that upon checking both connections, if management connection isn't present, `OnlyManagementConnection` will be returned. + let name = "test_check_node_connections_find_missing_management_connection"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + + let ip = None; + let node = AsyncClusterNode::new( + ConnectionWithIp { + conn: get_mock_connection(name, 1), + ip, + } + .into_future(), + None, + ); + let response = check_node_connections::( + &node, + &ClusterParams::default(), + RefreshConnectionType::AllConnections, + name, + ) + .await; + assert_eq!( + response, + Some(RefreshConnectionType::OnlyManagementConnection) + ); + } + + #[tokio::test] + async fn test_check_node_connections_find_both_connections_issue() { + // Test that upon checking both connections, if management connection isn't responding to pings, `OnlyManagementConnection` will be returned. + let name = "test_check_node_connections_find_both_connections_issue"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|_, _| Err(Err((ErrorKind::ClientError, "some error").into()))), + ); + + let node = create_node_with_all_connections(name); + let response = check_node_connections::( + &node, + &ClusterParams::default(), + RefreshConnectionType::AllConnections, + name, + ) + .await; + assert_eq!(response, Some(RefreshConnectionType::AllConnections)); + } + + #[tokio::test] + async fn test_check_node_connections_find_user_connection_issue() { + // Test that upon checking both connections, if user connection isn't responding to pings, `OnlyUserConnection` will be returned. + let name = "test_check_node_connections_find_user_connection_issue"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, port| { + if port == 6380 { + return Err(Err((ErrorKind::ClientError, "some error").into())); + } + respond_startup(name, cmd)?; + Ok(()) + }), + ); + + let node = create_node_with_all_connections(name); + let response = check_node_connections::( + &node, + &ClusterParams::default(), + RefreshConnectionType::AllConnections, + name, + ) + .await; + assert_eq!(response, Some(RefreshConnectionType::OnlyUserConnection)); + } + + #[tokio::test] + async fn test_check_node_connections_ignore_missing_management_connection_when_refreshing_user() + { + // Test that upon checking only user connection, issues with management connection won't affect the result. + let name = + "test_check_node_connections_ignore_management_connection_issue_when_refreshing_user"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + + let node = AsyncClusterNode::new( + ConnectionWithIp { + conn: get_mock_connection(name, 1), + ip: None, + } + .into_future(), + None, + ); + let response = check_node_connections::( + &node, + &ClusterParams::default(), + RefreshConnectionType::OnlyUserConnection, + name, + ) + .await; + assert_eq!(response, None); + } +} diff --git a/glide-core/redis-rs/redis/tests/test_basic.rs b/glide-core/redis-rs/redis/tests/test_basic.rs new file mode 100644 index 0000000000..e31c33384c --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_basic.rs @@ -0,0 +1,1581 @@ +#![allow(clippy::let_unit_value)] + +mod support; + +#[cfg(test)] +mod basic { + use redis::{cmd, ProtocolVersion, PushInfo}; + use redis::{ + Commands, ConnectionInfo, ConnectionLike, ControlFlow, ErrorKind, ExistenceCheck, Expiry, + PubSubCommands, PushKind, RedisResult, SetExpiry, SetOptions, ToRedisArgs, Value, + }; + use std::collections::{BTreeMap, BTreeSet}; + use std::collections::{HashMap, HashSet}; + use std::thread::{sleep, spawn}; + use std::time::Duration; + use std::vec; + use tokio::sync::mpsc::error::TryRecvError; + + use crate::{assert_args, support::*}; + + #[test] + fn test_parse_redis_url() { + let redis_url = "redis://127.0.0.1:1234/0".to_string(); + redis::parse_redis_url(&redis_url).unwrap(); + redis::parse_redis_url("unix:/var/run/redis/redis.sock").unwrap(); + assert!(redis::parse_redis_url("127.0.0.1").is_none()); + } + + #[test] + fn test_redis_url_fromstr() { + let _info: ConnectionInfo = "redis://127.0.0.1:1234/0".parse().unwrap(); + } + + #[test] + fn test_args() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("SET").arg("key1").arg(b"foo").execute(&mut con); + redis::cmd("SET").arg(&["key2", "bar"]).execute(&mut con); + + assert_eq!( + redis::cmd("MGET").arg(&["key1", "key2"]).query(&mut con), + Ok(("foo".to_string(), b"bar".to_vec())) + ); + } + + #[test] + fn test_getset() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("SET").arg("foo").arg(42).execute(&mut con); + assert_eq!(redis::cmd("GET").arg("foo").query(&mut con), Ok(42)); + + redis::cmd("SET").arg("bar").arg("foo").execute(&mut con); + assert_eq!( + redis::cmd("GET").arg("bar").query(&mut con), + Ok(b"foo".to_vec()) + ); + } + + //unit test for key_type function + #[test] + fn test_key_type() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + //The key is a simple value + redis::cmd("SET").arg("foo").arg(42).execute(&mut con); + let string_key_type: String = con.key_type("foo").unwrap(); + assert_eq!(string_key_type, "string"); + + //The key is a list + redis::cmd("LPUSH") + .arg("list_bar") + .arg("foo") + .execute(&mut con); + let list_key_type: String = con.key_type("list_bar").unwrap(); + assert_eq!(list_key_type, "list"); + + //The key is a set + redis::cmd("SADD") + .arg("set_bar") + .arg("foo") + .execute(&mut con); + let set_key_type: String = con.key_type("set_bar").unwrap(); + assert_eq!(set_key_type, "set"); + + //The key is a sorted set + redis::cmd("ZADD") + .arg("sorted_set_bar") + .arg("1") + .arg("foo") + .execute(&mut con); + let zset_key_type: String = con.key_type("sorted_set_bar").unwrap(); + assert_eq!(zset_key_type, "zset"); + + //The key is a hash + redis::cmd("HSET") + .arg("hset_bar") + .arg("hset_key_1") + .arg("foo") + .execute(&mut con); + let hash_key_type: String = con.key_type("hset_bar").unwrap(); + assert_eq!(hash_key_type, "hash"); + } + + #[test] + fn test_client_tracking_doesnt_block_execution() { + //It checks if the library distinguish a push-type message from the others and continues its normal operation. + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let (k1, k2): (i32, i32) = redis::pipe() + .cmd("CLIENT") + .arg("TRACKING") + .arg("ON") + .ignore() + .cmd("GET") + .arg("key_1") + .ignore() + .cmd("SET") + .arg("key_1") + .arg(42) + .ignore() + .cmd("SET") + .arg("key_2") + .arg(43) + .ignore() + .cmd("GET") + .arg("key_1") + .cmd("GET") + .arg("key_2") + .cmd("SET") + .arg("key_1") + .arg(45) + .ignore() + .query(&mut con) + .unwrap(); + assert_eq!(k1, 42); + assert_eq!(k2, 43); + let num: i32 = con.get("key_1").unwrap(); + assert_eq!(num, 45); + } + + #[test] + fn test_incr() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("SET").arg("foo").arg(42).execute(&mut con); + assert_eq!(redis::cmd("INCR").arg("foo").query(&mut con), Ok(43usize)); + } + + #[test] + fn test_getdel() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("SET").arg("foo").arg(42).execute(&mut con); + + assert_eq!(con.get_del("foo"), Ok(42usize)); + + assert_eq!( + redis::cmd("GET").arg("foo").query(&mut con), + Ok(None::) + ); + } + + #[test] + fn test_getex() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("SET").arg("foo").arg(42usize).execute(&mut con); + + // Return of get_ex must match set value + let ret_value = con.get_ex::<_, usize>("foo", Expiry::EX(1)).unwrap(); + assert_eq!(ret_value, 42usize); + + // Get before expiry time must also return value + sleep(Duration::from_millis(100)); + let delayed_get = con.get::<_, usize>("foo").unwrap(); + assert_eq!(delayed_get, 42usize); + + // Get after expiry time mustn't return value + sleep(Duration::from_secs(1)); + let after_expire_get = con.get::<_, Option>("foo").unwrap(); + assert_eq!(after_expire_get, None); + + // Persist option test prep + redis::cmd("SET").arg("foo").arg(420usize).execute(&mut con); + + // Return of get_ex with persist option must match set value + let ret_value = con.get_ex::<_, usize>("foo", Expiry::PERSIST).unwrap(); + assert_eq!(ret_value, 420usize); + + // Get after persist get_ex must return value + sleep(Duration::from_millis(200)); + let delayed_get = con.get::<_, usize>("foo").unwrap(); + assert_eq!(delayed_get, 420usize); + } + + #[test] + fn test_info() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let info: redis::InfoDict = redis::cmd("INFO").query(&mut con).unwrap(); + assert_eq!( + info.find(&"role"), + Some(&redis::Value::SimpleString("master".to_string())) + ); + assert_eq!(info.get("role"), Some("master".to_string())); + assert_eq!(info.get("loading"), Some(false)); + assert!(!info.is_empty()); + assert!(info.contains_key(&"role")); + } + + #[test] + fn test_hash_ops() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("HSET") + .arg("foo") + .arg("key_1") + .arg(1) + .execute(&mut con); + redis::cmd("HSET") + .arg("foo") + .arg("key_2") + .arg(2) + .execute(&mut con); + + let h: HashMap = redis::cmd("HGETALL").arg("foo").query(&mut con).unwrap(); + assert_eq!(h.len(), 2); + assert_eq!(h.get("key_1"), Some(&1i32)); + assert_eq!(h.get("key_2"), Some(&2i32)); + + let h: BTreeMap = redis::cmd("HGETALL").arg("foo").query(&mut con).unwrap(); + assert_eq!(h.len(), 2); + assert_eq!(h.get("key_1"), Some(&1i32)); + assert_eq!(h.get("key_2"), Some(&2i32)); + } + + // Requires redis-server >= 4.0.0. + // Not supported with the current appveyor/windows binary deployed. + #[cfg(not(target_os = "windows"))] + #[test] + fn test_unlink() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("SET").arg("foo").arg(42).execute(&mut con); + assert_eq!(redis::cmd("GET").arg("foo").query(&mut con), Ok(42)); + assert_eq!(con.unlink("foo"), Ok(1)); + + redis::cmd("SET").arg("foo").arg(42).execute(&mut con); + redis::cmd("SET").arg("bar").arg(42).execute(&mut con); + assert_eq!(con.unlink(&["foo", "bar"]), Ok(2)); + } + + #[test] + fn test_set_ops() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.sadd("foo", &[1, 2, 3]), Ok(3)); + + let mut s: Vec = con.smembers("foo").unwrap(); + s.sort_unstable(); + assert_eq!(s.len(), 3); + assert_eq!(&s, &[1, 2, 3]); + + let set: HashSet = con.smembers("foo").unwrap(); + assert_eq!(set.len(), 3); + assert!(set.contains(&1i32)); + assert!(set.contains(&2i32)); + assert!(set.contains(&3i32)); + + let set: BTreeSet = con.smembers("foo").unwrap(); + assert_eq!(set.len(), 3); + assert!(set.contains(&1i32)); + assert!(set.contains(&2i32)); + assert!(set.contains(&3i32)); + } + + #[test] + fn test_scan() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.sadd("foo", &[1, 2, 3]), Ok(3)); + + let (cur, mut s): (i32, Vec) = redis::cmd("SSCAN") + .arg("foo") + .arg(0) + .query(&mut con) + .unwrap(); + s.sort_unstable(); + assert_eq!(cur, 0i32); + assert_eq!(s.len(), 3); + assert_eq!(&s, &[1, 2, 3]); + } + + #[test] + fn test_optionals() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("SET").arg("foo").arg(1).execute(&mut con); + + let (a, b): (Option, Option) = redis::cmd("MGET") + .arg("foo") + .arg("missing") + .query(&mut con) + .unwrap(); + assert_eq!(a, Some(1i32)); + assert_eq!(b, None); + + let a = redis::cmd("GET") + .arg("missing") + .query(&mut con) + .unwrap_or(0i32); + assert_eq!(a, 0i32); + } + + #[test] + fn test_scanning() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let mut unseen = HashSet::new(); + + for x in 0..1000 { + redis::cmd("SADD").arg("foo").arg(x).execute(&mut con); + unseen.insert(x); + } + + let iter = redis::cmd("SSCAN") + .arg("foo") + .cursor_arg(0) + .clone() + .iter(&mut con) + .unwrap(); + + for x in iter { + // type inference limitations + let x: usize = x; + unseen.remove(&x); + } + + assert_eq!(unseen.len(), 0); + } + + #[test] + fn test_filtered_scanning() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let mut unseen = HashSet::new(); + + for x in 0..3000 { + let _: () = con + .hset("foo", format!("key_{}_{}", x % 100, x), x) + .unwrap(); + if x % 100 == 0 { + unseen.insert(x); + } + } + + let iter = con + .hscan_match::<&str, &str, (String, usize)>("foo", "key_0_*") + .unwrap(); + + for (_field, value) in iter { + unseen.remove(&value); + } + + assert_eq!(unseen.len(), 0); + } + + #[test] + fn test_pipeline() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let ((k1, k2),): ((i32, i32),) = redis::pipe() + .cmd("SET") + .arg("key_1") + .arg(42) + .ignore() + .cmd("SET") + .arg("key_2") + .arg(43) + .ignore() + .cmd("MGET") + .arg(&["key_1", "key_2"]) + .query(&mut con) + .unwrap(); + + assert_eq!(k1, 42); + assert_eq!(k2, 43); + } + + #[test] + fn test_pipeline_with_err() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = redis::cmd("SET") + .arg("x") + .arg("x-value") + .query(&mut con) + .unwrap(); + let _: () = redis::cmd("SET") + .arg("y") + .arg("y-value") + .query(&mut con) + .unwrap(); + + let _: () = redis::cmd("SLAVEOF") + .arg("1.1.1.1") + .arg("99") + .query(&mut con) + .unwrap(); + + let res = redis::pipe() + .set("x", "another-x-value") + .ignore() + .get("y") + .query::<()>(&mut con); + assert!(res.is_err() && res.unwrap_err().kind() == ErrorKind::ReadOnly); + + // Make sure we don't get leftover responses from the pipeline ("y-value"). See #436. + let res = redis::cmd("GET") + .arg("x") + .query::(&mut con) + .unwrap(); + assert_eq!(res, "x-value"); + } + + #[test] + fn test_empty_pipeline() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = redis::pipe().cmd("PING").ignore().query(&mut con).unwrap(); + + let _: () = redis::pipe().query(&mut con).unwrap(); + } + + #[test] + fn test_pipeline_transaction() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let ((k1, k2),): ((i32, i32),) = redis::pipe() + .atomic() + .cmd("SET") + .arg("key_1") + .arg(42) + .ignore() + .cmd("SET") + .arg("key_2") + .arg(43) + .ignore() + .cmd("MGET") + .arg(&["key_1", "key_2"]) + .query(&mut con) + .unwrap(); + + assert_eq!(k1, 42); + assert_eq!(k2, 43); + } + + #[test] + fn test_pipeline_transaction_with_errors() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = con.set("x", 42).unwrap(); + + // Make Redis a replica of a nonexistent master, thereby making it read-only. + let _: () = redis::cmd("slaveof") + .arg("1.1.1.1") + .arg("1") + .query(&mut con) + .unwrap(); + + // Ensure that a write command fails with a READONLY error + let err: RedisResult<()> = redis::pipe() + .atomic() + .set("x", 142) + .ignore() + .get("x") + .query(&mut con); + + assert_eq!(err.unwrap_err().kind(), ErrorKind::ReadOnly); + + let x: i32 = con.get("x").unwrap(); + assert_eq!(x, 42); + } + + #[test] + fn test_pipeline_reuse_query() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let mut pl = redis::pipe(); + + let ((k1,),): ((i32,),) = pl + .cmd("SET") + .arg("pkey_1") + .arg(42) + .ignore() + .cmd("MGET") + .arg(&["pkey_1"]) + .query(&mut con) + .unwrap(); + + assert_eq!(k1, 42); + + redis::cmd("DEL").arg("pkey_1").execute(&mut con); + + // The internal commands vector of the pipeline still contains the previous commands. + let ((k1,), (k2, k3)): ((i32,), (i32, i32)) = pl + .cmd("SET") + .arg("pkey_2") + .arg(43) + .ignore() + .cmd("MGET") + .arg(&["pkey_1"]) + .arg(&["pkey_2"]) + .query(&mut con) + .unwrap(); + + assert_eq!(k1, 42); + assert_eq!(k2, 42); + assert_eq!(k3, 43); + } + + #[test] + fn test_pipeline_reuse_query_clear() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let mut pl = redis::pipe(); + + let ((k1,),): ((i32,),) = pl + .cmd("SET") + .arg("pkey_1") + .arg(44) + .ignore() + .cmd("MGET") + .arg(&["pkey_1"]) + .query(&mut con) + .unwrap(); + pl.clear(); + + assert_eq!(k1, 44); + + redis::cmd("DEL").arg("pkey_1").execute(&mut con); + + let ((k1, k2),): ((bool, i32),) = pl + .cmd("SET") + .arg("pkey_2") + .arg(45) + .ignore() + .cmd("MGET") + .arg(&["pkey_1"]) + .arg(&["pkey_2"]) + .query(&mut con) + .unwrap(); + pl.clear(); + + assert!(!k1); + assert_eq!(k2, 45); + } + + #[test] + fn test_real_transaction() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let key = "the_key"; + let _: () = redis::cmd("SET").arg(key).arg(42).query(&mut con).unwrap(); + + loop { + let _: () = redis::cmd("WATCH").arg(key).query(&mut con).unwrap(); + let val: isize = redis::cmd("GET").arg(key).query(&mut con).unwrap(); + let response: Option<(isize,)> = redis::pipe() + .atomic() + .cmd("SET") + .arg(key) + .arg(val + 1) + .ignore() + .cmd("GET") + .arg(key) + .query(&mut con) + .unwrap(); + + match response { + None => { + continue; + } + Some(response) => { + assert_eq!(response, (43,)); + break; + } + } + } + } + + #[test] + fn test_real_transaction_highlevel() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let key = "the_key"; + let _: () = redis::cmd("SET").arg(key).arg(42).query(&mut con).unwrap(); + + let response: (isize,) = redis::transaction(&mut con, &[key], |con, pipe| { + let val: isize = redis::cmd("GET").arg(key).query(con)?; + pipe.cmd("SET") + .arg(key) + .arg(val + 1) + .ignore() + .cmd("GET") + .arg(key) + .query(con) + }) + .unwrap(); + + assert_eq!(response, (43,)); + } + + #[test] + fn test_pubsub() { + use std::sync::{Arc, Barrier}; + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + // Connection for subscriber api + let mut pubsub_con = ctx.connection(); + + // Barrier is used to make test thread wait to publish + // until after the pubsub thread has subscribed. + let barrier = Arc::new(Barrier::new(2)); + let pubsub_barrier = barrier.clone(); + + let thread = spawn(move || { + let mut pubsub = pubsub_con.as_pubsub(); + pubsub.subscribe("foo").unwrap(); + + let _ = pubsub_barrier.wait(); + + let msg = pubsub.get_message().unwrap(); + assert_eq!(msg.get_channel(), Ok("foo".to_string())); + assert_eq!(msg.get_payload(), Ok(42)); + + let msg = pubsub.get_message().unwrap(); + assert_eq!(msg.get_channel(), Ok("foo".to_string())); + assert_eq!(msg.get_payload(), Ok(23)); + }); + + let _ = barrier.wait(); + redis::cmd("PUBLISH").arg("foo").arg(42).execute(&mut con); + // We can also call the command directly + assert_eq!(con.publish("foo", 23), Ok(1)); + + thread.join().expect("Something went wrong"); + } + + #[test] + fn test_pubsub_unsubscribe() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + { + let mut pubsub = con.as_pubsub(); + pubsub.subscribe("foo").unwrap(); + pubsub.subscribe("bar").unwrap(); + pubsub.subscribe("baz").unwrap(); + pubsub.psubscribe("foo*").unwrap(); + pubsub.psubscribe("bar*").unwrap(); + pubsub.psubscribe("baz*").unwrap(); + } + + // Connection should be usable again for non-pubsub commands + let _: redis::Value = con.set("foo", "bar").unwrap(); + let value: String = con.get("foo").unwrap(); + assert_eq!(&value[..], "bar"); + } + + #[test] + fn test_pubsub_subscribe_while_messages_are_sent() { + let ctx = TestContext::new(); + let mut conn_external = ctx.connection(); + let mut conn_internal = ctx.connection(); + let received = std::sync::Arc::new(std::sync::Mutex::new(Vec::new())); + let received_clone = received.clone(); + let (sender, receiver) = std::sync::mpsc::channel(); + // receive message from foo channel + let thread = std::thread::spawn(move || { + let mut pubsub = conn_internal.as_pubsub(); + pubsub.subscribe("foo").unwrap(); + sender.send(()).unwrap(); + loop { + let msg = pubsub.get_message().unwrap(); + let channel = msg.get_channel_name(); + let content: i32 = msg.get_payload().unwrap(); + received + .lock() + .unwrap() + .push(format!("{channel}:{content}")); + if content == -1 { + return; + } + if content == 5 { + // subscribe bar channel using the same pubsub + pubsub.subscribe("bar").unwrap(); + sender.send(()).unwrap(); + } + } + }); + receiver.recv().unwrap(); + + // send message to foo channel after channel is ready. + for index in 0..10 { + println!("publishing on foo {index}"); + redis::cmd("PUBLISH") + .arg("foo") + .arg(index) + .query::(&mut conn_external) + .unwrap(); + } + receiver.recv().unwrap(); + redis::cmd("PUBLISH") + .arg("bar") + .arg(-1) + .query::(&mut conn_external) + .unwrap(); + thread.join().unwrap(); + assert_eq!( + *received_clone.lock().unwrap(), + (0..10) + .map(|index| format!("foo:{}", index)) + .chain(std::iter::once("bar:-1".to_string())) + .collect::>() + ); + } + + #[test] + fn test_pubsub_unsubscribe_no_subs() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + { + let _pubsub = con.as_pubsub(); + } + + // Connection should be usable again for non-pubsub commands + let _: redis::Value = con.set("foo", "bar").unwrap(); + let value: String = con.get("foo").unwrap(); + assert_eq!(&value[..], "bar"); + } + + #[test] + fn test_pubsub_unsubscribe_one_sub() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + { + let mut pubsub = con.as_pubsub(); + pubsub.subscribe("foo").unwrap(); + } + + // Connection should be usable again for non-pubsub commands + let _: redis::Value = con.set("foo", "bar").unwrap(); + let value: String = con.get("foo").unwrap(); + assert_eq!(&value[..], "bar"); + } + + #[test] + fn test_pubsub_unsubscribe_one_sub_one_psub() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + { + let mut pubsub = con.as_pubsub(); + pubsub.subscribe("foo").unwrap(); + pubsub.psubscribe("foo*").unwrap(); + } + + // Connection should be usable again for non-pubsub commands + let _: redis::Value = con.set("foo", "bar").unwrap(); + let value: String = con.get("foo").unwrap(); + assert_eq!(&value[..], "bar"); + } + + #[test] + fn scoped_pubsub() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + // Connection for subscriber api + let mut pubsub_con = ctx.connection(); + + let thread = spawn(move || { + let mut count = 0; + pubsub_con + .subscribe(&["foo", "bar"], |msg| { + count += 1; + match count { + 1 => { + assert_eq!(msg.get_channel(), Ok("foo".to_string())); + assert_eq!(msg.get_payload(), Ok(42)); + ControlFlow::Continue + } + 2 => { + assert_eq!(msg.get_channel(), Ok("bar".to_string())); + assert_eq!(msg.get_payload(), Ok(23)); + ControlFlow::Break(()) + } + _ => ControlFlow::Break(()), + } + }) + .unwrap(); + + pubsub_con + }); + + // Can't use a barrier in this case since there's no opportunity to run code + // between channel subscription and blocking for messages. + sleep(Duration::from_millis(100)); + + redis::cmd("PUBLISH").arg("foo").arg(42).execute(&mut con); + assert_eq!(con.publish("bar", 23), Ok(1)); + + // Wait for thread + let mut pubsub_con = thread.join().expect("pubsub thread terminates ok"); + + // Connection should be usable again for non-pubsub commands + let _: redis::Value = pubsub_con.set("foo", "bar").unwrap(); + let value: String = pubsub_con.get("foo").unwrap(); + assert_eq!(&value[..], "bar"); + } + + #[test] + #[cfg(feature = "script")] + fn test_script() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let script = redis::Script::new( + r" + return {redis.call('GET', KEYS[1]), ARGV[1]} + ", + ); + + let _: () = redis::cmd("SET") + .arg("my_key") + .arg("foo") + .query(&mut con) + .unwrap(); + let response = script.key("my_key").arg(42).invoke(&mut con); + + assert_eq!(response, Ok(("foo".to_string(), 42))); + } + + #[test] + #[cfg(feature = "script")] + fn test_script_load() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let script = redis::Script::new("return 'Hello World'"); + + let hash = script.prepare_invoke().load(&mut con); + + assert_eq!(hash, Ok(script.get_hash().to_string())); + } + + #[test] + fn test_tuple_args() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("HMSET") + .arg("my_key") + .arg(&[("field_1", 42), ("field_2", 23)]) + .execute(&mut con); + + assert_eq!( + redis::cmd("HGET") + .arg("my_key") + .arg("field_1") + .query(&mut con), + Ok(42) + ); + assert_eq!( + redis::cmd("HGET") + .arg("my_key") + .arg("field_2") + .query(&mut con), + Ok(23) + ); + } + + #[test] + fn test_nice_api() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.set("my_key", 42), Ok(())); + assert_eq!(con.get("my_key"), Ok(42)); + + let (k1, k2): (i32, i32) = redis::pipe() + .atomic() + .set("key_1", 42) + .ignore() + .set("key_2", 43) + .ignore() + .get("key_1") + .get("key_2") + .query(&mut con) + .unwrap(); + + assert_eq!(k1, 42); + assert_eq!(k2, 43); + } + + #[test] + fn test_auto_m_versions() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.mset(&[("key1", 1), ("key2", 2)]), Ok(())); + assert_eq!(con.get(&["key1", "key2"]), Ok((1, 2))); + assert_eq!(con.get(vec!["key1", "key2"]), Ok((1, 2))); + assert_eq!(con.get(vec!["key1", "key2"]), Ok((1, 2))); + } + + #[test] + fn test_nice_hash_api() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!( + con.hset_multiple("my_hash", &[("f1", 1), ("f2", 2), ("f3", 4), ("f4", 8)]), + Ok(()) + ); + + let hm: HashMap = con.hgetall("my_hash").unwrap(); + assert_eq!(hm.get("f1"), Some(&1)); + assert_eq!(hm.get("f2"), Some(&2)); + assert_eq!(hm.get("f3"), Some(&4)); + assert_eq!(hm.get("f4"), Some(&8)); + assert_eq!(hm.len(), 4); + + let hm: BTreeMap = con.hgetall("my_hash").unwrap(); + assert_eq!(hm.get("f1"), Some(&1)); + assert_eq!(hm.get("f2"), Some(&2)); + assert_eq!(hm.get("f3"), Some(&4)); + assert_eq!(hm.get("f4"), Some(&8)); + assert_eq!(hm.len(), 4); + + let v: Vec<(String, isize)> = con.hgetall("my_hash").unwrap(); + assert_eq!( + v, + vec![ + ("f1".to_string(), 1), + ("f2".to_string(), 2), + ("f3".to_string(), 4), + ("f4".to_string(), 8), + ] + ); + + assert_eq!(con.hget("my_hash", &["f2", "f4"]), Ok((2, 8))); + assert_eq!(con.hincr("my_hash", "f1", 1), Ok(2)); + assert_eq!(con.hincr("my_hash", "f2", 1.5f32), Ok(3.5f32)); + assert_eq!(con.hexists("my_hash", "f2"), Ok(true)); + assert_eq!(con.hdel("my_hash", &["f1", "f2"]), Ok(())); + assert_eq!(con.hexists("my_hash", "f2"), Ok(false)); + + let iter: redis::Iter<'_, (String, isize)> = con.hscan("my_hash").unwrap(); + let mut found = HashSet::new(); + for item in iter { + found.insert(item); + } + + assert_eq!(found.len(), 2); + assert!(found.contains(&("f3".to_string(), 4))); + assert!(found.contains(&("f4".to_string(), 8))); + } + + #[test] + fn test_nice_list_api() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.rpush("my_list", &[1, 2, 3, 4]), Ok(4)); + assert_eq!(con.rpush("my_list", &[5, 6, 7, 8]), Ok(8)); + assert_eq!(con.llen("my_list"), Ok(8)); + + assert_eq!(con.lpop("my_list", Default::default()), Ok(1)); + assert_eq!(con.llen("my_list"), Ok(7)); + + assert_eq!(con.lrange("my_list", 0, 2), Ok((2, 3, 4))); + + assert_eq!(con.lset("my_list", 0, 4), Ok(true)); + assert_eq!(con.lrange("my_list", 0, 2), Ok((4, 3, 4))); + + #[cfg(not(windows))] + //Windows version of redis is limited to v3.x + { + let my_list: Vec = con.lrange("my_list", 0, 10).expect("To get range"); + assert_eq!( + con.lpop("my_list", core::num::NonZeroUsize::new(10)), + Ok(my_list) + ); + } + } + + #[test] + fn test_tuple_decoding_regression() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.del("my_zset"), Ok(())); + assert_eq!(con.zadd("my_zset", "one", 1), Ok(1)); + assert_eq!(con.zadd("my_zset", "two", 2), Ok(1)); + + let vec: Vec<(String, u32)> = con.zrangebyscore_withscores("my_zset", 0, 10).unwrap(); + assert_eq!(vec.len(), 2); + + assert_eq!(con.del("my_zset"), Ok(1)); + + let vec: Vec<(String, u32)> = con.zrangebyscore_withscores("my_zset", 0, 10).unwrap(); + assert_eq!(vec.len(), 0); + } + + #[test] + fn test_bit_operations() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.setbit("bitvec", 10, true), Ok(false)); + assert_eq!(con.getbit("bitvec", 10), Ok(true)); + } + + #[test] + fn test_redis_server_down() { + let mut ctx = TestContext::new(); + let mut con = ctx.connection(); + + let ping = redis::cmd("PING").query::(&mut con); + assert_eq!(ping, Ok("PONG".into())); + + ctx.stop_server(); + + let ping = redis::cmd("PING").query::(&mut con); + + assert!(ping.is_err()); + eprintln!("{}", ping.unwrap_err()); + assert!(!con.is_open()); + } + + #[test] + fn test_zinterstore_weights() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = con + .zadd_multiple("zset1", &[(1, "one"), (2, "two"), (4, "four")]) + .unwrap(); + let _: () = con + .zadd_multiple("zset2", &[(1, "one"), (2, "two"), (3, "three")]) + .unwrap(); + + // zinterstore_weights + assert_eq!( + con.zinterstore_weights("out", &[("zset1", 2), ("zset2", 3)]), + Ok(2) + ); + + assert_eq!( + con.zrange_withscores("out", 0, -1), + Ok(vec![ + ("one".to_string(), "5".to_string()), + ("two".to_string(), "10".to_string()) + ]) + ); + + // zinterstore_min_weights + assert_eq!( + con.zinterstore_min_weights("out", &[("zset1", 2), ("zset2", 3)]), + Ok(2) + ); + + assert_eq!( + con.zrange_withscores("out", 0, -1), + Ok(vec![ + ("one".to_string(), "2".to_string()), + ("two".to_string(), "4".to_string()), + ]) + ); + + // zinterstore_max_weights + assert_eq!( + con.zinterstore_max_weights("out", &[("zset1", 2), ("zset2", 3)]), + Ok(2) + ); + + assert_eq!( + con.zrange_withscores("out", 0, -1), + Ok(vec![ + ("one".to_string(), "3".to_string()), + ("two".to_string(), "6".to_string()), + ]) + ); + } + + #[test] + fn test_zunionstore_weights() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = con + .zadd_multiple("zset1", &[(1, "one"), (2, "two")]) + .unwrap(); + let _: () = con + .zadd_multiple("zset2", &[(1, "one"), (2, "two"), (3, "three")]) + .unwrap(); + + // zunionstore_weights + assert_eq!( + con.zunionstore_weights("out", &[("zset1", 2), ("zset2", 3)]), + Ok(3) + ); + + assert_eq!( + con.zrange_withscores("out", 0, -1), + Ok(vec![ + ("one".to_string(), "5".to_string()), + ("three".to_string(), "9".to_string()), + ("two".to_string(), "10".to_string()) + ]) + ); + // test converting to double + assert_eq!( + con.zrange_withscores("out", 0, -1), + Ok(vec![ + ("one".to_string(), 5.0), + ("three".to_string(), 9.0), + ("two".to_string(), 10.0) + ]) + ); + + // zunionstore_min_weights + assert_eq!( + con.zunionstore_min_weights("out", &[("zset1", 2), ("zset2", 3)]), + Ok(3) + ); + + assert_eq!( + con.zrange_withscores("out", 0, -1), + Ok(vec![ + ("one".to_string(), "2".to_string()), + ("two".to_string(), "4".to_string()), + ("three".to_string(), "9".to_string()) + ]) + ); + + // zunionstore_max_weights + assert_eq!( + con.zunionstore_max_weights("out", &[("zset1", 2), ("zset2", 3)]), + Ok(3) + ); + + assert_eq!( + con.zrange_withscores("out", 0, -1), + Ok(vec![ + ("one".to_string(), "3".to_string()), + ("two".to_string(), "6".to_string()), + ("three".to_string(), "9".to_string()) + ]) + ); + } + + #[test] + fn test_zrembylex() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let setname = "myzset"; + assert_eq!( + con.zadd_multiple( + setname, + &[ + (0, "apple"), + (0, "banana"), + (0, "carrot"), + (0, "durian"), + (0, "eggplant"), + (0, "grapes"), + ], + ), + Ok(6) + ); + + // Will remove "banana", "carrot", "durian" and "eggplant" + let num_removed: u32 = con.zrembylex(setname, "[banana", "[eggplant").unwrap(); + assert_eq!(4, num_removed); + + let remaining: Vec = con.zrange(setname, 0, -1).unwrap(); + assert_eq!(remaining, vec!["apple".to_string(), "grapes".to_string()]); + } + + // Requires redis-server >= 6.2.0. + // Not supported with the current appveyor/windows binary deployed. + #[cfg(not(target_os = "windows"))] + #[test] + fn test_zrandmember() { + use redis::ProtocolVersion; + + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let setname = "myzrandset"; + let () = con.zadd(setname, "one", 1).unwrap(); + + let result: String = con.zrandmember(setname, None).unwrap(); + assert_eq!(result, "one".to_string()); + + let result: Vec = con.zrandmember(setname, Some(1)).unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0], "one".to_string()); + + let result: Vec = con.zrandmember(setname, Some(2)).unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0], "one".to_string()); + + assert_eq!( + con.zadd_multiple( + setname, + &[(2, "two"), (3, "three"), (4, "four"), (5, "five")] + ), + Ok(4) + ); + + let results: Vec = con.zrandmember(setname, Some(5)).unwrap(); + assert_eq!(results.len(), 5); + + let results: Vec = con.zrandmember(setname, Some(-5)).unwrap(); + assert_eq!(results.len(), 5); + + if ctx.protocol == ProtocolVersion::RESP2 { + let results: Vec = con.zrandmember_withscores(setname, 5).unwrap(); + assert_eq!(results.len(), 10); + + let results: Vec = con.zrandmember_withscores(setname, -5).unwrap(); + assert_eq!(results.len(), 10); + } + + let results: Vec<(String, f64)> = con.zrandmember_withscores(setname, 5).unwrap(); + assert_eq!(results.len(), 5); + + let results: Vec<(String, f64)> = con.zrandmember_withscores(setname, -5).unwrap(); + assert_eq!(results.len(), 5); + } + + #[test] + fn test_sismember() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let setname = "myset"; + assert_eq!(con.sadd(setname, &["a"]), Ok(1)); + + let result: bool = con.sismember(setname, &["a"]).unwrap(); + assert!(result); + + let result: bool = con.sismember(setname, &["b"]).unwrap(); + assert!(!result); + } + + // Requires redis-server >= 6.2.0. + // Not supported with the current appveyor/windows binary deployed. + #[cfg(not(target_os = "windows"))] + #[test] + fn test_smismember() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let setname = "myset"; + assert_eq!(con.sadd(setname, &["a", "b", "c"]), Ok(3)); + let results: Vec = con.smismember(setname, &["0", "a", "b", "c", "x"]).unwrap(); + assert_eq!(results, vec![false, true, true, true, false]); + } + + #[test] + fn test_object_commands() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = con.set("object_key_str", "object_value_str").unwrap(); + let _: () = con.set("object_key_int", 42).unwrap(); + + assert_eq!( + con.object_encoding::<_, String>("object_key_str").unwrap(), + "embstr" + ); + + assert_eq!( + con.object_encoding::<_, String>("object_key_int").unwrap(), + "int" + ); + + assert!(con.object_idletime::<_, i32>("object_key_str").unwrap() <= 1); + assert_eq!(con.object_refcount::<_, i32>("object_key_str").unwrap(), 1); + + // Needed for OBJECT FREQ and can't be set before object_idletime + // since that will break getting the idletime before idletime adjuts + redis::cmd("CONFIG") + .arg("SET") + .arg(b"maxmemory-policy") + .arg("allkeys-lfu") + .execute(&mut con); + + let _: () = con.get("object_key_str").unwrap(); + // since maxmemory-policy changed, freq should reset to 1 since we only called + // get after that + assert_eq!(con.object_freq::<_, i32>("object_key_str").unwrap(), 1); + } + + #[test] + fn test_mget() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = con.set(1, "1").unwrap(); + let data: Vec = con.mget(&[1]).unwrap(); + assert_eq!(data, vec!["1"]); + + let _: () = con.set(2, "2").unwrap(); + let data: Vec = con.mget(&[1, 2]).unwrap(); + assert_eq!(data, vec!["1", "2"]); + + let data: Vec> = con.mget(&[4]).unwrap(); + assert_eq!(data, vec![None]); + + let data: Vec> = con.mget(&[2, 4]).unwrap(); + assert_eq!(data, vec![Some("2".to_string()), None]); + } + + #[test] + fn test_variable_length_get() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = con.set(1, "1").unwrap(); + let keys = vec![1]; + assert_eq!(keys.len(), 1); + let data: Vec = con.get(&keys).unwrap(); + assert_eq!(data, vec!["1"]); + } + + #[test] + fn test_multi_generics() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.sadd(b"set1", vec![5, 42]), Ok(2)); + assert_eq!(con.sadd(999_i64, vec![42, 123]), Ok(2)); + let _: () = con.rename(999_i64, b"set2").unwrap(); + assert_eq!(con.sunionstore("res", &[b"set1", b"set2"]), Ok(3)); + } + + #[test] + fn test_set_options_with_get() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let opts = SetOptions::default().get(true); + let data: Option = con.set_options(1, "1", opts).unwrap(); + assert_eq!(data, None); + + let opts = SetOptions::default().get(true); + let data: Option = con.set_options(1, "1", opts).unwrap(); + assert_eq!(data, Some("1".to_string())); + } + + #[test] + fn test_set_options_options() { + let empty = SetOptions::default(); + assert_eq!(ToRedisArgs::to_redis_args(&empty).len(), 0); + + let opts = SetOptions::default() + .conditional_set(ExistenceCheck::NX) + .get(true) + .with_expiration(SetExpiry::PX(1000)); + + assert_args!(&opts, "NX", "GET", "PX", "1000"); + + let opts = SetOptions::default() + .conditional_set(ExistenceCheck::XX) + .get(true) + .with_expiration(SetExpiry::PX(1000)); + + assert_args!(&opts, "XX", "GET", "PX", "1000"); + + let opts = SetOptions::default() + .conditional_set(ExistenceCheck::XX) + .with_expiration(SetExpiry::KEEPTTL); + + assert_args!(&opts, "XX", "KEEPTTL"); + + let opts = SetOptions::default() + .conditional_set(ExistenceCheck::XX) + .with_expiration(SetExpiry::EXAT(100)); + + assert_args!(&opts, "XX", "EXAT", "100"); + + let opts = SetOptions::default().with_expiration(SetExpiry::EX(1000)); + + assert_args!(&opts, "EX", "1000"); + } + + #[test] + fn test_blocking_sorted_set_api() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + // setup version & input data followed by assertions that take into account Redis version + // BZPOPMIN & BZPOPMAX are available from Redis version 5.0.0 + // BZMPOP is available from Redis version 7.0.0 + + let redis_version = ctx.get_version(); + assert!(redis_version.0 >= 5); + + assert_eq!(con.zadd("a", "1a", 1), Ok(())); + assert_eq!(con.zadd("b", "2b", 2), Ok(())); + assert_eq!(con.zadd("c", "3c", 3), Ok(())); + assert_eq!(con.zadd("d", "4d", 4), Ok(())); + assert_eq!(con.zadd("a", "5a", 5), Ok(())); + assert_eq!(con.zadd("b", "6b", 6), Ok(())); + assert_eq!(con.zadd("c", "7c", 7), Ok(())); + assert_eq!(con.zadd("d", "8d", 8), Ok(())); + + let min = con.bzpopmin::<&str, (String, String, String)>("b", 0.0); + let max = con.bzpopmax::<&str, (String, String, String)>("b", 0.0); + + assert_eq!( + min.unwrap(), + (String::from("b"), String::from("2b"), String::from("2")) + ); + assert_eq!( + max.unwrap(), + (String::from("b"), String::from("6b"), String::from("6")) + ); + + if redis_version.0 >= 7 { + let min = con.bzmpop_min::<&str, (String, Vec>)>( + 0.0, + vec!["a", "b", "c", "d"].as_slice(), + 1, + ); + let max = con.bzmpop_max::<&str, (String, Vec>)>( + 0.0, + vec!["a", "b", "c", "d"].as_slice(), + 1, + ); + + assert_eq!( + min.unwrap().1[0][0], + (String::from("1a"), String::from("1")) + ); + assert_eq!( + max.unwrap().1[0][0], + (String::from("5a"), String::from("5")) + ); + } + } + + #[test] + fn test_set_client_name_by_config() { + const CLIENT_NAME: &str = "TEST_CLIENT_NAME"; + + let ctx = TestContext::with_client_name(CLIENT_NAME); + let mut con = ctx.connection(); + + let client_info: String = redis::cmd("CLIENT").arg("INFO").query(&mut con).unwrap(); + + let client_attrs = parse_client_info(&client_info); + + assert!( + client_attrs.contains_key("name"), + "Could not detect the 'name' attribute in CLIENT INFO output" + ); + + assert_eq!( + client_attrs["name"], CLIENT_NAME, + "Incorrect client name, expecting: {}, got {}", + CLIENT_NAME, client_attrs["name"] + ); + } + + #[test] + fn test_push_manager() { + let ctx = TestContext::new(); + if ctx.protocol == ProtocolVersion::RESP2 { + return; + } + let mut con = ctx.connection(); + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + con.get_push_manager().replace_sender(tx.clone()); + let _ = cmd("CLIENT") + .arg("TRACKING") + .arg("ON") + .query::<()>(&mut con) + .unwrap(); + let pipe = build_simple_pipeline_for_invalidation(); + for _ in 0..10 { + let _: RedisResult<()> = pipe.query(&mut con); + let _: i32 = con.get("key_1").unwrap(); + let PushInfo { kind, data } = rx.try_recv().unwrap(); + assert_eq!( + ( + PushKind::Invalidate, + vec![Value::Array(vec![Value::BulkString( + "key_1".as_bytes().to_vec() + )])] + ), + (kind, data) + ); + } + let (new_tx, mut new_rx) = tokio::sync::mpsc::unbounded_channel(); + con.get_push_manager().replace_sender(new_tx.clone()); + drop(rx); + let _: RedisResult<()> = pipe.query(&mut con); + let _: i32 = con.get("key_1").unwrap(); + let PushInfo { kind, data } = new_rx.try_recv().unwrap(); + assert_eq!( + ( + PushKind::Invalidate, + vec![Value::Array(vec![Value::BulkString( + "key_1".as_bytes().to_vec() + )])] + ), + (kind, data) + ); + + { + drop(new_rx); + for _ in 0..10 { + let _: RedisResult<()> = pipe.query(&mut con); + let v: i32 = con.get("key_1").unwrap(); + assert_eq!(v, 42); + } + } + } + + #[test] + fn test_push_manager_disconnection() { + let ctx = TestContext::new(); + if ctx.protocol == ProtocolVersion::RESP2 { + return; + } + let mut con = ctx.connection(); + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + con.get_push_manager().replace_sender(tx.clone()); + + let _: () = con.set("A", "1").unwrap(); + assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty); + drop(ctx); + let x: RedisResult<()> = con.set("A", "1"); + assert!(x.is_err()); + assert_eq!(rx.try_recv().unwrap().kind, PushKind::Disconnection); + } +} diff --git a/glide-core/redis-rs/redis/tests/test_bignum.rs b/glide-core/redis-rs/redis/tests/test_bignum.rs new file mode 100644 index 0000000000..20beefbc66 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_bignum.rs @@ -0,0 +1,61 @@ +#![cfg(any( + feature = "rust_decimal", + feature = "bigdecimal", + feature = "num-bigint" +))] +use redis::{ErrorKind, FromRedisValue, RedisResult, ToRedisArgs, Value}; +use std::str::FromStr; + +fn test(content: &str) +where + T: FromRedisValue + + ToRedisArgs + + std::str::FromStr + + std::convert::From + + std::cmp::PartialEq + + std::fmt::Debug, + ::Err: std::fmt::Debug, +{ + let v: RedisResult = + FromRedisValue::from_redis_value(&Value::BulkString(Vec::from(content))); + assert_eq!(v, Ok(T::from_str(content).unwrap())); + + let arg = ToRedisArgs::to_redis_args(&v.unwrap()); + assert_eq!(arg[0], Vec::from(content)); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Int(0)); + assert_eq!(v.unwrap(), T::from(0u32)); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Int(42)); + assert_eq!(v.unwrap(), T::from(42u32)); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Okay); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Nil); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); +} + +#[test] +#[cfg(feature = "rust_decimal")] +fn test_rust_decimal() { + test::("-79228162514264.337593543950335"); +} + +#[test] +#[cfg(feature = "bigdecimal")] +fn test_bigdecimal() { + test::("-14272476927059598810582859.69449495136382746623"); +} + +#[test] +#[cfg(feature = "num-bigint")] +fn test_bigint() { + test::("-1427247692705959881058285969449495136382746623"); +} + +#[test] +#[cfg(feature = "num-bigint")] +fn test_biguint() { + test::("1427247692705959881058285969449495136382746623"); +} diff --git a/glide-core/redis-rs/redis/tests/test_cluster.rs b/glide-core/redis-rs/redis/tests/test_cluster.rs new file mode 100644 index 0000000000..cbeddd2fe4 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_cluster.rs @@ -0,0 +1,1093 @@ +#![cfg(feature = "cluster")] +mod support; + +#[cfg(test)] +mod cluster { + use std::sync::{ + atomic::{self, AtomicI32, Ordering}, + Arc, + }; + + use crate::support::*; + use redis::{ + cluster::{cluster_pipe, ClusterClient}, + cmd, parse_redis_value, Commands, ConnectionLike, ErrorKind, ProtocolVersion, RedisError, + Value, + }; + + #[test] + fn test_cluster_basics() { + let cluster = TestClusterContext::new(3, 0); + let mut con = cluster.connection(); + + redis::cmd("SET") + .arg("{x}key1") + .arg(b"foo") + .execute(&mut con); + redis::cmd("SET").arg(&["{x}key2", "bar"]).execute(&mut con); + + assert_eq!( + redis::cmd("MGET") + .arg(&["{x}key1", "{x}key2"]) + .query(&mut con), + Ok(("foo".to_string(), b"bar".to_vec())) + ); + } + + #[test] + fn test_cluster_with_username_and_password() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .username(RedisCluster::username().to_string()) + .password(RedisCluster::password().to_string()) + }, + false, + ); + cluster.disable_default_user(); + + let mut con = cluster.connection(); + + redis::cmd("SET") + .arg("{x}key1") + .arg(b"foo") + .execute(&mut con); + redis::cmd("SET").arg(&["{x}key2", "bar"]).execute(&mut con); + + assert_eq!( + redis::cmd("MGET") + .arg(&["{x}key1", "{x}key2"]) + .query(&mut con), + Ok(("foo".to_string(), b"bar".to_vec())) + ); + } + + #[test] + fn test_cluster_with_bad_password() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .username(RedisCluster::username().to_string()) + .password("not the right password".to_string()) + }, + false, + ); + assert!(cluster.client.get_connection(None).is_err()); + } + + #[test] + fn test_cluster_read_from_replicas() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 6, + 1, + |builder| builder.read_from_replicas(), + false, + ); + let mut con = cluster.connection(); + + // Write commands would go to the primary nodes + redis::cmd("SET") + .arg("{x}key1") + .arg(b"foo") + .execute(&mut con); + redis::cmd("SET").arg(&["{x}key2", "bar"]).execute(&mut con); + + // Read commands would go to the replica nodes + assert_eq!( + redis::cmd("MGET") + .arg(&["{x}key1", "{x}key2"]) + .query(&mut con), + Ok(("foo".to_string(), b"bar".to_vec())) + ); + } + + #[test] + fn test_cluster_eval() { + let cluster = TestClusterContext::new(3, 0); + let mut con = cluster.connection(); + + let rv = redis::cmd("EVAL") + .arg( + r#" + redis.call("SET", KEYS[1], "1"); + redis.call("SET", KEYS[2], "2"); + return redis.call("MGET", KEYS[1], KEYS[2]); + "#, + ) + .arg("2") + .arg("{x}a") + .arg("{x}b") + .query(&mut con); + + assert_eq!(rv, Ok(("1".to_string(), "2".to_string()))); + } + + #[test] + fn test_cluster_resp3() { + if use_protocol() == ProtocolVersion::RESP2 { + return; + } + let cluster = TestClusterContext::new(3, 0); + + let mut connection = cluster.connection(); + + let hello: std::collections::HashMap = + redis::cmd("HELLO").query(&mut connection).unwrap(); + assert_eq!(hello.get("proto").unwrap(), &Value::Int(3)); + + let _: () = connection.hset("hash", "foo", "baz").unwrap(); + let _: () = connection.hset("hash", "bar", "foobar").unwrap(); + let result: Value = connection.hgetall("hash").unwrap(); + + assert_eq!( + result, + Value::Map(vec![ + ( + Value::BulkString("foo".as_bytes().to_vec()), + Value::BulkString("baz".as_bytes().to_vec()) + ), + ( + Value::BulkString("bar".as_bytes().to_vec()), + Value::BulkString("foobar".as_bytes().to_vec()) + ) + ]) + ); + } + + #[test] + fn test_cluster_multi_shard_commands() { + let cluster = TestClusterContext::new(3, 0); + + let mut connection = cluster.connection(); + + let res: String = connection + .mset(&[("foo", "bar"), ("bar", "foo"), ("baz", "bazz")]) + .unwrap(); + assert_eq!(res, "OK"); + let res: Vec = connection.mget(&["baz", "foo", "bar"]).unwrap(); + assert_eq!(res, vec!["bazz", "bar", "foo"]); + } + + #[test] + #[cfg(feature = "script")] + fn test_cluster_script() { + let cluster = TestClusterContext::new(3, 0); + let mut con = cluster.connection(); + + let script = redis::Script::new( + r#" + redis.call("SET", KEYS[1], "1"); + redis.call("SET", KEYS[2], "2"); + return redis.call("MGET", KEYS[1], KEYS[2]); + "#, + ); + + let rv = script.key("{x}a").key("{x}b").invoke(&mut con); + assert_eq!(rv, Ok(("1".to_string(), "2".to_string()))); + } + + #[test] + fn test_cluster_pipeline() { + let cluster = TestClusterContext::new(3, 0); + cluster.wait_for_cluster_up(); + let mut con = cluster.connection(); + + let resp = cluster_pipe() + .cmd("SET") + .arg("key_1") + .arg(42) + .query::>(&mut con) + .unwrap(); + + assert_eq!(resp, vec!["OK".to_string()]); + } + + #[test] + fn test_cluster_pipeline_multiple_keys() { + use redis::FromRedisValue; + let cluster = TestClusterContext::new(3, 0); + cluster.wait_for_cluster_up(); + let mut con = cluster.connection(); + + let resp = cluster_pipe() + .cmd("HSET") + .arg("hash_1") + .arg("key_1") + .arg("value_1") + .cmd("ZADD") + .arg("zset") + .arg(1) + .arg("zvalue_2") + .query::>(&mut con) + .unwrap(); + + assert_eq!(resp, vec![1i64, 1i64]); + + let resp = cluster_pipe() + .cmd("HGET") + .arg("hash_1") + .arg("key_1") + .cmd("ZCARD") + .arg("zset") + .query::>(&mut con) + .unwrap(); + + let resp_1: String = FromRedisValue::from_redis_value(&resp[0]).unwrap(); + assert_eq!(resp_1, "value_1".to_string()); + + let resp_2: usize = FromRedisValue::from_redis_value(&resp[1]).unwrap(); + assert_eq!(resp_2, 1); + } + + #[test] + fn test_cluster_pipeline_invalid_command() { + let cluster = TestClusterContext::new(3, 0); + cluster.wait_for_cluster_up(); + let mut con = cluster.connection(); + + let err = cluster_pipe() + .cmd("SET") + .arg("foo") + .arg(42) + .ignore() + .cmd(" SCRIPT kill ") + .query::<()>(&mut con) + .unwrap_err(); + + assert_eq!( + err.to_string(), + "This command cannot be safely routed in cluster mode - ClientError: Command 'SCRIPT KILL' can't be executed in a cluster pipeline." + ); + + let err = cluster_pipe().keys("*").query::<()>(&mut con).unwrap_err(); + + assert_eq!( + err.to_string(), + "This command cannot be safely routed in cluster mode - ClientError: Command 'KEYS' can't be executed in a cluster pipeline." + ); + } + + #[test] + fn test_cluster_can_connect_to_server_that_sends_cluster_slots_without_host_name() { + let name = "test_cluster_can_connect_to_server_that_sends_cluster_slots_without_host_name"; + + let MockEnv { mut connection, .. } = MockEnv::new(name, move |cmd: &[u8], _| { + if contains_slice(cmd, b"PING") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![Value::Array(vec![ + Value::Int(0), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString("".as_bytes().to_vec()), + Value::Int(6379), + ]), + ])]))) + } else { + Err(Ok(Value::Nil)) + } + }); + + let value = cmd("GET").arg("test").query::(&mut connection); + + assert_eq!(value, Ok(Value::Nil)); + } + + #[test] + fn test_cluster_can_connect_to_server_that_sends_cluster_slots_with_null_host_name() { + let name = + "test_cluster_can_connect_to_server_that_sends_cluster_slots_with_null_host_name"; + + let MockEnv { mut connection, .. } = MockEnv::new(name, move |cmd: &[u8], _| { + if contains_slice(cmd, b"PING") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![Value::Array(vec![ + Value::Int(0), + Value::Int(16383), + Value::Array(vec![Value::Nil, Value::Int(6379)]), + ])]))) + } else { + Err(Ok(Value::Nil)) + } + }); + + let value = cmd("GET").arg("test").query::(&mut connection); + + assert_eq!(value, Ok(Value::Nil)); + } + + #[test] + fn test_cluster_can_connect_to_server_that_sends_cluster_slots_with_partial_nodes_with_unknown_host_name( + ) { + let name = "test_cluster_can_connect_to_server_that_sends_cluster_slots_with_partial_nodes_with_unknown_host_name"; + + let MockEnv { mut connection, .. } = MockEnv::new(name, move |cmd: &[u8], _| { + if contains_slice(cmd, b"PING") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![ + Value::Array(vec![ + Value::Int(0), + Value::Int(7000), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6379), + ]), + ]), + Value::Array(vec![ + Value::Int(7001), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString("?".as_bytes().to_vec()), + Value::Int(6380), + ]), + ]), + ]))) + } else { + Err(Ok(Value::Nil)) + } + }); + + let value = cmd("GET").arg("test").query::(&mut connection); + assert_eq!(value, Ok(Value::Nil)); + } + + #[test] + fn test_cluster_pipeline_command_ordering() { + let cluster = TestClusterContext::new(3, 0); + cluster.wait_for_cluster_up(); + let mut con = cluster.connection(); + let mut pipe = cluster_pipe(); + + let mut queries = Vec::new(); + let mut expected = Vec::new(); + for i in 0..100 { + queries.push(format!("foo{i}")); + expected.push(format!("bar{i}")); + pipe.set(&queries[i], &expected[i]).ignore(); + } + pipe.execute(&mut con); + + pipe.clear(); + for q in &queries { + pipe.get(q); + } + + let got = pipe.query::>(&mut con).unwrap(); + assert_eq!(got, expected); + } + + #[test] + #[ignore] // Flaky + fn test_cluster_pipeline_ordering_with_improper_command() { + let cluster = TestClusterContext::new(3, 0); + cluster.wait_for_cluster_up(); + let mut con = cluster.connection(); + let mut pipe = cluster_pipe(); + + let mut queries = Vec::new(); + let mut expected = Vec::new(); + for i in 0..10 { + if i == 5 { + pipe.cmd("hset").arg("foo").ignore(); + } else { + let query = format!("foo{i}"); + let r = format!("bar{i}"); + pipe.set(&query, &r).ignore(); + queries.push(query); + expected.push(r); + } + } + pipe.query::<()>(&mut con).unwrap_err(); + + std::thread::sleep(std::time::Duration::from_secs(5)); + + pipe.clear(); + for q in &queries { + pipe.get(q); + } + + let got = pipe.query::>(&mut con).unwrap(); + assert_eq!(got, expected); + } + + #[test] + fn test_cluster_retries() { + let name = "tryagain"; + + let requests = atomic::AtomicUsize::new(0); + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(5), + name, + move |cmd: &[u8], _| { + respond_startup(name, cmd)?; + + match requests.fetch_add(1, atomic::Ordering::SeqCst) { + 0..=4 => Err(parse_redis_value(b"-TRYAGAIN mock\r\n")), + _ => Err(Ok(Value::BulkString(b"123".to_vec()))), + } + }, + ); + + let value = cmd("GET").arg("test").query::>(&mut connection); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + fn test_cluster_exhaust_retries() { + let name = "tryagain_exhaust_retries"; + + let requests = Arc::new(atomic::AtomicUsize::new(0)); + + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(2), + name, + { + let requests = requests.clone(); + move |cmd: &[u8], _| { + respond_startup(name, cmd)?; + requests.fetch_add(1, atomic::Ordering::SeqCst); + Err(parse_redis_value(b"-TRYAGAIN mock\r\n")) + } + }, + ); + + let result = cmd("GET").arg("test").query::>(&mut connection); + + match result { + Ok(_) => panic!("result should be an error"), + Err(e) => match e.kind() { + ErrorKind::TryAgain => {} + _ => panic!("Expected TryAgain but got {:?}", e.kind()), + }, + } + assert_eq!(requests.load(atomic::Ordering::SeqCst), 3); + } + + #[test] + fn test_cluster_move_error_when_new_node_is_added() { + let name = "rebuild_with_extra_nodes"; + + let requests = atomic::AtomicUsize::new(0); + let started = atomic::AtomicBool::new(false); + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::new(name, move |cmd: &[u8], port| { + if !started.load(atomic::Ordering::SeqCst) { + respond_startup(name, cmd)?; + } + started.store(true, atomic::Ordering::SeqCst); + + if contains_slice(cmd, b"PING") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + + match i { + // Respond that the key exists on a node that does not yet have a connection: + 0 => Err(parse_redis_value(b"-MOVED 123\r\n")), + // Respond with the new masters + 1 => Err(Ok(Value::Array(vec![ + Value::Array(vec![ + Value::Int(0), + Value::Int(1), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6379), + ]), + ]), + Value::Array(vec![ + Value::Int(2), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6380), + ]), + ]), + ]))), + _ => { + // Check that the correct node receives the request after rebuilding + assert_eq!(port, 6380); + Err(Ok(Value::BulkString(b"123".to_vec()))) + } + } + }); + + let value = cmd("GET").arg("test").query::>(&mut connection); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + fn test_cluster_ask_redirect() { + let name = "node"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]), + name, + { + move |cmd: &[u8], port| { + respond_startup_two_nodes(name, cmd)?; + // Error twice with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + let count = completed.fetch_add(1, Ordering::SeqCst); + match port { + 6379 => match count { + 0 => Err(parse_redis_value(b"-ASK 14000 node:6380\r\n")), + _ => panic!("Node should not be called now"), + }, + 6380 => match count { + 1 => { + assert!(contains_slice(cmd, b"ASKING")); + Err(Ok(Value::Okay)) + } + 2 => { + assert!(contains_slice(cmd, b"GET")); + Err(Ok(Value::BulkString(b"123".to_vec()))) + } + _ => panic!("Node should not be called now"), + }, + _ => panic!("Wrong node"), + } + } + }, + ); + + let value = cmd("GET").arg("test").query::>(&mut connection); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + fn test_cluster_ask_error_when_new_node_is_added() { + let name = "ask_with_extra_nodes"; + + let requests = atomic::AtomicUsize::new(0); + let started = atomic::AtomicBool::new(false); + + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::new(name, move |cmd: &[u8], port| { + if !started.load(atomic::Ordering::SeqCst) { + respond_startup(name, cmd)?; + } + started.store(true, atomic::Ordering::SeqCst); + + if contains_slice(cmd, b"PING") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + + match i { + // Respond that the key exists on a node that does not yet have a connection: + 0 => Err(parse_redis_value( + format!("-ASK 123 {name}:6380\r\n").as_bytes(), + )), + 1 => { + assert_eq!(port, 6380); + assert!(contains_slice(cmd, b"ASKING")); + Err(Ok(Value::Okay)) + } + 2 => { + assert_eq!(port, 6380); + assert!(contains_slice(cmd, b"GET")); + Err(Ok(Value::BulkString(b"123".to_vec()))) + } + _ => { + panic!("Unexpected request: {:?}", cmd); + } + } + }); + + let value = cmd("GET").arg("test").query::>(&mut connection); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + fn test_cluster_replica_read() { + let name = "node"; + + // requests should route to replica + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica(name, cmd)?; + + match port { + 6380 => Err(Ok(Value::BulkString(b"123".to_vec()))), + _ => panic!("Wrong node"), + } + }, + ); + + let value = cmd("GET").arg("test").query::>(&mut connection); + assert_eq!(value, Ok(Some(123))); + + // requests should route to primary + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica(name, cmd)?; + match port { + 6379 => Err(Ok(Value::SimpleString("OK".into()))), + _ => panic!("Wrong node"), + } + }, + ); + + let value = cmd("SET") + .arg("test") + .arg("123") + .query::>(&mut connection); + assert_eq!(value, Ok(Some(Value::SimpleString("OK".to_owned())))); + } + + #[test] + fn test_cluster_io_error() { + let name = "node"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(2), + name, + move |cmd: &[u8], port| { + respond_startup_two_nodes(name, cmd)?; + // Error twice with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + match port { + 6380 => panic!("Node should not be called"), + _ => match completed.fetch_add(1, Ordering::SeqCst) { + 0..=1 => Err(Err(RedisError::from(std::io::Error::new( + std::io::ErrorKind::ConnectionReset, + "mock-io-error", + )))), + _ => Err(Ok(Value::BulkString(b"123".to_vec()))), + }, + } + }, + ); + + let value = cmd("GET").arg("test").query::>(&mut connection); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + fn test_cluster_non_retryable_error_should_not_retry() { + let name = "node"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { mut connection, .. } = MockEnv::new(name, { + let completed = completed.clone(); + move |cmd: &[u8], _| { + respond_startup_two_nodes(name, cmd)?; + // Error twice with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + completed.fetch_add(1, Ordering::SeqCst); + Err(Err((ErrorKind::ReadOnly, "").into())) + } + }); + + let value = cmd("GET").arg("test").query::>(&mut connection); + + match value { + Ok(_) => panic!("result should be an error"), + Err(e) => match e.kind() { + ErrorKind::ReadOnly => {} + _ => panic!("Expected ReadOnly but got {:?}", e.kind()), + }, + } + assert_eq!(completed.load(Ordering::SeqCst), 1); + } + + fn test_cluster_fan_out( + command: &'static str, + expected_ports: Vec, + slots_config: Option>, + ) { + let name = "node"; + let found_ports = Arc::new(std::sync::Mutex::new(Vec::new())); + let ports_clone = found_ports.clone(); + let mut cmd = redis::Cmd::new(); + for arg in command.split_whitespace() { + cmd.arg(arg); + } + let packed_cmd = cmd.get_packed_command(); + // requests should route to replica + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config( + name, + received_cmd, + slots_config.clone(), + )?; + if received_cmd == packed_cmd { + ports_clone.lock().unwrap().push(port); + return Err(Ok(Value::SimpleString("OK".into()))); + } + Ok(()) + }, + ); + + let _ = cmd.query::>(&mut connection); + found_ports.lock().unwrap().sort(); + // MockEnv creates 2 mock connections. + assert_eq!(*found_ports.lock().unwrap(), expected_ports); + } + + #[test] + fn test_cluster_fan_out_to_all_primaries() { + test_cluster_fan_out("FLUSHALL", vec![6379, 6381], None); + } + + #[test] + fn test_cluster_fan_out_to_all_nodes() { + test_cluster_fan_out("CONFIG SET", vec![6379, 6380, 6381, 6382], None); + } + + #[test] + fn test_cluster_fan_out_out_once_to_each_primary_when_no_replicas_are_available() { + test_cluster_fan_out( + "CONFIG SET", + vec![6379, 6381], + Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: Vec::new(), + slot_range: (0..8191), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: Vec::new(), + slot_range: (8192..16383), + }, + ]), + ); + } + + #[test] + fn test_cluster_fan_out_out_once_even_if_primary_has_multiple_slot_ranges() { + test_cluster_fan_out( + "CONFIG SET", + vec![6379, 6380, 6381, 6382], + Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380], + slot_range: (0..4000), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: vec![6382], + slot_range: (4001..8191), + }, + MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380], + slot_range: (8192..8200), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: vec![6382], + slot_range: (8201..16383), + }, + ]), + ); + } + + #[test] + fn test_cluster_split_multi_shard_command_and_combine_arrays_of_values() { + let name = "test_cluster_split_multi_shard_command_and_combine_arrays_of_values"; + let mut cmd = cmd("MGET"); + cmd.arg("foo").arg("bar").arg("baz"); + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + let cmd_str = std::str::from_utf8(received_cmd).unwrap(); + let results = ["foo", "bar", "baz"] + .iter() + .filter_map(|expected_key| { + if cmd_str.contains(expected_key) { + Some(Value::BulkString( + format!("{expected_key}-{port}").into_bytes(), + )) + } else { + None + } + }) + .collect(); + Err(Ok(Value::Array(results))) + }, + ); + + let result = cmd.query::>(&mut connection).unwrap(); + assert_eq!(result, vec!["foo-6382", "bar-6380", "baz-6380"]); + } + + #[test] + fn test_cluster_route_correctly_on_packed_transaction_with_single_node_requests() { + let name = "test_cluster_route_correctly_on_packed_transaction_with_single_node_requests"; + let mut pipeline = redis::pipe(); + pipeline.atomic().set("foo", "bar").get("foo"); + let packed_pipeline = pipeline.get_packed_pipeline(); + + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + if port == 6381 { + let results = vec![ + Value::BulkString("OK".as_bytes().to_vec()), + Value::BulkString("QUEUED".as_bytes().to_vec()), + Value::BulkString("QUEUED".as_bytes().to_vec()), + Value::Array(vec![ + Value::BulkString("OK".as_bytes().to_vec()), + Value::BulkString("bar".as_bytes().to_vec()), + ]), + ]; + return Err(Ok(Value::Array(results))); + } + Err(Err(RedisError::from(std::io::Error::new( + std::io::ErrorKind::ConnectionReset, + format!("wrong port: {port}"), + )))) + }, + ); + + let result = connection + .req_packed_commands(&packed_pipeline, 3, 1) + .unwrap(); + assert_eq!( + result, + vec![ + Value::BulkString("OK".as_bytes().to_vec()), + Value::BulkString("bar".as_bytes().to_vec()), + ] + ); + } + + #[test] + fn test_cluster_route_correctly_on_packed_transaction_with_single_node_requests2() { + let name = "test_cluster_route_correctly_on_packed_transaction_with_single_node_requests2"; + let mut pipeline = redis::pipe(); + pipeline.atomic().set("foo", "bar").get("foo"); + let packed_pipeline = pipeline.get_packed_pipeline(); + let results = vec![ + Value::BulkString("OK".as_bytes().to_vec()), + Value::BulkString("QUEUED".as_bytes().to_vec()), + Value::BulkString("QUEUED".as_bytes().to_vec()), + Value::Array(vec![ + Value::BulkString("OK".as_bytes().to_vec()), + Value::BulkString("bar".as_bytes().to_vec()), + ]), + ]; + let expected_result = Value::Array(results); + let cloned_result = expected_result.clone(); + + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + if port == 6381 { + return Err(Ok(cloned_result.clone())); + } + Err(Err(RedisError::from(std::io::Error::new( + std::io::ErrorKind::ConnectionReset, + format!("wrong port: {port}"), + )))) + }, + ); + + let result = connection.req_packed_command(&packed_pipeline).unwrap(); + assert_eq!(result, expected_result); + } + + #[test] + fn test_cluster_with_client_name() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.client_name(RedisCluster::client_name().to_string()), + false, + ); + let mut con = cluster.connection(); + let client_info: String = redis::cmd("CLIENT").arg("INFO").query(&mut con).unwrap(); + + let client_attrs = parse_client_info(&client_info); + + assert!( + client_attrs.contains_key("name"), + "Could not detect the 'name' attribute in CLIENT INFO output" + ); + + assert_eq!( + client_attrs["name"], + RedisCluster::client_name(), + "Incorrect client name, expecting: {}, got {}", + RedisCluster::client_name(), + client_attrs["name"] + ); + } + + #[test] + fn test_cluster_can_be_created_with_partial_slot_coverage() { + let name = "test_cluster_can_be_created_with_partial_slot_coverage"; + let slots_config = Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![], + slot_range: (0..8000), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: vec![], + slot_range: (8201..16380), + }, + ]); + + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], _| { + respond_startup_with_replica_using_config( + name, + received_cmd, + slots_config.clone(), + )?; + Err(Ok(Value::SimpleString("PONG".into()))) + }, + ); + + let res = connection.req_command(&redis::cmd("PING")); + assert!(res.is_ok()); + } + + #[cfg(feature = "tls-rustls")] + mod mtls_test { + use super::*; + use crate::support::mtls_test::create_cluster_client_from_cluster; + use redis::ConnectionInfo; + + #[test] + fn test_cluster_basics_with_mtls() { + let cluster = TestClusterContext::new_with_mtls(3, 0); + + let client = create_cluster_client_from_cluster(&cluster, true).unwrap(); + let mut con = client.get_connection(None).unwrap(); + + redis::cmd("SET") + .arg("{x}key1") + .arg(b"foo") + .execute(&mut con); + redis::cmd("SET").arg(&["{x}key2", "bar"]).execute(&mut con); + + assert_eq!( + redis::cmd("MGET") + .arg(&["{x}key1", "{x}key2"]) + .query(&mut con), + Ok(("foo".to_string(), b"bar".to_vec())) + ); + } + + #[test] + fn test_cluster_should_not_connect_without_mtls() { + let cluster = TestClusterContext::new_with_mtls(3, 0); + + let client = create_cluster_client_from_cluster(&cluster, false).unwrap(); + let connection = client.get_connection(None); + + match cluster.cluster.servers.first().unwrap().connection_info() { + ConnectionInfo { + addr: redis::ConnectionAddr::TcpTls { .. }, + .. + } => { + if connection.is_ok() { + panic!("Must NOT be able to connect without client credentials if server accepts TLS"); + } + } + _ => { + if let Err(e) = connection { + panic!("Must be able to connect without client credentials if server does NOT accept TLS: {e:?}"); + } + } + } + } + } +} diff --git a/glide-core/redis-rs/redis/tests/test_cluster_async.rs b/glide-core/redis-rs/redis/tests/test_cluster_async.rs new file mode 100644 index 0000000000..b690ed87b5 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_cluster_async.rs @@ -0,0 +1,4245 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +#![cfg(feature = "cluster-async")] +mod support; + +#[cfg(test)] +mod cluster_async { + use std::{ + collections::{HashMap, HashSet}, + net::{IpAddr, SocketAddr}, + str::from_utf8, + sync::{ + atomic::{self, AtomicBool, AtomicI32, AtomicU16, AtomicU32, Ordering}, + Arc, + }, + time::Duration, + }; + + use futures::prelude::*; + use futures_time::{future::FutureExt, task::sleep}; + use once_cell::sync::Lazy; + use std::ops::Add; + + use redis::{ + aio::{ConnectionLike, MultiplexedConnection}, + cluster::ClusterClient, + cluster_async::{testing::MANAGEMENT_CONN_NAME, ClusterConnection, Connect}, + cluster_routing::{ + MultipleNodeRoutingInfo, Route, RoutingInfo, SingleNodeRoutingInfo, SlotAddr, + }, + cluster_topology::{get_slot, DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES}, + cmd, from_owned_redis_value, parse_redis_value, AsyncCommands, Cmd, ErrorKind, + FromRedisValue, GlideConnectionOptions, InfoDict, IntoConnectionInfo, ProtocolVersion, + PubSubChannelOrPattern, PubSubSubscriptionInfo, PubSubSubscriptionKind, PushInfo, PushKind, + RedisError, RedisFuture, RedisResult, Script, Value, + }; + + use crate::support::*; + + use tokio::sync::mpsc; + fn broken_pipe_error() -> RedisError { + RedisError::from(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "mock-io-error", + )) + } + + fn validate_subscriptions( + pubsub_subs: &PubSubSubscriptionInfo, + notifications_rx: &mut mpsc::UnboundedReceiver, + allow_disconnects: bool, + ) { + let mut subscribe_cnt = + if let Some(exact_subs) = pubsub_subs.get(&PubSubSubscriptionKind::Exact) { + exact_subs.len() + } else { + 0 + }; + + let mut psubscribe_cnt = + if let Some(pattern_subs) = pubsub_subs.get(&PubSubSubscriptionKind::Pattern) { + pattern_subs.len() + } else { + 0 + }; + + let mut ssubscribe_cnt = + if let Some(sharded_subs) = pubsub_subs.get(&PubSubSubscriptionKind::Sharded) { + sharded_subs.len() + } else { + 0 + }; + + for _ in 0..(subscribe_cnt + psubscribe_cnt + ssubscribe_cnt) { + let result = notifications_rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data: _ } = result.unwrap(); + assert!( + kind == PushKind::Subscribe + || kind == PushKind::PSubscribe + || kind == PushKind::SSubscribe + || if allow_disconnects { + kind == PushKind::Disconnection + } else { + false + } + ); + if kind == PushKind::Subscribe { + subscribe_cnt -= 1; + } else if kind == PushKind::PSubscribe { + psubscribe_cnt -= 1; + } else if kind == PushKind::SSubscribe { + ssubscribe_cnt -= 1; + } + } + + assert!(subscribe_cnt == 0); + assert!(psubscribe_cnt == 0); + assert!(ssubscribe_cnt == 0); + } + + #[test] + fn test_async_cluster_basic_cmd() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + cmd("SET") + .arg("test") + .arg("test_data") + .query_async(&mut connection) + .await?; + let res: String = cmd("GET") + .arg("test") + .clone() + .query_async(&mut connection) + .await?; + assert_eq!(res, "test_data"); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_basic_eval() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let res: String = cmd("EVAL") + .arg(r#"redis.call("SET", KEYS[1], ARGV[1]); return redis.call("GET", KEYS[1])"#) + .arg(1) + .arg("key") + .arg("test") + .query_async(&mut connection) + .await?; + assert_eq!(res, "test"); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_basic_script() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let res: String = Script::new( + r#"redis.call("SET", KEYS[1], ARGV[1]); return redis.call("GET", KEYS[1])"#, + ) + .key("key") + .arg("test") + .invoke_async(&mut connection) + .await?; + assert_eq!(res, "test"); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_route_flush_to_specific_node() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let _: () = connection.set("foo", "bar").await.unwrap(); + let _: () = connection.set("bar", "foo").await.unwrap(); + + let res: String = connection.get("foo").await.unwrap(); + assert_eq!(res, "bar".to_string()); + let res2: Option = connection.get("bar").await.unwrap(); + assert_eq!(res2, Some("foo".to_string())); + + let route = + redis::cluster_routing::Route::new(1, redis::cluster_routing::SlotAddr::Master); + let single_node_route = + redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(route); + let routing = RoutingInfo::SingleNode(single_node_route); + assert_eq!( + connection + .route_command(&redis::cmd("FLUSHALL"), routing) + .await + .unwrap(), + Value::Okay + ); + let res: String = connection.get("foo").await.unwrap(); + assert_eq!(res, "bar".to_string()); + let res2: Option = connection.get("bar").await.unwrap(); + assert_eq!(res2, None); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_route_flush_to_node_by_address() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let mut cmd = redis::cmd("INFO"); + // The other sections change with time. + // TODO - after we remove support of redis 6, we can add more than a single section - .arg("Persistence").arg("Memory").arg("Replication") + cmd.arg("Clients"); + let value = connection + .route_command( + &cmd, + RoutingInfo::MultiNode((MultipleNodeRoutingInfo::AllNodes, None)), + ) + .await + .unwrap(); + + let info_by_address = from_owned_redis_value::>(value).unwrap(); + // find the info of the first returned node + let (address, info) = info_by_address.into_iter().next().unwrap(); + let mut split_address = address.split(':'); + let host = split_address.next().unwrap().to_string(); + let port = split_address.next().unwrap().parse().unwrap(); + + let value = connection + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { host, port }), + ) + .await + .unwrap(); + let new_info = from_owned_redis_value::(value).unwrap(); + + assert_eq!(new_info, info); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_route_info_to_nodes() { + let cluster = TestClusterContext::new(12, 1); + + let split_to_addresses_and_info = |res| -> (Vec, Vec) { + if let Value::Map(values) = res { + let mut pairs: Vec<_> = values + .into_iter() + .map(|(key, value)| { + ( + redis::from_redis_value::(&key).unwrap(), + redis::from_redis_value::(&value).unwrap(), + ) + }) + .collect(); + pairs.sort_by(|(address1, _), (address2, _)| address1.cmp(address2)); + pairs.into_iter().unzip() + } else { + unreachable!("{:?}", res); + } + }; + + block_on_all(async move { + let cluster_addresses: Vec<_> = cluster + .cluster + .servers + .iter() + .map(|server| server.connection_info()) + .collect(); + let client = ClusterClient::builder(cluster_addresses.clone()) + .read_from_replicas() + .build()?; + let mut connection = client.get_async_connection(None).await?; + + let route_to_all_nodes = redis::cluster_routing::MultipleNodeRoutingInfo::AllNodes; + let routing = RoutingInfo::MultiNode((route_to_all_nodes, None)); + let res = connection + .route_command(&redis::cmd("INFO"), routing) + .await + .unwrap(); + let (addresses, infos) = split_to_addresses_and_info(res); + + let mut cluster_addresses: Vec<_> = cluster_addresses + .into_iter() + .map(|info| info.addr.to_string()) + .collect(); + cluster_addresses.sort(); + + assert_eq!(addresses.len(), 12); + assert_eq!(addresses, cluster_addresses); + assert_eq!(infos.len(), 12); + for i in 0..12 { + let split: Vec<_> = addresses[i].split(':').collect(); + assert!(infos[i].contains(&format!("tcp_port:{}", split[1]))); + } + + let route_to_all_primaries = + redis::cluster_routing::MultipleNodeRoutingInfo::AllMasters; + let routing = RoutingInfo::MultiNode((route_to_all_primaries, None)); + let res = connection + .route_command(&redis::cmd("INFO"), routing) + .await + .unwrap(); + let (addresses, infos) = split_to_addresses_and_info(res); + assert_eq!(addresses.len(), 6); + assert_eq!(infos.len(), 6); + // verify that all primaries have the correct port & host, and are marked as primaries. + for i in 0..6 { + assert!(cluster_addresses.contains(&addresses[i])); + let split: Vec<_> = addresses[i].split(':').collect(); + assert!(infos[i].contains(&format!("tcp_port:{}", split[1]))); + assert!(infos[i].contains("role:primary") || infos[i].contains("role:master")); + } + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_resp3() { + if use_protocol() == ProtocolVersion::RESP2 { + return; + } + block_on_all(async move { + let cluster = TestClusterContext::new(3, 0); + + let mut connection = cluster.async_connection(None).await; + + let hello: HashMap = redis::cmd("HELLO") + .query_async(&mut connection) + .await + .unwrap(); + assert_eq!(hello.get("proto").unwrap(), &Value::Int(3)); + + let _: () = connection.hset("hash", "foo", "baz").await.unwrap(); + let _: () = connection.hset("hash", "bar", "foobar").await.unwrap(); + let result: Value = connection.hgetall("hash").await.unwrap(); + + assert_eq!( + result, + Value::Map(vec![ + ( + Value::BulkString("foo".as_bytes().to_vec()), + Value::BulkString("baz".as_bytes().to_vec()) + ), + ( + Value::BulkString("bar".as_bytes().to_vec()), + Value::BulkString("foobar".as_bytes().to_vec()) + ) + ]) + ); + + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_basic_pipe() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let mut pipe = redis::pipe(); + pipe.add_command(cmd("SET").arg("test").arg("test_data").clone()); + pipe.add_command(cmd("SET").arg("{test}3").arg("test_data3").clone()); + pipe.query_async(&mut connection).await?; + let res: String = connection.get("test").await?; + assert_eq!(res, "test_data"); + let res: String = connection.get("{test}3").await?; + assert_eq!(res, "test_data3"); + Ok::<_, RedisError>(()) + }) + .unwrap() + } + + #[test] + fn test_async_cluster_multi_shard_commands() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + + let res: String = connection + .mset(&[("foo", "bar"), ("bar", "foo"), ("baz", "bazz")]) + .await?; + assert_eq!(res, "OK"); + let res: Vec = connection.mget(&["baz", "foo", "bar"]).await?; + assert_eq!(res, vec!["bazz", "bar", "foo"]); + Ok::<_, RedisError>(()) + }) + .unwrap() + } + + #[test] + fn test_async_cluster_basic_failover() { + block_on_all(async move { + test_failover(&TestClusterContext::new(6, 1), 10, 123, false).await; + Ok::<_, RedisError>(()) + }) + .unwrap() + } + + async fn do_failover( + redis: &mut redis::aio::MultiplexedConnection, + ) -> Result<(), anyhow::Error> { + cmd("CLUSTER").arg("FAILOVER").query_async(redis).await?; + Ok(()) + } + + // parameter `_mtls_enabled` can only be used if `feature = tls-rustls` is active + #[allow(dead_code)] + async fn test_failover( + env: &TestClusterContext, + requests: i32, + value: i32, + _mtls_enabled: bool, + ) { + let completed = Arc::new(AtomicI32::new(0)); + + let connection = env.async_connection(None).await; + let mut node_conns: Vec = Vec::new(); + + 'outer: loop { + node_conns.clear(); + let cleared_nodes = async { + for server in env.cluster.iter_servers() { + let addr = server.client_addr(); + + #[cfg(feature = "tls-rustls")] + let client = build_single_client( + server.connection_info(), + &server.tls_paths, + _mtls_enabled, + ) + .unwrap_or_else(|e| panic!("Failed to connect to '{addr}': {e}")); + + #[cfg(not(feature = "tls-rustls"))] + let client = redis::Client::open(server.connection_info()) + .unwrap_or_else(|e| panic!("Failed to connect to '{addr}': {e}")); + + let mut conn = client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await + .unwrap_or_else(|e| panic!("Failed to get connection: {e}")); + + let info: InfoDict = redis::Cmd::new() + .arg("INFO") + .query_async(&mut conn) + .await + .expect("INFO"); + let role: String = info.get("role").expect("cluster role"); + + if role == "master" { + tokio::time::timeout(std::time::Duration::from_secs(3), async { + Ok(redis::Cmd::new() + .arg("FLUSHALL") + .query_async(&mut conn) + .await?) + }) + .await + .unwrap_or_else(|err| Err(anyhow::Error::from(err)))?; + } + + node_conns.push(conn); + } + Ok::<_, anyhow::Error>(()) + } + .await; + match cleared_nodes { + Ok(()) => break 'outer, + Err(err) => { + // Failed to clear the databases, retry + tracing::warn!("{}", err); + } + } + } + + (0..requests + 1) + .map(|i| { + let mut connection = connection.clone(); + let mut node_conns = node_conns.clone(); + let completed = completed.clone(); + async move { + if i == requests / 2 { + // Failover all the nodes, error only if all the failover requests error + let mut results = future::join_all( + node_conns + .iter_mut() + .map(|conn| Box::pin(do_failover(conn))), + ) + .await; + if results.iter().all(|res| res.is_err()) { + results.pop().unwrap() + } else { + Ok::<_, anyhow::Error>(()) + } + } else { + let key = format!("test-{value}-{i}"); + cmd("SET") + .arg(&key) + .arg(i) + .clone() + .query_async(&mut connection) + .await?; + let res: i32 = cmd("GET") + .arg(key) + .clone() + .query_async(&mut connection) + .await?; + assert_eq!(res, i); + completed.fetch_add(1, Ordering::SeqCst); + Ok::<_, anyhow::Error>(()) + } + } + }) + .collect::>() + .try_collect() + .await + .unwrap_or_else(|e| panic!("{e}")); + + assert_eq!( + completed.load(Ordering::SeqCst), + requests, + "Some requests never completed!" + ); + } + + static ERROR: Lazy = Lazy::new(Default::default); + + #[derive(Clone)] + struct ErrorConnection { + inner: MultiplexedConnection, + } + + impl Connect for ErrorConnection { + fn connect<'a, T>( + info: T, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + socket_addr: Option, + glide_connection_options: GlideConnectionOptions, + ) -> RedisFuture<'a, (Self, Option)> + where + T: IntoConnectionInfo + Send + 'a, + { + Box::pin(async move { + let (inner, _ip) = MultiplexedConnection::connect( + info, + response_timeout, + connection_timeout, + socket_addr, + glide_connection_options, + ) + .await?; + Ok((ErrorConnection { inner }, None)) + }) + } + } + + impl ConnectionLike for ErrorConnection { + fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> { + if ERROR.load(Ordering::SeqCst) { + Box::pin(async move { Err(RedisError::from((redis::ErrorKind::Moved, "ERROR"))) }) + } else { + self.inner.req_packed_command(cmd) + } + } + + fn req_packed_commands<'a>( + &'a mut self, + pipeline: &'a redis::Pipeline, + offset: usize, + count: usize, + ) -> RedisFuture<'a, Vec> { + self.inner.req_packed_commands(pipeline, offset, count) + } + + fn get_db(&self) -> i64 { + self.inner.get_db() + } + + fn is_closed(&self) -> bool { + true + } + } + + #[test] + fn test_async_cluster_error_in_inner_connection() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut con = cluster.async_generic_connection::().await; + + ERROR.store(false, Ordering::SeqCst); + let r: Option = con.get("test").await?; + assert_eq!(r, None::); + + ERROR.store(true, Ordering::SeqCst); + + let result: RedisResult<()> = con.get("test").await; + assert_eq!( + result, + Err(RedisError::from((redis::ErrorKind::Moved, "ERROR"))) + ); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + fn test_async_cluster_async_std_basic_cmd() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all_using_async_std(async { + let mut connection = cluster.async_connection(None).await; + redis::cmd("SET") + .arg("test") + .arg("test_data") + .query_async(&mut connection) + .await?; + redis::cmd("GET") + .arg("test") + .clone() + .query_async(&mut connection) + .map_ok(|res: String| { + assert_eq!(res, "test_data"); + }) + .await + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_can_connect_to_server_that_sends_cluster_slots_without_host_name() { + let name = + "test_async_cluster_can_connect_to_server_that_sends_cluster_slots_without_host_name"; + + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::new(name, move |cmd: &[u8], _| { + if contains_slice(cmd, b"PING") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![Value::Array(vec![ + Value::Int(0), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString("".as_bytes().to_vec()), + Value::Int(6379), + ]), + ])]))) + } else { + Err(Ok(Value::Nil)) + } + }); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Value>(&mut connection), + ); + + assert_eq!(value, Ok(Value::Nil)); + } + + #[test] + fn test_async_cluster_can_connect_to_server_that_sends_cluster_slots_with_null_host_name() { + let name = + "test_async_cluster_can_connect_to_server_that_sends_cluster_slots_with_null_host_name"; + + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::new(name, move |cmd: &[u8], _| { + if contains_slice(cmd, b"PING") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![Value::Array(vec![ + Value::Int(0), + Value::Int(16383), + Value::Array(vec![Value::Nil, Value::Int(6379)]), + ])]))) + } else { + Err(Ok(Value::Nil)) + } + }); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Value>(&mut connection), + ); + + assert_eq!(value, Ok(Value::Nil)); + } + + #[test] + fn test_async_cluster_cannot_connect_to_server_with_unknown_host_name() { + let name = "test_async_cluster_cannot_connect_to_server_with_unknown_host_name"; + let handler = move |cmd: &[u8], _| { + if contains_slice(cmd, b"PING") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![Value::Array(vec![ + Value::Int(0), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString("?".as_bytes().to_vec()), + Value::Int(6379), + ]), + ])]))) + } else { + Err(Ok(Value::Nil)) + } + }; + let client_builder = ClusterClient::builder(vec![&*format!("redis://{name}")]); + let client: ClusterClient = client_builder.build().unwrap(); + let _handler = MockConnectionBehavior::register_new(name, Arc::new(handler)); + let connection = client.get_generic_connection::(None); + assert!(connection.is_err()); + let err = connection.err().unwrap(); + assert!(err + .to_string() + .contains("Error parsing slots: No healthy node found")) + } + + #[test] + fn test_async_cluster_can_connect_to_server_that_sends_cluster_slots_with_partial_nodes_with_unknown_host_name( + ) { + let name = "test_async_cluster_can_connect_to_server_that_sends_cluster_slots_with_partial_nodes_with_unknown_host_name"; + + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::new(name, move |cmd: &[u8], _| { + if contains_slice(cmd, b"PING") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![ + Value::Array(vec![ + Value::Int(0), + Value::Int(7000), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6379), + ]), + ]), + Value::Array(vec![ + Value::Int(7001), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString("?".as_bytes().to_vec()), + Value::Int(6380), + ]), + ]), + ]))) + } else { + Err(Ok(Value::Nil)) + } + }); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Value>(&mut connection), + ); + + assert_eq!(value, Ok(Value::Nil)); + } + + #[test] + fn test_async_cluster_retries() { + let name = "tryagain"; + + let requests = atomic::AtomicUsize::new(0); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(5), + name, + move |cmd: &[u8], _| { + respond_startup(name, cmd)?; + + match requests.fetch_add(1, atomic::Ordering::SeqCst) { + 0..=4 => Err(parse_redis_value(b"-TRYAGAIN mock\r\n")), + _ => Err(Ok(Value::BulkString(b"123".to_vec()))), + } + }, + ); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + fn test_async_cluster_tryagain_exhaust_retries() { + let name = "tryagain_exhaust_retries"; + + let requests = Arc::new(atomic::AtomicUsize::new(0)); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(2), + name, + { + let requests = requests.clone(); + move |cmd: &[u8], _| { + respond_startup(name, cmd)?; + requests.fetch_add(1, atomic::Ordering::SeqCst); + Err(parse_redis_value(b"-TRYAGAIN mock\r\n")) + } + }, + ); + + let result = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + match result { + Ok(_) => panic!("result should be an error"), + Err(e) => match e.kind() { + ErrorKind::TryAgain => {} + _ => panic!("Expected TryAgain but got {:?}", e.kind()), + }, + } + assert_eq!(requests.load(atomic::Ordering::SeqCst), 3); + } + + // Obtain the view index associated with the node with [called_port] port + fn get_node_view_index(num_of_views: usize, ports: &Vec, called_port: u16) -> usize { + let port_index = ports + .iter() + .position(|&p| p == called_port) + .unwrap_or_else(|| { + panic!( + "CLUSTER SLOTS was called with unknown port: {called_port}; Known ports: {:?}", + ports + ) + }); + // If we have less views than nodes, use the last view + if port_index < num_of_views { + port_index + } else { + num_of_views - 1 + } + } + #[test] + fn test_async_cluster_move_error_when_new_node_is_added() { + let name = "rebuild_with_extra_nodes"; + + let requests = atomic::AtomicUsize::new(0); + let started = atomic::AtomicBool::new(false); + let refreshed_map = HashMap::from([ + (6379, atomic::AtomicBool::new(false)), + (6380, atomic::AtomicBool::new(false)), + ]); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::new(name, move |cmd: &[u8], port| { + if !started.load(atomic::Ordering::SeqCst) { + respond_startup(name, cmd)?; + } + started.store(true, atomic::Ordering::SeqCst); + + if contains_slice(cmd, b"PING") || contains_slice(cmd, b"SETNAME") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + + let is_get_cmd = contains_slice(cmd, b"GET"); + let get_response = Err(Ok(Value::BulkString(b"123".to_vec()))); + match i { + // Respond that the key exists on a node that does not yet have a connection: + 0 => Err(parse_redis_value( + format!("-MOVED 123 {name}:6380\r\n").as_bytes(), + )), + _ => { + if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + // Should not attempt to refresh slots more than once, + // so we expect a single CLUSTER NODES request for each node + assert!(!refreshed_map + .get(&port) + .unwrap() + .swap(true, Ordering::SeqCst)); + Err(Ok(Value::Array(vec![ + Value::Array(vec![ + Value::Int(0), + Value::Int(1), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6379), + ]), + ]), + Value::Array(vec![ + Value::Int(2), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6380), + ]), + ]), + ]))) + } else { + assert_eq!(port, 6380); + assert!(is_get_cmd, "{:?}", std::str::from_utf8(cmd)); + get_response + } + } + } + }); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + fn test_async_cluster_refresh_topology_after_moved_assert_get_succeed_and_expected_retries( + slots_config_vec: Vec>, + ports: Vec, + has_a_majority: bool, + ) { + assert!(!ports.is_empty() && !slots_config_vec.is_empty()); + let name = "refresh_topology_moved"; + let num_of_nodes = ports.len(); + let requests = atomic::AtomicUsize::new(0); + let started = atomic::AtomicBool::new(false); + let refresh_calls = Arc::new(atomic::AtomicUsize::new(0)); + let refresh_calls_cloned = refresh_calls.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + // Disable the rate limiter to refresh slots immediately on all MOVED errors. + .slots_refresh_rate_limit(Duration::from_secs(0), 0), + name, + move |cmd: &[u8], port| { + if !started.load(atomic::Ordering::SeqCst) { + respond_startup_with_replica_using_config( + name, + cmd, + Some(slots_config_vec[0].clone()), + )?; + } + started.store(true, atomic::Ordering::SeqCst); + + if contains_slice(cmd, b"PING") || contains_slice(cmd, b"SETNAME") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + let is_get_cmd = contains_slice(cmd, b"GET"); + let get_response = Err(Ok(Value::BulkString(b"123".to_vec()))); + let moved_node = ports[0]; + match i { + // Respond that the key exists on a node that does not yet have a connection: + 0 => Err(parse_redis_value( + format!("-MOVED 123 {name}:{moved_node}\r\n").as_bytes(), + )), + _ => { + if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + refresh_calls_cloned.fetch_add(1, atomic::Ordering::SeqCst); + let view_index = + get_node_view_index(slots_config_vec.len(), &ports, port); + Err(Ok(create_topology_from_config( + name, + slots_config_vec[view_index].clone(), + ))) + } else { + assert_eq!(port, moved_node); + assert!(is_get_cmd, "{:?}", std::str::from_utf8(cmd)); + get_response + } + } + } + }, + ); + runtime.block_on(async move { + let res = cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection) + .await; + assert_eq!(res, Ok(Some(123))); + // If there is a majority in the topology views, or if it's a 2-nodes cluster, we shall be able to calculate the topology on the first try, + // so each node will be queried only once with CLUSTER SLOTS. + // Otherwise, if we don't have a majority, we expect to see the refresh_slots function being called with the maximum retry number. + let expected_calls = if has_a_majority || num_of_nodes == 2 {num_of_nodes} else {DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES * num_of_nodes}; + let mut refreshed_calls = 0; + for _ in 0..100 { + refreshed_calls = refresh_calls.load(atomic::Ordering::Relaxed); + if refreshed_calls == expected_calls { + return; + } else { + let sleep_duration = core::time::Duration::from_millis(100); + #[cfg(feature = "tokio-comp")] + tokio::time::sleep(sleep_duration).await; + + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + async_std::task::sleep(sleep_duration).await; + } + } + panic!("Failed to reach to the expected topology refresh retries. Found={refreshed_calls}, Expected={expected_calls}") + }); + } + + fn test_async_cluster_refresh_slots_rate_limiter_helper( + slots_config_vec: Vec>, + ports: Vec, + should_skip: bool, + ) { + // This test queries GET, which returns a MOVED error. If `should_skip` is true, + // it indicates that we should skip refreshing slots because the specified time + // duration since the last refresh slots call has not yet passed. In this case, + // we expect CLUSTER SLOTS not to be called on the nodes after receiving the + // MOVED error. + + // If `should_skip` is false, we verify that if the MOVED error occurs after the + // time duration of the rate limiter has passed, the refresh slots operation + // should not be skipped. We assert this by expecting calls to CLUSTER SLOTS on + // all nodes. + let test_name = format!( + "test_async_cluster_refresh_slots_rate_limiter_helper_{}", + if should_skip { + "should_skip" + } else { + "not_skipping_waiting_time_passed" + } + ); + + let requests = atomic::AtomicUsize::new(0); + let started = atomic::AtomicBool::new(false); + let refresh_calls = Arc::new(atomic::AtomicUsize::new(0)); + let refresh_calls_cloned = Arc::clone(&refresh_calls); + let wait_duration = Duration::from_millis(10); + let num_of_nodes = ports.len(); + + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{test_name}")]) + .slots_refresh_rate_limit(wait_duration, 0), + test_name.clone().as_str(), + move |cmd: &[u8], port| { + if !started.load(atomic::Ordering::SeqCst) { + respond_startup_with_replica_using_config( + test_name.as_str(), + cmd, + Some(slots_config_vec[0].clone()), + )?; + started.store(true, atomic::Ordering::SeqCst); + } + + if contains_slice(cmd, b"PING") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + let is_get_cmd = contains_slice(cmd, b"GET"); + let get_response = Err(Ok(Value::BulkString(b"123".to_vec()))); + let moved_node = ports[0]; + match i { + // The first request calls are the starting calls for each GET command where we want to respond with MOVED error + 0 => { + if !should_skip { + // Wait for the wait duration to pass + std::thread::sleep(wait_duration.add(Duration::from_millis(10))); + } + Err(parse_redis_value( + format!("-MOVED 123 {test_name}:{moved_node}\r\n").as_bytes(), + )) + } + _ => { + if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + refresh_calls_cloned.fetch_add(1, atomic::Ordering::SeqCst); + let view_index = + get_node_view_index(slots_config_vec.len(), &ports, port); + Err(Ok(create_topology_from_config( + test_name.as_str(), + slots_config_vec[view_index].clone(), + ))) + } else { + // Even if the slots weren't refreshed we still expect the command to be + // routed by the redirect host and port it received in the moved error + assert_eq!(port, moved_node); + assert!(is_get_cmd, "{:?}", std::str::from_utf8(cmd)); + get_response + } + } + } + }, + ); + + runtime.block_on(async move { + // First GET request should raise MOVED error and then refresh slots + let res = cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection) + .await; + assert_eq!(res, Ok(Some(123))); + + // We should skip is false, we should call CLUSTER SLOTS once per node + let expected_calls = if should_skip { + 0 + } else { + num_of_nodes + }; + for _ in 0..4 { + if refresh_calls.load(atomic::Ordering::Relaxed) == expected_calls { + return Ok::<_, RedisError>(()); + } + let _ = sleep(Duration::from_millis(50).into()).await; + } + panic!("Refresh slots wasn't called as expected!\nExpected CLUSTER SLOTS calls: {}, actual calls: {:?}", expected_calls, refresh_calls.load(atomic::Ordering::Relaxed)); + }).unwrap() + } + + fn test_async_cluster_refresh_topology_in_client_init_get_succeed( + slots_config_vec: Vec>, + ports: Vec, + ) { + assert!(!ports.is_empty() && !slots_config_vec.is_empty()); + let name = "refresh_topology_client_init"; + let started = atomic::AtomicBool::new(false); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder::( + ports + .iter() + .map(|port| format!("redis://{name}:{port}")) + .collect::>(), + ), + name, + move |cmd: &[u8], port| { + let is_started = started.load(atomic::Ordering::SeqCst); + if !is_started { + if contains_slice(cmd, b"PING") || contains_slice(cmd, b"SETNAME") { + return Err(Ok(Value::SimpleString("OK".into()))); + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + let view_index = get_node_view_index(slots_config_vec.len(), &ports, port); + return Err(Ok(create_topology_from_config( + name, + slots_config_vec[view_index].clone(), + ))); + } else if contains_slice(cmd, b"READONLY") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + } + started.store(true, atomic::Ordering::SeqCst); + if contains_slice(cmd, b"PING") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + let is_get_cmd = contains_slice(cmd, b"GET"); + let get_response = Err(Ok(Value::BulkString(b"123".to_vec()))); + { + assert!(is_get_cmd, "{:?}", std::str::from_utf8(cmd)); + get_response + } + }, + ); + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + fn generate_topology_view( + ports: &[u16], + interval: usize, + full_slot_coverage: bool, + ) -> Vec { + let mut slots_res = vec![]; + let mut start_pos: usize = 0; + for (idx, port) in ports.iter().enumerate() { + let end_pos: usize = if idx == ports.len() - 1 && full_slot_coverage { + 16383 + } else { + start_pos + interval + }; + let mock_slot = MockSlotRange { + primary_port: *port, + replica_ports: vec![], + slot_range: (start_pos as u16..end_pos as u16), + }; + slots_res.push(mock_slot); + start_pos = end_pos + 1; + } + slots_res + } + + fn get_ports(num_of_nodes: usize) -> Vec { + (6379_u16..6379 + num_of_nodes as u16).collect() + } + + fn get_no_majority_topology_view(ports: &[u16]) -> Vec> { + let mut result = vec![]; + let mut full_coverage = true; + for i in 0..ports.len() { + result.push(generate_topology_view(ports, i + 1, full_coverage)); + full_coverage = !full_coverage; + } + result + } + + fn get_topology_with_majority(ports: &[u16]) -> Vec> { + let view: Vec = generate_topology_view(ports, 10, true); + let result: Vec<_> = ports.iter().map(|_| view.clone()).collect(); + result + } + + #[test] + fn test_async_cluster_refresh_topology_after_moved_error_all_nodes_agree_get_succeed() { + let ports = get_ports(3); + test_async_cluster_refresh_topology_after_moved_assert_get_succeed_and_expected_retries( + get_topology_with_majority(&ports), + ports, + true, + ); + } + + #[test] + fn test_async_cluster_refresh_topology_in_client_init_all_nodes_agree_get_succeed() { + let ports = get_ports(3); + test_async_cluster_refresh_topology_in_client_init_get_succeed( + get_topology_with_majority(&ports), + ports, + ); + } + + #[test] + fn test_async_cluster_refresh_topology_after_moved_error_with_no_majority_get_succeed() { + for num_of_nodes in 2..4 { + let ports = get_ports(num_of_nodes); + test_async_cluster_refresh_topology_after_moved_assert_get_succeed_and_expected_retries( + get_no_majority_topology_view(&ports), + ports, + false, + ); + } + } + + #[test] + fn test_async_cluster_refresh_topology_in_client_init_with_no_majority_get_succeed() { + for num_of_nodes in 2..4 { + let ports = get_ports(num_of_nodes); + test_async_cluster_refresh_topology_in_client_init_get_succeed( + get_no_majority_topology_view(&ports), + ports, + ); + } + } + + #[test] + fn test_async_cluster_refresh_topology_even_with_zero_retries() { + let name = "test_async_cluster_refresh_topology_even_with_zero_retries"; + + let should_refresh = atomic::AtomicBool::new(false); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(0) + // Disable the rate limiter to refresh slots immediately on the MOVED error. + .slots_refresh_rate_limit(Duration::from_secs(0), 0), + name, + move |cmd: &[u8], port| { + if !should_refresh.load(atomic::Ordering::SeqCst) { + respond_startup(name, cmd)?; + } + + if contains_slice(cmd, b"PING") || contains_slice(cmd, b"SETNAME") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + return Err(Ok(Value::Array(vec![ + Value::Array(vec![ + Value::Int(0), + Value::Int(1), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6379), + ]), + ]), + Value::Array(vec![ + Value::Int(2), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6380), + ]), + ]), + ]))); + } + + if contains_slice(cmd, b"GET") { + let get_response = Err(Ok(Value::BulkString(b"123".to_vec()))); + match port { + 6380 => get_response, + // Respond that the key exists on a node that does not yet have a connection: + _ => { + // Should not attempt to refresh slots more than once: + assert!(!should_refresh.swap(true, Ordering::SeqCst)); + Err(parse_redis_value( + format!("-MOVED 123 {name}:6380\r\n").as_bytes(), + )) + } + } + } else { + panic!("unexpected command {cmd:?}") + } + }, + ); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + // The user should receive an initial error, because there are no retries and the first request failed. + assert_eq!( + value, + Err(RedisError::from(( + ErrorKind::Moved, + "An error was signalled by the server", + "test_async_cluster_refresh_topology_even_with_zero_retries:6380".to_string() + ))) + ); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + fn test_async_cluster_reconnect_even_with_zero_retries() { + let name = "test_async_cluster_reconnect_even_with_zero_retries"; + + let should_reconnect = atomic::AtomicBool::new(true); + let connection_count = Arc::new(atomic::AtomicU16::new(0)); + let connection_count_clone = connection_count.clone(); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(0), + name, + move |cmd: &[u8], port| { + match respond_startup(name, cmd) { + Ok(_) => {} + Err(err) => { + connection_count.fetch_add(1, Ordering::Relaxed); + return Err(err); + } + } + + if contains_slice(cmd, b"ECHO") && port == 6379 { + // Should not attempt to refresh slots more than once: + if should_reconnect.swap(false, Ordering::SeqCst) { + Err(Err(broken_pipe_error())) + } else { + Err(Ok(Value::BulkString(b"PONG".to_vec()))) + } + } else { + panic!("unexpected command {cmd:?}") + } + }, + ); + + // We expect 6 calls in total. MockEnv creates both synchronous and asynchronous connections, which make the following calls: + // - 1 call by the sync connection to `CLUSTER SLOTS` for initializing the client's topology map. + // - 3 calls by the async connection to `PING`: one for the user connection when creating the node from initial addresses, + // and two more for checking the user and management connections during client initialization in `refresh_slots`. + // - 1 call by the async connection to `CLIENT SETNAME` for setting up the management connection name. + // - 1 call by the async connection to `CLUSTER SLOTS` for initializing the client's topology map. + // Note: If additional nodes or setup calls are added, this number should increase. + let expected_init_calls = 6; + assert_eq!( + connection_count_clone.load(Ordering::Relaxed), + expected_init_calls + ); + + let value = runtime.block_on(connection.route_command( + &cmd("ECHO"), + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: name.to_string(), + port: 6379, + }), + )); + + // The user should receive an initial error, because there are no retries and the first request failed. + assert_eq!( + value.unwrap_err().to_string(), + broken_pipe_error().to_string() + ); + + let value = runtime.block_on(connection.route_command( + &cmd("ECHO"), + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: name.to_string(), + port: 6379, + }), + )); + + assert_eq!(value, Ok(Value::BulkString(b"PONG".to_vec()))); + // `expected_init_calls` plus another PING for a new user connection created from refresh_connections + assert_eq!( + connection_count_clone.load(Ordering::Relaxed), + expected_init_calls + 1 + ); + } + + #[test] + fn test_async_cluster_refresh_slots_rate_limiter_skips_refresh() { + let ports = get_ports(3); + test_async_cluster_refresh_slots_rate_limiter_helper( + get_topology_with_majority(&ports), + ports, + true, + ); + } + + #[test] + fn test_async_cluster_refresh_slots_rate_limiter_does_refresh_when_wait_duration_passed() { + let ports = get_ports(3); + test_async_cluster_refresh_slots_rate_limiter_helper( + get_topology_with_majority(&ports), + ports, + false, + ); + } + + #[test] + fn test_async_cluster_ask_redirect() { + let name = "node"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { + async_connection: mut connection, + handler: _handler, + runtime, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]), + name, + { + move |cmd: &[u8], port| { + respond_startup_two_nodes(name, cmd)?; + // Error twice with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + let count = completed.fetch_add(1, Ordering::SeqCst); + match port { + 6379 => match count { + 0 => Err(parse_redis_value(b"-ASK 14000 node:6380\r\n")), + _ => panic!("Node should not be called now"), + }, + 6380 => match count { + 1 => { + assert!(contains_slice(cmd, b"ASKING")); + Err(Ok(Value::Okay)) + } + 2 => { + assert!(contains_slice(cmd, b"GET")); + Err(Ok(Value::BulkString(b"123".to_vec()))) + } + _ => panic!("Node should not be called now"), + }, + _ => panic!("Wrong node"), + } + } + }, + ); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + fn test_async_cluster_ask_save_new_connection() { + let name = "node"; + let ping_attempts = Arc::new(AtomicI32::new(0)); + let ping_attempts_clone = ping_attempts.clone(); + let MockEnv { + async_connection: mut connection, + handler: _handler, + runtime, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]), + name, + { + move |cmd: &[u8], port| { + if port != 6391 { + respond_startup_two_nodes(name, cmd)?; + return Err(parse_redis_value(b"-ASK 14000 node:6391\r\n")); + } + + if contains_slice(cmd, b"PING") { + ping_attempts_clone.fetch_add(1, Ordering::Relaxed); + } + respond_startup_two_nodes(name, cmd)?; + Err(Ok(Value::Okay)) + } + }, + ); + + for _ in 0..4 { + runtime + .block_on( + cmd("GET") + .arg("test") + .query_async::<_, Value>(&mut connection), + ) + .unwrap(); + } + + assert_eq!(ping_attempts.load(Ordering::Relaxed), 1); + } + + #[test] + fn test_async_cluster_reset_routing_if_redirect_fails() { + let name = "test_async_cluster_reset_routing_if_redirect_fails"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { + async_connection: mut connection, + handler: _handler, + runtime, + .. + } = MockEnv::new(name, move |cmd: &[u8], port| { + if port != 6379 && port != 6380 { + return Err(Err(broken_pipe_error())); + } + respond_startup_two_nodes(name, cmd)?; + let count = completed.fetch_add(1, Ordering::SeqCst); + match (port, count) { + // redirect once to non-existing node + (6379, 0) => Err(parse_redis_value( + format!("-ASK 14000 {name}:9999\r\n").as_bytes(), + )), + // accept the next request + (6379, 1) => { + assert!(contains_slice(cmd, b"GET")); + Err(Ok(Value::BulkString(b"123".to_vec()))) + } + _ => panic!("Wrong node. port: {port}, received count: {count}"), + } + }); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + fn test_async_cluster_ask_redirect_even_if_original_call_had_no_route() { + let name = "node"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { + async_connection: mut connection, + handler: _handler, + runtime, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]), + name, + { + move |cmd: &[u8], port| { + respond_startup_two_nodes(name, cmd)?; + // Error twice with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + let count = completed.fetch_add(1, Ordering::SeqCst); + if count == 0 { + return Err(parse_redis_value(b"-ASK 14000 node:6380\r\n")); + } + match port { + 6380 => match count { + 1 => { + assert!( + contains_slice(cmd, b"ASKING"), + "{:?}", + std::str::from_utf8(cmd) + ); + Err(Ok(Value::Okay)) + } + 2 => { + assert!(contains_slice(cmd, b"EVAL")); + Err(Ok(Value::Okay)) + } + _ => panic!("Node should not be called now"), + }, + _ => panic!("Wrong node"), + } + } + }, + ); + + let value = runtime.block_on( + cmd("EVAL") // Eval command has no directed, and so is redirected randomly + .query_async::<_, Value>(&mut connection), + ); + + assert_eq!(value, Ok(Value::Okay)); + } + + #[test] + fn test_async_cluster_ask_error_when_new_node_is_added() { + let name = "ask_with_extra_nodes"; + + let requests = atomic::AtomicUsize::new(0); + let started = atomic::AtomicBool::new(false); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::new(name, move |cmd: &[u8], port| { + if !started.load(atomic::Ordering::SeqCst) { + respond_startup(name, cmd)?; + } + started.store(true, atomic::Ordering::SeqCst); + + if contains_slice(cmd, b"PING") || contains_slice(cmd, b"SETNAME") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + + match i { + // Respond that the key exists on a node that does not yet have a connection: + 0 => Err(parse_redis_value( + format!("-ASK 123 {name}:6380\r\n").as_bytes(), + )), + 1 => { + assert_eq!(port, 6380); + assert!(contains_slice(cmd, b"ASKING")); + Err(Ok(Value::Okay)) + } + 2 => { + assert_eq!(port, 6380); + assert!(contains_slice(cmd, b"GET")); + Err(Ok(Value::BulkString(b"123".to_vec()))) + } + _ => { + panic!("Unexpected request: {:?}", cmd); + } + } + }); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + fn test_async_cluster_replica_read() { + let name = "node"; + + // requests should route to replica + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica(name, cmd)?; + match port { + 6380 => Err(Ok(Value::BulkString(b"123".to_vec()))), + _ => panic!("Wrong node"), + } + }, + ); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(123))); + + // requests should route to primary + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica(name, cmd)?; + match port { + 6379 => Err(Ok(Value::SimpleString("OK".into()))), + _ => panic!("Wrong node"), + } + }, + ); + + let value = runtime.block_on( + cmd("SET") + .arg("test") + .arg("123") + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(Value::SimpleString("OK".to_owned())))); + } + + fn test_async_cluster_fan_out( + command: &'static str, + expected_ports: Vec, + slots_config: Option>, + ) { + let name = "node"; + let found_ports = Arc::new(std::sync::Mutex::new(Vec::new())); + let ports_clone = found_ports.clone(); + let mut cmd = Cmd::new(); + for arg in command.split_whitespace() { + cmd.arg(arg); + } + let packed_cmd = cmd.get_packed_command(); + // requests should route to replica + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config( + name, + received_cmd, + slots_config.clone(), + )?; + if received_cmd == packed_cmd { + ports_clone.lock().unwrap().push(port); + return Err(Ok(Value::SimpleString("OK".into()))); + } + Ok(()) + }, + ); + + let _ = runtime.block_on(cmd.query_async::<_, Option<()>>(&mut connection)); + found_ports.lock().unwrap().sort(); + // MockEnv creates 2 mock connections. + assert_eq!(*found_ports.lock().unwrap(), expected_ports); + } + + #[test] + fn test_async_cluster_fan_out_to_all_primaries() { + test_async_cluster_fan_out("FLUSHALL", vec![6379, 6381], None); + } + + #[test] + fn test_async_cluster_fan_out_to_all_nodes() { + test_async_cluster_fan_out("CONFIG SET", vec![6379, 6380, 6381, 6382], None); + } + + #[test] + fn test_async_cluster_fan_out_once_to_each_primary_when_no_replicas_are_available() { + test_async_cluster_fan_out( + "CONFIG SET", + vec![6379, 6381], + Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: Vec::new(), + slot_range: (0..8191), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: Vec::new(), + slot_range: (8192..16383), + }, + ]), + ); + } + + #[test] + fn test_async_cluster_fan_out_once_even_if_primary_has_multiple_slot_ranges() { + test_async_cluster_fan_out( + "CONFIG SET", + vec![6379, 6380, 6381, 6382], + Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380], + slot_range: (0..4000), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: vec![6382], + slot_range: (4001..8191), + }, + MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380], + slot_range: (8192..8200), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: vec![6382], + slot_range: (8201..16383), + }, + ]), + ); + } + + #[test] + fn test_async_cluster_route_according_to_passed_argument() { + let name = "test_async_cluster_route_according_to_passed_argument"; + + let touched_ports = Arc::new(std::sync::Mutex::new(Vec::new())); + let cloned_ports = touched_ports.clone(); + + // requests should route to replica + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica(name, cmd)?; + cloned_ports.lock().unwrap().push(port); + Err(Ok(Value::Nil)) + }, + ); + + let mut cmd = cmd("GET"); + cmd.arg("test"); + let _ = runtime.block_on(connection.route_command( + &cmd, + RoutingInfo::MultiNode((MultipleNodeRoutingInfo::AllMasters, None)), + )); + { + let mut touched_ports = touched_ports.lock().unwrap(); + touched_ports.sort(); + assert_eq!(*touched_ports, vec![6379, 6381]); + touched_ports.clear(); + } + + let _ = runtime.block_on(connection.route_command( + &cmd, + RoutingInfo::MultiNode((MultipleNodeRoutingInfo::AllNodes, None)), + )); + { + let mut touched_ports = touched_ports.lock().unwrap(); + touched_ports.sort(); + assert_eq!(*touched_ports, vec![6379, 6380, 6381, 6382]); + touched_ports.clear(); + } + + let _ = runtime.block_on(connection.route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: name.to_string(), + port: 6382, + }), + )); + { + let mut touched_ports = touched_ports.lock().unwrap(); + touched_ports.sort(); + assert_eq!(*touched_ports, vec![6382]); + touched_ports.clear(); + } + } + + #[test] + fn test_async_cluster_fan_out_and_aggregate_numeric_response_with_min() { + let name = "test_async_cluster_fan_out_and_aggregate_numeric_response"; + let mut cmd = Cmd::new(); + cmd.arg("SLOWLOG").arg("LEN"); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + + let res = 6383 - port as i64; + Err(Ok(Value::Int(res))) // this results in 1,2,3,4 + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, i64>(&mut connection)) + .unwrap(); + assert_eq!(result, 10, "{result}"); + } + + #[test] + fn test_async_cluster_fan_out_and_aggregate_logical_array_response() { + let name = "test_async_cluster_fan_out_and_aggregate_logical_array_response"; + let mut cmd = Cmd::new(); + cmd.arg("SCRIPT") + .arg("EXISTS") + .arg("foo") + .arg("bar") + .arg("baz") + .arg("barvaz"); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + + if port == 6381 { + return Err(Ok(Value::Array(vec![ + Value::Int(0), + Value::Int(0), + Value::Int(1), + Value::Int(1), + ]))); + } else if port == 6379 { + return Err(Ok(Value::Array(vec![ + Value::Int(0), + Value::Int(1), + Value::Int(0), + Value::Int(1), + ]))); + } + + panic!("unexpected port {port}"); + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Vec>(&mut connection)) + .unwrap(); + assert_eq!(result, vec![0, 0, 0, 1], "{result:?}"); + } + + #[test] + fn test_async_cluster_fan_out_and_return_one_succeeded_response() { + let name = "test_async_cluster_fan_out_and_return_one_succeeded_response"; + let mut cmd = Cmd::new(); + cmd.arg("SCRIPT").arg("KILL"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + if port == 6381 { + return Err(Ok(Value::Okay)); + } + Err(Err(( + ErrorKind::NotBusy, + "No scripts in execution right now", + ) + .into())) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Value>(&mut connection)) + .unwrap(); + assert_eq!(result, Value::Okay, "{result:?}"); + } + + #[test] + fn test_async_cluster_fan_out_and_fail_one_succeeded_if_there_are_no_successes() { + let name = "test_async_cluster_fan_out_and_fail_one_succeeded_if_there_are_no_successes"; + let mut cmd = Cmd::new(); + cmd.arg("SCRIPT").arg("KILL"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], _port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + + Err(Err(( + ErrorKind::NotBusy, + "No scripts in execution right now", + ) + .into())) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Value>(&mut connection)) + .unwrap_err(); + assert_eq!(result.kind(), ErrorKind::NotBusy, "{:?}", result.kind()); + } + + #[test] + fn test_async_cluster_fan_out_and_return_all_succeeded_response() { + let name = "test_async_cluster_fan_out_and_return_all_succeeded_response"; + let cmd = cmd("FLUSHALL"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], _port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + Err(Ok(Value::Okay)) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Value>(&mut connection)) + .unwrap(); + assert_eq!(result, Value::Okay, "{result:?}"); + } + + #[test] + fn test_async_cluster_fan_out_and_fail_all_succeeded_if_there_is_a_single_failure() { + let name = "test_async_cluster_fan_out_and_fail_all_succeeded_if_there_is_a_single_failure"; + let cmd = cmd("FLUSHALL"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + if port == 6381 { + return Err(Err(( + ErrorKind::NotBusy, + "No scripts in execution right now", + ) + .into())); + } + Err(Ok(Value::Okay)) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Value>(&mut connection)) + .unwrap_err(); + assert_eq!(result.kind(), ErrorKind::NotBusy, "{:?}", result.kind()); + } + + #[test] + fn test_async_cluster_first_succeeded_non_empty_or_all_empty_return_value_ignoring_nil_and_err_resps( + ) { + let name = + "test_async_cluster_first_succeeded_non_empty_or_all_empty_return_value_ignoring_nil_and_err_resps"; + let cmd = cmd("RANDOMKEY"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + let ports = vec![6379, 6380, 6381]; + let slots_config_vec = generate_topology_view(&ports, 1000, true); + respond_startup_with_config(name, received_cmd, Some(slots_config_vec), false)?; + if port == 6380 { + return Err(Ok(Value::BulkString("foo".as_bytes().to_vec()))); + } else if port == 6381 { + return Err(Err(RedisError::from(( + redis::ErrorKind::ResponseError, + "ERROR", + )))); + } + Err(Ok(Value::Nil)) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, String>(&mut connection)) + .unwrap(); + assert_eq!(result, "foo", "{result:?}"); + } + + #[test] + fn test_async_cluster_first_succeeded_non_empty_or_all_empty_return_err_if_all_resps_are_nil_and_errors( + ) { + let name = + "test_async_cluster_first_succeeded_non_empty_or_all_empty_return_err_if_all_resps_are_nil_and_errors"; + let cmd = cmd("RANDOMKEY"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_config(name, received_cmd, None, false)?; + if port == 6380 { + return Err(Ok(Value::Nil)); + } + Err(Err(RedisError::from(( + redis::ErrorKind::ResponseError, + "ERROR", + )))) + }, + ); + let result = runtime + .block_on(cmd.query_async::<_, Value>(&mut connection)) + .unwrap_err(); + assert_eq!(result.kind(), ErrorKind::ResponseError); + } + + #[test] + fn test_async_cluster_first_succeeded_non_empty_or_all_empty_return_nil_if_all_resp_nil() { + let name = + "test_async_cluster_first_succeeded_non_empty_or_all_empty_return_nil_if_all_resp_nil"; + let cmd = cmd("RANDOMKEY"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], _port| { + respond_startup_with_config(name, received_cmd, None, false)?; + Err(Ok(Value::Nil)) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Value>(&mut connection)) + .unwrap(); + assert_eq!(result, Value::Nil, "{result:?}"); + } + + #[test] + fn test_async_cluster_fan_out_and_return_map_of_results_for_special_response_policy() { + let name = "foo"; + let mut cmd = Cmd::new(); + cmd.arg("LATENCY").arg("LATEST"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + Err(Ok(Value::BulkString( + format!("latency: {port}").into_bytes(), + ))) + }, + ); + + // TODO once RESP3 is in, return this as a map + let mut result = runtime + .block_on(cmd.query_async::<_, Vec<(String, String)>>(&mut connection)) + .unwrap(); + result.sort(); + assert_eq!( + result, + vec![ + (format!("{name}:6379"), "latency: 6379".to_string()), + (format!("{name}:6380"), "latency: 6380".to_string()), + (format!("{name}:6381"), "latency: 6381".to_string()), + (format!("{name}:6382"), "latency: 6382".to_string()) + ], + "{result:?}" + ); + } + + #[test] + fn test_async_cluster_fan_out_and_combine_arrays_of_values() { + let name = "foo"; + let cmd = cmd("KEYS"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + Err(Ok(Value::Array(vec![Value::BulkString( + format!("key:{port}").into_bytes(), + )]))) + }, + ); + + let mut result = runtime + .block_on(cmd.query_async::<_, Vec>(&mut connection)) + .unwrap(); + result.sort(); + assert_eq!( + result, + vec!["key:6379".to_string(), "key:6381".to_string(),], + "{result:?}" + ); + } + + #[test] + fn test_async_cluster_split_multi_shard_command_and_combine_arrays_of_values() { + let name = "test_async_cluster_split_multi_shard_command_and_combine_arrays_of_values"; + let mut cmd = cmd("MGET"); + cmd.arg("foo").arg("bar").arg("baz"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + let cmd_str = std::str::from_utf8(received_cmd).unwrap(); + let results = ["foo", "bar", "baz"] + .iter() + .filter_map(|expected_key| { + if cmd_str.contains(expected_key) { + Some(Value::BulkString( + format!("{expected_key}-{port}").into_bytes(), + )) + } else { + None + } + }) + .collect(); + Err(Ok(Value::Array(results))) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Vec>(&mut connection)) + .unwrap(); + assert_eq!(result, vec!["foo-6382", "bar-6380", "baz-6380"]); + } + + #[test] + fn test_async_cluster_handle_asking_error_in_split_multi_shard_command() { + let name = "test_async_cluster_handle_asking_error_in_split_multi_shard_command"; + let mut cmd = cmd("MGET"); + cmd.arg("foo").arg("bar").arg("baz"); + let asking_called = Arc::new(AtomicU16::new(0)); + let asking_called_cloned = asking_called.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + let cmd_str = std::str::from_utf8(received_cmd).unwrap(); + if cmd_str.contains("ASKING") && port == 6382 { + asking_called_cloned.fetch_add(1, Ordering::Relaxed); + } + if port == 6380 && cmd_str.contains("baz") { + return Err(parse_redis_value( + format!("-ASK 14000 {name}:6382\r\n").as_bytes(), + )); + } + let results = ["foo", "bar", "baz"] + .iter() + .filter_map(|expected_key| { + if cmd_str.contains(expected_key) { + Some(Value::BulkString( + format!("{expected_key}-{port}").into_bytes(), + )) + } else { + None + } + }) + .collect(); + Err(Ok(Value::Array(results))) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Vec>(&mut connection)) + .unwrap(); + assert_eq!(result, vec!["foo-6382", "bar-6380", "baz-6382"]); + assert_eq!(asking_called.load(Ordering::Relaxed), 1); + } + + #[test] + fn test_async_cluster_pass_errors_from_split_multi_shard_command() { + let name = "test_async_cluster_pass_errors_from_split_multi_shard_command"; + let mut cmd = cmd("MGET"); + cmd.arg("foo").arg("bar").arg("baz"); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::new(name, move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + let cmd_str = std::str::from_utf8(received_cmd).unwrap(); + if cmd_str.contains("foo") || cmd_str.contains("baz") { + Err(Err((ErrorKind::IoError, "error").into())) + } else { + Err(Ok(Value::Array(vec![Value::BulkString( + format!("{port}").into_bytes(), + )]))) + } + }); + + let result = runtime + .block_on(cmd.query_async::<_, Vec>(&mut connection)) + .unwrap_err(); + assert_eq!(result.kind(), ErrorKind::IoError); + } + + #[test] + fn test_async_cluster_handle_missing_slots_in_split_multi_shard_command() { + let name = "test_async_cluster_handle_missing_slots_in_split_multi_shard_command"; + let mut cmd = cmd("MGET"); + cmd.arg("foo").arg("bar").arg("baz"); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::new(name, move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config( + name, + received_cmd, + Some(vec![MockSlotRange { + primary_port: 6381, + replica_ports: vec![6382], + slot_range: (8192..16383), + }]), + )?; + Err(Ok(Value::Array(vec![Value::BulkString( + format!("{port}").into_bytes(), + )]))) + }); + + let result = runtime + .block_on(cmd.query_async::<_, Vec>(&mut connection)) + .unwrap_err(); + assert!( + matches!(result.kind(), ErrorKind::ConnectionNotFoundForRoute) + || result.is_connection_dropped() + ); + } + + #[test] + fn test_async_cluster_with_username_and_password() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .username(RedisCluster::username().to_string()) + .password(RedisCluster::password().to_string()) + }, + false, + ); + cluster.disable_default_user(); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + cmd("SET") + .arg("test") + .arg("test_data") + .query_async(&mut connection) + .await?; + let res: String = cmd("GET") + .arg("test") + .clone() + .query_async(&mut connection) + .await?; + assert_eq!(res, "test_data"); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_io_error() { + let name = "node"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(2), + name, + move |cmd: &[u8], port| { + respond_startup_two_nodes(name, cmd)?; + // Error twice with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + match port { + 6380 => panic!("Node should not be called"), + _ => match completed.fetch_add(1, Ordering::SeqCst) { + 0..=1 => Err(Err(RedisError::from(std::io::Error::new( + std::io::ErrorKind::ConnectionReset, + "mock-io-error", + )))), + _ => Err(Ok(Value::BulkString(b"123".to_vec()))), + }, + } + }, + ); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + fn test_async_cluster_non_retryable_error_should_not_retry() { + let name = "node"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { + async_connection: mut connection, + handler: _handler, + runtime, + .. + } = MockEnv::new(name, { + let completed = completed.clone(); + move |cmd: &[u8], _| { + respond_startup_two_nodes(name, cmd)?; + // Error twice with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + completed.fetch_add(1, Ordering::SeqCst); + Err(Err((ErrorKind::ReadOnly, "").into())) + } + }); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + match value { + Ok(_) => panic!("result should be an error"), + Err(e) => match e.kind() { + ErrorKind::ReadOnly => {} + _ => panic!("Expected ReadOnly but got {:?}", e.kind()), + }, + } + assert_eq!(completed.load(Ordering::SeqCst), 1); + } + + #[test] + fn test_async_cluster_read_from_primary() { + let name = "node"; + let found_ports = Arc::new(std::sync::Mutex::new(Vec::new())); + let ports_clone = found_ports.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::new(name, move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config( + name, + received_cmd, + Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380, 6381], + slot_range: (0..8191), + }, + MockSlotRange { + primary_port: 6382, + replica_ports: vec![6383, 6384], + slot_range: (8192..16383), + }, + ]), + )?; + ports_clone.lock().unwrap().push(port); + Err(Ok(Value::Nil)) + }); + + runtime.block_on(async { + cmd("GET") + .arg("foo") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("bar") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("foo") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("bar") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + }); + + found_ports.lock().unwrap().sort(); + assert_eq!(*found_ports.lock().unwrap(), vec![6379, 6379, 6382, 6382]); + } + + #[test] + fn test_async_cluster_round_robin_read_from_replica() { + let name = "node"; + let found_ports = Arc::new(std::sync::Mutex::new(Vec::new())); + let ports_clone = found_ports.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config( + name, + received_cmd, + Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380, 6381], + slot_range: (0..8191), + }, + MockSlotRange { + primary_port: 6382, + replica_ports: vec![6383, 6384], + slot_range: (8192..16383), + }, + ]), + )?; + ports_clone.lock().unwrap().push(port); + Err(Ok(Value::Nil)) + }, + ); + + runtime.block_on(async { + cmd("GET") + .arg("foo") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("bar") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("foo") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("bar") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + }); + + found_ports.lock().unwrap().sort(); + assert_eq!(*found_ports.lock().unwrap(), vec![6380, 6381, 6383, 6384]); + } + + fn get_queried_node_id_if_master(cluster_nodes_output: Value) -> Option { + // Returns the node ID of the connection that was queried for CLUSTER NODES (using the 'myself' flag), if it's a master. + // Otherwise, returns None. + let get_node_id = |str: &str| { + let parts: Vec<&str> = str.split('\n').collect(); + for node_entry in parts { + if node_entry.contains("myself") && node_entry.contains("master") { + let node_entry_parts: Vec<&str> = node_entry.split(' ').collect(); + let node_id = node_entry_parts[0]; + return Some(node_id.to_string()); + } + } + None + }; + + match cluster_nodes_output { + Value::BulkString(val) => match from_utf8(&val) { + Ok(str_res) => get_node_id(str_res), + Err(e) => panic!("failed to decode INFO response: {:?}", e), + }, + Value::VerbatimString { format: _, text } => get_node_id(&text), + _ => panic!("Recieved unexpected response: {:?}", cluster_nodes_output), + } + } + + #[test] + fn test_async_cluster_handle_complete_server_disconnect_without_panicking() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.retries(2), + false, + ); + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + drop(cluster); + for _ in 0..5 { + let cmd = cmd("PING"); + let result = connection + .route_command(&cmd, RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + .await; + // TODO - this should be a NoConnectionError, but ATM we get the errors from the failing + assert!(result.is_err()); + // This will route to all nodes - different path through the code. + let result = connection.req_packed_command(&cmd).await; + // TODO - this should be a NoConnectionError, but ATM we get the errors from the failing + assert!(result.is_err()); + } + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_test_fast_reconnect() { + // Note the 3 seconds connection check to differentiate between notifications and periodic + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .retries(0) + .periodic_connections_checks(Duration::from_secs(3)) + }, + false, + ); + + // For tokio-comp, do 3 consequtive disconnects and ensure reconnects succeeds in less than 100ms, + // which is more than enough for local connections even with TLS. + // More than 1 run is done to ensure it is the fast reconnect notification that trigger the reconnect + // and not the periodic interval. + // For other async implementation, only periodic connection check is available, hence, + // do 1 run sleeping for periodic connection check interval, allowing it to reestablish connections + block_on_all(async move { + let mut disconnecting_con = cluster.async_connection(None).await; + let mut monitoring_con = cluster.async_connection(None).await; + + #[cfg(feature = "tokio-comp")] + let tries = 0..3; + #[cfg(not(feature = "tokio-comp"))] + let tries = 0..1; + + for _ in tries { + // get connection id + let mut cmd = redis::cmd("CLIENT"); + cmd.arg("ID"); + let res = disconnecting_con + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + 0, + SlotAddr::Master, + ))), + ) + .await; + assert!(res.is_ok()); + let res = res.unwrap(); + let id = { + match res { + Value::Int(id) => id, + _ => { + panic!("Wrong return value for CLIENT ID command: {:?}", res); + } + } + }; + + // ask server to kill the connection + let mut cmd = redis::cmd("CLIENT"); + cmd.arg("KILL").arg("ID").arg(id).arg("SKIPME").arg("NO"); + let res = disconnecting_con + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + 0, + SlotAddr::Master, + ))), + ) + .await; + // assert server has closed connection + assert_eq!(res, Ok(Value::Int(1))); + + #[cfg(feature = "tokio-comp")] + // ensure reconnect happened in less than 100ms + sleep(futures_time::time::Duration::from_millis(100)).await; + + #[cfg(not(feature = "tokio-comp"))] + // no fast notification is available, wait for 1 periodic check + overhead + sleep(futures_time::time::Duration::from_secs(3 + 1)).await; + + let mut cmd = redis::cmd("CLIENT"); + cmd.arg("LIST").arg("TYPE").arg("NORMAL"); + let res = monitoring_con + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + 0, + SlotAddr::Master, + ))), + ) + .await; + assert!(res.is_ok()); + let res = res.unwrap(); + let client_list: String = { + match res { + // RESP2 + Value::BulkString(client_info) => { + // ensure 4 connections - 2 for each client, its save to unwrap here + String::from_utf8(client_info).unwrap() + } + // RESP3 + Value::VerbatimString { format: _, text } => text, + _ => { + panic!("Wrong return type for CLIENT LIST command: {:?}", res); + } + } + }; + assert_eq!(client_list.chars().filter(|&x| x == '\n').count(), 4); + } + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_restore_resp3_pubsub_state_passive_disconnect() { + let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); + let use_sharded = redis_ver.starts_with("7."); + + let mut client_subscriptions = PubSubSubscriptionInfo::from([( + PubSubSubscriptionKind::Exact, + HashSet::from([PubSubChannelOrPattern::from("test_channel".as_bytes())]), + )]); + + if use_sharded { + client_subscriptions.insert( + PubSubSubscriptionKind::Sharded, + HashSet::from([PubSubChannelOrPattern::from("test_channel_?".as_bytes())]), + ); + } + + // note topology change detection is not activated since no topology change is expected + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .retries(3) + .use_protocol(ProtocolVersion::RESP3) + .pubsub_subscriptions(client_subscriptions.clone()) + .periodic_connections_checks(Duration::from_secs(1)) + }, + false, + ); + + block_on_all(async move { + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + let mut _listening_con = cluster.async_connection(Some(tx.clone())).await; + // Note, publishing connection has the same pubsub config + let mut publishing_con = cluster.async_connection(None).await; + + // short sleep to allow the server to push subscription notification + sleep(futures_time::time::Duration::from_secs(1)).await; + + // validate subscriptions + validate_subscriptions(&client_subscriptions, &mut rx, false); + + // validate PUBLISH + let result = cmd("PUBLISH") + .arg("test_channel") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::Message, + vec![ + Value::BulkString("test_channel".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + + if use_sharded { + // validate SPUBLISH + let result = cmd("SPUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::SMessage, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + } + + // simulate passive disconnect + drop(cluster); + + // recreate the cluster, the assumtion is that the cluster is built with exactly the same params (ports, slots map...) + let _cluster = + TestClusterContext::new_with_cluster_client_builder(3, 0, |builder| builder, false); + + // sleep for 1 periodic_connections_checks + overhead + sleep(futures_time::time::Duration::from_secs(1 + 1)).await; + + // new subscription notifications due to resubscriptions + validate_subscriptions(&client_subscriptions, &mut rx, true); + + // validate PUBLISH + let result = cmd("PUBLISH") + .arg("test_channel") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::Message, + vec![ + Value::BulkString("test_channel".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + + if use_sharded { + // validate SPUBLISH + let result = cmd("SPUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::SMessage, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + } + + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_restore_resp3_pubsub_state_after_scale_out() { + let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); + let use_sharded = redis_ver.starts_with("7."); + + let mut client_subscriptions = PubSubSubscriptionInfo::from([ + // test_channel_? is used as it maps to 14212 slot, which is the last node in both 3 and 6 node config + // (assuming slots allocation is monotonicaly increasing starting from node 0) + ( + PubSubSubscriptionKind::Exact, + HashSet::from([PubSubChannelOrPattern::from("test_channel_?".as_bytes())]), + ), + ]); + + if use_sharded { + client_subscriptions.insert( + PubSubSubscriptionKind::Sharded, + HashSet::from([PubSubChannelOrPattern::from("test_channel_?".as_bytes())]), + ); + } + + let slot_14212 = get_slot(b"test_channel_?"); + assert_eq!(slot_14212, 14212); + + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .retries(3) + .use_protocol(ProtocolVersion::RESP3) + .pubsub_subscriptions(client_subscriptions.clone()) + // periodic connection check is required to detect the disconnect from the last node + .periodic_connections_checks(Duration::from_secs(1)) + // periodic topology check is required to detect topology change + .periodic_topology_checks(Duration::from_secs(1)) + .slots_refresh_rate_limit(Duration::from_secs(0), 0) + }, + false, + ); + + block_on_all(async move { + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + let mut _listening_con = cluster.async_connection(Some(tx.clone())).await; + // Note, publishing connection has the same pubsub config + let mut publishing_con = cluster.async_connection(None).await; + + // short sleep to allow the server to push subscription notification + sleep(futures_time::time::Duration::from_secs(1)).await; + + // validate subscriptions + validate_subscriptions(&client_subscriptions, &mut rx, false); + + // validate PUBLISH + let result = cmd("PUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::Message, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + + if use_sharded { + // validate SPUBLISH + let result = cmd("SPUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::SMessage, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + } + + // drop and recreate a cluster with more nodes + drop(cluster); + + // recreate the cluster, the assumtion is that the cluster is built with exactly the same params (ports, slots map...) + let cluster = + TestClusterContext::new_with_cluster_client_builder(6, 0, |builder| builder, false); + + // assume slot 14212 will reside in the last node + let last_server_port = { + let addr = cluster.cluster.servers.last().unwrap().addr.clone(); + match addr { + redis::ConnectionAddr::TcpTls { + host: _, + port, + insecure: _, + tls_params: _, + } => port, + redis::ConnectionAddr::Tcp(_, port) => port, + _ => { + panic!("Wrong server address type: {:?}", addr); + } + } + }; + + // wait for new topology discovery + loop { + let mut cmd = redis::cmd("INFO"); + cmd.arg("SERVER"); + let res = publishing_con + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + slot_14212, + SlotAddr::Master, + ))), + ) + .await; + assert!(res.is_ok()); + let res = res.unwrap(); + match res { + Value::VerbatimString { format: _, text } => { + if text.contains(format!("tcp_port:{}", last_server_port).as_str()) { + // new topology rediscovered + break; + } + } + _ => { + panic!("Wrong return type for INFO SERVER command: {:?}", res); + } + } + sleep(futures_time::time::Duration::from_secs(1)).await; + } + + // sleep for one one cycle of topology refresh + sleep(futures_time::time::Duration::from_secs(1)).await; + + // validate PUBLISH + let result = redis::cmd("PUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + // allow message to propagate + sleep(futures_time::time::Duration::from_secs(1)).await; + + loop { + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + // ignore disconnection and subscription notifications due to resubscriptions + if kind == PushKind::Message { + assert_eq!( + data, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ); + break; + } + } + + if use_sharded { + // validate SPUBLISH + let result = cmd("SPUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + // allow message to propagate + sleep(futures_time::time::Duration::from_secs(1)).await; + + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::SMessage, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + } + + drop(publishing_con); + drop(_listening_con); + + Ok(()) + }) + .unwrap(); + + block_on_all(async move { + sleep(futures_time::time::Duration::from_secs(10)).await; + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_resp3_pubsub() { + let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); + let use_sharded = redis_ver.starts_with("7."); + + let mut client_subscriptions = PubSubSubscriptionInfo::from([ + ( + PubSubSubscriptionKind::Exact, + HashSet::from([PubSubChannelOrPattern::from("test_channel_?".as_bytes())]), + ), + ( + PubSubSubscriptionKind::Pattern, + HashSet::from([ + PubSubChannelOrPattern::from("test_*".as_bytes()), + PubSubChannelOrPattern::from("*".as_bytes()), + ]), + ), + ]); + + if use_sharded { + client_subscriptions.insert( + PubSubSubscriptionKind::Sharded, + HashSet::from([PubSubChannelOrPattern::from("test_channel_?".as_bytes())]), + ); + } + + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .retries(3) + .use_protocol(ProtocolVersion::RESP3) + .pubsub_subscriptions(client_subscriptions.clone()) + }, + false, + ); + + block_on_all(async move { + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + let mut connection = cluster.async_connection(Some(tx.clone())).await; + + // short sleep to allow the server to push subscription notification + sleep(futures_time::time::Duration::from_secs(1)).await; + + validate_subscriptions(&client_subscriptions, &mut rx, false); + + let slot_14212 = get_slot(b"test_channel_?"); + assert_eq!(slot_14212, 14212); + + let slot_0_route = + redis::cluster_routing::Route::new(0, redis::cluster_routing::SlotAddr::Master); + let node_0_route = + redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(slot_0_route); + + // node 0 route is used to ensure that the publish is propagated correctly + let result = connection + .route_command( + redis::Cmd::new() + .arg("PUBLISH") + .arg("test_channel_?") + .arg("test_message"), + RoutingInfo::SingleNode(node_0_route.clone()), + ) + .await; + assert!(result.is_ok()); + + sleep(futures_time::time::Duration::from_secs(1)).await; + + let mut pmsg_cnt = 0; + let mut msg_cnt = 0; + for _ in 0..3 { + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data: _ } = result.unwrap(); + assert!(kind == PushKind::Message || kind == PushKind::PMessage); + if kind == PushKind::Message { + msg_cnt += 1; + } else { + pmsg_cnt += 1; + } + } + assert_eq!(msg_cnt, 1); + assert_eq!(pmsg_cnt, 2); + + if use_sharded { + let result = cmd("SPUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut connection) + .await; + assert_eq!(result, Ok(Value::Int(1))); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::SMessage, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + } + + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_periodic_checks_update_topology_after_failover() { + // This test aims to validate the functionality of periodic topology checks by detecting and updating topology changes. + // We will repeatedly execute CLUSTER NODES commands against the primary node responsible for slot 0, recording its node ID. + // Once we've successfully completed commands with the current primary, we will initiate a failover within the same shard. + // Since we are not executing key-based commands, we won't encounter MOVED errors that trigger a slot refresh. + // Consequently, we anticipate that only the periodic topology check will detect this change and trigger topology refresh. + // If successful, the node to which we route the CLUSTER NODES command should be the newly promoted node with a different node ID. + let cluster = TestClusterContext::new_with_cluster_client_builder( + 6, + 1, + |builder| { + builder + .periodic_topology_checks(Duration::from_millis(10)) + // Disable the rate limiter to refresh slots immediately on all MOVED errors + .slots_refresh_rate_limit(Duration::from_secs(0), 0) + }, + false, + ); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let mut prev_master_id = "".to_string(); + let max_requests = 5000; + let mut i = 0; + loop { + if i == 10 { + let mut cmd = redis::cmd("CLUSTER"); + cmd.arg("FAILOVER"); + cmd.arg("TAKEOVER"); + let res = connection + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode( + Route::new(0, SlotAddr::ReplicaRequired), + )), + ) + .await; + assert!(res.is_ok()); + } else if i == max_requests { + break; + } else { + let mut cmd = redis::cmd("CLUSTER"); + cmd.arg("NODES"); + let res = connection + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode( + Route::new(0, SlotAddr::Master), + )), + ) + .await + .expect("Failed executing CLUSTER NODES"); + let node_id = get_queried_node_id_if_master(res); + if let Some(current_master_id) = node_id { + if prev_master_id.is_empty() { + prev_master_id = current_master_id; + } else if prev_master_id != current_master_id { + return Ok::<_, RedisError>(()); + } + } + } + i += 1; + let _ = sleep(futures_time::time::Duration::from_millis(10)).await; + } + panic!("Topology change wasn't found!"); + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_recover_disconnected_management_connections() { + // This test aims to verify that the management connections used for periodic checks are reconnected, in case that they get killed. + // In order to test this, we choose a single node, kill all connections to it which aren't user connections, and then wait until new + // connections are created. + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder.periodic_topology_checks(Duration::from_millis(10)) + // Disable the rate limiter to refresh slots immediately + .slots_refresh_rate_limit(Duration::from_secs(0), 0) + }, + false, + ); + + block_on_all(async move { + let routing = RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + 1, + SlotAddr::Master, + ))); + + let mut connection = cluster.async_connection(None).await; + let max_requests = 5000; + + let connections = + get_clients_names_to_ids(&mut connection, routing.clone().into()).await; + assert!(connections.contains_key(MANAGEMENT_CONN_NAME)); + let management_conn_id = connections.get(MANAGEMENT_CONN_NAME).unwrap(); + + // Get the connection ID of the management connection + kill_connection(&mut connection, management_conn_id).await; + + let connections = + get_clients_names_to_ids(&mut connection, routing.clone().into()).await; + assert!(!connections.contains_key(MANAGEMENT_CONN_NAME)); + + for _ in 0..max_requests { + let _ = sleep(futures_time::time::Duration::from_millis(10)).await; + + let connections = + get_clients_names_to_ids(&mut connection, routing.clone().into()).await; + if connections.contains_key(MANAGEMENT_CONN_NAME) { + return Ok(()); + } + } + + panic!("Topology connection didn't reconnect!"); + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_with_client_name() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.client_name(RedisCluster::client_name().to_string()), + false, + ); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let client_info: String = cmd("CLIENT") + .arg("INFO") + .query_async(&mut connection) + .await + .unwrap(); + + let client_attrs = parse_client_info(&client_info); + + assert!( + client_attrs.contains_key("name"), + "Could not detect the 'name' attribute in CLIENT INFO output" + ); + + assert_eq!( + client_attrs["name"], + RedisCluster::client_name(), + "Incorrect client name, expecting: {}, got {}", + RedisCluster::client_name(), + client_attrs["name"] + ); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_reroute_from_replica_if_in_loading_state() { + /* Test replica in loading state. The expected behaviour is that the request will be directed to a different replica or the primary. + depends on the read from replica policy. */ + let name = "test_async_cluster_reroute_from_replica_if_in_loading_state"; + + let load_errors: Arc<_> = Arc::new(std::sync::Mutex::new(vec![])); + let load_errors_clone = load_errors.clone(); + + // requests should route to replica + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).read_from_replicas(), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica_using_config( + name, + cmd, + Some(vec![MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380, 6381], + slot_range: (0..16383), + }]), + )?; + match port { + 6380 | 6381 => { + load_errors_clone.lock().unwrap().push(port); + Err(parse_redis_value(b"-LOADING\r\n")) + } + 6379 => Err(Ok(Value::BulkString(b"123".to_vec()))), + _ => panic!("Wrong node"), + } + }, + ); + for _n in 0..3 { + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(123))); + } + + let mut load_errors_guard = load_errors.lock().unwrap(); + load_errors_guard.sort(); + + // We expected to get only 2 loading error since the 2 replicas are in loading state. + // The third iteration will be directed to the primary since the connections of the replicas were removed. + assert_eq!(*load_errors_guard, vec![6380, 6381]); + } + + #[test] + fn test_async_cluster_read_from_primary_when_primary_loading() { + // Test primary in loading state. The expected behaviour is that the request will be retried until the primary is no longer in loading state. + let name = "test_async_cluster_read_from_primary_when_primary_loading"; + + const RETRIES: u32 = 3; + const ITERATIONS: u32 = 2; + let load_errors = Arc::new(AtomicU32::new(0)); + let load_errors_clone = load_errors.clone(); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica_using_config( + name, + cmd, + Some(vec![MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380, 6381], + slot_range: (0..16383), + }]), + )?; + match port { + 6379 => { + let attempts = load_errors_clone.fetch_add(1, Ordering::Relaxed) + 1; + if attempts % RETRIES == 0 { + Err(Ok(Value::BulkString(b"123".to_vec()))) + } else { + Err(parse_redis_value(b"-LOADING\r\n")) + } + } + _ => panic!("Wrong node"), + } + }, + ); + for _n in 0..ITERATIONS { + runtime + .block_on( + cmd("GET") + .arg("test") + .query_async::<_, Value>(&mut connection), + ) + .unwrap(); + } + + assert_eq!(load_errors.load(Ordering::Relaxed), ITERATIONS * RETRIES); + } + + #[test] + fn test_async_cluster_can_be_created_with_partial_slot_coverage() { + let name = "test_async_cluster_can_be_created_with_partial_slot_coverage"; + let slots_config = Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![], + slot_range: (0..8000), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: vec![], + slot_range: (8201..16380), + }, + ]); + + let MockEnv { + async_connection: mut connection, + handler: _handler, + runtime, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], _| { + respond_startup_with_replica_using_config( + name, + received_cmd, + slots_config.clone(), + )?; + Err(Ok(Value::SimpleString("PONG".into()))) + }, + ); + + let res = runtime.block_on(connection.req_packed_command(&redis::cmd("PING"))); + assert!(res.is_ok()); + } + + #[test] + fn test_async_cluster_reconnect_after_complete_server_disconnect() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder.retries(2) + // Disable the rate limiter to refresh slots immediately + .slots_refresh_rate_limit(Duration::from_secs(0), 0) + }, + false, + ); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + drop(cluster); + let cmd = cmd("PING"); + + let result = connection + .route_command(&cmd, RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + .await; + // TODO - this should be a NoConnectionError, but ATM we get the errors from the failing + assert!(result.is_err()); + + // This will route to all nodes - different path through the code. + let result = connection.req_packed_command(&cmd).await; + // TODO - this should be a NoConnectionError, but ATM we get the errors from the failing + assert!(result.is_err()); + + let _cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.retries(2), + false, + ); + + let result = connection.req_packed_command(&cmd).await.unwrap(); + assert_eq!(result, Value::SimpleString("PONG".to_string())); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_reconnect_after_complete_server_disconnect_route_to_many() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.retries(3), + false, + ); + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + drop(cluster); + + // recreate cluster + let _cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.retries(2), + false, + ); + + let cmd = cmd("PING"); + // explicitly route to all primaries and request all succeeded + let result = connection + .route_command( + &cmd, + RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllMasters, + Some(redis::cluster_routing::ResponsePolicy::AllSucceeded), + )), + ) + .await; + assert!(result.is_ok()); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_blocking_command_when_cluster_drops() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.retries(3), + false, + ); + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + futures::future::join( + async { + let res = connection.blpop::<&str, f64>("foo", 0.0).await; + assert!(res.is_err()); + println!("blpop returned error {:?}", res.map_err(|e| e.to_string())); + }, + async { + let _ = sleep(futures_time::time::Duration::from_secs(3)).await; + drop(cluster); + }, + ) + .await; + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_saves_reconnected_connection() { + let name = "test_async_cluster_saves_reconnected_connection"; + let ping_attempts = Arc::new(AtomicI32::new(0)); + let ping_attempts_clone = ping_attempts.clone(); + let get_attempts = AtomicI32::new(0); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(1), + name, + move |cmd: &[u8], port| { + if port == 6380 { + respond_startup_two_nodes(name, cmd)?; + return Err(parse_redis_value( + format!("-MOVED 123 {name}:6379\r\n").as_bytes(), + )); + } + + if contains_slice(cmd, b"PING") { + let connect_attempt = ping_attempts_clone.fetch_add(1, Ordering::Relaxed); + let past_get_attempts = get_attempts.load(Ordering::Relaxed); + // We want connection checks to fail after the first GET attempt, until it retries. Hence, we wait for 5 PINGs - + // 1. initial connection, + // 2. refresh slots on client creation, + // 3. refresh_connections `check_connection` after first GET failed, + // 4. refresh_connections `connect_and_check` after first GET failed, + // 5. reconnect on 2nd GET attempt. + // more than 5 attempts mean that the server reconnects more than once, which is the behavior we're testing against. + if past_get_attempts != 1 || connect_attempt > 3 { + respond_startup_two_nodes(name, cmd)?; + } + if connect_attempt > 5 { + panic!("Too many pings!"); + } + Err(Err(broken_pipe_error())) + } else { + respond_startup_two_nodes(name, cmd)?; + let past_get_attempts = get_attempts.fetch_add(1, Ordering::Relaxed); + // we fail the initial GET request, and after that we'll fail the first reconnect attempt, in the `refresh_connections` attempt. + if past_get_attempts == 0 { + // Error once with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + Err(Err(broken_pipe_error())) + } else { + Err(Ok(Value::BulkString(b"123".to_vec()))) + } + } + }, + ); + + for _ in 0..4 { + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + // If you need to change the number here due to a change in the cluster, you probably also need to adjust the test. + // See the PING counts above to explain why 5 is the target number. + assert_eq!(ping_attempts.load(Ordering::Acquire), 5); + } + + #[test] + fn test_async_cluster_periodic_checks_use_management_connection() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder.periodic_topology_checks(Duration::from_millis(10)) + // Disable the rate limiter to refresh slots immediately on the periodic checks + .slots_refresh_rate_limit(Duration::from_secs(0), 0) + }, + false, + ); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let mut client_list = "".to_string(); + let max_requests = 1000; + let mut i = 0; + loop { + if i == max_requests { + break; + } else { + client_list = cmd("CLIENT") + .arg("LIST") + .query_async::<_, String>(&mut connection) + .await + .expect("Failed executing CLIENT LIST"); + let mut client_list_parts = client_list.split('\n'); + if client_list_parts + .any(|line| line.contains(MANAGEMENT_CONN_NAME) && line.contains("cmd=cluster")) + && client_list.matches(MANAGEMENT_CONN_NAME).count() == 1 { + return Ok::<_, RedisError>(()); + } + } + i += 1; + let _ = sleep(futures_time::time::Duration::from_millis(10)).await; + } + panic!("Couldn't find a management connection or the connection wasn't used to execute CLUSTER SLOTS {:?}", client_list); + }) + .unwrap(); + } + + async fn get_clients_names_to_ids( + connection: &mut ClusterConnection, + routing: Option, + ) -> HashMap { + let mut client_list_cmd = redis::cmd("CLIENT"); + client_list_cmd.arg("LIST"); + let value = match routing { + Some(routing) => connection.route_command(&client_list_cmd, routing).await, + None => connection.req_packed_command(&client_list_cmd).await, + } + .unwrap(); + let string = String::from_owned_redis_value(value).unwrap(); + string + .split('\n') + .filter_map(|line| { + if line.is_empty() { + return None; + } + let key_values = line + .split(' ') + .filter_map(|value| { + let mut split = value.split('='); + match (split.next(), split.next()) { + (Some(key), Some(val)) => Some((key, val)), + _ => None, + } + }) + .collect::>(); + match (key_values.get("name"), key_values.get("id")) { + (Some(key), Some(val)) if !val.is_empty() => { + Some((key.to_string(), val.to_string())) + } + _ => None, + } + }) + .collect() + } + + async fn kill_connection(killer_connection: &mut ClusterConnection, connection_to_kill: &str) { + let mut cmd = redis::cmd("CLIENT"); + cmd.arg("KILL"); + cmd.arg("ID"); + cmd.arg(connection_to_kill); + // Kill the management connection in the primary node that holds slot 0 + assert!(killer_connection + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + 0, + SlotAddr::Master, + )),), + ) + .await + .is_ok()); + } + + #[test] + fn test_async_cluster_only_management_connection_is_reconnected_after_connection_failure() { + // This test will check two aspects: + // 1. Ensuring that after a disconnection in the management connection, a new management connection is established. + // 2. Confirming that a failure in the management connection does not impact the user connection, which should remain intact. + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.periodic_topology_checks(Duration::from_millis(10)), + false, + ); + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let _client_list = "".to_string(); + let max_requests = 500; + let mut i = 0; + // Set the name of the client connection to 'user-connection', so we'll be able to identify it later on + assert!(cmd("CLIENT") + .arg("SETNAME") + .arg("user-connection") + .query_async::<_, Value>(&mut connection) + .await + .is_ok()); + // Get the client list + let names_to_ids = get_clients_names_to_ids(&mut connection, Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new(0, SlotAddr::Master))))).await; + + // Get the connection ID of 'user-connection' + let user_conn_id = names_to_ids.get("user-connection").unwrap(); + // Get the connection ID of the management connection + let management_conn_id = names_to_ids.get(MANAGEMENT_CONN_NAME).unwrap(); + // Get another connection that will be used to kill the management connection + let mut killer_connection = cluster.async_connection(None).await; + kill_connection(&mut killer_connection, management_conn_id).await; + loop { + // In this loop we'll wait for the new management connection to be established + if i == max_requests { + break; + } else { + let names_to_ids = get_clients_names_to_ids(&mut connection, Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new(0, SlotAddr::Master))))).await; + if names_to_ids.contains_key(MANAGEMENT_CONN_NAME) { + // A management connection is found + let curr_management_conn_id = + names_to_ids.get(MANAGEMENT_CONN_NAME).unwrap(); + let curr_user_conn_id = + names_to_ids.get("user-connection").unwrap(); + // Confirm that the management connection has a new connection ID, and verify that the user connection remains unaffected. + if (curr_management_conn_id != management_conn_id) + && (curr_user_conn_id == user_conn_id) + { + return Ok::<_, RedisError>(()); + } + } else { + i += 1; + let _ = sleep(futures_time::time::Duration::from_millis(50)).await; + continue; + } + } + } + panic!( + "No reconnection of the management connection found, or there was an unwantedly reconnection of the user connections. + \nprev_management_conn_id={:?},prev_user_conn_id={:?}\nclient list={:?}", + management_conn_id, user_conn_id, names_to_ids + ); + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_dont_route_to_a_random_on_non_key_based_cmd() { + // This test verifies that non-key-based commands do not get routed to a random node + // when no connection is found for the given route. Instead, the appropriate error + // should be raised. + let name = "test_async_cluster_dont_route_to_a_random_on_non_key_based_cmd"; + let request_counter = Arc::new(AtomicU32::new(0)); + let cloned_req_counter = request_counter.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(1), + name, + move |received_cmd: &[u8], _| { + let slots_config_vec = vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![], + slot_range: (0_u16..8000_u16), + }, + MockSlotRange { + primary_port: 6380, + replica_ports: vec![], + // Don't cover all slots + slot_range: (8001_u16..12000_u16), + }, + ]; + respond_startup_with_config(name, received_cmd, Some(slots_config_vec), false)?; + // If requests are sent to random nodes, they will be caught and counted here. + request_counter.fetch_add(1, Ordering::Relaxed); + Err(Ok(Value::Nil)) + }, + ); + + runtime + .block_on(async move { + let uncovered_slot = 16000; + let route = redis::cluster_routing::Route::new( + uncovered_slot, + redis::cluster_routing::SlotAddr::Master, + ); + let single_node_route = + redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(route); + let routing = RoutingInfo::SingleNode(single_node_route); + let res = connection + .route_command(&redis::cmd("FLUSHALL"), routing) + .await; + assert!(res.is_err()); + let res_err = res.unwrap_err(); + assert_eq!( + res_err.kind(), + ErrorKind::ConnectionNotFoundForRoute, + "{:?}", + res_err + ); + assert_eq!(cloned_req_counter.load(Ordering::Relaxed), 0); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_route_to_random_on_key_based_cmd() { + // This test verifies that key-based commands get routed to a random node + // when no connection is found for the given route. The command should + // then be redirected correctly by the server's MOVED error. + let name = "test_async_cluster_route_to_random_on_key_based_cmd"; + let request_counter = Arc::new(AtomicU32::new(0)); + let cloned_req_counter = request_counter.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]), + name, + move |received_cmd: &[u8], _| { + let slots_config_vec = vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![], + slot_range: (0_u16..8000_u16), + }, + MockSlotRange { + primary_port: 6380, + replica_ports: vec![], + // Don't cover all slots + slot_range: (8001_u16..12000_u16), + }, + ]; + respond_startup_with_config(name, received_cmd, Some(slots_config_vec), false)?; + if contains_slice(received_cmd, b"GET") { + if request_counter.fetch_add(1, Ordering::Relaxed) == 0 { + return Err(parse_redis_value( + format!("-MOVED 12182 {name}:6380\r\n").as_bytes(), + )); + } else { + return Err(Ok(Value::SimpleString("bar".into()))); + } + } + panic!("unexpected command {:?}", received_cmd); + }, + ); + + runtime + .block_on(async move { + // The keyslot of "foo" is 12182 and it isn't covered by any node, so we expect the + // request to be routed to a random node and then to be redirected to the MOVED node (2 requests in total) + let res: String = connection.get("foo").await.unwrap(); + assert_eq!(res, "bar".to_string()); + assert_eq!(cloned_req_counter.load(Ordering::Relaxed), 2); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_do_not_retry_when_receiver_was_dropped() { + let name = "test_async_cluster_do_not_retry_when_receiver_was_dropped"; + let cmd = cmd("FAKE_COMMAND"); + let packed_cmd = cmd.get_packed_command(); + let request_counter = Arc::new(AtomicU32::new(0)); + let cloned_req_counter = request_counter.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(5) + .max_retry_wait(2) + .min_retry_wait(2), + name, + move |received_cmd: &[u8], _| { + respond_startup(name, received_cmd)?; + + if received_cmd == packed_cmd { + cloned_req_counter.fetch_add(1, Ordering::Relaxed); + return Err(Err((ErrorKind::TryAgain, "seriously, try again").into())); + } + + Err(Ok(Value::Okay)) + }, + ); + + runtime.block_on(async move { + let err = cmd + .query_async::<_, Value>(&mut connection) + .timeout(futures_time::time::Duration::from_millis(1)) + .await + .unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::TimedOut); + + // we sleep here, to allow the cluster connection time to retry. We expect it won't, but without this + // sleep the test will complete before the the runtime gave the connection time to retry, which would've made the + // test pass regardless of whether the connection tries retrying or not. + sleep(Duration::from_millis(10).into()).await; + }); + + assert_eq!(request_counter.load(Ordering::Relaxed), 1); + } + + #[cfg(feature = "tls-rustls")] + mod mtls_test { + use crate::support::mtls_test::create_cluster_client_from_cluster; + use redis::ConnectionInfo; + + use super::*; + + #[test] + fn test_async_cluster_basic_cmd_with_mtls() { + let cluster = TestClusterContext::new_with_mtls(3, 0); + block_on_all(async move { + let client = create_cluster_client_from_cluster(&cluster, true).unwrap(); + let mut connection = client.get_async_connection(None).await.unwrap(); + cmd("SET") + .arg("test") + .arg("test_data") + .query_async(&mut connection) + .await?; + let res: String = cmd("GET") + .arg("test") + .clone() + .query_async(&mut connection) + .await?; + assert_eq!(res, "test_data"); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_should_not_connect_without_mtls_enabled() { + let cluster = TestClusterContext::new_with_mtls(3, 0); + block_on_all(async move { + let client = create_cluster_client_from_cluster(&cluster, false).unwrap(); + let connection = client.get_async_connection(None).await; + match cluster.cluster.servers.first().unwrap().connection_info() { + ConnectionInfo { + addr: redis::ConnectionAddr::TcpTls { .. }, + .. + } => { + if connection.is_ok() { + panic!("Must NOT be able to connect without client credentials if server accepts TLS"); + } + } + _ => { + if let Err(e) = connection { + panic!("Must be able to connect without client credentials if server does NOT accept TLS: {e:?}"); + } + } + } + Ok::<_, RedisError>(()) + }).unwrap(); + } + } +} diff --git a/glide-core/redis-rs/redis/tests/test_cluster_scan.rs b/glide-core/redis-rs/redis/tests/test_cluster_scan.rs new file mode 100644 index 0000000000..29a3c87b48 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_cluster_scan.rs @@ -0,0 +1,849 @@ +#![cfg(feature = "cluster-async")] +mod support; + +#[cfg(test)] +mod test_cluster_scan_async { + use crate::support::*; + use rand::Rng; + use redis::cluster_routing::{RoutingInfo, SingleNodeRoutingInfo}; + use redis::{cmd, from_redis_value, ObjectType, RedisResult, ScanStateRC, Value}; + use std::time::Duration; + + async fn kill_one_node( + cluster: &TestClusterContext, + slot_distribution: Vec<(String, String, String, Vec>)>, + ) -> RoutingInfo { + let mut cluster_conn = cluster.async_connection(None).await; + let distribution_clone = slot_distribution.clone(); + let index_of_random_node = rand::thread_rng().gen_range(0..slot_distribution.len()); + let random_node = distribution_clone.get(index_of_random_node).unwrap(); + let random_node_route_info = RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: random_node.1.clone(), + port: random_node.2.parse::().unwrap(), + }); + let random_node_id = &random_node.0; + // Create connections to all nodes + for node in &distribution_clone { + if random_node_id == &node.0 { + continue; + } + let node_route = RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: node.1.clone(), + port: node.2.parse::().unwrap(), + }); + + let mut forget_cmd = cmd("CLUSTER"); + forget_cmd.arg("FORGET").arg(random_node_id); + let _: RedisResult = cluster_conn + .route_command(&forget_cmd, node_route.clone()) + .await; + } + let mut shutdown_cmd = cmd("SHUTDOWN"); + shutdown_cmd.arg("NOSAVE"); + let _: RedisResult = cluster_conn + .route_command(&shutdown_cmd, random_node_route_info.clone()) + .await; + random_node_route_info + } + + #[tokio::test] + async fn test_async_cluster_scan() { + let cluster = TestClusterContext::new(3, 0); + let mut connection = cluster.async_connection(None).await; + + // Set some keys + for i in 0..10 { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + } + + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = vec![]; + loop { + let (next_cursor, scan_keys): (ScanStateRC, Vec) = connection + .cluster_scan(scan_state_rc, None, None) + .await + .unwrap(); + scan_state_rc = next_cursor; + let mut scan_keys = scan_keys + .into_iter() + .map(|v| from_redis_value(&v).unwrap()) + .collect::>(); // Change the type of `keys` to `Vec` + keys.append(&mut scan_keys); + if scan_state_rc.is_finished() { + break; + } + } + // Check if all keys were scanned + keys.sort(); + keys.dedup(); + for (i, key) in keys.iter().enumerate() { + assert_eq!(key.to_owned(), format!("key{}", i)); + } + } + + #[tokio::test] // test cluster scan with slot migration in the middle + async fn test_async_cluster_scan_with_migration() { + let cluster = TestClusterContext::new(3, 0); + + let mut connection = cluster.async_connection(None).await; + // Set some keys + let mut expected_keys: Vec = Vec::new(); + + for i in 0..1000 { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + } + + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = Vec::new(); + let mut count = 0; + loop { + count += 1; + let (next_cursor, scan_keys): (ScanStateRC, Vec) = connection + .cluster_scan(scan_state_rc, None, None) + .await + .unwrap(); + scan_state_rc = next_cursor; + let scan_keys = scan_keys + .into_iter() + .map(|v| from_redis_value(&v).unwrap()) + .collect::>(); + keys.extend(scan_keys); + if scan_state_rc.is_finished() { + break; + } + if count == 5 { + let mut cluster_nodes = cluster.get_cluster_nodes().await; + let slot_distribution = cluster.get_slots_ranges_distribution(&cluster_nodes); + cluster + .migrate_slots_from_node_to_another(slot_distribution.clone()) + .await; + for node in &slot_distribution { + let ready = cluster + .wait_for_connection_is_ready(&RoutingInfo::SingleNode( + SingleNodeRoutingInfo::ByAddress { + host: node.1.clone(), + port: node.2.parse::().unwrap(), + }, + )) + .await; + match ready { + Ok(_) => {} + Err(e) => { + println!("error: {:?}", e); + break; + } + } + } + + cluster_nodes = cluster.get_cluster_nodes().await; + // Compare slot distribution before and after migration + let new_slot_distribution = cluster.get_slots_ranges_distribution(&cluster_nodes); + assert_ne!(slot_distribution, new_slot_distribution); + } + } + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + assert_eq!(keys, expected_keys); + } + + #[tokio::test] // test cluster scan with node fail in the middle + async fn test_async_cluster_scan_with_fail() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.retries(1), + false, + ); + let mut connection = cluster.async_connection(None).await; + // Set some keys + for i in 0..1000 { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + } + + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = Vec::new(); + let mut count = 0; + let mut result: RedisResult = Ok(Value::Nil); + loop { + count += 1; + let scan_response: RedisResult<(ScanStateRC, Vec)> = + connection.cluster_scan(scan_state_rc, None, None).await; + let (next_cursor, scan_keys) = match scan_response { + Ok((cursor, keys)) => (cursor, keys), + Err(e) => { + result = Err(e); + break; + } + }; + scan_state_rc = next_cursor; + keys.extend(scan_keys.into_iter().map(|v| from_redis_value(&v).unwrap())); + if scan_state_rc.is_finished() { + break; + } + if count == 5 { + let cluster_nodes = cluster.get_cluster_nodes().await; + let slot_distribution = cluster.get_slots_ranges_distribution(&cluster_nodes); + // simulate node failure + let killed_node_routing = kill_one_node(&cluster, slot_distribution.clone()).await; + let ready = cluster.wait_for_fail_to_finish(&killed_node_routing).await; + match ready { + Ok(_) => {} + Err(e) => { + println!("error: {:?}", e); + break; + } + } + let cluster_nodes = cluster.get_cluster_nodes().await; + let new_slot_distribution = cluster.get_slots_ranges_distribution(&cluster_nodes); + assert_ne!(slot_distribution, new_slot_distribution); + } + } + // We expect an error of finding address + assert!(result.is_err()); + } + + #[tokio::test] // Test cluster scan with killing all masters during scan + async fn test_async_cluster_scan_with_all_masters_down() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 6, + 1, + |builder| { + builder + .slots_refresh_rate_limit(Duration::from_secs(0), 0) + .retries(1) + }, + false, + ); + + let mut connection = cluster.async_connection(None).await; + + let mut expected_keys: Vec = Vec::new(); + + cluster.wait_for_cluster_up(); + + let mut cluster_nodes = cluster.get_cluster_nodes().await; + + let slot_distribution = cluster.get_slots_ranges_distribution(&cluster_nodes); + let masters = cluster.get_masters(&cluster_nodes).await; + let replicas = cluster.get_replicas(&cluster_nodes).await; + + for i in 0..1000 { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + } + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = Vec::new(); + let mut count = 0; + loop { + count += 1; + let scan_response: RedisResult<(ScanStateRC, Vec)> = + connection.cluster_scan(scan_state_rc, None, None).await; + if scan_response.is_err() { + println!("error: {:?}", scan_response); + } + let (next_cursor, scan_keys) = scan_response.unwrap(); + scan_state_rc = next_cursor; + keys.extend(scan_keys.into_iter().map(|v| from_redis_value(&v).unwrap())); + if scan_state_rc.is_finished() { + break; + } + if count == 5 { + for replica in replicas.iter() { + let mut failover_cmd = cmd("CLUSTER"); + let _: RedisResult = connection + .route_command( + failover_cmd.arg("FAILOVER").arg("TAKEOVER"), + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: replica[1].clone(), + port: replica[2].parse::().unwrap(), + }), + ) + .await; + let ready = cluster + .wait_for_connection_is_ready(&RoutingInfo::SingleNode( + SingleNodeRoutingInfo::ByAddress { + host: replica[1].clone(), + port: replica[2].parse::().unwrap(), + }, + )) + .await; + match ready { + Ok(_) => {} + Err(e) => { + println!("error: {:?}", e); + break; + } + } + } + + for master in masters.iter() { + for replica in replicas.clone() { + let mut forget_cmd = cmd("CLUSTER"); + forget_cmd.arg("FORGET").arg(master[0].clone()); + let _: RedisResult = connection + .route_command( + &forget_cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: replica[1].clone(), + port: replica[2].parse::().unwrap(), + }), + ) + .await; + } + } + for master in masters.iter() { + let mut shut_cmd = cmd("SHUTDOWN"); + shut_cmd.arg("NOSAVE"); + let _ = connection + .route_command( + &shut_cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: master[1].clone(), + port: master[2].parse::().unwrap(), + }), + ) + .await; + let ready = cluster + .wait_for_fail_to_finish(&RoutingInfo::SingleNode( + SingleNodeRoutingInfo::ByAddress { + host: master[1].clone(), + port: master[2].parse::().unwrap(), + }, + )) + .await; + match ready { + Ok(_) => {} + Err(e) => { + println!("error: {:?}", e); + break; + } + } + } + for replica in replicas.iter() { + let ready = cluster + .wait_for_connection_is_ready(&RoutingInfo::SingleNode( + SingleNodeRoutingInfo::ByAddress { + host: replica[1].clone(), + port: replica[2].parse::().unwrap(), + }, + )) + .await; + match ready { + Ok(_) => {} + Err(e) => { + println!("error: {:?}", e); + break; + } + } + } + cluster_nodes = cluster.get_cluster_nodes().await; + let new_slot_distribution = cluster.get_slots_ranges_distribution(&cluster_nodes); + assert_ne!(slot_distribution, new_slot_distribution); + } + } + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + assert_eq!(keys, expected_keys); + } + + #[tokio::test] + // Test cluster scan with killing all replicas during scan + async fn test_async_cluster_scan_with_all_replicas_down() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 6, + 1, + |builder| { + builder + .slots_refresh_rate_limit(Duration::from_secs(0), 0) + .retries(1) + }, + false, + ); + + let mut connection = cluster.async_connection(None).await; + + let mut expected_keys: Vec = Vec::new(); + + for server in cluster.cluster.servers.iter() { + let address = server.addr.clone().to_string(); + let host_and_port = address.split(':'); + let host = host_and_port.clone().next().unwrap().to_string(); + let port = host_and_port + .clone() + .last() + .unwrap() + .parse::() + .unwrap(); + let ready = cluster + .wait_for_connection_is_ready(&RoutingInfo::SingleNode( + SingleNodeRoutingInfo::ByAddress { host, port }, + )) + .await; + match ready { + Ok(_) => {} + Err(e) => { + println!("error: {:?}", e); + break; + } + } + } + + let cluster_nodes = cluster.get_cluster_nodes().await; + + let replicas = cluster.get_replicas(&cluster_nodes).await; + + for i in 0..1000 { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + } + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = Vec::new(); + let mut count = 0; + loop { + count += 1; + let scan_response: RedisResult<(ScanStateRC, Vec)> = + connection.cluster_scan(scan_state_rc, None, None).await; + if scan_response.is_err() { + println!("error: {:?}", scan_response); + } + let (next_cursor, scan_keys) = scan_response.unwrap(); + scan_state_rc = next_cursor; + keys.extend(scan_keys.into_iter().map(|v| from_redis_value(&v).unwrap())); + if scan_state_rc.is_finished() { + break; + } + if count == 5 { + for replica in replicas.iter() { + let mut shut_cmd = cmd("SHUTDOWN"); + shut_cmd.arg("NOSAVE"); + let ready: RedisResult = connection + .route_command( + &shut_cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: replica[1].clone(), + port: replica[2].parse::().unwrap(), + }), + ) + .await; + match ready { + Ok(_) => {} + Err(e) => { + println!("error: {:?}", e); + break; + } + } + } + let new_cluster_nodes = cluster.get_cluster_nodes().await; + assert_ne!(cluster_nodes, new_cluster_nodes); + } + } + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + assert_eq!(keys, expected_keys); + } + #[tokio::test] + // Test cluster scan with setting keys for each iteration + async fn test_async_cluster_scan_set_in_the_middle() { + let cluster = TestClusterContext::new(3, 0); + let mut connection = cluster.async_connection(None).await; + let mut expected_keys: Vec = Vec::new(); + let mut i = 0; + // Set some keys + loop { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + i += 1; + if i == 1000 { + break; + } + } + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = vec![]; + loop { + let (next_cursor, scan_keys): (ScanStateRC, Vec) = connection + .cluster_scan(scan_state_rc, None, None) + .await + .unwrap(); + scan_state_rc = next_cursor; + let mut scan_keys = scan_keys + .into_iter() + .map(|v| from_redis_value(&v).unwrap()) + .collect::>(); // Change the type of `keys` to `Vec` + keys.append(&mut scan_keys); + if scan_state_rc.is_finished() { + break; + } + let key = format!("key{}", i); + i += 1; + let res: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + assert!(res.is_ok()); + } + // Check if all keys were scanned + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + for key in expected_keys.iter() { + assert!(keys.contains(key)); + } + assert!(keys.len() >= expected_keys.len()); + } + + #[tokio::test] + // Test cluster scan with deleting keys for each iteration + async fn test_async_cluster_scan_dell_in_the_middle() { + let cluster = TestClusterContext::new(3, 0); + + let mut connection = cluster.async_connection(None).await; + let mut expected_keys: Vec = Vec::new(); + let mut i = 0; + // Set some keys + loop { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + i += 1; + if i == 1000 { + break; + } + } + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = vec![]; + loop { + let (next_cursor, scan_keys): (ScanStateRC, Vec) = connection + .cluster_scan(scan_state_rc, None, None) + .await + .unwrap(); + scan_state_rc = next_cursor; + let mut scan_keys = scan_keys + .into_iter() + .map(|v| from_redis_value(&v).unwrap()) + .collect::>(); // Change the type of `keys` to `Vec` + keys.append(&mut scan_keys); + if scan_state_rc.is_finished() { + break; + } + i -= 1; + let key = format!("key{}", i); + + let res: Result<(), redis::RedisError> = redis::cmd("del") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + assert!(res.is_ok()); + expected_keys.remove(i as usize); + } + // Check if all keys were scanned + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + for key in expected_keys.iter() { + assert!(keys.contains(key)); + } + assert!(keys.len() >= expected_keys.len()); + } + + #[tokio::test] + // Testing cluster scan with Pattern option + async fn test_async_cluster_scan_with_pattern() { + let cluster = TestClusterContext::new(3, 0); + let mut connection = cluster.async_connection(None).await; + let mut expected_keys: Vec = Vec::new(); + let mut i = 0; + // Set some keys + loop { + let key = format!("key:pattern:{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + let non_relevant_key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&non_relevant_key) + .arg("value") + .query_async(&mut connection) + .await; + i += 1; + if i == 500 { + break; + } + } + + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = vec![]; + loop { + let (next_cursor, scan_keys): (ScanStateRC, Vec) = connection + .cluster_scan_with_pattern(scan_state_rc, "key:pattern:*", None, None) + .await + .unwrap(); + scan_state_rc = next_cursor; + let mut scan_keys = scan_keys + .into_iter() + .map(|v| from_redis_value(&v).unwrap()) + .collect::>(); // Change the type of `keys` to `Vec` + keys.append(&mut scan_keys); + if scan_state_rc.is_finished() { + break; + } + } + // Check if all keys were scanned + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + for key in expected_keys.iter() { + assert!(keys.contains(key)); + } + assert!(keys.len() == expected_keys.len()); + } + + #[tokio::test] + // Testing cluster scan with TYPE option + async fn test_async_cluster_scan_with_type() { + let cluster = TestClusterContext::new(3, 0); + let mut connection = cluster.async_connection(None).await; + let mut expected_keys: Vec = Vec::new(); + let mut i = 0; + // Set some keys + loop { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SADD") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + let key = format!("key-that-is-not-set{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + i += 1; + if i == 500 { + break; + } + } + + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = vec![]; + loop { + let (next_cursor, scan_keys): (ScanStateRC, Vec) = connection + .cluster_scan(scan_state_rc, None, Some(ObjectType::Set)) + .await + .unwrap(); + scan_state_rc = next_cursor; + let mut scan_keys = scan_keys + .into_iter() + .map(|v| from_redis_value(&v).unwrap()) + .collect::>(); // Change the type of `keys` to `Vec` + keys.append(&mut scan_keys); + if scan_state_rc.is_finished() { + break; + } + } + // Check if all keys were scanned + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + for key in expected_keys.iter() { + assert!(keys.contains(key)); + } + assert!(keys.len() == expected_keys.len()); + } + + #[tokio::test] + // Testing cluster scan with COUNT option + async fn test_async_cluster_scan_with_count() { + let cluster = TestClusterContext::new(3, 0); + let mut connection = cluster.async_connection(None).await; + let mut expected_keys: Vec = Vec::new(); + let mut i = 0; + // Set some keys + loop { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + i += 1; + if i == 1000 { + break; + } + } + + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = vec![]; + let mut comparing_times = 0; + loop { + let (next_cursor, scan_keys): (ScanStateRC, Vec) = connection + .cluster_scan(scan_state_rc.clone(), Some(100), None) + .await + .unwrap(); + let (_, scan_without_count_keys): (ScanStateRC, Vec) = connection + .cluster_scan(scan_state_rc, Some(100), None) + .await + .unwrap(); + if !scan_keys.is_empty() && !scan_without_count_keys.is_empty() { + assert!(scan_keys.len() >= scan_without_count_keys.len()); + + comparing_times += 1; + } + scan_state_rc = next_cursor; + let mut scan_keys = scan_keys + .into_iter() + .map(|v| from_redis_value(&v).unwrap()) + .collect::>(); // Change the type of `keys` to `Vec` + keys.append(&mut scan_keys); + if scan_state_rc.is_finished() { + break; + } + } + assert!(comparing_times > 0); + // Check if all keys were scanned + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + for key in expected_keys.iter() { + assert!(keys.contains(key)); + } + assert!(keys.len() == expected_keys.len()); + } + + #[tokio::test] + // Testing cluster scan when connection fails in the middle and we get an error + // then cluster up again and scanning can continue without any problem + async fn test_async_cluster_scan_failover() { + let mut cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.retries(1), + false, + ); + let mut connection = cluster.async_connection(None).await; + let mut i = 0; + loop { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + i += 1; + if i == 1000 { + break; + } + } + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = Vec::new(); + let mut count = 0; + loop { + count += 1; + let scan_response: RedisResult<(ScanStateRC, Vec)> = + connection.cluster_scan(scan_state_rc, None, None).await; + if scan_response.is_err() { + println!("error: {:?}", scan_response); + } + let (next_cursor, scan_keys) = scan_response.unwrap(); + scan_state_rc = next_cursor; + keys.extend(scan_keys.into_iter().map(|v| from_redis_value(&v).unwrap())); + if scan_state_rc.is_finished() { + break; + } + if count == 5 { + drop(cluster); + let scan_response: RedisResult<(ScanStateRC, Vec)> = connection + .cluster_scan(scan_state_rc.clone(), None, None) + .await; + assert!(scan_response.is_err()); + break; + }; + } + cluster = TestClusterContext::new(3, 0); + connection = cluster.async_connection(None).await; + loop { + let scan_response: RedisResult<(ScanStateRC, Vec)> = + connection.cluster_scan(scan_state_rc, None, None).await; + if scan_response.is_err() { + println!("error: {:?}", scan_response); + } + let (next_cursor, scan_keys) = scan_response.unwrap(); + scan_state_rc = next_cursor; + keys.extend(scan_keys.into_iter().map(|v| from_redis_value(&v).unwrap())); + if scan_state_rc.is_finished() { + break; + } + } + } +} diff --git a/glide-core/redis-rs/redis/tests/test_geospatial.rs b/glide-core/redis-rs/redis/tests/test_geospatial.rs new file mode 100644 index 0000000000..8bec9a1d73 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_geospatial.rs @@ -0,0 +1,197 @@ +#![cfg(feature = "geospatial")] + +use assert_approx_eq::assert_approx_eq; + +use redis::geo::{Coord, RadiusOptions, RadiusOrder, RadiusSearchResult, Unit}; +use redis::{Commands, RedisResult}; + +mod support; +use crate::support::*; + +const PALERMO: (&str, &str, &str) = ("13.361389", "38.115556", "Palermo"); +const CATANIA: (&str, &str, &str) = ("15.087269", "37.502669", "Catania"); +const AGRIGENTO: (&str, &str, &str) = ("13.5833332", "37.316667", "Agrigento"); + +#[test] +fn test_geoadd_single_tuple() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", PALERMO), Ok(1)); +} + +#[test] +fn test_geoadd_multiple_tuples() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", &[PALERMO, CATANIA]), Ok(2)); +} + +#[test] +fn test_geodist_existing_members() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", &[PALERMO, CATANIA]), Ok(2)); + + let dist: f64 = con + .geo_dist("my_gis", PALERMO.2, CATANIA.2, Unit::Kilometers) + .unwrap(); + assert_approx_eq!(dist, 166.2742, 0.001); +} + +#[test] +fn test_geodist_support_option() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", &[PALERMO, CATANIA]), Ok(2)); + + // We should be able to extract the value as an Option<_>, so we can detect + // if a member is missing + + let result: RedisResult> = con.geo_dist("my_gis", PALERMO.2, "none", Unit::Meters); + assert_eq!(result, Ok(None)); + + let result: RedisResult> = + con.geo_dist("my_gis", PALERMO.2, CATANIA.2, Unit::Meters); + assert_ne!(result, Ok(None)); + + let dist = result.unwrap().unwrap(); + assert_approx_eq!(dist, 166_274.151_6, 0.01); +} + +#[test] +fn test_geohash() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", &[PALERMO, CATANIA]), Ok(2)); + let result: RedisResult> = con.geo_hash("my_gis", PALERMO.2); + assert_eq!(result, Ok(vec![String::from("sqc8b49rny0")])); + + let result: RedisResult> = con.geo_hash("my_gis", &[PALERMO.2, CATANIA.2]); + assert_eq!( + result, + Ok(vec![ + String::from("sqc8b49rny0"), + String::from("sqdtr74hyu0"), + ]) + ); +} + +#[test] +fn test_geopos() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", &[PALERMO, CATANIA]), Ok(2)); + + let result: Vec> = con.geo_pos("my_gis", &[PALERMO.2]).unwrap(); + assert_eq!(result.len(), 1); + + assert_approx_eq!(result[0][0], 13.36138, 0.0001); + assert_approx_eq!(result[0][1], 38.11555, 0.0001); + + // Using the Coord struct + let result: Vec> = con.geo_pos("my_gis", &[PALERMO.2, CATANIA.2]).unwrap(); + assert_eq!(result.len(), 2); + + assert_approx_eq!(result[0].longitude, 13.36138, 0.0001); + assert_approx_eq!(result[0].latitude, 38.11555, 0.0001); + + assert_approx_eq!(result[1].longitude, 15.08726, 0.0001); + assert_approx_eq!(result[1].latitude, 37.50266, 0.0001); +} + +#[test] +fn test_use_coord_struct() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!( + con.geo_add( + "my_gis", + (Coord::lon_lat(13.361_389, 38.115_556), "Palermo") + ), + Ok(1) + ); + + let result: Vec> = con.geo_pos("my_gis", "Palermo").unwrap(); + assert_eq!(result.len(), 1); + + assert_approx_eq!(result[0].longitude, 13.36138, 0.0001); + assert_approx_eq!(result[0].latitude, 38.11555, 0.0001); +} + +#[test] +fn test_georadius() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", &[PALERMO, CATANIA]), Ok(2)); + + let mut geo_radius = |opts: RadiusOptions| -> Vec { + con.geo_radius("my_gis", 15.0, 37.0, 200.0, Unit::Kilometers, opts) + .unwrap() + }; + + // Simple request, without extra data + let mut result = geo_radius(RadiusOptions::default()); + result.sort_by(|a, b| Ord::cmp(&a.name, &b.name)); + + assert_eq!(result.len(), 2); + + assert_eq!(result[0].name.as_str(), "Catania"); + assert_eq!(result[0].coord, None); + assert_eq!(result[0].dist, None); + + assert_eq!(result[1].name.as_str(), "Palermo"); + assert_eq!(result[1].coord, None); + assert_eq!(result[1].dist, None); + + // Get data with multiple fields + let result = geo_radius(RadiusOptions::default().with_dist().order(RadiusOrder::Asc)); + + assert_eq!(result.len(), 2); + + assert_eq!(result[0].name.as_str(), "Catania"); + assert_eq!(result[0].coord, None); + assert_approx_eq!(result[0].dist.unwrap(), 56.4413, 0.001); + + assert_eq!(result[1].name.as_str(), "Palermo"); + assert_eq!(result[1].coord, None); + assert_approx_eq!(result[1].dist.unwrap(), 190.4424, 0.001); + + let result = geo_radius( + RadiusOptions::default() + .with_coord() + .order(RadiusOrder::Desc) + .limit(1), + ); + + assert_eq!(result.len(), 1); + + assert_eq!(result[0].name.as_str(), "Palermo"); + assert_approx_eq!(result[0].coord.as_ref().unwrap().longitude, 13.361_389); + assert_approx_eq!(result[0].coord.as_ref().unwrap().latitude, 38.115_556); + assert_eq!(result[0].dist, None); +} + +#[test] +fn test_georadius_by_member() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", &[PALERMO, CATANIA, AGRIGENTO]), Ok(3)); + + // Simple request, without extra data + let opts = RadiusOptions::default().order(RadiusOrder::Asc); + let result: Vec = con + .geo_radius_by_member("my_gis", AGRIGENTO.2, 100.0, Unit::Kilometers, opts) + .unwrap(); + let names: Vec<_> = result.iter().map(|c| c.name.as_str()).collect(); + + assert_eq!(names, vec!["Agrigento", "Palermo"]); +} diff --git a/glide-core/redis-rs/redis/tests/test_module_json.rs b/glide-core/redis-rs/redis/tests/test_module_json.rs new file mode 100644 index 0000000000..08fed23930 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_module_json.rs @@ -0,0 +1,540 @@ +#![cfg(feature = "json")] + +use std::assert_eq; +use std::collections::HashMap; + +use redis::{JsonCommands, ProtocolVersion}; + +use redis::{ + ErrorKind, RedisError, RedisResult, + Value::{self, *}, +}; + +use crate::support::*; +mod support; + +use serde::Serialize; +// adds json! macro for quick json generation on the fly. +use serde_json::json; + +const TEST_KEY: &str = "my_json"; + +const MTLS_NOT_ENABLED: bool = false; + +#[test] +fn test_module_json_serialize_error() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + #[derive(Debug, Serialize)] + struct InvalidSerializedStruct { + // Maps in serde_json must have string-like keys + // so numbers and strings, anything else will cause the serialization to fail + // this is basically the only way to make a serialization fail at runtime + // since rust doesnt provide the necessary ability to enforce this + pub invalid_json: HashMap, i64>, + } + + let mut test_invalid_value: InvalidSerializedStruct = InvalidSerializedStruct { + invalid_json: HashMap::new(), + }; + + test_invalid_value.invalid_json.insert(None, 2i64); + + let set_invalid: RedisResult = con.json_set(TEST_KEY, "$", &test_invalid_value); + + assert_eq!( + set_invalid, + Err(RedisError::from(( + ErrorKind::Serialize, + "Serialization Error", + String::from("key must be string") + ))) + ); +} + +#[test] +fn test_module_json_arr_append() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":[1i64], "nested": {"a": [1i64, 2i64]}, "nested2": {"a": 42i64}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_append: RedisResult = con.json_arr_append(TEST_KEY, "$..a", &3i64); + + assert_eq!(json_append, Ok(Array(vec![Int(2i64), Int(3i64), Nil]))); +} + +#[test] +fn test_module_json_arr_index() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":[1i64, 2i64, 3i64, 2i64], "nested": {"a": [3i64, 4i64]}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_arrindex: RedisResult = con.json_arr_index(TEST_KEY, "$..a", &2i64); + + assert_eq!(json_arrindex, Ok(Array(vec![Int(1i64), Int(-1i64)]))); + + let update_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":[1i64, 2i64, 3i64, 2i64], "nested": {"a": false}}), + ); + + assert_eq!(update_initial, Ok(true)); + + let json_arrindex_2: RedisResult = + con.json_arr_index_ss(TEST_KEY, "$..a", &2i64, &0, &0); + + assert_eq!(json_arrindex_2, Ok(Array(vec![Int(1i64), Nil]))); +} + +#[test] +fn test_module_json_arr_insert() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":[3i64], "nested": {"a": [3i64 ,4i64]}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_arrinsert: RedisResult = con.json_arr_insert(TEST_KEY, "$..a", 0, &1i64); + + assert_eq!(json_arrinsert, Ok(Array(vec![Int(2), Int(3)]))); + + let update_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":[1i64 ,2i64 ,3i64 ,2i64], "nested": {"a": false}}), + ); + + assert_eq!(update_initial, Ok(true)); + + let json_arrinsert_2: RedisResult = con.json_arr_insert(TEST_KEY, "$..a", 0, &1i64); + + assert_eq!(json_arrinsert_2, Ok(Array(vec![Int(5), Nil]))); +} + +#[test] +fn test_module_json_arr_len() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a": [3i64], "nested": {"a": [3i64, 4i64]}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_arrlen: RedisResult = con.json_arr_len(TEST_KEY, "$..a"); + + assert_eq!(json_arrlen, Ok(Array(vec![Int(1), Int(2)]))); + + let update_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a": [1i64, 2i64, 3i64, 2i64], "nested": {"a": false}}), + ); + + assert_eq!(update_initial, Ok(true)); + + let json_arrlen_2: RedisResult = con.json_arr_len(TEST_KEY, "$..a"); + + assert_eq!(json_arrlen_2, Ok(Array(vec![Int(4), Nil]))); +} + +#[test] +fn test_module_json_arr_pop() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a": [3i64], "nested": {"a": [3i64, 4i64]}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_arrpop: RedisResult = con.json_arr_pop(TEST_KEY, "$..a", -1); + + assert_eq!( + json_arrpop, + Ok(Array(vec![ + // convert string 3 to its ascii value as bytes + BulkString(Vec::from("3".as_bytes())), + BulkString(Vec::from("4".as_bytes())) + ])) + ); + + let update_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":["foo", "bar"], "nested": {"a": false}, "nested2": {"a":[]}}), + ); + + assert_eq!(update_initial, Ok(true)); + + let json_arrpop_2: RedisResult = con.json_arr_pop(TEST_KEY, "$..a", -1); + + assert_eq!( + json_arrpop_2, + Ok(Array(vec![ + BulkString(Vec::from("\"bar\"".as_bytes())), + Nil, + Nil + ])) + ); +} + +#[test] +fn test_module_json_arr_trim() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a": [], "nested": {"a": [1i64, 4u64]}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_arrtrim: RedisResult = con.json_arr_trim(TEST_KEY, "$..a", 1, 1); + + assert_eq!(json_arrtrim, Ok(Array(vec![Int(0), Int(1)]))); + + let update_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a": [1i64, 2i64, 3i64, 4i64], "nested": {"a": false}}), + ); + + assert_eq!(update_initial, Ok(true)); + + let json_arrtrim_2: RedisResult = con.json_arr_trim(TEST_KEY, "$..a", 1, 1); + + assert_eq!(json_arrtrim_2, Ok(Array(vec![Int(1), Nil]))); +} + +#[test] +fn test_module_json_clear() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set(TEST_KEY, "$", &json!({"obj": {"a": 1i64, "b": 2i64}, "arr": [1i64, 2i64, 3i64], "str": "foo", "bool": true, "int": 42i64, "float": std::f64::consts::PI})); + + assert_eq!(set_initial, Ok(true)); + + let json_clear: RedisResult = con.json_clear(TEST_KEY, "$.*"); + + assert_eq!(json_clear, Ok(4)); + + let checking_value: RedisResult = con.json_get(TEST_KEY, "$"); + + // float is set to 0 and serde_json serializes 0f64 to 0.0, which is a different string + assert_eq!( + checking_value, + // i found it changes the order? + // its not reallt a problem if you're just deserializing it anyway but still + // kinda weird + Ok("[{\"arr\":[],\"bool\":true,\"float\":0,\"int\":0,\"obj\":{},\"str\":\"foo\"}]".into()) + ); +} + +#[test] +fn test_module_json_del() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a": 1i64, "nested": {"a": 2i64, "b": 3i64}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_del: RedisResult = con.json_del(TEST_KEY, "$..a"); + + assert_eq!(json_del, Ok(2)); +} + +#[test] +fn test_module_json_get() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":2i64, "b": 3i64, "nested": {"a": 4i64, "b": null}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_get: RedisResult = con.json_get(TEST_KEY, "$..b"); + + assert_eq!(json_get, Ok("[3,null]".into())); + + let json_get_multi: RedisResult = con.json_get(TEST_KEY, vec!["..a", "$..b"]); + + if json_get_multi != Ok("{\"$..b\":[3,null],\"..a\":[2,4]}".into()) + && json_get_multi != Ok("{\"..a\":[2,4],\"$..b\":[3,null]}".into()) + { + panic!("test_error: incorrect response from json_get_multi"); + } +} + +#[test] +fn test_module_json_mget() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial_a: RedisResult = con.json_set( + format!("{TEST_KEY}-a"), + "$", + &json!({"a":1i64, "b": 2i64, "nested": {"a": 3i64, "b": null}}), + ); + let set_initial_b: RedisResult = con.json_set( + format!("{TEST_KEY}-b"), + "$", + &json!({"a":4i64, "b": 5i64, "nested": {"a": 6i64, "b": null}}), + ); + + assert_eq!(set_initial_a, Ok(true)); + assert_eq!(set_initial_b, Ok(true)); + + let json_mget: RedisResult = con.json_get( + vec![format!("{TEST_KEY}-a"), format!("{TEST_KEY}-b")], + "$..a", + ); + + assert_eq!( + json_mget, + Ok(Array(vec![ + BulkString(Vec::from("[1,3]".as_bytes())), + BulkString(Vec::from("[4,6]".as_bytes())) + ])) + ); +} + +#[test] +fn test_module_json_num_incr_by() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":"b","b":[{"a":2i64}, {"a":5i64}, {"a":"c"}]}), + ); + + assert_eq!(set_initial, Ok(true)); + + let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); + if ctx.protocol != ProtocolVersion::RESP2 && redis_ver.starts_with("7.") { + // cannot increment a string + let json_numincrby_a: RedisResult> = con.json_num_incr_by(TEST_KEY, "$.a", 2); + assert_eq!(json_numincrby_a, Ok(vec![Nil])); + + let json_numincrby_b: RedisResult> = con.json_num_incr_by(TEST_KEY, "$..a", 2); + + // however numbers can be incremented + assert_eq!(json_numincrby_b, Ok(vec![Nil, Int(4), Int(7), Nil])); + } else { + // cannot increment a string + let json_numincrby_a: RedisResult = con.json_num_incr_by(TEST_KEY, "$.a", 2); + assert_eq!(json_numincrby_a, Ok("[null]".into())); + + let json_numincrby_b: RedisResult = con.json_num_incr_by(TEST_KEY, "$..a", 2); + + // however numbers can be incremented + assert_eq!(json_numincrby_b, Ok("[null,4,7,null]".into())); + } +} + +#[test] +fn test_module_json_obj_keys() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":[3i64], "nested": {"a": {"b":2i64, "c": 1i64}}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_objkeys: RedisResult = con.json_obj_keys(TEST_KEY, "$..a"); + + assert_eq!( + json_objkeys, + Ok(Array(vec![ + Nil, + Array(vec![ + BulkString(Vec::from("b".as_bytes())), + BulkString(Vec::from("c".as_bytes())) + ]) + ])) + ); +} + +#[test] +fn test_module_json_obj_len() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":[3i64], "nested": {"a": {"b":2i64, "c": 1i64}}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_objlen: RedisResult = con.json_obj_len(TEST_KEY, "$..a"); + + assert_eq!(json_objlen, Ok(Array(vec![Nil, Int(2)]))); +} + +#[test] +fn test_module_json_set() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set: RedisResult = con.json_set(TEST_KEY, "$", &json!({"key": "value"})); + + assert_eq!(set, Ok(true)); +} + +#[test] +fn test_module_json_str_append() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":"foo", "nested": {"a": "hello"}, "nested2": {"a": 31i64}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_strappend: RedisResult = con.json_str_append(TEST_KEY, "$..a", "\"baz\""); + + assert_eq!(json_strappend, Ok(Array(vec![Int(6), Int(8), Nil]))); + + let json_get_check: RedisResult = con.json_get(TEST_KEY, "$"); + + assert_eq!( + json_get_check, + Ok("[{\"a\":\"foobaz\",\"nested\":{\"a\":\"hellobaz\"},\"nested2\":{\"a\":31}}]".into()) + ); +} + +#[test] +fn test_module_json_str_len() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":"foo", "nested": {"a": "hello"}, "nested2": {"a": 31i32}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_strlen: RedisResult = con.json_str_len(TEST_KEY, "$..a"); + + assert_eq!(json_strlen, Ok(Array(vec![Int(3), Int(5), Nil]))); +} + +#[test] +fn test_module_json_toggle() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set(TEST_KEY, "$", &json!({"bool": true})); + + assert_eq!(set_initial, Ok(true)); + + let json_toggle_a: RedisResult = con.json_toggle(TEST_KEY, "$.bool"); + assert_eq!(json_toggle_a, Ok(Array(vec![Int(0)]))); + + let json_toggle_b: RedisResult = con.json_toggle(TEST_KEY, "$.bool"); + assert_eq!(json_toggle_b, Ok(Array(vec![Int(1)]))); +} + +#[test] +fn test_module_json_type() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":2i64, "nested": {"a": true}, "foo": "bar"}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_type_a: RedisResult = con.json_type(TEST_KEY, "$..foo"); + let json_type_b: RedisResult = con.json_type(TEST_KEY, "$..a"); + let json_type_c: RedisResult = con.json_type(TEST_KEY, "$..dummy"); + + let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); + if ctx.protocol != ProtocolVersion::RESP2 && redis_ver.starts_with("7.") { + // In RESP3 current RedisJSON always gives response in an array. + assert_eq!( + json_type_a, + Ok(Array(vec![Array(vec![BulkString(Vec::from( + "string".as_bytes() + ))])])) + ); + + assert_eq!( + json_type_b, + Ok(Array(vec![Array(vec![ + BulkString(Vec::from("integer".as_bytes())), + BulkString(Vec::from("boolean".as_bytes())) + ])])) + ); + assert_eq!(json_type_c, Ok(Array(vec![Array(vec![])]))); + } else { + assert_eq!( + json_type_a, + Ok(Array(vec![BulkString(Vec::from("string".as_bytes()))])) + ); + + assert_eq!( + json_type_b, + Ok(Array(vec![ + BulkString(Vec::from("integer".as_bytes())), + BulkString(Vec::from("boolean".as_bytes())) + ])) + ); + assert_eq!(json_type_c, Ok(Array(vec![]))); + } +} diff --git a/glide-core/redis-rs/redis/tests/test_sentinel.rs b/glide-core/redis-rs/redis/tests/test_sentinel.rs new file mode 100644 index 0000000000..24cd13bd67 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_sentinel.rs @@ -0,0 +1,496 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +#![cfg(feature = "sentinel")] +mod support; + +use std::collections::HashMap; + +use redis::{ + sentinel::{Sentinel, SentinelClient, SentinelNodeConnectionInfo}, + Client, Connection, ConnectionAddr, ConnectionInfo, +}; + +use crate::support::*; + +fn parse_replication_info(value: &str) -> HashMap<&str, &str> { + let info_map: std::collections::HashMap<&str, &str> = value + .split("\r\n") + .filter(|line| !line.trim_start().starts_with('#')) + .filter_map(|line| line.split_once(':')) + .collect(); + info_map +} + +fn assert_is_master_role(replication_info: String) { + let info_map = parse_replication_info(&replication_info); + assert_eq!(info_map.get("role"), Some(&"master")); +} + +fn assert_replica_role_and_master_addr(replication_info: String, expected_master: &ConnectionInfo) { + let info_map = parse_replication_info(&replication_info); + + assert_eq!(info_map.get("role"), Some(&"slave")); + + let (master_host, master_port) = match &expected_master.addr { + ConnectionAddr::Tcp(host, port) => (host, port), + ConnectionAddr::TcpTls { + host, + port, + insecure: _, + tls_params: _, + } => (host, port), + ConnectionAddr::Unix(..) => panic!("Unexpected master connection type"), + }; + + assert_eq!(info_map.get("master_host"), Some(&master_host.as_str())); + assert_eq!( + info_map.get("master_port"), + Some(&master_port.to_string().as_str()) + ); +} + +fn assert_is_connection_to_master(conn: &mut Connection) { + let info: String = redis::cmd("INFO").arg("REPLICATION").query(conn).unwrap(); + assert_is_master_role(info); +} + +fn assert_connection_is_replica_of_correct_master(conn: &mut Connection, master_client: &Client) { + let info: String = redis::cmd("INFO").arg("REPLICATION").query(conn).unwrap(); + assert_replica_role_and_master_addr(info, master_client.get_connection_info()); +} + +/// Get replica clients from the sentinel in a rotating fashion, asserting that they are +/// indeed replicas of the given master, and returning a list of their addresses. +fn connect_to_all_replicas( + sentinel: &mut Sentinel, + master_name: &str, + master_client: &Client, + node_conn_info: &SentinelNodeConnectionInfo, + number_of_replicas: u16, +) -> Vec { + let mut replica_conn_infos = vec![]; + + for _ in 0..number_of_replicas { + let replica_client = sentinel + .replica_rotate_for(master_name, Some(node_conn_info)) + .unwrap(); + let mut replica_con = replica_client.get_connection(None).unwrap(); + + assert!(!replica_conn_infos.contains(&replica_client.get_connection_info().addr)); + replica_conn_infos.push(replica_client.get_connection_info().addr.clone()); + + assert_connection_is_replica_of_correct_master(&mut replica_con, master_client); + } + + replica_conn_infos +} + +fn assert_connect_to_known_replicas( + sentinel: &mut Sentinel, + replica_conn_infos: &[ConnectionAddr], + master_name: &str, + master_client: &Client, + node_conn_info: &SentinelNodeConnectionInfo, + number_of_connections: u32, +) { + for _ in 0..number_of_connections { + let replica_client = sentinel + .replica_rotate_for(master_name, Some(node_conn_info)) + .unwrap(); + let mut replica_con = replica_client.get_connection(None).unwrap(); + + assert!(replica_conn_infos.contains(&replica_client.get_connection_info().addr)); + + assert_connection_is_replica_of_correct_master(&mut replica_con, master_client); + } +} + +#[test] +fn test_sentinel_connect_to_random_replica() { + let master_name = "master1"; + let mut context = TestSentinelContext::new(2, 3, 3); + let node_conn_info: SentinelNodeConnectionInfo = context.sentinel_node_connection_info(); + let sentinel = context.sentinel_mut(); + + let master_client = sentinel + .master_for(master_name, Some(&node_conn_info)) + .unwrap(); + let mut master_con = master_client.get_connection(None).unwrap(); + + let mut replica_con = sentinel + .replica_for(master_name, Some(&node_conn_info)) + .unwrap() + .get_connection(None) + .unwrap(); + + assert_is_connection_to_master(&mut master_con); + assert_connection_is_replica_of_correct_master(&mut replica_con, &master_client); +} + +#[test] +fn test_sentinel_connect_to_multiple_replicas() { + let number_of_replicas = 3; + let master_name = "master1"; + let mut cluster = TestSentinelContext::new(2, number_of_replicas, 3); + let node_conn_info = cluster.sentinel_node_connection_info(); + let sentinel = cluster.sentinel_mut(); + + let master_client = sentinel + .master_for(master_name, Some(&node_conn_info)) + .unwrap(); + let mut master_con = master_client.get_connection(None).unwrap(); + + assert_is_connection_to_master(&mut master_con); + + let replica_conn_infos = connect_to_all_replicas( + sentinel, + master_name, + &master_client, + &node_conn_info, + number_of_replicas, + ); + + assert_connect_to_known_replicas( + sentinel, + &replica_conn_infos, + master_name, + &master_client, + &node_conn_info, + 10, + ); +} + +#[test] +fn test_sentinel_server_down() { + let number_of_replicas = 3; + let master_name = "master1"; + let mut context = TestSentinelContext::new(2, number_of_replicas, 3); + let node_conn_info = context.sentinel_node_connection_info(); + let sentinel = context.sentinel_mut(); + + let master_client = sentinel + .master_for(master_name, Some(&node_conn_info)) + .unwrap(); + let mut master_con = master_client.get_connection(None).unwrap(); + + assert_is_connection_to_master(&mut master_con); + + context.cluster.sentinel_servers[0].stop(); + std::thread::sleep(std::time::Duration::from_millis(25)); + + let sentinel = context.sentinel_mut(); + + let replica_conn_infos = connect_to_all_replicas( + sentinel, + master_name, + &master_client, + &node_conn_info, + number_of_replicas, + ); + + assert_connect_to_known_replicas( + sentinel, + &replica_conn_infos, + master_name, + &master_client, + &node_conn_info, + 10, + ); +} + +#[test] +fn test_sentinel_client() { + let master_name = "master1"; + let mut context = TestSentinelContext::new(2, 3, 3); + let mut master_client = SentinelClient::build( + context.sentinels_connection_info().clone(), + String::from(master_name), + Some(context.sentinel_node_connection_info()), + redis::sentinel::SentinelServerType::Master, + ) + .unwrap(); + + let mut replica_client = SentinelClient::build( + context.sentinels_connection_info().clone(), + String::from(master_name), + Some(context.sentinel_node_connection_info()), + redis::sentinel::SentinelServerType::Replica, + ) + .unwrap(); + + let mut master_con = master_client.get_connection().unwrap(); + + assert_is_connection_to_master(&mut master_con); + + let node_conn_info = context.sentinel_node_connection_info(); + let sentinel = context.sentinel_mut(); + let master_client = sentinel + .master_for(master_name, Some(&node_conn_info)) + .unwrap(); + + for _ in 0..20 { + let mut replica_con = replica_client.get_connection().unwrap(); + + assert_connection_is_replica_of_correct_master(&mut replica_con, &master_client); + } +} + +#[cfg(feature = "aio")] +pub mod async_tests { + use redis::{ + aio::MultiplexedConnection, + sentinel::{Sentinel, SentinelClient, SentinelNodeConnectionInfo}, + Client, ConnectionAddr, GlideConnectionOptions, RedisError, + }; + + use crate::{assert_is_master_role, assert_replica_role_and_master_addr, support::*}; + + async fn async_assert_is_connection_to_master(conn: &mut MultiplexedConnection) { + let info: String = redis::cmd("INFO") + .arg("REPLICATION") + .query_async(conn) + .await + .unwrap(); + + assert_is_master_role(info); + } + + async fn async_assert_connection_is_replica_of_correct_master( + conn: &mut MultiplexedConnection, + master_client: &Client, + ) { + let info: String = redis::cmd("INFO") + .arg("REPLICATION") + .query_async(conn) + .await + .unwrap(); + + assert_replica_role_and_master_addr(info, master_client.get_connection_info()); + } + + /// Async version of connect_to_all_replicas + async fn async_connect_to_all_replicas( + sentinel: &mut Sentinel, + master_name: &str, + master_client: &Client, + node_conn_info: &SentinelNodeConnectionInfo, + number_of_replicas: u16, + ) -> Vec { + let mut replica_conn_infos = vec![]; + + for _ in 0..number_of_replicas { + let replica_client = sentinel + .async_replica_rotate_for(master_name, Some(node_conn_info)) + .await + .unwrap(); + let mut replica_con = replica_client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await + .unwrap(); + + assert!( + !replica_conn_infos.contains(&replica_client.get_connection_info().addr), + "pushing {:?} into {:?}", + replica_client.get_connection_info().addr, + replica_conn_infos + ); + replica_conn_infos.push(replica_client.get_connection_info().addr.clone()); + + async_assert_connection_is_replica_of_correct_master(&mut replica_con, master_client) + .await; + } + + replica_conn_infos + } + + async fn async_assert_connect_to_known_replicas( + sentinel: &mut Sentinel, + replica_conn_infos: &[ConnectionAddr], + master_name: &str, + master_client: &Client, + node_conn_info: &SentinelNodeConnectionInfo, + number_of_connections: u32, + ) { + for _ in 0..number_of_connections { + let replica_client = sentinel + .async_replica_rotate_for(master_name, Some(node_conn_info)) + .await + .unwrap(); + let mut replica_con = replica_client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await + .unwrap(); + + assert!(replica_conn_infos.contains(&replica_client.get_connection_info().addr)); + + async_assert_connection_is_replica_of_correct_master(&mut replica_con, master_client) + .await; + } + } + + #[test] + fn test_sentinel_connect_to_random_replica_async() { + let master_name = "master1"; + let mut context = TestSentinelContext::new(2, 3, 3); + let node_conn_info = context.sentinel_node_connection_info(); + let sentinel = context.sentinel_mut(); + + block_on_all(async move { + let master_client = sentinel + .async_master_for(master_name, Some(&node_conn_info)) + .await?; + let mut master_con = master_client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + + let mut replica_con = sentinel + .async_replica_for(master_name, Some(&node_conn_info)) + .await? + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + + async_assert_is_connection_to_master(&mut master_con).await; + async_assert_connection_is_replica_of_correct_master(&mut replica_con, &master_client) + .await; + + Ok::<(), RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_sentinel_connect_to_multiple_replicas_async() { + let number_of_replicas = 3; + let master_name = "master1"; + let mut cluster = TestSentinelContext::new(2, number_of_replicas, 3); + let node_conn_info = cluster.sentinel_node_connection_info(); + let sentinel = cluster.sentinel_mut(); + + block_on_all(async move { + let master_client = sentinel + .async_master_for(master_name, Some(&node_conn_info)) + .await?; + let mut master_con = master_client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + + async_assert_is_connection_to_master(&mut master_con).await; + + let replica_conn_infos = async_connect_to_all_replicas( + sentinel, + master_name, + &master_client, + &node_conn_info, + number_of_replicas, + ) + .await; + + async_assert_connect_to_known_replicas( + sentinel, + &replica_conn_infos, + master_name, + &master_client, + &node_conn_info, + 10, + ) + .await; + + Ok::<(), RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_sentinel_server_down_async() { + let number_of_replicas = 3; + let master_name = "master1"; + let mut context = TestSentinelContext::new(2, number_of_replicas, 3); + let node_conn_info = context.sentinel_node_connection_info(); + + block_on_all(async move { + let sentinel = context.sentinel_mut(); + + let master_client = sentinel + .async_master_for(master_name, Some(&node_conn_info)) + .await?; + let mut master_con = master_client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + + async_assert_is_connection_to_master(&mut master_con).await; + + context.cluster.sentinel_servers[0].stop(); + std::thread::sleep(std::time::Duration::from_millis(25)); + + let sentinel = context.sentinel_mut(); + + let replica_conn_infos = async_connect_to_all_replicas( + sentinel, + master_name, + &master_client, + &node_conn_info, + number_of_replicas, + ) + .await; + + async_assert_connect_to_known_replicas( + sentinel, + &replica_conn_infos, + master_name, + &master_client, + &node_conn_info, + 10, + ) + .await; + + Ok::<(), RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_sentinel_client_async() { + let master_name = "master1"; + let mut context = TestSentinelContext::new(2, 3, 3); + let mut master_client = SentinelClient::build( + context.sentinels_connection_info().clone(), + String::from(master_name), + Some(context.sentinel_node_connection_info()), + redis::sentinel::SentinelServerType::Master, + ) + .unwrap(); + + let mut replica_client = SentinelClient::build( + context.sentinels_connection_info().clone(), + String::from(master_name), + Some(context.sentinel_node_connection_info()), + redis::sentinel::SentinelServerType::Replica, + ) + .unwrap(); + + block_on_all(async move { + let mut master_con = master_client.get_async_connection().await?; + + async_assert_is_connection_to_master(&mut master_con).await; + + let node_conn_info = context.sentinel_node_connection_info(); + let sentinel = context.sentinel_mut(); + let master_client = sentinel + .async_master_for(master_name, Some(&node_conn_info)) + .await?; + + // Read commands to the replica node + for _ in 0..20 { + let mut replica_con = replica_client.get_async_connection().await?; + + async_assert_connection_is_replica_of_correct_master( + &mut replica_con, + &master_client, + ) + .await; + } + + Ok::<(), RedisError>(()) + }) + .unwrap(); + } +} diff --git a/glide-core/redis-rs/redis/tests/test_streams.rs b/glide-core/redis-rs/redis/tests/test_streams.rs new file mode 100644 index 0000000000..bf06028b95 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_streams.rs @@ -0,0 +1,627 @@ +#![cfg(feature = "streams")] + +use redis::streams::*; +use redis::{Commands, Connection, RedisResult, ToRedisArgs}; + +mod support; +use crate::support::*; + +use std::collections::BTreeMap; +use std::str; +use std::thread::sleep; +use std::time::Duration; + +fn xadd(con: &mut Connection) { + let _: RedisResult = + con.xadd("k1", "1000-0", &[("hello", "world"), ("redis", "streams")]); + let _: RedisResult = con.xadd("k1", "1000-1", &[("hello", "world2")]); + let _: RedisResult = con.xadd("k2", "2000-0", &[("hello", "world")]); + let _: RedisResult = con.xadd("k2", "2000-1", &[("hello", "world2")]); +} + +fn xadd_keyrange(con: &mut Connection, key: &str, start: i32, end: i32) { + for _i in start..end { + let _: RedisResult = con.xadd(key, "*", &[("h", "w")]); + } +} + +#[test] +fn test_cmd_options() { + // Tests the following command option builders.... + // xclaim_options + // xread_options + // maxlen enum + + // test read options + + let empty = StreamClaimOptions::default(); + assert_eq!(ToRedisArgs::to_redis_args(&empty).len(), 0); + + let empty = StreamReadOptions::default(); + assert_eq!(ToRedisArgs::to_redis_args(&empty).len(), 0); + + let opts = StreamClaimOptions::default() + .idle(50) + .time(500) + .retry(3) + .with_force() + .with_justid(); + + assert_args!( + &opts, + "IDLE", + "50", + "TIME", + "500", + "RETRYCOUNT", + "3", + "FORCE", + "JUSTID" + ); + + // test maxlen options + + assert_args!(StreamMaxlen::Approx(10), "MAXLEN", "~", "10"); + assert_args!(StreamMaxlen::Equals(10), "MAXLEN", "=", "10"); + + // test read options + + let opts = StreamReadOptions::default() + .noack() + .block(100) + .count(200) + .group("group-name", "consumer-name"); + + assert_args!( + &opts, + "GROUP", + "group-name", + "consumer-name", + "BLOCK", + "100", + "COUNT", + "200", + "NOACK" + ); + + // should skip noack because of missing group(,) + let opts = StreamReadOptions::default().noack().block(100).count(200); + + assert_args!(&opts, "BLOCK", "100", "COUNT", "200"); +} + +#[test] +fn test_assorted_1() { + // Tests the following commands.... + // xadd + // xadd_map (skip this for now) + // xadd_maxlen + // xread + // xlen + + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + xadd(&mut con); + + // smoke test that we get the same id back + let result: RedisResult = con.xadd("k0", "1000-0", &[("x", "y")]); + assert_eq!(result.unwrap(), "1000-0"); + + // xread reply + let reply: StreamReadReply = con.xread(&["k1", "k2", "k3"], &["0", "0", "0"]).unwrap(); + + // verify reply contains 2 keys even though we asked for 3 + assert_eq!(&reply.keys.len(), &2usize); + + // verify first key & first id exist + assert_eq!(&reply.keys[0].key, "k1"); + assert_eq!(&reply.keys[0].ids.len(), &2usize); + assert_eq!(&reply.keys[0].ids[0].id, "1000-0"); + + // lookup the key in StreamId map + let hello: Option = reply.keys[0].ids[0].get("hello"); + assert_eq!(hello, Some("world".to_string())); + + // verify the second key was written + assert_eq!(&reply.keys[1].key, "k2"); + assert_eq!(&reply.keys[1].ids.len(), &2usize); + assert_eq!(&reply.keys[1].ids[0].id, "2000-0"); + + // test xadd_map + let mut map: BTreeMap<&str, &str> = BTreeMap::new(); + map.insert("ab", "cd"); + map.insert("ef", "gh"); + map.insert("ij", "kl"); + let _: RedisResult = con.xadd_map("k3", "3000-0", map); + + let reply: StreamRangeReply = con.xrange_all("k3").unwrap(); + assert!(reply.ids[0].contains_key("ab")); + assert!(reply.ids[0].contains_key("ef")); + assert!(reply.ids[0].contains_key("ij")); + + // test xadd w/ maxlength below... + + // add 100 things to k4 + xadd_keyrange(&mut con, "k4", 0, 100); + + // test xlen.. should have 100 items + let result: RedisResult = con.xlen("k4"); + assert_eq!(result, Ok(100)); + + // test xadd_maxlen + let _: RedisResult = + con.xadd_maxlen("k4", StreamMaxlen::Equals(10), "*", &[("h", "w")]); + let result: RedisResult = con.xlen("k4"); + assert_eq!(result, Ok(10)); +} + +#[test] +fn test_xgroup_create() { + // Tests the following commands.... + // xadd + // xinfo_stream + // xgroup_create + // xinfo_groups + + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + xadd(&mut con); + + // no key exists... this call breaks the connection pipe for some reason + let reply: RedisResult = con.xinfo_stream("k10"); + assert!(reply.is_err()); + + // redo the connection because the above error + con = ctx.connection(); + + // key should exist + let reply: StreamInfoStreamReply = con.xinfo_stream("k1").unwrap(); + assert_eq!(&reply.first_entry.id, "1000-0"); + assert_eq!(&reply.last_entry.id, "1000-1"); + assert_eq!(&reply.last_generated_id, "1000-1"); + + // xgroup create (existing stream) + let result: RedisResult = con.xgroup_create("k1", "g1", "$"); + assert!(result.is_ok()); + + // xinfo groups (existing stream) + let result: RedisResult = con.xinfo_groups("k1"); + assert!(result.is_ok()); + let reply = result.unwrap(); + assert_eq!(&reply.groups.len(), &1); + assert_eq!(&reply.groups[0].name, &"g1"); +} + +#[test] +fn test_assorted_2() { + // Tests the following commands.... + // xadd + // xinfo_stream + // xinfo_groups + // xinfo_consumer + // xgroup_create_mkstream + // xread_options + // xack + // xpending + // xpending_count + // xpending_consumer_count + + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + xadd(&mut con); + + // test xgroup create w/ mkstream @ 0 + let result: RedisResult = con.xgroup_create_mkstream("k99", "g99", "0"); + assert!(result.is_ok()); + + // Since nothing exists on this stream yet, + // it should have the defaults returned by the client + let result: RedisResult = con.xinfo_groups("k99"); + assert!(result.is_ok()); + let reply = result.unwrap(); + assert_eq!(&reply.groups.len(), &1); + assert_eq!(&reply.groups[0].name, &"g99"); + assert_eq!(&reply.groups[0].last_delivered_id, &"0-0"); + + // call xadd on k99 just so we can read from it + // using consumer g99 and test xinfo_consumers + let _: RedisResult = con.xadd("k99", "1000-0", &[("a", "b"), ("c", "d")]); + let _: RedisResult = con.xadd("k99", "1000-1", &[("e", "f"), ("g", "h")]); + + // test empty PEL + let empty_reply: StreamPendingReply = con.xpending("k99", "g99").unwrap(); + + assert_eq!(empty_reply.count(), 0); + if let StreamPendingReply::Empty = empty_reply { + // looks good + } else { + panic!("Expected StreamPendingReply::Empty but got Data"); + } + + // passing options w/ group triggers XREADGROUP + // using ID=">" means all undelivered ids + // otherwise, ID="0 | ms-num" means all pending already + // sent to this client + let reply: StreamReadReply = con + .xread_options( + &["k99"], + &[">"], + &StreamReadOptions::default().group("g99", "c99"), + ) + .unwrap(); + assert_eq!(reply.keys[0].ids.len(), 2); + + // read xinfo consumers again, should have 2 messages for the c99 consumer + let reply: StreamInfoConsumersReply = con.xinfo_consumers("k99", "g99").unwrap(); + assert_eq!(reply.consumers[0].pending, 2); + + // ack one of these messages + let result: RedisResult = con.xack("k99", "g99", &["1000-0"]); + assert_eq!(result, Ok(1)); + + // get pending messages already seen by this client + // we should only have one now.. + let reply: StreamReadReply = con + .xread_options( + &["k99"], + &["0"], + &StreamReadOptions::default().group("g99", "c99"), + ) + .unwrap(); + assert_eq!(reply.keys.len(), 1); + + // we should also have one pending here... + let reply: StreamInfoConsumersReply = con.xinfo_consumers("k99", "g99").unwrap(); + assert_eq!(reply.consumers[0].pending, 1); + + // add more and read so we can test xpending + let _: RedisResult = con.xadd("k99", "1001-0", &[("i", "j"), ("k", "l")]); + let _: RedisResult = con.xadd("k99", "1001-1", &[("m", "n"), ("o", "p")]); + let _: StreamReadReply = con + .xread_options( + &["k99"], + &[">"], + &StreamReadOptions::default().group("g99", "c99"), + ) + .unwrap(); + + // call xpending here... + // this has a different reply from what the count variations return + let data_reply: StreamPendingReply = con.xpending("k99", "g99").unwrap(); + + assert_eq!(data_reply.count(), 3); + + if let StreamPendingReply::Data(data) = data_reply { + assert_stream_pending_data(data) + } else { + panic!("Expected StreamPendingReply::Data but got Empty"); + } + + // both count variations have the same reply types + let reply: StreamPendingCountReply = con.xpending_count("k99", "g99", "-", "+", 10).unwrap(); + assert_eq!(reply.ids.len(), 3); + + let reply: StreamPendingCountReply = con + .xpending_consumer_count("k99", "g99", "-", "+", 10, "c99") + .unwrap(); + assert_eq!(reply.ids.len(), 3); + + for StreamPendingId { + id, + consumer, + times_delivered, + last_delivered_ms: _, + } in reply.ids + { + assert!(!id.is_empty()); + assert!(!consumer.is_empty()); + assert!(times_delivered > 0); + } +} + +fn assert_stream_pending_data(data: StreamPendingData) { + assert_eq!(data.start_id, "1000-1"); + assert_eq!(data.end_id, "1001-1"); + assert_eq!(data.consumers.len(), 1); + assert_eq!(data.consumers[0].name, "c99"); +} + +#[test] +fn test_xadd_maxlen_map() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + for i in 0..10 { + let mut map: BTreeMap<&str, &str> = BTreeMap::new(); + let idx = i.to_string(); + map.insert("idx", &idx); + let _: RedisResult = + con.xadd_maxlen_map("maxlen_map", StreamMaxlen::Equals(3), "*", map); + } + + let result: RedisResult = con.xlen("maxlen_map"); + assert_eq!(result, Ok(3)); + let reply: StreamRangeReply = con.xrange_all("maxlen_map").unwrap(); + + assert_eq!(reply.ids[0].get("idx"), Some("7".to_string())); + assert_eq!(reply.ids[1].get("idx"), Some("8".to_string())); + assert_eq!(reply.ids[2].get("idx"), Some("9".to_string())); +} + +#[test] +fn test_xread_options_deleted_pel_entry() { + // Test xread_options behaviour with deleted entry + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let result: RedisResult = con.xgroup_create_mkstream("k1", "g1", "$"); + assert!(result.is_ok()); + let _: RedisResult = + con.xadd_maxlen("k1", StreamMaxlen::Equals(1), "*", &[("h1", "w1")]); + // read the pending items for this key & group + let result: StreamReadReply = con + .xread_options( + &["k1"], + &[">"], + &StreamReadOptions::default().group("g1", "c1"), + ) + .unwrap(); + + let _: RedisResult = + con.xadd_maxlen("k1", StreamMaxlen::Equals(1), "*", &[("h2", "w2")]); + let result_deleted_entry: StreamReadReply = con + .xread_options( + &["k1"], + &["0"], + &StreamReadOptions::default().group("g1", "c1"), + ) + .unwrap(); + assert_eq!( + result.keys[0].ids.len(), + result_deleted_entry.keys[0].ids.len() + ); + assert_eq!( + result.keys[0].ids[0].id, + result_deleted_entry.keys[0].ids[0].id + ); +} +#[test] +fn test_xclaim() { + // Tests the following commands.... + // xclaim + // xclaim_options + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + // xclaim test basic idea: + // 1. we need to test adding messages to a group + // 2. then xreadgroup needs to define a consumer and read pending + // messages without acking them + // 3. then we need to sleep 5ms and call xpending + // 4. from here we should be able to claim message + // past the idle time and read them from a different consumer + + // create the group + let result: RedisResult = con.xgroup_create_mkstream("k1", "g1", "$"); + assert!(result.is_ok()); + + // add some keys + xadd_keyrange(&mut con, "k1", 0, 10); + + // read the pending items for this key & group + let reply: StreamReadReply = con + .xread_options( + &["k1"], + &[">"], + &StreamReadOptions::default().group("g1", "c1"), + ) + .unwrap(); + // verify we have 10 ids + assert_eq!(reply.keys[0].ids.len(), 10); + + // save this StreamId for later + let claim = &reply.keys[0].ids[0]; + let _claim_1 = &reply.keys[0].ids[1]; + let claim_justids = &reply.keys[0] + .ids + .iter() + .map(|msg| &msg.id) + .collect::>(); + + // sleep for 5ms + sleep(Duration::from_millis(5)); + + // grab this id if > 4ms + let reply: StreamClaimReply = con + .xclaim("k1", "g1", "c2", 4, &[claim.id.clone()]) + .unwrap(); + assert_eq!(reply.ids.len(), 1); + assert_eq!(reply.ids[0].id, claim.id); + + // grab all pending ids for this key... + // we should 9 in c1 and 1 in c2 + let reply: StreamPendingReply = con.xpending("k1", "g1").unwrap(); + if let StreamPendingReply::Data(data) = reply { + assert_eq!(data.consumers[0].name, "c1"); + assert_eq!(data.consumers[0].pending, 9); + assert_eq!(data.consumers[1].name, "c2"); + assert_eq!(data.consumers[1].pending, 1); + } + + // sleep for 5ms + sleep(Duration::from_millis(5)); + + // lets test some of the xclaim_options + // call force on the same claim.id + let _: StreamClaimReply = con + .xclaim_options( + "k1", + "g1", + "c3", + 4, + &[claim.id.clone()], + StreamClaimOptions::default().with_force(), + ) + .unwrap(); + + let reply: StreamPendingReply = con.xpending("k1", "g1").unwrap(); + // we should have 9 w/ c1 and 1 w/ c3 now + if let StreamPendingReply::Data(data) = reply { + assert_eq!(data.consumers[1].name, "c3"); + assert_eq!(data.consumers[1].pending, 1); + } + + // sleep for 5ms + sleep(Duration::from_millis(5)); + + // claim and only return JUSTID + let claimed: Vec = con + .xclaim_options( + "k1", + "g1", + "c5", + 4, + claim_justids, + StreamClaimOptions::default().with_force().with_justid(), + ) + .unwrap(); + // we just claimed the original 10 ids + // and only returned the ids + assert_eq!(claimed.len(), 10); +} + +#[test] +fn test_xdel() { + // Tests the following commands.... + // xdel + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + // add some keys + xadd(&mut con); + + // delete the first stream item for this key + let result: RedisResult = con.xdel("k1", &["1000-0"]); + // returns the number of items deleted + assert_eq!(result, Ok(1)); + + let result: RedisResult = con.xdel("k2", &["2000-0", "2000-1", "2000-2"]); + // should equal 2 since the last id doesn't exist + assert_eq!(result, Ok(2)); +} + +#[test] +fn test_xtrim() { + // Tests the following commands.... + // xtrim + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + // add some keys + xadd_keyrange(&mut con, "k1", 0, 100); + + // trim key to 50 + // returns the number of items remaining in the stream + let result: RedisResult = con.xtrim("k1", StreamMaxlen::Equals(50)); + assert_eq!(result, Ok(50)); + // we should end up with 40 after this call + let result: RedisResult = con.xtrim("k1", StreamMaxlen::Equals(10)); + assert_eq!(result, Ok(40)); +} + +#[test] +fn test_xgroup() { + // Tests the following commands.... + // xgroup_create_mkstream + // xgroup_destroy + // xgroup_delconsumer + + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + // test xgroup create w/ mkstream @ 0 + let result: RedisResult = con.xgroup_create_mkstream("k1", "g1", "0"); + assert!(result.is_ok()); + + // destroy this new stream group + let result: RedisResult = con.xgroup_destroy("k1", "g1"); + assert_eq!(result, Ok(1)); + + // add some keys + xadd(&mut con); + + // create the group again using an existing stream + let result: RedisResult = con.xgroup_create("k1", "g1", "0"); + assert!(result.is_ok()); + + // read from the group so we can register the consumer + let reply: StreamReadReply = con + .xread_options( + &["k1"], + &[">"], + &StreamReadOptions::default().group("g1", "c1"), + ) + .unwrap(); + assert_eq!(reply.keys[0].ids.len(), 2); + + let result: RedisResult = con.xgroup_delconsumer("k1", "g1", "c1"); + // returns the number of pending message this client had open + assert_eq!(result, Ok(2)); + + let result: RedisResult = con.xgroup_destroy("k1", "g1"); + assert_eq!(result, Ok(1)); +} + +#[test] +fn test_xrange() { + // Tests the following commands.... + // xrange (-/+ variations) + // xrange_all + // xrange_count + + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + xadd(&mut con); + + // xrange replies + let reply: StreamRangeReply = con.xrange_all("k1").unwrap(); + assert_eq!(reply.ids.len(), 2); + + let reply: StreamRangeReply = con.xrange("k1", "1000-1", "+").unwrap(); + assert_eq!(reply.ids.len(), 1); + + let reply: StreamRangeReply = con.xrange("k1", "-", "1000-0").unwrap(); + assert_eq!(reply.ids.len(), 1); + + let reply: StreamRangeReply = con.xrange_count("k1", "-", "+", 1).unwrap(); + assert_eq!(reply.ids.len(), 1); +} + +#[test] +fn test_xrevrange() { + // Tests the following commands.... + // xrevrange (+/- variations) + // xrevrange_all + // xrevrange_count + + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + xadd(&mut con); + + // xrange replies + let reply: StreamRangeReply = con.xrevrange_all("k1").unwrap(); + assert_eq!(reply.ids.len(), 2); + + let reply: StreamRangeReply = con.xrevrange("k1", "1000-1", "-").unwrap(); + assert_eq!(reply.ids.len(), 2); + + let reply: StreamRangeReply = con.xrevrange("k1", "+", "1000-1").unwrap(); + assert_eq!(reply.ids.len(), 1); + + let reply: StreamRangeReply = con.xrevrange_count("k1", "+", "-", 1).unwrap(); + assert_eq!(reply.ids.len(), 1); +} diff --git a/glide-core/redis-rs/redis/tests/test_types.rs b/glide-core/redis-rs/redis/tests/test_types.rs new file mode 100644 index 0000000000..d5df513efb --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_types.rs @@ -0,0 +1,606 @@ +mod support; + +#[cfg(test)] +mod types { + use redis::{FromRedisValue, ToRedisArgs, Value}; + #[test] + fn test_is_single_arg() { + let sslice: &[_] = &["foo"][..]; + let nestslice: &[_] = &[sslice][..]; + let nestvec = vec![nestslice]; + let bytes = b"Hello World!"; + let twobytesslice: &[_] = &[bytes, bytes][..]; + let twobytesvec = vec![bytes, bytes]; + + assert!("foo".is_single_arg()); + assert!(sslice.is_single_arg()); + assert!(nestslice.is_single_arg()); + assert!(nestvec.is_single_arg()); + assert!(bytes.is_single_arg()); + + assert!(!twobytesslice.is_single_arg()); + assert!(!twobytesvec.is_single_arg()); + } + + /// The `FromRedisValue` trait provides two methods for parsing: + /// - `fn from_redis_value(&Value) -> Result` + /// - `fn from_owned_redis_value(Value) -> Result` + /// The `RedisParseMode` below allows choosing between the two + /// so that test logic does not need to be duplicated for each. + enum RedisParseMode { + Owned, + Ref, + } + + impl RedisParseMode { + /// Calls either `FromRedisValue::from_owned_redis_value` or + /// `FromRedisValue::from_redis_value`. + fn parse_redis_value( + &self, + value: redis::Value, + ) -> Result { + match self { + Self::Owned => redis::FromRedisValue::from_owned_redis_value(value), + Self::Ref => redis::FromRedisValue::from_redis_value(&value), + } + } + } + + #[test] + fn test_info_dict() { + use redis::{InfoDict, Value}; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let d: InfoDict = parse_mode + .parse_redis_value(Value::SimpleString( + "# this is a comment\nkey1:foo\nkey2:42\n".into(), + )) + .unwrap(); + + assert_eq!(d.get("key1"), Some("foo".to_string())); + assert_eq!(d.get("key2"), Some(42i64)); + assert_eq!(d.get::("key3"), None); + } + } + + #[test] + fn test_i32() { + use redis::{ErrorKind, Value}; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let i = parse_mode.parse_redis_value(Value::SimpleString("42".into())); + assert_eq!(i, Ok(42i32)); + + let i = parse_mode.parse_redis_value(Value::Int(42)); + assert_eq!(i, Ok(42i32)); + + let i = parse_mode.parse_redis_value(Value::BulkString("42".into())); + assert_eq!(i, Ok(42i32)); + + let bad_i: Result = + parse_mode.parse_redis_value(Value::SimpleString("42x".into())); + assert_eq!(bad_i.unwrap_err().kind(), ErrorKind::TypeError); + } + } + + #[test] + fn test_u32() { + use redis::{ErrorKind, Value}; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let i = parse_mode.parse_redis_value(Value::SimpleString("42".into())); + assert_eq!(i, Ok(42u32)); + + let bad_i: Result = + parse_mode.parse_redis_value(Value::SimpleString("-1".into())); + assert_eq!(bad_i.unwrap_err().kind(), ErrorKind::TypeError); + } + } + + #[test] + fn test_vec() { + use redis::Value; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::Array(vec![ + Value::BulkString("1".into()), + Value::BulkString("2".into()), + Value::BulkString("3".into()), + ])); + assert_eq!(v, Ok(vec![1i32, 2, 3])); + + let content: &[u8] = b"\x01\x02\x03\x04"; + let content_vec: Vec = Vec::from(content); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec.clone())); + assert_eq!(v, Ok(content_vec)); + + let content: &[u8] = b"1"; + let content_vec: Vec = Vec::from(content); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec.clone())); + assert_eq!(v, Ok(vec![b'1'])); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec)); + assert_eq!(v, Ok(vec![1_u16])); + } + } + + #[test] + fn test_box_slice() { + use redis::{FromRedisValue, Value}; + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::Array(vec![ + Value::BulkString("1".into()), + Value::BulkString("2".into()), + Value::BulkString("3".into()), + ])); + assert_eq!(v, Ok(vec![1i32, 2, 3].into_boxed_slice())); + + let content: &[u8] = b"\x01\x02\x03\x04"; + let content_vec: Vec = Vec::from(content); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec.clone())); + assert_eq!(v, Ok(content_vec.into_boxed_slice())); + + let content: &[u8] = b"1"; + let content_vec: Vec = Vec::from(content); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec.clone())); + assert_eq!(v, Ok(vec![b'1'].into_boxed_slice())); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec)); + assert_eq!(v, Ok(vec![1_u16].into_boxed_slice())); + + assert_eq!( + Box::<[i32]>::from_redis_value( + &Value::BulkString("just a string".into()) + ).unwrap_err().to_string(), + "Response was of incompatible type - TypeError: \"Conversion to alloc::boxed::Box<[i32]> failed.\" (response was bulk-string('\"just a string\"'))", + ); + } + } + + #[test] + fn test_arc_slice() { + use redis::{FromRedisValue, Value}; + use std::sync::Arc; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::Array(vec![ + Value::BulkString("1".into()), + Value::BulkString("2".into()), + Value::BulkString("3".into()), + ])); + assert_eq!(v, Ok(Arc::from(vec![1i32, 2, 3]))); + + let content: &[u8] = b"\x01\x02\x03\x04"; + let content_vec: Vec = Vec::from(content); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec.clone())); + assert_eq!(v, Ok(Arc::from(content_vec))); + + let content: &[u8] = b"1"; + let content_vec: Vec = Vec::from(content); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec.clone())); + assert_eq!(v, Ok(Arc::from(vec![b'1']))); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec)); + assert_eq!(v, Ok(Arc::from(vec![1_u16]))); + + assert_eq!( + Arc::<[i32]>::from_redis_value( + &Value::BulkString("just a string".into()) + ).unwrap_err().to_string(), + "Response was of incompatible type - TypeError: \"Conversion to alloc::sync::Arc<[i32]> failed.\" (response was bulk-string('\"just a string\"'))", + ); + } + } + + #[test] + fn test_single_bool_vec() { + use redis::Value; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::BulkString("1".into())); + + assert_eq!(v, Ok(vec![true])); + } + } + + #[test] + fn test_single_i32_vec() { + use redis::Value; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::BulkString("1".into())); + + assert_eq!(v, Ok(vec![1i32])); + } + } + + #[test] + fn test_single_u32_vec() { + use redis::Value; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::BulkString("42".into())); + + assert_eq!(v, Ok(vec![42u32])); + } + } + + #[test] + fn test_single_string_vec() { + use redis::Value; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::BulkString("1".into())); + assert_eq!(v, Ok(vec!["1".to_string()])); + } + } + + #[test] + fn test_tuple() { + use redis::Value; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::Array(vec![Value::Array(vec![ + Value::BulkString("1".into()), + Value::BulkString("2".into()), + Value::BulkString("3".into()), + ])])); + + assert_eq!(v, Ok(((1i32, 2, 3,),))); + } + } + + #[test] + fn test_hashmap() { + use fnv::FnvHasher; + use redis::{ErrorKind, Value}; + use std::collections::HashMap; + use std::hash::BuildHasherDefault; + + type Hm = HashMap; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v: Result = parse_mode.parse_redis_value(Value::Array(vec![ + Value::BulkString("a".into()), + Value::BulkString("1".into()), + Value::BulkString("b".into()), + Value::BulkString("2".into()), + Value::BulkString("c".into()), + Value::BulkString("3".into()), + ])); + let mut e: Hm = HashMap::new(); + e.insert("a".into(), 1); + e.insert("b".into(), 2); + e.insert("c".into(), 3); + assert_eq!(v, Ok(e)); + + type Hasher = BuildHasherDefault; + type HmHasher = HashMap; + let v: Result = parse_mode.parse_redis_value(Value::Array(vec![ + Value::BulkString("a".into()), + Value::BulkString("1".into()), + Value::BulkString("b".into()), + Value::BulkString("2".into()), + Value::BulkString("c".into()), + Value::BulkString("3".into()), + ])); + + let fnv = Hasher::default(); + let mut e: HmHasher = HashMap::with_hasher(fnv); + e.insert("a".into(), 1); + e.insert("b".into(), 2); + e.insert("c".into(), 3); + assert_eq!(v, Ok(e)); + + let v: Result = + parse_mode.parse_redis_value(Value::Array(vec![Value::BulkString("a".into())])); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + } + } + + #[test] + fn test_bool() { + use redis::{ErrorKind, Value}; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::BulkString("1".into())); + assert_eq!(v, Ok(true)); + + let v = parse_mode.parse_redis_value(Value::BulkString("0".into())); + assert_eq!(v, Ok(false)); + + let v: Result = + parse_mode.parse_redis_value(Value::BulkString("garbage".into())); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v = parse_mode.parse_redis_value(Value::SimpleString("1".into())); + assert_eq!(v, Ok(true)); + + let v = parse_mode.parse_redis_value(Value::SimpleString("0".into())); + assert_eq!(v, Ok(false)); + + let v: Result = + parse_mode.parse_redis_value(Value::SimpleString("garbage".into())); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v = parse_mode.parse_redis_value(Value::Okay); + assert_eq!(v, Ok(true)); + + let v = parse_mode.parse_redis_value(Value::Nil); + assert_eq!(v, Ok(false)); + + let v = parse_mode.parse_redis_value(Value::Int(0)); + assert_eq!(v, Ok(false)); + + let v = parse_mode.parse_redis_value(Value::Int(42)); + assert_eq!(v, Ok(true)); + } + } + + #[cfg(feature = "bytes")] + #[test] + fn test_bytes() { + use bytes::Bytes; + use redis::{ErrorKind, RedisResult, Value}; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let content: &[u8] = b"\x01\x02\x03\x04"; + let content_vec: Vec = Vec::from(content); + let content_bytes = Bytes::from_static(content); + + let v: RedisResult = + parse_mode.parse_redis_value(Value::BulkString(content_vec)); + assert_eq!(v, Ok(content_bytes)); + + let v: RedisResult = + parse_mode.parse_redis_value(Value::SimpleString("garbage".into())); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Okay); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Nil); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Int(0)); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Int(42)); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + } + } + + #[cfg(feature = "uuid")] + #[test] + fn test_uuid() { + use std::str::FromStr; + + use redis::{ErrorKind, FromRedisValue, RedisResult, Value}; + use uuid::Uuid; + + let uuid = Uuid::from_str("abab64b7-e265-4052-a41b-23e1e28674bf").unwrap(); + let bytes = uuid.as_bytes().to_vec(); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::BulkString(bytes)); + assert_eq!(v, Ok(uuid)); + + let v: RedisResult = + FromRedisValue::from_redis_value(&Value::SimpleString("garbage".into())); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Okay); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Nil); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Int(0)); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Int(42)); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + } + + #[test] + fn test_cstring() { + use redis::{ErrorKind, RedisResult, Value}; + use std::ffi::CString; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let content: &[u8] = b"\x01\x02\x03\x04"; + let content_vec: Vec = Vec::from(content); + + let v: RedisResult = + parse_mode.parse_redis_value(Value::BulkString(content_vec)); + assert_eq!(v, Ok(CString::new(content).unwrap())); + + let v: RedisResult = + parse_mode.parse_redis_value(Value::SimpleString("garbage".into())); + assert_eq!(v, Ok(CString::new("garbage").unwrap())); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Okay); + assert_eq!(v, Ok(CString::new("OK").unwrap())); + + let v: RedisResult = + parse_mode.parse_redis_value(Value::SimpleString("gar\0bage".into())); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Nil); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Int(0)); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Int(42)); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + } + } + + #[test] + fn test_types_to_redis_args() { + use redis::ToRedisArgs; + use std::collections::BTreeMap; + use std::collections::BTreeSet; + use std::collections::HashMap; + use std::collections::HashSet; + + assert!(!5i32.to_redis_args().is_empty()); + assert!(!"abc".to_redis_args().is_empty()); + assert!(!"abc".to_redis_args().is_empty()); + assert!(!String::from("x").to_redis_args().is_empty()); + + assert!(![5, 4] + .iter() + .cloned() + .collect::>() + .to_redis_args() + .is_empty()); + + assert!(![5, 4] + .iter() + .cloned() + .collect::>() + .to_redis_args() + .is_empty()); + + // this can be used on something HMSET + assert!(![("a", 5), ("b", 6), ("C", 7)] + .iter() + .cloned() + .collect::>() + .to_redis_args() + .is_empty()); + + // this can also be used on something HMSET + assert!(![("d", 8), ("e", 9), ("f", 10)] + .iter() + .cloned() + .collect::>() + .to_redis_args() + .is_empty()); + } + + #[test] + fn test_large_usize_array_to_redis_args_and_back() { + use crate::support::encode_value; + use redis::ToRedisArgs; + + let mut array = [0; 1000]; + for (i, item) in array.iter_mut().enumerate() { + *item = i; + } + + let vec = (&array).to_redis_args(); + assert_eq!(array.len(), vec.len()); + + let value = Value::Array( + vec.iter() + .map(|val| Value::BulkString(val.clone())) + .collect(), + ); + let mut encoded_input = Vec::new(); + encode_value(&value, &mut encoded_input).unwrap(); + + let new_array: [usize; 1000] = FromRedisValue::from_redis_value(&value).unwrap(); + assert_eq!(new_array, array); + } + + #[test] + fn test_large_u8_array_to_redis_args_and_back() { + use crate::support::encode_value; + use redis::ToRedisArgs; + + let mut array: [u8; 1000] = [0; 1000]; + for (i, item) in array.iter_mut().enumerate() { + *item = (i % 256) as u8; + } + + let vec = (&array).to_redis_args(); + assert_eq!(vec.len(), 1); + assert_eq!(array.len(), vec[0].len()); + + let value = Value::Array(vec[0].iter().map(|val| Value::Int(*val as i64)).collect()); + let mut encoded_input = Vec::new(); + encode_value(&value, &mut encoded_input).unwrap(); + + let new_array: [u8; 1000] = FromRedisValue::from_redis_value(&value).unwrap(); + assert_eq!(new_array, array); + } + + #[test] + fn test_large_string_array_to_redis_args_and_back() { + use crate::support::encode_value; + use redis::ToRedisArgs; + + let mut array: [String; 1000] = [(); 1000].map(|_| String::new()); + for (i, item) in array.iter_mut().enumerate() { + *item = format!("{i}"); + } + + let vec = (&array).to_redis_args(); + assert_eq!(array.len(), vec.len()); + + let value = Value::Array( + vec.iter() + .map(|val| Value::BulkString(val.clone())) + .collect(), + ); + let mut encoded_input = Vec::new(); + encode_value(&value, &mut encoded_input).unwrap(); + + let new_array: [String; 1000] = FromRedisValue::from_redis_value(&value).unwrap(); + assert_eq!(new_array, array); + } + + #[test] + fn test_0_length_usize_array_to_redis_args_and_back() { + use crate::support::encode_value; + use redis::ToRedisArgs; + + let array: [usize; 0] = [0; 0]; + + let vec = (&array).to_redis_args(); + assert_eq!(array.len(), vec.len()); + + let value = Value::Array( + vec.iter() + .map(|val| Value::BulkString(val.clone())) + .collect(), + ); + let mut encoded_input = Vec::new(); + encode_value(&value, &mut encoded_input).unwrap(); + + let new_array: [usize; 0] = FromRedisValue::from_redis_value(&value).unwrap(); + assert_eq!(new_array, array); + + let new_array: [usize; 0] = FromRedisValue::from_redis_value(&Value::Nil).unwrap(); + assert_eq!(new_array, array); + } + + #[test] + fn test_attributes() { + use redis::{parse_redis_value, FromRedisValue, Value}; + let bytes: &[u8] = b"*3\r\n:1\r\n:2\r\n|1\r\n+ttl\r\n:3600\r\n:3\r\n"; + let val = parse_redis_value(bytes).unwrap(); + { + // The case user doesn't expect attributes from server + let x: Vec = redis::FromRedisValue::from_redis_value(&val).unwrap(); + assert_eq!(x, vec![1, 2, 3]); + } + { + // The case user wants raw value from server + let x: Value = FromRedisValue::from_redis_value(&val).unwrap(); + assert_eq!( + x, + Value::Array(vec![ + Value::Int(1), + Value::Int(2), + Value::Attribute { + data: Box::new(Value::Int(3)), + attributes: vec![( + Value::SimpleString("ttl".to_string()), + Value::Int(3600) + )] + } + ]) + ) + } + } +} diff --git a/glide-core/redis-rs/release.sh b/glide-core/redis-rs/release.sh new file mode 100755 index 0000000000..f01241c382 --- /dev/null +++ b/glide-core/redis-rs/release.sh @@ -0,0 +1,15 @@ +#!/bin/sh +set -ex + +LEVEL=$1 +if [ -z "$LEVEL" ]; then + echo "Expected patch, minor or major" + exit 1 +fi + +clog --$LEVEL + +git add CHANGELOG.md +git commit -m "Update changelog" + +cargo release --execute $LEVEL diff --git a/glide-core/redis-rs/rustfmt.toml b/glide-core/redis-rs/rustfmt.toml new file mode 100644 index 0000000000..0d564415cb --- /dev/null +++ b/glide-core/redis-rs/rustfmt.toml @@ -0,0 +1,2 @@ +use_try_shorthand = true +edition = "2018" diff --git a/glide-core/redis-rs/scripts/get_command_info.py b/glide-core/redis-rs/scripts/get_command_info.py new file mode 100644 index 0000000000..dcba666bff --- /dev/null +++ b/glide-core/redis-rs/scripts/get_command_info.py @@ -0,0 +1,227 @@ +import argparse +import json +import os +from os.path import join + +"""Valkey command categorizer + +This script analyzes command info json files and categorizes the commands based on their routing. The output can be used +to map commands in the cluster_routing.rs#base_routing function to their RouteBy category. Commands that cannot be +categorized by the script will be listed under the "Uncategorized" section. These commands will need to be manually +categorized. + +To use the script: +1. Clone https://github.com/valkey-io/valkey +2. cd into the cloned valkey repository and checkout the desired version of the code, eg 7.2.5 +3. cd into the directory containing this script +4. run: + python get_command_info.py --commands-dir=/valkey/src/commands +""" + + +class CommandCategory: + def __init__(self, name, description): + self.name = name + self.description = description + self.commands = [] + + def add_command(self, command_name): + self.commands.append(command_name) + + +def main(): + parser = argparse.ArgumentParser( + description="Analyzes command info json and categorizes commands into their RouteBy categories") + parser.add_argument( + "--commands-dir", + type=str, + help="Path to the directory containing the command info json files (example: ../../valkey/src/commands)", + required=True, + ) + + args = parser.parse_args() + commands_dir = args.commands_dir + if not os.path.exists(commands_dir): + raise parser.error("The command info directory passed to the '--commands-dir' argument does not exist") + + all_nodes = CommandCategory("AllNodes", "Commands with an ALL_NODES request policy") + all_primaries = CommandCategory("AllPrimaries", "Commands with an ALL_SHARDS request policy") + multi_shard = CommandCategory("MultiShardNoValues or MultiShardWithValues", + "Commands with a MULTI_SHARD request policy") + first_arg = CommandCategory("FirstKey", "Commands with their first key argument at position 1") + second_arg = CommandCategory("SecondArg", "Commands with their first key argument at position 2") + second_arg_numkeys = ( + CommandCategory("SecondArgAfterKeyCount", + "Commands with their first key argument at position 2, after a numkeys argument")) + # all commands with their first key argument at position 3 have a numkeys argument at position 2, + # so there is a ThirdArgAfterKeyCount category but no ThirdArg category + third_arg_numkeys = ( + CommandCategory("ThirdArgAfterKeyCount", + "Commands with their first key argument at position 3, after a numkeys argument")) + streams_index = CommandCategory("StreamsIndex", "Commands that include a STREAMS token") + second_arg_slot = CommandCategory("SecondArgSlot", "Commands with a slot argument at position 2") + uncategorized = ( + CommandCategory( + "Uncategorized", + "Commands that don't fall into the other categories. These commands will have to be manually categorized.")) + + categories = [all_nodes, all_primaries, multi_shard, first_arg, second_arg, second_arg_numkeys, third_arg_numkeys, + streams_index, second_arg_slot, uncategorized] + + print("Gathering command info...\n") + + for filename in os.listdir(commands_dir): + file_path = join(commands_dir, filename) + _, file_extension = os.path.splitext(file_path) + if file_extension != ".json": + print(f"Note: {filename} is not a json file and will thus be ignored") + continue + + file = open(file_path) + command_json = json.load(file) + if len(command_json) == 0: + raise Exception( + f"The json for {filename} was empty. A json object with information about the command was expected.") + + command_name = next(iter(command_json)) + command_info = command_json[command_name] + if "container" in command_info: + # for two-word commands like 'XINFO GROUPS', the `next(iter(command_json))` statement above returns 'GROUPS' + # and `command_info['container']` returns 'XINFO' + command_name = f"{command_info['container']} {command_name}" + + if "command_tips" in command_info: + request_policy = get_request_policy(command_info["command_tips"]) + if request_policy == "ALL_NODES": + all_nodes.add_command(command_name) + continue + elif request_policy == "ALL_SHARDS": + all_primaries.add_command(command_name) + continue + elif request_policy == "MULTI_SHARD": + multi_shard.add_command(command_name) + continue + + if "arguments" not in command_info: + uncategorized.add_command(command_name) + continue + + command_args = command_info["arguments"] + split_name = command_name.split() + if len(split_name) == 0: + raise Exception(f"Encountered json with an empty command name in file '{filename}'") + + json_key_index, is_key_optional = get_first_key_info(command_args) + # cluster_routing.rs can handle optional keys if a keycount of 0 is provided, otherwise the command should + # fall under the "Uncategorized" section to indicate it will need to be manually inspected + if is_key_optional and not is_after_numkeys(command_args, json_key_index): + uncategorized.add_command(command_name) + continue + + if json_key_index == -1: + # the command does not have a key argument, check for a slot argument + json_slot_index, is_slot_optional = get_first_slot_info(command_args) + if is_slot_optional: + uncategorized.add_command(command_name) + continue + + # cluster_routing.rs considers each word in the command name to be an argument, but the json does not + cluster_routing_slot_index = -1 if json_slot_index == -1 else len(split_name) + json_slot_index + if cluster_routing_slot_index == 2: + second_arg_slot.add_command(command_name) + continue + + # the command does not have a slot argument, check for a "STREAMS" token + if has_streams_token(command_args): + streams_index.add_command(command_name) + continue + + uncategorized.add_command(command_name) + continue + + # cluster_routing.rs considers each word in the command name to be an argument, but the json does not + cluster_routing_key_index = -1 if json_key_index == -1 else len(split_name) + json_key_index + if cluster_routing_key_index == 1: + first_arg.add_command(command_name) + continue + elif cluster_routing_key_index == 2: + if is_after_numkeys(command_args, json_key_index): + second_arg_numkeys.add_command(command_name) + continue + else: + second_arg.add_command(command_name) + continue + # there aren't any commands that fall into a ThirdArg category, + # but there are commands that fall under ThirdArgAfterKeyCount category + elif cluster_routing_key_index == 3 and is_after_numkeys(command_args, json_key_index): + third_arg_numkeys.add_command(command_name) + continue + + uncategorized.add_command(command_name) + + print("\nNote: the following information considers each word in the command name to be an argument") + print("For example, for 'XGROUP DESTROY key group':") + print("'XGROUP' is arg0, 'DESTROY' is arg1, 'key' is arg2, and 'group' is arg3.\n") + + for category in categories: + print_category(category) + + +def get_request_policy(command_tips): + for command_tip in command_tips: + if command_tip.startswith("REQUEST_POLICY:"): + return command_tip[len("REQUEST_POLICY:"):] + + return None + + +def get_first_key_info(args_info_json) -> tuple[int, bool]: + for i in range(len(args_info_json)): + info = args_info_json[i] + if info["type"].lower() == "key": + is_optional = "optional" in info and info["optional"] + return i, is_optional + + return -1, False + + +def get_first_slot_info(args_info_json) -> tuple[int, bool]: + for i in range(len(args_info_json)): + info = args_info_json[i] + if info["name"].lower() == "slot": + is_optional = "optional" in info and info["optional"] + return i, is_optional + + return -1, False + + +def is_after_numkeys(args_info_json, json_index): + return json_index > 0 and args_info_json[json_index - 1]["name"].lower() == "numkeys" + + +def has_streams_token(args_info_json): + for arg_info in args_info_json: + if "token" in arg_info and arg_info["token"].upper() == "STREAMS": + return True + + return False + + +def print_category(category): + print("============================") + print(f"Category: {category.name} commands") + print(f"Description: {category.description}") + print("List of commands in this category:\n") + + if len(category.commands) == 0: + print("(No commands found for this category)") + else: + category.commands.sort() + for command_name in category.commands: + print(f"{command_name}") + + print("\n") + + +if __name__ == "__main__": + main() diff --git a/glide-core/redis-rs/scripts/update-versions.sh b/glide-core/redis-rs/scripts/update-versions.sh new file mode 100755 index 0000000000..f2800985f0 --- /dev/null +++ b/glide-core/redis-rs/scripts/update-versions.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +# This script is pretty low tech, but it helps keep the doc version numbers +# up to date. It should be run as a `pre-release-hook` from cargo-release. + +set -eo pipefail + +if [ -z "$PREV_VERSION" ] || [ -z "$NEW_VERSION" ]; then + echo "Missing PREV_VERSION or NEW_VERSION." + echo "This script needs to run as a 'pre-release-hook' from cargo-release." + exit 1 +fi + +for file in README.md; do + sed -i.bak -E \ + -e "s|version=[0-9.]+|version=${NEW_VERSION}|g" \ + -e "s|redis/[0-9.]+|redis/${NEW_VERSION}|g" \ + -e "s|redis = \"[0-9.]+\"|redis = \"${NEW_VERSION}\"|g" \ + "${CRATE_ROOT}/$file" + rm "${CRATE_ROOT}/$file.bak" +done diff --git a/glide-core/redis-rs/upload-docs.sh b/glide-core/redis-rs/upload-docs.sh new file mode 100755 index 0000000000..4f6d01cd0f --- /dev/null +++ b/glide-core/redis-rs/upload-docs.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +# Make a new repo for the gh-pages branch +rm -rf .gh-pages +mkdir .gh-pages +cd .gh-pages +git init + +# Copy over the documentation +cp -r ../target/doc/* . +cat < index.html + +redis + +EOF + +# Add, commit and push files +git add -f --all . +git commit -m "Built documentation" +git checkout -b gh-pages +git remote add origin git@github.com:mitsuhiko/redis-rs.git +git push -qf origin gh-pages + +# Cleanup +cd .. +rm -rf .gh-pages diff --git a/go/Cargo.toml b/go/Cargo.toml index 62872578da..05d34e7108 100644 --- a/go/Cargo.toml +++ b/go/Cargo.toml @@ -9,7 +9,7 @@ authors = ["Valkey GLIDE Maintainers"] crate-type = ["cdylib"] [dependencies] -redis = { path = "../submodules/redis-rs/redis", features = ["aio", "tokio-comp", "tls", "tokio-native-tls-comp", "tls-rustls-insecure"] } +redis = { path = "../glide-core/redis-rs/redis", features = ["aio", "tokio-comp", "tls", "tokio-native-tls-comp", "tls-rustls-insecure"] } glide-core = { path = "../glide-core", features = ["socket-layer"] } tokio = { version = "^1", features = ["rt", "macros", "rt-multi-thread", "time"] } protobuf = { version = "3.3.0", features = [] } diff --git a/go/DEVELOPER.md b/go/DEVELOPER.md index 023828a0cf..ad3ded8e57 100644 --- a/go/DEVELOPER.md +++ b/go/DEVELOPER.md @@ -105,7 +105,7 @@ Before starting this step, make sure you've installed all software requirements. git clone --branch ${VERSION} https://github.com/valkey-io/valkey-glide.git cd valkey-glide ``` -2. Initialize git submodule: +2. Initialize git submodules: ```bash git submodule update --init --recursive ``` @@ -163,7 +163,7 @@ go test -race ./... -run TestConnectionRequestProtobufGeneration_allFieldsSet -v After pulling new changes, ensure that you update the submodules by running the following command: ```bash -git submodule update +git submodule update --init --recursive ``` ### Generate protobuf files diff --git a/java/Cargo.toml b/java/Cargo.toml index 6428f67fa6..c8fa49fe3f 100644 --- a/java/Cargo.toml +++ b/java/Cargo.toml @@ -10,7 +10,7 @@ authors = ["Valkey GLIDE Maintainers"] crate-type = ["cdylib"] [dependencies] -redis = { path = "../submodules/redis-rs/redis", features = ["aio", "tokio-comp", "connection-manager", "tokio-rustls-comp"] } +redis = { path = "../glide-core/redis-rs/redis", features = ["aio", "tokio-comp", "connection-manager", "tokio-rustls-comp"] } glide-core = { path = "../glide-core", features = ["socket-layer"] } tokio = { version = "^1", features = ["rt", "macros", "rt-multi-thread", "time"] } logger_core = {path = "../logger_core"} diff --git a/node/DEVELOPER.md b/node/DEVELOPER.md index f71966862e..a3391c3282 100644 --- a/node/DEVELOPER.md +++ b/node/DEVELOPER.md @@ -70,6 +70,7 @@ Before starting this step, make sure you've installed all software requirments. git submodule update --init --recursive ``` 3. Install all node dependencies: + ```bash cd node npm i @@ -77,6 +78,7 @@ Before starting this step, make sure you've installed all software requirments. npm i cd .. ``` + 4. Build the Node wrapper (Choose a build option from the following and run it from the `node` folder): 1. Build in release mode, stripped from all debug symbols (optimized and minimized binary size): diff --git a/node/rust-client/Cargo.toml b/node/rust-client/Cargo.toml index e9e2af8851..f9baaf6cc2 100644 --- a/node/rust-client/Cargo.toml +++ b/node/rust-client/Cargo.toml @@ -11,7 +11,7 @@ authors = ["Valkey GLIDE Maintainers"] crate-type = ["cdylib"] [dependencies] -redis = { path = "../../submodules/redis-rs/redis", features = ["aio", "tokio-comp", "tokio-rustls-comp"] } +redis = { path = "../../glide-core/redis-rs/redis", features = ["aio", "tokio-comp", "tokio-rustls-comp"] } glide-core = { path = "../../glide-core", features = ["socket-layer"] } tokio = { version = "1", features = ["rt", "macros", "rt-multi-thread", "time"] } napi = {version = "2.14", features = ["napi4", "napi6"] } diff --git a/python/Cargo.toml b/python/Cargo.toml index 16632945bb..3945322cd2 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -13,7 +13,7 @@ crate-type = ["cdylib"] [dependencies] pyo3 = { version = "^0.20", features = ["extension-module", "num-bigint"] } bytes = { version = "1.6.0" } -redis = { path = "../submodules/redis-rs/redis", features = ["aio", "tokio-comp", "connection-manager","tokio-rustls-comp"] } +redis = { path = "../glide-core/redis-rs/redis", features = ["aio", "tokio-comp", "connection-manager","tokio-rustls-comp"] } glide-core = { path = "../glide-core", features = ["socket-layer"] } logger_core = {path = "../logger_core"} diff --git a/python/DEVELOPER.md b/python/DEVELOPER.md index abf12dc9a3..a3e5b07237 100644 --- a/python/DEVELOPER.md +++ b/python/DEVELOPER.md @@ -2,15 +2,13 @@ This document describes how to set up your development environment to build and test the Valkey GLIDE Python wrapper. -### Development Overview - The Valkey GLIDE Python wrapper consists of both Python and Rust code. Rust bindings for Python are implemented using [PyO3](https://github.com/PyO3/pyo3), and the Python package is built using [maturin](https://github.com/PyO3/maturin). The Python and Rust components communicate using the [protobuf](https://github.com/protocolbuffers/protobuf) protocol. -### Build from source +# Prerequisites +--- -#### Prerequisites +Before building the package from source, make sure that you have installed the listed dependencies below: -Software Dependencies - python3 virtualenv - git @@ -21,7 +19,10 @@ Software Dependencies - openssl-dev - rustup -**Dependencies installation for Ubuntu** +For your convenience, we wrapped the steps in a "copy-paste" code blocks for common operating systems: + +
+Ubuntu / Debian ```bash sudo apt update -y @@ -42,7 +43,10 @@ export PATH="$PATH:$HOME/.local/bin" protoc --version ``` -**Dependencies installation for CentOS** +
+ +
+CentOS ```bash sudo yum update -y @@ -62,7 +66,10 @@ export PATH="$PATH:$HOME/.local/bin" protoc --version ``` -**Dependencies installation for MacOS** +
+ +
+MacOS ```bash brew update @@ -80,112 +87,108 @@ source /Users/$USER/.bash_profile protoc --version ``` -#### Building and installation steps - -Before starting this step, make sure you've installed all software requirments. +
-1. Clone the repository: - ```bash - git clone https://github.com/valkey-io/valkey-glide.git - cd valkey-glide - ``` -2. Initialize git submodule: - ```bash - git submodule update --init --recursive - ``` -3. Generate protobuf files: - ```bash - GLIDE_ROOT_FOLDER_PATH=. - protoc -Iprotobuf=${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/ --python_out=${GLIDE_ROOT_FOLDER_PATH}/python/python/glide ${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/*.proto - ``` -4. Create a virtual environment: - ```bash - cd python - python3 -m venv .env - ``` -5. Activate the virtual environment: - ```bash - source .env/bin/activate - ``` -6. Install requirements: - ```bash - pip install -r requirements.txt - ``` -7. Build the Python wrapper in release mode: - ``` - maturin develop --release --strip - ``` - > **Note:** To build the wrapper binary with debug symbols remove the --strip flag. -8. Run tests: - 1. Ensure that you have installed redis-server or valkey-server and redis-cli or valkey-cli on your host. You can find the Redis installation guide at the following link: [Redis Installation Guide](https://redis.io/docs/install/install-redis/install-redis-on-linux/). You can get Valkey from the following link: [Valkey Download](https://valkey.io/download/). - 2. Validate the activation of the virtual environment from step 4 by ensuring its name (`.env`) is displayed next to your command prompt. - 3. Execute the following command from the python folder: - ```bash - pytest --asyncio-mode=auto - ``` - > **Note:** To run Valkey modules tests, add -k "test_server_modules.py". - -- Install Python development requirements with: +# Building +--- - ```bash - pip install -r python/dev_requirements.txt - ``` +Before starting this step, make sure you've installed all software requirements. -- For a fast build, execute `maturin develop` without the release flag. This will perform an unoptimized build, which is suitable for developing tests. Keep in mind that performance is significantly affected in an unoptimized build, so it's required to include the "--release" flag when measuring performance. +## Prepare your environment -### Test +```bash +mkdir -p $HOME/src +cd $_ +git clone https://github.com/valkey-io/valkey-glide.git +cd valkey-glide +GLIDE_ROOT=$(pwd) +protoc -Iprotobuf=${GLIDE_ROOT}/glide-core/src/protobuf/ \ + --python_out=${GLIDE_ROOT}/python/python/glide \ + ${GLIDE_ROOT}/glide-core/src/protobuf/*.proto +cd python +python3 -m venv .env +source .env/bin/activate +pip install -r requirements.txt +pip install -r python/dev_requirements.txt +``` -To run tests, use the following command: +## Build the package (in release mode): ```bash -pytest --asyncio-mode=auto +maturin develop --release --strip ``` + +> **Note:** to build the wrapper binary with debug symbols remove the `--strip` flag. + +> **Note 2:** for a faster build time, execute `maturin develop` without the release flag. This will perform an unoptimized build, which is suitable for developing tests. Keep in mind that performance is significantly affected in an unoptimized build, so it's required to include the `--release` flag when measuring performance. -To execute a specific test, include the `-k ` option. For example: +# Running tests +--- + +Ensure that you have installed `redis-server` or `valkey-server` along with `redis-cli` or `valkey-cli` on your host. You can find the Redis installation guide at the following link: [Redis Installation Guide](https://redis.io/docs/install/install-redis/install-redis-on-linux/). You can get Valkey from the following link: [Valkey Download](https://valkey.io/download/). + +From a terminal, change directory to the GLIDE source folder and type: ```bash -pytest --asyncio-mode=auto -k test_socket_set_and_get +cd $HOME/src/valkey-glide +cd python +source .env/bin/activate +pytest --asyncio-mode=auto ``` -IT suite starts the server for testing - standalone and cluster installation using `cluster_manager` script. -If you want IT to use already started servers, use the following command line from `python/python` dir: +To run modules tests: ```bash -pytest --asyncio-mode=auto --cluster-endpoints=localhost:7000 --standalone-endpoints=localhost:6379 +cd $HOME/src/valkey-glide +cd python +source .env/bin/activate +pytest --asyncio-mode=auto -k "test_server_modules.py" ``` -### Submodules +**TIP:** to run a specific test, append `-k ` to the `pytest` execution line -After pulling new changes, ensure that you update the submodules by running the following command: +To run tests against an already running servers, change the `pytest` line above to this: ```bash -git submodule update +pytest --asyncio-mode=auto --cluster-endpoints=localhost:7000 --standalone-endpoints=localhost:6379 ``` -### Generate protobuf files +# Generate protobuf files +--- -During the initial build, Python protobuf files were created in `python/python/glide/protobuf`. If modifications are made to the protobuf definition files (.proto files located in `glide-core/src/protofuf`), it becomes necessary to regenerate the Python protobuf files. To do so, run: +During the initial build, Python protobuf files were created in `python/python/glide/protobuf`. If modifications are made +to the protobuf definition files (`.proto` files located in `glide-core/src/protofuf`), it becomes necessary to +regenerate the Python protobuf files. To do so, run: ```bash -GLIDE_ROOT_FOLDER_PATH=. # e.g. /home/ubuntu/valkey-glide -protoc -Iprotobuf=${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/ --python_out=${GLIDE_ROOT_FOLDER_PATH}/python/python/glide ${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/*.proto +cd $HOME/src/valkey-glide +GLIDE_ROOT_FOLDER_PATH=. +protoc -Iprotobuf=${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/ \ + --python_out=${GLIDE_ROOT_FOLDER_PATH}/python/python/glide \ + ${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/*.proto ``` -#### Protobuf interface files +## Protobuf interface files To generate the protobuf files with Python Interface files (pyi) for type-checking purposes, ensure you have installed `mypy-protobuf` with pip, and then execute the following command: ```bash -GLIDE_ROOT_FOLDER_PATH=. # e.g. /home/ubuntu/valkey-glide +cd $HOME/src/valkey-glide +GLIDE_ROOT_FOLDER_PATH=. MYPY_PROTOC_PATH=`which protoc-gen-mypy` -protoc --plugin=protoc-gen-mypy=${MYPY_PROTOC_PATH} -Iprotobuf=${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/ --python_out=${GLIDE_ROOT_FOLDER_PATH}/python/python/glide --mypy_out=${GLIDE_ROOT_FOLDER_PATH}/python/python/glide ${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/*.proto +protoc --plugin=protoc-gen-mypy=${MYPY_PROTOC_PATH} \ + -Iprotobuf=${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/ \ + --python_out=${GLIDE_ROOT_FOLDER_PATH}/python/python/glide \ + --mypy_out=${GLIDE_ROOT_FOLDER_PATH}/python/python/glide \ + ${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/*.proto ``` -### Linters +# Linters +--- Development on the Python wrapper may involve changes in either the Python or Rust code. Each language has distinct linter tests that must be passed before committing changes. -#### Language-specific Linters +## Language-specific Linters **Python:** @@ -199,31 +202,37 @@ Development on the Python wrapper may involve changes in either the Python or Ru - clippy - fmt -#### Running the linters +## Running the linters Run from the main `/python` folder 1. Python - > Note: make sure to [generate protobuf with interface files]("#protobuf-interface-files") before running mypy linter + > Note: make sure to [generate protobuf with interface files]("#protobuf-interface-files") before running `mypy` linter ```bash + cd $HOME/src/valkey-glide/python + source .env/bin/activate pip install -r dev_requirements.txt isort . --profile black --skip-glob python/glide/protobuf --skip-glob .env black . --exclude python/glide/protobuf --exclude .env - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=python/glide/protobuf,.env/* --extend-ignore=E230 - flake8 . --count --exit-zero --max-complexity=12 --max-line-length=127 --statistics --exclude=python/glide/protobuf,.env/* --extend-ignore=E230 + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics \ + --exclude=python/glide/protobuf,.env/* --extend-ignore=E230 + flake8 . --count --exit-zero --max-complexity=12 --max-line-length=127 \ + --statistics --exclude=python/glide/protobuf,.env/* \ + --extend-ignore=E230 # run type check mypy . ``` + 2. Rust ```bash rustup component add clippy rustfmt cargo clippy --all-features --all-targets -- -D warnings cargo fmt --manifest-path ./Cargo.toml --all - ``` -### Recommended extensions for VS Code +# Recommended extensions for VS Code +--- - [Python](https://marketplace.visualstudio.com/items?itemName=ms-python.python) - [isort](https://marketplace.visualstudio.com/items?itemName=ms-python.isort) diff --git a/python/python/tests/test_async_client.py b/python/python/tests/test_async_client.py index 7566194dcc..c9744157d6 100644 --- a/python/python/tests/test_async_client.py +++ b/python/python/tests/test_async_client.py @@ -1,4 +1,5 @@ # Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 +# mypy: disable_error_code="arg-type" from __future__ import annotations diff --git a/submodules/redis-rs b/submodules/redis-rs deleted file mode 160000 index 396536db31..0000000000 --- a/submodules/redis-rs +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 396536db31fbf2de0f272d8179d68286329fa70e From 1b3cad345b4f9f85fd38baba2930d0b819e50037 Mon Sep 17 00:00:00 2001 From: Muhammad-awawdi-amazon Date: Tue, 15 Oct 2024 13:55:31 +0300 Subject: [PATCH 009/180] Python: add JSON.NUMINCRBY command (#2448) --------- Signed-off-by: Muhammad Awawdi Signed-off-by: Muhammad-awawdi-amazon Co-authored-by: Shoham Elias <116083498+shohamazon@users.noreply.github.com> --- CHANGELOG.md | 3 + .../async_commands/server_modules/json.py | 42 ++++++ .../tests/tests_server_modules/test_json.py | 127 ++++++++++++++++++ 3 files changed, 172 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index dc7d93523e..efa77a972f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ * Python: Python: Added FT.CREATE command([#2413](https://github.com/valkey-io/valkey-glide/pull/2413)) * Python: Add JSON.ARRLEN command ([#2403](https://github.com/valkey-io/valkey-glide/pull/2403)) * Python: Add JSON.CLEAR command ([#2418](https://github.com/valkey-io/valkey-glide/pull/2418)) +* Python: Add JSON.TYPE command ([#2409](https://github.com/valkey-io/valkey-glide/pull/2409)) +* Python: Add JSON.NUMINCRBY command ([#2448](https://github.com/valkey-io/valkey-glide/pull/2448)) + * Java: Added `FT.CREATE` ([#2414](https://github.com/valkey-io/valkey-glide/pull/2414)) diff --git a/python/python/glide/async_commands/server_modules/json.py b/python/python/glide/async_commands/server_modules/json.py index 1864132451..31ee9fe10e 100644 --- a/python/python/glide/async_commands/server_modules/json.py +++ b/python/python/glide/async_commands/server_modules/json.py @@ -320,6 +320,48 @@ async def forget( ) +async def numincrby( + client: TGlideClient, + key: TEncodable, + path: TEncodable, + number: Union[int, float], +) -> Optional[bytes]: + """ + Increments or decrements the JSON value(s) at the specified `path` by `number` within the JSON document stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (TEncodable): The path within the JSON document. + number (Union[int, float]): The number to increment or decrement by. + + Returns: + Optional[bytes]: + For JSONPath (`path` starts with `$`): + Returns a bytes string representation of an array of bulk strings, indicating the new values after incrementing for each matched `path`. + If a value is not a number, its corresponding return value will be `null`. + If `path` doesn't exist, a byte string representation of an empty array will be returned. + For legacy path (`path` doesn't start with `$`): + Returns a bytes string representation of the resulting value after the increment or decrement. + If multiple paths match, the result of the last updated value is returned. + If the value at the `path` is not a number or `path` doesn't exist, an error is raised. + If `key` does not exist, an error is raised. + If the result is out of the range of 64-bit IEEE double, an error is raised. + + Examples: + >>> from glide import json + >>> await json.set(client, "doc", "$", '{"a": [], "b": [1], "c": [1, 2], "d": [1, 2, 3]}') + 'OK' + >>> await json.numincrby(client, "doc", "$.d[*]", 10)› + b'[11,12,13]' # Increment each element in `d` array by 10. + >>> await json.numincrby(client, "doc", ".c[1]", 10) + b'12' # Increment the second element in the `c` array by 10. + """ + args = ["JSON.NUMINCRBY", key, path, str(number)] + + return cast(Optional[bytes], await client.custom_command(args)) + + async def toggle( client: TGlideClient, key: TEncodable, diff --git a/python/python/tests/tests_server_modules/test_json.py b/python/python/tests/tests_server_modules/test_json.py index d21b11686b..18d7ff525b 100644 --- a/python/python/tests/tests_server_modules/test_json.py +++ b/python/python/tests/tests_server_modules/test_json.py @@ -359,3 +359,130 @@ async def test_json_clear(self, glide_client: TGlideClient): with pytest.raises(RequestError): await json.clear(glide_client, "non_existing_key", ".") + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_numincrby(self, glide_client: TGlideClient): + key = get_random_string(10) + + json_value = { + "key1": 1, + "key2": 3.5, + "key3": {"nested_key": {"key1": [4, 5]}}, + "key4": [1, 2, 3], + "key5": 0, + "key6": "hello", + "key7": None, + "key8": {"nested_key": {"key1": 69}}, + "key9": 1.7976931348623157e308, + } + + # Set the initial JSON document at the key + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + + # Test JSONPath + # Increment integer value (key1) by 5 + result = await json.numincrby(glide_client, key, "$.key1", 5) + assert result == b"[6]" # Expect 1 + 5 = 6 + + # Increment float value (key2) by 2.5 + result = await json.numincrby(glide_client, key, "$.key2", 2.5) + assert result == b"[6]" # Expect 3.5 + 2.5 = 6 + + # Increment nested object (key3.nested_key.key1[0]) by 7 + result = await json.numincrby(glide_client, key, "$.key3.nested_key.key1[1]", 7) + assert result == b"[12]" # Expect 4 + 7 = 12 + + # Increment array element (key4[1]) by 1 + result = await json.numincrby(glide_client, key, "$.key4[1]", 1) + assert result == b"[3]" # Expect 2 + 1 = 3 + + # Increment zero value (key5) by 10.23 (float number) + result = await json.numincrby(glide_client, key, "$.key5", 10.23) + assert result == b"[10.23]" # Expect 0 + 10.23 = 10.23 + + # Increment a string value (key6) by a number + result = await json.numincrby(glide_client, key, "$.key6", 99) + assert result == b"[null]" # Expect null + + # Increment a None value (key7) by a number + result = await json.numincrby(glide_client, key, "$.key7", 51) + assert result == b"[null]" # Expect null + + # Check increment for all numbers in the document using JSON Path (First Null: key3 as an entire object. Second Null: The path checks under key3, which is an object, for numeric values). + result = await json.numincrby(glide_client, key, "$..*", 5) + assert ( + result + == b"[11,11,null,null,15.23,null,null,null,1.7976931348623157e+308,null,null,9,17,6,8,8,null,74]" + ) + + # Check for multiple path match in enhanced + result = await json.numincrby(glide_client, key, "$..key1", 1) + assert result == b"[12,null,75]" + + # Check for non existent path in JSONPath + result = await json.numincrby(glide_client, key, "$.key10", 51) + assert result == b"[]" # Expect Empty Array + + # Check for non existent key in JSONPath + with pytest.raises(RequestError): + await json.numincrby(glide_client, "non_existent_key", "$.key10", 51) + + # Check for Overflow in JSONPath + with pytest.raises(RequestError): + await json.numincrby(glide_client, key, "$.key9", 1.7976931348623157e308) + + # Decrement integer value (key1) by 12 + result = await json.numincrby(glide_client, key, "$.key1", -12) + assert result == b"[0]" # Expect 12 - 12 = 0 + + # Decrement integer value (key1) by 0.5 + result = await json.numincrby(glide_client, key, "$.key1", -0.5) + assert result == b"[-0.5]" # Expect 0 - 0.5 = -0.5 + + # Test Legacy Path + # Increment float value (key1) by 5 (integer) + result = await json.numincrby(glide_client, key, "key1", 5) + assert result == b"4.5" # Expect -0.5 + 5 = 4.5 + + # Decrement float value (key1) by 5.5 (integer) + result = await json.numincrby(glide_client, key, "key1", -5.5) + assert result == b"-1" # Expect 4.5 - 5.5 = -1 + + # Increment int value (key2) by 2.5 (a float number) + result = await json.numincrby(glide_client, key, "key2", 2.5) + assert result == b"13.5" # Expect 11 + 2.5 = 13.5 + + # Increment nested value (key3.nested_key.key1[0]) by 7 + result = await json.numincrby(glide_client, key, "key3.nested_key.key1[0]", 7) + assert result == b"16" # Expect 9 + 7 = 16 + + # Increment array element (key4[1]) by 1 + result = await json.numincrby(glide_client, key, "key4[1]", 1) + assert result == b"9" # Expect 8 + 1 = 9 + + # Increment a float value (key5) by 10.2 (a float number) + result = await json.numincrby(glide_client, key, "key5", 10.2) + assert result == b"25.43" # Expect 15.23 + 10.2 = 25.43 + + # Check for multiple path match in legacy and assure that the result of the last updated value is returned + result = await json.numincrby(glide_client, key, "..key1", 1) + assert result == b"76" + + # Check if the rest of the key1 path matches were updated and not only the last value + result = await json.get(glide_client, key, "$..key1") + assert ( + result == b"[0,[16,17],76]" + ) # First is 0 as 0 + 0 = 0, Second doesn't change as its an array type (non-numeric), third is 76 as 0 + 76 = 0 + + # Check for non existent path in legacy + with pytest.raises(RequestError): + await json.numincrby(glide_client, key, ".key10", 51) + + # Check for non existent key in legacy + with pytest.raises(RequestError): + await json.numincrby(glide_client, "non_existent_key", ".key10", 51) + + # Check for Overflow in legacy + with pytest.raises(RequestError): + await json.numincrby(glide_client, key, ".key9", 1.7976931348623157e308) From 04f091067af59c7ecfe95960824d8c781a8f4dde Mon Sep 17 00:00:00 2001 From: Gilboab <97948000+GilboaAWS@users.noreply.github.com> Date: Tue, 15 Oct 2024 14:30:36 +0300 Subject: [PATCH 010/180] Added inflightRequestsLimit client config to Node (#2452) * Added inflightRequestsLimit client config to Node --------- Signed-off-by: GilboaAWS --- node/rust-client/src/lib.rs | 3 ++ node/src/BaseClient.ts | 12 +++++++ node/tests/GlideClient.test.ts | 48 ++++++++++++++++++++++++++ node/tests/GlideClusterClient.test.ts | 49 +++++++++++++++++++++++++++ 4 files changed, 112 insertions(+) diff --git a/node/rust-client/src/lib.rs b/node/rust-client/src/lib.rs index a6e611c0f6..82e546d295 100644 --- a/node/rust-client/src/lib.rs +++ b/node/rust-client/src/lib.rs @@ -43,6 +43,9 @@ pub const MAX_REQUEST_ARGS_LEN: u32 = MAX_REQUEST_ARGS_LENGTH as u32; pub const DEFAULT_TIMEOUT_IN_MILLISECONDS: u32 = glide_core::client::DEFAULT_RESPONSE_TIMEOUT.as_millis() as u32; +#[napi] +pub const DEFAULT_INFLIGHT_REQUESTS_LIMIT: u32 = glide_core::client::DEFAULT_MAX_INFLIGHT_REQUESTS; + #[napi] struct AsyncClient { #[allow(dead_code)] diff --git a/node/src/BaseClient.ts b/node/src/BaseClient.ts index 2bfb9a3bbd..7923e71d2e 100644 --- a/node/src/BaseClient.ts +++ b/node/src/BaseClient.ts @@ -3,6 +3,7 @@ */ import { ClusterScanCursor, + DEFAULT_INFLIGHT_REQUESTS_LIMIT, DEFAULT_TIMEOUT_IN_MILLISECONDS, Script, StartSocketConnection, @@ -563,6 +564,13 @@ export interface BaseClientConfiguration { * If not set, 'Decoder.String' will be used. */ defaultDecoder?: Decoder; + /** + * The maximum number of concurrent requests allowed to be in-flight (sent but not yet completed). + * This limit is used to control the memory usage and prevent the client from overwhelming the + * server or getting stuck in case of a queue backlog. If not set, a default value of 1000 will be + * used. + */ + inflightRequestsLimit?: number; } /** @@ -707,6 +715,7 @@ export class BaseClient { protected defaultDecoder = Decoder.String; private readonly pubsubFutures: [PromiseFunction, ErrorFunction][] = []; private pendingPushNotification: response.Response[] = []; + private readonly inflightRequestsLimit: number; private config: BaseClientConfiguration | undefined; protected configurePubsub( @@ -873,6 +882,8 @@ export class BaseClient { this.close(); }); this.defaultDecoder = options?.defaultDecoder ?? Decoder.String; + this.inflightRequestsLimit = + options?.inflightRequestsLimit ?? DEFAULT_INFLIGHT_REQUESTS_LIMIT; } protected getCallbackIndex(): number { @@ -7505,6 +7516,7 @@ export class BaseClient { clusterModeEnabled: false, readFrom, authenticationInfo, + inflightRequestsLimit: options.inflightRequestsLimit, }; } diff --git a/node/tests/GlideClient.test.ts b/node/tests/GlideClient.test.ts index 4e3a31d195..e0ac65e169 100644 --- a/node/tests/GlideClient.test.ts +++ b/node/tests/GlideClient.test.ts @@ -1501,6 +1501,54 @@ describe("GlideClient", () => { }, ); + it.each([ + [ProtocolVersion.RESP2, 5], + [ProtocolVersion.RESP2, 100], + [ProtocolVersion.RESP2, 1500], + [ProtocolVersion.RESP3, 5], + [ProtocolVersion.RESP3, 100], + [ProtocolVersion.RESP3, 1500], + ])( + "test inflight requests limit of %p with protocol %p", + async (protocol, inflightRequestsLimit) => { + const config = getClientConfigurationOption( + cluster.getAddresses(), + protocol, + { inflightRequestsLimit }, + ); + const client = await GlideClient.createClient(config); + + try { + const key1 = `{nonexistinglist}:1-${uuidv4()}`; + const tasks: Promise<[GlideString, GlideString] | null>[] = []; + + // Start inflightRequestsLimit blocking tasks + for (let i = 0; i < inflightRequestsLimit; i++) { + tasks.push(client.blpop([key1], 0)); + } + + // This task should immediately fail due to reaching the limit + await expect(client.blpop([key1], 0)).rejects.toThrow( + RequestError, + ); + + // Verify that all previous tasks are still pending + const timeoutPromise = new Promise((resolve) => + setTimeout(resolve, 100), + ); + const allTasksStatus = await Promise.race([ + Promise.any( + tasks.map((task) => task.then(() => "resolved")), + ), + timeoutPromise.then(() => "pending"), + ]); + expect(allTasksStatus).toBe("pending"); + } finally { + await client.close(); + } + }, + ); + runBaseTests({ init: async (protocol, configOverrides) => { const config = getClientConfigurationOption( diff --git a/node/tests/GlideClusterClient.test.ts b/node/tests/GlideClusterClient.test.ts index d21ff1a43a..6f29f99884 100644 --- a/node/tests/GlideClusterClient.test.ts +++ b/node/tests/GlideClusterClient.test.ts @@ -24,6 +24,7 @@ import { GeoUnit, GlideClusterClient, GlideReturnType, + GlideString, InfoOptions, ListDirection, ProtocolVersion, @@ -1974,4 +1975,52 @@ describe("GlideClusterClient", () => { }, TIMEOUT, ); + + it.each([ + [ProtocolVersion.RESP2, 5], + [ProtocolVersion.RESP2, 100], + [ProtocolVersion.RESP2, 1500], + [ProtocolVersion.RESP3, 5], + [ProtocolVersion.RESP3, 100], + [ProtocolVersion.RESP3, 1500], + ])( + "test inflight requests limit of %p with protocol %p", + async (protocol, inflightRequestsLimit) => { + const config = getClientConfigurationOption( + cluster.getAddresses(), + protocol, + { inflightRequestsLimit }, + ); + const client = await GlideClusterClient.createClient(config); + + try { + const key1 = `{nonexistinglist}:1-${uuidv4()}`; + const tasks: Promise<[GlideString, GlideString] | null>[] = []; + + // Start inflightRequestsLimit blocking tasks + for (let i = 0; i < inflightRequestsLimit; i++) { + tasks.push(client.blpop([key1], 0)); + } + + // This task should immediately fail due to reaching the limit + await expect(client.blpop([key1], 0)).rejects.toThrow( + RequestError, + ); + + // Verify that all previous tasks are still pending + const timeoutPromise = new Promise((resolve) => + setTimeout(resolve, 100), + ); + const allTasksStatus = await Promise.race([ + Promise.any( + tasks.map((task) => task.then(() => "resolved")), + ), + timeoutPromise.then(() => "pending"), + ]); + expect(allTasksStatus).toBe("pending"); + } finally { + await client.close(); + } + }, + ); }); From b2551897082af2b84027e3d1261a523677418277 Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Tue, 15 Oct 2024 08:31:51 -0700 Subject: [PATCH 011/180] Java: Allow to run modules CI on demand. (#2416) * Run modules CI on demand. Signed-off-by: Yury-Fridlyand --- .github/workflows/java.yml | 4 ++-- CHANGELOG.md | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index d5d0697abb..ca626224f4 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -197,9 +197,9 @@ jobs: name: lint java rust test-modules: - if: github.event.pull_request.head.repo.owner.login == 'valkey-io' + if: (github.repository_owner == 'valkey-io' && github.event_name == 'workflow_dispatch') || github.event.pull_request.head.repo.owner.login == 'valkey-io' environment: AWS_ACTIONS - name: Running Module Tests + name: Modules Tests runs-on: [self-hosted, linux, ARM64] timeout-minutes: 15 steps: diff --git a/CHANGELOG.md b/CHANGELOG.md index efa77a972f..bf7bb69d44 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ #### Operational Enhancements +* Java: Add modules CI ([#2388](https://github.com/valkey-io/valkey-glide/pull/2388), [#2404](https://github.com/valkey-io/valkey-glide/pull/2404), [#2416](https://github.com/valkey-io/valkey-glide/pull/2416)) + ## 1.1.0 (2024-09-24) #### Changes From 23a180bb67019188318f6e3a28fe2d61fa9ccfe0 Mon Sep 17 00:00:00 2001 From: prateek-kumar-improving Date: Tue, 15 Oct 2024 11:33:00 -0700 Subject: [PATCH 012/180] Python FT.DROPINDEX command (#2437) * Python [FT.DROPINDEX] Added command --------- Signed-off-by: Prateek Kumar --- CHANGELOG.md | 1 + .../glide/async_commands/server_modules/ft.py | 30 ++++++- .../{ => ft_options}/ft_constants.py | 3 +- .../ft_options/ft_create_options.py | 78 +++++++++---------- .../{test_ft.py => search/test_ft_create.py} | 8 +- .../search/test_ft_dropindex.py | 44 +++++++++++ 6 files changed, 116 insertions(+), 48 deletions(-) rename python/python/glide/async_commands/server_modules/{ => ft_options}/ft_constants.py (87%) rename python/python/tests/tests_server_modules/{test_ft.py => search/test_ft_create.py} (95%) create mode 100644 python/python/tests/tests_server_modules/search/test_ft_dropindex.py diff --git a/CHANGELOG.md b/CHANGELOG.md index bf7bb69d44..ae54d5022e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ #### Changes +* Python: Python FT.DROPINDEX command ([#2437](https://github.com/valkey-io/valkey-glide/pull/2437)) * Python: Python: Added FT.CREATE command([#2413](https://github.com/valkey-io/valkey-glide/pull/2413)) * Python: Add JSON.ARRLEN command ([#2403](https://github.com/valkey-io/valkey-glide/pull/2403)) * Python: Add JSON.CLEAR command ([#2418](https://github.com/valkey-io/valkey-glide/pull/2418)) diff --git a/python/python/glide/async_commands/server_modules/ft.py b/python/python/glide/async_commands/server_modules/ft.py index b7c764cd0f..74d75e8953 100644 --- a/python/python/glide/async_commands/server_modules/ft.py +++ b/python/python/glide/async_commands/server_modules/ft.py @@ -5,7 +5,7 @@ from typing import List, Optional, cast -from glide.async_commands.server_modules.ft_constants import ( +from glide.async_commands.server_modules.ft_options.ft_constants import ( CommandNames, FtCreateKeywords, ) @@ -30,10 +30,10 @@ async def create( client (TGlideClient): The client to execute the command. indexName (TEncodable): The index name for the index to be created schema (List[Field]): The fields of the index schema, specifying the fields and their types. - options (Optional[FtCreateOptions]): Optional arguments for the [FT.CREATE] command. + options (Optional[FtCreateOptions]): Optional arguments for the FT.CREATE command. See `FtCreateOptions`. Returns: - If the index is successfully created, returns "OK". + TOK: A simple "OK" response. Examples: >>> from glide.async_commands.server_modules import ft @@ -44,7 +44,7 @@ async def create( >>> prefixes.append("blog:post:") >>> index = "idx" >>> result = await ft.create(glide_client, index, schema, FtCreateOptions(DataType.HASH, prefixes)) - b'OK' # Indicates successful creation of index named 'idx' + 'OK' # Indicates successful creation of index named 'idx' """ args: List[TEncodable] = [CommandNames.FT_CREATE, indexName] if options: @@ -54,3 +54,25 @@ async def create( for field in schema: args.extend(field.toArgs()) return cast(TOK, await client.custom_command(args)) + + +async def dropindex(client: TGlideClient, indexName: TEncodable) -> TOK: + """ + Drops an index. The index definition and associated content are deleted. Keys are unaffected. + + Args: + client (TGlideClient): The client to execute the command. + indexName (TEncodable): The index name for the index to be dropped. + + Returns: + TOK: A simple "OK" response. + + Examples: + For the following example to work, an index named 'idx' must be already created. If not created, you will get an error. + >>> from glide.async_commands.server_modules import ft + >>> indexName = "idx" + >>> result = await ft.dropindex(glide_client, indexName) + 'OK' # Indicates successful deletion/dropping of index named 'idx' + """ + args: List[TEncodable] = [CommandNames.FT_DROPINDEX, indexName] + return cast(TOK, await client.custom_command(args)) diff --git a/python/python/glide/async_commands/server_modules/ft_constants.py b/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py similarity index 87% rename from python/python/glide/async_commands/server_modules/ft_constants.py rename to python/python/glide/async_commands/server_modules/ft_options/ft_constants.py index 3c48f5b67c..d1e8e524eb 100644 --- a/python/python/glide/async_commands/server_modules/ft_constants.py +++ b/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py @@ -7,11 +7,12 @@ class CommandNames: """ FT_CREATE = "FT.CREATE" + FT_DROPINDEX = "FT.DROPINDEX" class FtCreateKeywords: """ - Keywords used in the [FT.CREATE] command statment. + Keywords used in the FT.CREATE command statment. """ SCHEMA = "SCHEMA" diff --git a/python/python/glide/async_commands/server_modules/ft_options/ft_create_options.py b/python/python/glide/async_commands/server_modules/ft_options/ft_create_options.py index d3db3dbe75..89ac1d760d 100644 --- a/python/python/glide/async_commands/server_modules/ft_options/ft_create_options.py +++ b/python/python/glide/async_commands/server_modules/ft_options/ft_create_options.py @@ -3,7 +3,7 @@ from enum import Enum from typing import List, Optional -from glide.async_commands.server_modules.ft_constants import FtCreateKeywords +from glide.async_commands.server_modules.ft_options.ft_constants import FtCreateKeywords from glide.constants import TEncodable @@ -85,15 +85,15 @@ def __init__( self, name: TEncodable, type: FieldType, - alias: Optional[str] = None, + alias: Optional[TEncodable] = None, ): """ Initialize a new field instance. Args: name (TEncodable): The name of the field. - type (FieldType): The type of the field. - alias (Optional[str]): An alias for the field. + type (FieldType): The type of the field. See `FieldType`. + alias (Optional[TEncodable]): An alias for the field. """ self.name = name self.type = type @@ -119,13 +119,13 @@ class TextField(Field): Class for defining text fields in a schema. """ - def __init__(self, name: TEncodable, alias: Optional[str] = None): + def __init__(self, name: TEncodable, alias: Optional[TEncodable] = None): """ Initialize a new TextField instance. Args: name (TEncodable): The name of the text field. - alias (Optional[str]): An alias for the field. + alias (Optional[TEncodable]): An alias for the field. """ super().__init__(name, FieldType.TEXT, alias) @@ -148,8 +148,8 @@ class TagField(Field): def __init__( self, name: TEncodable, - alias: Optional[str] = None, - separator: Optional[str] = None, + alias: Optional[TEncodable] = None, + separator: Optional[TEncodable] = None, case_sensitive: bool = False, ): """ @@ -157,8 +157,8 @@ def __init__( Args: name (TEncodable): The name of the tag field. - alias (Optional[str]): An alias for the field. - separator (Optional[str]): Specify how text in the attribute is split into individual tags. Must be a single character. + alias (Optional[TEncodable]): An alias for the field. + separator (Optional[TEncodable]): Specify how text in the attribute is split into individual tags. Must be a single character. case_sensitive (bool): Preserve the original letter cases of tags. If set to False, characters are converted to lowercase by default. """ super().__init__(name, FieldType.TAG, alias) @@ -185,13 +185,13 @@ class NumericField(Field): Class for defining the numeric fields in a schema. """ - def __init__(self, name: TEncodable, alias: Optional[str] = None): + def __init__(self, name: TEncodable, alias: Optional[TEncodable] = None): """ Initialize a new NumericField instance. Args: name (TEncodable): The name of the numeric field. - alias (Optional[str]): An alias for the field. + alias (Optional[TEncodable]): An alias for the field. """ super().__init__(name, FieldType.NUMERIC, alias) @@ -219,21 +219,21 @@ def __init__(self, dim: int, distance_metric: DistanceMetricType, type: VectorTy Args: dim (int): Number of dimensions in the vector. distance_metric (DistanceMetricType): The distance metric used in vector type field. Can be one of [L2 | IP | COSINE]. - type (VectorType): Vector type. The only supported type is FLOAT32. + type (VectorType): Vector type. The only supported type is FLOAT32. See `VectorType`. """ self.dim = dim self.distance_metric = distance_metric self.type = type @abstractmethod - def toArgs(self) -> List[str]: + def toArgs(self) -> List[TEncodable]: """ Get the arguments to be used for the algorithm of the vector field. Returns: - List[str]: A list of arguments. + List[TEncodable]: A list of arguments. """ - args = [] + args: List[TEncodable] = [] if self.dim: args.extend([FtCreateKeywords.DIM, str(self.dim)]) if self.distance_metric: @@ -260,19 +260,19 @@ def __init__( Args: dim (int): Number of dimensions in the vector. - distance_metric (DistanceMetricType): The distance metric used in vector type field. Can be one of [L2 | IP | COSINE]. - type (VectorType): Vector type. The only supported type is FLOAT32. + distance_metric (DistanceMetricType): The distance metric used in vector type field. Can be one of [L2 | IP | COSINE]. See `DistanceMetricType`. + type (VectorType): Vector type. The only supported type is FLOAT32. See `VectorType`. initial_cap (Optional[int]): Initial vector capacity in the index affecting memory allocation size of the index. Defaults to 1024. """ super().__init__(dim, distance_metric, type) self.initial_cap = initial_cap - def toArgs(self) -> List[str]: + def toArgs(self) -> List[TEncodable]: """ Get the arguments representing the vector field created with FLAT algorithm. Returns: - List[str]: A list of FLAT algorithm type vector arguments. + List[TEncodable]: A list of FLAT algorithm type vector arguments. """ args = super().toArgs() if self.initial_cap: @@ -300,8 +300,8 @@ def __init__( Args: dim (int): Number of dimensions in the vector. - distance_metric (DistanceMetricType): The distance metric used in vector type field. Can be one of [L2 | IP | COSINE]. - type (VectorType): Vector type. The only supported type is FLOAT32. + distance_metric (DistanceMetricType): The distance metric used in vector type field. Can be one of [L2 | IP | COSINE]. See `DistanceMetricType`. + type (VectorType): Vector type. The only supported type is FLOAT32. See `VectorType`. initial_cap (Optional[int]): Initial vector capacity in the index affecting memory allocation size of the index. Defaults to 1024. m (Optional[int]): Number of maximum allowed outgoing edges for each node in the graph in each layer. Default is 16, maximum is 512. ef_contruction (Optional[int]): Controls the number of vectors examined during index construction. Default value is 200, Maximum value is 4096. @@ -313,12 +313,12 @@ def __init__( self.ef_contruction = ef_contruction self.ef_runtime = ef_runtime - def toArgs(self) -> List[str]: + def toArgs(self) -> List[TEncodable]: """ Get the arguments representing the vector field created with HSNW algorithm. Returns: - List[str]: A list of HNSW algorithm type vector arguments. + List[TEncodable]: A list of HNSW algorithm type vector arguments. """ args = super().toArgs() if self.initial_cap: @@ -342,16 +342,16 @@ def __init__( name: TEncodable, algorithm: VectorAlgorithm, attributes: VectorFieldAttributes, - alias: Optional[str] = None, + alias: Optional[TEncodable] = None, ): """ Initialize a new VectorField instance. Args: name (TEncodable): The name of the vector field. - algorithm (VectorAlgorithm): The vector indexing algorithm. - alias (Optional[str]): An alias for the field. - attributes (VectorFieldAttributes): Additional attributes to be passed with the vector field after the algorithm name. + algorithm (VectorAlgorithm): The vector indexing algorithm. See `VectorAlgorithm`. + alias (Optional[TEncodable]): An alias for the field. + attributes (VectorFieldAttributes): Additional attributes to be passed with the vector field after the algorithm name. See `VectorFieldAttributes`. """ super().__init__(name, FieldType.VECTOR, alias) self.algorithm = algorithm @@ -390,34 +390,34 @@ class DataType(Enum): class FtCreateOptions: """ - This class represents the input options to be used in the [FT.CREATE] command. - All fields in this class are optional inputs for [FT.CREATE]. + This class represents the input options to be used in the FT.CREATE command. + All fields in this class are optional inputs for FT.CREATE. """ def __init__( self, data_type: Optional[DataType] = None, - prefixes: Optional[List[str]] = None, + prefixes: Optional[List[TEncodable]] = None, ): """ - Initialize the [FT.CREATE] optional fields. + Initialize the FT.CREATE optional fields. Args: - data_type (Optional[DataType]): The type of data to be indexed using [FT.CREATE]. - prefixes (Optional[List[str]]): The prefix of the key to be indexed. + data_type (Optional[DataType]): The type of data to be indexed using FT.CREATE. See `DataType`. + prefixes (Optional[List[TEncodable]]): The prefix of the key to be indexed. """ self.data_type = data_type self.prefixes = prefixes - def toArgs(self) -> List[str]: + def toArgs(self) -> List[TEncodable]: """ - Get the optional arguments for the [FT.CREATE] command. + Get the optional arguments for the FT.CREATE command. Returns: - List[str]: - List of [FT.CREATE] optional agruments. + List[TEncodable]: + List of FT.CREATE optional agruments. """ - args = [] + args: List[TEncodable] = [] if self.data_type: args.append(FtCreateKeywords.ON) args.append(self.data_type.value) diff --git a/python/python/tests/tests_server_modules/test_ft.py b/python/python/tests/tests_server_modules/search/test_ft_create.py similarity index 95% rename from python/python/tests/tests_server_modules/test_ft.py rename to python/python/tests/tests_server_modules/search/test_ft_create.py index 93f9efc9c1..c08346563b 100644 --- a/python/python/tests/tests_server_modules/test_ft.py +++ b/python/python/tests/tests_server_modules/search/test_ft_create.py @@ -17,15 +17,15 @@ VectorType, ) from glide.config import ProtocolVersion -from glide.constants import OK +from glide.constants import OK, TEncodable from glide.glide_client import GlideClusterClient @pytest.mark.asyncio -class TestVss: +class TestFtCreate: @pytest.mark.parametrize("cluster_mode", [True]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) - async def test_vss_create(self, glide_client: GlideClusterClient): + async def test_ft_create(self, glide_client: GlideClusterClient): fields: List[Field] = [] textFieldTitle: TextField = TextField("$title") numberField: NumericField = NumericField("$published_at") @@ -34,7 +34,7 @@ async def test_vss_create(self, glide_client: GlideClusterClient): fields.append(numberField) fields.append(textFieldCategory) - prefixes: List[str] = [] + prefixes: List[TEncodable] = [] prefixes.append("blog:post:") # Create an index with multiple fields with Hash data type. diff --git a/python/python/tests/tests_server_modules/search/test_ft_dropindex.py b/python/python/tests/tests_server_modules/search/test_ft_dropindex.py new file mode 100644 index 0000000000..717df38eb8 --- /dev/null +++ b/python/python/tests/tests_server_modules/search/test_ft_dropindex.py @@ -0,0 +1,44 @@ +import uuid +from typing import List + +import pytest +from glide.async_commands.server_modules import ft +from glide.async_commands.server_modules.ft_options.ft_create_options import ( + DataType, + Field, + FtCreateOptions, + TextField, +) +from glide.config import ProtocolVersion +from glide.constants import OK, TEncodable +from glide.exceptions import RequestError +from glide.glide_client import GlideClusterClient + + +@pytest.mark.asyncio +class TestFtDropIndex: + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_dropindex(self, glide_client: GlideClusterClient): + # Index name for the index to be dropped. + indexName = str(uuid.uuid4()) + + fields: List[Field] = [] + textFieldTitle: TextField = TextField("$title") + fields.append(textFieldTitle) + prefixes: List[TEncodable] = [] + prefixes.append("blog:post:") + + # Create an index with multiple fields with Hash data type. + result = await ft.create( + glide_client, indexName, fields, FtCreateOptions(DataType.HASH, prefixes) + ) + assert result == OK + + # Drop the index. Expects "OK" as a response. + result = await ft.dropindex(glide_client, indexName) + assert result == OK + + # Drop a non existent index. Expects a RequestError. + with pytest.raises(RequestError): + await ft.dropindex(glide_client, indexName) From e51aa6d3d608d4df052fc4a7a22c372abc42d0ea Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Tue, 15 Oct 2024 15:22:36 -0700 Subject: [PATCH 013/180] Update routing for commands from server modules. (#2461) * Update routing for commands from server modules. Signed-off-by: Yury-Fridlyand --- CHANGELOG.md | 3 +- .../redis-rs/redis/src/cluster_routing.rs | 31 +++++++++++++++++-- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ae54d5022e..f484dbb972 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,9 +5,8 @@ * Python: Add JSON.CLEAR command ([#2418](https://github.com/valkey-io/valkey-glide/pull/2418)) * Python: Add JSON.TYPE command ([#2409](https://github.com/valkey-io/valkey-glide/pull/2409)) * Python: Add JSON.NUMINCRBY command ([#2448](https://github.com/valkey-io/valkey-glide/pull/2448)) - - * Java: Added `FT.CREATE` ([#2414](https://github.com/valkey-io/valkey-glide/pull/2414)) +* Core: Update routing for commands from server modules ([#2461](https://github.com/valkey-io/valkey-glide/pull/2461)) #### Breaking Changes diff --git a/glide-core/redis-rs/redis/src/cluster_routing.rs b/glide-core/redis-rs/redis/src/cluster_routing.rs index bfe6ae2039..dcd02e6046 100644 --- a/glide-core/redis-rs/redis/src/cluster_routing.rs +++ b/glide-core/redis-rs/redis/src/cluster_routing.rs @@ -355,9 +355,14 @@ impl ResponsePolicy { Some(ResponsePolicy::AllSucceeded) } - b"KEYS" | b"MGET" | b"SLOWLOG GET" | b"PUBSUB CHANNELS" | b"PUBSUB SHARDCHANNELS" => { - Some(ResponsePolicy::CombineArrays) - } + b"KEYS" + | b"FT._ALIASLIST" + | b"FT._LIST" + | b"MGET" + | b"SLOWLOG GET" + | b"PUBSUB CHANNELS" + | b"PUBSUB SHARDCHANNELS" => Some(ResponsePolicy::CombineArrays), + b"PUBSUB NUMSUB" | b"PUBSUB SHARDNUMSUB" => Some(ResponsePolicy::CombineMaps), b"FUNCTION KILL" | b"SCRIPT KILL" => Some(ResponsePolicy::OneSucceeded), @@ -429,6 +434,8 @@ fn base_routing(cmd: &[u8]) -> RouteBy { b"DBSIZE" | b"FLUSHALL" | b"FLUSHDB" + | b"FT._ALIASLIST" + | b"FT._LIST" | b"FUNCTION DELETE" | b"FUNCTION FLUSH" | b"FUNCTION LOAD" @@ -689,6 +696,14 @@ pub fn is_readonly_cmd(cmd: &[u8]) -> bool { | b"EXISTS" | b"EXPIRETIME" | b"FCALL_RO" + | b"FT.AGGREGATE" + | b"FT.EXPLAIN" + | b"FT.EXPLAINCLI" + | b"FT.INFO" + | b"FT.PROFILE" + | b"FT.SEARCH" + | b"FT._ALIASLIST" + | b"FT._LIST" | b"FUNCTION DUMP" | b"FUNCTION KILL" | b"FUNCTION LIST" @@ -712,6 +727,16 @@ pub fn is_readonly_cmd(cmd: &[u8]) -> bool { | b"HSCAN" | b"HSTRLEN" | b"HVALS" + | b"JSON.ARRINDEX" + | b"JSON.ARRLEN" + | b"JSON.DEBUG" + | b"JSON.GET" + | b"JSON.OBJLEN" + | b"JSON.OBJKEYS" + | b"JSON.MGET" + | b"JSON.RESP" + | b"JSON.STRLEN" + | b"JSON.TYPE" | b"KEYS" | b"LCS" | b"LINDEX" From 9b9ccdbfea2ce6fb1e1ee29585d6b3a51997d3af Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Tue, 15 Oct 2024 16:53:25 -0700 Subject: [PATCH 014/180] Java: `FT.DROPINDEX`. (#2440) * `FT.DROPINDEX`. Signed-off-by: Yury-Fridlyand Signed-off-by: Andrew Carbonetto Co-authored-by: Andrew Carbonetto --- CHANGELOG.md | 1 + .../glide/api/commands/servermodules/FT.java | 30 +++++++++++++++ .../java/glide/modules/VectorSearchTests.java | 38 +++++++++++++++++++ 3 files changed, 69 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f484dbb972..339684f4ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ * Python: Add JSON.TYPE command ([#2409](https://github.com/valkey-io/valkey-glide/pull/2409)) * Python: Add JSON.NUMINCRBY command ([#2448](https://github.com/valkey-io/valkey-glide/pull/2448)) * Java: Added `FT.CREATE` ([#2414](https://github.com/valkey-io/valkey-glide/pull/2414)) +* Java: Added `FT.DROPINDEX` ([#2440](https://github.com/valkey-io/valkey-glide/pull/2440)) * Core: Update routing for commands from server modules ([#2461](https://github.com/valkey-io/valkey-glide/pull/2461)) #### Breaking Changes diff --git a/java/client/src/main/java/glide/api/commands/servermodules/FT.java b/java/client/src/main/java/glide/api/commands/servermodules/FT.java index 51bde7a03d..bff9eeb357 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/FT.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/FT.java @@ -140,6 +140,36 @@ public static CompletableFuture create( return executeCommand(client, args, false); } + /** + * Deletes an index and associated content. Indexed document keys are unaffected. + * + * @param indexName The index name. + * @return "OK". + * @example + *
{@code
+     * FT.dropindex(client, "hash_idx1").get();
+     * }
+ */ + public static CompletableFuture dropindex( + @NonNull BaseClient client, @NonNull String indexName) { + return executeCommand(client, new GlideString[] {gs("FT.DROPINDEX"), gs(indexName)}, false); + } + + /** + * Deletes an index and associated content. Indexed document keys are unaffected. + * + * @param indexName The index name. + * @return "OK". + * @example + *
{@code
+     * FT.dropindex(client, gs("hash_idx1")).get();
+     * }
+ */ + public static CompletableFuture dropindex( + @NonNull BaseClient client, @NonNull GlideString indexName) { + return executeCommand(client, new GlideString[] {gs("FT.DROPINDEX"), indexName}, false); + } + /** * A wrapper for custom command API. * diff --git a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java index 67387026bd..0597fa0023 100644 --- a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java +++ b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java @@ -6,6 +6,7 @@ import static glide.api.models.configuration.RequestRoutingConfiguration.SimpleMultiNodeRoute.ALL_PRIMARIES; import static glide.api.models.configuration.RequestRoutingConfiguration.SimpleSingleNodeRoute.RANDOM; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -24,6 +25,8 @@ import glide.api.models.commands.FlushMode; import glide.api.models.commands.InfoOptions.Section; import glide.api.models.exceptions.RequestException; +import java.util.HashSet; +import java.util.Set; import java.util.UUID; import java.util.concurrent.ExecutionException; import lombok.SneakyThrows; @@ -182,4 +185,39 @@ public void ft_create() { assertInstanceOf(RequestException.class, exception.getCause()); assertTrue(exception.getMessage().contains("already exists")); } + + @SneakyThrows + @Test + public void ft_drop() { + var index = UUID.randomUUID().toString(); + assertEquals( + OK, + FT.create( + client, + index, + new FieldInfo[] { + new FieldInfo("vec", VectorFieldHnsw.builder(DistanceMetric.L2, 2).build()) + }) + .get()); + + // TODO use FT.LIST with it is done + var before = + Set.of((Object[]) client.customCommand(new String[] {"FT._LIST"}).get().getSingleValue()); + + assertEquals(OK, FT.dropindex(client, index).get()); + + // TODO use FT.LIST with it is done + var after = + new HashSet<>( + Set.of( + (Object[]) client.customCommand(new String[] {"FT._LIST"}).get().getSingleValue())); + + assertFalse(after.contains(index)); + after.add(index); + assertEquals(after, before); + + var exception = assertThrows(ExecutionException.class, () -> FT.dropindex(client, index).get()); + assertInstanceOf(RequestException.class, exception.getCause()); + assertTrue(exception.getMessage().contains("Index does not exist")); + } } From db785c98f46259ac7a47c36d513c14887ab7b427 Mon Sep 17 00:00:00 2001 From: Muhammad Awawdi Date: Wed, 16 Oct 2024 12:12:30 +0300 Subject: [PATCH 015/180] Python: Adds JSON.NUMMULTBY Command (#2458) --------- Signed-off-by: Muhammad Awawdi Co-authored-by: Shoham Elias <116083498+shohamazon@users.noreply.github.com> --- CHANGELOG.md | 1 + .../async_commands/server_modules/json.py | 42 ++++++ .../tests/tests_server_modules/test_json.py | 138 ++++++++++++++++++ 3 files changed, 181 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 339684f4ad..3cf1a3a7ea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ * Python: Add JSON.CLEAR command ([#2418](https://github.com/valkey-io/valkey-glide/pull/2418)) * Python: Add JSON.TYPE command ([#2409](https://github.com/valkey-io/valkey-glide/pull/2409)) * Python: Add JSON.NUMINCRBY command ([#2448](https://github.com/valkey-io/valkey-glide/pull/2448)) +* Python: Add JSON.NUMMULTBY command ([#2458](https://github.com/valkey-io/valkey-glide/pull/2458)) * Java: Added `FT.CREATE` ([#2414](https://github.com/valkey-io/valkey-glide/pull/2414)) * Java: Added `FT.DROPINDEX` ([#2440](https://github.com/valkey-io/valkey-glide/pull/2440)) * Core: Update routing for commands from server modules ([#2461](https://github.com/valkey-io/valkey-glide/pull/2461)) diff --git a/python/python/glide/async_commands/server_modules/json.py b/python/python/glide/async_commands/server_modules/json.py index 31ee9fe10e..7da1c6c4aa 100644 --- a/python/python/glide/async_commands/server_modules/json.py +++ b/python/python/glide/async_commands/server_modules/json.py @@ -362,6 +362,48 @@ async def numincrby( return cast(Optional[bytes], await client.custom_command(args)) +async def nummultby( + client: TGlideClient, + key: TEncodable, + path: TEncodable, + number: Union[int, float], +) -> Optional[bytes]: + """ + Multiplies the JSON value(s) at the specified `path` by `number` within the JSON document stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (TEncodable): The path within the JSON document. + number (Union[int, float]): The number to multiply by. + + Returns: + Optional[bytes]: + For JSONPath (`path` starts with `$`): + Returns a bytes string representation of an array of bulk strings, indicating the new values after multiplication for each matched `path`. + If a value is not a number, its corresponding return value will be `null`. + If `path` doesn't exist, a byte string representation of an empty array will be returned. + For legacy path (`path` doesn't start with `$`): + Returns a bytes string representation of the resulting value after multiplication. + If multiple paths match, the result of the last updated value is returned. + If the value at the `path` is not a number or `path` doesn't exist, an error is raised. + If `key` does not exist, an error is raised. + If the result is out of the range of 64-bit IEEE double, an error is raised. + + Examples: + >>> from glide import json + >>> await json.set(client, "doc", "$", '{"a": [], "b": [1], "c": [1, 2], "d": [1, 2, 3]}') + 'OK' + >>> await json.nummultby(client, "doc", "$.d[*]", 2) + b'[2,4,6]' # Multiplies each element in the `d` array by 2. + >>> await json.nummultby(client, "doc", ".c[1]", 2) + b'4' # Multiplies the second element in the `c` array by 2. + """ + args = ["JSON.NUMMULTBY", key, path, str(number)] + + return cast(Optional[bytes], await client.custom_command(args)) + + async def toggle( client: TGlideClient, key: TEncodable, diff --git a/python/python/tests/tests_server_modules/test_json.py b/python/python/tests/tests_server_modules/test_json.py index 18d7ff525b..794e885cfe 100644 --- a/python/python/tests/tests_server_modules/test_json.py +++ b/python/python/tests/tests_server_modules/test_json.py @@ -486,3 +486,141 @@ async def test_json_numincrby(self, glide_client: TGlideClient): # Check for Overflow in legacy with pytest.raises(RequestError): await json.numincrby(glide_client, key, ".key9", 1.7976931348623157e308) + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_nummultby(self, glide_client: TGlideClient): + key = get_random_string(10) + + json_value = { + "key1": 1, + "key2": 3.5, + "key3": {"nested_key": {"key1": [4, 5]}}, + "key4": [1, 2, 3], + "key5": 0, + "key6": "hello", + "key7": None, + "key8": {"nested_key": {"key1": 69}}, + "key9": 3.5953862697246314e307, + } + + # Set the initial JSON document at the key + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + + # Test JSONPath + # Multiply integer value (key1) by 5 + result = await json.nummultby(glide_client, key, "$.key1", 5) + assert result == b"[5]" # Expect 1 * 5 = 5 + + # Multiply float value (key2) by 2.5 + result = await json.nummultby(glide_client, key, "$.key2", 2.5) + assert result == b"[8.75]" # Expect 3.5 * 2.5 = 8.75 + + # Multiply nested object (key3.nested_key.key1[1]) by 7 + result = await json.nummultby(glide_client, key, "$.key3.nested_key.key1[1]", 7) + assert result == b"[35]" # Expect 5 * 7 = 35 + + # Multiply array element (key4[1]) by 1 + result = await json.nummultby(glide_client, key, "$.key4[1]", 1) + assert result == b"[2]" # Expect 2 * 1 = 2 + + # Multiply zero value (key5) by 10.23 (float number) + result = await json.nummultby(glide_client, key, "$.key5", 10.23) + assert result == b"[0]" # Expect 0 * 10.23 = 0 + + # Multiply a string value (key6) by a number + result = await json.nummultby(glide_client, key, "$.key6", 99) + assert result == b"[null]" # Expect null + + # Multiply a None value (key7) by a number + result = await json.nummultby(glide_client, key, "$.key7", 51) + assert result == b"[null]" # Expect null + + # Check multiplication for all numbers in the document using JSON Path + # key1: 5 * 5 = 25 + # key2: 8.75 * 5 = 43.75 + # key3.nested_key.key1[0]: 4 * 5 = 20 + # key3.nested_key.key1[1]: 35 * 5 = 175 + # key4[0]: 1 * 5 = 5 + # key4[1]: 2 * 5 = 10 + # key4[2]: 3 * 5 = 15 + # key5: 0 * 5 = 0 + # key8.nested_key.key1: 69 * 5 = 345 + # key9: 3.5953862697246314e307 * 5 = 1.7976931348623157e308 + result = await json.nummultby(glide_client, key, "$..*", 5) + assert ( + result + == b"[25,43.75,null,null,0,null,null,null,1.7976931348623157e+308,null,null,20,175,5,10,15,null,345]" + ) + + # Check for multiple path matches in JSONPath + # key1: 25 * 2 = 50 + # key8.nested_key.key1: 345 * 2 = 690 + result = await json.nummultby(glide_client, key, "$..key1", 2) + assert result == b"[50,null,690]" # After previous multiplications + + # Check for non-existent path in JSONPath + result = await json.nummultby(glide_client, key, "$.key10", 51) + assert result == b"[]" # Expect Empty Array + + # Check for non-existent key in JSONPath + with pytest.raises(RequestError): + await json.nummultby(glide_client, "non_existent_key", "$.key10", 51) + + # Check for Overflow in JSONPath + with pytest.raises(RequestError): + await json.nummultby(glide_client, key, "$.key9", 1.7976931348623157e308) + + # Multiply integer value (key1) by -12 + result = await json.nummultby(glide_client, key, "$.key1", -12) + assert result == b"[-600]" # Expect 50 * -12 = -600 + + # Multiply integer value (key1) by -0.5 + result = await json.nummultby(glide_client, key, "$.key1", -0.5) + assert result == b"[300]" # Expect -600 * -0.5 = 300 + + # Test Legacy Path + # Multiply int value (key1) by 5 (integer) + result = await json.nummultby(glide_client, key, "key1", 5) + assert result == b"1500" # Expect 300 * 5 = -1500 + + # Multiply int value (key1) by -5.5 (float number) + result = await json.nummultby(glide_client, key, "key1", -5.5) + assert result == b"-8250" # Expect -150 * -5.5 = -8250 + + # Multiply int float (key2) by 2.5 (a float number) + result = await json.nummultby(glide_client, key, "key2", 2.5) + assert result == b"109.375" # Expect 43.75 * 2.5 = 109.375 + + # Multiply nested value (key3.nested_key.key1[0]) by 7 + result = await json.nummultby(glide_client, key, "key3.nested_key.key1[0]", 7) + assert result == b"140" # Expect 20 * 7 = 140 + + # Multiply array element (key4[1]) by 1 + result = await json.nummultby(glide_client, key, "key4[1]", 1) + assert result == b"10" # Expect 10 * 1 = 10 + + # Multiply a float value (key5) by 10.2 (a float number) + result = await json.nummultby(glide_client, key, "key5", 10.2) + assert result == b"0" # Expect 0 * 10.2 = 0 + + # Check for multiple path matches in legacy and assure that the result of the last updated value is returned + # last updated value is key8.nested_key.key1: 690 * 2 = 1380 + result = await json.nummultby(glide_client, key, "..key1", 2) + assert result == b"1380" # Expect the last updated key1 value multiplied by 2 + + # Check if the rest of the key1 path matches were updated and not only the last value + result = await json.get(glide_client, key, "$..key1") + assert result == b"[-16500,[140,175],1380]" + + # Check for non-existent path in legacy + with pytest.raises(RequestError): + await json.nummultby(glide_client, key, ".key10", 51) + + # Check for non-existent key in legacy + with pytest.raises(RequestError): + await json.nummultby(glide_client, "non_existent_key", ".key10", 51) + + # Check for Overflow in legacy + with pytest.raises(RequestError): + await json.nummultby(glide_client, key, ".key9", 1.7976931348623157e308) From 15d57b23c667cc5b4076630617a649fc709f1742 Mon Sep 17 00:00:00 2001 From: eifrah-aws Date: Wed, 16 Oct 2024 13:52:59 +0300 Subject: [PATCH 016/180] `redis-rs` code cleanup: keep a single async runtime: `tokio` (#2459) Keep a single Runtime: `tokio` - Removed dead code for `async-std` - Fixed `redis-rs` tests so they could be run with a simple `cargo test` command - Changed the default features to include: "tokio-comp", "tokio-rustls-comp", "connection-manager", "cluster", "cluster-async" - Fixed a flaky `glide-core` test Signed-off-by: Eran Ifrah --- glide-core/redis-rs/redis/Cargo.toml | 30 +- .../redis-rs/redis/src/aio/async_std.rs | 269 -------------- .../redis-rs/redis/src/aio/connection.rs | 23 +- .../redis/src/aio/connection_manager.rs | 3 +- glide-core/redis-rs/redis/src/aio/mod.rs | 5 - .../redis/src/aio/multiplexed_connection.rs | 8 +- glide-core/redis-rs/redis/src/aio/runtime.rs | 33 +- glide-core/redis-rs/redis/src/client.rs | 166 +-------- .../redis-rs/redis/src/cluster_async/mod.rs | 40 --- glide-core/redis-rs/redis/src/sentinel.rs | 5 +- .../redis-rs/redis/tests/support/mod.rs | 26 -- .../redis/tests/test_async_async_std.rs | 328 ------------------ glide-core/redis-rs/redis/tests/test_basic.rs | 59 ++++ .../redis-rs/redis/tests/test_cluster.rs | 35 ++ .../redis/tests/test_cluster_async.rs | 111 ++++-- .../redis-rs/redis/tests/test_cluster_scan.rs | 17 +- glide-core/tests/test_standalone_client.rs | 34 +- glide-core/tests/utilities/mocks.rs | 94 +++-- 18 files changed, 308 insertions(+), 978 deletions(-) delete mode 100644 glide-core/redis-rs/redis/src/aio/async_std.rs delete mode 100644 glide-core/redis-rs/redis/tests/test_async_async_std.rs diff --git a/glide-core/redis-rs/redis/Cargo.toml b/glide-core/redis-rs/redis/Cargo.toml index fd79ff079e..46f6fe9231 100644 --- a/glide-core/redis-rs/redis/Cargo.toml +++ b/glide-core/redis-rs/redis/Cargo.toml @@ -66,13 +66,9 @@ derivative = { version = "2.2.0", optional = true } # Only needed for async cluster dashmap = { version = "6.0", optional = true } -# Only needed for async_std support -async-std = { version = "1.8.0", optional = true } async-trait = { version = "0.1.24", optional = true } -# To avoid conflicts, backoff-std-async.version != backoff-tokio.version so we could run tests with --all-features -backoff-std-async = { package = "backoff", version = "0.3.0", optional = true, features = ["async-std"] } -# Only needed for tokio support +# Only needed for tokio support backoff-tokio = { package = "backoff", version = "0.4.0", optional = true, features = ["tokio"] } # Only needed for native tls @@ -108,7 +104,18 @@ arcstr = "1.1.5" uuid = { version = "1.6.1", optional = true } [features] -default = ["acl", "streams", "geospatial", "script", "keep-alive"] +default = [ + "acl", + "streams", + "geospatial", + "script", + "keep-alive", + "tokio-comp", + "tokio-rustls-comp", + "connection-manager", + "cluster", + "cluster-async" +] acl = [] aio = ["bytes", "pin-project-lite", "futures-util", "futures-util/alloc", "futures-util/sink", "tokio/io-util", "tokio-util", "tokio-util/codec", "combine/tokio", "async-trait", "fast-math", "dispose"] geospatial = [] @@ -119,9 +126,6 @@ tls-native-tls = ["native-tls"] tls-rustls = ["rustls", "rustls-native-certs", "rustls-pemfile", "rustls-pki-types"] tls-rustls-insecure = ["tls-rustls"] tls-rustls-webpki-roots = ["tls-rustls", "webpki-roots"] -async-std-comp = ["aio", "async-std", "backoff-std-async"] -async-std-native-tls-comp = ["async-std-comp", "async-native-tls", "tls-native-tls"] -async-std-rustls-comp = ["async-std-comp", "futures-rustls", "tls-rustls"] tokio-comp = ["aio", "tokio/net", "backoff-tokio"] tokio-native-tls-comp = ["tokio-comp", "tls-native-tls", "tokio-native-tls"] tokio-rustls-comp = ["tokio-comp", "tls-rustls", "tokio-rustls"] @@ -139,7 +143,6 @@ disable-client-setinfo = [] # Deprecated features tls = ["tls-native-tls"] # use "tls-native-tls" instead -async-std-tls-comp = ["async-std-native-tls-comp"] # use "async-std-native-tls-comp" instead [dev-dependencies] rand = "0.8" @@ -156,15 +159,12 @@ tempfile = "=3.6.0" once_cell = "1" anyhow = "1" sscanf = "0.4.1" +serial_test = "2" [[test]] name = "test_async" required-features = ["tokio-comp"] -[[test]] -name = "test_async_async_std" -required-features = ["async-std-comp"] - [[test]] name = "parser" required-features = ["aio"] @@ -178,7 +178,7 @@ required-features = ["json", "serde/derive"] [[test]] name = "test_cluster_async" -required-features = ["cluster-async"] +required-features = ["cluster-async", "tokio-comp"] [[test]] name = "test_async_cluster_connections_logic" diff --git a/glide-core/redis-rs/redis/src/aio/async_std.rs b/glide-core/redis-rs/redis/src/aio/async_std.rs deleted file mode 100644 index 19c54d3b31..0000000000 --- a/glide-core/redis-rs/redis/src/aio/async_std.rs +++ /dev/null @@ -1,269 +0,0 @@ -#[cfg(unix)] -use std::path::Path; -#[cfg(feature = "tls-rustls")] -use std::sync::Arc; -use std::{ - future::Future, - io, - net::SocketAddr, - pin::Pin, - task::{self, Poll}, -}; - -use crate::aio::{AsyncStream, RedisRuntime}; -use crate::types::RedisResult; - -#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] -use async_native_tls::{TlsConnector, TlsStream}; - -#[cfg(feature = "tls-rustls")] -use crate::connection::create_rustls_config; -#[cfg(feature = "tls-rustls")] -use futures_rustls::{client::TlsStream, TlsConnector}; - -use async_std::net::TcpStream; -#[cfg(unix)] -use async_std::os::unix::net::UnixStream; -use async_trait::async_trait; -use futures_util::ready; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; - -#[inline(always)] -async fn connect_tcp(addr: &SocketAddr) -> io::Result { - let socket = TcpStream::connect(addr).await?; - #[cfg(feature = "tcp_nodelay")] - socket.set_nodelay(true)?; - #[cfg(feature = "keep-alive")] - { - //For now rely on system defaults - const KEEP_ALIVE: socket2::TcpKeepalive = socket2::TcpKeepalive::new(); - //these are useless error that not going to happen - let mut std_socket = std::net::TcpStream::try_from(socket)?; - let socket2: socket2::Socket = std_socket.into(); - socket2.set_tcp_keepalive(&KEEP_ALIVE)?; - std_socket = socket2.into(); - Ok(std_socket.into()) - } - #[cfg(not(feature = "keep-alive"))] - { - Ok(socket) - } -} -#[cfg(feature = "tls-rustls")] -use crate::tls::TlsConnParams; - -#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] -use crate::connection::TlsConnParams; - -pin_project_lite::pin_project! { - /// Wraps the async_std `AsyncRead/AsyncWrite` in order to implement the required the tokio traits - /// for it - pub struct AsyncStdWrapped { #[pin] inner: T } -} - -impl AsyncStdWrapped { - pub(super) fn new(inner: T) -> Self { - Self { inner } - } -} - -impl AsyncWrite for AsyncStdWrapped -where - T: async_std::io::Write, -{ - fn poll_write( - self: Pin<&mut Self>, - cx: &mut core::task::Context, - buf: &[u8], - ) -> std::task::Poll> { - async_std::io::Write::poll_write(self.project().inner, cx, buf) - } - - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut core::task::Context, - ) -> std::task::Poll> { - async_std::io::Write::poll_flush(self.project().inner, cx) - } - - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut core::task::Context, - ) -> std::task::Poll> { - async_std::io::Write::poll_close(self.project().inner, cx) - } -} - -impl AsyncRead for AsyncStdWrapped -where - T: async_std::io::Read, -{ - fn poll_read( - self: Pin<&mut Self>, - cx: &mut core::task::Context, - buf: &mut ReadBuf<'_>, - ) -> std::task::Poll> { - let n = ready!(async_std::io::Read::poll_read( - self.project().inner, - cx, - buf.initialize_unfilled() - ))?; - buf.advance(n); - std::task::Poll::Ready(Ok(())) - } -} - -/// Represents an AsyncStd connectable -pub enum AsyncStd { - /// Represents an Async_std TCP connection. - Tcp(AsyncStdWrapped), - /// Represents an Async_std TLS encrypted TCP connection. - #[cfg(any( - feature = "async-std-native-tls-comp", - feature = "async-std-rustls-comp" - ))] - TcpTls(AsyncStdWrapped>>), - /// Represents an Async_std Unix connection. - #[cfg(unix)] - Unix(AsyncStdWrapped), -} - -impl AsyncWrite for AsyncStd { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut task::Context, - buf: &[u8], - ) -> Poll> { - match &mut *self { - AsyncStd::Tcp(r) => Pin::new(r).poll_write(cx, buf), - #[cfg(any( - feature = "async-std-native-tls-comp", - feature = "async-std-rustls-comp" - ))] - AsyncStd::TcpTls(r) => Pin::new(r).poll_write(cx, buf), - #[cfg(unix)] - AsyncStd::Unix(r) => Pin::new(r).poll_write(cx, buf), - } - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { - match &mut *self { - AsyncStd::Tcp(r) => Pin::new(r).poll_flush(cx), - #[cfg(any( - feature = "async-std-native-tls-comp", - feature = "async-std-rustls-comp" - ))] - AsyncStd::TcpTls(r) => Pin::new(r).poll_flush(cx), - #[cfg(unix)] - AsyncStd::Unix(r) => Pin::new(r).poll_flush(cx), - } - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { - match &mut *self { - AsyncStd::Tcp(r) => Pin::new(r).poll_shutdown(cx), - #[cfg(any( - feature = "async-std-native-tls-comp", - feature = "async-std-rustls-comp" - ))] - AsyncStd::TcpTls(r) => Pin::new(r).poll_shutdown(cx), - #[cfg(unix)] - AsyncStd::Unix(r) => Pin::new(r).poll_shutdown(cx), - } - } -} - -impl AsyncRead for AsyncStd { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut task::Context, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - match &mut *self { - AsyncStd::Tcp(r) => Pin::new(r).poll_read(cx, buf), - #[cfg(any( - feature = "async-std-native-tls-comp", - feature = "async-std-rustls-comp" - ))] - AsyncStd::TcpTls(r) => Pin::new(r).poll_read(cx, buf), - #[cfg(unix)] - AsyncStd::Unix(r) => Pin::new(r).poll_read(cx, buf), - } - } -} - -#[async_trait] -impl RedisRuntime for AsyncStd { - async fn connect_tcp(socket_addr: SocketAddr) -> RedisResult { - Ok(connect_tcp(&socket_addr) - .await - .map(|con| Self::Tcp(AsyncStdWrapped::new(con)))?) - } - - #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] - async fn connect_tcp_tls( - hostname: &str, - socket_addr: SocketAddr, - insecure: bool, - _tls_params: &Option, - ) -> RedisResult { - let tcp_stream = connect_tcp(&socket_addr).await?; - let tls_connector = if insecure { - TlsConnector::new() - .danger_accept_invalid_certs(true) - .danger_accept_invalid_hostnames(true) - .use_sni(false) - } else { - TlsConnector::new() - }; - Ok(tls_connector - .connect(hostname, tcp_stream) - .await - .map(|con| Self::TcpTls(AsyncStdWrapped::new(Box::new(con))))?) - } - - #[cfg(feature = "tls-rustls")] - async fn connect_tcp_tls( - hostname: &str, - socket_addr: SocketAddr, - insecure: bool, - tls_params: &Option, - ) -> RedisResult { - let tcp_stream = connect_tcp(&socket_addr).await?; - - let config = create_rustls_config(insecure, tls_params.clone())?; - let tls_connector = TlsConnector::from(Arc::new(config)); - - Ok(tls_connector - .connect( - rustls_pki_types::ServerName::try_from(hostname)?.to_owned(), - tcp_stream, - ) - .await - .map(|con| Self::TcpTls(AsyncStdWrapped::new(Box::new(con))))?) - } - - #[cfg(unix)] - async fn connect_unix(path: &Path) -> RedisResult { - Ok(UnixStream::connect(path) - .await - .map(|con| Self::Unix(AsyncStdWrapped::new(con)))?) - } - - fn spawn(f: impl Future + Send + 'static) { - async_std::task::spawn(f); - } - - fn boxed(self) -> Pin> { - match self { - AsyncStd::Tcp(x) => Box::pin(x), - #[cfg(any( - feature = "async-std-native-tls-comp", - feature = "async-std-rustls-comp" - ))] - AsyncStd::TcpTls(x) => Box::pin(x), - #[cfg(unix)] - AsyncStd::Unix(x) => Box::pin(x), - } - } -} diff --git a/glide-core/redis-rs/redis/src/aio/connection.rs b/glide-core/redis-rs/redis/src/aio/connection.rs index 6b1f6e657a..5adef7869f 100644 --- a/glide-core/redis-rs/redis/src/aio/connection.rs +++ b/glide-core/redis-rs/redis/src/aio/connection.rs @@ -1,7 +1,5 @@ #![allow(deprecated)] -#[cfg(feature = "async-std-comp")] -use super::async_std; use super::ConnectionLike; use super::{setup_connection, AsyncStream, RedisRuntime}; use crate::cmd::{cmd, Cmd}; @@ -9,12 +7,10 @@ use crate::connection::{ resp2_is_pub_sub_state_cleared, resp3_is_pub_sub_state_cleared, ConnectionAddr, ConnectionInfo, Msg, RedisConnectionInfo, }; -#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] +#[cfg(any(feature = "tokio-comp"))] use crate::parser::ValueCodec; use crate::types::{ErrorKind, FromRedisValue, RedisError, RedisFuture, RedisResult, Value}; use crate::{from_owned_redis_value, ProtocolVersion, ToRedisArgs}; -#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] -use ::async_std::net::ToSocketAddrs; use ::tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; #[cfg(feature = "tokio-comp")] use ::tokio::net::lookup_host; @@ -26,7 +22,7 @@ use futures_util::{ }; use std::net::{IpAddr, SocketAddr}; use std::pin::Pin; -#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] +#[cfg(feature = "tokio-comp")] use tokio_util::codec::Decoder; use tracing::info; @@ -194,19 +190,6 @@ where } } -#[cfg(feature = "async-std-comp")] -#[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] -impl Connection> -where - C: Unpin + ::async_std::io::Read + ::async_std::io::Write + Send, -{ - /// Constructs a new `Connection` out of a `async_std::io::AsyncRead + async_std::io::AsyncWrite` object - /// and a `RedisConnectionInfo` - pub async fn new_async_std(connection_info: &RedisConnectionInfo, con: C) -> RedisResult { - Connection::new(connection_info, async_std::AsyncStdWrapped::new(con)).await - } -} - pub(crate) async fn connect( connection_info: &ConnectionInfo, socket_addr: Option, @@ -436,8 +419,6 @@ pub(crate) async fn get_socket_addrs( ) -> RedisResult + Send + '_> { #[cfg(feature = "tokio-comp")] let socket_addrs = lookup_host((host, port)).await?; - #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] - let socket_addrs = (host, port).to_socket_addrs().await?; let mut socket_addrs = socket_addrs.peekable(); match socket_addrs.peek() { diff --git a/glide-core/redis-rs/redis/src/aio/connection_manager.rs b/glide-core/redis-rs/redis/src/aio/connection_manager.rs index 61df9bc31a..dce7b254a5 100644 --- a/glide-core/redis-rs/redis/src/aio/connection_manager.rs +++ b/glide-core/redis-rs/redis/src/aio/connection_manager.rs @@ -7,8 +7,7 @@ use crate::{ aio::{ConnectionLike, MultiplexedConnection, Runtime}, Client, }; -#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] -use ::async_std::net::ToSocketAddrs; + use arc_swap::ArcSwap; use futures::{ future::{self, Shared}, diff --git a/glide-core/redis-rs/redis/src/aio/mod.rs b/glide-core/redis-rs/redis/src/aio/mod.rs index ffe2c9e3a2..34c098d600 100644 --- a/glide-core/redis-rs/redis/src/aio/mod.rs +++ b/glide-core/redis-rs/redis/src/aio/mod.rs @@ -14,11 +14,6 @@ use std::path::Path; use std::pin::Pin; use std::time::Duration; -/// Enables the async_std compatibility -#[cfg(feature = "async-std-comp")] -#[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] -pub mod async_std; - #[cfg(feature = "tls-rustls")] use crate::tls::TlsConnParams; diff --git a/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs b/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs index 1067bc2df5..fb1b62f8a1 100644 --- a/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs +++ b/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs @@ -3,7 +3,7 @@ use crate::aio::setup_connection; use crate::aio::DisconnectNotifier; use crate::client::GlideConnectionOptions; use crate::cmd::Cmd; -#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] +#[cfg(feature = "tokio-comp")] use crate::parser::ValueCodec; use crate::push_manager::PushManager; use crate::types::{RedisError, RedisFuture, RedisResult, Value}; @@ -29,7 +29,7 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::task::{self, Poll}; use std::time::Duration; -#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] +#[cfg(feature = "tokio-comp")] use tokio_util::codec::Decoder; // Senders which the result of a single request are sent through @@ -448,8 +448,8 @@ impl MultiplexedConnection { Box::pin(f) } - #[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))] - compile_error!("tokio-comp or async-std-comp features required for aio feature"); + #[cfg(not(feature = "tokio-comp"))] + compile_error!("tokio-comp feature is required for aio feature"); let redis_connection_info = &connection_info.redis; let codec = ValueCodec::default() diff --git a/glide-core/redis-rs/redis/src/aio/runtime.rs b/glide-core/redis-rs/redis/src/aio/runtime.rs index 5755f62c9f..2222783ed8 100644 --- a/glide-core/redis-rs/redis/src/aio/runtime.rs +++ b/glide-core/redis-rs/redis/src/aio/runtime.rs @@ -2,8 +2,6 @@ use std::{io, time::Duration}; use futures_util::Future; -#[cfg(feature = "async-std-comp")] -use super::async_std; #[cfg(feature = "tokio-comp")] use super::tokio; use super::RedisRuntime; @@ -13,34 +11,17 @@ use crate::types::RedisError; pub(crate) enum Runtime { #[cfg(feature = "tokio-comp")] Tokio, - #[cfg(feature = "async-std-comp")] - AsyncStd, } impl Runtime { pub(crate) fn locate() -> Self { - #[cfg(all(feature = "tokio-comp", not(feature = "async-std-comp")))] + #[cfg(not(feature = "tokio-comp"))] { - Runtime::Tokio - } - - #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] - { - Runtime::AsyncStd + compile_error!("tokio-comp feature is required for aio feature") } - - #[cfg(all(feature = "tokio-comp", feature = "async-std-comp"))] + #[cfg(feature = "tokio-comp")] { - if ::tokio::runtime::Handle::try_current().is_ok() { - Runtime::Tokio - } else { - Runtime::AsyncStd - } - } - - #[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))] - { - compile_error!("tokio-comp or async-std-comp features required for aio feature") + Runtime::Tokio } } @@ -49,8 +30,6 @@ impl Runtime { match self { #[cfg(feature = "tokio-comp")] Runtime::Tokio => tokio::Tokio::spawn(f), - #[cfg(feature = "async-std-comp")] - Runtime::AsyncStd => async_std::AsyncStd::spawn(f), } } @@ -64,10 +43,6 @@ impl Runtime { Runtime::Tokio => ::tokio::time::timeout(duration, future) .await .map_err(|_| Elapsed(())), - #[cfg(feature = "async-std-comp")] - Runtime::AsyncStd => ::async_std::future::timeout(duration, future) - .await - .map_err(|_| Elapsed(())), } } } diff --git a/glide-core/redis-rs/redis/src/client.rs b/glide-core/redis-rs/redis/src/client.rs index 5e3f144e71..fd8c4c08b4 100644 --- a/glide-core/redis-rs/redis/src/client.rs +++ b/glide-core/redis-rs/redis/src/client.rs @@ -88,13 +88,12 @@ pub struct GlideConnectionOptions { pub disconnect_notifier: Option>, } -/// To enable async support you need to chose one of the supported runtimes and active its -/// corresponding feature: `tokio-comp` or `async-std-comp` +/// To enable async support you need to enable the feature: `tokio-comp` #[cfg(feature = "aio")] #[cfg_attr(docsrs, doc(cfg(feature = "aio")))] impl Client { /// Returns an async connection from the client. - #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] + #[cfg(feature = "tokio-comp")] #[deprecated( note = "aio::Connection is deprecated. Use client::get_multiplexed_async_connection instead." )] @@ -109,11 +108,6 @@ impl Client { self.get_simple_async_connection::(None) .await? } - #[cfg(feature = "async-std-comp")] - Runtime::AsyncStd => { - self.get_simple_async_connection::(None) - .await? - } }; crate::aio::Connection::new(&self.connection_info.redis, con).await @@ -136,27 +130,8 @@ impl Client { } /// Returns an async connection from the client. - #[cfg(feature = "async-std-comp")] - #[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] - #[deprecated( - note = "aio::Connection is deprecated. Use client::get_multiplexed_async_std_connection instead." - )] - #[allow(deprecated)] - pub async fn get_async_std_connection(&self) -> RedisResult { - use crate::aio::RedisRuntime; - Ok( - crate::aio::connect::(&self.connection_info, None) - .await? - .map(RedisRuntime::boxed), - ) - } - - /// Returns an async connection from the client. - #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] - #[cfg_attr( - docsrs, - doc(cfg(any(feature = "tokio-comp", feature = "async-std-comp"))) - )] + #[cfg(feature = "tokio-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] pub async fn get_multiplexed_async_connection( &self, glide_connection_options: GlideConnectionOptions, @@ -170,11 +145,8 @@ impl Client { } /// Returns an async connection from the client. - #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] - #[cfg_attr( - docsrs, - doc(cfg(any(feature = "tokio-comp", feature = "async-std-comp"))) - )] + #[cfg(feature = "tokio-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] pub async fn get_multiplexed_async_connection_with_timeouts( &self, response_timeout: std::time::Duration, @@ -194,18 +166,6 @@ impl Client { ) .await } - #[cfg(feature = "async-std-comp")] - rt @ Runtime::AsyncStd => { - rt.timeout( - connection_timeout, - self.get_multiplexed_async_connection_inner::( - response_timeout, - None, - glide_connection_options, - ), - ) - .await - } }; match result { @@ -218,11 +178,8 @@ impl Client { /// For TCP connections: returns (async connection, Some(the direct IP address)) /// For Unix connections, returns (async connection, None) - #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] - #[cfg_attr( - docsrs, - doc(cfg(any(feature = "tokio-comp", feature = "async-std-comp"))) - )] + #[cfg(feature = "tokio-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] pub async fn get_multiplexed_async_connection_and_ip( &self, glide_connection_options: GlideConnectionOptions, @@ -237,15 +194,6 @@ impl Client { ) .await } - #[cfg(feature = "async-std-comp")] - Runtime::AsyncStd => { - self.get_multiplexed_async_connection_inner::( - Duration::MAX, - None, - glide_connection_options, - ) - .await - } } } @@ -297,54 +245,6 @@ impl Client { .await } - /// Returns an async multiplexed connection from the client. - /// - /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently - /// on the same underlying connection (tcp/unix socket). - #[cfg(feature = "async-std-comp")] - #[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] - pub async fn get_multiplexed_async_std_connection_with_timeouts( - &self, - response_timeout: std::time::Duration, - connection_timeout: std::time::Duration, - glide_connection_options: GlideConnectionOptions, - ) -> RedisResult { - let result = Runtime::locate() - .timeout( - connection_timeout, - self.get_multiplexed_async_connection_inner::( - response_timeout, - None, - glide_connection_options, - ), - ) - .await; - - match result { - Ok(Ok((connection, _ip))) => Ok(connection), - Ok(Err(e)) => Err(e), - Err(elapsed) => Err(elapsed.into()), - } - } - - /// Returns an async multiplexed connection from the client. - /// - /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently - /// on the same underlying connection (tcp/unix socket). - #[cfg(feature = "async-std-comp")] - #[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] - pub async fn get_multiplexed_async_std_connection( - &self, - glide_connection_options: GlideConnectionOptions, - ) -> RedisResult { - self.get_multiplexed_async_std_connection_with_timeouts( - std::time::Duration::MAX, - std::time::Duration::MAX, - glide_connection_options, - ) - .await - } - /// Returns an async multiplexed connection from the client and a future which must be polled /// to drive any requests submitted to it (see `get_multiplexed_tokio_connection`). /// @@ -392,52 +292,6 @@ impl Client { .map(|conn_res| (conn_res.0, conn_res.1)) } - /// Returns an async multiplexed connection from the client and a future which must be polled - /// to drive any requests submitted to it (see `get_multiplexed_tokio_connection`). - /// - /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently - /// on the same underlying connection (tcp/unix socket). - /// The multiplexer will return a timeout error on any request that takes longer then `response_timeout`. - #[cfg(feature = "async-std-comp")] - #[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] - pub async fn create_multiplexed_async_std_connection_with_response_timeout( - &self, - response_timeout: std::time::Duration, - glide_connection_options: GlideConnectionOptions, - ) -> RedisResult<( - crate::aio::MultiplexedConnection, - impl std::future::Future, - )> { - self.create_multiplexed_async_connection_inner::( - response_timeout, - None, - glide_connection_options, - ) - .await - .map(|(conn, driver, _ip)| (conn, driver)) - } - - /// Returns an async multiplexed connection from the client and a future which must be polled - /// to drive any requests submitted to it (see `get_multiplexed_tokio_connection`). - /// - /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently - /// on the same underlying connection (tcp/unix socket). - #[cfg(feature = "async-std-comp")] - #[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] - pub async fn create_multiplexed_async_std_connection( - &self, - glide_connection_options: GlideConnectionOptions, - ) -> RedisResult<( - crate::aio::MultiplexedConnection, - impl std::future::Future, - )> { - self.create_multiplexed_async_std_connection_with_response_timeout( - std::time::Duration::MAX, - glide_connection_options, - ) - .await - } - /// Returns an async [`ConnectionManager`][connection-manager] from the client. /// /// The connection manager wraps a @@ -785,7 +639,7 @@ impl Client { } /// Returns an async receiver for pub-sub messages. - #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] + #[cfg(feature = "tokio-comp")] // TODO - do we want to type-erase pubsub using a trait, to allow us to replace it with a different implementation later? pub async fn get_async_pubsub(&self) -> RedisResult { #[allow(deprecated)] @@ -795,7 +649,7 @@ impl Client { } /// Returns an async receiver for monitor messages. - #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] + #[cfg(feature = "tokio-comp")] // TODO - do we want to type-erase monitor using a trait, to allow us to replace it with a different implementation later? pub async fn get_async_monitor(&self) -> RedisResult { #[allow(deprecated)] diff --git a/glide-core/redis-rs/redis/src/cluster_async/mod.rs b/glide-core/redis-rs/redis/src/cluster_async/mod.rs index be7beb79b7..c8628c16bb 100644 --- a/glide-core/redis-rs/redis/src/cluster_async/mod.rs +++ b/glide-core/redis-rs/redis/src/cluster_async/mod.rs @@ -38,11 +38,7 @@ use crate::{ commands::cluster_scan::{cluster_scan, ClusterScanArgs, ObjectType, ScanStateRC}, FromRedisValue, InfoDict, ToRedisArgs, }; -#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] -use async_std::task::{spawn, JoinHandle}; use dashmap::DashMap; -#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] -use futures::executor::block_on; use std::{ collections::{HashMap, HashSet}, fmt, io, mem, @@ -84,13 +80,6 @@ use crate::{ use futures::stream::{FuturesUnordered, StreamExt}; use std::time::Duration; -#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] -use crate::aio::{async_std::AsyncStd, RedisRuntime}; -#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] -use backoff_std_async::future::retry; -#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] -use backoff_std_async::{Error as BackoffError, ExponentialBackoff}; - #[cfg(feature = "tokio-comp")] use async_trait::async_trait; #[cfg(feature = "tokio-comp")] @@ -142,9 +131,6 @@ where }; #[cfg(feature = "tokio-comp")] tokio::spawn(stream); - #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] - AsyncStd::spawn(stream); - ClusterConnection(tx) }) } @@ -510,14 +496,10 @@ pub(crate) struct ClusterConnInner { impl Dispose for ClusterConnInner { fn dispose(self) { if let Some(handle) = self.periodic_checks_handler { - #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] - block_on(handle.cancel()); #[cfg(feature = "tokio-comp")] handle.abort() } if let Some(handle) = self.connections_validation_handler { - #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] - block_on(handle.cancel()); #[cfg(feature = "tokio-comp")] handle.abort() } @@ -657,9 +639,6 @@ fn route_for_pipeline(pipeline: &crate::Pipeline) -> RedisResult> fn boxed_sleep(duration: Duration) -> BoxFuture<'static, ()> { #[cfg(feature = "tokio-comp")] return Box::pin(tokio::time::sleep(duration)); - - #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] - return Box::pin(async_std::task::sleep(duration)); } pub(crate) enum Response { @@ -1080,10 +1059,6 @@ where { connection.periodic_checks_handler = Some(tokio::spawn(periodic_task)); } - #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] - { - connection.periodic_checks_handler = Some(spawn(periodic_task)); - } } let connections_validation_interval = cluster_params.connections_validation_interval; @@ -1095,11 +1070,6 @@ where connection.connections_validation_handler = Some(tokio::spawn(connections_validation_handler)); } - #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] - { - connection.connections_validation_handler = - Some(spawn(connections_validation_handler)); - } } Ok(Disposable::new(connection)) @@ -2555,16 +2525,6 @@ impl Connect for MultiplexedConnection { ) .await? } - #[cfg(feature = "async-std-comp")] - rt @ Runtime::AsyncStd => { - rt.timeout(connection_timeout,client - .get_multiplexed_async_connection_inner::( - response_timeout, - socket_addr, - glide_connection_options, - )) - .await? - } } } .boxed() diff --git a/glide-core/redis-rs/redis/src/sentinel.rs b/glide-core/redis-rs/redis/src/sentinel.rs index ac6aac65cc..569ab2fe0f 100644 --- a/glide-core/redis-rs/redis/src/sentinel.rs +++ b/glide-core/redis-rs/redis/src/sentinel.rs @@ -746,8 +746,7 @@ impl SentinelClient { } } -/// To enable async support you need to chose one of the supported runtimes and active its -/// corresponding feature: `tokio-comp` or `async-std-comp` +/// To enable async support you need to enable the feature: `tokio-comp` #[cfg(feature = "aio")] #[cfg_attr(docsrs, doc(cfg(feature = "aio")))] impl SentinelClient { @@ -768,7 +767,7 @@ impl SentinelClient { /// Returns an async connection from the client, using the same logic from /// `SentinelClient::get_connection`. - #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] + #[cfg(feature = "tokio-comp")] pub async fn get_async_connection(&mut self) -> RedisResult { let client = self.async_get_client().await?; client diff --git a/glide-core/redis-rs/redis/tests/support/mod.rs b/glide-core/redis-rs/redis/tests/support/mod.rs index 335cd045de..72dc7c9a78 100644 --- a/glide-core/redis-rs/redis/tests/support/mod.rs +++ b/glide-core/redis-rs/redis/tests/support/mod.rs @@ -85,14 +85,6 @@ where res } -#[cfg(feature = "async-std-comp")] -pub fn block_on_all_using_async_std(f: F) -> F::Output -where - F: Future, -{ - async_std::task::block_on(f) -} - #[cfg(any(feature = "cluster", feature = "cluster-async"))] mod cluster; @@ -514,15 +506,6 @@ impl TestContext { self.client.get_async_pubsub().await } - #[cfg(feature = "async-std-comp")] - pub async fn async_connection_async_std( - &self, - ) -> redis::RedisResult { - self.client - .get_multiplexed_async_std_connection(GlideConnectionOptions::default()) - .await - } - pub fn stop_server(&mut self) { self.server.stop(); } @@ -543,15 +526,6 @@ impl TestContext { .await } - #[cfg(feature = "async-std-comp")] - pub async fn multiplexed_async_connection_async_std( - &self, - ) -> redis::RedisResult { - self.client - .get_multiplexed_async_std_connection(GlideConnectionOptions::default()) - .await - } - pub fn get_version(&self) -> Version { let mut conn = self.connection(); get_version(&mut conn) diff --git a/glide-core/redis-rs/redis/tests/test_async_async_std.rs b/glide-core/redis-rs/redis/tests/test_async_async_std.rs deleted file mode 100644 index 656d1979f6..0000000000 --- a/glide-core/redis-rs/redis/tests/test_async_async_std.rs +++ /dev/null @@ -1,328 +0,0 @@ -#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] -use futures::prelude::*; - -use crate::support::*; - -use redis::{aio::MultiplexedConnection, GlideConnectionOptions, RedisResult}; - -mod support; - -#[test] -fn test_args() { - let ctx = TestContext::new(); - let connect = ctx.async_connection_async_std(); - - block_on_all_using_async_std(connect.and_then(|mut con| async move { - redis::cmd("SET") - .arg("key1") - .arg(b"foo") - .query_async(&mut con) - .await?; - redis::cmd("SET") - .arg(&["key2", "bar"]) - .query_async(&mut con) - .await?; - let result = redis::cmd("MGET") - .arg(&["key1", "key2"]) - .query_async(&mut con) - .await; - assert_eq!(result, Ok(("foo".to_string(), b"bar".to_vec()))); - result - })) - .unwrap(); -} - -#[test] -fn test_args_async_std() { - let ctx = TestContext::new(); - let connect = ctx.async_connection_async_std(); - - block_on_all_using_async_std(connect.and_then(|mut con| async move { - redis::cmd("SET") - .arg("key1") - .arg(b"foo") - .query_async(&mut con) - .await?; - redis::cmd("SET") - .arg(&["key2", "bar"]) - .query_async(&mut con) - .await?; - let result = redis::cmd("MGET") - .arg(&["key1", "key2"]) - .query_async(&mut con) - .await; - assert_eq!(result, Ok(("foo".to_string(), b"bar".to_vec()))); - result - })) - .unwrap(); -} - -#[test] -fn dont_panic_on_closed_multiplexed_connection() { - let ctx = TestContext::new(); - let client = ctx.client.clone(); - let connect = client.get_multiplexed_async_std_connection(GlideConnectionOptions::default()); - drop(ctx); - - block_on_all_using_async_std(async move { - connect - .and_then(|con| async move { - let cmd = move || { - let mut con = con.clone(); - async move { - redis::cmd("SET") - .arg("key1") - .arg(b"foo") - .query_async(&mut con) - .await - } - }; - let result: RedisResult<()> = cmd().await; - assert_eq!( - result.as_ref().unwrap_err().kind(), - redis::ErrorKind::IoError, - "{}", - result.as_ref().unwrap_err() - ); - cmd().await - }) - .map(|result| { - assert_eq!( - result.as_ref().unwrap_err().kind(), - redis::ErrorKind::IoError, - "{}", - result.as_ref().unwrap_err() - ); - }) - .await - }); -} - -#[test] -fn test_pipeline_transaction() { - let ctx = TestContext::new(); - block_on_all_using_async_std(async move { - let mut con = ctx.async_connection_async_std().await?; - let mut pipe = redis::pipe(); - pipe.atomic() - .cmd("SET") - .arg("key_1") - .arg(42) - .ignore() - .cmd("SET") - .arg("key_2") - .arg(43) - .ignore() - .cmd("MGET") - .arg(&["key_1", "key_2"]); - pipe.query_async(&mut con) - .map_ok(|((k1, k2),): ((i32, i32),)| { - assert_eq!(k1, 42); - assert_eq!(k2, 43); - }) - .await - }) - .unwrap(); -} - -fn test_cmd(con: &MultiplexedConnection, i: i32) -> impl Future> + Send { - let mut con = con.clone(); - async move { - let key = format!("key{i}"); - let key_2 = key.clone(); - let key2 = format!("key{i}_2"); - let key2_2 = key2.clone(); - - let foo_val = format!("foo{i}"); - - redis::cmd("SET") - .arg(&key[..]) - .arg(foo_val.as_bytes()) - .query_async(&mut con) - .await?; - redis::cmd("SET") - .arg(&[&key2, "bar"]) - .query_async(&mut con) - .await?; - redis::cmd("MGET") - .arg(&[&key_2, &key2_2]) - .query_async(&mut con) - .map(|result| { - assert_eq!(Ok((foo_val, b"bar".to_vec())), result); - Ok(()) - }) - .await - } -} - -fn test_error(con: &MultiplexedConnection) -> impl Future> { - let mut con = con.clone(); - async move { - redis::cmd("SET") - .query_async(&mut con) - .map(|result| match result { - Ok(()) => panic!("Expected redis to return an error"), - Err(_) => Ok(()), - }) - .await - } -} - -#[test] -fn test_args_multiplexed_connection() { - let ctx = TestContext::new(); - block_on_all_using_async_std(async move { - ctx.multiplexed_async_connection_async_std() - .and_then(|con| { - let cmds = (0..100).map(move |i| test_cmd(&con, i)); - future::try_join_all(cmds).map_ok(|results| { - assert_eq!(results.len(), 100); - }) - }) - .map_err(|err| panic!("{}", err)) - .await - }) - .unwrap(); -} - -#[test] -fn test_args_with_errors_multiplexed_connection() { - let ctx = TestContext::new(); - block_on_all_using_async_std(async move { - ctx.multiplexed_async_connection_async_std() - .and_then(|con| { - let cmds = (0..100).map(move |i| { - let con = con.clone(); - async move { - if i % 2 == 0 { - test_cmd(&con, i).await - } else { - test_error(&con).await - } - } - }); - future::try_join_all(cmds).map_ok(|results| { - assert_eq!(results.len(), 100); - }) - }) - .map_err(|err| panic!("{}", err)) - .await - }) - .unwrap(); -} - -#[test] -fn test_transaction_multiplexed_connection() { - let ctx = TestContext::new(); - block_on_all_using_async_std(async move { - ctx.multiplexed_async_connection_async_std() - .and_then(|con| { - let cmds = (0..100).map(move |i| { - let mut con = con.clone(); - async move { - let foo_val = i; - let bar_val = format!("bar{i}"); - - let mut pipe = redis::pipe(); - pipe.atomic() - .cmd("SET") - .arg("key") - .arg(foo_val) - .ignore() - .cmd("SET") - .arg(&["key2", &bar_val[..]]) - .ignore() - .cmd("MGET") - .arg(&["key", "key2"]); - - pipe.query_async(&mut con) - .map(move |result| { - assert_eq!(Ok(((foo_val, bar_val.into_bytes()),)), result); - result - }) - .await - } - }); - future::try_join_all(cmds) - }) - .map_ok(|results| { - assert_eq!(results.len(), 100); - }) - .map_err(|err| panic!("{}", err)) - .await - }) - .unwrap(); -} - -#[test] -#[cfg(feature = "script")] -fn test_script() { - use redis::RedisError; - - // Note this test runs both scripts twice to test when they have already been loaded - // into Redis and when they need to be loaded in - let script1 = redis::Script::new("return redis.call('SET', KEYS[1], ARGV[1])"); - let script2 = redis::Script::new("return redis.call('GET', KEYS[1])"); - let script3 = redis::Script::new("return redis.call('KEYS', '*')"); - - let ctx = TestContext::new(); - - block_on_all_using_async_std(async move { - let mut con = ctx.multiplexed_async_connection_async_std().await?; - script1 - .key("key1") - .arg("foo") - .invoke_async(&mut con) - .await?; - let val: String = script2.key("key1").invoke_async(&mut con).await?; - assert_eq!(val, "foo"); - let keys: Vec = script3.invoke_async(&mut con).await?; - assert_eq!(keys, ["key1"]); - script1 - .key("key1") - .arg("bar") - .invoke_async(&mut con) - .await?; - let val: String = script2.key("key1").invoke_async(&mut con).await?; - assert_eq!(val, "bar"); - let keys: Vec = script3.invoke_async(&mut con).await?; - assert_eq!(keys, ["key1"]); - Ok::<_, RedisError>(()) - }) - .unwrap(); -} - -#[test] -#[cfg(feature = "script")] -fn test_script_load() { - let ctx = TestContext::new(); - let script = redis::Script::new("return 'Hello World'"); - - block_on_all(async move { - let mut con = ctx.multiplexed_async_connection_async_std().await.unwrap(); - - let hash = script.prepare_invoke().load_async(&mut con).await.unwrap(); - assert_eq!(hash, script.get_hash().to_string()); - Ok(()) - }) - .unwrap(); -} - -#[test] -#[cfg(feature = "script")] -fn test_script_returning_complex_type() { - let ctx = TestContext::new(); - block_on_all_using_async_std(async { - let mut con = ctx.multiplexed_async_connection_async_std().await?; - redis::Script::new("return {1, ARGV[1], true}") - .arg("hello") - .invoke_async(&mut con) - .map_ok(|(i, s, b): (i32, String, bool)| { - assert_eq!(i, 1); - assert_eq!(s, "hello"); - assert!(b); - }) - .await - }) - .unwrap(); -} diff --git a/glide-core/redis-rs/redis/tests/test_basic.rs b/glide-core/redis-rs/redis/tests/test_basic.rs index e31c33384c..4c9ad3aae8 100644 --- a/glide-core/redis-rs/redis/tests/test_basic.rs +++ b/glide-core/redis-rs/redis/tests/test_basic.rs @@ -19,6 +19,7 @@ mod basic { use crate::{assert_args, support::*}; #[test] + #[serial_test::serial] fn test_parse_redis_url() { let redis_url = "redis://127.0.0.1:1234/0".to_string(); redis::parse_redis_url(&redis_url).unwrap(); @@ -27,11 +28,13 @@ mod basic { } #[test] + #[serial_test::serial] fn test_redis_url_fromstr() { let _info: ConnectionInfo = "redis://127.0.0.1:1234/0".parse().unwrap(); } #[test] + #[serial_test::serial] fn test_args() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -46,6 +49,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_getset() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -62,6 +66,7 @@ mod basic { //unit test for key_type function #[test] + #[serial_test::serial] fn test_key_type() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -107,6 +112,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_client_tracking_doesnt_block_execution() { //It checks if the library distinguish a push-type message from the others and continues its normal operation. let ctx = TestContext::new(); @@ -144,6 +150,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_incr() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -153,6 +160,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_getdel() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -168,6 +176,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_getex() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -202,6 +211,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_info() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -218,6 +228,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_hash_ops() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -248,6 +259,7 @@ mod basic { // Not supported with the current appveyor/windows binary deployed. #[cfg(not(target_os = "windows"))] #[test] + #[serial_test::serial] fn test_unlink() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -262,6 +274,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_set_ops() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -287,6 +300,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_scan() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -305,6 +319,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_optionals() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -327,6 +342,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_scanning() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -354,6 +370,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_filtered_scanning() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -380,6 +397,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_pipeline() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -403,6 +421,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_pipeline_with_err() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -440,6 +459,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_empty_pipeline() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -450,6 +470,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_pipeline_transaction() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -474,6 +495,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_pipeline_transaction_with_errors() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -502,6 +524,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_pipeline_reuse_query() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -540,6 +563,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_pipeline_reuse_query_clear() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -578,6 +602,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_real_transaction() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -612,6 +637,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_real_transaction_highlevel() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -635,6 +661,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_pubsub() { use std::sync::{Arc, Barrier}; let ctx = TestContext::new(); @@ -672,6 +699,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_pubsub_unsubscribe() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -693,6 +721,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_pubsub_subscribe_while_messages_are_sent() { let ctx = TestContext::new(); let mut conn_external = ctx.connection(); @@ -751,6 +780,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_pubsub_unsubscribe_no_subs() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -766,6 +796,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_pubsub_unsubscribe_one_sub() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -782,6 +813,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_pubsub_unsubscribe_one_sub_one_psub() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -799,6 +831,7 @@ mod basic { } #[test] + #[serial_test::serial] fn scoped_pubsub() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -847,6 +880,7 @@ mod basic { } #[test] + #[serial_test::serial] #[cfg(feature = "script")] fn test_script() { let ctx = TestContext::new(); @@ -869,6 +903,7 @@ mod basic { } #[test] + #[serial_test::serial] #[cfg(feature = "script")] fn test_script_load() { let ctx = TestContext::new(); @@ -882,6 +917,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_tuple_args() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -908,6 +944,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_nice_api() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -931,6 +968,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_auto_m_versions() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -942,6 +980,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_nice_hash_api() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -995,6 +1034,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_nice_list_api() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -1023,6 +1063,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_tuple_decoding_regression() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -1041,6 +1082,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_bit_operations() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -1050,6 +1092,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_redis_server_down() { let mut ctx = TestContext::new(); let mut con = ctx.connection(); @@ -1067,6 +1110,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_zinterstore_weights() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -1122,6 +1166,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_zunionstore_weights() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -1189,6 +1234,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_zrembylex() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -1221,6 +1267,7 @@ mod basic { // Not supported with the current appveyor/windows binary deployed. #[cfg(not(target_os = "windows"))] #[test] + #[serial_test::serial] fn test_zrandmember() { use redis::ProtocolVersion; @@ -1271,6 +1318,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_sismember() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -1289,6 +1337,7 @@ mod basic { // Not supported with the current appveyor/windows binary deployed. #[cfg(not(target_os = "windows"))] #[test] + #[serial_test::serial] fn test_smismember() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -1300,6 +1349,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_object_commands() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -1335,6 +1385,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_mget() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -1355,6 +1406,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_variable_length_get() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -1367,6 +1419,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_multi_generics() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -1378,6 +1431,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_set_options_with_get() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -1392,6 +1446,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_set_options_options() { let empty = SetOptions::default(); assert_eq!(ToRedisArgs::to_redis_args(&empty).len(), 0); @@ -1428,6 +1483,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_blocking_sorted_set_api() { let ctx = TestContext::new(); let mut con = ctx.connection(); @@ -1484,6 +1540,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_set_client_name_by_config() { const CLIENT_NAME: &str = "TEST_CLIENT_NAME"; @@ -1507,6 +1564,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_push_manager() { let ctx = TestContext::new(); if ctx.protocol == ProtocolVersion::RESP2 { @@ -1562,6 +1620,7 @@ mod basic { } #[test] + #[serial_test::serial] fn test_push_manager_disconnection() { let ctx = TestContext::new(); if ctx.protocol == ProtocolVersion::RESP2 { diff --git a/glide-core/redis-rs/redis/tests/test_cluster.rs b/glide-core/redis-rs/redis/tests/test_cluster.rs index cbeddd2fe4..38b3019edb 100644 --- a/glide-core/redis-rs/redis/tests/test_cluster.rs +++ b/glide-core/redis-rs/redis/tests/test_cluster.rs @@ -16,6 +16,7 @@ mod cluster { }; #[test] + #[serial_test::serial] fn test_cluster_basics() { let cluster = TestClusterContext::new(3, 0); let mut con = cluster.connection(); @@ -35,6 +36,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_with_username_and_password() { let cluster = TestClusterContext::new_with_cluster_client_builder( 3, @@ -65,6 +67,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_with_bad_password() { let cluster = TestClusterContext::new_with_cluster_client_builder( 3, @@ -80,6 +83,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_read_from_replicas() { let cluster = TestClusterContext::new_with_cluster_client_builder( 6, @@ -106,6 +110,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_eval() { let cluster = TestClusterContext::new(3, 0); let mut con = cluster.connection(); @@ -127,6 +132,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_resp3() { if use_protocol() == ProtocolVersion::RESP2 { return; @@ -159,6 +165,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_multi_shard_commands() { let cluster = TestClusterContext::new(3, 0); @@ -173,6 +180,7 @@ mod cluster { } #[test] + #[serial_test::serial] #[cfg(feature = "script")] fn test_cluster_script() { let cluster = TestClusterContext::new(3, 0); @@ -191,6 +199,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_pipeline() { let cluster = TestClusterContext::new(3, 0); cluster.wait_for_cluster_up(); @@ -207,6 +216,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_pipeline_multiple_keys() { use redis::FromRedisValue; let cluster = TestClusterContext::new(3, 0); @@ -244,6 +254,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_pipeline_invalid_command() { let cluster = TestClusterContext::new(3, 0); cluster.wait_for_cluster_up(); @@ -272,6 +283,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_can_connect_to_server_that_sends_cluster_slots_without_host_name() { let name = "test_cluster_can_connect_to_server_that_sends_cluster_slots_without_host_name"; @@ -298,6 +310,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_can_connect_to_server_that_sends_cluster_slots_with_null_host_name() { let name = "test_cluster_can_connect_to_server_that_sends_cluster_slots_with_null_host_name"; @@ -322,6 +335,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_can_connect_to_server_that_sends_cluster_slots_with_partial_nodes_with_unknown_host_name( ) { let name = "test_cluster_can_connect_to_server_that_sends_cluster_slots_with_partial_nodes_with_unknown_host_name"; @@ -358,6 +372,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_pipeline_command_ordering() { let cluster = TestClusterContext::new(3, 0); cluster.wait_for_cluster_up(); @@ -383,6 +398,7 @@ mod cluster { } #[test] + #[serial_test::serial] #[ignore] // Flaky fn test_cluster_pipeline_ordering_with_improper_command() { let cluster = TestClusterContext::new(3, 0); @@ -417,6 +433,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_retries() { let name = "tryagain"; @@ -444,6 +461,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_exhaust_retries() { let name = "tryagain_exhaust_retries"; @@ -479,6 +497,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_move_error_when_new_node_is_added() { let name = "rebuild_with_extra_nodes"; @@ -536,6 +555,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_ask_redirect() { let name = "node"; let completed = Arc::new(AtomicI32::new(0)); @@ -580,6 +600,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_ask_error_when_new_node_is_added() { let name = "ask_with_extra_nodes"; @@ -629,6 +650,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_replica_read() { let name = "node"; @@ -682,6 +704,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_io_error() { let name = "node"; let completed = Arc::new(AtomicI32::new(0)); @@ -715,6 +738,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_non_retryable_error_should_not_retry() { let name = "node"; let completed = Arc::new(AtomicI32::new(0)); @@ -785,16 +809,19 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_fan_out_to_all_primaries() { test_cluster_fan_out("FLUSHALL", vec![6379, 6381], None); } #[test] + #[serial_test::serial] fn test_cluster_fan_out_to_all_nodes() { test_cluster_fan_out("CONFIG SET", vec![6379, 6380, 6381, 6382], None); } #[test] + #[serial_test::serial] fn test_cluster_fan_out_out_once_to_each_primary_when_no_replicas_are_available() { test_cluster_fan_out( "CONFIG SET", @@ -815,6 +842,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_fan_out_out_once_even_if_primary_has_multiple_slot_ranges() { test_cluster_fan_out( "CONFIG SET", @@ -845,6 +873,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_split_multi_shard_command_and_combine_arrays_of_values() { let name = "test_cluster_split_multi_shard_command_and_combine_arrays_of_values"; let mut cmd = cmd("MGET"); @@ -882,6 +911,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_route_correctly_on_packed_transaction_with_single_node_requests() { let name = "test_cluster_route_correctly_on_packed_transaction_with_single_node_requests"; let mut pipeline = redis::pipe(); @@ -931,6 +961,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_route_correctly_on_packed_transaction_with_single_node_requests2() { let name = "test_cluster_route_correctly_on_packed_transaction_with_single_node_requests2"; let mut pipeline = redis::pipe(); @@ -974,6 +1005,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_with_client_name() { let cluster = TestClusterContext::new_with_cluster_client_builder( 3, @@ -1001,6 +1033,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_can_be_created_with_partial_slot_coverage() { let name = "test_cluster_can_be_created_with_partial_slot_coverage"; let slots_config = Some(vec![ @@ -1046,6 +1079,7 @@ mod cluster { use redis::ConnectionInfo; #[test] + #[serial_test::serial] fn test_cluster_basics_with_mtls() { let cluster = TestClusterContext::new_with_mtls(3, 0); @@ -1067,6 +1101,7 @@ mod cluster { } #[test] + #[serial_test::serial] fn test_cluster_should_not_connect_without_mtls() { let cluster = TestClusterContext::new_with_mtls(3, 0); diff --git a/glide-core/redis-rs/redis/tests/test_cluster_async.rs b/glide-core/redis-rs/redis/tests/test_cluster_async.rs index b690ed87b5..e6a5984fa7 100644 --- a/glide-core/redis-rs/redis/tests/test_cluster_async.rs +++ b/glide-core/redis-rs/redis/tests/test_cluster_async.rs @@ -99,6 +99,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_basic_cmd() { let cluster = TestClusterContext::new(3, 0); @@ -121,6 +122,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_basic_eval() { let cluster = TestClusterContext::new(3, 0); @@ -140,6 +142,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_basic_script() { let cluster = TestClusterContext::new(3, 0); @@ -159,6 +162,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_route_flush_to_specific_node() { let cluster = TestClusterContext::new(3, 0); @@ -194,6 +198,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_route_flush_to_node_by_address() { let cluster = TestClusterContext::new(3, 0); @@ -234,6 +239,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_route_info_to_nodes() { let cluster = TestClusterContext::new(12, 1); @@ -313,6 +319,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_resp3() { if use_protocol() == ProtocolVersion::RESP2 { return; @@ -352,6 +359,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_basic_pipe() { let cluster = TestClusterContext::new(3, 0); @@ -371,6 +379,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_multi_shard_commands() { let cluster = TestClusterContext::new(3, 0); @@ -389,6 +398,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_basic_failover() { block_on_all(async move { test_failover(&TestClusterContext::new(6, 1), 10, 123, false).await; @@ -582,6 +592,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_error_in_inner_connection() { let cluster = TestClusterContext::new(3, 0); @@ -606,30 +617,7 @@ mod cluster_async { } #[test] - #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] - fn test_async_cluster_async_std_basic_cmd() { - let cluster = TestClusterContext::new(3, 0); - - block_on_all_using_async_std(async { - let mut connection = cluster.async_connection(None).await; - redis::cmd("SET") - .arg("test") - .arg("test_data") - .query_async(&mut connection) - .await?; - redis::cmd("GET") - .arg("test") - .clone() - .query_async(&mut connection) - .map_ok(|res: String| { - assert_eq!(res, "test_data"); - }) - .await - }) - .unwrap(); - } - - #[test] + #[serial_test::serial] fn test_async_cluster_can_connect_to_server_that_sends_cluster_slots_without_host_name() { let name = "test_async_cluster_can_connect_to_server_that_sends_cluster_slots_without_host_name"; @@ -665,6 +653,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_can_connect_to_server_that_sends_cluster_slots_with_null_host_name() { let name = "test_async_cluster_can_connect_to_server_that_sends_cluster_slots_with_null_host_name"; @@ -697,6 +686,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_cannot_connect_to_server_with_unknown_host_name() { let name = "test_async_cluster_cannot_connect_to_server_with_unknown_host_name"; let handler = move |cmd: &[u8], _| { @@ -727,6 +717,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_can_connect_to_server_that_sends_cluster_slots_with_partial_nodes_with_unknown_host_name( ) { let name = "test_async_cluster_can_connect_to_server_that_sends_cluster_slots_with_partial_nodes_with_unknown_host_name"; @@ -772,6 +763,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_retries() { let name = "tryagain"; @@ -804,6 +796,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_tryagain_exhaust_retries() { let name = "tryagain_exhaust_retries"; @@ -862,6 +855,7 @@ mod cluster_async { } } #[test] + #[serial_test::serial] fn test_async_cluster_move_error_when_new_node_is_added() { let name = "rebuild_with_extra_nodes"; @@ -1008,7 +1002,7 @@ mod cluster_async { .query_async::<_, Option>(&mut connection) .await; assert_eq!(res, Ok(Some(123))); - // If there is a majority in the topology views, or if it's a 2-nodes cluster, we shall be able to calculate the topology on the first try, + // If there is a majority in the topology views, or if it's a 2-nodes cluster, we shall be able to calculate the topology on the first try, // so each node will be queried only once with CLUSTER SLOTS. // Otherwise, if we don't have a majority, we expect to see the refresh_slots function being called with the maximum retry number. let expected_calls = if has_a_majority || num_of_nodes == 2 {num_of_nodes} else {DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES * num_of_nodes}; @@ -1022,8 +1016,6 @@ mod cluster_async { #[cfg(feature = "tokio-comp")] tokio::time::sleep(sleep_duration).await; - #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] - async_std::task::sleep(sleep_duration).await; } } panic!("Failed to reach to the expected topology refresh retries. Found={refreshed_calls}, Expected={expected_calls}") @@ -1244,6 +1236,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_refresh_topology_after_moved_error_all_nodes_agree_get_succeed() { let ports = get_ports(3); test_async_cluster_refresh_topology_after_moved_assert_get_succeed_and_expected_retries( @@ -1254,6 +1247,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_refresh_topology_in_client_init_all_nodes_agree_get_succeed() { let ports = get_ports(3); test_async_cluster_refresh_topology_in_client_init_get_succeed( @@ -1263,6 +1257,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_refresh_topology_after_moved_error_with_no_majority_get_succeed() { for num_of_nodes in 2..4 { let ports = get_ports(num_of_nodes); @@ -1275,6 +1270,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_refresh_topology_in_client_init_with_no_majority_get_succeed() { for num_of_nodes in 2..4 { let ports = get_ports(num_of_nodes); @@ -1286,6 +1282,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_refresh_topology_even_with_zero_retries() { let name = "test_async_cluster_refresh_topology_even_with_zero_retries"; @@ -1297,7 +1294,7 @@ mod cluster_async { handler: _handler, .. } = MockEnv::with_client_builder( - ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(0) + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(0) // Disable the rate limiter to refresh slots immediately on the MOVED error. .slots_refresh_rate_limit(Duration::from_secs(0), 0), name, @@ -1376,6 +1373,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_reconnect_even_with_zero_retries() { let name = "test_async_cluster_reconnect_even_with_zero_retries"; @@ -1457,6 +1455,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_refresh_slots_rate_limiter_skips_refresh() { let ports = get_ports(3); test_async_cluster_refresh_slots_rate_limiter_helper( @@ -1467,6 +1466,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_refresh_slots_rate_limiter_does_refresh_when_wait_duration_passed() { let ports = get_ports(3); test_async_cluster_refresh_slots_rate_limiter_helper( @@ -1477,6 +1477,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_ask_redirect() { let name = "node"; let completed = Arc::new(AtomicI32::new(0)); @@ -1526,6 +1527,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_ask_save_new_connection() { let name = "node"; let ping_attempts = Arc::new(AtomicI32::new(0)); @@ -1568,6 +1570,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_reset_routing_if_redirect_fails() { let name = "test_async_cluster_reset_routing_if_redirect_fails"; let completed = Arc::new(AtomicI32::new(0)); @@ -1606,6 +1609,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_ask_redirect_even_if_original_call_had_no_route() { let name = "node"; let completed = Arc::new(AtomicI32::new(0)); @@ -1657,6 +1661,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_ask_error_when_new_node_is_added() { let name = "ask_with_extra_nodes"; @@ -1711,6 +1716,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_replica_read() { let name = "node"; @@ -1815,16 +1821,19 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_fan_out_to_all_primaries() { test_async_cluster_fan_out("FLUSHALL", vec![6379, 6381], None); } #[test] + #[serial_test::serial] fn test_async_cluster_fan_out_to_all_nodes() { test_async_cluster_fan_out("CONFIG SET", vec![6379, 6380, 6381, 6382], None); } #[test] + #[serial_test::serial] fn test_async_cluster_fan_out_once_to_each_primary_when_no_replicas_are_available() { test_async_cluster_fan_out( "CONFIG SET", @@ -1845,6 +1854,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_fan_out_once_even_if_primary_has_multiple_slot_ranges() { test_async_cluster_fan_out( "CONFIG SET", @@ -1875,6 +1885,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_route_according_to_passed_argument() { let name = "test_async_cluster_route_according_to_passed_argument"; @@ -1939,6 +1950,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_fan_out_and_aggregate_numeric_response_with_min() { let name = "test_async_cluster_fan_out_and_aggregate_numeric_response"; let mut cmd = Cmd::new(); @@ -1969,6 +1981,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_fan_out_and_aggregate_logical_array_response() { let name = "test_async_cluster_fan_out_and_aggregate_logical_array_response"; let mut cmd = Cmd::new(); @@ -2019,6 +2032,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_fan_out_and_return_one_succeeded_response() { let name = "test_async_cluster_fan_out_and_return_one_succeeded_response"; let mut cmd = Cmd::new(); @@ -2053,6 +2067,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_fan_out_and_fail_one_succeeded_if_there_are_no_successes() { let name = "test_async_cluster_fan_out_and_fail_one_succeeded_if_there_are_no_successes"; let mut cmd = Cmd::new(); @@ -2085,6 +2100,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_fan_out_and_return_all_succeeded_response() { let name = "test_async_cluster_fan_out_and_return_all_succeeded_response"; let cmd = cmd("FLUSHALL"); @@ -2111,6 +2127,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_fan_out_and_fail_all_succeeded_if_there_is_a_single_failure() { let name = "test_async_cluster_fan_out_and_fail_all_succeeded_if_there_is_a_single_failure"; let cmd = cmd("FLUSHALL"); @@ -2144,6 +2161,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_first_succeeded_non_empty_or_all_empty_return_value_ignoring_nil_and_err_resps( ) { let name = @@ -2182,6 +2200,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_first_succeeded_non_empty_or_all_empty_return_err_if_all_resps_are_nil_and_errors( ) { let name = @@ -2215,6 +2234,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_first_succeeded_non_empty_or_all_empty_return_nil_if_all_resp_nil() { let name = "test_async_cluster_first_succeeded_non_empty_or_all_empty_return_nil_if_all_resp_nil"; @@ -2242,6 +2262,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_fan_out_and_return_map_of_results_for_special_response_policy() { let name = "foo"; let mut cmd = Cmd::new(); @@ -2282,6 +2303,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_fan_out_and_combine_arrays_of_values() { let name = "foo"; let cmd = cmd("KEYS"); @@ -2315,6 +2337,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_split_multi_shard_command_and_combine_arrays_of_values() { let name = "test_async_cluster_split_multi_shard_command_and_combine_arrays_of_values"; let mut cmd = cmd("MGET"); @@ -2355,6 +2378,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_handle_asking_error_in_split_multi_shard_command() { let name = "test_async_cluster_handle_asking_error_in_split_multi_shard_command"; let mut cmd = cmd("MGET"); @@ -2404,6 +2428,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_pass_errors_from_split_multi_shard_command() { let name = "test_async_cluster_pass_errors_from_split_multi_shard_command"; let mut cmd = cmd("MGET"); @@ -2431,6 +2456,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_handle_missing_slots_in_split_multi_shard_command() { let name = "test_async_cluster_handle_missing_slots_in_split_multi_shard_command"; let mut cmd = cmd("MGET"); @@ -2464,6 +2490,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_with_username_and_password() { let cluster = TestClusterContext::new_with_cluster_client_builder( 3, @@ -2496,6 +2523,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_io_error() { let name = "node"; let completed = Arc::new(AtomicI32::new(0)); @@ -2534,6 +2562,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_non_retryable_error_should_not_retry() { let name = "node"; let completed = Arc::new(AtomicI32::new(0)); @@ -2570,6 +2599,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_read_from_primary() { let name = "node"; let found_ports = Arc::new(std::sync::Mutex::new(Vec::new())); @@ -2627,6 +2657,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_round_robin_read_from_replica() { let name = "node"; let found_ports = Arc::new(std::sync::Mutex::new(Vec::new())); @@ -2713,6 +2744,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_handle_complete_server_disconnect_without_panicking() { let cluster = TestClusterContext::new_with_cluster_client_builder( 3, @@ -2741,6 +2773,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_test_fast_reconnect() { // Note the 3 seconds connection check to differentiate between notifications and periodic let cluster = TestClusterContext::new_with_cluster_client_builder( @@ -2851,6 +2884,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_restore_resp3_pubsub_state_passive_disconnect() { let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); let use_sharded = redis_ver.starts_with("7."); @@ -3020,6 +3054,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_restore_resp3_pubsub_state_after_scale_out() { let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); let use_sharded = redis_ver.starts_with("7."); @@ -3258,6 +3293,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_resp3_pubsub() { let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); let use_sharded = redis_ver.starts_with("7."); @@ -3372,6 +3408,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_periodic_checks_update_topology_after_failover() { // This test aims to validate the functionality of periodic topology checks by detecting and updating topology changes. // We will repeatedly execute CLUSTER NODES commands against the primary node responsible for slot 0, recording its node ID. @@ -3442,6 +3479,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_recover_disconnected_management_connections() { // This test aims to verify that the management connections used for periodic checks are reconnected, in case that they get killed. // In order to test this, we choose a single node, kill all connections to it which aren't user connections, and then wait until new @@ -3494,6 +3532,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_with_client_name() { let cluster = TestClusterContext::new_with_cluster_client_builder( 3, @@ -3530,6 +3569,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_reroute_from_replica_if_in_loading_state() { /* Test replica in loading state. The expected behaviour is that the request will be directed to a different replica or the primary. depends on the read from replica policy. */ @@ -3585,6 +3625,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_read_from_primary_when_primary_loading() { // Test primary in loading state. The expected behaviour is that the request will be retried until the primary is no longer in loading state. let name = "test_async_cluster_read_from_primary_when_primary_loading"; @@ -3639,6 +3680,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_can_be_created_with_partial_slot_coverage() { let name = "test_async_cluster_can_be_created_with_partial_slot_coverage"; let slots_config = Some(vec![ @@ -3679,6 +3721,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_reconnect_after_complete_server_disconnect() { let cluster = TestClusterContext::new_with_cluster_client_builder( 3, @@ -3722,6 +3765,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_reconnect_after_complete_server_disconnect_route_to_many() { let cluster = TestClusterContext::new_with_cluster_client_builder( 3, @@ -3760,6 +3804,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_blocking_command_when_cluster_drops() { let cluster = TestClusterContext::new_with_cluster_client_builder( 3, @@ -3787,6 +3832,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_saves_reconnected_connection() { let name = "test_async_cluster_saves_reconnected_connection"; let ping_attempts = Arc::new(AtomicI32::new(0)); @@ -3856,6 +3902,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_periodic_checks_use_management_connection() { let cluster = TestClusterContext::new_with_cluster_client_builder( 3, @@ -3884,7 +3931,7 @@ mod cluster_async { .expect("Failed executing CLIENT LIST"); let mut client_list_parts = client_list.split('\n'); if client_list_parts - .any(|line| line.contains(MANAGEMENT_CONN_NAME) && line.contains("cmd=cluster")) + .any(|line| line.contains(MANAGEMENT_CONN_NAME) && line.contains("cmd=cluster")) && client_list.matches(MANAGEMENT_CONN_NAME).count() == 1 { return Ok::<_, RedisError>(()); } @@ -3954,6 +4001,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_only_management_connection_is_reconnected_after_connection_failure() { // This test will check two aspects: // 1. Ensuring that after a disconnection in the management connection, a new management connection is established. @@ -4023,6 +4071,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_dont_route_to_a_random_on_non_key_based_cmd() { // This test verifies that non-key-based commands do not get routed to a random node // when no connection is found for the given route. Instead, the appropriate error @@ -4087,6 +4136,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_route_to_random_on_key_based_cmd() { // This test verifies that key-based commands get routed to a random node // when no connection is found for the given route. The command should @@ -4143,6 +4193,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_do_not_retry_when_receiver_was_dropped() { let name = "test_async_cluster_do_not_retry_when_receiver_was_dropped"; let cmd = cmd("FAKE_COMMAND"); @@ -4196,6 +4247,7 @@ mod cluster_async { use super::*; #[test] + #[serial_test::serial] fn test_async_cluster_basic_cmd_with_mtls() { let cluster = TestClusterContext::new_with_mtls(3, 0); block_on_all(async move { @@ -4218,6 +4270,7 @@ mod cluster_async { } #[test] + #[serial_test::serial] fn test_async_cluster_should_not_connect_without_mtls_enabled() { let cluster = TestClusterContext::new_with_mtls(3, 0); block_on_all(async move { diff --git a/glide-core/redis-rs/redis/tests/test_cluster_scan.rs b/glide-core/redis-rs/redis/tests/test_cluster_scan.rs index 29a3c87b48..cfc4bae594 100644 --- a/glide-core/redis-rs/redis/tests/test_cluster_scan.rs +++ b/glide-core/redis-rs/redis/tests/test_cluster_scan.rs @@ -47,6 +47,7 @@ mod test_cluster_scan_async { } #[tokio::test] + #[serial_test::serial] async fn test_async_cluster_scan() { let cluster = TestClusterContext::new(3, 0); let mut connection = cluster.async_connection(None).await; @@ -87,7 +88,8 @@ mod test_cluster_scan_async { } } - #[tokio::test] // test cluster scan with slot migration in the middle + #[tokio::test] + #[serial_test::serial] // test cluster scan with slot migration in the middle async fn test_async_cluster_scan_with_migration() { let cluster = TestClusterContext::new(3, 0); @@ -162,7 +164,8 @@ mod test_cluster_scan_async { assert_eq!(keys, expected_keys); } - #[tokio::test] // test cluster scan with node fail in the middle + #[tokio::test] + #[serial_test::serial] // test cluster scan with node fail in the middle async fn test_async_cluster_scan_with_fail() { let cluster = TestClusterContext::new_with_cluster_client_builder( 3, @@ -224,7 +227,8 @@ mod test_cluster_scan_async { assert!(result.is_err()); } - #[tokio::test] // Test cluster scan with killing all masters during scan + #[tokio::test] + #[serial_test::serial] // Test cluster scan with killing all masters during scan async fn test_async_cluster_scan_with_all_masters_down() { let cluster = TestClusterContext::new_with_cluster_client_builder( 6, @@ -378,6 +382,7 @@ mod test_cluster_scan_async { } #[tokio::test] + #[serial_test::serial] // Test cluster scan with killing all replicas during scan async fn test_async_cluster_scan_with_all_replicas_down() { let cluster = TestClusterContext::new_with_cluster_client_builder( @@ -482,6 +487,7 @@ mod test_cluster_scan_async { assert_eq!(keys, expected_keys); } #[tokio::test] + #[serial_test::serial] // Test cluster scan with setting keys for each iteration async fn test_async_cluster_scan_set_in_the_middle() { let cluster = TestClusterContext::new(3, 0); @@ -541,6 +547,7 @@ mod test_cluster_scan_async { } #[tokio::test] + #[serial_test::serial] // Test cluster scan with deleting keys for each iteration async fn test_async_cluster_scan_dell_in_the_middle() { let cluster = TestClusterContext::new(3, 0); @@ -603,6 +610,7 @@ mod test_cluster_scan_async { } #[tokio::test] + #[serial_test::serial] // Testing cluster scan with Pattern option async fn test_async_cluster_scan_with_pattern() { let cluster = TestClusterContext::new(3, 0); @@ -661,6 +669,7 @@ mod test_cluster_scan_async { } #[tokio::test] + #[serial_test::serial] // Testing cluster scan with TYPE option async fn test_async_cluster_scan_with_type() { let cluster = TestClusterContext::new(3, 0); @@ -719,6 +728,7 @@ mod test_cluster_scan_async { } #[tokio::test] + #[serial_test::serial] // Testing cluster scan with COUNT option async fn test_async_cluster_scan_with_count() { let cluster = TestClusterContext::new(3, 0); @@ -782,6 +792,7 @@ mod test_cluster_scan_async { } #[tokio::test] + #[serial_test::serial] // Testing cluster scan when connection fails in the middle and we get an error // then cluster up again and scanning can continue without any problem async fn test_async_cluster_scan_failover() { diff --git a/glide-core/tests/test_standalone_client.rs b/glide-core/tests/test_standalone_client.rs index 448fa0faa0..5b269dd42c 100644 --- a/glide-core/tests/test_standalone_client.rs +++ b/glide-core/tests/test_standalone_client.rs @@ -193,20 +193,19 @@ mod standalone_client_tests { } fn test_read_from_replica(config: ReadFromReplicaTestConfig) { - let mut mocks = create_primary_mock_with_replicas( + let mut servers = create_primary_mock_with_replicas( config.number_of_initial_replicas - config.number_of_missing_replicas, ); let mut cmd = redis::cmd("GET"); cmd.arg("foo"); - for mock in mocks.iter() { + for server in servers.iter() { for _ in 0..3 { - mock.add_response(&cmd, "$-1\r\n".to_string()); + server.add_response(&cmd, "$-1\r\n".to_string()); } } - let mut addresses = get_mock_addresses(&mocks); - + let mut addresses = get_mock_addresses(&servers); for i in 4 - config.number_of_missing_replicas..4 { addresses.push(redis::ConnectionAddr::Tcp( "foo".to_string(), @@ -221,19 +220,32 @@ mod standalone_client_tests { let mut client = StandaloneClient::create_client(connection_request.into(), None) .await .unwrap(); - for mock in mocks.drain(1..config.number_of_replicas_dropped_after_connection + 1) { - mock.close().await; + logger_core::log_info( + "Test", + format!( + "Closing {} servers after connection established", + config.number_of_replicas_dropped_after_connection + ), + ); + for server in servers.drain(1..config.number_of_replicas_dropped_after_connection + 1) { + server.close().await; } + logger_core::log_info( + "Test", + format!("sending {} messages", config.number_of_requests_sent), + ); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; for _ in 0..config.number_of_requests_sent { let _ = client.send_command(&cmd).await; } }); assert_eq!( - mocks[0].get_number_of_received_commands(), + servers[0].get_number_of_received_commands(), config.expected_primary_reads ); - let mut replica_reads: Vec<_> = mocks + let mut replica_reads: Vec<_> = servers .iter() .skip(1) .map(|mock| mock.get_number_of_received_commands()) @@ -294,7 +306,9 @@ mod standalone_client_tests { test_read_from_replica(ReadFromReplicaTestConfig { read_from: ReadFrom::PreferReplica, expected_primary_reads: 0, - expected_replica_reads: vec![2, 3], + // Since we drop 1 replica after connection establishment + // we expect all reads to be handled by the remaining replicas + expected_replica_reads: vec![3, 3], number_of_replicas_dropped_after_connection: 1, number_of_requests_sent: 6, ..Default::default() diff --git a/glide-core/tests/utilities/mocks.rs b/glide-core/tests/utilities/mocks.rs index 160e8a3189..33b8ae4121 100644 --- a/glide-core/tests/utilities/mocks.rs +++ b/glide-core/tests/utilities/mocks.rs @@ -5,14 +5,15 @@ use futures_intrusive::sync::ManualResetEvent; use redis::{Cmd, ConnectionAddr, Value}; use std::collections::HashMap; use std::io; +use std::io::Read; +use std::io::Write; use std::net::TcpListener; +use std::net::TcpStream as StdTcpStream; use std::str::from_utf8; use std::sync::{ atomic::{AtomicU16, Ordering}, Arc, }; -use tokio::io::AsyncWriteExt; -use tokio::net::TcpStream; use tokio::sync::mpsc::UnboundedSender; pub struct MockedRequest { @@ -29,20 +30,24 @@ pub struct ServerMock { closing_completed_signal: Arc, } -async fn read_from_socket(buffer: &mut Vec, socket: &mut TcpStream) -> Option { - let _ = socket.readable().await; - - loop { - match socket.try_read_buf(buffer) { +fn read_from_socket( + buffer: &mut [u8], + socket: &mut StdTcpStream, + closing_signal: &Arc, +) -> Option { + while !closing_signal.is_set() { + let read_res = socket.read(buffer); // read() is using timeout + match read_res { Ok(0) => { return None; } - Ok(size) => return Some(size), + Ok(size) => { + return Some(size); + } Err(ref e) if e.kind() == io::ErrorKind::WouldBlock || e.kind() == io::ErrorKind::Interrupted => { - tokio::task::yield_now().await; continue; } Err(_) => { @@ -50,43 +55,53 @@ async fn read_from_socket(buffer: &mut Vec, socket: &mut TcpStream) -> Optio } } } + // If we reached here, it means we got a signal to terminate + None +} + +/// Escape and print a RESP message +fn log_resp_message(msg: &str) { + logger_core::log_info( + "Test", + format!( + "{:?} {}", + std::thread::current().id(), + msg.replace('\r', "\\r").replace('\n', "\\n") + ), + ); } -async fn receive_and_respond_to_next_message( +fn receive_and_respond_to_next_message( receiver: &mut tokio::sync::mpsc::UnboundedReceiver, - socket: &mut TcpStream, + socket: &mut StdTcpStream, received_commands: &Arc, constant_responses: &HashMap, closing_signal: &Arc, ) -> bool { - let mut buffer = Vec::with_capacity(1024); - let size = tokio::select! { - size = read_from_socket(&mut buffer, socket) => { - let Some(size) = size else { - return false; - }; - size - }, - _ = closing_signal.wait() => { + let mut buffer = vec![0; 1024]; + let size = match read_from_socket(&mut buffer, socket, closing_signal) { + Some(size) => size, + None => { return false; } }; - let message = from_utf8(&buffer[..size]).unwrap().to_string(); + log_resp_message(&message); + let setinfo_count = message.matches("SETINFO").count(); if setinfo_count > 0 { let mut buffer = Vec::new(); for _ in 0..setinfo_count { super::encode_value(&Value::Okay, &mut buffer).unwrap(); } - socket.write_all(&buffer).await.unwrap(); + socket.write_all(&buffer).unwrap(); return true; } if let Some(response) = constant_responses.get(&message) { let mut buffer = Vec::new(); super::encode_value(response, &mut buffer).unwrap(); - socket.write_all(&buffer).await.unwrap(); + socket.write_all(&buffer).unwrap(); return true; } let Ok(request) = receiver.try_recv() else { @@ -94,7 +109,7 @@ async fn receive_and_respond_to_next_message( }; received_commands.fetch_add(1, Ordering::AcqRel); assert_eq!(message, request.expected_message); - socket.write_all(request.response.as_bytes()).await.unwrap(); + socket.write_all(request.response.as_bytes()).unwrap(); true } @@ -127,15 +142,11 @@ impl ServerMock { let closing_signal_clone = closing_signal.clone(); let closing_completed_signal = Arc::new(ManualResetEvent::new(false)); let closing_completed_signal_clone = closing_completed_signal.clone(); - let runtime = tokio::runtime::Builder::new_multi_thread() - .worker_threads(1) - .thread_name(format!("ServerMock - {address}")) - .enable_all() - .build() - .unwrap(); - runtime.spawn(async move { - let listener = tokio::net::TcpListener::from_std(listener).unwrap(); - let mut socket = listener.accept().await.unwrap().0; + let address_clone = address.clone(); + std::thread::spawn(move || { + logger_core::log_info("Test", format!("ServerMock started on: {}", address_clone)); + let mut socket: StdTcpStream = listener.accept().unwrap().0; + let _ = socket.set_read_timeout(Some(std::time::Duration::from_millis(10))); while receive_and_respond_to_next_message( &mut receiver, @@ -143,17 +154,25 @@ impl ServerMock { &received_commands_clone, &constant_responses, &closing_signal_clone, - ) - .await - {} + ) {} + + // Terminate the connection + let _ = socket.shutdown(std::net::Shutdown::Both); + // Now notify exit completed closing_completed_signal_clone.set(); + + logger_core::log_info( + "Test", + format!("{:?} ServerMock exited", std::thread::current().id()), + ); }); + Self { request_sender, address, received_commands, - runtime: Some(runtime), + runtime: None, closing_signal, closing_completed_signal, } @@ -186,6 +205,5 @@ impl Mock for ServerMock { impl Drop for ServerMock { fn drop(&mut self) { self.closing_signal.set(); - self.runtime.take().unwrap().shutdown_background(); } } From 1b6db03e2904c39ae02981d1fff8e56e88b83284 Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Wed, 16 Oct 2024 14:20:14 -0700 Subject: [PATCH 017/180] Fix python CI: fix linter installation (#2465) fix ci Signed-off-by: Yury-Fridlyand --- .github/workflows/python.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 45d2c0cf0d..c85045df07 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -216,9 +216,9 @@ jobs: - name: Install dependencies if: always() - working-directory: ./python - run: | - sudo apt install -y python3-pip python3 flake8 isort black + uses: threeal/pipx-install-action@latest + with: + packages: flake8 isort black - name: Lint python with isort if: always() From c252cbbd41c07b4af362622819a2e66413133a39 Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Thu, 17 Oct 2024 11:30:06 -0700 Subject: [PATCH 018/180] Java: `FT.SEARCH` (#2439) * `FT.CREATE` Signed-off-by: Yury-Fridlyand --- CHANGELOG.md | 1 + glide-core/src/client/value_conversion.rs | 50 ++++++- .../glide/api/commands/servermodules/FT.java | 135 ++++++++++++++++++ .../models/commands/FT/FTSearchOptions.java | 131 +++++++++++++++++ .../java/glide/modules/VectorSearchTests.java | 110 ++++++++++++++ 5 files changed, 426 insertions(+), 1 deletion(-) create mode 100644 java/client/src/main/java/glide/api/models/commands/FT/FTSearchOptions.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 3cf1a3a7ea..ad4c80a577 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ * Python: Add JSON.NUMMULTBY command ([#2458](https://github.com/valkey-io/valkey-glide/pull/2458)) * Java: Added `FT.CREATE` ([#2414](https://github.com/valkey-io/valkey-glide/pull/2414)) * Java: Added `FT.DROPINDEX` ([#2440](https://github.com/valkey-io/valkey-glide/pull/2440)) +* Java: Added `FT.SEARCH` ([#2439](https://github.com/valkey-io/valkey-glide/pull/2439)) * Core: Update routing for commands from server modules ([#2461](https://github.com/valkey-io/valkey-glide/pull/2461)) #### Breaking Changes diff --git a/glide-core/src/client/value_conversion.rs b/glide-core/src/client/value_conversion.rs index 4a43da7da7..e89c92adbe 100644 --- a/glide-core/src/client/value_conversion.rs +++ b/glide-core/src/client/value_conversion.rs @@ -22,6 +22,7 @@ pub(crate) enum ExpectedReturnType<'a> { ArrayOfStrings, ArrayOfBools, ArrayOfDoubleOrNull, + FTSearchReturnType, Lolwut, ArrayOfStringAndArrays, ArrayOfArraysOfDoubleOrNull, @@ -891,7 +892,53 @@ pub(crate) fn convert_to_expected_type( format!("(response was {:?})", get_value_type(&value)), ) .into()), - } + }, + ExpectedReturnType::FTSearchReturnType => match value { + /* + Example of the response + 1) (integer) 2 + 2) "json:2" + 3) 1) "__VEC_score" + 2) "11.1100006104" + 3) "$" + 4) "{\"vec\":[1.1,1.2,1.3,1.4,1.5,1.6]}" + 4) "json:0" + 5) 1) "__VEC_score" + 2) "91" + 3) "$" + 4) "{\"vec\":[1,2,3,4,5,6]}" + + Converting response to + 1) (integer) 2 + 2) 1# "json:2" => + 1# "__VEC_score" => "11.1100006104" + 2# "$" => "{\"vec\":[1.1,1.2,1.3,1.4,1.5,1.6]}" + 2# "json:0" => + 1# "__VEC_score" => "91" + 2# "$" => "{\"vec\":[1,2,3,4,5,6]}" + + Response may contain only 1 element, no conversion in that case. + */ + Value::Array(ref array) if array.len() == 1 => Ok(value), + Value::Array(mut array) => { + Ok(Value::Array(vec![ + array.remove(0), + convert_to_expected_type(Value::Array(array), Some(ExpectedReturnType::Map { + key_type: &Some(ExpectedReturnType::BulkString), + value_type: &Some(ExpectedReturnType::Map { + key_type: &Some(ExpectedReturnType::BulkString), + value_type: &Some(ExpectedReturnType::BulkString), + }), + }))? + ])) + }, + _ => Err(( + ErrorKind::TypeError, + "Response couldn't be converted to Pair", + format!("(response was {:?})", get_value_type(&value)), + ) + .into()) + }, } } @@ -1256,6 +1303,7 @@ pub(crate) fn expected_type_for_cmd(cmd: &Cmd) -> Option { key_type: &None, value_type: &None, }), + b"FT.SEARCH" => Some(ExpectedReturnType::FTSearchReturnType), _ => None, } } diff --git a/java/client/src/main/java/glide/api/commands/servermodules/FT.java b/java/client/src/main/java/glide/api/commands/servermodules/FT.java index bff9eeb357..12a20a0ff4 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/FT.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/FT.java @@ -2,6 +2,7 @@ package glide.api.commands.servermodules; import static glide.api.models.GlideString.gs; +import static glide.utils.ArrayTransformUtils.concatenateArrays; import glide.api.BaseClient; import glide.api.GlideClient; @@ -10,6 +11,7 @@ import glide.api.models.GlideString; import glide.api.models.commands.FT.FTCreateOptions; import glide.api.models.commands.FT.FTCreateOptions.FieldInfo; +import glide.api.models.commands.FT.FTSearchOptions; import java.util.Arrays; import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; @@ -140,6 +142,139 @@ public static CompletableFuture create( return executeCommand(client, args, false); } + /** + * Uses the provided query expression to locate keys within an index. Once located, the count + * and/or content of indexed fields within those keys can be returned. + * + * @param client The client to execute the command. + * @param indexName The index name to search into. + * @param query The text query to search. + * @param options The search options - see {@link FTSearchOptions}. + * @return A two element array, where first element is count of documents in result set, and the + * second element, which has format + * {@literal Map>} - a mapping between + * document names and map of their attributes.
+ * If {@link FTSearchOptions.FTSearchOptionsBuilder#count()} or {@link + * FTSearchOptions.FTSearchOptionsBuilder#limit(int, int)} with values 0, 0 is + * set, the command returns array with only one element - the count of the documents. + * @example + *
{@code
+     * byte[] vector = new byte[24];
+     * Arrays.fill(vector, (byte) 0);
+     * var result = FT.search(client, "json_idx1", "*=>[KNN 2 @VEC $query_vec]",
+     *         FTSearchOptions.builder().params(Map.of(gs("query_vec"), gs(vector))).build())
+     *     .get();
+     * assertArrayEquals(result, new Object[] { 2L, Map.of(
+     *     gs("json:2"), Map.of(gs("__VEC_score"), gs("11.1100006104"), gs("$"), gs("{\"vec\":[1.1,1.2,1.3,1.4,1.5,1.6]}")),
+     *     gs("json:0"), Map.of(gs("__VEC_score"), gs("91"), gs("$"), gs("{\"vec\":[1,2,3,4,5,6]}")))
+     * });
+     * }
+ */ + public static CompletableFuture search( + @NonNull BaseClient client, + @NonNull String indexName, + @NonNull String query, + @NonNull FTSearchOptions options) { + var args = + concatenateArrays( + new GlideString[] {gs("FT.SEARCH"), gs(indexName), gs(query)}, options.toArgs()); + return executeCommand(client, args, false); + } + + /** + * Uses the provided query expression to locate keys within an index. Once located, the count + * and/or content of indexed fields within those keys can be returned. + * + * @param client The client to execute the command. + * @param indexName The index name to search into. + * @param query The text query to search. + * @param options The search options - see {@link FTSearchOptions}. + * @return A two element array, where first element is count of documents in result set, and the + * second element, which has format + * {@literal Map>} - a mapping between + * document names and map of their attributes.
+ * If {@link FTSearchOptions.FTSearchOptionsBuilder#count()} or {@link + * FTSearchOptions.FTSearchOptionsBuilder#limit(int, int)} with values 0, 0 is + * set, the command returns array with only one element - the count of the documents. + * @example + *
{@code
+     * byte[] vector = new byte[24];
+     * Arrays.fill(vector, (byte) 0);
+     * var result = FT.search(client, gs("json_idx1"), gs("*=>[KNN 2 @VEC $query_vec]"),
+     *         FTSearchOptions.builder().params(Map.of(gs("query_vec"), gs(vector))).build())
+     *     .get();
+     * assertArrayEquals(result, new Object[] { 2L, Map.of(
+     *     gs("json:2"), Map.of(gs("__VEC_score"), gs("11.1100006104"), gs("$"), gs("{\"vec\":[1.1,1.2,1.3,1.4,1.5,1.6]}")),
+     *     gs("json:0"), Map.of(gs("__VEC_score"), gs("91"), gs("$"), gs("{\"vec\":[1,2,3,4,5,6]}")))
+     * });
+     * }
+ */ + public static CompletableFuture search( + @NonNull BaseClient client, + @NonNull GlideString indexName, + @NonNull GlideString query, + @NonNull FTSearchOptions options) { + var args = + concatenateArrays(new GlideString[] {gs("FT.SEARCH"), indexName, query}, options.toArgs()); + return executeCommand(client, args, false); + } + + /** + * Uses the provided query expression to locate keys within an index. Once located, the count + * and/or content of indexed fields within those keys can be returned. + * + * @param client The client to execute the command. + * @param indexName The index name to search into. + * @param query The text query to search. + * @return A two element array, where first element is count of documents in result set, and the + * second element, which has format + * {@literal Map>} - a mapping between + * document names and map of their attributes. + * @example + *
{@code
+     * byte[] vector = new byte[24];
+     * Arrays.fill(vector, (byte) 0);
+     * var result = FT.search(client, "json_idx1", "*").get();
+     * assertArrayEquals(result, new Object[] { 2L, Map.of(
+     *     gs("json:2"), Map.of(gs("$"), gs("{\"vec\":[1.1,1.2,1.3,1.4,1.5,1.6]}")),
+     *     gs("json:0"), Map.of(gs("$"), gs("{\"vec\":[1,2,3,4,5,6]}")))
+     * });
+     * }
+ */ + public static CompletableFuture search( + @NonNull BaseClient client, @NonNull String indexName, @NonNull String query) { + var args = new GlideString[] {gs("FT.SEARCH"), gs(indexName), gs(query)}; + return executeCommand(client, args, false); + } + + /** + * Uses the provided query expression to locate keys within an index. Once located, the count + * and/or content of indexed fields within those keys can be returned. + * + * @param client The client to execute the command. + * @param indexName The index name to search into. + * @param query The text query to search. + * @return A two element array, where first element is count of documents in result set, and the + * second element, which has format + * {@literal Map>} - a mapping between + * document names and map of their attributes. + * @example + *
{@code
+     * byte[] vector = new byte[24];
+     * Arrays.fill(vector, (byte) 0);
+     * var result = FT.search(client, gs("json_idx1"), gs("*")).get();
+     * assertArrayEquals(result, new Object[] { 2L, Map.of(
+     *     gs("json:2"), Map.of(gs("$"), gs("{\"vec\":[1.1,1.2,1.3,1.4,1.5,1.6]}")),
+     *     gs("json:0"), Map.of(gs("$"), gs("{\"vec\":[1,2,3,4,5,6]}")))
+     * });
+     * }
+ */ + public static CompletableFuture search( + @NonNull BaseClient client, @NonNull GlideString indexName, @NonNull GlideString query) { + var args = new GlideString[] {gs("FT.SEARCH"), indexName, query}; + return executeCommand(client, args, false); + } + /** * Deletes an index and associated content. Indexed document keys are unaffected. * diff --git a/java/client/src/main/java/glide/api/models/commands/FT/FTSearchOptions.java b/java/client/src/main/java/glide/api/models/commands/FT/FTSearchOptions.java new file mode 100644 index 0000000000..990eab2cb3 --- /dev/null +++ b/java/client/src/main/java/glide/api/models/commands/FT/FTSearchOptions.java @@ -0,0 +1,131 @@ +/** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ +package glide.api.models.commands.FT; + +import static glide.api.models.GlideString.gs; + +import glide.api.commands.servermodules.FT; +import glide.api.models.GlideString; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; +import lombok.Builder; +import org.apache.commons.lang3.tuple.Pair; + +/** Mandatory parameters for {@link FT#search}. */ +@Builder +public class FTSearchOptions { + + @Builder.Default private final Map identifiers = new HashMap<>(); + + /** Query timeout in milliseconds. */ + private final Integer timeout; + + private final Pair limit; + + @Builder.Default private final boolean count = false; + + /** + * Query parameters, which could be referenced in the query by $ sign, followed by + * the parameter name. + */ + @Builder.Default private final Map params = new HashMap<>(); + + // TODO maxstale? + // dialect is no-op + + /** Convert to module API. */ + public GlideString[] toArgs() { + var args = new ArrayList(); + if (!identifiers.isEmpty()) { + args.add(gs("RETURN")); + int tokenCount = 0; + for (var pair : identifiers.entrySet()) { + tokenCount++; + args.add(pair.getKey()); + if (pair.getValue() != null) { + tokenCount += 2; + args.add(gs("AS")); + args.add(pair.getValue()); + } + } + args.add(1, gs(Integer.toString(tokenCount))); + } + if (timeout != null) { + args.add(gs("TIMEOUT")); + args.add(gs(timeout.toString())); + } + if (!params.isEmpty()) { + args.add(gs("PARAMS")); + args.add(gs(Integer.toString(params.size() * 2))); + params.forEach( + (name, value) -> { + args.add(name); + args.add(value); + }); + } + if (limit != null) { + args.add(gs("LIMIT")); + args.add(gs(Integer.toString(limit.getLeft()))); + args.add(gs(Integer.toString(limit.getRight()))); + } + if (count) { + args.add(gs("COUNT")); + } + return args.toArray(GlideString[]::new); + } + + public static class FTSearchOptionsBuilder { + + // private - hiding this API from user + void limit(Pair limit) {} + + void count(boolean count) {} + + void identifiers(Map identifiers) {} + + /** Add a field to be returned. */ + public FTSearchOptionsBuilder addReturnField(String field) { + this.identifiers$value.put(gs(field), null); + return this; + } + + /** Add a field with an alias to be returned. */ + public FTSearchOptionsBuilder addReturnField(String field, String alias) { + this.identifiers$value.put(gs(field), gs(alias)); + return this; + } + + /** Add a field to be returned. */ + public FTSearchOptionsBuilder addReturnField(GlideString field) { + this.identifiers$value.put(field, null); + return this; + } + + /** Add a field with an alias to be returned. */ + public FTSearchOptionsBuilder addReturnField(GlideString field, GlideString alias) { + this.identifiers$value.put(field, alias); + return this; + } + + /** + * Configure query pagination. By default only first 10 documents are returned. + * + * @param offset Zero-based offset. + * @param count Number of elements to return. + */ + public FTSearchOptionsBuilder limit(int offset, int count) { + this.limit = Pair.of(offset, count); + return this; + } + + /** + * Once set, the query will return only number of documents in the result set without actually + * returning them. + */ + public FTSearchOptionsBuilder count() { + this.count$value = true; + this.count$set = true; + return this; + } + } +} diff --git a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java index 0597fa0023..33b9bd9dd6 100644 --- a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java +++ b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java @@ -3,8 +3,10 @@ import static glide.TestUtilities.commonClusterClientConfig; import static glide.api.BaseClient.OK; +import static glide.api.models.GlideString.gs; import static glide.api.models.configuration.RequestRoutingConfiguration.SimpleMultiNodeRoute.ALL_PRIMARIES; import static glide.api.models.configuration.RequestRoutingConfiguration.SimpleSingleNodeRoute.RANDOM; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; @@ -22,10 +24,12 @@ import glide.api.models.commands.FT.FTCreateOptions.TextField; import glide.api.models.commands.FT.FTCreateOptions.VectorFieldFlat; import glide.api.models.commands.FT.FTCreateOptions.VectorFieldHnsw; +import glide.api.models.commands.FT.FTSearchOptions; import glide.api.models.commands.FlushMode; import glide.api.models.commands.InfoOptions.Section; import glide.api.models.exceptions.RequestException; import java.util.HashSet; +import java.util.Map; import java.util.Set; import java.util.UUID; import java.util.concurrent.ExecutionException; @@ -186,6 +190,112 @@ public void ft_create() { assertTrue(exception.getMessage().contains("already exists")); } + @SneakyThrows + @Test + public void ft_search() { + String prefix = "{" + UUID.randomUUID() + "}:"; + String index = prefix + "index"; + + assertEquals( + OK, + FT.create( + client, + index, + new FieldInfo[] { + new FieldInfo("vec", "VEC", VectorFieldHnsw.builder(DistanceMetric.L2, 2).build()) + }, + FTCreateOptions.builder() + .indexType(IndexType.HASH) + .prefixes(new String[] {prefix}) + .build()) + .get()); + + assertEquals( + 1L, + client + .hset( + gs(prefix + 0), + Map.of( + gs("vec"), + gs( + new byte[] { + (byte) 0, (byte) 0, (byte) 0, (byte) 0, (byte) 0, (byte) 0, (byte) 0, + (byte) 0 + }))) + .get()); + assertEquals( + 1L, + client + .hset( + gs(prefix + 1), + Map.of( + gs("vec"), + gs( + new byte[] { + (byte) 0, + (byte) 0, + (byte) 0, + (byte) 0, + (byte) 0, + (byte) 0, + (byte) 0x80, + (byte) 0xBF + }))) + .get()); + + var ftsearch = + FT.search( + client, + index, + "*=>[KNN 2 @VEC $query_vec]", + FTSearchOptions.builder() + .params( + Map.of( + gs("query_vec"), + gs( + new byte[] { + (byte) 0, (byte) 0, (byte) 0, (byte) 0, (byte) 0, (byte) 0, + (byte) 0, (byte) 0 + }))) + .build()) + .get(); + + assertArrayEquals( + new Object[] { + 2L, + Map.of( + gs(prefix + 0), + Map.of(gs("__VEC_score"), gs("0"), gs("vec"), gs("\0\0\0\0\0\0\0\0")), + gs(prefix + 1), + Map.of( + gs("__VEC_score"), + gs("1"), + gs("vec"), + gs( + new byte[] { + (byte) 0, + (byte) 0, + (byte) 0, + (byte) 0, + (byte) 0, + (byte) 0, + (byte) 0x80, + (byte) 0xBF + }))) + }, + ftsearch); + + // TODO more tests with json index + + // querying non-existing index + var exception = + assertThrows( + ExecutionException.class, + () -> FT.search(client, UUID.randomUUID().toString(), "*").get()); + assertInstanceOf(RequestException.class, exception.getCause()); + assertTrue(exception.getMessage().contains("Index not found")); + } + @SneakyThrows @Test public void ft_drop() { From 1125eba9c77a2b2c8fe21f1f0a7781d921b762a2 Mon Sep 17 00:00:00 2001 From: prateek-kumar-improving Date: Thu, 17 Oct 2024 13:53:54 -0700 Subject: [PATCH 019/180] Python: Add commands `FT.ALIASADD`, `FT.ALIASDEL`, `FT.ALIASUPDATE` (#2471) * Python: Add commands FT.ALIASADD, FT.ALIASDEL, FT.ALIASUPDATE --------- Signed-off-by: Prateek Kumar --- CHANGELOG.md | 1 + .../glide/async_commands/server_modules/ft.py | 66 +++++++++++++ .../server_modules/ft_options/ft_constants.py | 3 + .../tests/tests_server_modules/test_ft.py | 97 +++++++++++++++++++ 4 files changed, 167 insertions(+) create mode 100644 python/python/tests/tests_server_modules/test_ft.py diff --git a/CHANGELOG.md b/CHANGELOG.md index ad4c80a577..e28f1941ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ #### Changes +* Python: Add commands FT.ALIASADD, FT.ALIASDEL, FT.ALIASUPDATE([#2471](https://github.com/valkey-io/valkey-glide/pull/2471)) * Python: Python FT.DROPINDEX command ([#2437](https://github.com/valkey-io/valkey-glide/pull/2437)) * Python: Python: Added FT.CREATE command([#2413](https://github.com/valkey-io/valkey-glide/pull/2413)) * Python: Add JSON.ARRLEN command ([#2403](https://github.com/valkey-io/valkey-glide/pull/2403)) diff --git a/python/python/glide/async_commands/server_modules/ft.py b/python/python/glide/async_commands/server_modules/ft.py index 74d75e8953..88dcb9b58e 100644 --- a/python/python/glide/async_commands/server_modules/ft.py +++ b/python/python/glide/async_commands/server_modules/ft.py @@ -76,3 +76,69 @@ async def dropindex(client: TGlideClient, indexName: TEncodable) -> TOK: """ args: List[TEncodable] = [CommandNames.FT_DROPINDEX, indexName] return cast(TOK, await client.custom_command(args)) + + +async def aliasadd( + client: TGlideClient, alias: TEncodable, indexName: TEncodable +) -> TOK: + """ + Add an alias for an index. The new alias name can be used anywhere that an index name is required. + + Args: + client (TGlideClient): The client to execute the command. + alias (TEncodable): The alias to be added to an index. + indexName (TEncodable): The index name for which the alias has to be added. + + Returns: + TOK: A simple "OK" response. + + Examples: + >>> from glide.async_commands.server_modules import ft + >>> result = await ft.aliasadd(glide_client, "myalias", "myindex") + 'OK' # Indicates the successful addition of the alias named "myalias" for the index. + """ + args: List[TEncodable] = [CommandNames.FT_ALIASADD, alias, indexName] + return cast(TOK, await client.custom_command(args)) + + +async def aliasdel(client: TGlideClient, alias: TEncodable) -> TOK: + """ + Delete an existing alias for an index. + + Args: + client (TGlideClient): The client to execute the command. + alias (TEncodable): The exisiting alias to be deleted for an index. + + Returns: + TOK: A simple "OK" response. + + Examples: + >>> from glide.async_commands.server_modules import ft + >>> result = await ft.aliasdel(glide_client, "myalias") + 'OK' # Indicates the successful deletion of the alias named "myalias" + """ + args: List[TEncodable] = [CommandNames.FT_ALIASDEL, alias] + return cast(TOK, await client.custom_command(args)) + + +async def aliasupdate( + client: TGlideClient, alias: TEncodable, indexName: TEncodable +) -> TOK: + """ + Update an existing alias to point to a different physical index. This command only affects future references to the alias. + + Args: + client (TGlideClient): The client to execute the command. + alias (TEncodable): The alias name. This alias will now be pointed to a different index. + indexName (TEncodable): The index name for which an existing alias has to updated. + + Returns: + TOK: A simple "OK" response. + + Examples: + >>> from glide.async_commands.server_modules import ft + >>> result = await ft.aliasupdate(glide_client, "myalias", "myindex") + 'OK' # Indicates the successful update of the alias to point to the index named "myindex" + """ + args: List[TEncodable] = [CommandNames.FT_ALIASUPDATE, alias, indexName] + return cast(TOK, await client.custom_command(args)) diff --git a/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py b/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py index d1e8e524eb..14fef2a681 100644 --- a/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py +++ b/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py @@ -8,6 +8,9 @@ class CommandNames: FT_CREATE = "FT.CREATE" FT_DROPINDEX = "FT.DROPINDEX" + FT_ALIASADD = "FT.ALIASADD" + FT_ALIASDEL = "FT.ALIASDEL" + FT_ALIASUPDATE = "FT.ALIASUPDATE" class FtCreateKeywords: diff --git a/python/python/tests/tests_server_modules/test_ft.py b/python/python/tests/tests_server_modules/test_ft.py new file mode 100644 index 0000000000..39b068246b --- /dev/null +++ b/python/python/tests/tests_server_modules/test_ft.py @@ -0,0 +1,97 @@ +# Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 +import uuid +from typing import List + +import pytest +from glide.async_commands.server_modules import ft +from glide.async_commands.server_modules.ft_options.ft_create_options import ( + DataType, + Field, + FtCreateOptions, + TextField, +) +from glide.config import ProtocolVersion +from glide.constants import OK, TEncodable +from glide.exceptions import RequestError +from glide.glide_client import GlideClusterClient + + +@pytest.mark.asyncio +class TestFt: + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_aliasadd(self, glide_client: GlideClusterClient): + indexName: str = str(uuid.uuid4()) + alias: str = "alias" + # Test ft.aliasadd throws an error if index does not exist. + with pytest.raises(RequestError): + await ft.aliasadd(glide_client, alias, indexName) + + # Test ft.aliasadd successfully adds an alias to an existing index. + await TestFt.create_test_index_hash_type(self, glide_client, indexName) + assert await ft.aliasadd(glide_client, alias, indexName) == OK + + # Test ft.aliasadd for input of bytes type. + indexNameString = str(uuid.uuid4()) + indexNameBytes = bytes(indexNameString, "utf-8") + aliasNameBytes = b"alias-bytes" + await TestFt.create_test_index_hash_type(self, glide_client, indexNameString) + assert await ft.aliasadd(glide_client, aliasNameBytes, indexNameBytes) == OK + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_aliasdel(self, glide_client: GlideClusterClient): + indexName: TEncodable = str(uuid.uuid4()) + alias: str = "alias" + await TestFt.create_test_index_hash_type(self, glide_client, indexName) + + # Test if deleting a non existent alias throws an error. + with pytest.raises(RequestError): + await ft.aliasdel(glide_client, alias) + + # Test if an existing alias is deleted successfully. + assert await ft.aliasadd(glide_client, alias, indexName) == OK + assert await ft.aliasdel(glide_client, alias) == OK + + # Test if an existing alias is deleted successfully for bytes type input. + assert await ft.aliasadd(glide_client, alias, indexName) == OK + assert await ft.aliasdel(glide_client, bytes(alias, "utf-8")) == OK + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_aliasupdate(self, glide_client: GlideClusterClient): + indexName: str = str(uuid.uuid4()) + alias: str = "alias" + await TestFt.create_test_index_hash_type(self, glide_client, indexName) + assert await ft.aliasadd(glide_client, alias, indexName) == OK + newAliasName: str = "newAlias" + newIndexName: str = str(uuid.uuid4()) + + await TestFt.create_test_index_hash_type(self, glide_client, newIndexName) + assert await ft.aliasadd(glide_client, newAliasName, newIndexName) == OK + + # Test if updating an already existing alias to point to an existing index returns "OK". + assert await ft.aliasupdate(glide_client, newAliasName, indexName) == OK + assert ( + await ft.aliasupdate( + glide_client, bytes(alias, "utf-8"), bytes(newIndexName, "utf-8") + ) + == OK + ) + + async def create_test_index_hash_type( + self, glide_client: GlideClusterClient, index_name: TEncodable + ): + # Helper function used for creating a basic index with hash data type with one text field. + fields: List[Field] = [] + text_field_title: TextField = TextField("$title") + fields.append(text_field_title) + + prefix = "{json-search-" + str(uuid.uuid4()) + "}:" + prefixes: List[TEncodable] = [] + prefixes.append(prefix) + + result = await ft.create( + glide_client, index_name, fields, FtCreateOptions(DataType.HASH, prefixes) + ) + assert result == OK From adcc76f3a390b180dc454765f95320b40d774cd2 Mon Sep 17 00:00:00 2001 From: James Xin Date: Thu, 17 Oct 2024 14:52:36 -0700 Subject: [PATCH 020/180] Java: add JSON.SET and JSON.GET (#2462) * Java: add JSON.SET and JSON.GET --------- Signed-off-by: James Xin --- CHANGELOG.md | 1 + .../api/commands/servermodules/Json.java | 460 ++++++++++++++++++ .../models/commands/json/JsonGetOptions.java | 64 +++ .../commands/json/JsonGetOptionsBinary.java | 67 +++ java/client/src/main/java/module-info.java | 1 + .../api/commands/servermodules/JsonTest.java | 345 +++++++++++++ java/integTest/build.gradle | 3 + .../test/java/glide/modules/JsonTests.java | 140 +++++- 8 files changed, 1078 insertions(+), 3 deletions(-) create mode 100644 java/client/src/main/java/glide/api/commands/servermodules/Json.java create mode 100644 java/client/src/main/java/glide/api/models/commands/json/JsonGetOptions.java create mode 100644 java/client/src/main/java/glide/api/models/commands/json/JsonGetOptionsBinary.java create mode 100644 java/client/src/test/java/glide/api/commands/servermodules/JsonTest.java diff --git a/CHANGELOG.md b/CHANGELOG.md index e28f1941ce..a90f628944 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ * Java: Added `FT.CREATE` ([#2414](https://github.com/valkey-io/valkey-glide/pull/2414)) * Java: Added `FT.DROPINDEX` ([#2440](https://github.com/valkey-io/valkey-glide/pull/2440)) * Java: Added `FT.SEARCH` ([#2439](https://github.com/valkey-io/valkey-glide/pull/2439)) +* Java: Added `JSON.SET` and `JSON.GET` ([#2462](https://github.com/valkey-io/valkey-glide/pull/2462)) * Core: Update routing for commands from server modules ([#2461](https://github.com/valkey-io/valkey-glide/pull/2461)) #### Breaking Changes diff --git a/java/client/src/main/java/glide/api/commands/servermodules/Json.java b/java/client/src/main/java/glide/api/commands/servermodules/Json.java new file mode 100644 index 0000000000..5aeb9fd851 --- /dev/null +++ b/java/client/src/main/java/glide/api/commands/servermodules/Json.java @@ -0,0 +1,460 @@ +/** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ +package glide.api.commands.servermodules; + +import static glide.api.models.GlideString.gs; + +import glide.api.BaseClient; +import glide.api.GlideClient; +import glide.api.GlideClusterClient; +import glide.api.models.ClusterValue; +import glide.api.models.GlideString; +import glide.api.models.commands.ConditionalChange; +import glide.api.models.commands.json.JsonGetOptions; +import glide.api.models.commands.json.JsonGetOptionsBinary; +import glide.utils.ArgsBuilder; +import glide.utils.ArrayTransformUtils; +import java.util.concurrent.CompletableFuture; +import lombok.NonNull; + +/** Module for JSON commands. */ +public class Json { + + public static final String JSON_PREFIX = "JSON."; + public static final String JSON_SET = JSON_PREFIX + "SET"; + public static final String JSON_GET = JSON_PREFIX + "GET"; + + private Json() {} + + /** + * Sets the JSON value at the specified path stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the value will be set. The key + * will be modified only if value is added as the last child in the specified + * path, or if the specified path acts as the parent of a new child + * being added. + * @param value The value to set at the specific path, in JSON formatted string. + * @return A simple "OK" response if the value is successfully set. + * @example + *
{@code
+     * String value = Json.set(client, "doc", , ".", "{'a': 1.0, 'b': 2}").get();
+     * assert value.equals("OK");
+     * }
+ */ + public static CompletableFuture set( + @NonNull BaseClient client, + @NonNull String key, + @NonNull String path, + @NonNull String value) { + return executeCommand(client, new String[] {JSON_SET, key, path, value}); + } + + /** + * Sets the JSON value at the specified path stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the value will be set. The key + * will be modified only if value is added as the last child in the specified + * path, or if the specified path acts as the parent of a new child + * being added. + * @param value The value to set at the specific path, in JSON formatted GlideString. + * @return A simple "OK" response if the value is successfully set. + * @example + *
{@code
+     * String value = client.Json.set(client, gs("doc"), , gs("."), gs("{'a': 1.0, 'b': 2}")).get();
+     * assert value.equals("OK");
+     * }
+ */ + public static CompletableFuture set( + @NonNull BaseClient client, + @NonNull GlideString key, + @NonNull GlideString path, + @NonNull GlideString value) { + return executeCommand(client, new GlideString[] {gs(JSON_SET), key, path, value}); + } + + /** + * Sets the JSON value at the specified path stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the value will be set. The key + * will be modified only if value is added as the last child in the specified + * path, or if the specified path acts as the parent of a new child + * being added. + * @param value The value to set at the specific path, in JSON formatted string. + * @param setCondition Set the value only if the given condition is met (within the key or path). + * @return A simple "OK" response if the value is successfully set. If value isn't + * set because of setCondition, returns null. + * @example + *
{@code
+     * String value = client.Json.set(client, "doc", , ".", "{'a': 1.0, 'b': 2}", ConditionalChange.ONLY_IF_DOES_NOT_EXIST).get();
+     * assert value.equals("OK");
+     * }
+ */ + public static CompletableFuture set( + @NonNull BaseClient client, + @NonNull String key, + @NonNull String path, + @NonNull String value, + @NonNull ConditionalChange setCondition) { + return executeCommand( + client, new String[] {JSON_SET, key, path, value, setCondition.getValkeyApi()}); + } + + /** + * Sets the JSON value at the specified path stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the value will be set. The key + * will be modified only if value is added as the last child in the specified + * path, or if the specified path acts as the parent of a new child + * being added. + * @param value The value to set at the specific path, in JSON formatted GlideString. + * @param setCondition Set the value only if the given condition is met (within the key or path). + * @return A simple "OK" response if the value is successfully set. If value isn't + * set because of setCondition, returns null. + * @example + *
{@code
+     * String value = client.Json.set(client, gs("doc"), , gs("."), gs("{'a': 1.0, 'b': 2}"), ConditionalChange.ONLY_IF_DOES_NOT_EXIST).get();
+     * assert value.equals("OK");
+     * }
+ */ + public static CompletableFuture set( + @NonNull BaseClient client, + @NonNull GlideString key, + @NonNull GlideString path, + @NonNull GlideString value, + @NonNull ConditionalChange setCondition) { + return executeCommand( + client, + new GlideString[] {gs(JSON_SET), key, path, value, gs(setCondition.getValkeyApi())}); + } + + /** + * Retrieves the JSON value at the specified path stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return Returns a string representation of the JSON document. If key doesn't + * exist, returns null. + * @example + *
{@code
+     * String value = client.Json.get(client, "doc").get();
+     * assert value.equals("{'a': 1.0, 'b': 2}");
+     * }
+ */ + public static CompletableFuture get(@NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new String[] {JSON_GET, key}); + } + + /** + * Retrieves the JSON value at the specified path stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return Returns a string representation of the JSON document. If key doesn't + * exist, returns null. + * @example + *
{@code
+     * GlideString value = client.Json.get(client, gs("doc")).get();
+     * assert value.equals(gs("{'a': 1.0, 'b': 2}"));
+     * }
+ */ + public static CompletableFuture get( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new GlideString[] {gs(JSON_GET), key}); + } + + /** + * Retrieves the JSON value at the specified paths stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param paths List of paths within the JSON document. + * @return + *
    + *
  • If one path is given: + *
      + *
    • For JSONPath (path starts with $): Returns a stringified JSON list + * replies for every possible path, or a string representation of an empty array, + * if path doesn't exist. If key doesn't exist, returns None. + *
    • For legacy path (path doesn't start with $): Returns a string + * representation of the value in paths. If paths + * doesn't exist, an error is raised. If key doesn't exist, returns + * None. + *
    + *
  • If multiple paths are given: Returns a stringified JSON, in which each path is a key, + * and it's corresponding value, is the value as if the path was executed in the command + * as a single path. + *
+ * In case of multiple paths, and paths are a mix of both JSONPath and legacy + * path, the command behaves as if all are JSONPath paths. + * @example + *
{@code
+     * String value = client.Json.get(client, "doc", new String[] {"$"}).get();
+     * assert value.equals("{'a': 1.0, 'b': 2}");
+     * String value = client.Json.get(client, "doc", new String[] {"$.a", "$.b"}).get();
+     * assert value.equals("{\"$.a\": [1.0], \"$.b\": [2]}");
+     * }
+ */ + public static CompletableFuture get( + @NonNull BaseClient client, @NonNull String key, @NonNull String[] paths) { + return executeCommand( + client, ArrayTransformUtils.concatenateArrays(new String[] {JSON_GET, key}, paths)); + } + + /** + * Retrieves the JSON value at the specified paths stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param paths List of paths within the JSON document. + * @return + *
    + *
  • If one path is given: + *
      + *
    • For JSONPath (path starts with $): Returns a stringified JSON list + * replies for every possible path, or a string representation of an empty array, + * if path doesn't exist. If key doesn't exist, returns None. + *
    • For legacy path (path doesn't start with $): Returns a string + * representation of the value in paths. If paths + * doesn't exist, an error is raised. If key doesn't exist, returns + * None. + *
    + *
  • If multiple paths are given: Returns a stringified JSON, in which each path is a key, + * and it's corresponding value, is the value as if the path was executed in the command + * as a single path. + *
+ * In case of multiple paths, and paths are a mix of both JSONPath and legacy + * path, the command behaves as if all are JSONPath paths. + * @example + *
{@code
+     * GlideString value = client.Json.get(client, gs("doc"), new GlideString[] {gs("$")}).get();
+     * assert value.equals(gs("{'a': 1.0, 'b': 2}"));
+     * GlideString value = client.Json.get(client, gs("doc"), new GlideString[] {gs("$.a"), gs("$.b")}).get();
+     * assert value.equals(gs("{\"$.a\": [1.0], \"$.b\": [2]}"));
+     * }
+ */ + public static CompletableFuture get( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString[] paths) { + return executeCommand( + client, + ArrayTransformUtils.concatenateArrays(new GlideString[] {gs(JSON_GET), key}, paths)); + } + + /** + * Retrieves the JSON value at the specified path stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param options Options for formatting the byte representation of the JSON data. See + * JsonGetOptions. + * @return Returns a string representation of the JSON document. If key doesn't + * exist, returns null. + * @example + *
{@code
+     * JsonGetOptions options = JsonGetOptions.builder()
+     *                              .indent("  ")
+     *                              .space(" ")
+     *                              .newline("\n")
+     *                              .build();
+     * String value = client.Json.get(client, "doc", "$", options).get();
+     * assert value.equals("{\n \"a\": \n  1.0\n ,\n \"b\": \n  2\n }");
+     * }
+ */ + public static CompletableFuture get( + @NonNull BaseClient client, @NonNull String key, @NonNull JsonGetOptions options) { + return executeCommand( + client, + ArrayTransformUtils.concatenateArrays(new String[] {JSON_GET, key}, options.toArgs())); + } + + /** + * Retrieves the JSON value at the specified path stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param options Options for formatting the byte representation of the JSON data. See + * JsonGetOptions. + * @return Returns a string representation of the JSON document. If key doesn't + * exist, returns null. + * @example + *
{@code
+     * JsonGetOptions options = JsonGetOptions.builder()
+     *                              .indent("  ")
+     *                              .space(" ")
+     *                              .newline("\n")
+     *                              .build();
+     * GlideString value = client.Json.get(client, gs("doc"), gs("$"), options).get();
+     * assert value.equals(gs("{\n \"a\": \n  1.0\n ,\n \"b\": \n  2\n }"));
+     * }
+ */ + public static CompletableFuture get( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull JsonGetOptionsBinary options) { + return executeCommand( + client, new ArgsBuilder().add(gs(JSON_GET)).add(key).add(options.toArgs()).toArray()); + } + + /** + * Retrieves the JSON value at the specified path stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param paths List of paths within the JSON document. + * @param options Options for formatting the byte representation of the JSON data. See + * JsonGetOptions. + * @return + *
    + *
  • If one path is given: + *
      + *
    • For JSONPath (path starts with $): Returns a stringified JSON list + * replies for every possible path, or a string representation of an empty array, + * if path doesn't exist. If key doesn't exist, returns None. + *
    • For legacy path (path doesn't start with $): Returns a string + * representation of the value in paths. If paths + * doesn't exist, an error is raised. If key doesn't exist, returns + * None. + *
    + *
  • If multiple paths are given: Returns a stringified JSON, in which each path is a key, + * and it's corresponding value, is the value as if the path was executed in the command + * as a single path. + *
+ * In case of multiple paths, and paths are a mix of both JSONPath and legacy + * path, the command behaves as if all are JSONPath paths. + * @example + *
{@code
+     * JsonGetOptions options = JsonGetOptions.builder()
+     *                              .indent("  ")
+     *                              .space(" ")
+     *                              .newline("\n")
+     *                              .build();
+     * String value = client.Json.get(client, "doc", new String[] {"$.a", "$.b"}, options).get();
+     * assert value.equals("{\n \"$.a\": [\n  1.0\n ],\n \"$.b\": [\n  2\n ]\n}");
+     * }
+ */ + public static CompletableFuture get( + @NonNull BaseClient client, + @NonNull String key, + @NonNull String[] paths, + @NonNull JsonGetOptions options) { + return executeCommand( + client, + ArrayTransformUtils.concatenateArrays( + new String[] {JSON_GET, key}, options.toArgs(), paths)); + } + + /** + * Retrieves the JSON value at the specified path stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param paths List of paths within the JSON document. + * @param options Options for formatting the byte representation of the JSON data. See + * JsonGetOptions. + * @return + *
    + *
  • If one path is given: + *
      + *
    • For JSONPath (path starts with $): Returns a stringified JSON list + * replies for every possible path, or a string representation of an empty array, + * if path doesn't exist. If key doesn't exist, returns None. + *
    • For legacy path (path doesn't start with $): Returns a string + * representation of the value in paths. If paths + * doesn't exist, an error is raised. If key doesn't exist, returns + * None. + *
    + *
  • If multiple paths are given: Returns a stringified JSON, in which each path is a key, + * and it's corresponding value, is the value as if the path was executed in the command + * as a single path. + *
+ * In case of multiple paths, and paths are a mix of both JSONPath and legacy + * path, the command behaves as if all are JSONPath paths. + * @example + *
{@code
+     * JsonGetOptions options = JsonGetOptions.builder()
+     *                              .indent("  ")
+     *                              .space(" ")
+     *                              .newline("\n")
+     *                              .build();
+     * GlideString value = client.Json.get(client, gs("doc"), new GlideString[] {gs("$.a"), gs("$.b")}, options).get();
+     * assert value.equals(gs("{\n \"$.a\": [\n  1.0\n ],\n \"$.b\": [\n  2\n ]\n}"));
+     * }
+ */ + public static CompletableFuture get( + @NonNull BaseClient client, + @NonNull GlideString key, + @NonNull GlideString[] paths, + @NonNull JsonGetOptionsBinary options) { + return executeCommand( + client, + new ArgsBuilder().add(gs(JSON_GET)).add(key).add(options.toArgs()).add(paths).toArray()); + } + + /** + * A wrapper for custom command API. + * + * @param client The client to execute the command. + * @param args The command line. + */ + private static CompletableFuture executeCommand(BaseClient client, String[] args) { + return executeCommand(client, args, false); + } + + /** + * A wrapper for custom command API. + * + * @param client The client to execute the command. + * @param args The command line. + * @param returnsMap - true if command returns a map + */ + @SuppressWarnings({"unchecked", "SameParameterValue"}) + private static CompletableFuture executeCommand( + BaseClient client, String[] args, boolean returnsMap) { + if (client instanceof GlideClient) { + return ((GlideClient) client).customCommand(args).thenApply(r -> (T) r); + } else if (client instanceof GlideClusterClient) { + return ((GlideClusterClient) client) + .customCommand(args) + .thenApply(returnsMap ? ClusterValue::getMultiValue : ClusterValue::getSingleValue) + .thenApply(r -> (T) r); + } + throw new IllegalArgumentException( + "Unknown type of client, should be either `GlideClient` or `GlideClusterClient`"); + } + + /** + * A wrapper for custom command API. + * + * @param client The client to execute the command. + * @param args The command line. + */ + private static CompletableFuture executeCommand(BaseClient client, GlideString[] args) { + return executeCommand(client, args, false); + } + + /** + * A wrapper for custom command API. + * + * @param client The client to execute the command. + * @param args The command line. + * @param returnsMap - true if command returns a map + */ + @SuppressWarnings({"unchecked", "SameParameterValue"}) + private static CompletableFuture executeCommand( + BaseClient client, GlideString[] args, boolean returnsMap) { + if (client instanceof GlideClient) { + return ((GlideClient) client).customCommand(args).thenApply(r -> (T) r); + } else if (client instanceof GlideClusterClient) { + return ((GlideClusterClient) client) + .customCommand(args) + .thenApply(returnsMap ? ClusterValue::getMultiValue : ClusterValue::getSingleValue) + .thenApply(r -> (T) r); + } + throw new IllegalArgumentException( + "Unknown type of client, should be either `GlideClient` or `GlideClusterClient`"); + } +} diff --git a/java/client/src/main/java/glide/api/models/commands/json/JsonGetOptions.java b/java/client/src/main/java/glide/api/models/commands/json/JsonGetOptions.java new file mode 100644 index 0000000000..5273e9c8c1 --- /dev/null +++ b/java/client/src/main/java/glide/api/models/commands/json/JsonGetOptions.java @@ -0,0 +1,64 @@ +/** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ +package glide.api.models.commands.json; + +import glide.api.commands.servermodules.Json; +import java.util.ArrayList; +import java.util.List; +import lombok.Builder; + +/** Additional parameters for {@link Json#get} command. */ +@Builder +public final class JsonGetOptions { + /** ValKey API string to designate INDENT */ + public static final String INDENT_VALKEY_API = "INDENT"; + + /** ValKey API string to designate NEWLINE */ + public static final String NEWLINE_VALKEY_API = "NEWLINE"; + + /** ValKey API string to designate SPACE */ + public static final String SPACE_VALKEY_API = "SPACE"; + + /** ValKey API string to designate SPACE */ + public static final String NOESCAPE_VALKEY_API = "NOESCAPE"; + + /** Sets an indentation string for nested levels. */ + private String indent; + + /** Sets a string that's printed at the end of each line. */ + private String newline; + + /** Sets a string that's put between a key and a value. */ + private String space; + + /** Allowed to be present for legacy compatibility and has no other effect. */ + private boolean noescape; + + /** + * Converts JsonGetOptions into a String[]. + * + * @return String[] + */ + public String[] toArgs() { + List args = new ArrayList<>(); + if (indent != null) { + args.add(INDENT_VALKEY_API); + args.add(indent); + } + + if (newline != null) { + args.add(NEWLINE_VALKEY_API); + args.add(newline); + } + + if (space != null) { + args.add(SPACE_VALKEY_API); + args.add(space); + } + + if (noescape) { + args.add(NOESCAPE_VALKEY_API); + } + + return args.toArray(new String[0]); + } +} diff --git a/java/client/src/main/java/glide/api/models/commands/json/JsonGetOptionsBinary.java b/java/client/src/main/java/glide/api/models/commands/json/JsonGetOptionsBinary.java new file mode 100644 index 0000000000..634b4d298e --- /dev/null +++ b/java/client/src/main/java/glide/api/models/commands/json/JsonGetOptionsBinary.java @@ -0,0 +1,67 @@ +/** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ +package glide.api.models.commands.json; + +import static glide.api.models.GlideString.gs; + +import glide.api.commands.servermodules.Json; +import glide.api.models.GlideString; +import java.util.ArrayList; +import java.util.List; +import lombok.Builder; + +/** GlideString version of additional parameters for {@link Json#get} command. */ +@Builder +public final class JsonGetOptionsBinary { + /** ValKey API string to designate INDENT */ + public static final GlideString INDENT_VALKEY_API = gs("INDENT"); + + /** ValKey API string to designate NEWLINE */ + public static final GlideString NEWLINE_VALKEY_API = gs("NEWLINE"); + + /** ValKey API string to designate SPACE */ + public static final GlideString SPACE_VALKEY_API = gs("SPACE"); + + /** ValKey API string to designate SPACE */ + public static final GlideString NOESCAPE_VALKEY_API = gs("NOESCAPE"); + + /** Sets an indentation string for nested levels. */ + private GlideString indent; + + /** Sets a string that's printed at the end of each line. */ + private GlideString newline; + + /** Sets a string that's put between a key and a value. */ + private GlideString space; + + /** Allowed to be present for legacy compatibility and has no other effect. */ + private boolean noescape; + + /** + * Converts JsonGetOptions into a GlideString[]. + * + * @return GlideString[] + */ + public GlideString[] toArgs() { + List args = new ArrayList<>(); + if (indent != null) { + args.add(INDENT_VALKEY_API); + args.add(indent); + } + + if (newline != null) { + args.add(NEWLINE_VALKEY_API); + args.add(newline); + } + + if (space != null) { + args.add(SPACE_VALKEY_API); + args.add(space); + } + + if (noescape) { + args.add(NOESCAPE_VALKEY_API); + } + + return args.toArray(new GlideString[0]); + } +} diff --git a/java/client/src/main/java/module-info.java b/java/client/src/main/java/module-info.java index 183e6c0410..fc280da076 100644 --- a/java/client/src/main/java/module-info.java +++ b/java/client/src/main/java/module-info.java @@ -10,6 +10,7 @@ exports glide.api.models.commands.scan; exports glide.api.models.commands.stream; exports glide.api.models.commands.FT; + exports glide.api.models.commands.json; exports glide.api.models.configuration; exports glide.api.models.exceptions; exports glide.api.commands.servermodules; diff --git a/java/client/src/test/java/glide/api/commands/servermodules/JsonTest.java b/java/client/src/test/java/glide/api/commands/servermodules/JsonTest.java new file mode 100644 index 0000000000..81754474a3 --- /dev/null +++ b/java/client/src/test/java/glide/api/commands/servermodules/JsonTest.java @@ -0,0 +1,345 @@ +/** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ +package glide.api.commands.servermodules; + +import static glide.api.models.GlideString.gs; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.RETURNS_DEEP_STUBS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import glide.api.GlideClient; +import glide.api.models.GlideString; +import glide.api.models.commands.ConditionalChange; +import glide.api.models.commands.json.JsonGetOptions; +import glide.api.models.commands.json.JsonGetOptionsBinary; +import glide.utils.ArgsBuilder; +import glide.utils.ArrayTransformUtils; +import java.util.ArrayList; +import java.util.Collections; +import java.util.concurrent.CompletableFuture; +import lombok.SneakyThrows; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class JsonTest { + + private GlideClient glideClient; + + @BeforeEach + void setUp() { + glideClient = mock(GlideClient.class, RETURNS_DEEP_STUBS); + } + + @Test + @SneakyThrows + void set_returns_success() { + // setup + String key = "testKey"; + String path = "$"; + String jsonValue = "{\"a\": 1.0, \"b\": 2}"; + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "OK"; + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new String[] {Json.JSON_SET, key, path, jsonValue})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.set(glideClient, key, path, jsonValue); + String actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void set_binary_returns_success() { + // setup + GlideString key = gs("testKey"); + GlideString path = gs("$"); + GlideString jsonValue = gs("{\"a\": 1.0, \"b\": 2}"); + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "OK"; + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new GlideString[] {gs(Json.JSON_SET), key, path, jsonValue})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.set(glideClient, key, path, jsonValue); + String actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void set_with_condition_returns_success() { + // setup + String key = "testKey"; + String path = "$"; + String jsonValue = "{\"a\": 1.0, \"b\": 2}"; + ConditionalChange setCondition = ConditionalChange.ONLY_IF_DOES_NOT_EXIST; + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "OK"; + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand( + eq(new String[] {Json.JSON_SET, key, path, jsonValue, setCondition.getValkeyApi()})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = + Json.set(glideClient, key, path, jsonValue, setCondition); + String actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void set_binary_with_condition_returns_success() { + // setup + GlideString key = gs("testKey"); + GlideString path = gs("$"); + GlideString jsonValue = gs("{\"a\": 1.0, \"b\": 2}"); + ConditionalChange setCondition = ConditionalChange.ONLY_IF_DOES_NOT_EXIST; + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "OK"; + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand( + eq( + new GlideString[] { + gs(Json.JSON_SET), key, path, jsonValue, gs(setCondition.getValkeyApi()) + })) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = + Json.set(glideClient, key, path, jsonValue, setCondition); + String actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void get_with_no_path_returns_success() { + // setup + String key = "testKey"; + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "{\"a\": 1.0, \"b\": 2}"; + expectedResponse.complete(expectedResponseValue); + when(glideClient.customCommand(eq(new String[] {Json.JSON_GET, key})).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.get(glideClient, key); + String actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void get_binary_with_no_path_returns_success() { + // setup + GlideString key = gs("testKey"); + CompletableFuture expectedResponse = new CompletableFuture<>(); + GlideString expectedResponseValue = gs("{\"a\": 1.0, \"b\": 2}"); + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new GlideString[] {gs(Json.JSON_GET), key})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.get(glideClient, key); + GlideString actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void get_with_multiple_paths_returns_success() { + // setup + String key = "testKey"; + String path1 = ".firstName"; + String path2 = ".lastName"; + String[] paths = new String[] {path1, path2}; + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "{\"a\": 1.0, \"b\": 2}"; + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new String[] {Json.JSON_GET, key, path1, path2})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.get(glideClient, key, paths); + String actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void get_binary_with_multiple_paths_returns_success() { + // setup + GlideString key = gs("testKey"); + GlideString path1 = gs(".firstName"); + GlideString path2 = gs(".lastName"); + GlideString[] paths = new GlideString[] {path1, path2}; + CompletableFuture expectedResponse = new CompletableFuture<>(); + GlideString expectedResponseValue = gs("{\"a\": 1.0, \"b\": 2}"); + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new GlideString[] {gs(Json.JSON_GET), key, path1, path2})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.get(glideClient, key, paths); + GlideString actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void get_with_no_path_and_options_returns_success() { + // setup + String key = "testKey"; + JsonGetOptions options = JsonGetOptions.builder().indent("\t").space(" ").newline("\n").build(); + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "{\"a\": 1.0, \"b\": 2}"; + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand( + eq( + ArrayTransformUtils.concatenateArrays( + new String[] {Json.JSON_GET, key}, options.toArgs()))) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.get(glideClient, key, options); + String actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void get_binary_with_no_path_and_options_returns_success() { + // setup + GlideString key = gs("testKey"); + JsonGetOptionsBinary options = + JsonGetOptionsBinary.builder().indent(gs("\t")).space(gs(" ")).newline(gs("\n")).build(); + CompletableFuture expectedResponse = new CompletableFuture<>(); + GlideString expectedResponseValue = gs("{\"a\": 1.0, \"b\": 2}"); + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand( + eq( + new ArgsBuilder() + .add(new GlideString[] {gs(Json.JSON_GET), key}) + .add(options.toArgs()) + .toArray())) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.get(glideClient, key, options); + GlideString actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void get_with_multiple_paths_and_options_returns_success() { + // setup + String key = "testKey"; + String path1 = ".firstName"; + String path2 = ".lastName"; + JsonGetOptions options = JsonGetOptions.builder().indent("\t").newline("\n").space(" ").build(); + String[] paths = new String[] {path1, path2}; + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "{\"a\": 1.0, \"b\": 2}"; + expectedResponse.complete(expectedResponseValue); + ArrayList argsList = new ArrayList<>(); + argsList.add(Json.JSON_GET); + argsList.add(key); + Collections.addAll(argsList, options.toArgs()); + Collections.addAll(argsList, paths); + when(glideClient.customCommand(eq(argsList.toArray(new String[0]))).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.get(glideClient, key, paths, options); + String actualResponseValue = actualResponse.get(); + + // verify + assertArrayEquals( + new String[] {"INDENT", "\t", "NEWLINE", "\n", "SPACE", " "}, options.toArgs()); + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void get_binary_with_multiple_paths_and_options_returns_success() { + // setup + GlideString key = gs("testKey"); + GlideString path1 = gs(".firstName"); + GlideString path2 = gs(".lastName"); + JsonGetOptionsBinary options = + JsonGetOptionsBinary.builder().indent(gs("\t")).newline(gs("\n")).space(gs(" ")).build(); + GlideString[] paths = new GlideString[] {path1, path2}; + CompletableFuture expectedResponse = new CompletableFuture<>(); + GlideString expectedResponseValue = gs("{\"a\": 1.0, \"b\": 2}"); + expectedResponse.complete(expectedResponseValue); + GlideString[] args = + new ArgsBuilder().add(Json.JSON_GET).add(key).add(options.toArgs()).add(paths).toArray(); + when(glideClient.customCommand(eq(args)).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.get(glideClient, key, paths, options); + GlideString actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } +} diff --git a/java/integTest/build.gradle b/java/integTest/build.gradle index c2032d05d1..70fdf18915 100644 --- a/java/integTest/build.gradle +++ b/java/integTest/build.gradle @@ -28,6 +28,9 @@ dependencies { //lombok testCompileOnly 'org.projectlombok:lombok:1.18.32' testAnnotationProcessor 'org.projectlombok:lombok:1.18.32' + + // jsonassert + testImplementation 'org.skyscreamer:jsonassert:1.5.3' } def standaloneHosts = '' diff --git a/java/integTest/src/test/java/glide/modules/JsonTests.java b/java/integTest/src/test/java/glide/modules/JsonTests.java index 5d6880ae2e..3bf1c93823 100644 --- a/java/integTest/src/test/java/glide/modules/JsonTests.java +++ b/java/integTest/src/test/java/glide/modules/JsonTests.java @@ -2,22 +2,156 @@ package glide.modules; import static glide.TestUtilities.commonClusterClientConfig; +import static glide.api.BaseClient.OK; +import static glide.api.models.GlideString.gs; +import static glide.api.models.configuration.RequestRoutingConfiguration.SimpleMultiNodeRoute.ALL_PRIMARIES; import static glide.api.models.configuration.RequestRoutingConfiguration.SimpleSingleNodeRoute.RANDOM; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; import glide.api.GlideClusterClient; +import glide.api.commands.servermodules.Json; +import glide.api.models.GlideString; +import glide.api.models.commands.ConditionalChange; +import glide.api.models.commands.FlushMode; import glide.api.models.commands.InfoOptions.Section; +import glide.api.models.commands.json.JsonGetOptions; +import java.util.UUID; import lombok.SneakyThrows; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; +import org.skyscreamer.jsonassert.JSONAssert; +import org.skyscreamer.jsonassert.JSONCompareMode; public class JsonTests { - @Test + + private static GlideClusterClient client; + + @BeforeAll @SneakyThrows - public void check_module_loaded() { - var client = + public static void init() { + client = GlideClusterClient.createClient(commonClusterClientConfig().requestTimeout(5000).build()) .get(); + client.flushall(FlushMode.SYNC, ALL_PRIMARIES).get(); + } + + @AfterAll + @SneakyThrows + public static void teardown() { + client.close(); + } + + @Test + @SneakyThrows + public void check_module_loaded() { var info = client.info(new Section[] {Section.MODULES}, RANDOM).get().getSingleValue(); assertTrue(info.contains("# json_core_metrics")); } + + @Test + @SneakyThrows + public void json_set_get() { + String key = UUID.randomUUID().toString(); + String jsonValue = "{\"a\": 1.0,\"b\": 2}"; + + assertEquals(OK, Json.set(client, key, "$", jsonValue).get()); + + String getResult = Json.get(client, key).get(); + + JSONAssert.assertEquals(jsonValue, getResult, JSONCompareMode.LENIENT); + + String getResultWithMultiPaths = Json.get(client, key, new String[] {"$.a", "$.b"}).get(); + + JSONAssert.assertEquals( + "{\"$.a\":[1.0],\"$.b\":[2]}", getResultWithMultiPaths, JSONCompareMode.LENIENT); + + assertNull(Json.get(client, "non_existing_key").get()); + assertEquals("[]", Json.get(client, key, new String[] {"$.d"}).get()); + } + + @Test + @SneakyThrows + public void json_set_get_multiple_values() { + String key = UUID.randomUUID().toString(); + String jsonValue = "{\"a\": {\"c\": 1, \"d\": 4}, \"b\": {\"c\": 2}, \"c\": true}"; + + assertEquals(OK, Json.set(client, gs(key), gs("$"), gs(jsonValue)).get()); + + GlideString getResult = Json.get(client, gs(key), new GlideString[] {gs("$..c")}).get(); + + JSONAssert.assertEquals("[true, 1, 2]", getResult.getString(), JSONCompareMode.LENIENT); + + String getResultWithMultiPaths = Json.get(client, key, new String[] {"$..c", "$.c"}).get(); + + JSONAssert.assertEquals( + "{\"$..c\": [True, 1, 2], \"$.c\": [True]}", + getResultWithMultiPaths, + JSONCompareMode.LENIENT); + + assertEquals(OK, Json.set(client, key, "$..c", "\"new_value\"").get()); + String getResultAfterSetNewValue = Json.get(client, key, new String[] {"$..c"}).get(); + JSONAssert.assertEquals( + "[\"new_value\", \"new_value\", \"new_value\"]", + getResultAfterSetNewValue, + JSONCompareMode.LENIENT); + } + + @Test + @SneakyThrows + public void json_set_get_conditional_set() { + String key = UUID.randomUUID().toString(); + String jsonValue = "{\"a\": 1.0, \"b\": 2}"; + + assertNull(Json.set(client, key, "$", jsonValue, ConditionalChange.ONLY_IF_EXISTS).get()); + assertEquals( + OK, Json.set(client, key, "$", jsonValue, ConditionalChange.ONLY_IF_DOES_NOT_EXIST).get()); + assertNull(Json.set(client, key, "$.a", "4.5", ConditionalChange.ONLY_IF_DOES_NOT_EXIST).get()); + assertEquals("1.0", Json.get(client, key, new String[] {".a"}).get()); + assertEquals(OK, Json.set(client, key, "$.a", "4.5", ConditionalChange.ONLY_IF_EXISTS).get()); + assertEquals("4.5", Json.get(client, key, new String[] {".a"}).get()); + } + + @Test + @SneakyThrows + public void json_set_get_formatting() { + String key = UUID.randomUUID().toString(); + + assertEquals( + OK, + Json.set(client, key, "$", "{\"a\": 1.0, \"b\": 2, \"c\": {\"d\": 3, \"e\": 4}}").get()); + + String expectedGetResult = + "[\n" + + " {\n" + + " \"a\": 1.0,\n" + + " \"b\": 2,\n" + + " \"c\": {\n" + + " \"d\": 3,\n" + + " \"e\": 4\n" + + " }\n" + + " }\n" + + "]"; + String actualGetResult = + Json.get( + client, + key, + new String[] {"$"}, + JsonGetOptions.builder().indent(" ").newline("\n").space(" ").build()) + .get(); + assertEquals(expectedGetResult, actualGetResult); + + String expectedGetResult2 = + "[\n~{\n~~\"a\":*1.0,\n~~\"b\":*2,\n~~\"c\":*{\n~~~\"d\":*3,\n~~~\"e\":*4\n~~}\n~}\n]"; + String actualGetResult2 = + Json.get( + client, + key, + new String[] {"$"}, + JsonGetOptions.builder().indent("~").newline("\n").space("*").build()) + .get(); + assertEquals(expectedGetResult2, actualGetResult2); + } } From fa12563eab6f242a7ac1fca3a175952f908d3de3 Mon Sep 17 00:00:00 2001 From: tjzhang-BQ <111323543+tjzhang-BQ@users.noreply.github.com> Date: Fri, 18 Oct 2024 08:38:43 -0700 Subject: [PATCH 021/180] Node: Add CI support for server modules (#2472) * Node: add server modules CI support --------- Signed-off-by: TJ Zhang Signed-off-by: Yury-Fridlyand Signed-off-by: Chloe Co-authored-by: TJ Zhang Co-authored-by: Yury-Fridlyand Co-authored-by: Chloe --- .../workflows/build-node-wrapper/action.yml | 2 +- .../install-shared-dependencies/action.yml | 8 +- .github/workflows/node.yml | 78 +++++++++++++++---- CHANGELOG.md | 1 + node/package.json | 5 +- node/tests/ServerModules.test.ts | 60 ++++++++++++++ node/tests/TestUtilities.ts | 1 + 7 files changed, 131 insertions(+), 24 deletions(-) create mode 100644 node/tests/ServerModules.test.ts diff --git a/.github/workflows/build-node-wrapper/action.yml b/.github/workflows/build-node-wrapper/action.yml index 98246df22f..aa1200fbd5 100644 --- a/.github/workflows/build-node-wrapper/action.yml +++ b/.github/workflows/build-node-wrapper/action.yml @@ -31,7 +31,7 @@ inputs: required: true engine-version: description: "Engine version to install" - required: true + required: false type: string publish: description: "Enable building the wrapper in release mode" diff --git a/.github/workflows/install-shared-dependencies/action.yml b/.github/workflows/install-shared-dependencies/action.yml index abca1966cd..1cb56e63f0 100644 --- a/.github/workflows/install-shared-dependencies/action.yml +++ b/.github/workflows/install-shared-dependencies/action.yml @@ -22,10 +22,9 @@ inputs: - aarch64-unknown-linux-musl - x86_64-unknown-linux-musl engine-version: - description: "Engine version to install" - required: true - type: string - + description: "Engine version to install" + required: false + type: string github-token: description: "GITHUB_TOKEN, GitHub App installation access token" required: true @@ -72,6 +71,7 @@ runs: github-token: ${{ inputs.github-token }} - name: Install Valkey + if: ${{ inputs.engine-version != '' }} uses: ./.github/workflows/install-valkey with: engine-version: ${{ inputs.engine-version }} diff --git a/.github/workflows/node.yml b/.github/workflows/node.yml index 32db45e5c5..634219ea15 100644 --- a/.github/workflows/node.yml +++ b/.github/workflows/node.yml @@ -15,6 +15,7 @@ on: - .github/workflows/lint-rust/action.yml - .github/workflows/install-valkey/action.yml - .github/json_matrices/build-matrix.json + - .github/workflows/start-self-hosted-runner/action.yml pull_request: paths: - glide-core/src/** @@ -28,6 +29,7 @@ on: - .github/workflows/lint-rust/action.yml - .github/workflows/install-valkey/action.yml - .github/json_matrices/build-matrix.json + - .github/workflows/start-self-hosted-runner/action.yml workflow_dispatch: concurrency: @@ -39,17 +41,17 @@ env: jobs: load-engine-matrix: - runs-on: ubuntu-latest - outputs: - matrix: ${{ steps.load-engine-matrix.outputs.matrix }} - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Load the engine matrix - id: load-engine-matrix - shell: bash - run: echo "matrix=$(jq -c . < .github/json_matrices/engine-matrix.json)" >> $GITHUB_OUTPUT + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.load-engine-matrix.outputs.matrix }} + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Load the engine matrix + id: load-engine-matrix + shell: bash + run: echo "matrix=$(jq -c . < .github/json_matrices/engine-matrix.json)" >> $GITHUB_OUTPUT test-ubuntu-latest: runs-on: ubuntu-latest @@ -84,18 +86,18 @@ jobs: - name: test hybrid node modules - commonjs run: | - npm install --package-lock-only - npm ci - npm run build-and-test + npm install --package-lock-only + npm ci + npm run build-and-test working-directory: ./node/hybrid-node-tests/commonjs-test env: JEST_HTML_REPORTER_OUTPUT_PATH: test-report-commonjs.html - name: test hybrid node modules - ecma run: | - npm install --package-lock-only - npm ci - npm run build-and-test + npm install --package-lock-only + npm ci + npm run build-and-test working-directory: ./node/hybrid-node-tests/ecmascript-test env: JEST_HTML_REPORTER_OUTPUT_PATH: test-report-ecma.html @@ -269,3 +271,45 @@ jobs: node/test-report*.html utils/clusters/** benchmarks/results/** + + test-modules: + if: (github.repository_owner == 'valkey-io' && github.event_name == 'workflow_dispatch') || github.event.pull_request.head.repo.owner.login == 'valkey-io' + environment: AWS_ACTIONS + name: Running Module Tests + runs-on: [self-hosted, linux, ARM64] + timeout-minutes: 15 + + steps: + - name: Setup self-hosted runner access + run: sudo chown -R $USER:$USER /home/ubuntu/actions-runner/_work/valkey-glide + + - uses: actions/checkout@v4 + with: + submodules: recursive + + - name: Use Node.js 18.x + uses: actions/setup-node@v4 + with: + node-version: 18.x + + - name: Build Node wrapper + uses: ./.github/workflows/build-node-wrapper + with: + os: ubuntu + named_os: linux + arch: arm64 + target: aarch64-unknown-linux-gnu + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: test + run: npm run test-modules -- --cluster-endpoints=${{ secrets.MEMDB_MODULES_ENDPOINT }} --tls=true + working-directory: ./node + + - name: Upload test reports + if: always() + continue-on-error: true + uses: actions/upload-artifact@v4 + with: + name: test-report-node-modules-ubuntu + path: | + node/test-report*.html diff --git a/CHANGELOG.md b/CHANGELOG.md index a90f628944..6dd52a4859 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ #### Operational Enhancements * Java: Add modules CI ([#2388](https://github.com/valkey-io/valkey-glide/pull/2388), [#2404](https://github.com/valkey-io/valkey-glide/pull/2404), [#2416](https://github.com/valkey-io/valkey-glide/pull/2416)) +* Node: Add modules CI ([#2472](https://github.com/valkey-io/valkey-glide/pull/2472)) ## 1.1.0 (2024-09-24) diff --git a/node/package.json b/node/package.json index 20dbeea80b..1979e408e4 100644 --- a/node/package.json +++ b/node/package.json @@ -31,14 +31,15 @@ "build-protobuf": "npm run compile-protobuf-files && npm run fix-protobuf-file", "compile-protobuf-files": "cd src && pbjs -t static-module -o ProtobufMessage.js ../../glide-core/src/protobuf/*.proto && pbts -o ProtobufMessage.d.ts ProtobufMessage.js", "fix-protobuf-file": "replace 'this\\.encode\\(message, writer\\)\\.ldelim' 'this.encode(message, writer && writer.len ? writer.fork() : writer).ldelim' src/ProtobufMessage.js", - "test": "npm run build-test-utils && jest --verbose --runInBand --testPathIgnorePatterns='RedisModules'", + "test": "npm run build-test-utils && jest --verbose --runInBand --testPathIgnorePatterns='ServerModules'", "build-test-utils": "cd ../utils && npm i && npm run build", "lint:fix": "npm run install-linting && npx eslint -c ../eslint.config.mjs --fix && npm run prettier:format", "lint": "npm run install-linting && npx eslint -c ../eslint.config.mjs && npm run prettier:check:ci", "install-linting": "cd ../ & npm install", "prepack": "npmignore --auto", "prettier:check:ci": "npx prettier --check . --ignore-unknown '!**/*.{js,d.ts}'", - "prettier:format": "npx prettier --write . --ignore-unknown '!**/*.{js,d.ts}'" + "prettier:format": "npx prettier --write . --ignore-unknown '!**/*.{js,d.ts}'", + "test-modules": "npm run build-test-utils && jest --verbose --runInBand --testNamePattern='ServerModules'" }, "devDependencies": { "@jest/globals": "^29.7.0", diff --git a/node/tests/ServerModules.test.ts b/node/tests/ServerModules.test.ts new file mode 100644 index 0000000000..6a11884385 --- /dev/null +++ b/node/tests/ServerModules.test.ts @@ -0,0 +1,60 @@ +/** + * Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + */ +import { + afterAll, + afterEach, + beforeAll, + describe, + expect, + it, +} from "@jest/globals"; +import { GlideClusterClient, InfoOptions, ProtocolVersion } from ".."; +import { ValkeyCluster } from "../../utils/TestUtils"; +import { + flushAndCloseClient, + getClientConfigurationOption, + getServerVersion, + parseCommandLineArgs, + parseEndpoints, +} from "./TestUtilities"; + +const TIMEOUT = 50000; +describe("GlideJson", () => { + const testsFailed = 0; + let cluster: ValkeyCluster; + let client: GlideClusterClient; + beforeAll(async () => { + const clusterAddresses = parseCommandLineArgs()["cluster-endpoints"]; + cluster = await ValkeyCluster.initFromExistingCluster( + true, + parseEndpoints(clusterAddresses), + getServerVersion, + ); + }, 20000); + + afterEach(async () => { + await flushAndCloseClient(true, cluster.getAddresses(), client); + }); + + afterAll(async () => { + if (testsFailed === 0) { + await cluster.close(); + } + }, TIMEOUT); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "ServerModules check modules loaded", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption(cluster.getAddresses(), protocol), + ); + const info = await client.info({ + sections: [InfoOptions.Modules], + route: "randomNode", + }); + expect(info).toContain("# json_core_metrics"); + expect(info).toContain("# search_index_stats"); + }, + ); +}); diff --git a/node/tests/TestUtilities.ts b/node/tests/TestUtilities.ts index 6da0a39f00..c3fac91e09 100644 --- a/node/tests/TestUtilities.ts +++ b/node/tests/TestUtilities.ts @@ -387,6 +387,7 @@ export const getClientConfigurationOption = ( port, })), protocol, + useTLS: parseCommandLineArgs()["tls"] == "true", ...configOverrides, }; }; From 87412e8fb76b2f3d68e067368b7b38aa2e7b941d Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Fri, 18 Oct 2024 12:03:57 -0700 Subject: [PATCH 022/180] Java: `FT.AGGREGATE`. (#2468) * `FT.AGGREGATE`. Signed-off-by: Yury-Fridlyand --- CHANGELOG.md | 1 + glide-core/src/client/value_conversion.rs | 66 +++ .../glide/api/commands/servermodules/FT.java | 169 ++++++++ .../commands/FT/FTAggregateOptions.java | 315 ++++++++++++++ .../java/glide/modules/VectorSearchTests.java | 383 ++++++++++++++++++ 5 files changed, 934 insertions(+) create mode 100644 java/client/src/main/java/glide/api/models/commands/FT/FTAggregateOptions.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 6dd52a4859..429e2d2316 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ * Java: Added `FT.CREATE` ([#2414](https://github.com/valkey-io/valkey-glide/pull/2414)) * Java: Added `FT.DROPINDEX` ([#2440](https://github.com/valkey-io/valkey-glide/pull/2440)) * Java: Added `FT.SEARCH` ([#2439](https://github.com/valkey-io/valkey-glide/pull/2439)) +* Java: Added `FT.AGGREGATE` ([#2466](https://github.com/valkey-io/valkey-glide/pull/2466)) * Java: Added `JSON.SET` and `JSON.GET` ([#2462](https://github.com/valkey-io/valkey-glide/pull/2462)) * Core: Update routing for commands from server modules ([#2461](https://github.com/valkey-io/valkey-glide/pull/2461)) diff --git a/glide-core/src/client/value_conversion.rs b/glide-core/src/client/value_conversion.rs index e89c92adbe..64b853af9f 100644 --- a/glide-core/src/client/value_conversion.rs +++ b/glide-core/src/client/value_conversion.rs @@ -22,6 +22,7 @@ pub(crate) enum ExpectedReturnType<'a> { ArrayOfStrings, ArrayOfBools, ArrayOfDoubleOrNull, + FTAggregateReturnType, FTSearchReturnType, Lolwut, ArrayOfStringAndArrays, @@ -893,6 +894,70 @@ pub(crate) fn convert_to_expected_type( ) .into()), }, + ExpectedReturnType::FTAggregateReturnType => match value { + /* + Example of the response + 1) "3" + 2) 1) "condition" + 2) "refurbished" + 3) "bicylces" + 4) 1) "bicycle:9" + 3) 1) "condition" + 2) "used" + 3) "bicylces" + 4) 1) "bicycle:1" + 2) "bicycle:2" + 3) "bicycle:3" + 4) "bicycle:4" + 4) 1) "condition" + 2) "new" + 3) "bicylces" + 4) 1) "bicycle:5" + 2) "bicycle:6" + + Converting response to (array of maps) + 1) 1# "condition" => "refurbished" + 2# "bicylces" => + 1) "bicycle:9" + 2) 1# "condition" => "used" + 2# "bicylces" => + 1) "bicycle:1" + 2) "bicycle:2" + 3) "bicycle:3" + 4) "bicycle:4" + 3) 1# "condition" => "new" + 2# "bicylces" => + 1) "bicycle:5" + 2) "bicycle:6" + + Very first element in the response is meaningless and should be ignored. + */ + Value::Array(array) => { + let mut res = Vec::with_capacity(array.len() - 1); + for aggregation in array.into_iter().skip(1) { + let Value::Array(fields) = aggregation else { + return Err(( + ErrorKind::TypeError, + "Response couldn't be converted for FT.AGGREGATION", + format!("(`fields` was {:?})", get_value_type(&aggregation)), + ) + .into()); + }; + res.push(convert_array_to_map_by_type( + fields, + None, + None, + )?); + } + Ok(Value::Array(res)) + } + _ => Err(( + ErrorKind::TypeError, + "Response couldn't be converted to FT.AGGREGATION", + format!("(response was {:?})", get_value_type(&value)), + ) + .into()), + }, ExpectedReturnType::FTSearchReturnType => match value { /* Example of the response @@ -1303,6 +1368,7 @@ pub(crate) fn expected_type_for_cmd(cmd: &Cmd) -> Option { key_type: &None, value_type: &None, }), + b"FT.AGGREGATE" => Some(ExpectedReturnType::FTAggregateReturnType), b"FT.SEARCH" => Some(ExpectedReturnType::FTSearchReturnType), _ => None, } diff --git a/java/client/src/main/java/glide/api/commands/servermodules/FT.java b/java/client/src/main/java/glide/api/commands/servermodules/FT.java index 12a20a0ff4..106d540f8c 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/FT.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/FT.java @@ -2,6 +2,7 @@ package glide.api.commands.servermodules; import static glide.api.models.GlideString.gs; +import static glide.utils.ArrayTransformUtils.castArray; import static glide.utils.ArrayTransformUtils.concatenateArrays; import glide.api.BaseClient; @@ -9,10 +10,12 @@ import glide.api.GlideClusterClient; import glide.api.models.ClusterValue; import glide.api.models.GlideString; +import glide.api.models.commands.FT.FTAggregateOptions; import glide.api.models.commands.FT.FTCreateOptions; import glide.api.models.commands.FT.FTCreateOptions.FieldInfo; import glide.api.models.commands.FT.FTSearchOptions; import java.util.Arrays; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; import lombok.NonNull; @@ -305,6 +308,172 @@ public static CompletableFuture dropindex( return executeCommand(client, new GlideString[] {gs("FT.DROPINDEX"), indexName}, false); } + /** + * Runs a search query on an index, and perform aggregate transformations on the results. + * + * @param client The client to execute the command. + * @param indexName The index name. + * @param query The text query to search. + * @return Results of the last stage of the pipeline. + * @example + *
{@code
+     * // example of using the API:
+     * FT.aggregate(client, "myIndex", "*").get();
+     * // the response contains data in the following format:
+     * Map[] response = new Map[] {
+     *     Map.of(
+     *         gs("condition"), gs("refurbished"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:9") }
+     *     ),
+     *     Map.of(
+     *         gs("condition"), gs("used"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:1"), gs("bicycle:2"), gs("bicycle:3") }
+     *     ),
+     *     Map.of(
+     *         gs("condition"), gs("new"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:0"), gs("bicycle:5") }
+     *     )
+     * };
+     * }
+ */ + public static CompletableFuture[]> aggregate( + @NonNull BaseClient client, @NonNull String indexName, @NonNull String query) { + return aggregate(client, gs(indexName), gs(query)); + } + + /** + * Runs a search query on an index, and perform aggregate transformations on the results. + * + * @param client The client to execute the command. + * @param indexName The index name. + * @param query The text query to search. + * @param options Additional parameters for the command - see {@link FTAggregateOptions}. + * @return Results of the last stage of the pipeline. + * @example + *
{@code
+     * // example of using the API:
+     * FTAggregateOptions options = FTAggregateOptions.builder()
+     *     .loadFields(new String[] {"__key"})
+     *     .addExpression(
+     *             new FTAggregateOptions.GroupBy(
+     *                     new String[] {"@condition"},
+     *                     new Reducer[] {
+     *                         new Reducer("TOLIST", new String[] {"__key"}, "bicycles")
+     *                     }))
+     *     .build();
+     * FT.aggregate(client, "myIndex", "*", options).get();
+     * // the response contains data in the following format:
+     * Map[] response = new Map[] {
+     *     Map.of(
+     *         gs("condition"), gs("refurbished"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:9") }
+     *     ),
+     *     Map.of(
+     *         gs("condition"), gs("used"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:1"), gs("bicycle:2"), gs("bicycle:3") }
+     *     ),
+     *     Map.of(
+     *         gs("condition"), gs("new"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:0"), gs("bicycle:5") }
+     *     )
+     * };
+     * }
+ */ + public static CompletableFuture[]> aggregate( + @NonNull BaseClient client, + @NonNull String indexName, + @NonNull String query, + @NonNull FTAggregateOptions options) { + return aggregate(client, gs(indexName), gs(query), options); + } + + /** + * Runs a search query on an index, and perform aggregate transformations on the results. + * + * @param client The client to execute the command. + * @param indexName The index name. + * @param query The text query to search. + * @return Results of the last stage of the pipeline. + * @example + *
{@code
+     * // example of using the API:
+     * FT.aggregate(client, gs("myIndex"), gs("*")).get();
+     * // the response contains data in the following format:
+     * Map[] response = new Map[] {
+     *     Map.of(
+     *         gs("condition"), gs("refurbished"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:9") }
+     *     ),
+     *     Map.of(
+     *         gs("condition"), gs("used"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:1"), gs("bicycle:2"), gs("bicycle:3") }
+     *     ),
+     *     Map.of(
+     *         gs("condition"), gs("new"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:0"), gs("bicycle:5") }
+     *     )
+     * };
+     * }
+ */ + @SuppressWarnings("unchecked") + public static CompletableFuture[]> aggregate( + @NonNull BaseClient client, @NonNull GlideString indexName, @NonNull GlideString query) { + var args = new GlideString[] {gs("FT.AGGREGATE"), indexName, query}; + return FT.executeCommand(client, args, false) + .thenApply(res -> castArray(res, Map.class)); + } + + /** + * Runs a search query on an index, and perform aggregate transformations on the results. + * + * @param client The client to execute the command. + * @param indexName The index name. + * @param query The text query to search. + * @param options Additional parameters for the command - see {@link FTAggregateOptions}. + * @return Results of the last stage of the pipeline. + * @example + *
{@code
+     * // example of using the API:
+     * FTAggregateOptions options = FTAggregateOptions.builder()
+     *     .loadFields(new String[] {"__key"})
+     *     .addExpression(
+     *             new FTAggregateOptions.GroupBy(
+     *                     new String[] {"@condition"},
+     *                     new Reducer[] {
+     *                         new Reducer("TOLIST", new String[] {"__key"}, "bicycles")
+     *                     }))
+     *     .build();
+     * FT.aggregate(client, gs("myIndex"), gs("*"), options).get();
+     * // the response contains data in the following format:
+     * Map[] response = new Map[] {
+     *     Map.of(
+     *         gs("condition"), gs("refurbished"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:9") }
+     *     ),
+     *     Map.of(
+     *         gs("condition"), gs("used"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:1"), gs("bicycle:2"), gs("bicycle:3") }
+     *     ),
+     *     Map.of(
+     *         gs("condition"), gs("new"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:0"), gs("bicycle:5") }
+     *     )
+     * };
+     * }
+ */ + @SuppressWarnings("unchecked") + public static CompletableFuture[]> aggregate( + @NonNull BaseClient client, + @NonNull GlideString indexName, + @NonNull GlideString query, + @NonNull FTAggregateOptions options) { + var args = + concatenateArrays( + new GlideString[] {gs("FT.AGGREGATE"), indexName, query}, options.toArgs()); + return FT.executeCommand(client, args, false) + .thenApply(res -> castArray(res, Map.class)); + } + /** * A wrapper for custom command API. * diff --git a/java/client/src/main/java/glide/api/models/commands/FT/FTAggregateOptions.java b/java/client/src/main/java/glide/api/models/commands/FT/FTAggregateOptions.java new file mode 100644 index 0000000000..73ffdbf412 --- /dev/null +++ b/java/client/src/main/java/glide/api/models/commands/FT/FTAggregateOptions.java @@ -0,0 +1,315 @@ +/** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ +package glide.api.models.commands.FT; + +import static glide.api.models.GlideString.gs; +import static glide.utils.ArrayTransformUtils.concatenateArrays; +import static glide.utils.ArrayTransformUtils.toGlideStringArray; + +import glide.api.BaseClient; +import glide.api.commands.servermodules.FT; +import glide.api.models.GlideString; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; +import lombok.Builder; + +/** + * Additional arguments for {@link FT#aggregate(BaseClient, String, String, FTAggregateOptions)} + * command. + */ +@Builder +public class FTAggregateOptions { + /** Query timeout in milliseconds. */ + private final Integer timeout; + + private final boolean loadAll; + + private final GlideString[] loadFields; + + private final List expressions; + + /** Convert to module API. */ + public GlideString[] toArgs() { + var args = new ArrayList(); + if (loadAll) { + args.add(gs("LOAD")); + args.add(gs("*")); + } else if (loadFields != null) { + args.add(gs("LOAD")); + args.add(gs(Integer.toString(loadFields.length))); + args.addAll(List.of(loadFields)); + } + if (timeout != null) { + args.add(gs("TIMEOUT")); + args.add(gs(timeout.toString())); + } + if (!params.isEmpty()) { + args.add(gs("PARAMS")); + args.add(gs(Integer.toString(params.size() * 2))); + params.forEach( + (name, value) -> { + args.add(gs(name)); + args.add(value); + }); + } + if (expressions != null) { + for (var expression : expressions) { + args.addAll(List.of(expression.toArgs())); + } + } + return args.toArray(GlideString[]::new); + } + + /** + * Query parameters, which could be referenced in the query by $ sign, followed by + * the parameter name. + */ + @Builder.Default private final Map params = new HashMap<>(); + + public static class FTAggregateOptionsBuilder { + // private - hiding this API from user + void loadAll(boolean loadAll) {} + + void expressions(List expressions) {} + + public FTAggregateOptionsBuilder loadAll() { + loadAll = true; + return this; + } + + public FTAggregateOptionsBuilder loadFields(String[] fields) { + loadFields = toGlideStringArray(fields); + loadAll = false; + return this; + } + + public FTAggregateOptionsBuilder loadFields(GlideString[] fields) { + loadFields = fields; + loadAll = false; + return this; + } + + public FTAggregateOptionsBuilder addExpression(FTAggregateExpression expression) { + if (expressions == null) expressions = new ArrayList<>(); + expressions.add(expression); + return this; + } + } + + public abstract static class FTAggregateExpression { + abstract GlideString[] toArgs(); + } + + enum ExpressionType { + LIMIT, + FILTER, + GROUPBY, + SORTBY, + REDUCE, + APPLY + } + + /** Configure results limiting. */ + public static class Limit extends FTAggregateExpression { + private final int offset; + private final int count; + + public Limit(int offset, int count) { + this.offset = offset; + this.count = count; + } + + @Override + GlideString[] toArgs() { + return new GlideString[] { + gs(ExpressionType.LIMIT.toString()), + gs(Integer.toString(offset)), + gs(Integer.toString(count)) + }; + } + } + + /** + * Filter the results using predicate expression relating to values in each result. It is applied + * post query and relate to the current state of the pipeline. + */ + public static class Filter extends FTAggregateExpression { + private final GlideString expression; + + public Filter(GlideString expression) { + this.expression = expression; + } + + public Filter(String expression) { + this.expression = gs(expression); + } + + @Override + GlideString[] toArgs() { + return new GlideString[] {gs(ExpressionType.FILTER.toString()), expression}; + } + } + + /** + * Filter the results using predicate expression relating to values in each result. It is applied + * post query and relate to the current state of the pipeline. + */ + public static class GroupBy extends FTAggregateExpression { + private final GlideString[] properties; + private final Reducer[] reducers; + + public GroupBy(GlideString[] properties, Reducer[] reducers) { + this.properties = properties; + this.reducers = reducers; + } + + public GroupBy(String[] properties, Reducer[] reducers) { + this.properties = toGlideStringArray(properties); + this.reducers = reducers; + } + + public GroupBy(GlideString[] properties) { + this.properties = properties; + this.reducers = new Reducer[0]; + } + + public GroupBy(String[] properties) { + this.properties = toGlideStringArray(properties); + this.reducers = new Reducer[0]; + } + + @Override + GlideString[] toArgs() { + return concatenateArrays( + new GlideString[] { + gs(ExpressionType.GROUPBY.toString()), gs(Integer.toString(properties.length)) + }, + properties, + Stream.of(reducers).map(Reducer::toArgs).flatMap(Stream::of).toArray(GlideString[]::new)); + } + + /** + * A function that handles the group entries, either counting them, or performing multiple + * aggregate operations. + */ + public static class Reducer { + private final String function; + private final GlideString[] args; + private final String alias; + + public Reducer(String function, GlideString[] args, String alias) { + this.function = function; + this.args = args; + this.alias = alias; + } + + public Reducer(String function, GlideString[] args) { + this.function = function; + this.args = args; + this.alias = null; + } + + public Reducer(String function, String[] args, String alias) { + this.function = function; + this.args = toGlideStringArray(args); + this.alias = alias; + } + + public Reducer(String function, String[] args) { + this.function = function; + this.args = toGlideStringArray(args); + this.alias = null; + } + + GlideString[] toArgs() { + return concatenateArrays( + new GlideString[] { + gs(ExpressionType.REDUCE.toString()), gs(function), gs(Integer.toString(args.length)) + }, + args, + alias == null ? new GlideString[0] : new GlideString[] {gs("AS"), gs(alias)}); + } + } + } + + /** Sort the pipeline using a list of properties. */ + public static class SortBy extends FTAggregateExpression { + + private final SortProperty[] properties; + private final Integer max; + + public SortBy(SortProperty[] properties) { + this.properties = properties; + this.max = null; + } + + public SortBy(SortProperty[] properties, int max) { + this.properties = properties; + this.max = max; + } + + @Override + GlideString[] toArgs() { + return concatenateArrays( + new GlideString[] { + gs(ExpressionType.SORTBY.toString()), gs(Integer.toString(properties.length * 2)), + }, + Stream.of(properties) + .map(SortProperty::toArgs) + .flatMap(Stream::of) + .toArray(GlideString[]::new), + max == null ? new GlideString[0] : new GlideString[] {gs("MAX"), gs(max.toString())}); + } + + public enum SortOrder { + ASC, + DESC + } + + /** A sorting parameter. */ + public static class SortProperty { + private final GlideString property; + private final SortOrder order; + + public SortProperty(GlideString property, SortOrder order) { + this.property = property; + this.order = order; + } + + public SortProperty(String property, SortOrder order) { + this.property = gs(property); + this.order = order; + } + + GlideString[] toArgs() { + return new GlideString[] {property, gs(order.toString())}; + } + } + } + + /** + * Apply a 1-to-1 transformation on one or more properties and either stores the result as a new + * property down the pipeline or replace any property using this transformation. + */ + public static class Apply extends FTAggregateExpression { + private final GlideString expression; + private final GlideString alias; + + public Apply(GlideString expression, GlideString alias) { + this.expression = expression; + this.alias = alias; + } + + public Apply(String expression, String alias) { + this.expression = gs(expression); + this.alias = gs(alias); + } + + @Override + GlideString[] toArgs() { + return new GlideString[] {gs(ExpressionType.APPLY.toString()), expression, gs("AS"), alias}; + } + } +} diff --git a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java index 33b9bd9dd6..fbc3eab196 100644 --- a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java +++ b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java @@ -1,6 +1,7 @@ /** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ package glide.modules; +import static glide.TestUtilities.assertDeepEquals; import static glide.TestUtilities.commonClusterClientConfig; import static glide.api.BaseClient.OK; import static glide.api.models.GlideString.gs; @@ -15,6 +16,13 @@ import glide.api.GlideClusterClient; import glide.api.commands.servermodules.FT; +import glide.api.models.commands.FT.FTAggregateOptions; +import glide.api.models.commands.FT.FTAggregateOptions.Apply; +import glide.api.models.commands.FT.FTAggregateOptions.GroupBy; +import glide.api.models.commands.FT.FTAggregateOptions.GroupBy.Reducer; +import glide.api.models.commands.FT.FTAggregateOptions.SortBy; +import glide.api.models.commands.FT.FTAggregateOptions.SortBy.SortOrder; +import glide.api.models.commands.FT.FTAggregateOptions.SortBy.SortProperty; import glide.api.models.commands.FT.FTCreateOptions; import glide.api.models.commands.FT.FTCreateOptions.DistanceMetric; import glide.api.models.commands.FT.FTCreateOptions.FieldInfo; @@ -42,6 +50,9 @@ public class VectorSearchTests { private static GlideClusterClient client; + /** Waiting interval to let server process the data before querying */ + private static final int DATA_PROCESSING_TIMEOUT = 1000; // ms + @BeforeAll @SneakyThrows public static void init() { @@ -330,4 +341,376 @@ public void ft_drop() { assertInstanceOf(RequestException.class, exception.getCause()); assertTrue(exception.getMessage().contains("Index does not exist")); } + + @SneakyThrows + @Test + public void ft_aggregate() { + var prefixBicycles = "{bicycles}:"; + var indexBicycles = prefixBicycles + UUID.randomUUID(); + var prefixMovies = "{movies}:"; + var indexMovies = prefixMovies + UUID.randomUUID(); + + // FT.CREATE idx:bicycle ON JSON PREFIX 1 bicycle: SCHEMA $.model AS model TEXT $.description AS + // description TEXT $.price AS price NUMERIC $.condition AS condition TAG SEPARATOR , + assertEquals( + OK, + FT.create( + client, + indexBicycles, + new FieldInfo[] { + new FieldInfo("$.model", "model", new TextField()), + new FieldInfo("$.description", "description", new TextField()), + new FieldInfo("$.price", "price", new NumericField()), + new FieldInfo("$.condition", "condition", new TagField(',')), + }, + FTCreateOptions.builder() + .indexType(IndexType.JSON) + .prefixes(new String[] {prefixBicycles}) + .build()) + .get()); + + // TODO use JSON module API + // TODO check JSON module loaded + + client + .customCommand( + new String[] { + "JSON.SET", + prefixBicycles + 0, + ".", + "{\"brand\": \"Velorim\", \"model\": \"Jigger\", \"price\": 270, \"description\":" + + " \"Small and powerful, the Jigger is the best ride for the smallest of tikes!" + + " This is the tiniest kids\\u2019 pedal bike on the market available without a" + + " coaster brake, the Jigger is the vehicle of choice for the rare tenacious" + + " little rider raring to go.\", \"condition\": \"new\"}" + }) + .get(); + client + .customCommand( + new String[] { + "JSON.SET", + prefixBicycles + 1, + ".", + "{\"brand\": \"Bicyk\", \"model\": \"Hillcraft\", \"price\": 1200, \"description\":" + + " \"Kids want to ride with as little weight as possible. Especially on an" + + " incline! They may be at the age when a 27.5\\\" wheel bike is just too clumsy" + + " coming off a 24\\\" bike. The Hillcraft 26 is just the solution they need!\"," + + " \"condition\": \"used\"}" + }) + .get(); + client + .customCommand( + new String[] { + "JSON.SET", + prefixBicycles + 2, + ".", + "{\"brand\": \"Nord\", \"model\": \"Chook air 5\", \"price\": 815, \"description\":" + + " \"The Chook Air 5 gives kids aged six years and older a durable and" + + " uberlight mountain bike for their first experience on tracks and easy" + + " cruising through forests and fields. The lower top tube makes it easy to" + + " mount and dismount in any situation, giving your kids greater safety on the" + + " trails.\", \"condition\": \"used\"}" + }) + .get(); + client + .customCommand( + new String[] { + "JSON.SET", + prefixBicycles + 3, + ".", + "{\"brand\": \"Eva\", \"model\": \"Eva 291\", \"price\": 3400, \"description\": \"The" + + " sister company to Nord, Eva launched in 2005 as the first and only" + + " women-dedicated bicycle brand. Designed by women for women, allEva bikes are" + + " optimized for the feminine physique using analytics from a body metrics" + + " database. If you like 29ers, try the Eva 291. It\\u2019s a brand new bike for" + + " 2022.. This full-suspension, cross-country ride has been designed for" + + " velocity. The 291 has 100mm of front and rear travel, a superlight aluminum" + + " frame and fast-rolling 29-inch wheels. Yippee!\", \"condition\": \"used\"}" + }) + .get(); + client + .customCommand( + new String[] { + "JSON.SET", + prefixBicycles + 4, + ".", + "{\"brand\": \"Noka Bikes\", \"model\": \"Kahuna\", \"price\": 3200, \"description\":" + + " \"Whether you want to try your hand at XC racing or are looking for a lively" + + " trail bike that's just as inspiring on the climbs as it is over rougher" + + " ground, the Wilder is one heck of a bike built specifically for short women." + + " Both the frames and components have been tweaked to include a women\\u2019s" + + " saddle, different bars and unique colourway.\", \"condition\": \"used\"}" + }) + .get(); + client + .customCommand( + new String[] { + "JSON.SET", + prefixBicycles + 5, + ".", + "{\"brand\": \"Breakout\", \"model\": \"XBN 2.1 Alloy\", \"price\": 810," + + " \"description\": \"The XBN 2.1 Alloy is our entry-level road bike \\u2013 but" + + " that\\u2019s not to say that it\\u2019s a basic machine. With an internal" + + " weld aluminium frame, a full carbon fork, and the slick-shifting Claris gears" + + " from Shimano\\u2019s, this is a bike which doesn\\u2019t break the bank and" + + " delivers craved performance.\", \"condition\": \"new\"}" + }) + .get(); + client + .customCommand( + new String[] { + "JSON.SET", + prefixBicycles + 6, + ".", + "{\"brand\": \"ScramBikes\", \"model\": \"WattBike\", \"price\": 2300," + + " \"description\": \"The WattBike is the best e-bike for people who still feel" + + " young at heart. It has a Bafang 1000W mid-drive system and a 48V 17.5AH" + + " Samsung Lithium-Ion battery, allowing you to ride for more than 60 miles on" + + " one charge. It\\u2019s great for tackling hilly terrain or if you just fancy" + + " a more leisurely ride. With three working modes, you can choose between" + + " E-bike, assisted bicycle, and normal bike modes.\", \"condition\": \"new\"}" + }) + .get(); + client + .customCommand( + new String[] { + "JSON.SET", + prefixBicycles + 7, + ".", + "{\"brand\": \"Peaknetic\", \"model\": \"Secto\", \"price\": 430, \"description\":" + + " \"If you struggle with stiff fingers or a kinked neck or back after a few" + + " minutes on the road, this lightweight, aluminum bike alleviates those issues" + + " and allows you to enjoy the ride. From the ergonomic grips to the" + + " lumbar-supporting seat position, the Roll Low-Entry offers incredible" + + " comfort. The rear-inclined seat tube facilitates stability by allowing you to" + + " put a foot on the ground to balance at a stop, and the low step-over frame" + + " makes it accessible for all ability and mobility levels. The saddle is very" + + " soft, with a wide back to support your hip joints and a cutout in the center" + + " to redistribute that pressure. Rim brakes deliver satisfactory braking" + + " control, and the wide tires provide a smooth, stable ride on paved roads and" + + " gravel. Rack and fender mounts facilitate setting up the Roll Low-Entry as" + + " your preferred commuter, and the BMX-like handlebar offers space for mounting" + + " a flashlight, bell, or phone holder.\", \"condition\": \"new\"}" + }) + .get(); + client + .customCommand( + new String[] { + "JSON.SET", + prefixBicycles + 8, + ".", + "{\"brand\": \"nHill\", \"model\": \"Summit\", \"price\": 1200, \"description\":" + + " \"This budget mountain bike from nHill performs well both on bike paths and" + + " on the trail. The fork with 100mm of travel absorbs rough terrain. Fat Kenda" + + " Booster tires give you grip in corners and on wet trails. The Shimano Tourney" + + " drivetrain offered enough gears for finding a comfortable pace to ride" + + " uphill, and the Tektro hydraulic disc brakes break smoothly. Whether you want" + + " an affordable bike that you can take to work, but also take trail in" + + " mountains on the weekends or you\\u2019re just after a stable, comfortable" + + " ride for the bike path, the Summit gives a good value for money.\"," + + " \"condition\": \"new\"}" + }) + .get(); + client + .customCommand( + new String[] { + "JSON.SET", + prefixBicycles + 9, + ".", + "{\"model\": \"ThrillCycle\", \"brand\": \"BikeShind\", \"price\": 815," + + " \"description\": \"An artsy, retro-inspired bicycle that\\u2019s as" + + " functional as it is pretty: The ThrillCycle steel frame offers a smooth ride." + + " A 9-speed drivetrain has enough gears for coasting in the city, but we" + + " wouldn\\u2019t suggest taking it to the mountains. Fenders protect you from" + + " mud, and a rear basket lets you transport groceries, flowers and books. The" + + " ThrillCycle comes with a limited lifetime warranty, so this little guy will" + + " last you long past graduation.\", \"condition\": \"refurbished\"}" + }) + .get(); + Thread.sleep(DATA_PROCESSING_TIMEOUT); // let server digest the data and update index + + // FT.AGGREGATE idx:bicycle "*" LOAD 1 "__key" GROUPBY 1 "@condition" REDUCE COUNT 0 AS bicylces + var aggreg = + FT.aggregate( + client, + indexBicycles, + "*", + FTAggregateOptions.builder() + .loadFields(new String[] {"__key"}) + .addExpression( + new GroupBy( + new String[] {"@condition"}, + new Reducer[] {new Reducer("COUNT", new String[0], "bicycles")})) + .build()) + .get(); + // elements (maps in array) could be reordered, comparing as sets + assertDeepEquals( + Set.of( + Map.of(gs("condition"), gs("new"), gs("bicycles"), 5.), + Map.of(gs("condition"), gs("used"), gs("bicycles"), 4.), + Map.of(gs("condition"), gs("refurbished"), gs("bicycles"), 1.)), + Set.of(aggreg)); + + // FT.CREATE idx:movie ON hash PREFIX 1 "movie:" SCHEMA title TEXT release_year NUMERIC rating + // NUMERIC genre TAG votes NUMERIC + assertEquals( + OK, + FT.create( + client, + indexMovies, + new FieldInfo[] { + new FieldInfo("title", new TextField()), + new FieldInfo("release_year", new NumericField()), + new FieldInfo("rating", new NumericField()), + new FieldInfo("genre", new TagField()), + new FieldInfo("votes", new NumericField()), + }, + FTCreateOptions.builder() + .indexType(IndexType.HASH) + .prefixes(new String[] {prefixMovies}) + .build()) + .get()); + + client + .hset( + prefixMovies + 11002, + Map.of( + "title", + "Star Wars: Episode V - The Empire Strikes Back", + "plot", + "After the Rebels are brutally overpowered by the Empire on the ice planet Hoth," + + " Luke Skywalker begins Jedi training with Yoda, while his friends are" + + " pursued by Darth Vader and a bounty hunter named Boba Fett all over the" + + " galaxy.", + "release_year", + "1980", + "genre", + "Action", + "rating", + "8.7", + "votes", + "1127635", + "imdb_id", + "tt0080684")) + .get(); + client + .hset( + prefixMovies + 11003, + Map.of( + "title", + "The Godfather", + "plot", + "The aging patriarch of an organized crime dynasty transfers control of his" + + " clandestine empire to his reluctant son.", + "release_year", + "1972", + "genre", + "Drama", + "rating", + "9.2", + "votes", + "1563839", + "imdb_id", + "tt0068646")) + .get(); + client + .hset( + prefixMovies + 11004, + Map.of( + "title", + "Heat", + "plot", + "A group of professional bank robbers start to feel the heat from police when they" + + " unknowingly leave a clue at their latest heist.", + "release_year", + "1995", + "genre", + "Thriller", + "rating", + "8.2", + "votes", + "559490", + "imdb_id", + "tt0113277")) + .get(); + client + .hset( + prefixMovies + 11005, + Map.of( + "title", + "Star Wars: Episode VI - Return of the Jedi", + "genre", + "Action", + "votes", + "906260", + "rating", + "8.3", + "release_year", + "1983", + "plot", + "The Rebels dispatch to Endor to destroy the second Empire's Death Star.", + "ibmdb_id", + "tt0086190")) + .get(); + Thread.sleep(DATA_PROCESSING_TIMEOUT); // let server digest the data and update index + + // FT.AGGREGATE idx:movie * LOAD * APPLY ceil(@rating) as r_rating GROUPBY 1 @genre REDUCE + // COUNT 0 AS nb_of_movies REDUCE SUM 1 votes AS nb_of_votes REDUCE AVG 1 r_rating AS avg_rating + // SORTBY 4 @avg_rating DESC @nb_of_votes DESC + aggreg = + FT.aggregate( + client, + indexMovies, + "*", + FTAggregateOptions.builder() + .loadAll() + .addExpression(new Apply("ceil(@rating)", "r_rating")) + .addExpression( + new GroupBy( + new String[] {"@genre"}, + new Reducer[] { + new Reducer("COUNT", new String[0], "nb_of_movies"), + new Reducer("SUM", new String[] {"votes"}, "nb_of_votes"), + new Reducer("AVG", new String[] {"r_rating"}, "avg_rating") + })) + .addExpression( + new SortBy( + new SortProperty[] { + new SortProperty("@avg_rating", SortOrder.DESC), + new SortProperty("@nb_of_votes", SortOrder.DESC) + })) + .build()) + .get(); + // elements (maps in array) could be reordered, comparing as sets + assertDeepEquals( + Set.of( + Map.of( + gs("genre"), + gs("Drama"), + gs("nb_of_movies"), + 1., + gs("nb_of_votes"), + 1563839., + gs("avg_rating"), + 10.), + Map.of( + gs("genre"), + gs("Action"), + gs("nb_of_movies"), + 2., + gs("nb_of_votes"), + 2033895., + gs("avg_rating"), + 9.), + Map.of( + gs("genre"), + gs("Thriller"), + gs("nb_of_movies"), + 1., + gs("nb_of_votes"), + 559490., + gs("avg_rating"), + 9.)), + Set.of(aggreg)); + } } From d291d87bf140efc2f9b9b28a2887b5ffdd8a124c Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Fri, 18 Oct 2024 13:01:53 -0700 Subject: [PATCH 023/180] Java: add `FT.INFO` (#2405) * `FT.INFO` Signed-off-by: Yury-Fridlyand --- CHANGELOG.md | 1 + glide-core/src/client/value_conversion.rs | 105 +++++++++++++++- .../glide/api/commands/servermodules/FT.java | 113 ++++++++++++++++++ .../java/glide/modules/VectorSearchTests.java | 69 ++++++++++- 4 files changed, 284 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 429e2d2316..cf0761a6ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ * Python: Add JSON.NUMINCRBY command ([#2448](https://github.com/valkey-io/valkey-glide/pull/2448)) * Python: Add JSON.NUMMULTBY command ([#2458](https://github.com/valkey-io/valkey-glide/pull/2458)) * Java: Added `FT.CREATE` ([#2414](https://github.com/valkey-io/valkey-glide/pull/2414)) +* Java: Added `FT.INFO` ([#2405](https://github.com/valkey-io/valkey-glide/pull/2441)) * Java: Added `FT.DROPINDEX` ([#2440](https://github.com/valkey-io/valkey-glide/pull/2440)) * Java: Added `FT.SEARCH` ([#2439](https://github.com/valkey-io/valkey-glide/pull/2439)) * Java: Added `FT.AGGREGATE` ([#2466](https://github.com/valkey-io/valkey-glide/pull/2466)) diff --git a/glide-core/src/client/value_conversion.rs b/glide-core/src/client/value_conversion.rs index 64b853af9f..d16d4ef939 100644 --- a/glide-core/src/client/value_conversion.rs +++ b/glide-core/src/client/value_conversion.rs @@ -24,6 +24,7 @@ pub(crate) enum ExpectedReturnType<'a> { ArrayOfDoubleOrNull, FTAggregateReturnType, FTSearchReturnType, + FTInfoReturnType, Lolwut, ArrayOfStringAndArrays, ArrayOfArraysOfDoubleOrNull, @@ -999,7 +1000,108 @@ pub(crate) fn convert_to_expected_type( }, _ => Err(( ErrorKind::TypeError, - "Response couldn't be converted to Pair", + "Response couldn't be converted for FT.SEARCH", + format!("(response was {:?})", get_value_type(&value)), + ) + .into()) + }, + ExpectedReturnType::FTInfoReturnType => match value { + /* + Example of the response + 1) index_name + 2) "957fa3ca-2280-467d-873f-8763a36fbd5a" + 3) creation_timestamp + 4) (integer) 1728348101740745 + 5) key_type + 6) HASH + 7) key_prefixes + 8) 1) "blog:post:" + 9) fields + 10) 1) 1) identifier + 2) category + 3) field_name + 4) category + 5) type + 6) TAG + 7) option + 8) + 2) 1) identifier + 2) vec + 3) field_name + 4) VEC + 5) type + 6) VECTOR + 7) option + 8) + 9) vector_params + 10) 1) algorithm + 2) HNSW + 3) data_type + 4) FLOAT32 + 5) dimension + 6) (integer) 2 + ... + + Converting response to + 1# "index_name" => "957fa3ca-2280-467d-873f-8763a36fbd5a" + 2# "creation_timestamp" => 1728348101740745 + 3# "key_type" => "HASH" + 4# "key_prefixes" => + 1) "blog:post:" + 5# "fields" => + 1) 1# "identifier" => "category" + 2# "field_name" => "category" + 3# "type" => "TAG" + 4# "option" => "" + 2) 1# "identifier" => "vec" + 2# "field_name" => "VEC" + 3# "type" => "TAVECTORG" + 4# "option" => "" + 5# "vector_params" => + 1# "algorithm" => "HNSW" + 2# "data_type" => "FLOAT32" + 3# "dimension" => 2 + ... + + Map keys (odd array elements) are simple strings, not bulk strings. + */ + Value::Array(_) => { + let Value::Map(mut map) = convert_to_expected_type(value, Some(ExpectedReturnType::Map { + key_type: &None, + value_type: &None, + }))? else { unreachable!() }; + let Some(fields_pair) = map.iter_mut().find(|(key, _)| { + *key == Value::SimpleString("fields".into()) + }) else { return Ok(Value::Map(map)) }; + let (fields_key, fields_value) = std::mem::replace(fields_pair, (Value::Nil, Value::Nil)); + let Value::Array(fields) = fields_value else { + return Err(( + ErrorKind::TypeError, + "Response couldn't be converted for FT.INFO", + format!("(`fields` was {:?})", get_value_type(&fields_value)), + ).into()); + }; + let fields = fields.into_iter().map(|field| { + let Value::Map(mut field_params) = convert_to_expected_type(field, Some(ExpectedReturnType::Map { + key_type: &None, + value_type: &None, + }))? else { unreachable!() }; + let Some(vector_params_pair) = field_params.iter_mut().find(|(key, _)| { + *key == Value::SimpleString("vector_params".into()) + }) else { return Ok(Value::Map(field_params)) }; + let (vector_params_key, vector_params_value) = std::mem::replace(vector_params_pair, (Value::Nil, Value::Nil)); + let _ = std::mem::replace(vector_params_pair, (vector_params_key, convert_to_expected_type(vector_params_value, Some(ExpectedReturnType::Map { + key_type: &None, + value_type: &None, + }))?)); + Ok(Value::Map(field_params)) + }).collect::>>()?; + let _ = std::mem::replace(fields_pair, (fields_key, Value::Array(fields))); + Ok(Value::Map(map)) + }, + _ => Err(( + ErrorKind::TypeError, + "Response couldn't be converted for FT.INFO", format!("(response was {:?})", get_value_type(&value)), ) .into()) @@ -1370,6 +1472,7 @@ pub(crate) fn expected_type_for_cmd(cmd: &Cmd) -> Option { }), b"FT.AGGREGATE" => Some(ExpectedReturnType::FTAggregateReturnType), b"FT.SEARCH" => Some(ExpectedReturnType::FTSearchReturnType), + b"FT.INFO" => Some(ExpectedReturnType::FTInfoReturnType), _ => None, } } diff --git a/java/client/src/main/java/glide/api/commands/servermodules/FT.java b/java/client/src/main/java/glide/api/commands/servermodules/FT.java index 106d540f8c..7a5b514816 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/FT.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/FT.java @@ -474,6 +474,119 @@ public static CompletableFuture[]> aggregate( .thenApply(res -> castArray(res, Map.class)); } + /** + * Returns information about a given index. + * + * @param indexName The index name. + * @return Nested maps with info about the index. See example for more details. + * @example + *
{@code
+     * // example of using the API:
+     * Map response = client.ftinfo("myIndex").get();
+     * // the response contains data in the following format:
+     * Map data = Map.of(
+     *     "index_name", gs("bcd97d68-4180-4bc5-98fe-5125d0abbcb8"),
+     *     "index_status", gs("AVAILABLE"),
+     *     "key_type", gs("JSON"),
+     *     "creation_timestamp", 1728348101728771L,
+     *     "key_prefixes", new String[] { gs("json:") },
+     *     "num_indexed_vectors", 0L,
+     *     "space_usage", 653471L,
+     *     "num_docs", 0L,
+     *     "vector_space_usage", 653471L,
+     *     "index_degradation_percentage", 0L,
+     *     "fulltext_space_usage", 0L,
+     *     "current_lag", 0L,
+     *     "fields", new Object [] {
+     *         Map.of(
+     *             gs("identifier"), gs("$.vec"),
+     *             gs("type"), gs("VECTOR"),
+     *             gs("field_name"), gs("VEC"),
+     *             gs("option"), gs(""),
+     *             gs("vector_params", Map.of(
+     *                 gs("data_type", gs("FLOAT32"),
+     *                 gs("initial_capacity", 1000L,
+     *                 gs("current_capacity", 1000L,
+     *                 gs("distance_metric", gs("L2"),
+     *                 gs("dimension", 6L,
+     *                 gs("block_size", 1024L,
+     *                 gs("algorithm", gs("FLAT")
+     *           )
+     *         ),
+     *         Map.of(
+     *             gs("identifier"), gs("name"),
+     *             gs("type"), gs("TEXT"),
+     *             gs("field_name"), gs("name"),
+     *             gs("option"), gs("")
+     *         ),
+     *     }
+     * );
+     * }
+ */ + public static CompletableFuture> info( + @NonNull BaseClient client, @NonNull String indexName) { + // TODO inconsistency: the outer map is `Map`, + // while inner maps are `Map` + // The outer map converted from `Map` in ClusterValue::ofMultiValueBinary + // TODO server returns all strings as `SimpleString`, we're safe to convert all to + // `GlideString`s to `String` + return executeCommand(client, new GlideString[] {gs("FT.INFO"), gs(indexName)}, true); + } + + /** + * Returns information about a given index. + * + * @param indexName The index name. + * @return Nested maps with info about the index. See example for more details. + * @example + *
{@code
+     * // example of using the API:
+     * Map response = client.ftinfo(gs("myIndex")).get();
+     * // the response contains data in the following format:
+     * Map data = Map.of(
+     *     "index_name", gs("bcd97d68-4180-4bc5-98fe-5125d0abbcb8"),
+     *     "index_status", gs("AVAILABLE"),
+     *     "key_type", gs("JSON"),
+     *     "creation_timestamp", 1728348101728771L,
+     *     "key_prefixes", new String[] { gs("json:") },
+     *     "num_indexed_vectors", 0L,
+     *     "space_usage", 653471L,
+     *     "num_docs", 0L,
+     *     "vector_space_usage", 653471L,
+     *     "index_degradation_percentage", 0L,
+     *     "fulltext_space_usage", 0L,
+     *     "current_lag", 0L,
+     *     "fields", new Object [] {
+     *         Map.of(
+     *             gs("identifier"), gs("$.vec"),
+     *             gs("type"), gs("VECTOR"),
+     *             gs("field_name"), gs("VEC"),
+     *             gs("option"), gs(""),
+     *             gs("vector_params", Map.of(
+     *                 gs("data_type", gs("FLOAT32"),
+     *                 gs("initial_capacity", 1000L,
+     *                 gs("current_capacity", 1000L,
+     *                 gs("distance_metric", gs("L2"),
+     *                 gs("dimension", 6L,
+     *                 gs("block_size", 1024L,
+     *                 gs("algorithm", gs("FLAT")
+     *           )
+     *         ),
+     *         Map.of(
+     *             gs("identifier"), gs("name"),
+     *             gs("type"), gs("TEXT"),
+     *             gs("field_name"), gs("name"),
+     *             gs("option"), gs("")
+     *         ),
+     *     }
+     * );
+     * }
+ */ + public static CompletableFuture> info( + @NonNull BaseClient client, @NonNull GlideString indexName) { + return executeCommand(client, new GlideString[] {gs("FT.INFO"), indexName}, true); + } + /** * A wrapper for custom command API. * diff --git a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java index fbc3eab196..6e690e77c9 100644 --- a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java +++ b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java @@ -16,6 +16,7 @@ import glide.api.GlideClusterClient; import glide.api.commands.servermodules.FT; +import glide.api.models.GlideString; import glide.api.models.commands.FT.FTAggregateOptions; import glide.api.models.commands.FT.FTAggregateOptions.Apply; import glide.api.models.commands.FT.FTAggregateOptions.GroupBy; @@ -142,12 +143,12 @@ public void ft_create() { .get()); // create an index with multiple prefixes - var name = UUID.randomUUID().toString(); + var index = UUID.randomUUID().toString(); assertEquals( OK, FT.create( client, - name, + index, new FieldInfo[] { new FieldInfo("author_id", new TagField()), new FieldInfo("author_ids", new TagField()), @@ -167,7 +168,7 @@ public void ft_create() { () -> FT.create( client, - name, + index, new FieldInfo[] { new FieldInfo("title", new TextField()), new FieldInfo("name", new TextField()) @@ -713,4 +714,66 @@ public void ft_aggregate() { 9.)), Set.of(aggreg)); } + + @SuppressWarnings("unchecked") + @Test + @SneakyThrows + public void ft_info() { + // TODO use FT.LIST when it is done + var indices = (Object[]) client.customCommand(new String[] {"FT._LIST"}).get().getSingleValue(); + + // check that we can get a response for all indices (no crashes on value conversion or so) + for (var idx : indices) { + FT.info(client, (String) idx).get(); + } + + var index = UUID.randomUUID().toString(); + assertEquals( + OK, + FT.create( + client, + index, + new FieldInfo[] { + new FieldInfo( + "$.vec", "VEC", VectorFieldHnsw.builder(DistanceMetric.COSINE, 42).build()), + new FieldInfo("$.name", new TextField()), + }, + FTCreateOptions.builder() + .indexType(IndexType.JSON) + .prefixes(new String[] {"123"}) + .build()) + .get()); + + var response = FT.info(client, index).get(); + assertEquals(gs(index), response.get("index_name")); + assertEquals(gs("JSON"), response.get("key_type")); + assertArrayEquals(new GlideString[] {gs("123")}, (Object[]) response.get("key_prefixes")); + var fields = (Object[]) response.get("fields"); + assertEquals(2, fields.length); + var f1 = (Map) fields[1]; + assertEquals(gs("$.vec"), f1.get(gs("identifier"))); + assertEquals(gs("VECTOR"), f1.get(gs("type"))); + assertEquals(gs("VEC"), f1.get(gs("field_name"))); + var f1params = (Map) f1.get(gs("vector_params")); + assertEquals(gs("COSINE"), f1params.get(gs("distance_metric"))); + assertEquals(42L, f1params.get(gs("dimension"))); + + assertEquals( + Map.of( + gs("identifier"), + gs("$.name"), + gs("type"), + gs("TEXT"), + gs("field_name"), + gs("$.name"), + gs("option"), + gs("")), + fields[0]); + + // querying a missing index + assertEquals(OK, FT.dropindex(client, index).get()); + var exception = assertThrows(ExecutionException.class, () -> FT.info(client, index).get()); + assertInstanceOf(RequestException.class, exception.getCause()); + assertTrue(exception.getMessage().contains("Index not found")); + } } From b1689723896c10ec3d43493ed37c5849029502a2 Mon Sep 17 00:00:00 2001 From: Chloe Yip <168601573+cyip10@users.noreply.github.com> Date: Fri, 18 Oct 2024 14:04:47 -0700 Subject: [PATCH 024/180] Java: FT.ALIASADD, FT.ALIASDEL, FT.ALIASUPDATE (#2442) Signed-off-by: Andrew Carbonetto Co-authored-by: Andrew Carbonetto --- CHANGELOG.md | 1 + .../glide/api/commands/servermodules/FT.java | 110 ++++++++++++++++++ .../java/glide/modules/VectorSearchTests.java | 49 ++++++++ .../glide/async_commands/server_modules/ft.py | 8 +- 4 files changed, 164 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cf0761a6ed..fd185405b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ * Java: Added `FT.SEARCH` ([#2439](https://github.com/valkey-io/valkey-glide/pull/2439)) * Java: Added `FT.AGGREGATE` ([#2466](https://github.com/valkey-io/valkey-glide/pull/2466)) * Java: Added `JSON.SET` and `JSON.GET` ([#2462](https://github.com/valkey-io/valkey-glide/pull/2462)) +* Java: Added `FT.ALIASADD`, `FT.ALIASDEL`, `FT.ALIASUPDATE` ([#2442](https://github.com/valkey-io/valkey-glide/pull/2442)) * Core: Update routing for commands from server modules ([#2461](https://github.com/valkey-io/valkey-glide/pull/2461)) #### Breaking Changes diff --git a/java/client/src/main/java/glide/api/commands/servermodules/FT.java b/java/client/src/main/java/glide/api/commands/servermodules/FT.java index 7a5b514816..38b0ec7096 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/FT.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/FT.java @@ -587,6 +587,116 @@ public static CompletableFuture> info( return executeCommand(client, new GlideString[] {gs("FT.INFO"), indexName}, true); } + /** + * Adds an alias for an index. The new alias name can be used anywhere that an index name is + * required. + * + * @param client The client to execute the command. + * @param aliasName The alias to be added to an index. + * @param indexName The index name for which the alias has to be added. + * @return "OK". + * @example + *
{@code
+     * FT.aliasadd(client, "myalias", "myindex").get(); // "OK"
+     * }
+ */ + public static CompletableFuture aliasadd( + @NonNull BaseClient client, @NonNull String aliasName, @NonNull String indexName) { + return aliasadd(client, gs(aliasName), gs(indexName)); + } + + /** + * Adds an alias for an index. The new alias name can be used anywhere that an index name is + * required. + * + * @param client The client to execute the command. + * @param aliasName The alias to be added to an index. + * @param indexName The index name for which the alias has to be added. + * @return "OK". + * @example + *
{@code
+     * FT.aliasadd(client, gs("myalias"), gs("myindex")).get(); // "OK"
+     * }
+ */ + public static CompletableFuture aliasadd( + @NonNull BaseClient client, @NonNull GlideString aliasName, @NonNull GlideString indexName) { + var args = new GlideString[] {gs("FT.ALIASADD"), aliasName, indexName}; + + return executeCommand(client, args, false); + } + + /** + * Deletes an existing alias for an index. + * + * @param client The client to execute the command. + * @param aliasName The existing alias to be deleted for an index. + * @return "OK". + * @example + *
{@code
+     * FT.aliasdel(client, "myalias").get(); // "OK"
+     * }
+ */ + public static CompletableFuture aliasdel( + @NonNull BaseClient client, @NonNull String aliasName) { + return aliasdel(client, gs(aliasName)); + } + + /** + * Deletes an existing alias for an index. + * + * @param client The client to execute the command. + * @param aliasName The existing alias to be deleted for an index. + * @return "OK". + * @example + *
{@code
+     * FT.aliasdel(client, gs("myalias")).get(); // "OK"
+     * }
+ */ + public static CompletableFuture aliasdel( + @NonNull BaseClient client, @NonNull GlideString aliasName) { + var args = new GlideString[] {gs("FT.ALIASDEL"), aliasName}; + + return executeCommand(client, args, false); + } + + /** + * Updates an existing alias to point to a different physical index. This command only affects + * future references to the alias. + * + * @param client The client to execute the command. + * @param aliasName The alias name. This alias will now be pointed to a different index. + * @param indexName The index name for which an existing alias has to updated. + * @return "OK". + * @example + *
{@code
+     * FT.aliasupdate(client, "myalias", "myindex").get(); // "OK"
+     * }
+ */ + public static CompletableFuture aliasupdate( + @NonNull BaseClient client, @NonNull String aliasName, @NonNull String indexName) { + return aliasupdate(client, gs(aliasName), gs(indexName)); + } + + /** + * Update an existing alias to point to a different physical index. This command only affects + * future references to the alias. + * + * @param client The client to execute the command. + * @param aliasName The alias name. This alias will now be pointed to a different index. + * @param indexName The index name for which an existing alias has to updated. + * @return "OK". + * @example + *
{@code
+     * FT.aliasupdate(client,gs("myalias"), gs("myindex")).get(); // "OK"
+     * }
+ */ + public static CompletableFuture aliasupdate( + @NonNull BaseClient client, @NonNull GlideString aliasName, @NonNull GlideString indexName) { + var args = new GlideString[] {gs("FT.ALIASUPDATE"), aliasName, indexName}; + + return executeCommand(client, args, false); + } + /** * A wrapper for custom command API. * diff --git a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java index 6e690e77c9..107edaf7b1 100644 --- a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java +++ b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java @@ -776,4 +776,53 @@ public void ft_info() { assertInstanceOf(RequestException.class, exception.getCause()); assertTrue(exception.getMessage().contains("Index not found")); } + + @SneakyThrows + @Test + public void ft_aliasadd_aliasdel_aliasupdate() { + + var alias1 = "alias1"; + var alias2 = "a2"; + var indexName = "{" + UUID.randomUUID() + "-index}"; + + // create some indices + assertEquals( + OK, + FT.create( + client, + indexName, + new FieldInfo[] { + new FieldInfo("vec", VectorFieldFlat.builder(DistanceMetric.L2, 2).build()) + }) + .get()); + + assertEquals(OK, FT.aliasadd(client, alias1, indexName).get()); + + // error with adding the same alias to the same index + var exception = + assertThrows(ExecutionException.class, () -> FT.aliasadd(client, alias1, indexName).get()); + assertInstanceOf(RequestException.class, exception.getCause()); + assertTrue(exception.getMessage().contains("Alias already exists")); + + assertEquals(OK, FT.aliasupdate(client, alias2, indexName).get()); + assertEquals(OK, FT.aliasdel(client, alias2).get()); + + // with GlideString: + assertEquals(OK, FT.aliasupdate(client, gs(alias1), gs(indexName)).get()); + assertEquals(OK, FT.aliasdel(client, gs(alias1)).get()); + assertEquals(OK, FT.aliasadd(client, gs(alias2), gs(indexName)).get()); + assertEquals(OK, FT.aliasdel(client, gs(alias2)).get()); + + // exception with calling `aliasdel` on an alias that doesn't exist + exception = assertThrows(ExecutionException.class, () -> FT.aliasdel(client, alias2).get()); + assertInstanceOf(RequestException.class, exception.getCause()); + assertTrue(exception.getMessage().contains("Alias does not exist")); + + // exception with calling `aliasadd` with a nonexisting index + exception = + assertThrows( + ExecutionException.class, () -> FT.aliasadd(client, alias1, "nonexistent_index").get()); + assertInstanceOf(RequestException.class, exception.getCause()); + assertTrue(exception.getMessage().contains("Index does not exist")); + } } diff --git a/python/python/glide/async_commands/server_modules/ft.py b/python/python/glide/async_commands/server_modules/ft.py index 88dcb9b58e..ccd1fd8735 100644 --- a/python/python/glide/async_commands/server_modules/ft.py +++ b/python/python/glide/async_commands/server_modules/ft.py @@ -82,7 +82,7 @@ async def aliasadd( client: TGlideClient, alias: TEncodable, indexName: TEncodable ) -> TOK: """ - Add an alias for an index. The new alias name can be used anywhere that an index name is required. + Adds an alias for an index. The new alias name can be used anywhere that an index name is required. Args: client (TGlideClient): The client to execute the command. @@ -103,11 +103,11 @@ async def aliasadd( async def aliasdel(client: TGlideClient, alias: TEncodable) -> TOK: """ - Delete an existing alias for an index. + Deletes an existing alias for an index. Args: client (TGlideClient): The client to execute the command. - alias (TEncodable): The exisiting alias to be deleted for an index. + alias (TEncodable): The existing alias to be deleted for an index. Returns: TOK: A simple "OK" response. @@ -125,7 +125,7 @@ async def aliasupdate( client: TGlideClient, alias: TEncodable, indexName: TEncodable ) -> TOK: """ - Update an existing alias to point to a different physical index. This command only affects future references to the alias. + Updates an existing alias to point to a different physical index. This command only affects future references to the alias. Args: client (TGlideClient): The client to execute the command. From 91cafe8a3af44fb29debfe9ad9b9257dd091dde2 Mon Sep 17 00:00:00 2001 From: tjzhang-BQ <111323543+tjzhang-BQ@users.noreply.github.com> Date: Fri, 18 Oct 2024 14:56:10 -0700 Subject: [PATCH 025/180] Node: add commands JSON.GET and JSON.SET (#2427) * Node: add commands JSON.GET and JSON.SET Signed-off-by: TJ Zhang --- CHANGELOG.md | 1 + node/index.ts | 1 + node/npm/glide/index.ts | 4 + node/package.json | 4 +- node/src/server-modules/GlideJson.ts | 180 +++++++++++++++++++++++++++ node/tests/ServerModules.test.ts | 174 +++++++++++++++++++++++++- 6 files changed, 360 insertions(+), 4 deletions(-) create mode 100644 node/src/server-modules/GlideJson.ts diff --git a/CHANGELOG.md b/CHANGELOG.md index fd185405b6..d61e36f02a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ * Java: Added `JSON.SET` and `JSON.GET` ([#2462](https://github.com/valkey-io/valkey-glide/pull/2462)) * Java: Added `FT.ALIASADD`, `FT.ALIASDEL`, `FT.ALIASUPDATE` ([#2442](https://github.com/valkey-io/valkey-glide/pull/2442)) * Core: Update routing for commands from server modules ([#2461](https://github.com/valkey-io/valkey-glide/pull/2461)) +* Node: Added `JSON.SET` and `JSON.GET` ([#2427](https://github.com/valkey-io/valkey-glide/pull/2427)) #### Breaking Changes diff --git a/node/index.ts b/node/index.ts index e1f04fd9c2..2127178d07 100644 --- a/node/index.ts +++ b/node/index.ts @@ -9,4 +9,5 @@ export * from "./src/Errors"; export * from "./src/GlideClient"; export * from "./src/GlideClusterClient"; export * from "./src/Logger"; +export * from "./src/server-modules/GlideJson"; export * from "./src/Transaction"; diff --git a/node/npm/glide/index.ts b/node/npm/glide/index.ts index 98171bfef4..245ddcd2e4 100644 --- a/node/npm/glide/index.ts +++ b/node/npm/glide/index.ts @@ -117,8 +117,10 @@ function initialize() { GlideClient, GlideClusterClient, GlideClientConfiguration, + GlideJson, GlideRecord, GlideString, + JsonGetOptions, SortedSetDataType, StreamEntryDataType, HashDataType, @@ -227,7 +229,9 @@ function initialize() { DecoderOption, GeoAddOptions, GlideRecord, + GlideJson, GlideString, + JsonGetOptions, SortedSetDataType, StreamEntryDataType, HashDataType, diff --git a/node/package.json b/node/package.json index 1979e408e4..b7595bc79d 100644 --- a/node/package.json +++ b/node/package.json @@ -32,14 +32,14 @@ "compile-protobuf-files": "cd src && pbjs -t static-module -o ProtobufMessage.js ../../glide-core/src/protobuf/*.proto && pbts -o ProtobufMessage.d.ts ProtobufMessage.js", "fix-protobuf-file": "replace 'this\\.encode\\(message, writer\\)\\.ldelim' 'this.encode(message, writer && writer.len ? writer.fork() : writer).ldelim' src/ProtobufMessage.js", "test": "npm run build-test-utils && jest --verbose --runInBand --testPathIgnorePatterns='ServerModules'", + "test-modules": "npm run build-test-utils && jest --verbose --runInBand --testPathPattern='ServerModules'", "build-test-utils": "cd ../utils && npm i && npm run build", "lint:fix": "npm run install-linting && npx eslint -c ../eslint.config.mjs --fix && npm run prettier:format", "lint": "npm run install-linting && npx eslint -c ../eslint.config.mjs && npm run prettier:check:ci", "install-linting": "cd ../ & npm install", "prepack": "npmignore --auto", "prettier:check:ci": "npx prettier --check . --ignore-unknown '!**/*.{js,d.ts}'", - "prettier:format": "npx prettier --write . --ignore-unknown '!**/*.{js,d.ts}'", - "test-modules": "npm run build-test-utils && jest --verbose --runInBand --testNamePattern='ServerModules'" + "prettier:format": "npx prettier --write . --ignore-unknown '!**/*.{js,d.ts}'" }, "devDependencies": { "@jest/globals": "^29.7.0", diff --git a/node/src/server-modules/GlideJson.ts b/node/src/server-modules/GlideJson.ts new file mode 100644 index 0000000000..6dd57b16d3 --- /dev/null +++ b/node/src/server-modules/GlideJson.ts @@ -0,0 +1,180 @@ +/** + * Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + */ + +import { BaseClient, DecoderOption, GlideString } from "../BaseClient"; +import { ConditionalChange } from "../Commands"; +import { GlideClient } from "../GlideClient"; +import { GlideClusterClient, RouteOption } from "../GlideClusterClient"; + +export type ReturnTypeJson = GlideString | (GlideString | null)[]; + +/** + * Represents options for formatting JSON data, to be used in the [JSON.GET](https://valkey.io/commands/json.get/) command. + */ +export interface JsonGetOptions { + /** The path or list of paths within the JSON document. Default is root `$`. */ + paths?: GlideString[]; + /** Sets an indentation string for nested levels. */ + indent?: GlideString; + /** Sets a string that's printed at the end of each line. */ + newline?: GlideString; + /** Sets a string that's put between a key and a value. */ + space?: GlideString; + /** Optional, allowed to be present for legacy compatibility and has no other effect */ + noescape?: boolean; +} + +/** + * @internal + */ +function _jsonGetOptionsToArgs(options: JsonGetOptions): GlideString[] { + const result: GlideString[] = []; + + if (options.paths !== undefined) { + result.push(...options.paths); + } + + if (options.indent !== undefined) { + result.push("INDENT", options.indent); + } + + if (options.newline !== undefined) { + result.push("NEWLINE", options.newline); + } + + if (options.space !== undefined) { + result.push("SPACE", options.space); + } + + if (options.noescape !== undefined) { + result.push("NOESCAPE"); + } + + return result; +} + +/** + * @internal + */ +function _executeCommand( + client: BaseClient, + args: GlideString[], + options?: RouteOption & DecoderOption, +): Promise { + if (client instanceof GlideClient) { + return (client as GlideClient).customCommand( + args, + options, + ) as Promise; + } else { + return (client as GlideClusterClient).customCommand( + args, + options, + ) as Promise; + } +} + +/** Module for JSON commands. */ +export class GlideJson { + /** + * Sets the JSON value at the specified `path` stored at `key`. + * + * @param key - The key of the JSON document. + * @param path - Represents the path within the JSON document where the value will be set. + * The key will be modified only if `value` is added as the last child in the specified `path`, or if the specified `path` acts as the parent of a new child being added. + * @param value - The value to set at the specific path, in JSON formatted bytes or str. + * @param options - (Optional) Additional parameters: + * - (Optional) `conditionalChange` - Set the value only if the given condition is met (within the key or path). + * Equivalent to [`XX` | `NX`] in the module API. Defaults to null. + * - (Optional) `decoder`: see {@link DecoderOption}. + * + * @returns If the value is successfully set, returns `"OK"`. + * If `value` isn't set because of `conditionalChange`, returns `null`. + * + * @example + * ```typescript + * const value = {a: 1.0, b:2}; + * const jsonStr = JSON.stringify(value); + * const result = await GlideJson.set("doc", "$", jsonStr); + * console.log(result); // 'OK' - Indicates successful setting of the value at path '$' in the key stored at `doc`. + * + * const jsonGetStr = await GlideJson.get(client, "doc", "$"); // Returns the value at path '$' in the JSON document stored at `doc` as JSON string. + * console.log(jsonGetStr); // '[{"a":1.0,"b":2}]' + * console.log(JSON.stringify(jsonGetStr)); // [{"a": 1.0, "b": 2}] # JSON object retrieved from the key `doc` + * ``` + */ + static async set( + client: BaseClient, + key: GlideString, + path: GlideString, + value: GlideString, + options?: { conditionalChange: ConditionalChange } & DecoderOption, + ): Promise<"OK" | null> { + const args: GlideString[] = ["JSON.SET", key, path, value]; + + if (options?.conditionalChange !== undefined) { + args.push(options.conditionalChange); + } + + return _executeCommand<"OK" | null>(client, args, options); + } + + /** + * Retrieves the JSON value at the specified `paths` stored at `key`. + * + * @param key - The key of the JSON document. + * @param options - Options for formatting the byte representation of the JSON data. See {@link JsonGetOptions}. + * @returns ReturnTypeJson: + * - If one path is given: + * - For JSONPath (path starts with `$`): + * - Returns a stringified JSON list of bytes replies for every possible path, + * or a byte string representation of an empty array, if path doesn't exist. + * If `key` doesn't exist, returns null. + * - For legacy path (path doesn't start with `$`): + * Returns a byte string representation of the value in `path`. + * If `path` doesn't exist, an error is raised. + * If `key` doesn't exist, returns null. + * - If multiple paths are given: + * Returns a stringified JSON object in bytes, in which each path is a key, and it's corresponding value, is the value as if the path was executed in the command as a single path. + * In case of multiple paths, and `paths` are a mix of both JSONPath and legacy path, the command behaves as if all are JSONPath paths. + * + * @example + * ```typescript + * const jsonStr = await client.jsonGet('doc', '$'); + * console.log(JSON.parse(jsonStr as string)); + * // Output: [{"a": 1.0, "b" :2}] - JSON object retrieved from the key `doc`. + * + * const jsonData = await client.jsonGet('doc', '$'); + * console.log(jsonData); + * // Output: '[{"a":1.0,"b":2}]' - Returns the value at path '$' in the JSON document stored at `doc`. + * + * const formattedJson = await client.jsonGet('doc', { + * ['$.a', '$.b'] + * indent: " ", + * newline: "\n", + * space: " " + * }); + * console.log(formattedJson); + * // Output: "{\n \"$.a\": [\n 1.0\n ],\n \"$.b\": [\n 2\n ]\n}" - Returns values at paths '$.a' and '$.b' with custom format. + * + * const nonExistingPath = await client.jsonGet('doc', '$.non_existing_path'); + * console.log(nonExistingPath); + * // Output: "[]" - Empty array since the path does not exist in the JSON document. + * ``` + */ + static async get( + client: GlideClient | GlideClusterClient, + key: GlideString, + options?: JsonGetOptions & DecoderOption, + ): Promise { + const args = ["JSON.GET", key]; + + if (options) { + const optionArgs = _jsonGetOptionsToArgs(options); + args.push(...optionArgs); + } + + return _executeCommand(client, args, options); + } +} diff --git a/node/tests/ServerModules.test.ts b/node/tests/ServerModules.test.ts index 6a11884385..855dfd4305 100644 --- a/node/tests/ServerModules.test.ts +++ b/node/tests/ServerModules.test.ts @@ -9,7 +9,15 @@ import { expect, it, } from "@jest/globals"; -import { GlideClusterClient, InfoOptions, ProtocolVersion } from ".."; +import { v4 as uuidv4 } from "uuid"; +import { + ConditionalChange, + GlideClusterClient, + GlideJson, + InfoOptions, + JsonGetOptions, + ProtocolVersion, +} from ".."; import { ValkeyCluster } from "../../utils/TestUtils"; import { flushAndCloseClient, @@ -44,7 +52,7 @@ describe("GlideJson", () => { }, TIMEOUT); it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( - "ServerModules check modules loaded", + "check modules loaded", async (protocol) => { client = await GlideClusterClient.createClient( getClientConfigurationOption(cluster.getAddresses(), protocol), @@ -57,4 +65,166 @@ describe("GlideJson", () => { expect(info).toContain("# search_index_stats"); }, ); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "json.set and json.get tests", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption(cluster.getAddresses(), protocol), + ); + const key = uuidv4(); + const jsonValue = { a: 1.0, b: 2 }; + + // JSON.set + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + + // JSON.get + let result = await GlideJson.get(client, key, { paths: ["."] }); + expect(JSON.parse(result.toString())).toEqual(jsonValue); + + // JSON.get with array of paths + result = await GlideJson.get(client, key, { + paths: ["$.a", "$.b"], + }); + expect(JSON.parse(result.toString())).toEqual({ + "$.a": [1.0], + "$.b": [2], + }); + + // JSON.get with non-existing key + expect( + await GlideJson.get(client, "non_existing_key", { + paths: ["$"], + }), + ); + + // JSON.get with non-existing path + result = await GlideJson.get(client, key, { paths: ["$.d"] }); + expect(result).toEqual("[]"); + }, + ); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "json.set and json.get tests with multiple value", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption(cluster.getAddresses(), protocol), + ); + const key = uuidv4(); + + // JSON.set with complex object + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify({ a: { c: 1, d: 4 }, b: { c: 2 }, c: true }), + ), + ).toBe("OK"); + + // JSON.get with deep path + let result = await GlideJson.get(client, key, { paths: ["$..c"] }); + expect(JSON.parse(result.toString())).toEqual([true, 1, 2]); + + // JSON.set with deep path + expect( + await GlideJson.set(client, key, "$..c", '"new_value"'), + ).toBe("OK"); + + // verify JSON.set result + result = await GlideJson.get(client, key, { paths: ["$..c"] }); + expect(JSON.parse(result.toString())).toEqual([ + "new_value", + "new_value", + "new_value", + ]); + }, + ); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "json.set conditional set", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption(cluster.getAddresses(), protocol), + ); + const key = uuidv4(); + const value = JSON.stringify({ a: 1.0, b: 2 }); + + expect( + await GlideJson.set(client, key, "$", value, { + conditionalChange: ConditionalChange.ONLY_IF_EXISTS, + }), + ).toBeNull(); + + expect( + await GlideJson.set(client, key, "$", value, { + conditionalChange: ConditionalChange.ONLY_IF_DOES_NOT_EXIST, + }), + ).toBe("OK"); + + expect( + await GlideJson.set(client, key, "$.a", "4.5", { + conditionalChange: ConditionalChange.ONLY_IF_DOES_NOT_EXIST, + }), + ).toBeNull(); + let result = await GlideJson.get(client, key, { paths: [".a"] }); + expect(result).toEqual("1"); + + expect( + await GlideJson.set(client, key, "$.a", "4.5", { + conditionalChange: ConditionalChange.ONLY_IF_EXISTS, + }), + ).toBe("OK"); + result = await GlideJson.get(client, key, { paths: [".a"] }); + expect(result).toEqual("4.5"); + }, + ); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "json.get formatting", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption(cluster.getAddresses(), protocol), + ); + const key = uuidv4(); + // Set initial JSON value + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify({ a: 1.0, b: 2, c: { d: 3, e: 4 } }), + ), + ).toBe("OK"); + // JSON.get with formatting options + let result = await GlideJson.get(client, key, { + paths: ["$"], + indent: " ", + newline: "\n", + space: " ", + } as JsonGetOptions); + + const expectedResult1 = + '[\n {\n "a": 1,\n "b": 2,\n "c": {\n "d": 3,\n "e": 4\n }\n }\n]'; + expect(result).toEqual(expectedResult1); + // JSON.get with different formatting options + result = await GlideJson.get(client, key, { + paths: ["$"], + indent: "~", + newline: "\n", + space: "*", + } as JsonGetOptions); + + const expectedResult2 = + '[\n~{\n~~"a":*1,\n~~"b":*2,\n~~"c":*{\n~~~"d":*3,\n~~~"e":*4\n~~}\n~}\n]'; + expect(result).toEqual(expectedResult2); + }, + ); }); From 56b4401e6678ed79daa22ea03d06c73417898ffd Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Fri, 18 Oct 2024 19:23:37 -0700 Subject: [PATCH 026/180] Java: `JSON.ARRINSERT` and `JSON.ARRLEN` (#2476) * `JSON.ARRINSERT` and `JSON.ARRLEN` Signed-off-by: Yury-Fridlyand --- CHANGELOG.md | 1 + .../api/commands/servermodules/Json.java | 279 ++++++++++++++++-- java/integTest/build.gradle | 1 + .../test/java/glide/modules/JsonTests.java | 113 +++++++ .../java/glide/modules/VectorSearchTests.java | 2 +- 5 files changed, 369 insertions(+), 27 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d61e36f02a..71d007cce5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ * Java: Added `FT.SEARCH` ([#2439](https://github.com/valkey-io/valkey-glide/pull/2439)) * Java: Added `FT.AGGREGATE` ([#2466](https://github.com/valkey-io/valkey-glide/pull/2466)) * Java: Added `JSON.SET` and `JSON.GET` ([#2462](https://github.com/valkey-io/valkey-glide/pull/2462)) +* Java: Added `JSON.ARRINSERT` and `JSON.ARRLEN` ([#2476](https://github.com/valkey-io/valkey-glide/pull/2476)) * Java: Added `FT.ALIASADD`, `FT.ALIASDEL`, `FT.ALIASUPDATE` ([#2442](https://github.com/valkey-io/valkey-glide/pull/2442)) * Core: Update routing for commands from server modules ([#2461](https://github.com/valkey-io/valkey-glide/pull/2461)) * Node: Added `JSON.SET` and `JSON.GET` ([#2427](https://github.com/valkey-io/valkey-glide/pull/2427)) diff --git a/java/client/src/main/java/glide/api/commands/servermodules/Json.java b/java/client/src/main/java/glide/api/commands/servermodules/Json.java index 5aeb9fd851..8398a9f168 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/Json.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/Json.java @@ -2,6 +2,7 @@ package glide.api.commands.servermodules; import static glide.api.models.GlideString.gs; +import static glide.utils.ArrayTransformUtils.concatenateArrays; import glide.api.BaseClient; import glide.api.GlideClient; @@ -12,16 +13,17 @@ import glide.api.models.commands.json.JsonGetOptions; import glide.api.models.commands.json.JsonGetOptionsBinary; import glide.utils.ArgsBuilder; -import glide.utils.ArrayTransformUtils; import java.util.concurrent.CompletableFuture; import lombok.NonNull; /** Module for JSON commands. */ public class Json { - public static final String JSON_PREFIX = "JSON."; + private static final String JSON_PREFIX = "JSON."; public static final String JSON_SET = JSON_PREFIX + "SET"; public static final String JSON_GET = JSON_PREFIX + "GET"; + private static final String JSON_ARRINSERT = JSON_PREFIX + "ARRINSERT"; + private static final String JSON_ARRLEN = JSON_PREFIX + "ARRLEN"; private Json() {} @@ -38,7 +40,7 @@ private Json() {} * @return A simple "OK" response if the value is successfully set. * @example *
{@code
-     * String value = Json.set(client, "doc", , ".", "{'a': 1.0, 'b': 2}").get();
+     * String value = Json.set(client, "doc", ".", "{'a': 1.0, 'b': 2}").get();
      * assert value.equals("OK");
      * }
*/ @@ -63,7 +65,7 @@ public static CompletableFuture set( * @return A simple "OK" response if the value is successfully set. * @example *
{@code
-     * String value = client.Json.set(client, gs("doc"), , gs("."), gs("{'a': 1.0, 'b': 2}")).get();
+     * String value = Json.set(client, gs("doc"), gs("."), gs("{'a': 1.0, 'b': 2}")).get();
      * assert value.equals("OK");
      * }
*/ @@ -90,7 +92,7 @@ public static CompletableFuture set( * set because of setCondition, returns null. * @example *
{@code
-     * String value = client.Json.set(client, "doc", , ".", "{'a': 1.0, 'b': 2}", ConditionalChange.ONLY_IF_DOES_NOT_EXIST).get();
+     * String value = Json.set(client, "doc", ".", "{'a': 1.0, 'b': 2}", ConditionalChange.ONLY_IF_DOES_NOT_EXIST).get();
      * assert value.equals("OK");
      * }
*/ @@ -119,7 +121,7 @@ public static CompletableFuture set( * set because of setCondition, returns null. * @example *
{@code
-     * String value = client.Json.set(client, gs("doc"), , gs("."), gs("{'a': 1.0, 'b': 2}"), ConditionalChange.ONLY_IF_DOES_NOT_EXIST).get();
+     * String value = Json.set(client, gs("doc"), gs("."), gs("{'a': 1.0, 'b': 2}"), ConditionalChange.ONLY_IF_DOES_NOT_EXIST).get();
      * assert value.equals("OK");
      * }
*/ @@ -143,7 +145,7 @@ public static CompletableFuture set( * exist, returns null. * @example *
{@code
-     * String value = client.Json.get(client, "doc").get();
+     * String value = Json.get(client, "doc").get();
      * assert value.equals("{'a': 1.0, 'b': 2}");
      * }
*/ @@ -160,7 +162,7 @@ public static CompletableFuture get(@NonNull BaseClient client, @NonNull * exist, returns null. * @example *
{@code
-     * GlideString value = client.Json.get(client, gs("doc")).get();
+     * GlideString value = Json.get(client, gs("doc")).get();
      * assert value.equals(gs("{'a': 1.0, 'b': 2}"));
      * }
*/ @@ -195,16 +197,15 @@ public static CompletableFuture get( * path, the command behaves as if all are JSONPath paths. * @example *
{@code
-     * String value = client.Json.get(client, "doc", new String[] {"$"}).get();
+     * String value = Json.get(client, "doc", new String[] {"$"}).get();
      * assert value.equals("{'a': 1.0, 'b': 2}");
-     * String value = client.Json.get(client, "doc", new String[] {"$.a", "$.b"}).get();
+     * String value = Json.get(client, "doc", new String[] {"$.a", "$.b"}).get();
      * assert value.equals("{\"$.a\": [1.0], \"$.b\": [2]}");
      * }
*/ public static CompletableFuture get( @NonNull BaseClient client, @NonNull String key, @NonNull String[] paths) { - return executeCommand( - client, ArrayTransformUtils.concatenateArrays(new String[] {JSON_GET, key}, paths)); + return executeCommand(client, concatenateArrays(new String[] {JSON_GET, key}, paths)); } /** @@ -233,17 +234,15 @@ public static CompletableFuture get( * path, the command behaves as if all are JSONPath paths. * @example *
{@code
-     * GlideString value = client.Json.get(client, gs("doc"), new GlideString[] {gs("$")}).get();
+     * GlideString value = Json.get(client, gs("doc"), new GlideString[] {gs("$")}).get();
      * assert value.equals(gs("{'a': 1.0, 'b': 2}"));
-     * GlideString value = client.Json.get(client, gs("doc"), new GlideString[] {gs("$.a"), gs("$.b")}).get();
+     * GlideString value = Json.get(client, gs("doc"), new GlideString[] {gs("$.a"), gs("$.b")}).get();
      * assert value.equals(gs("{\"$.a\": [1.0], \"$.b\": [2]}"));
      * }
*/ public static CompletableFuture get( @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString[] paths) { - return executeCommand( - client, - ArrayTransformUtils.concatenateArrays(new GlideString[] {gs(JSON_GET), key}, paths)); + return executeCommand(client, concatenateArrays(new GlideString[] {gs(JSON_GET), key}, paths)); } /** @@ -262,15 +261,14 @@ public static CompletableFuture get( * .space(" ") * .newline("\n") * .build(); - * String value = client.Json.get(client, "doc", "$", options).get(); + * String value = Json.get(client, "doc", "$", options).get(); * assert value.equals("{\n \"a\": \n 1.0\n ,\n \"b\": \n 2\n }"); * } */ public static CompletableFuture get( @NonNull BaseClient client, @NonNull String key, @NonNull JsonGetOptions options) { return executeCommand( - client, - ArrayTransformUtils.concatenateArrays(new String[] {JSON_GET, key}, options.toArgs())); + client, concatenateArrays(new String[] {JSON_GET, key}, options.toArgs())); } /** @@ -289,7 +287,7 @@ public static CompletableFuture get( * .space(" ") * .newline("\n") * .build(); - * GlideString value = client.Json.get(client, gs("doc"), gs("$"), options).get(); + * GlideString value = Json.get(client, gs("doc"), gs("$"), options).get(); * assert value.equals(gs("{\n \"a\": \n 1.0\n ,\n \"b\": \n 2\n }")); * } */ @@ -332,7 +330,7 @@ public static CompletableFuture get( * .space(" ") * .newline("\n") * .build(); - * String value = client.Json.get(client, "doc", new String[] {"$.a", "$.b"}, options).get(); + * String value = Json.get(client, "doc", new String[] {"$.a", "$.b"}, options).get(); * assert value.equals("{\n \"$.a\": [\n 1.0\n ],\n \"$.b\": [\n 2\n ]\n}"); * } */ @@ -342,9 +340,7 @@ public static CompletableFuture get( @NonNull String[] paths, @NonNull JsonGetOptions options) { return executeCommand( - client, - ArrayTransformUtils.concatenateArrays( - new String[] {JSON_GET, key}, options.toArgs(), paths)); + client, concatenateArrays(new String[] {JSON_GET, key}, options.toArgs(), paths)); } /** @@ -380,7 +376,7 @@ public static CompletableFuture get( * .space(" ") * .newline("\n") * .build(); - * GlideString value = client.Json.get(client, gs("doc"), new GlideString[] {gs("$.a"), gs("$.b")}, options).get(); + * GlideString value = Json.get(client, gs("doc"), new GlideString[] {gs("$.a"), gs("$.b")}, options).get(); * assert value.equals(gs("{\n \"$.a\": [\n 1.0\n ],\n \"$.b\": [\n 2\n ]\n}")); * } */ @@ -394,6 +390,237 @@ public static CompletableFuture get( new ArgsBuilder().add(gs(JSON_GET)).add(key).add(options.toArgs()).add(paths).toArray()); } + /** + * Inserts one or more values into the array at the specified path within the JSON + * document stored at key, before the given index. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param index The array index before which values are inserted. + * @param values The JSON values to be inserted into the array, in JSON formatted bytes or str. + * JSON string values must be wrapped with quotes. For example, to append "foo", + * pass "\"foo\"". + * @return + *
    + *
  • For JSONPath (path starts with $):
    + * Returns an Object[] with a list of integers for every possible path, + * indicating the new length of the array, or null for JSON values matching + * the path that are not an array. If path does not exist, an empty array + * will be returned. + *
  • For legacy path (path doesn't start with $):
    + * Returns an integer representing the new length of the array. If multiple paths are + * matched, returns the length of the first modified array. If path doesn't + * exist or the value at path is not an array, an error is raised. + *
+ * If the index is out of bounds or key doesn't exist, an error is raised. + * @example + *
{@code
+     * Json.set(client, "doc", "$", "[[], [\"a\"], [\"a\", \"b\"]]").get();
+     * var newValues = new String[] { "\"c\"", "{\"key\": \"value\"}", "true", "null", "[\"bar\"]" };
+     * var res = Json.arrinsert(client, "doc", "$[*]", 0, newValues).get();
+     * assert Arrays.equals((Object[]) res, new int[] { 5, 6, 7 }); // New lengths of arrays after insertion
+     * var doc = Json.get(client, "doc").get();
+     * assert doc.equals("[[\"c\", {\"key\": \"value\"}, true, null, [\"bar\"]], [\"c\", {\"key\": \"value\"}, "
+     *     + "true, null, [\"bar\"], \"a\"], [\"c\", {\"key\": \"value\"}, true, null, [\"bar\"], \"a\", \"b\"]]");
+     *
+     * Json.set(client, "doc", "$", "[[], [\"a\"], [\"a\", \"b\"]]").get();
+     * res = Json.arrinsert(client, "doc", ".", 0, new String[] { "\"c\"" }).get();
+     * assert res == 4 // New length of the root array after insertion
+     * doc = Json.get(client, "doc").get();
+     * assert doc.equals("[\"c\", [], [\"a\"], [\"a\", \"b\"]]");
+     * }
+ */ + public static CompletableFuture arrinsert( + @NonNull BaseClient client, + @NonNull String key, + @NonNull String path, + int index, + @NonNull String[] values) { + return executeCommand( + client, + concatenateArrays( + new String[] {JSON_ARRINSERT, key, path, Integer.toString(index)}, values)); + } + + /** + * Inserts one or more values into the array at the specified path within the JSON + * document stored at key, before the given index. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param index The array index before which values are inserted. + * @param values The JSON values to be inserted into the array, in JSON formatted bytes or str. + * JSON string values must be wrapped with quotes. For example, to append "foo", + * pass "\"foo\"". + * @return + *
    + *
  • For JSONPath (path starts with $):
    + * Returns an Object[] with a list of integers for every possible path, + * indicating the new length of the array, or null for JSON values matching + * the path that are not an array. If path does not exist, an empty array + * will be returned. + *
  • For legacy path (path doesn't start with $):
    + * Returns an integer representing the new length of the array. If multiple paths are + * matched, returns the length of the first modified array. If path doesn't + * exist or the value at path is not an array, an error is raised. + *
+ * If the index is out of bounds or key doesn't exist, an error is raised. + * @example + *
{@code
+     * Json.set(client, "doc", "$", "[[], [\"a\"], [\"a\", \"b\"]]").get();
+     * var newValues = new GlideString[] { gs("\"c\""), gs("{\"key\": \"value\"}"), gs("true"), gs("null"), gs("[\"bar\"]") };
+     * var res = Json.arrinsert(client, gs("doc"), gs("$[*]"), 0, newValues).get();
+     * assert Arrays.equals((Object[]) res, new int[] { 5, 6, 7 }); // New lengths of arrays after insertion
+     * var doc = Json.get(client, "doc").get();
+     * assert doc.equals("[[\"c\", {\"key\": \"value\"}, true, null, [\"bar\"]], [\"c\", {\"key\": \"value\"}, "
+     *     + "true, null, [\"bar\"], \"a\"], [\"c\", {\"key\": \"value\"}, true, null, [\"bar\"], \"a\", \"b\"]]");
+     *
+     * Json.set(client, "doc", "$", "[[], [\"a\"], [\"a\", \"b\"]]").get();
+     * res = Json.arrinsert(client, gs("doc"), gs("."), 0, new GlideString[] { gs("\"c\"") }).get();
+     * assert res == 4 // New length of the root array after insertion
+     * doc = Json.get(client, "doc").get();
+     * assert doc.equals("[\"c\", [], [\"a\"], [\"a\", \"b\"]]");
+     * }
+ */ + public static CompletableFuture arrinsert( + @NonNull BaseClient client, + @NonNull GlideString key, + @NonNull GlideString path, + int index, + @NonNull GlideString[] values) { + return executeCommand( + client, + new ArgsBuilder() + .add(gs(JSON_ARRINSERT)) + .add(key) + .add(path) + .add(Integer.toString(index)) + .add(values) + .toArray()); + } + + /** + * Retrieves the length of the array at the specified path within the JSON document + * stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
    + *
  • For JSONPath (path starts with $):
    + * Returns an Object[] with a list of integers for every possible path, + * indicating the length of the array, or null for JSON values matching the + * path that are not an array. If path does not exist, an empty array will + * be returned. + *
  • For legacy path (path doesn't start with $):
    + * Returns an integer representing the length of the array. If multiple paths are + * matched, returns the length of the first matching array. If path doesn't + * exist or the value at path is not an array, an error is raised. + *
+ * If key doesn't exist, returns null. + * @example + *
{@code
+     * Json.set(client, "doc", "$", "{\"a\": [1, 2, 3], \"b\": {\"a\": [1, 2], \"c\": {\"a\": 42}}}").get();
+     * var res = Json.arrlen(client, "doc", "$").get();
+     * assert Arrays.equals((Object[]) res, new Object[] { null }); // No array at the root path.
+     * res = Json.arrlen(client, "doc", "$.a").get();
+     * assert Arrays.equals((Object[]) res, new Object[] { 3 }); // Retrieves the length of the array at path $.a.
+     * res = Json.arrlen(client, "doc", "$..a").get();
+     * assert Arrays.equals((Object[]) res, new Object[] { 3, 2, null }); // Retrieves lengths of arrays found at all levels of the path `..a`.
+     * res = Json.arrlen(client, "doc", "..a").get();
+     * assert res == 3; // Legacy path retrieves the first array match at path `..a`.
+     * }
+ */ + public static CompletableFuture arrlen( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, new String[] {JSON_ARRLEN, key, path}); + } + + /** + * Retrieves the length of the array at the specified path within the JSON document + * stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
    + *
  • For JSONPath (path starts with $):
    + * Returns an Object[] with a list of integers for every possible path, + * indicating the length of the array, or null for JSON values matching the + * path that are not an array. If path does not exist, an empty array will + * be returned. + *
  • For legacy path (path doesn't start with $):
    + * Returns an integer representing the length of the array. If multiple paths are + * matched, returns the length of the first matching array. If path doesn't + * exist or the value at path is not an array, an error is raised. + *
+ * If key doesn't exist, returns null. + * @example + *
{@code
+     * Json.set(client, "doc", "$", "{\"a\": [1, 2, 3], \"b\": {\"a\": [1, 2], \"c\": {\"a\": 42}}}").get();
+     * var res = Json.arrlen(client, gs("doc"), gs("$")).get();
+     * assert Arrays.equals((Object[]) res, new Object[] { null }); // No array at the root path.
+     * res = Json.arrlen(client, gs("doc"), gs("$.a")).get();
+     * assert Arrays.equals((Object[]) res, new Object[] { 3 }); // Retrieves the length of the array at path $.a.
+     * res = Json.arrlen(client, gs("doc"), gs("$..a")).get();
+     * assert Arrays.equals((Object[]) res, new Object[] { 3, 2, null }); // Retrieves lengths of arrays found at all levels of the path `..a`.
+     * res = Json.arrlen(client, gs("doc"), gs("..a")).get();
+     * assert res == 3; // Legacy path retrieves the first array match at path `..a`.
+     * }
+ */ + public static CompletableFuture arrlen( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand(client, new GlideString[] {gs(JSON_ARRLEN), key, path}); + } + + /** + * Retrieves the length of the array at the root of the JSON document stored at key. + *
+ * Equivalent to {@link #arrlen(BaseClient, String, String)} with path set to + * ".". + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return The array length stored at the root of the document. If document root is not an array, + * an error is raised.
+ * If key doesn't exist, returns null. + * @example + *
{@code
+     * Json.set(client, "doc", "$", "[1, 2, true, null, \"tree\"]").get();
+     * var res = Json.arrlen(client, "doc").get();
+     * assert res == 5;
+     * }
+ */ + public static CompletableFuture arrlen(@NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new String[] {JSON_ARRLEN, key}); + } + + /** + * Retrieves the length of the array at the root of the JSON document stored at key. + * Equivalent to {@link #arrlen(BaseClient, GlideString, GlideString)} with path set + * to gs("."). + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return The array length stored at the root of the document. If document root is not an array, + * an error is raised.
+ * If key doesn't exist, returns null. + * @example + *
{@code
+     * Json.set(client, "doc", "$", "[1, 2, true, null, \"tree\"]").get();
+     * var res = Json.arrlen(client, gs("doc")).get();
+     * assert res == 5;
+     * }
+ */ + public static CompletableFuture arrlen( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new GlideString[] {gs(JSON_ARRLEN), key}); + } + /** * A wrapper for custom command API. * diff --git a/java/integTest/build.gradle b/java/integTest/build.gradle index 70fdf18915..8c6b48a3cc 100644 --- a/java/integTest/build.gradle +++ b/java/integTest/build.gradle @@ -11,6 +11,7 @@ dependencies { implementation project(':client') implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.13.0' + implementation 'com.google.code.gson:gson:2.10.1' // https://github.com/netty/netty/wiki/Native-transports // At the moment, Windows is not supported diff --git a/java/integTest/src/test/java/glide/modules/JsonTests.java b/java/integTest/src/test/java/glide/modules/JsonTests.java index 3bf1c93823..07af34ab3e 100644 --- a/java/integTest/src/test/java/glide/modules/JsonTests.java +++ b/java/integTest/src/test/java/glide/modules/JsonTests.java @@ -6,10 +6,12 @@ import static glide.api.models.GlideString.gs; import static glide.api.models.configuration.RequestRoutingConfiguration.SimpleMultiNodeRoute.ALL_PRIMARIES; import static glide.api.models.configuration.RequestRoutingConfiguration.SimpleSingleNodeRoute.RANDOM; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import com.google.gson.JsonParser; import glide.api.GlideClusterClient; import glide.api.commands.servermodules.Json; import glide.api.models.GlideString; @@ -154,4 +156,115 @@ public void json_set_get_formatting() { .get(); assertEquals(expectedGetResult2, actualGetResult2); } + + @Test + @SneakyThrows + public void arrinsert() { + String key = UUID.randomUUID().toString(); + + String doc = + "{" + + "\"a\": []," + + "\"b\": { \"a\": [1, 2, 3, 4] }," + + "\"c\": { \"a\": \"not an array\" }," + + "\"d\": [{ \"a\": [\"x\", \"y\"] }, { \"a\": [[\"foo\"]] }]," + + "\"e\": [{ \"a\": 42 }, { \"a\": {} }]," + + "\"f\": { \"a\": [true, false, null] }" + + "}"; + assertEquals("OK", Json.set(client, key, "$", doc).get()); + + String[] values = + new String[] { + "\"string_value\"", "123", "{\"key\": \"value\"}", "true", "null", "[\"bar\"]" + }; + var res = Json.arrinsert(client, key, "$..a", 0, values).get(); + + doc = Json.get(client, key).get(); + var expected = + "{" + + " \"a\": [\"string_value\", 123, {\"key\": \"value\"}, true, null, [\"bar\"]]," + + " \"b\": {" + + " \"a\": [" + + " \"string_value\"," + + " 123," + + " {\"key\": \"value\"}," + + " true," + + " null," + + " [\"bar\"]," + + " 1," + + " 2," + + " 3," + + " 4" + + " ]" + + " }," + + " \"c\": {\"a\": \"not an array\"}," + + " \"d\": [" + + " {" + + " \"a\": [" + + " \"string_value\"," + + " 123," + + " {\"key\": \"value\"}," + + " true," + + " null," + + " [\"bar\"]," + + " \"x\"," + + " \"y\"" + + " ]" + + " }," + + " {" + + " \"a\": [" + + " \"string_value\"," + + " 123," + + " {\"key\": \"value\"}," + + " true," + + " null," + + " [\"bar\"]," + + " [\"foo\"]" + + " ]" + + " }" + + " ]," + + " \"e\": [{\"a\": 42}, {\"a\": {}}]," + + " \"f\": {" + + " \"a\": [" + + " \"string_value\"," + + " 123," + + " {\"key\": \"value\"}," + + " true," + + " null," + + " [\"bar\"]," + + " true," + + " false," + + " null" + + " ]" + + " }" + + "}"; + + assertEquals(JsonParser.parseString(expected), JsonParser.parseString(doc)); + } + + @Test + @SneakyThrows + public void arrlen() { + String key = UUID.randomUUID().toString(); + + String doc = "{\"a\": [1, 2, 3], \"b\": {\"a\": [1, 2], \"c\": {\"a\": 42}}}"; + assertEquals("OK", Json.set(client, key, "$", doc).get()); + + var res = Json.arrlen(client, key, "$.a").get(); + assertArrayEquals(new Object[] {3L}, (Object[]) res); + + res = Json.arrlen(client, key, "$..a").get(); + assertArrayEquals(new Object[] {3L, 2L, null}, (Object[]) res); + + // Legacy path retrieves the first array match at ..a + res = Json.arrlen(client, key, "..a").get(); + assertEquals(3L, res); + + doc = "[1, 2, true, null, \"tree\"]"; + assertEquals("OK", Json.set(client, key, "$", doc).get()); + + // no path + res = Json.arrlen(client, key).get(); + assertEquals(5L, res); + } } diff --git a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java index 107edaf7b1..bb39afe19c 100644 --- a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java +++ b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java @@ -254,7 +254,7 @@ public void ft_search() { (byte) 0xBF }))) .get()); - + Thread.sleep(DATA_PROCESSING_TIMEOUT); // let server digest the data and update index var ftsearch = FT.search( client, From 6a3a33f52d1ae4a3c271b72d5cb8fbd97ea9ee87 Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Sun, 20 Oct 2024 02:51:58 -0700 Subject: [PATCH 027/180] Run CI on PR merge on release branches (#2475) * fix CI Signed-off-by: Yury-Fridlyand * fix CI Signed-off-by: Yury-Fridlyand --------- Signed-off-by: Yury-Fridlyand --- .github/workflows/codeql.yml | 22 +++++++------ .github/workflows/csharp.yml | 5 ++- .github/workflows/go.yml | 5 ++- .github/workflows/java.yml | 5 ++- .github/workflows/lint-ts.yml | 5 ++- .github/workflows/node.yml | 5 ++- .github/workflows/python.yml | 5 ++- .github/workflows/rust.yml | 5 ++- .github/workflows/semgrep.yml | 59 ++++++++++++++++++----------------- 9 files changed, 71 insertions(+), 45 deletions(-) diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 1a2b90083d..36ac59f664 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -3,18 +3,20 @@ name: "CodeQL" on: push: branches: - - "main" - - "v.?[0-9]+.[0-9]+.[0-9]+" - - "v.?[0-9]+.[0-9]+" - - "v?[0-9]+.[0-9]+.[0-9]+" - - "v?[0-9]+.[0-9]+" + - "main" + - "v.?[0-9]+.[0-9]+.[0-9]+" + - "v.?[0-9]+.[0-9]+" + - "v?[0-9]+.[0-9]+.[0-9]+" + - "v?[0-9]+.[0-9]+" + - release-* pull_request: branches: - - "main" - - "v.?[0-9]+.[0-9]+.[0-9]+" - - "v.?[0-9]+.[0-9]+" - - "v?[0-9]+.[0-9]+.[0-9]+" - - "v?[0-9]+.[0-9]+" + - "main" + - "v.?[0-9]+.[0-9]+.[0-9]+" + - "v.?[0-9]+.[0-9]+" + - "v?[0-9]+.[0-9]+.[0-9]+" + - "v?[0-9]+.[0-9]+" + - release-* schedule: - cron: "37 18 * * 6" diff --git a/.github/workflows/csharp.yml b/.github/workflows/csharp.yml index 36b380c3e0..aa85d9a991 100644 --- a/.github/workflows/csharp.yml +++ b/.github/workflows/csharp.yml @@ -2,7 +2,10 @@ name: C# tests on: push: - branches: ["main"] + branches: + - main + - release-* + - v* paths: - csharp/** - glide-core/src/** diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 7cdfedef59..6eaf3d1d19 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -2,7 +2,10 @@ name: Go CI on: push: - branches: [ "main" ] + branches: + - main + - release-* + - v* paths: - glide-core/src/** - submodules/** diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index ca626224f4..ebc6a06169 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -2,7 +2,10 @@ name: Java CI on: push: - branches: ["main"] + branches: + - main + - release-* + - v* paths: - glide-core/src/** - submodules/** diff --git a/.github/workflows/lint-ts.yml b/.github/workflows/lint-ts.yml index 686bdd1183..cd324ba3cf 100644 --- a/.github/workflows/lint-ts.yml +++ b/.github/workflows/lint-ts.yml @@ -2,7 +2,10 @@ name: lint-ts on: push: - branches: ["main"] + branches: + - main + - release-* + - v* paths: - benchmarks/node/** - node/** diff --git a/.github/workflows/node.yml b/.github/workflows/node.yml index 634219ea15..a9b6b4be18 100644 --- a/.github/workflows/node.yml +++ b/.github/workflows/node.yml @@ -2,7 +2,10 @@ name: Node on: push: - branches: ["main"] + branches: + - main + - release-* + - v* paths: - glide-core/src/** - submodules/** diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index c85045df07..c3aa78072b 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -2,7 +2,10 @@ name: Python tests on: push: - branches: ["main"] + branches: + - main + - release-* + - v* paths: - python/** - glide-core/src/** diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c022e3e419..c632880a2b 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -2,7 +2,10 @@ name: Rust tests on: push: - branches: [main] + branches: + - main + - release-* + - v* paths: - logger_core/** - glide-core/** diff --git a/.github/workflows/semgrep.yml b/.github/workflows/semgrep.yml index 4bfd9e12ac..58bb7cb238 100644 --- a/.github/workflows/semgrep.yml +++ b/.github/workflows/semgrep.yml @@ -1,36 +1,39 @@ name: Semgrep on: - # Scan changed files in PRs (diff-aware scanning): - pull_request: {} - # Scan on-demand through GitHub Actions interface: - workflow_dispatch: - inputs: - branch: - description: 'The branch to run against the semgrep tool' - required: true - push: - branches: ["main"] - # Schedule the CI job (this method uses cron syntax): - schedule: - - cron: '0 8 * * *' # Sets Semgrep to scan every day at 08:00 UTC. + # Scan changed files in PRs (diff-aware scanning): + pull_request: {} + # Scan on-demand through GitHub Actions interface: + workflow_dispatch: + inputs: + branch: + description: "The branch to run against the semgrep tool" + required: true + push: + branches: + - main + - release-* + - v* + # Schedule the CI job (this method uses cron syntax): + schedule: + - cron: "0 8 * * *" # Sets Semgrep to scan every day at 08:00 UTC. jobs: - semgrep: - # User definable name of this GitHub Actions job. - name: semgrep/ci - # If you are self-hosting, change the following `runs-on` value: - runs-on: ubuntu-latest + semgrep: + # User definable name of this GitHub Actions job. + name: semgrep/ci + # If you are self-hosting, change the following `runs-on` value: + runs-on: ubuntu-latest - container: - # A Docker image with Semgrep installed. Do not change this. - image: semgrep/semgrep + container: + # A Docker image with Semgrep installed. Do not change this. + image: semgrep/semgrep - # Skip any PR created by dependabot to avoid permission issues: - if: (github.actor != 'dependabot[bot]') + # Skip any PR created by dependabot to avoid permission issues: + if: (github.actor != 'dependabot[bot]') - steps: - # Fetch project source with GitHub Actions Checkout. - - uses: actions/checkout@v3 - # Run the "semgrep ci" command on the command line of the docker image. - - run: semgrep ci --config auto --no-suppress-errors --exclude-rule generic.secrets.security.detected-private-key.detected-private-key + steps: + # Fetch project source with GitHub Actions Checkout. + - uses: actions/checkout@v3 + # Run the "semgrep ci" command on the command line of the docker image. + - run: semgrep ci --config auto --no-suppress-errors --exclude-rule generic.secrets.security.detected-private-key.detected-private-key From 914fd60fc74a5ceb7d0c24ff8df55995d32f7042 Mon Sep 17 00:00:00 2001 From: Bar Shaul <88437685+barshaul@users.noreply.github.com> Date: Mon, 21 Oct 2024 17:07:28 +0300 Subject: [PATCH 028/180] =?UTF-8?q?Avoid=20retrying=20on=20IO=20errors=20w?= =?UTF-8?q?hen=20it=E2=80=99s=20unclear=20if=20the=20server=20received=20t?= =?UTF-8?q?he=20request=20(#2479)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Avoid retrying on IO errors when it’s unclear if the server received the request Signed-off-by: barshaul --- .../redis-rs/redis/src/aio/connection.rs | 2 +- .../redis/src/aio/connection_manager.rs | 10 +- .../redis/src/aio/multiplexed_connection.rs | 42 ++-- glide-core/redis-rs/redis/src/cluster.rs | 21 +- .../cluster_async/connections_container.rs | 35 +-- .../redis-rs/redis/src/cluster_async/mod.rs | 85 +++++-- glide-core/redis-rs/redis/src/types.rs | 23 ++ glide-core/redis-rs/redis/tests/test_async.rs | 2 +- .../redis/tests/test_cluster_async.rs | 220 ++++++++++++++---- glide-core/tests/test_socket_listener.rs | 2 +- 10 files changed, 319 insertions(+), 123 deletions(-) diff --git a/glide-core/redis-rs/redis/src/aio/connection.rs b/glide-core/redis-rs/redis/src/aio/connection.rs index 5adef7869f..2b32a7ced3 100644 --- a/glide-core/redis-rs/redis/src/aio/connection.rs +++ b/glide-core/redis-rs/redis/src/aio/connection.rs @@ -7,7 +7,7 @@ use crate::connection::{ resp2_is_pub_sub_state_cleared, resp3_is_pub_sub_state_cleared, ConnectionAddr, ConnectionInfo, Msg, RedisConnectionInfo, }; -#[cfg(any(feature = "tokio-comp"))] +#[cfg(feature = "tokio-comp")] use crate::parser::ValueCodec; use crate::types::{ErrorKind, FromRedisValue, RedisError, RedisFuture, RedisResult, Value}; use crate::{from_owned_redis_value, ProtocolVersion, ToRedisArgs}; diff --git a/glide-core/redis-rs/redis/src/aio/connection_manager.rs b/glide-core/redis-rs/redis/src/aio/connection_manager.rs index dce7b254a5..83d680ae53 100644 --- a/glide-core/redis-rs/redis/src/aio/connection_manager.rs +++ b/glide-core/redis-rs/redis/src/aio/connection_manager.rs @@ -78,12 +78,12 @@ macro_rules! reconnect_if_dropped { }; } -/// Handle a connection result. If there's an I/O error, reconnect. +/// Handle a connection result. If the connection has dropped, reconnect. /// Propagate any error. -macro_rules! reconnect_if_io_error { +macro_rules! reconnect_if_conn_dropped { ($self:expr, $result:expr, $current:expr) => { if let Err(e) = $result { - if e.is_io_error() { + if e.is_connection_dropped() { $self.reconnect($current); } return Err(e); @@ -249,7 +249,7 @@ impl ConnectionManager { .clone() .await .map_err(|e| e.clone_mostly("Reconnecting failed")); - reconnect_if_io_error!(self, connection_result, guard); + reconnect_if_conn_dropped!(self, connection_result, guard); let result = connection_result?.send_packed_command(cmd).await; reconnect_if_dropped!(self, &result, guard); result @@ -270,7 +270,7 @@ impl ConnectionManager { .clone() .await .map_err(|e| e.clone_mostly("Reconnecting failed")); - reconnect_if_io_error!(self, connection_result, guard); + reconnect_if_conn_dropped!(self, connection_result, guard); let result = connection_result? .send_packed_commands(cmd, offset, count) .await; diff --git a/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs b/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs index fb1b62f8a1..c23d4dfca4 100644 --- a/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs +++ b/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs @@ -349,7 +349,7 @@ where &mut self, item: SinkItem, timeout: Duration, - ) -> Result> { + ) -> Result { self.send_recv(item, None, timeout).await } @@ -359,7 +359,7 @@ where // If `None`, this is a single request, not a pipeline of multiple requests. pipeline_response_count: Option, timeout: Duration, - ) -> Result> { + ) -> Result { let (sender, receiver) = oneshot::channel(); self.sender @@ -369,15 +369,29 @@ where output: sender, }) .await - .map_err(|_| None)?; + .map_err(|err| { + // If an error occurs here, it means the request never reached the server, as guaranteed + // by the 'send' function. Since the server did not receive the data, it is safe to retry + // the request. + RedisError::from(( + crate::ErrorKind::FatalSendError, + "Failed to send the request to the server", + err.to_string(), + )) + })?; match Runtime::locate().timeout(timeout, receiver).await { - Ok(Ok(result)) => result.map_err(Some), - Ok(Err(_)) => { - // The `sender` was dropped which likely means that the stream part - // failed for one reason or another - Err(None) + Ok(Ok(result)) => result, + Ok(Err(err)) => { + // The `sender` was dropped, likely indicating a failure in the stream. + // This error suggests that it's unclear whether the server received the request before the connection failed, + // making it unsafe to retry. For example, retrying an INCR request could result in double increments. + Err(RedisError::from(( + crate::ErrorKind::FatalReceiveError, + "Failed to receive a response due to a fatal error", + err.to_string(), + ))) } - Err(elapsed) => Err(Some(elapsed.into())), + Err(elapsed) => Err(elapsed.into()), } } @@ -503,10 +517,7 @@ impl MultiplexedConnection { let result = self .pipeline .send_single(cmd.get_packed_command(), self.response_timeout) - .await - .map_err(|err| { - err.unwrap_or_else(|| RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe))) - }); + .await; if self.protocol != ProtocolVersion::RESP2 { if let Err(e) = &result { if e.is_connection_dropped() { @@ -537,10 +548,7 @@ impl MultiplexedConnection { Some(offset + count), self.response_timeout, ) - .await - .map_err(|err| { - err.unwrap_or_else(|| RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe))) - }); + .await; if self.protocol != ProtocolVersion::RESP2 { if let Err(e) = &result { diff --git a/glide-core/redis-rs/redis/src/cluster.rs b/glide-core/redis-rs/redis/src/cluster.rs index f9c76f5161..170cac47b3 100644 --- a/glide-core/redis-rs/redis/src/cluster.rs +++ b/glide-core/redis-rs/redis/src/cluster.rs @@ -43,7 +43,9 @@ use std::time::Duration; use rand::{seq::IteratorRandom, thread_rng}; +pub use crate::cluster_client::{ClusterClient, ClusterClientBuilder}; use crate::cluster_pipeline::UNROUTABLE_ERROR; +pub use crate::cluster_pipeline::{cluster_pipe, ClusterPipeline}; use crate::cluster_routing::{ MultipleNodeRoutingInfo, ResponsePolicy, Routable, SingleNodeRoutingInfo, }; @@ -54,7 +56,7 @@ use crate::connection::{ connect, Connection, ConnectionAddr, ConnectionInfo, ConnectionLike, RedisConnectionInfo, }; use crate::parser::parse_redis_value; -use crate::types::{ErrorKind, HashMap, RedisError, RedisResult, Value}; +use crate::types::{ErrorKind, HashMap, RedisError, RedisResult, RetryMethod, Value}; pub use crate::TlsMode; // Pub for backwards compatibility use crate::{ cluster_client::ClusterParams, @@ -62,9 +64,6 @@ use crate::{ IntoConnectionInfo, PushInfo, }; -pub use crate::cluster_client::{ClusterClient, ClusterClientBuilder}; -pub use crate::cluster_pipeline::{cluster_pipe, ClusterPipeline}; - use tokio::sync::mpsc; #[cfg(feature = "tls-rustls")] @@ -749,12 +748,12 @@ where retries += 1; match err.retry_method() { - crate::types::RetryMethod::AskRedirect => { + RetryMethod::AskRedirect => { redirected = err .redirect_node() .map(|(node, _slot)| Redirect::Ask(node.to_string())); } - crate::types::RetryMethod::MovedRedirect => { + RetryMethod::MovedRedirect => { // Refresh slots. self.refresh_slots()?; // Request again. @@ -762,8 +761,8 @@ where .redirect_node() .map(|(node, _slot)| Redirect::Moved(node.to_string())); } - crate::types::RetryMethod::WaitAndRetryOnPrimaryRedirectOnReplica - | crate::types::RetryMethod::WaitAndRetry => { + RetryMethod::WaitAndRetryOnPrimaryRedirectOnReplica + | RetryMethod::WaitAndRetry => { // Sleep and retry. let sleep_time = self .cluster_params @@ -771,7 +770,7 @@ where .wait_time_for_retry(retries); thread::sleep(sleep_time); } - crate::types::RetryMethod::Reconnect => { + RetryMethod::Reconnect | RetryMethod::ReconnectAndRetry => { if *self.auto_reconnect.borrow() { if let Ok(mut conn) = self.connect(&addr) { if conn.check_connection() { @@ -780,10 +779,10 @@ where } } } - crate::types::RetryMethod::NoRetry => { + RetryMethod::NoRetry => { return Err(err); } - crate::types::RetryMethod::RetryImmediately => {} + RetryMethod::RetryImmediately => {} } } } diff --git a/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs b/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs index 2bfbb8b934..d89d063b78 100644 --- a/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs +++ b/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs @@ -255,16 +255,18 @@ where &self, amount: usize, conn_type: ConnectionType, - ) -> impl Iterator> + '_ { - self.connection_map - .iter() - .choose_multiple(&mut rand::thread_rng(), amount) - .into_iter() - .map(move |item| { - let (address, node) = (item.key(), item.value()); - let conn = node.get_connection(&conn_type); - (address.clone(), conn) - }) + ) -> Option> + '_> { + (!self.connection_map.is_empty()).then_some({ + self.connection_map + .iter() + .choose_multiple(&mut rand::thread_rng(), amount) + .into_iter() + .map(move |item| { + let (address, node) = (item.key(), item.value()); + let conn = node.get_connection(&conn_type); + (address.clone(), conn) + }) + }) } pub(crate) fn replace_or_add_connection_for_address( @@ -633,6 +635,7 @@ mod tests { let random_connections: HashSet<_> = container .random_connections(3, ConnectionType::User) + .expect("No connections found") .map(|pair| pair.1) .collect(); @@ -647,12 +650,9 @@ mod tests { let container = create_container(); remove_all_connections(&container); - assert_eq!( - 0, - container - .random_connections(1, ConnectionType::User) - .count() - ); + assert!(container + .random_connections(1, ConnectionType::User) + .is_none()); } #[test] @@ -665,6 +665,7 @@ mod tests { ); let random_connections: Vec<_> = container .random_connections(1, ConnectionType::User) + .expect("No connections found") .collect(); assert_eq!(vec![(address, 4)], random_connections); @@ -675,6 +676,7 @@ mod tests { let container = create_container(); let mut random_connections: Vec<_> = container .random_connections(1000, ConnectionType::User) + .expect("No connections found") .map(|pair| pair.1) .collect(); random_connections.sort(); @@ -687,6 +689,7 @@ mod tests { let container = create_container_with_strategy(ReadFromReplicaStrategy::RoundRobin, true); let mut random_connections: Vec<_> = container .random_connections(1000, ConnectionType::PreferManagement) + .expect("No connections found") .map(|pair| pair.1) .collect(); random_connections.sort(); diff --git a/glide-core/redis-rs/redis/src/cluster_async/mod.rs b/glide-core/redis-rs/redis/src/cluster_async/mod.rs index c8628c16bb..aa9f02e1e6 100644 --- a/glide-core/redis-rs/redis/src/cluster_async/mod.rs +++ b/glide-core/redis-rs/redis/src/cluster_async/mod.rs @@ -845,6 +845,7 @@ impl Future for Request { let request = this.request.as_mut().unwrap(); // TODO - would be nice if we didn't need to repeat this code twice, with & without retries. if request.retry >= this.retry_params.number_of_retries { + let retry_method = err.retry_method(); let next = if err.kind() == ErrorKind::AllConnectionsUnavailable { Next::ReconnectToInitialNodes { request: None }.into() } else if matches!(err.retry_method(), crate::types::RetryMethod::MovedRedirect) @@ -855,7 +856,9 @@ impl Future for Request { sleep_duration: None, } .into() - } else if matches!(err.retry_method(), crate::types::RetryMethod::Reconnect) { + } else if matches!(retry_method, crate::types::RetryMethod::Reconnect) + || matches!(retry_method, crate::types::RetryMethod::ReconnectAndRetry) + { if let OperationTarget::Node { address } = target { Next::Reconnect { request: None, @@ -934,13 +937,18 @@ impl Future for Request { }); self.poll(cx) } - crate::types::RetryMethod::Reconnect => { + crate::types::RetryMethod::Reconnect + | crate::types::RetryMethod::ReconnectAndRetry => { let mut request = this.request.take().unwrap(); // TODO should we reset the redirect here? request.info.reset_routing(); warn!("disconnected from {:?}", address); + let should_retry = matches!( + err.retry_method(), + crate::types::RetryMethod::ReconnectAndRetry + ); Next::Reconnect { - request: Some(request), + request: should_retry.then_some(request), target: address, } .into() @@ -1177,8 +1185,11 @@ where Ok(connections.0) } - fn reconnect_to_initial_nodes(&mut self) -> impl Future { - let inner = self.inner.clone(); + // Reconnet to the initial nodes provided by the user in the creation of the client, + // and try to refresh the slots based on the initial connections. + // Being used when all cluster connections are unavailable. + fn reconnect_to_initial_nodes(inner: Arc>) -> impl Future { + let inner = inner.clone(); async move { let connection_map = match Self::create_initial_connections( &inner.initial_nodes, @@ -1680,7 +1691,9 @@ where Self::refresh_slots_inner(inner, curr_retry) .await .map_err(|err| { - if curr_retry > DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES { + if curr_retry > DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES + || err.kind() == ErrorKind::AllConnectionsUnavailable + { BackoffError::Permanent(err) } else { BackoffError::from(err) @@ -2073,14 +2086,22 @@ where } ConnectionCheck::RandomConnection => { let read_guard = core.conn_lock.read().await; - let (random_address, random_conn_future) = read_guard + read_guard .random_connections(1, ConnectionType::User) - .next() - .ok_or(RedisError::from(( - ErrorKind::AllConnectionsUnavailable, - "No random connection found", - )))?; - return Ok((random_address, random_conn_future.await)); + .and_then(|mut random_connections| { + random_connections.next().map( + |(random_address, random_conn_future)| async move { + (random_address, random_conn_future.await) + }, + ) + }) + .ok_or_else(|| { + RedisError::from(( + ErrorKind::AllConnectionsUnavailable, + "No random connection found", + )) + })? + .await } }; @@ -2104,10 +2125,19 @@ where } Err(err) => { trace!("Recover slots failed!"); - *future = Box::pin(Self::refresh_slots_and_subscriptions_with_retries( - self.inner.clone(), - &RefreshPolicy::Throttable, - )); + let next_state = if err.kind() == ErrorKind::AllConnectionsUnavailable { + ConnectionState::Recover(RecoverFuture::Reconnect(Box::pin( + ClusterConnInner::reconnect_to_initial_nodes(self.inner.clone()), + ))) + } else { + ConnectionState::Recover(RecoverFuture::RecoverSlots(Box::pin( + Self::refresh_slots_and_subscriptions_with_retries( + self.inner.clone(), + &RefreshPolicy::Throttable, + ), + ))) + }; + self.state = next_state; Poll::Ready(Err(err)) } }, @@ -2226,9 +2256,7 @@ where })); } } - Next::Reconnect { - request, target, .. - } => { + Next::Reconnect { request, target } => { poll_flush_action = poll_flush_action.change_state(PollFlushAction::Reconnect(vec![target])); if let Some(request) = request { @@ -2371,7 +2399,7 @@ where } PollFlushAction::ReconnectFromInitialConnections => { self.state = ConnectionState::Recover(RecoverFuture::Reconnect(Box::pin( - self.reconnect_to_initial_nodes(), + ClusterConnInner::reconnect_to_initial_nodes(self.inner.clone()), ))); } } @@ -2413,8 +2441,19 @@ async fn calculate_topology_from_random_nodes<'a, C>( where C: ConnectionLike + Connect + Clone + Send + Sync + 'static, { - let requested_nodes = - read_guard.random_connections(num_of_nodes_to_query, ConnectionType::PreferManagement); + let requested_nodes = if let Some(random_conns) = + read_guard.random_connections(num_of_nodes_to_query, ConnectionType::PreferManagement) + { + random_conns + } else { + return ( + Err(RedisError::from(( + ErrorKind::AllConnectionsUnavailable, + "No available connections to refresh slots from", + ))), + vec![], + ); + }; let topology_join_results = futures::future::join_all(requested_nodes.map(|(addr, conn)| async move { let mut conn: C = conn.await; diff --git a/glide-core/redis-rs/redis/src/types.rs b/glide-core/redis-rs/redis/src/types.rs index a024f16a7d..7e303df68a 100644 --- a/glide-core/redis-rs/redis/src/types.rs +++ b/glide-core/redis-rs/redis/src/types.rs @@ -118,6 +118,14 @@ pub enum ErrorKind { /// not native to the system. This is usually the case if /// the cause is another error. IoError, + /// An error indicating that a fatal error occurred while attempting to send a request to the server, + /// meaning the connection was closed before the request was transmitted. Since the server did not process the request, + /// it is safe to retry the request. + FatalSendError, + /// An error indicating that a fatal error occurred while trying to receive a response, + /// likely due to the closure of the underlying connection. It is unclear whether + /// the server processed the request, making it unsafe to retry the request. + FatalReceiveError, /// An error raised that was identified on the client before execution. ClientError, /// An extension error. This is an error created by the server @@ -802,6 +810,7 @@ impl fmt::Debug for RedisError { pub(crate) enum RetryMethod { Reconnect, + ReconnectAndRetry, NoRetry, RetryImmediately, WaitAndRetry, @@ -870,6 +879,10 @@ impl RedisError { ErrorKind::CrossSlot => "cross-slot", ErrorKind::MasterDown => "master down", ErrorKind::IoError => "I/O error", + ErrorKind::FatalSendError => { + "failed to send the request to the server due to a fatal error - the request was not transmitted" + } + ErrorKind::FatalReceiveError => "a fatal error occurred while attempting to receive a response from the server", ErrorKind::ExtensionError => "extension error", ErrorKind::ClientError => "client error", ErrorKind::ReadOnly => "read-only", @@ -942,6 +955,12 @@ impl RedisError { /// Returns true if error was caused by a dropped connection. pub fn is_connection_dropped(&self) -> bool { + if matches!( + self.kind(), + ErrorKind::FatalSendError | ErrorKind::FatalReceiveError + ) { + return true; + } match self.repr { ErrorRepr::IoError(ref err) => matches!( err.kind(), @@ -957,6 +976,7 @@ impl RedisError { pub fn is_unrecoverable_error(&self) -> bool { match self.retry_method() { RetryMethod::Reconnect => true, + RetryMethod::ReconnectAndRetry => true, RetryMethod::NoRetry => false, RetryMethod::RetryImmediately => false, @@ -1064,12 +1084,15 @@ impl RedisError { io::ErrorKind::PermissionDenied => RetryMethod::NoRetry, io::ErrorKind::Unsupported => RetryMethod::NoRetry, + io::ErrorKind::TimedOut => RetryMethod::NoRetry, _ => RetryMethod::RetryImmediately, }, _ => RetryMethod::RetryImmediately, }, ErrorKind::NotAllSlotsCovered => RetryMethod::NoRetry, + ErrorKind::FatalReceiveError => RetryMethod::Reconnect, + ErrorKind::FatalSendError => RetryMethod::ReconnectAndRetry, } } } diff --git a/glide-core/redis-rs/redis/tests/test_async.rs b/glide-core/redis-rs/redis/tests/test_async.rs index d16f1e0694..73d14de022 100644 --- a/glide-core/redis-rs/redis/tests/test_async.rs +++ b/glide-core/redis-rs/redis/tests/test_async.rs @@ -569,7 +569,7 @@ mod basic_async { Err(err) => break err, } }; - assert_eq!(err.kind(), ErrorKind::IoError); // Shouldn't this be IoError? + assert_eq!(err.kind(), ErrorKind::FatalSendError); } #[tokio::test] diff --git a/glide-core/redis-rs/redis/tests/test_cluster_async.rs b/glide-core/redis-rs/redis/tests/test_cluster_async.rs index e6a5984fa7..971a31a809 100644 --- a/glide-core/redis-rs/redis/tests/test_cluster_async.rs +++ b/glide-core/redis-rs/redis/tests/test_cluster_async.rs @@ -1015,7 +1015,6 @@ mod cluster_async { let sleep_duration = core::time::Duration::from_millis(100); #[cfg(feature = "tokio-comp")] tokio::time::sleep(sleep_duration).await; - } } panic!("Failed to reach to the expected topology refresh retries. Found={refreshed_calls}, Expected={expected_calls}") @@ -2542,8 +2541,8 @@ mod cluster_async { match port { 6380 => panic!("Node should not be called"), _ => match completed.fetch_add(1, Ordering::SeqCst) { - 0..=1 => Err(Err(RedisError::from(std::io::Error::new( - std::io::ErrorKind::ConnectionReset, + 0..=1 => Err(Err(RedisError::from(( + ErrorKind::FatalSendError, "mock-io-error", )))), _ => Err(Ok(Value::BulkString(b"123".to_vec()))), @@ -2598,6 +2597,81 @@ mod cluster_async { assert_eq!(completed.load(Ordering::SeqCst), 1); } + #[test] + #[serial_test::serial] + fn test_async_cluster_non_retryable_io_error_should_not_retry() { + let name = "test_async_cluster_non_retryable_io_error_should_not_retry"; + let requests = atomic::AtomicUsize::new(0); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(3), + name, + move |cmd: &[u8], _port| { + respond_startup_two_nodes(name, cmd)?; + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + match i { + 0 => Err(Err(RedisError::from((ErrorKind::IoError, "io-error")))), + _ => { + panic!("Expected not to be retried!") + } + } + }, + ); + runtime + .block_on(async move { + let res = cmd("INCR") + .arg("foo") + .query_async::<_, Option>(&mut connection) + .await; + assert!(res.is_err()); + let err = res.unwrap_err(); + assert!(err.is_io_error()); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_retry_safe_io_error_should_be_retried() { + let name = "test_async_cluster_retry_safe_io_error_should_be_retried"; + let requests = atomic::AtomicUsize::new(0); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(3), + name, + move |cmd: &[u8], _port| { + respond_startup_two_nodes(name, cmd)?; + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + match i { + 0 => Err(Err(RedisError::from(( + ErrorKind::FatalSendError, + "server didn't receive the request, safe to retry", + )))), + _ => Err(Ok(Value::Int(1))), + } + }, + ); + runtime + .block_on(async move { + let res = cmd("INCR") + .arg("foo") + .query_async::<_, i32>(&mut connection) + .await; + assert!(res.is_ok()); + let value = res.unwrap(); + assert_eq!(value, 1); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + #[test] #[serial_test::serial] fn test_async_cluster_read_from_primary() { @@ -3186,10 +3260,17 @@ mod cluster_async { }; // wait for new topology discovery + let max_requests = 5; + let mut i = 0; + let mut cmd = redis::cmd("INFO"); + cmd.arg("SERVER"); loop { - let mut cmd = redis::cmd("INFO"); - cmd.arg("SERVER"); - let res = publishing_con + if i == max_requests { + panic!("Failed to recover and discover new topology"); + } + i += 1; + + if let Ok(res) = publishing_con .route_command( &cmd, RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( @@ -3197,21 +3278,21 @@ mod cluster_async { SlotAddr::Master, ))), ) - .await; - assert!(res.is_ok()); - let res = res.unwrap(); - match res { - Value::VerbatimString { format: _, text } => { - if text.contains(format!("tcp_port:{}", last_server_port).as_str()) { - // new topology rediscovered - break; + .await + { + match res { + Value::VerbatimString { format: _, text } => { + if text.contains(format!("tcp_port:{}", last_server_port).as_str()) { + // new topology rediscovered + break; + } + } + _ => { + panic!("Wrong return type for INFO SERVER command: {:?}", res); } } - _ => { - panic!("Wrong return type for INFO SERVER command: {:?}", res); - } + sleep(futures_time::time::Duration::from_secs(1)).await; } - sleep(futures_time::time::Duration::from_secs(1)).await; } // sleep for one one cycle of topology refresh @@ -3250,7 +3331,7 @@ mod cluster_async { if use_sharded { // validate SPUBLISH - let result = cmd("SPUBLISH") + let result = redis::cmd("SPUBLISH") .arg("test_channel_?") .arg("test_message") .query_async(&mut publishing_con) @@ -3757,9 +3838,26 @@ mod cluster_async { false, ); - let result = connection.req_packed_command(&cmd).await.unwrap(); - assert_eq!(result, Value::SimpleString("PONG".to_string())); - Ok::<_, RedisError>(()) + let max_requests = 5; + let mut i = 0; + let mut last_err = None; + loop { + if i == max_requests { + break; + } + i += 1; + match connection.req_packed_command(&cmd).await { + Ok(result) => { + assert_eq!(result, Value::SimpleString("PONG".to_string())); + return Ok::<_, RedisError>(()); + } + Err(err) => { + last_err = Some(err); + let _ = sleep(futures_time::time::Duration::from_secs(1)).await; + } + } + } + panic!("Failed to recover after all nodes went down. Last error: {last_err:?}"); }) .unwrap(); } @@ -3786,19 +3884,37 @@ mod cluster_async { ); let cmd = cmd("PING"); - // explicitly route to all primaries and request all succeeded - let result = connection - .route_command( - &cmd, - RoutingInfo::MultiNode(( - MultipleNodeRoutingInfo::AllMasters, - Some(redis::cluster_routing::ResponsePolicy::AllSucceeded), - )), - ) - .await; - assert!(result.is_ok()); - Ok::<_, RedisError>(()) + let max_requests = 5; + let mut i = 0; + let mut last_err = None; + loop { + if i == max_requests { + break; + } + i += 1; + // explicitly route to all primaries and request all succeeded + match connection + .route_command( + &cmd, + RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllMasters, + Some(redis::cluster_routing::ResponsePolicy::AllSucceeded), + )), + ) + .await + { + Ok(result) => { + assert_eq!(result, Value::SimpleString("PONG".to_string())); + return Ok::<_, RedisError>(()); + } + Err(err) => { + last_err = Some(err); + let _ = sleep(futures_time::time::Duration::from_secs(1)).await; + } + } + } + panic!("Failed to recover after all nodes went down. Last error: {last_err:?}"); }) .unwrap(); } @@ -3871,7 +3987,10 @@ mod cluster_async { if connect_attempt > 5 { panic!("Too many pings!"); } - Err(Err(broken_pipe_error())) + Err(Err(RedisError::from(( + ErrorKind::FatalSendError, + "mock-io-error", + )))) } else { respond_startup_two_nodes(name, cmd)?; let past_get_attempts = get_attempts.fetch_add(1, Ordering::Relaxed); @@ -3879,7 +3998,10 @@ mod cluster_async { if past_get_attempts == 0 { // Error once with io-error, ensure connection is reestablished w/out calling // other node (i.e., not doing a full slot rebuild) - Err(Err(broken_pipe_error())) + Err(Err(RedisError::from(( + ErrorKind::FatalSendError, + "mock-io-error", + )))) } else { Err(Ok(Value::BulkString(b"123".to_vec()))) } @@ -3931,7 +4053,7 @@ mod cluster_async { .expect("Failed executing CLIENT LIST"); let mut client_list_parts = client_list.split('\n'); if client_list_parts - .any(|line| line.contains(MANAGEMENT_CONN_NAME) && line.contains("cmd=cluster")) + .any(|line| line.contains(MANAGEMENT_CONN_NAME) && line.contains("cmd=cluster")) && client_list.matches(MANAGEMENT_CONN_NAME).count() == 1 { return Ok::<_, RedisError>(()); } @@ -3983,21 +4105,23 @@ mod cluster_async { } async fn kill_connection(killer_connection: &mut ClusterConnection, connection_to_kill: &str) { + let default_routing = RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode( + Route::new(0, SlotAddr::Master), + )); + kill_connection_with_routing(killer_connection, connection_to_kill, default_routing).await; + } + + async fn kill_connection_with_routing( + killer_connection: &mut ClusterConnection, + connection_to_kill: &str, + routing: RoutingInfo, + ) { let mut cmd = redis::cmd("CLIENT"); cmd.arg("KILL"); cmd.arg("ID"); cmd.arg(connection_to_kill); - // Kill the management connection in the primary node that holds slot 0 - assert!(killer_connection - .route_command( - &cmd, - RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( - 0, - SlotAddr::Master, - )),), - ) - .await - .is_ok()); + // Kill the management connection for the routing node + assert!(killer_connection.route_command(&cmd, routing).await.is_ok()); } #[test] diff --git a/glide-core/tests/test_socket_listener.rs b/glide-core/tests/test_socket_listener.rs index a242eb80d1..c6e2b7b15d 100644 --- a/glide-core/tests/test_socket_listener.rs +++ b/glide-core/tests/test_socket_listener.rs @@ -172,7 +172,7 @@ mod socket_listener { } fn read_from_socket(buffer: &mut Vec, socket: &mut UnixStream) -> usize { - buffer.resize(100, 0_u8); + buffer.resize(300, 0_u8); socket.read(buffer).unwrap() } From 695697a43be66cd7c2548f2601a55340fb583dbb Mon Sep 17 00:00:00 2001 From: Avi Fenesh <55848801+avifenesh@users.noreply.github.com> Date: Tue, 22 Oct 2024 00:46:21 +0300 Subject: [PATCH 029/180] Rust - code cleanup, part 1 (#2478) Some cleaning for unnecessary code to set the floor for next PR Signed-off-by: avifenesh --- .vscode/settings.json | 2 +- glide-core/Cargo.toml | 31 ++++-- glide-core/redis-rs/redis/Cargo.toml | 41 +++++-- .../redis-rs/redis/src/aio/connection.rs | 32 ------ glide-core/redis-rs/redis/src/client.rs | 102 ------------------ .../cluster_async/connections_container.rs | 4 +- .../src/cluster_async/connections_logic.rs | 6 +- .../redis-rs/redis/src/cluster_async/mod.rs | 8 +- .../redis-rs/redis/src/cluster_client.rs | 8 +- glide-core/redis-rs/redis/src/commands/mod.rs | 7 +- glide-core/redis-rs/redis/src/types.rs | 6 ++ node/DEVELOPER.md | 12 +-- node/tests/SharedTests.ts | 20 ++-- 13 files changed, 96 insertions(+), 183 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index ef488543df..72bcb0d6f7 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -14,7 +14,7 @@ "node/rust-client/Cargo.toml", "logger_core/Cargo.toml", "csharp/lib/Cargo.toml", - "submodules/redis-rs/Cargo.toml", + "glide-core/redis-rs/Cargo.toml", "benchmarks/rust/Cargo.toml", "java/Cargo.toml" ], diff --git a/glide-core/Cargo.toml b/glide-core/Cargo.toml index 51de808bd2..28cad8e646 100644 --- a/glide-core/Cargo.toml +++ b/glide-core/Cargo.toml @@ -10,27 +10,42 @@ authors = ["Valkey GLIDE Maintainers"] [dependencies] bytes = "1" futures = "^0.3" -redis = { path = "./redis-rs/redis", features = ["aio", "tokio-comp", "tokio-rustls-comp", "connection-manager","cluster", "cluster-async"] } +redis = { path = "./redis-rs/redis", features = [ + "aio", + "tokio-comp", + "tokio-rustls-comp", + "connection-manager", + "cluster", + "cluster-async", +] } tokio = { version = "1", features = ["macros", "time"] } -logger_core = {path = "../logger_core"} +logger_core = { path = "../logger_core" } dispose = "0.5.0" -tokio-util = {version = "^0.7", features = ["rt"], optional = true} +tokio-util = { version = "^0.7", features = ["rt"], optional = true } num_cpus = { version = "^1.15", optional = true } tokio-retry = "0.3.0" -protobuf = { version= "3", features = ["bytes", "with-bytes"], optional = true } +protobuf = { version = "3", features = [ + "bytes", + "with-bytes", +], optional = true } integer-encoding = { version = "4.0.0", optional = true } thiserror = "1" rand = { version = "0.8.5" } futures-intrusive = "0.5.0" directories = { version = "4.0", optional = true } once_cell = "1.18.0" -arcstr = "1.1.5" sha1_smol = "1.0.0" nanoid = "0.4.0" async-trait = { version = "0.1.24" } [features] -socket-layer = ["directories", "integer-encoding", "num_cpus", "protobuf", "tokio-util"] +socket-layer = [ + "directories", + "integer-encoding", + "num_cpus", + "protobuf", + "tokio-util", +] standalone_heartbeat = [] [dev-dependencies] @@ -45,7 +60,9 @@ ctor = "0.2.2" redis = { path = "./redis-rs/redis", features = ["tls-rustls-insecure"] } iai-callgrind = "0.9" tokio = { version = "1", features = ["rt-multi-thread"] } -glide-core = { path = ".", features = ["socket-layer"] } # always enable this feature in tests. +glide-core = { path = ".", features = [ + "socket-layer", +] } # always enable this feature in tests. [lints.rust] unexpected_cfgs = { level = "warn", check-cfg = ['cfg(standalone_heartbeat)'] } diff --git a/glide-core/redis-rs/redis/Cargo.toml b/glide-core/redis-rs/redis/Cargo.toml index 46f6fe9231..0179669eaf 100644 --- a/glide-core/redis-rs/redis/Cargo.toml +++ b/glide-core/redis-rs/redis/Cargo.toml @@ -69,19 +69,19 @@ dashmap = { version = "6.0", optional = true } async-trait = { version = "0.1.24", optional = true } # Only needed for tokio support -backoff-tokio = { package = "backoff", version = "0.4.0", optional = true, features = ["tokio"] } +backoff-tokio = { package = "backoff", version = "0.4.0", optional = true, features = [ + "tokio", +] } # Only needed for native tls native-tls = { version = "0.2", optional = true } tokio-native-tls = { version = "0.3", optional = true } -async-native-tls = { version = "0.4", optional = true } # Only needed for rustls rustls = { version = "0.22", optional = true } webpki-roots = { version = "0.26", optional = true } rustls-native-certs = { version = "0.7", optional = true } tokio-rustls = { version = "0.25", optional = true } -futures-rustls = { version = "0.25", optional = true } rustls-pemfile = { version = "2", optional = true } rustls-pki-types = { version = "1", optional = true } @@ -98,7 +98,6 @@ num-bigint = "0.4.4" ahash = { version = "0.8.11", optional = true } tracing = "0.1" -arcstr = "1.1.5" # Optional uuid support uuid = { version = "1.6.1", optional = true } @@ -114,16 +113,34 @@ default = [ "tokio-rustls-comp", "connection-manager", "cluster", - "cluster-async" + "cluster-async", ] acl = [] -aio = ["bytes", "pin-project-lite", "futures-util", "futures-util/alloc", "futures-util/sink", "tokio/io-util", "tokio-util", "tokio-util/codec", "combine/tokio", "async-trait", "fast-math", "dispose"] +aio = [ + "bytes", + "pin-project-lite", + "futures-util", + "futures-util/alloc", + "futures-util/sink", + "tokio/io-util", + "tokio-util", + "tokio-util/codec", + "combine/tokio", + "async-trait", + "fast-math", + "dispose", +] geospatial = [] json = ["serde", "serde/derive", "serde_json"] cluster = ["crc16", "rand", "derivative"] script = ["sha1_smol"] tls-native-tls = ["native-tls"] -tls-rustls = ["rustls", "rustls-native-certs", "rustls-pemfile", "rustls-pki-types"] +tls-rustls = [ + "rustls", + "rustls-native-certs", + "rustls-pemfile", + "rustls-pki-types", +] tls-rustls-insecure = ["tls-rustls"] tls-rustls-webpki-roots = ["tls-rustls", "webpki-roots"] tokio-comp = ["aio", "tokio/net", "backoff-tokio"] @@ -154,7 +171,12 @@ futures-time = "3" criterion = "0.4" partial-io = { version = "0.5", features = ["tokio", "quickcheck1"] } quickcheck = "1.0.3" -tokio = { version = "1", features = ["rt", "macros", "rt-multi-thread", "time"] } +tokio = { version = "1", features = [ + "rt", + "macros", + "rt-multi-thread", + "time", +] } tempfile = "=3.6.0" once_cell = "1" anyhow = "1" @@ -225,3 +247,6 @@ required-features = ["connection-manager"] [[example]] name = "streams" required-features = ["streams"] + +[package.metadata.cargo-machete] +ignored = ["strum"] diff --git a/glide-core/redis-rs/redis/src/aio/connection.rs b/glide-core/redis-rs/redis/src/aio/connection.rs index 2b32a7ced3..a18521c974 100644 --- a/glide-core/redis-rs/redis/src/aio/connection.rs +++ b/glide-core/redis-rs/redis/src/aio/connection.rs @@ -51,27 +51,6 @@ fn test() { assert_sync::(); } -impl Connection { - pub(crate) fn map(self, f: impl FnOnce(C) -> D) -> Connection { - let Self { - con, - buf, - decoder, - db, - pubsub, - protocol, - } = self; - Connection { - con: f(con), - buf, - decoder, - db, - pubsub, - protocol, - } - } -} - impl Connection where C: Unpin + AsyncRead + AsyncWrite + Send, @@ -190,17 +169,6 @@ where } } -pub(crate) async fn connect( - connection_info: &ConnectionInfo, - socket_addr: Option, -) -> RedisResult> -where - C: Unpin + RedisRuntime + AsyncRead + AsyncWrite + Send, -{ - let (con, _ip) = connect_simple::(connection_info, socket_addr).await?; - Connection::new(&connection_info.redis, con).await -} - impl ConnectionLike for Connection where C: Unpin + AsyncRead + AsyncWrite + Send, diff --git a/glide-core/redis-rs/redis/src/client.rs b/glide-core/redis-rs/redis/src/client.rs index fd8c4c08b4..f6a7b4ef91 100644 --- a/glide-core/redis-rs/redis/src/client.rs +++ b/glide-core/redis-rs/redis/src/client.rs @@ -113,22 +113,6 @@ impl Client { crate::aio::Connection::new(&self.connection_info.redis, con).await } - /// Returns an async connection from the client. - #[cfg(feature = "tokio-comp")] - #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] - #[deprecated( - note = "aio::Connection is deprecated. Use client::get_multiplexed_tokio_connection instead." - )] - #[allow(deprecated)] - pub async fn get_tokio_connection(&self) -> RedisResult { - use crate::aio::RedisRuntime; - Ok( - crate::aio::connect::(&self.connection_info, None) - .await? - .map(RedisRuntime::boxed), - ) - } - /// Returns an async connection from the client. #[cfg(feature = "tokio-comp")] #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] @@ -245,53 +229,6 @@ impl Client { .await } - /// Returns an async multiplexed connection from the client and a future which must be polled - /// to drive any requests submitted to it (see `get_multiplexed_tokio_connection`). - /// - /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently - /// on the same underlying connection (tcp/unix socket). - /// The multiplexer will return a timeout error on any request that takes longer then `response_timeout`. - #[cfg(feature = "tokio-comp")] - #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] - pub async fn create_multiplexed_tokio_connection_with_response_timeout( - &self, - response_timeout: std::time::Duration, - glide_connection_options: GlideConnectionOptions, - ) -> RedisResult<( - crate::aio::MultiplexedConnection, - impl std::future::Future, - )> { - self.create_multiplexed_async_connection_inner::( - response_timeout, - None, - glide_connection_options, - ) - .await - .map(|(conn, driver, _ip)| (conn, driver)) - } - - /// Returns an async multiplexed connection from the client and a future which must be polled - /// to drive any requests submitted to it (see `get_multiplexed_tokio_connection`). - /// - /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently - /// on the same underlying connection (tcp/unix socket). - #[cfg(feature = "tokio-comp")] - #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] - pub async fn create_multiplexed_tokio_connection( - &self, - glide_connection_options: GlideConnectionOptions, - ) -> RedisResult<( - crate::aio::MultiplexedConnection, - impl std::future::Future, - )> { - self.create_multiplexed_tokio_connection_with_response_timeout( - std::time::Duration::MAX, - glide_connection_options, - ) - .await - .map(|conn_res| (conn_res.0, conn_res.1)) - } - /// Returns an async [`ConnectionManager`][connection-manager] from the client. /// /// The connection manager wraps a @@ -375,45 +312,6 @@ impl Client { .await } - /// Returns an async [`ConnectionManager`][connection-manager] from the client. - /// - /// The connection manager wraps a - /// [`MultiplexedConnection`][multiplexed-connection]. If a command to that - /// connection fails with a connection error, then a new connection is - /// established in the background and the error is returned to the caller. - /// - /// This means that on connection loss at least one command will fail, but - /// the connection will be re-established automatically if possible. Please - /// refer to the [`ConnectionManager`][connection-manager] docs for - /// detailed reconnecting behavior. - /// - /// A connection manager can be cloned, allowing requests to be be sent concurrently - /// on the same underlying connection (tcp/unix socket). - /// - /// [connection-manager]: aio/struct.ConnectionManager.html - /// [multiplexed-connection]: aio/struct.MultiplexedConnection.html - #[cfg(feature = "connection-manager")] - #[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))] - #[deprecated(note = "use get_connection_manager_with_backoff_and_timeouts instead")] - pub async fn get_tokio_connection_manager_with_backoff_and_timeouts( - &self, - exponent_base: u64, - factor: u64, - number_of_retries: usize, - response_timeout: std::time::Duration, - connection_timeout: std::time::Duration, - ) -> RedisResult { - crate::aio::ConnectionManager::new_with_backoff_and_timeouts( - self.clone(), - exponent_base, - factor, - number_of_retries, - response_timeout, - connection_timeout, - ) - .await - } - /// Returns an async [`ConnectionManager`][connection-manager] from the client. /// /// The connection manager wraps a diff --git a/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs b/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs index d89d063b78..9d44b25cab 100644 --- a/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs +++ b/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs @@ -361,7 +361,7 @@ mod tests { } fn create_container_with_strategy( - stragey: ReadFromReplicaStrategy, + strategy: ReadFromReplicaStrategy, use_management_connections: bool, ) -> ConnectionsContainer { let slot_map = SlotMap::new( @@ -411,7 +411,7 @@ mod tests { ConnectionsContainer { slot_map, connection_map, - read_from_replica_strategy: stragey, + read_from_replica_strategy: strategy, topology_hash: 0, } } diff --git a/glide-core/redis-rs/redis/src/cluster_async/connections_logic.rs b/glide-core/redis-rs/redis/src/cluster_async/connections_logic.rs index 7de2493000..7a2306ca19 100644 --- a/glide-core/redis-rs/redis/src/cluster_async/connections_logic.rs +++ b/glide-core/redis-rs/redis/src/cluster_async/connections_logic.rs @@ -5,7 +5,7 @@ use super::{ Connect, }; use crate::{ - aio::{ConnectionLike, DisconnectNotifier, Runtime}, + aio::{ConnectionLike, DisconnectNotifier}, client::GlideConnectionOptions, cluster::get_connection_info, cluster_client::ClusterParams, @@ -462,9 +462,7 @@ async fn check_connection(conn: &mut C, timeout: std::time::Duration) -> Redi where C: ConnectionLike + Send + 'static, { - Runtime::locate() - .timeout(timeout, crate::cmd("PING").query_async::<_, String>(conn)) - .await??; + tokio::time::timeout(timeout, crate::cmd("PING").query_async::<_, String>(conn)).await??; Ok(()) } diff --git a/glide-core/redis-rs/redis/src/cluster_async/mod.rs b/glide-core/redis-rs/redis/src/cluster_async/mod.rs index aa9f02e1e6..28024f7649 100644 --- a/glide-core/redis-rs/redis/src/cluster_async/mod.rs +++ b/glide-core/redis-rs/redis/src/cluster_async/mod.rs @@ -1218,7 +1218,7 @@ where } } - // Validate all existing user connections and try to reconnect if nessesary. + // Validate all existing user connections and try to reconnect if necessary. // In addition, as a safety measure, drop nodes that do not have any assigned slots. // This function serves as a cheap alternative to slot_refresh() and thus can be used much more frequently. // The function does not discover the topology from the cluster and assumes the cached topology is valid. @@ -1293,7 +1293,7 @@ where let connections_container = inner.conn_lock.read().await; let cluster_params = &inner.cluster_params; let subscriptions_by_address = &inner.subscriptions_by_address; - let glide_connection_optons = &inner.glide_connection_options; + let glide_connection_options = &inner.glide_connection_options; stream::iter(addresses.into_iter()) .fold( @@ -1315,7 +1315,7 @@ where node_option, &cluster_params, conn_type, - glide_connection_optons.clone(), + glide_connection_options.clone(), ) .await; match node { @@ -1528,7 +1528,7 @@ where Self::check_topology_and_refresh_if_diff(inner.clone(), &RefreshPolicy::Throttable) .await; if !topology_changed { - // This serves as a safety measure for validating pubsub subsctiptions state in case it has drifted + // This serves as a safety measure for validating pubsub subscriptions state in case it has drifted // while topology stayed the same. // For example, a failed attempt to refresh a connection which is triggered from refresh_pubsub_subscriptions(), // might leave a node unconnected indefinitely in case topology is stable and no request are attempted to this node. diff --git a/glide-core/redis-rs/redis/src/cluster_client.rs b/glide-core/redis-rs/redis/src/cluster_client.rs index 5815bede1e..185b8547ad 100644 --- a/glide-core/redis-rs/redis/src/cluster_client.rs +++ b/glide-core/redis-rs/redis/src/cluster_client.rs @@ -327,13 +327,13 @@ impl ClusterClientBuilder { self } - /// Sets maximal wait time in millisceonds between retries for the new ClusterClient. + /// Sets maximal wait time in milliseconds between retries for the new ClusterClient. pub fn max_retry_wait(mut self, max_wait: u64) -> ClusterClientBuilder { self.builder_params.retries_configuration.max_wait_time = max_wait; self } - /// Sets minimal wait time in millisceonds between retries for the new ClusterClient. + /// Sets minimal wait time in milliseconds between retries for the new ClusterClient. pub fn min_retry_wait(mut self, min_wait: u64) -> ClusterClientBuilder { self.builder_params.retries_configuration.min_wait_time = min_wait; self @@ -400,9 +400,9 @@ impl ClusterClientBuilder { } /// Enables periodic connections checks for this client. - /// If enabled, the conenctions to the cluster nodes will be validated periodicatly, per configured interval. + /// If enabled, the connections to the cluster nodes will be validated periodically, per configured interval. /// In addition, for tokio runtime, passive disconnections could be detected instantly, - /// triggering reestablishemnt, w/o waiting for the next periodic check. + /// triggering reestablishment, w/o waiting for the next periodic check. #[cfg(feature = "cluster-async")] pub fn periodic_connections_checks(mut self, interval: Duration) -> ClusterClientBuilder { self.builder_params.connections_validation_interval = Some(interval); diff --git a/glide-core/redis-rs/redis/src/commands/mod.rs b/glide-core/redis-rs/redis/src/commands/mod.rs index d5c937fa70..22a68cc987 100644 --- a/glide-core/redis-rs/redis/src/commands/mod.rs +++ b/glide-core/redis-rs/redis/src/commands/mod.rs @@ -2176,15 +2176,12 @@ impl ToRedisArgs for SetOptions { pub fn resp3_hello(connection_info: &RedisConnectionInfo) -> Cmd { let mut hello_cmd = cmd("HELLO"); hello_cmd.arg("3"); - if connection_info.password.is_some() { + if let Some(password) = &connection_info.password { let username: &str = match connection_info.username.as_ref() { None => "default", Some(username) => username, }; - hello_cmd - .arg("AUTH") - .arg(username) - .arg(connection_info.password.as_ref().unwrap()); + hello_cmd.arg("AUTH").arg(username).arg(password); } hello_cmd } diff --git a/glide-core/redis-rs/redis/src/types.rs b/glide-core/redis-rs/redis/src/types.rs index 7e303df68a..6fd564b203 100644 --- a/glide-core/redis-rs/redis/src/types.rs +++ b/glide-core/redis-rs/redis/src/types.rs @@ -186,6 +186,12 @@ pub(crate) enum ServerError { }, } +impl From for RedisError { + fn from(_: tokio::time::error::Elapsed) -> Self { + RedisError::from((ErrorKind::IoError, "Operation timed out")) + } +} + impl From for RedisError { fn from(value: ServerError) -> Self { // TODO - Consider changing RedisError to explicitly represent whether an error came from the server or not. Today it is only implied. diff --git a/node/DEVELOPER.md b/node/DEVELOPER.md index a3391c3282..7185ce1359 100644 --- a/node/DEVELOPER.md +++ b/node/DEVELOPER.md @@ -65,11 +65,7 @@ Before starting this step, make sure you've installed all software requirments. git clone https://github.com/valkey-io/valkey-glide.git cd valkey-glide ``` -2. Initialize git submodule: - ```bash - git submodule update --init --recursive - ``` -3. Install all node dependencies: +2. Install all node dependencies: ```bash cd node @@ -79,7 +75,7 @@ Before starting this step, make sure you've installed all software requirments. cd .. ``` -4. Build the Node wrapper (Choose a build option from the following and run it from the `node` folder): +3. Build the Node wrapper (Choose a build option from the following and run it from the `node` folder): 1. Build in release mode, stripped from all debug symbols (optimized and minimized binary size): @@ -101,14 +97,14 @@ Before starting this step, make sure you've installed all software requirments. Once building completed, you'll find the compiled JavaScript code in the`./build-ts` folder. -5. Run tests: +4. Run tests: 1. Ensure that you have installed server and valkey-cli on your host. You can download Valkey at the following link: [Valkey Download page](https://valkey.io/download/). 2. Execute the following command from the node folder: ```bash npm run build # make sure we have a debug build compiled first npm test ``` -6. Integrating the built GLIDE package into your project: +5. Integrating the built GLIDE package into your project: Add the package to your project using the folder path with the command `npm install /node`. - For a fast build, execute `npm run build`. This will perform a full, unoptimized build, which is suitable for developing tests. Keep in mind that performance is significantly affected in an unoptimized build, so it's required to build with the `build:release` or `build:benchmark` option when measuring performance. diff --git a/node/tests/SharedTests.ts b/node/tests/SharedTests.ts index eea277241a..7b1af14ed1 100644 --- a/node/tests/SharedTests.ts +++ b/node/tests/SharedTests.ts @@ -9269,7 +9269,6 @@ export function runBaseTests(config: { { sortOrder: SortOrder.DESC, storeDist: true }, ), ).toEqual(3); - // TODO deep close to https://github.com/maasencioh/jest-matcher-deep-close-to expect( await client.zrangeWithScores( key2, @@ -9277,11 +9276,20 @@ export function runBaseTests(config: { { reverse: true }, ), ).toEqual( - convertElementsAndScores({ - edge2: 236529.17986494553, - Palermo: 166274.15156960033, - Catania: 0.0, - }), + expect.arrayContaining([ + { + element: "edge2", + score: expect.closeTo(236529.17986494553, 0.0001), + }, + { + element: "Palermo", + score: expect.closeTo(166274.15156960033, 0.0001), + }, + { + element: "Catania", + score: expect.closeTo(0.0, 0.0001), + }, + ]), ); // test search by box, unit: feet, from member, with limited count 2, with hash From 76f68a6d93ea19a9cc3a59fa71d8d5602d8f277d Mon Sep 17 00:00:00 2001 From: Yi-Pin Chen Date: Mon, 21 Oct 2024 15:39:45 -0700 Subject: [PATCH 030/180] Java: add JSON.ARRAPPEND command (#2489) * Java: add JSON.ARRAPPEND command --------- Signed-off-by: Yi-Pin Chen --- CHANGELOG.md | 1 + .../api/commands/servermodules/Json.java | 81 +++++++++++++++++++ .../test/java/glide/modules/JsonTests.java | 55 +++++++++++++ 3 files changed, 137 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 71d007cce5..464beda5b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ * Java: Added `FT.ALIASADD`, `FT.ALIASDEL`, `FT.ALIASUPDATE` ([#2442](https://github.com/valkey-io/valkey-glide/pull/2442)) * Core: Update routing for commands from server modules ([#2461](https://github.com/valkey-io/valkey-glide/pull/2461)) * Node: Added `JSON.SET` and `JSON.GET` ([#2427](https://github.com/valkey-io/valkey-glide/pull/2427)) +* Java: Added `JSON.ARRAPPEND` ([#2489](https://github.com/valkey-io/valkey-glide/pull/2489)) #### Breaking Changes diff --git a/java/client/src/main/java/glide/api/commands/servermodules/Json.java b/java/client/src/main/java/glide/api/commands/servermodules/Json.java index 8398a9f168..868ad6276f 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/Json.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/Json.java @@ -22,6 +22,7 @@ public class Json { private static final String JSON_PREFIX = "JSON."; public static final String JSON_SET = JSON_PREFIX + "SET"; public static final String JSON_GET = JSON_PREFIX + "GET"; + private static final String JSON_ARRAPPEND = JSON_PREFIX + "ARRAPPEND"; private static final String JSON_ARRINSERT = JSON_PREFIX + "ARRINSERT"; private static final String JSON_ARRLEN = JSON_PREFIX + "ARRLEN"; @@ -390,6 +391,86 @@ public static CompletableFuture get( new ArgsBuilder().add(gs(JSON_GET)).add(key).add(options.toArgs()).add(paths).toArray()); } + /** + * Appends one or more values to the JSON array at the specified path + * within the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the values + * will be appended. + * @param values The values to append to the JSON array at the specified path + * . + * @return + *
    + *
  • For JSONPath (path starts with $):
    + * Returns a list of integers for every possible path, indicating the new length of the + * new array after appending values, or null for JSON values + * matching the path that are not an array. If path does not exist, an + * empty array will be returned. + *
  • For legacy path (path doesn't start with $):
    + * Returns the length of the new array after appending values to the array + * at path. If multiple paths are matched, returns the last updated array. + * If the JSON value at path is not a array or if path doesn't + * exist, an error is raised. If key doesn't exist, an error is raised. + * @example + *
    {@code
    +     * Json.set(client, "doc", "$", "{\"a\": 1, \"b\": [\"one\", \"two\"]}").get();
    +     * var res = Json.arrappend(client, "doc", "$.b", new String[] {"\"three\""}).get();
    +     * assert Arrays.equals((Object[]) res, new int[] {3}); // New length of the array after appending
    +     * res = Json.arrappend(client, "doc", ".b", new String[] {"\"four\""}).get();
    +     * assert res.equals(4); // New length of the array after appending
    +     * }
    + */ + public static CompletableFuture arrappend( + @NonNull BaseClient client, + @NonNull String key, + @NonNull String path, + @NonNull String[] values) { + return executeCommand( + client, concatenateArrays(new String[] {JSON_ARRAPPEND, key, path}, values)); + } + + /** + * Appends one or more values to the JSON array at the specified path + * within the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the values + * will be appended. + * @param values The values to append to the JSON array at the specified path + * . + * @return + *
      + *
    • For JSONPath (path starts with $):
      + * Returns a list of integers for every possible path, indicating the new length of the + * new array after appending values, or null for JSON values + * matching the path that are not an array. If path does not exist, an + * empty array will be returned. + *
    • For legacy path (path doesn't start with $):
      + * Returns the length of the new array after appending values to the array + * at path. If multiple paths are matched, returns the last updated array. + * If the JSON value at path is not a array or if path doesn't + * exist, an error is raised. If key doesn't exist, an error is raised. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\": 1, \"b\": [\"one\", \"two\"]}").get();
      +     * var res = Json.arrappend(client, gs("doc"), gs("$.b"), new GlideString[] {gs("\"three\"")}).get();
      +     * assert Arrays.equals((Object[]) res, new int[] {3}); // New length of the array after appending
      +     * res = Json.arrappend(client, gs("doc"), gs(".b"), new GlideString[] {gs("\"four\"")}).get();
      +     * assert res.equals(4); // New length of the array after appending
      +     * }
      + */ + public static CompletableFuture arrappend( + @NonNull BaseClient client, + @NonNull GlideString key, + @NonNull GlideString path, + @NonNull GlideString[] values) { + return executeCommand( + client, new ArgsBuilder().add(gs(JSON_ARRAPPEND)).add(key).add(path).add(values).toArray()); + } + /** * Inserts one or more values into the array at the specified path within the JSON * document stored at key, before the given index. diff --git a/java/integTest/src/test/java/glide/modules/JsonTests.java b/java/integTest/src/test/java/glide/modules/JsonTests.java index 07af34ab3e..969b6b30a4 100644 --- a/java/integTest/src/test/java/glide/modules/JsonTests.java +++ b/java/integTest/src/test/java/glide/modules/JsonTests.java @@ -9,6 +9,7 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import com.google.gson.JsonParser; @@ -20,6 +21,7 @@ import glide.api.models.commands.InfoOptions.Section; import glide.api.models.commands.json.JsonGetOptions; import java.util.UUID; +import java.util.concurrent.ExecutionException; import lombok.SneakyThrows; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; @@ -157,6 +159,59 @@ public void json_set_get_formatting() { assertEquals(expectedGetResult2, actualGetResult2); } + @Test + @SneakyThrows + public void arrappend() { + String key = UUID.randomUUID().toString(); + String doc = "{\"a\": 1, \"b\": [\"one\", \"two\"]}"; + + assertEquals(OK, Json.set(client, key, "$", doc).get()); + + assertArrayEquals( + new Object[] {3L}, + (Object[]) Json.arrappend(client, key, "$.b", new String[] {"\"three\""}).get()); + assertEquals( + 5L, Json.arrappend(client, key, ".b", new String[] {"\"four\"", "\"five\""}).get()); + + String getResult = Json.get(client, key, new String[] {"$"}).get(); + String expectedGetResult = + "[{\"a\": 1, \"b\": [\"one\", \"two\", \"three\", \"four\", \"five\"]}]"; + assertEquals(JsonParser.parseString(expectedGetResult), JsonParser.parseString(getResult)); + + assertArrayEquals( + new Object[] {null}, + (Object[]) Json.arrappend(client, key, "$.a", new String[] {"\"value\""}).get()); + + // JSONPath, path doesn't exist + assertArrayEquals( + new Object[] {}, + (Object[]) + Json.arrappend(client, gs(key), gs("$.c"), new GlideString[] {gs("\"value\"")}).get()); + + // Legacy path, path doesn't exist + var exception = + assertThrows( + ExecutionException.class, + () -> Json.arrappend(client, key, ".c", new String[] {"\"value\""}).get()); + + // Legacy path, the JSON value at path is not a array + exception = + assertThrows( + ExecutionException.class, + () -> Json.arrappend(client, key, ".a", new String[] {"\"value\""}).get()); + + exception = + assertThrows( + ExecutionException.class, + () -> + Json.arrappend(client, "non_existing_key", "$.b", new String[] {"\"six\""}).get()); + + exception = + assertThrows( + ExecutionException.class, + () -> Json.arrappend(client, "non_existing_key", ".b", new String[] {"\"six\""}).get()); + } + @Test @SneakyThrows public void arrinsert() { From ee0f86fd97bedfca1840a92a28e5e1e7dcc150aa Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Mon, 21 Oct 2024 17:05:04 -0700 Subject: [PATCH 031/180] Fix python CI (#2487) * Fix python CI Signed-off-by: Yury-Fridlyand Signed-off-by: Prateek Kumar Co-authored-by: Prateek Kumar --- .github/workflows/python.yml | 107 ++++++------------ CHANGELOG.md | 1 + .../search/test_ft_create.py | 5 + .../tests/tests_server_modules/test_ft.py | 7 ++ 4 files changed, 47 insertions(+), 73 deletions(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index c3aa78072b..ebb83894e6 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -42,11 +42,6 @@ concurrency: group: python-${{ github.head_ref || github.ref }} cancel-in-progress: true -permissions: - contents: read - # Allows the GITHUB_TOKEN to make an API call to generate an OIDC token. - id-token: write - jobs: load-engine-matrix: runs-on: ubuntu-latest @@ -291,75 +286,41 @@ jobs: path: | python/python/tests/pytest_report.html - start-self-hosted-runner: - if: github.event.pull_request.head.repo.owner.login == 'valkey-io' - runs-on: ubuntu-latest - environment: AWS_ACTIONS - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Start self hosted EC2 runner - uses: ./.github/workflows/start-self-hosted-runner - with: - role-to-assume: ${{ secrets.ROLE_TO_ASSUME }} - aws-region: ${{ secrets.AWS_REGION }} - ec2-instance-id: ${{ secrets.AWS_EC2_INSTANCE_ID }} - test-modules: - needs: [start-self-hosted-runner, load-engine-matrix] - name: Running Module Tests - runs-on: ${{ matrix.host.RUNNER }} - timeout-minutes: 35 - strategy: - fail-fast: false - matrix: - engine: ${{ fromJson(needs.load-engine-matrix.outputs.matrix) }} - python: - - "3.12" - host: - - { - OS: "ubuntu", - NAMED_OS: "linux", - RUNNER: ["self-hosted", "Linux", "ARM64"], - TARGET: "aarch64-unknown-linux-gnu", - } + if: (github.repository_owner == 'valkey-io' && github.event_name == 'workflow_dispatch') || github.event.pull_request.head.repo.owner.login == 'valkey-io' + environment: AWS_ACTIONS + name: Running Module Tests + runs-on: [self-hosted, linux, ARM64] + timeout-minutes: 15 - steps: - - name: Setup self-hosted runner access - if: ${{ contains(matrix.host.RUNNER, 'self-hosted') }} - run: sudo chown -R $USER:$USER /home/ubuntu/actions-runner/_work/valkey-glide - - - uses: actions/checkout@v4 - with: - submodules: recursive - - - name: Setup Python for self-hosted Ubuntu runners - run: | - sudo apt update -y - sudo apt upgrade -y - sudo apt install python3 python3-venv python3-pip -y - - - name: Build Python wrapper - uses: ./.github/workflows/build-python-wrapper - with: - os: ${{ matrix.host.OS }} - target: ${{ matrix.host.TARGET }} - github-token: ${{ secrets.GITHUB_TOKEN }} - engine-version: ${{ matrix.engine.version }} + steps: + - name: Setup self-hosted runner access + if: ${{ contains(matrix.host.RUNNER, 'self-hosted') }} + run: sudo chown -R $USER:$USER /home/ubuntu/actions-runner/_work/valkey-glide + + - uses: actions/checkout@v4 + with: + submodules: recursive + + - name: Build Python wrapper + uses: ./.github/workflows/build-python-wrapper + with: + os: ubuntu + target: aarch64-unknown-linux-gnu + github-token: ${{ secrets.GITHUB_TOKEN }} - - name: Test with pytest - working-directory: ./python - run: | - source .env/bin/activate - cd python/tests/ - pytest --asyncio-mode=auto --tls --cluster-endpoints=${{ secrets.MEMDB_MODULES_ENDPOINT }} -k server_modules --html=pytest_report.html --self-contained-html + - name: Test with pytest + working-directory: ./python + run: | + source .env/bin/activate + cd python/tests/ + pytest --asyncio-mode=auto --tls --cluster-endpoints=${{ secrets.MEMDB_MODULES_ENDPOINT }} -k server_modules --html=pytest_report.html --self-contained-html - - name: Upload test reports - if: always() - continue-on-error: true - uses: actions/upload-artifact@v4 - with: - name: modules-test-report-${{ matrix.host.TARGET }}-python-${{ matrix.python }}-server-${{ matrix.engine.version }} - path: | - python/python/tests/pytest_report.html + - name: Upload test reports + if: always() + continue-on-error: true + uses: actions/upload-artifact@v4 + with: + name: modules-test-report + path: | + python/python/tests/pytest_report.html diff --git a/CHANGELOG.md b/CHANGELOG.md index 464beda5b7..fc89b888bd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ * Java: Add modules CI ([#2388](https://github.com/valkey-io/valkey-glide/pull/2388), [#2404](https://github.com/valkey-io/valkey-glide/pull/2404), [#2416](https://github.com/valkey-io/valkey-glide/pull/2416)) * Node: Add modules CI ([#2472](https://github.com/valkey-io/valkey-glide/pull/2472)) +* Python: Fix modules CI ([#2487](https://github.com/valkey-io/valkey-glide/pull/2487)) ## 1.1.0 (2024-09-24) diff --git a/python/python/tests/tests_server_modules/search/test_ft_create.py b/python/python/tests/tests_server_modules/search/test_ft_create.py index c08346563b..eba7592698 100644 --- a/python/python/tests/tests_server_modules/search/test_ft_create.py +++ b/python/python/tests/tests_server_modules/search/test_ft_create.py @@ -43,6 +43,7 @@ async def test_ft_create(self, glide_client: GlideClusterClient): glide_client, index, fields, FtCreateOptions(DataType.HASH, prefixes) ) assert result == OK + assert await ft.dropindex(glide_client, indexName=index) == OK # Create an index with multiple fields with JSON data type. index2 = str(uuid.uuid4()) @@ -50,6 +51,7 @@ async def test_ft_create(self, glide_client: GlideClusterClient): glide_client, index2, fields, FtCreateOptions(DataType.JSON, prefixes) ) assert result == OK + assert await ft.dropindex(glide_client, indexName=index2) == OK # Create an index for vectors of size 2 # FT.CREATE hash_idx1 ON HASH PREFIX 1 hash: SCHEMA vec AS VEC VECTOR HNSW 6 DIM 2 TYPE FLOAT32 DISTANCE_METRIC L2 @@ -71,6 +73,7 @@ async def test_ft_create(self, glide_client: GlideClusterClient): glide_client, index3, fields, FtCreateOptions(DataType.HASH, prefixes) ) assert result == OK + assert await ft.dropindex(glide_client, indexName=index3) == OK # Create a 6-dimensional JSON index using the HNSW algorithm # FT.CREATE json_idx1 ON JSON PREFIX 1 json: SCHEMA $.vec AS VEC VECTOR HNSW 6 DIM 6 TYPE FLOAT32 DISTANCE_METRIC L2 @@ -92,12 +95,14 @@ async def test_ft_create(self, glide_client: GlideClusterClient): glide_client, index4, fields, FtCreateOptions(DataType.JSON, prefixes) ) assert result == OK + assert await ft.dropindex(glide_client, indexName=index4) == OK # Create an index without FtCreateOptions index5 = str(uuid.uuid4()) result = await ft.create(glide_client, index5, fields, FtCreateOptions()) assert result == OK + assert await ft.dropindex(glide_client, indexName=index5) == OK # TO-DO: # Add additional tests from VSS documentation that require a combination of commands to run. diff --git a/python/python/tests/tests_server_modules/test_ft.py b/python/python/tests/tests_server_modules/test_ft.py index 39b068246b..a44a68bfc8 100644 --- a/python/python/tests/tests_server_modules/test_ft.py +++ b/python/python/tests/tests_server_modules/test_ft.py @@ -30,6 +30,7 @@ async def test_ft_aliasadd(self, glide_client: GlideClusterClient): # Test ft.aliasadd successfully adds an alias to an existing index. await TestFt.create_test_index_hash_type(self, glide_client, indexName) assert await ft.aliasadd(glide_client, alias, indexName) == OK + assert await ft.dropindex(glide_client, indexName=indexName) == OK # Test ft.aliasadd for input of bytes type. indexNameString = str(uuid.uuid4()) @@ -37,6 +38,7 @@ async def test_ft_aliasadd(self, glide_client: GlideClusterClient): aliasNameBytes = b"alias-bytes" await TestFt.create_test_index_hash_type(self, glide_client, indexNameString) assert await ft.aliasadd(glide_client, aliasNameBytes, indexNameBytes) == OK + assert await ft.dropindex(glide_client, indexName=indexNameString) == OK @pytest.mark.parametrize("cluster_mode", [True]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) @@ -57,6 +59,8 @@ async def test_ft_aliasdel(self, glide_client: GlideClusterClient): assert await ft.aliasadd(glide_client, alias, indexName) == OK assert await ft.aliasdel(glide_client, bytes(alias, "utf-8")) == OK + assert await ft.dropindex(glide_client, indexName=indexName) == OK + @pytest.mark.parametrize("cluster_mode", [True]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) async def test_ft_aliasupdate(self, glide_client: GlideClusterClient): @@ -79,6 +83,9 @@ async def test_ft_aliasupdate(self, glide_client: GlideClusterClient): == OK ) + assert await ft.dropindex(glide_client, indexName=indexName) == OK + assert await ft.dropindex(glide_client, indexName=newIndexName) == OK + async def create_test_index_hash_type( self, glide_client: GlideClusterClient, index_name: TEncodable ): From 9073eb0bbcb12c490ddb77bf73158c684a4c2013 Mon Sep 17 00:00:00 2001 From: prateek-kumar-improving Date: Mon, 21 Oct 2024 18:59:00 -0700 Subject: [PATCH 032/180] Python: Add `FT.SEARCH` command (#2470) * Python: Add FT.SEARCH command --------- Signed-off-by: Prateek Kumar Signed-off-by: prateek-kumar-improving --- CHANGELOG.md | 1 + python/python/glide/__init__.py | 8 + .../glide/async_commands/server_modules/ft.py | 41 ++++- .../server_modules/ft_options/ft_constants.py | 16 +- .../ft_options/ft_search_options.py | 131 +++++++++++++++ .../search/test_ft_search.py | 154 ++++++++++++++++++ 6 files changed, 349 insertions(+), 2 deletions(-) create mode 100644 python/python/glide/async_commands/server_modules/ft_options/ft_search_options.py create mode 100644 python/python/tests/tests_server_modules/search/test_ft_search.py diff --git a/CHANGELOG.md b/CHANGELOG.md index fc89b888bd..19303101ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ #### Changes +* Python: Add FT.SEARCH command([#2470](https://github.com/valkey-io/valkey-glide/pull/2470)) * Python: Add commands FT.ALIASADD, FT.ALIASDEL, FT.ALIASUPDATE([#2471](https://github.com/valkey-io/valkey-glide/pull/2471)) * Python: Python FT.DROPINDEX command ([#2437](https://github.com/valkey-io/valkey-glide/pull/2437)) * Python: Python: Added FT.CREATE command([#2413](https://github.com/valkey-io/valkey-glide/pull/2413)) diff --git a/python/python/glide/__init__.py b/python/python/glide/__init__.py index 05910eb480..5bb31c75fb 100644 --- a/python/python/glide/__init__.py +++ b/python/python/glide/__init__.py @@ -49,6 +49,11 @@ VectorFieldAttributesHnsw, VectorType, ) +from glide.async_commands.server_modules.ft_options.ft_search_options import ( + FtSeachOptions, + FtSearchLimit, + ReturnField, +) from glide.async_commands.sorted_set import ( AggregationType, GeoSearchByBox, @@ -265,4 +270,7 @@ "VectorFieldAttributesFlat", "VectorFieldAttributesHnsw", "VectorType", + "FtSearchLimit", + "ReturnField", + "FtSeachOptions", ] diff --git a/python/python/glide/async_commands/server_modules/ft.py b/python/python/glide/async_commands/server_modules/ft.py index ccd1fd8735..82118e9070 100644 --- a/python/python/glide/async_commands/server_modules/ft.py +++ b/python/python/glide/async_commands/server_modules/ft.py @@ -3,7 +3,7 @@ module for `vector search` commands. """ -from typing import List, Optional, cast +from typing import List, Mapping, Optional, Union, cast from glide.async_commands.server_modules.ft_options.ft_constants import ( CommandNames, @@ -13,6 +13,9 @@ Field, FtCreateOptions, ) +from glide.async_commands.server_modules.ft_options.ft_search_options import ( + FtSeachOptions, +) from glide.constants import TOK, TEncodable from glide.glide_client import TGlideClient @@ -78,6 +81,42 @@ async def dropindex(client: TGlideClient, indexName: TEncodable) -> TOK: return cast(TOK, await client.custom_command(args)) +async def search( + client: TGlideClient, + indexName: TEncodable, + query: TEncodable, + options: Optional[FtSeachOptions], +) -> List[Union[int, Mapping[TEncodable, Mapping[TEncodable, TEncodable]]]]: + """ + Uses the provided query expression to locate keys within an index. Once located, the count and/or the content of indexed fields within those keys can be returned. + + Args: + client (TGlideClient): The client to execute the command. + indexName (TEncodable): The index name to search into. + query (TEncodable): The text query to search. + options (Optional[FtSeachOptions]): The search options. See `FtSearchOptions`. + + Returns: + List[Union[int, Mapping[TEncodable, Mapping[TEncodable, TEncodable]]]]: A two element array, where first element is count of documents in result set, and the second element, which has the format Mapping[TEncodable, Mapping[TEncodable, TEncodable]] is a mapping between document names and map of their attributes. + If count(option in `FtSearchOptions`) is set to true or limit(option in `FtSearchOptions`) is set to FtSearchLimit(0, 0), the command returns array with only one element - the count of the documents. + Examples: + For the following example to work the following must already exist: + - An index named "idx", with fields having identifiers as "a" and "b" and prefix as "{json:}" + - A key named {json:}1 with value {"a":1, "b":2} + + >>> from glide.async_commands.server_modules import ft + >>> result = await ft.search(glide_client, "idx", "*", options=FtSeachOptions(return_fields=[ReturnField(field_identifier="first"), ReturnField(field_identifier="second")])) + [1, { b'json:1': { b'first': b'42', b'second': b'33' } }] # The first element, 1 is the number of keys returned in the search result. The second element is a map of data queried per key. + """ + args: List[TEncodable] = [CommandNames.FT_SEARCH, indexName, query] + if options: + args.extend(options.toArgs()) + return cast( + List[Union[int, Mapping[TEncodable, Mapping[TEncodable, TEncodable]]]], + await client.custom_command(args), + ) + + async def aliasadd( client: TGlideClient, alias: TEncodable, indexName: TEncodable ) -> TOK: diff --git a/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py b/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py index 14fef2a681..541b286d83 100644 --- a/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py +++ b/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py @@ -8,6 +8,7 @@ class CommandNames: FT_CREATE = "FT.CREATE" FT_DROPINDEX = "FT.DROPINDEX" + FT_SEARCH = "FT.SEARCH" FT_ALIASADD = "FT.ALIASADD" FT_ALIASDEL = "FT.ALIASDEL" FT_ALIASUPDATE = "FT.ALIASUPDATE" @@ -15,7 +16,7 @@ class CommandNames: class FtCreateKeywords: """ - Keywords used in the FT.CREATE command statment. + Keywords used in the FT.CREATE command. """ SCHEMA = "SCHEMA" @@ -34,3 +35,16 @@ class FtCreateKeywords: M = "M" EF_CONSTRUCTION = "EF_CONSTRUCTION" EF_RUNTIME = "EF_RUNTIME" + + +class FtSeachKeywords: + """ + Keywords used in the FT.SEARCH command. + """ + + RETURN = "RETURN" + TIMEOUT = "TIMEOUT" + PARAMS = "PARAMS" + LIMIT = "LIMIT" + COUNT = "COUNT" + AS = "AS" diff --git a/python/python/glide/async_commands/server_modules/ft_options/ft_search_options.py b/python/python/glide/async_commands/server_modules/ft_options/ft_search_options.py new file mode 100644 index 0000000000..79f5422edc --- /dev/null +++ b/python/python/glide/async_commands/server_modules/ft_options/ft_search_options.py @@ -0,0 +1,131 @@ +# Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +from typing import List, Mapping, Optional + +from glide.async_commands.server_modules.ft_options.ft_constants import FtSeachKeywords +from glide.constants import TEncodable + + +class FtSearchLimit: + """ + This class represents the arguments for the LIMIT option of the FT.SEARCH command. + """ + + def __init__(self, offset: int, count: int): + """ + Initialize a new FtSearchLimit instance. + + Args: + offset (int): The number of keys to skip before returning the result for the FT.SEARCH command. + count (int): The total number of keys to be returned by FT.SEARCH command. + """ + self.offset = offset + self.count = count + + def toArgs(self) -> List[TEncodable]: + """ + Get the arguments for the LIMIT option of FT.SEARCH. + + Returns: + List[TEncodable]: A list of LIMIT option arguments. + """ + args: List[TEncodable] = [ + FtSeachKeywords.LIMIT, + str(self.offset), + str(self.count), + ] + return args + + +class ReturnField: + """ + This class represents the arguments for the RETURN option of the FT.SEARCH command. + """ + + def __init__( + self, field_identifier: TEncodable, alias: Optional[TEncodable] = None + ): + """ + Initialize a new ReturnField instance. + + Args: + field_identifier (TEncodable): The identifier for the field of the key that has to returned as a result of FT.SEARCH command. + alias (Optional[TEncodable]): The alias to override the name of the field in the FT.SEARCH result. + """ + self.field_identifier = field_identifier + self.alias = alias + + def toArgs(self) -> List[TEncodable]: + """ + Get the arguments for the RETURN option of FT.SEARCH. + + Returns: + List[TEncodable]: A list of RETURN option arguments. + """ + args: List[TEncodable] = [self.field_identifier] + if self.alias: + args.append(FtSeachKeywords.AS) + args.append(self.alias) + return args + + +class FtSeachOptions: + """ + This class represents the input options to be used in the FT.SEARCH command. + All fields in this class are optional inputs for FT.SEARCH. + """ + + def __init__( + self, + return_fields: Optional[List[ReturnField]] = None, + timeout: Optional[int] = None, + params: Optional[Mapping[TEncodable, TEncodable]] = None, + limit: Optional[FtSearchLimit] = None, + count: Optional[bool] = False, + ): + """ + Initialize the FT.SEARCH optional fields. + + Args: + return_fields (Optional[List[ReturnField]]): The fields of a key that are returned by FT.SEARCH command. See `ReturnField`. + timeout (Optional[int]): This value overrides the timeout parameter of the module. The unit for the timout is in milliseconds. + params (Optional[Mapping[TEncodable, TEncodable]]): Param key/value pairs that can be referenced from within the query expression. + limit (Optional[FtSearchLimit]): This option provides pagination capability. Only the keys that satisfy the offset and count values are returned. See `FtSearchLimit`. + count (Optional[bool]): This flag option suppresses returning the contents of keys. Only the number of keys is returned. + """ + self.return_fields = return_fields + self.timeout = timeout + self.params = params + self.limit = limit + self.count = count + + def toArgs(self) -> List[TEncodable]: + """ + Get the optional arguments for the FT.SEARCH command. + + Returns: + List[TEncodable]: + List of FT.SEARCH optional agruments. + """ + args: List[TEncodable] = [] + if self.return_fields: + args.append(FtSeachKeywords.RETURN) + return_field_args: List[TEncodable] = [] + for return_field in self.return_fields: + return_field_args.extend(return_field.toArgs()) + args.append(str(len(return_field_args))) + args.extend(return_field_args) + if self.timeout: + args.append(FtSeachKeywords.TIMEOUT) + args.append(str(self.timeout)) + if self.params: + args.append(FtSeachKeywords.PARAMS) + args.append(str(len(self.params))) + for name, value in self.params.items(): + args.append(name) + args.append(value) + if self.limit: + args.extend(self.limit.toArgs()) + if self.count: + args.append(FtSeachKeywords.COUNT) + return args diff --git a/python/python/tests/tests_server_modules/search/test_ft_search.py b/python/python/tests/tests_server_modules/search/test_ft_search.py new file mode 100644 index 0000000000..80d8319676 --- /dev/null +++ b/python/python/tests/tests_server_modules/search/test_ft_search.py @@ -0,0 +1,154 @@ +# Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +import json +import time +import uuid +from typing import List, Mapping, Union, cast + +import pytest +from glide.async_commands.server_modules import ft +from glide.async_commands.server_modules import json as GlideJson +from glide.async_commands.server_modules.ft_options.ft_create_options import ( + DataType, + FtCreateOptions, + NumericField, +) +from glide.async_commands.server_modules.ft_options.ft_search_options import ( + FtSeachOptions, + ReturnField, +) +from glide.config import ProtocolVersion +from glide.constants import OK, TEncodable +from glide.glide_client import GlideClusterClient + + +@pytest.mark.asyncio +class TestFtSearch: + sleep_wait_time = 0.5 # This value is in seconds + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_search(self, glide_client: GlideClusterClient): + prefix = "{json-search-" + str(uuid.uuid4()) + "}:" + json_key1 = prefix + str(uuid.uuid4()) + json_key2 = prefix + str(uuid.uuid4()) + json_value1 = {"a": 11111, "b": 2, "c": 3} + json_value2 = {"a": 22222, "b": 2, "c": 3} + prefixes: List[TEncodable] = [] + prefixes.append(prefix) + index = prefix + str(uuid.uuid4()) + + # Create an index + assert ( + await ft.create( + glide_client, + index, + schema=[ + NumericField("$.a", "a"), + NumericField("$.b", "b"), + ], + options=FtCreateOptions(DataType.JSON), + ) + == OK + ) + + # Create a json key + assert ( + await GlideJson.set(glide_client, json_key1, "$", json.dumps(json_value1)) + == OK + ) + assert ( + await GlideJson.set(glide_client, json_key2, "$", json.dumps(json_value2)) + == OK + ) + + # Wait for index to be updated to avoid this error - ResponseError: The index is under construction. + time.sleep(self.sleep_wait_time) + + # Search the index for string inputs + result1 = await ft.search( + glide_client, + index, + "*", + options=FtSeachOptions( + return_fields=[ + ReturnField(field_identifier="a", alias="a_new"), + ReturnField(field_identifier="b", alias="b_new"), + ] + ), + ) + # Check if we get the expected result from ft.search for string inputs + TestFtSearch._ft_search_deep_compare_result( + self, + result=result1, + json_key1=json_key1, + json_key2=json_key2, + json_value1=json_value1, + json_value2=json_value2, + fieldName1="a", + fieldName2="b", + ) + + # Search the index for byte inputs + result2 = await ft.search( + glide_client, + bytes(index, "utf-8"), + b"*", + options=FtSeachOptions( + return_fields=[ + ReturnField(field_identifier=b"a", alias=b"a_new"), + ReturnField(field_identifier=b"b", alias=b"b_new"), + ] + ), + ) + + # Check if we get the expected result from ft.search from byte inputs + TestFtSearch._ft_search_deep_compare_result( + self, + result=result2, + json_key1=json_key1, + json_key2=json_key2, + json_value1=json_value1, + json_value2=json_value2, + fieldName1="a", + fieldName2="b", + ) + + assert await ft.dropindex(glide_client, indexName=index) == OK + + def _ft_search_deep_compare_result( + self, + result: List[Union[int, Mapping[TEncodable, Mapping[TEncodable, TEncodable]]]], + json_key1: str, + json_key2: str, + json_value1: dict, + json_value2: dict, + fieldName1: str, + fieldName2: str, + ): + """ + Deep compare the keys and values in FT.SEARCH result array. + + Args: + result (List[Union[int, Mapping[TEncodable, Mapping[TEncodable, TEncodable]]]]): + json_key1 (str): The first key in search result. + json_key2 (str): The second key in the search result. + json_value1 (dict): The fields map for first key in the search result. + json_value2 (dict): The fields map for second key in the search result. + """ + assert len(result) == 2 + assert result[0] == 2 + searchResultMap: Mapping[TEncodable, Mapping[TEncodable, TEncodable]] = cast( + Mapping[TEncodable, Mapping[TEncodable, TEncodable]], result[1] + ) + expectedResultMap: Mapping[TEncodable, Mapping[TEncodable, TEncodable]] = { + json_key1.encode(): { + fieldName1.encode(): str(json_value1.get(fieldName1)).encode(), + fieldName2.encode(): str(json_value1.get(fieldName2)).encode(), + }, + json_key2.encode(): { + fieldName1.encode(): str(json_value2.get(fieldName1)).encode(), + fieldName2.encode(): str(json_value2.get(fieldName2)).encode(), + }, + } + assert searchResultMap == expectedResultMap From 3642038d37442cad2e7ebd248a6338372a8aba69 Mon Sep 17 00:00:00 2001 From: ikolomi Date: Mon, 21 Oct 2024 09:30:26 +0000 Subject: [PATCH 033/180] Glide-core UDS Socket Handling Rework: 1.Introduced a user-land mechanism for ensuring singleton behavior of the socket, rather than relying on OS-specific semantics. This addresses the issue where macOS and Linux report different errors when the socket path already exists. 2.Simplified the implementation by removing unnecessary abstractions, including redundant connection retry logic. Signed-off-by: ikolomi --- CHANGELOG.md | 1 + glide-core/src/retry_strategies.rs | 1 + glide-core/src/socket_listener.rs | 237 ++++++++++------------- glide-core/tests/test_socket_listener.rs | 15 +- 4 files changed, 117 insertions(+), 137 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 19303101ad..935450afcf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ #### Breaking Changes #### Fixes +* Core: UDS Socket Handling Rework ([#2482](https://github.com/valkey-io/valkey-glide/pull/2482)) #### Operational Enhancements diff --git a/glide-core/src/retry_strategies.rs b/glide-core/src/retry_strategies.rs index dbe5683347..d851cb63dd 100644 --- a/glide-core/src/retry_strategies.rs +++ b/glide-core/src/retry_strategies.rs @@ -56,6 +56,7 @@ pub(crate) fn get_exponential_backoff( } #[cfg(feature = "socket-layer")] +#[allow(dead_code)] pub(crate) fn get_fixed_interval_backoff( fixed_interval: u32, number_of_retries: u32, diff --git a/glide-core/src/socket_listener.rs b/glide-core/src/socket_listener.rs index 50445c881d..f823f908e5 100644 --- a/glide-core/src/socket_listener.rs +++ b/glide-core/src/socket_listener.rs @@ -11,11 +11,10 @@ use crate::connection_request::ConnectionRequest; use crate::errors::{error_message, error_type, RequestErrorType}; use crate::response; use crate::response::Response; -use crate::retry_strategies::get_fixed_interval_backoff; use bytes::Bytes; use directories::BaseDirs; -use dispose::{Disposable, Dispose}; use logger_core::{log_debug, log_error, log_info, log_trace, log_warn}; +use once_cell::sync::Lazy; use protobuf::{Chars, Message}; use redis::cluster_routing::{ MultipleNodeRoutingInfo, Route, RoutingInfo, SingleNodeRoutingInfo, SlotAddr, @@ -23,18 +22,18 @@ use redis::cluster_routing::{ use redis::cluster_routing::{ResponsePolicy, Routable}; use redis::{Cmd, PushInfo, RedisError, ScanStateRC, Value}; use std::cell::Cell; +use std::collections::HashSet; use std::rc::Rc; +use std::sync::RwLock; use std::{env, str}; use std::{io, thread}; use thiserror::Error; -use tokio::io::ErrorKind::AddrInUse; use tokio::net::{UnixListener, UnixStream}; use tokio::runtime::Builder; use tokio::sync::mpsc; use tokio::sync::mpsc::{channel, Sender}; use tokio::sync::Mutex; use tokio::task; -use tokio_retry::Retry; use tokio_util::task::LocalPoolHandle; use ClosingReason::*; use PipeListeningResult::*; @@ -53,20 +52,6 @@ pub const ZSET: &str = "zset"; pub const HASH: &str = "hash"; pub const STREAM: &str = "stream"; -/// struct containing all objects needed to bind to a socket and clean it. -struct SocketListener { - socket_path: String, - cleanup_socket: bool, -} - -impl Dispose for SocketListener { - fn dispose(self) { - if self.cleanup_socket { - close_socket(&self.socket_path); - } - } -} - /// struct containing all objects needed to read from a unix stream. struct UnixStreamListener { read_socket: Rc, @@ -734,109 +719,6 @@ async fn listen_on_client_stream(socket: UnixStream) { log_trace("client closing", "closing connection"); } -enum SocketCreationResult { - // Socket creation was successful, returned a socket listener. - Created(UnixListener), - // There's an existing a socket listener. - PreExisting, - // Socket creation failed with an error. - Err(io::Error), -} - -impl SocketListener { - fn new(socket_path: String) -> Self { - SocketListener { - socket_path, - // Don't cleanup the socket resources unless we know that the socket is in use, and owned by this listener. - cleanup_socket: false, - } - } - - /// Return true if it's possible to connect to socket. - async fn socket_is_available(&self) -> bool { - if UnixStream::connect(&self.socket_path).await.is_ok() { - return true; - } - - let retry_strategy = get_fixed_interval_backoff(10, 3); - - let action = || async { - UnixStream::connect(&self.socket_path) - .await - .map(|_| ()) - .map_err(|_| ()) - }; - let result = Retry::spawn(retry_strategy.get_iterator(), action).await; - result.is_ok() - } - - async fn get_socket_listener(&self) -> SocketCreationResult { - const RETRY_COUNT: u8 = 3; - let mut retries = RETRY_COUNT; - while retries > 0 { - match UnixListener::bind(self.socket_path.clone()) { - Ok(listener) => { - return SocketCreationResult::Created(listener); - } - Err(err) if err.kind() == AddrInUse => { - if self.socket_is_available().await { - return SocketCreationResult::PreExisting; - } else { - // socket file might still exist, even if nothing is listening on it. - close_socket(&self.socket_path); - retries -= 1; - continue; - } - } - Err(err) => { - return SocketCreationResult::Err(err); - } - } - } - SocketCreationResult::Err(io::Error::new( - io::ErrorKind::Other, - "Failed to connect to socket", - )) - } - - pub(crate) async fn listen_on_socket(&mut self, init_callback: InitCallback) - where - InitCallback: FnOnce(Result) + Send + 'static, - { - // Bind to socket - let listener = match self.get_socket_listener().await { - SocketCreationResult::Created(listener) => listener, - SocketCreationResult::Err(err) => { - log_info("listen_on_socket", format!("failed with error: {err}")); - init_callback(Err(err.to_string())); - return; - } - SocketCreationResult::PreExisting => { - init_callback(Ok(self.socket_path.clone())); - return; - } - }; - - self.cleanup_socket = true; - init_callback(Ok(self.socket_path.clone())); - let local_set_pool = LocalPoolHandle::new(num_cpus::get()); - loop { - match listener.accept().await { - Ok((stream, _addr)) => { - local_set_pool.spawn_pinned(move || listen_on_client_stream(stream)); - } - Err(err) => { - log_debug( - "listen_on_socket", - format!("Socket closed with error: `{err}`"), - ); - return; - } - } - } - } -} - #[derive(Debug)] /// Enum describing the reason that a socket listener stopped listening on a socket. pub enum ClosingReason { @@ -924,23 +806,114 @@ pub fn start_socket_listener_internal( init_callback: InitCallback, socket_path: Option, ) where - InitCallback: FnOnce(Result) + Send + 'static, + InitCallback: FnOnce(Result) + Send + Clone + 'static, { + static INITIALIZED_SOCKETS: Lazy>> = + Lazy::new(|| RwLock::new(HashSet::new())); + + let socket_path = socket_path.unwrap_or_else(get_socket_path); + + { + // Optimize for already initialized + let initialized_sockets = INITIALIZED_SOCKETS + .read() + .expect("Failed to acquire sockets db read guard"); + if initialized_sockets.contains(&socket_path) { + init_callback(Ok(socket_path.clone())); + return; + } + } + + // Retry with write lock, will be dropped upon the function completion + let mut sockets_write_guard = INITIALIZED_SOCKETS + .write() + .expect("Failed to acquire sockets db write guard"); + if sockets_write_guard.contains(&socket_path) { + init_callback(Ok(socket_path.clone())); + return; + } + + let (tx, rx) = std::sync::mpsc::channel(); + let socket_path_cloned = socket_path.clone(); + let init_callback_cloned = init_callback.clone(); + let tx_cloned = tx.clone(); thread::Builder::new() .name("socket_listener_thread".to_string()) .spawn(move || { - let runtime = Builder::new_current_thread().enable_all().build(); - match runtime { - Ok(runtime) => { - let mut listener = Disposable::new(SocketListener::new( - socket_path.unwrap_or_else(get_socket_path), - )); - runtime.block_on(listener.listen_on_socket(init_callback)); - } - Err(err) => init_callback(Err(err.to_string())), + let init_result = { + let runtime = match Builder::new_current_thread().enable_all().build() { + Err(err) => { + log_error( + "listen_on_socket", + format!("Error failed to create a new tokio thread: {err}"), + ); + return Err(err); + } + Ok(runtime) => runtime, + }; + + runtime.block_on(async move { + let listener_socket = match UnixListener::bind(socket_path_cloned.clone()) { + Err(err) => { + log_error( + "listen_on_socket", + format!("Error failed to bind listening socket: {err}"), + ); + return Err(err); + } + Ok(listener_socket) => listener_socket, + }; + + // Signal initialization is successful. + // IMPORTANT: + // tx.send() must be called before init_callback_cloned() to ensure runtimes, such as Python, can properly complete the main function + let _ = tx.send(true); + init_callback_cloned(Ok(socket_path_cloned.clone())); + + let local_set_pool = LocalPoolHandle::new(num_cpus::get()); + loop { + match listener_socket.accept().await { + Ok((stream, _addr)) => { + local_set_pool + .spawn_pinned(move || listen_on_client_stream(stream)); + } + Err(err) => { + log_error( + "listen_on_socket", + format!("Error accepting connection: {err}"), + ); + break; + } + } + } + + // ensure socket file removal + drop(listener_socket); + let _ = std::fs::remove_file(socket_path_cloned.clone()); + + // no more listening on socket - update the sockets db + let mut sockets_write_guard = INITIALIZED_SOCKETS + .write() + .expect("Failed to acquire sockets db write guard"); + sockets_write_guard.remove(&socket_path_cloned); + Ok(()) + }) }; + + if let Err(err) = init_result { + init_callback(Err(err.to_string())); + let _ = tx_cloned.send(false); + } + Ok(()) }) .expect("Thread spawn failed. Cannot report error because callback was moved."); + + // wait for thread initialization signaling, callback invocation is done in the thread + let _ = rx.recv().map(|res| { + if res { + sockets_write_guard.insert(socket_path); + } + }); } /// Creates a new thread with a main loop task listening on the socket for new connections. @@ -950,7 +923,7 @@ pub fn start_socket_listener_internal( /// * `init_callback` - called when the socket listener fails to initialize, with the reason for the failure. pub fn start_socket_listener(init_callback: InitCallback) where - InitCallback: FnOnce(Result) + Send + 'static, + InitCallback: FnOnce(Result) + Send + Clone + 'static, { start_socket_listener_internal(init_callback, None); } diff --git a/glide-core/tests/test_socket_listener.rs b/glide-core/tests/test_socket_listener.rs index c6e2b7b15d..6f2aa566b9 100644 --- a/glide-core/tests/test_socket_listener.rs +++ b/glide-core/tests/test_socket_listener.rs @@ -518,8 +518,10 @@ mod socket_listener { #[rstest] #[timeout(SHORT_STANDALONE_TEST_TIMEOUT)] fn test_working_after_socket_listener_was_dropped() { - let socket_path = - get_socket_path_from_name("test_working_after_socket_listener_was_dropped".to_string()); + let socket_path = get_socket_path_from_name(format!( + "{}_test_working_after_socket_listener_was_dropped", + std::process::id() + )); close_socket(&socket_path); // create a socket listener and drop it, to simulate a panic in a previous iteration. Builder::new_current_thread() @@ -528,6 +530,8 @@ mod socket_listener { .unwrap() .block_on(async { let _ = UnixListener::bind(socket_path.clone()).unwrap(); + // UDS sockets require explicit removal of the socket file + close_socket(&socket_path); }); const CALLBACK_INDEX: u32 = 99; @@ -554,9 +558,10 @@ mod socket_listener { #[rstest] #[timeout(SHORT_STANDALONE_TEST_TIMEOUT)] fn test_multiple_listeners_competing_for_the_socket() { - let socket_path = get_socket_path_from_name( - "test_multiple_listeners_competing_for_the_socket".to_string(), - ); + let socket_path = get_socket_path_from_name(format!( + "{}_test_multiple_listeners_competing_for_the_socket", + std::process::id() + )); close_socket(&socket_path); let server = Arc::new(RedisServer::new(ServerType::Tcp { tls: false })); From f977b919e9fe14d9fe3bb5f2835ea820667a5d5d Mon Sep 17 00:00:00 2001 From: James Xin Date: Tue, 22 Oct 2024 09:47:29 -0700 Subject: [PATCH 034/180] Java: add JSON.DEL and JSON.FORGET (#2490) * Java: add JSON.DEL and JSON.FORGET Signed-off-by: James Xin --- CHANGELOG.md | 1 + .../api/commands/servermodules/Json.java | 167 ++++++++++++++- .../api/commands/servermodules/JsonTest.java | 198 ++++++++++++++++-- java/integTest/build.gradle | 3 - .../test/java/glide/modules/JsonTests.java | 66 ++++-- 5 files changed, 396 insertions(+), 39 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 935450afcf..9297ff4644 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ * Java: Added `FT.AGGREGATE` ([#2466](https://github.com/valkey-io/valkey-glide/pull/2466)) * Java: Added `JSON.SET` and `JSON.GET` ([#2462](https://github.com/valkey-io/valkey-glide/pull/2462)) * Java: Added `JSON.ARRINSERT` and `JSON.ARRLEN` ([#2476](https://github.com/valkey-io/valkey-glide/pull/2476)) +* Java: Added `JSON.DEL` and `JSON.FORGET` ([#2490](https://github.com/valkey-io/valkey-glide/pull/2490)) * Java: Added `FT.ALIASADD`, `FT.ALIASDEL`, `FT.ALIASUPDATE` ([#2442](https://github.com/valkey-io/valkey-glide/pull/2442)) * Core: Update routing for commands from server modules ([#2461](https://github.com/valkey-io/valkey-glide/pull/2461)) * Node: Added `JSON.SET` and `JSON.GET` ([#2427](https://github.com/valkey-io/valkey-glide/pull/2427)) diff --git a/java/client/src/main/java/glide/api/commands/servermodules/Json.java b/java/client/src/main/java/glide/api/commands/servermodules/Json.java index 868ad6276f..4f43acafd4 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/Json.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/Json.java @@ -20,11 +20,13 @@ public class Json { private static final String JSON_PREFIX = "JSON."; - public static final String JSON_SET = JSON_PREFIX + "SET"; - public static final String JSON_GET = JSON_PREFIX + "GET"; + private static final String JSON_SET = JSON_PREFIX + "SET"; + private static final String JSON_GET = JSON_PREFIX + "GET"; private static final String JSON_ARRAPPEND = JSON_PREFIX + "ARRAPPEND"; private static final String JSON_ARRINSERT = JSON_PREFIX + "ARRINSERT"; private static final String JSON_ARRLEN = JSON_PREFIX + "ARRLEN"; + private static final String JSON_DEL = JSON_PREFIX + "DEL"; + private static final String JSON_FORGET = JSON_PREFIX + "FORGET"; private Json() {} @@ -41,7 +43,7 @@ private Json() {} * @return A simple "OK" response if the value is successfully set. * @example *
      {@code
      -     * String value = Json.set(client, "doc", ".", "{'a': 1.0, 'b': 2}").get();
      +     * String value = Json.set(client, "doc", ".", "{\"a\": 1.0, \"b\": 2}").get();
            * assert value.equals("OK");
            * }
      */ @@ -66,7 +68,7 @@ public static CompletableFuture set( * @return A simple "OK" response if the value is successfully set. * @example *
      {@code
      -     * String value = Json.set(client, gs("doc"), gs("."), gs("{'a': 1.0, 'b': 2}")).get();
      +     * String value = Json.set(client, gs("doc"), gs("."), gs("{\"a\": 1.0, \"b\": 2}")).get();
            * assert value.equals("OK");
            * }
      */ @@ -93,7 +95,7 @@ public static CompletableFuture set( * set because of setCondition, returns null. * @example *
      {@code
      -     * String value = Json.set(client, "doc", ".", "{'a': 1.0, 'b': 2}", ConditionalChange.ONLY_IF_DOES_NOT_EXIST).get();
      +     * String value = Json.set(client, "doc", ".", "{\"a\": 1.0, \"b\": 2}", ConditionalChange.ONLY_IF_DOES_NOT_EXIST).get();
            * assert value.equals("OK");
            * }
      */ @@ -122,7 +124,7 @@ public static CompletableFuture set( * set because of setCondition, returns null. * @example *
      {@code
      -     * String value = Json.set(client, gs("doc"), gs("."), gs("{'a': 1.0, 'b': 2}"), ConditionalChange.ONLY_IF_DOES_NOT_EXIST).get();
      +     * String value = Json.set(client, gs("doc"), gs("."), gs("{\"a\": 1.0, \"b\": 2}"), ConditionalChange.ONLY_IF_DOES_NOT_EXIST).get();
            * assert value.equals("OK");
            * }
      */ @@ -147,7 +149,7 @@ public static CompletableFuture set( * @example *
      {@code
            * String value = Json.get(client, "doc").get();
      -     * assert value.equals("{'a': 1.0, 'b': 2}");
      +     * assert value.equals("{\"a\": 1.0, \"b\": 2}");
            * }
      */ public static CompletableFuture get(@NonNull BaseClient client, @NonNull String key) { @@ -164,7 +166,7 @@ public static CompletableFuture get(@NonNull BaseClient client, @NonNull * @example *
      {@code
            * GlideString value = Json.get(client, gs("doc")).get();
      -     * assert value.equals(gs("{'a': 1.0, 'b': 2}"));
      +     * assert value.equals(gs("{\"a\": 1.0, \"b\": 2}"));
            * }
      */ public static CompletableFuture get( @@ -199,7 +201,7 @@ public static CompletableFuture get( * @example *
      {@code
            * String value = Json.get(client, "doc", new String[] {"$"}).get();
      -     * assert value.equals("{'a': 1.0, 'b': 2}");
      +     * assert value.equals("{\"a\": 1.0, \"b\": 2}");
            * String value = Json.get(client, "doc", new String[] {"$.a", "$.b"}).get();
            * assert value.equals("{\"$.a\": [1.0], \"$.b\": [2]}");
            * }
      @@ -236,7 +238,7 @@ public static CompletableFuture get( * @example *
      {@code
            * GlideString value = Json.get(client, gs("doc"), new GlideString[] {gs("$")}).get();
      -     * assert value.equals(gs("{'a': 1.0, 'b': 2}"));
      +     * assert value.equals(gs("{\"a\": 1.0, \"b\": 2}"));
            * GlideString value = Json.get(client, gs("doc"), new GlideString[] {gs("$.a"), gs("$.b")}).get();
            * assert value.equals(gs("{\"$.a\": [1.0], \"$.b\": [2]}"));
            * }
      @@ -702,6 +704,151 @@ public static CompletableFuture arrlen( return executeCommand(client, new GlideString[] {gs(JSON_ARRLEN), key}); } + /** + * Deletes the JSON document stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return The number of elements deleted. 0 if the key does not exist. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}");
      +     * Long result = Json.del(client, "doc").get();
      +     * assertEquals(result, 1L);
      +     * }
      + */ + public static CompletableFuture del(@NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new String[] {JSON_DEL, key}); + } + + /** + * Deletes the JSON document stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return The number of elements deleted. 0 if the key does not exist. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}");
      +     * Long result = Json.del(client, gs("doc")).get();
      +     * assertEquals(result, 1L);
      +     * }
      + */ + public static CompletableFuture del(@NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new GlideString[] {gs(JSON_DEL), key}); + } + + /** + * Deletes the JSON value at the specified path within the JSON document stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the value will be deleted. + * @return The number of elements deleted. 0 if the key does not exist, or if the JSON path is invalid or does not exist. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}");
      +     * Long result = Json.del(client, "doc", "$..a").get();
      +     * assertEquals(result, 2L);
      +     * }
      + */ + public static CompletableFuture del( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, new String[] {JSON_DEL, key, path}); + } + + /** + * Deletes the JSON value at the specified path within the JSON document stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the value will be deleted. + * @return The number of elements deleted. 0 if the key does not exist, or if the JSON path is invalid or does not exist. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}");
      +     * Long result = Json.del(client, gs("doc"), gs("$..a")).get();
      +     * assertEquals(result, 2L);
      +     * }
      + */ + public static CompletableFuture del( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand(client, new GlideString[] {gs(JSON_DEL), key, path}); + } + + /** + * Deletes the JSON document stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return The number of elements deleted. 0 if the key does not exist. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}");
      +     * Long result = Json.forget(client, "doc").get();
      +     * assertEquals(result, 1L);
      +     * }
      + */ + public static CompletableFuture forget(@NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new String[] {JSON_FORGET, key}); + } + + /** + * Deletes the JSON document stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return The number of elements deleted. 0 if the key does not exist. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}");
      +     * Long result = Json.forget(client, gs("doc")).get();
      +     * assertEquals(result, 1L);
      +     * }
      + */ + public static CompletableFuture forget( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new GlideString[] {gs(JSON_FORGET), key}); + } + + /** + * Deletes the JSON value at the specified path within the JSON document stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the value will be deleted. + * @return The number of elements deleted. 0 if the key does not exist, or if the JSON path is invalid or does not exist. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}");
      +     * Long result = Json.forget(client, "doc", "$..a").get();
      +     * assertEquals(result, 2L);
      +     * }
      + */ + public static CompletableFuture forget( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, new String[] {JSON_FORGET, key, path}); + } + + /** + * Deletes the JSON value at the specified path within the JSON document stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the value will be deleted. + * @return The number of elements deleted. 0 if the key does not exist, or if the JSON path is invalid or does not exist. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}");
      +     * Long result = Json.forget(client, gs("doc"), gs("$..a")).get();
      +     * assertEquals(result, 2L);
      +     * }
      + */ + public static CompletableFuture forget( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand(client, new GlideString[] {gs(JSON_FORGET), key, path}); + } + /** * A wrapper for custom command API. * diff --git a/java/client/src/test/java/glide/api/commands/servermodules/JsonTest.java b/java/client/src/test/java/glide/api/commands/servermodules/JsonTest.java index 81754474a3..0425831ea2 100644 --- a/java/client/src/test/java/glide/api/commands/servermodules/JsonTest.java +++ b/java/client/src/test/java/glide/api/commands/servermodules/JsonTest.java @@ -44,7 +44,7 @@ void set_returns_success() { String expectedResponseValue = "OK"; expectedResponse.complete(expectedResponseValue); when(glideClient - .customCommand(eq(new String[] {Json.JSON_SET, key, path, jsonValue})) + .customCommand(eq(new String[] {"JSON.SET", key, path, jsonValue})) .thenApply(any())) .thenReturn(expectedResponse); @@ -68,7 +68,7 @@ void set_binary_returns_success() { String expectedResponseValue = "OK"; expectedResponse.complete(expectedResponseValue); when(glideClient - .customCommand(eq(new GlideString[] {gs(Json.JSON_SET), key, path, jsonValue})) + .customCommand(eq(new GlideString[] {gs("JSON.SET"), key, path, jsonValue})) .thenApply(any())) .thenReturn(expectedResponse); @@ -94,7 +94,7 @@ void set_with_condition_returns_success() { expectedResponse.complete(expectedResponseValue); when(glideClient .customCommand( - eq(new String[] {Json.JSON_SET, key, path, jsonValue, setCondition.getValkeyApi()})) + eq(new String[] {"JSON.SET", key, path, jsonValue, setCondition.getValkeyApi()})) .thenApply(any())) .thenReturn(expectedResponse); @@ -123,7 +123,7 @@ void set_binary_with_condition_returns_success() { .customCommand( eq( new GlideString[] { - gs(Json.JSON_SET), key, path, jsonValue, gs(setCondition.getValkeyApi()) + gs("JSON.SET"), key, path, jsonValue, gs(setCondition.getValkeyApi()) })) .thenApply(any())) .thenReturn(expectedResponse); @@ -146,7 +146,7 @@ void get_with_no_path_returns_success() { CompletableFuture expectedResponse = new CompletableFuture<>(); String expectedResponseValue = "{\"a\": 1.0, \"b\": 2}"; expectedResponse.complete(expectedResponseValue); - when(glideClient.customCommand(eq(new String[] {Json.JSON_GET, key})).thenApply(any())) + when(glideClient.customCommand(eq(new String[] {"JSON.GET", key})).thenApply(any())) .thenReturn(expectedResponse); // exercise @@ -167,7 +167,7 @@ void get_binary_with_no_path_returns_success() { GlideString expectedResponseValue = gs("{\"a\": 1.0, \"b\": 2}"); expectedResponse.complete(expectedResponseValue); when(glideClient - .customCommand(eq(new GlideString[] {gs(Json.JSON_GET), key})) + .customCommand(eq(new GlideString[] {gs("JSON.GET"), key})) .thenApply(any())) .thenReturn(expectedResponse); @@ -192,7 +192,7 @@ void get_with_multiple_paths_returns_success() { String expectedResponseValue = "{\"a\": 1.0, \"b\": 2}"; expectedResponse.complete(expectedResponseValue); when(glideClient - .customCommand(eq(new String[] {Json.JSON_GET, key, path1, path2})) + .customCommand(eq(new String[] {"JSON.GET", key, path1, path2})) .thenApply(any())) .thenReturn(expectedResponse); @@ -217,7 +217,7 @@ void get_binary_with_multiple_paths_returns_success() { GlideString expectedResponseValue = gs("{\"a\": 1.0, \"b\": 2}"); expectedResponse.complete(expectedResponseValue); when(glideClient - .customCommand(eq(new GlideString[] {gs(Json.JSON_GET), key, path1, path2})) + .customCommand(eq(new GlideString[] {gs("JSON.GET"), key, path1, path2})) .thenApply(any())) .thenReturn(expectedResponse); @@ -243,7 +243,7 @@ void get_with_no_path_and_options_returns_success() { .customCommand( eq( ArrayTransformUtils.concatenateArrays( - new String[] {Json.JSON_GET, key}, options.toArgs()))) + new String[] {"JSON.GET", key}, options.toArgs()))) .thenApply(any())) .thenReturn(expectedResponse); @@ -270,7 +270,7 @@ void get_binary_with_no_path_and_options_returns_success() { .customCommand( eq( new ArgsBuilder() - .add(new GlideString[] {gs(Json.JSON_GET), key}) + .add(new GlideString[] {gs("JSON.GET"), key}) .add(options.toArgs()) .toArray())) .thenApply(any())) @@ -298,7 +298,7 @@ void get_with_multiple_paths_and_options_returns_success() { String expectedResponseValue = "{\"a\": 1.0, \"b\": 2}"; expectedResponse.complete(expectedResponseValue); ArrayList argsList = new ArrayList<>(); - argsList.add(Json.JSON_GET); + argsList.add("JSON.GET"); argsList.add(key); Collections.addAll(argsList, options.toArgs()); Collections.addAll(argsList, paths); @@ -330,7 +330,7 @@ void get_binary_with_multiple_paths_and_options_returns_success() { GlideString expectedResponseValue = gs("{\"a\": 1.0, \"b\": 2}"); expectedResponse.complete(expectedResponseValue); GlideString[] args = - new ArgsBuilder().add(Json.JSON_GET).add(key).add(options.toArgs()).add(paths).toArray(); + new ArgsBuilder().add("JSON.GET").add(key).add(options.toArgs()).add(paths).toArray(); when(glideClient.customCommand(eq(args)).thenApply(any())) .thenReturn(expectedResponse); @@ -342,4 +342,178 @@ void get_binary_with_multiple_paths_and_options_returns_success() { assertEquals(expectedResponse, actualResponse); assertEquals(expectedResponseValue, actualResponseValue); } + + @Test + @SneakyThrows + void del_with_no_path_returns_success() { + // setup + String key = "testKey"; + CompletableFuture expectedResponse = new CompletableFuture<>(); + Long expectedResponseValue = 1L; + expectedResponse.complete(expectedResponseValue); + when(glideClient.customCommand(eq(new String[] {"JSON.DEL", key})).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.del(glideClient, key); + Long actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void del_binary_with_no_path_returns_success() { + // setup + GlideString key = gs("testKey"); + CompletableFuture expectedResponse = new CompletableFuture<>(); + Long expectedResponseValue = 1L; + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new GlideString[] {gs("JSON.DEL"), key})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.del(glideClient, key); + Long actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void del_with_path_returns_success() { + // setup + String key = "testKey"; + String path = "$"; + CompletableFuture expectedResponse = new CompletableFuture<>(); + Long expectedResponseValue = 2L; + expectedResponse.complete(expectedResponseValue); + when(glideClient.customCommand(eq(new String[] {"JSON.DEL", key, path})).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.del(glideClient, key, path); + Long actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void del_binary_with_path_returns_success() { + // setup + GlideString key = gs("testKey"); + GlideString path = gs("$"); + CompletableFuture expectedResponse = new CompletableFuture<>(); + Long expectedResponseValue = 2L; + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new GlideString[] {gs("JSON.DEL"), key, path})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.del(glideClient, key, path); + Long actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void forget_with_no_path_returns_success() { + // setup + String key = "testKey"; + CompletableFuture expectedResponse = new CompletableFuture<>(); + Long expectedResponseValue = 1L; + expectedResponse.complete(expectedResponseValue); + when(glideClient.customCommand(eq(new String[] {"JSON.FORGET", key})).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.forget(glideClient, key); + Long actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void forget_binary_with_no_path_returns_success() { + // setup + GlideString key = gs("testKey"); + CompletableFuture expectedResponse = new CompletableFuture<>(); + Long expectedResponseValue = 1L; + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new GlideString[] {gs("JSON.FORGET"), key})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.forget(glideClient, key); + Long actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void forget_with_path_returns_success() { + // setup + String key = "testKey"; + String path = "$"; + CompletableFuture expectedResponse = new CompletableFuture<>(); + Long expectedResponseValue = 2L; + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new String[] {"JSON.FORGET", key, path})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.forget(glideClient, key, path); + Long actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void forget_binary_with_path_returns_success() { + // setup + GlideString key = gs("testKey"); + GlideString path = gs("$"); + CompletableFuture expectedResponse = new CompletableFuture<>(); + Long expectedResponseValue = 2L; + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new GlideString[] {gs("JSON.FORGET"), key, path})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.forget(glideClient, key, path); + Long actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } } diff --git a/java/integTest/build.gradle b/java/integTest/build.gradle index 8c6b48a3cc..93cea92401 100644 --- a/java/integTest/build.gradle +++ b/java/integTest/build.gradle @@ -29,9 +29,6 @@ dependencies { //lombok testCompileOnly 'org.projectlombok:lombok:1.18.32' testAnnotationProcessor 'org.projectlombok:lombok:1.18.32' - - // jsonassert - testImplementation 'org.skyscreamer:jsonassert:1.5.3' } def standaloneHosts = '' diff --git a/java/integTest/src/test/java/glide/modules/JsonTests.java b/java/integTest/src/test/java/glide/modules/JsonTests.java index 969b6b30a4..ad5ae55489 100644 --- a/java/integTest/src/test/java/glide/modules/JsonTests.java +++ b/java/integTest/src/test/java/glide/modules/JsonTests.java @@ -26,8 +26,6 @@ import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; -import org.skyscreamer.jsonassert.JSONAssert; -import org.skyscreamer.jsonassert.JSONCompareMode; public class JsonTests { @@ -65,12 +63,13 @@ public void json_set_get() { String getResult = Json.get(client, key).get(); - JSONAssert.assertEquals(jsonValue, getResult, JSONCompareMode.LENIENT); + assertEquals(JsonParser.parseString(jsonValue), JsonParser.parseString(getResult)); String getResultWithMultiPaths = Json.get(client, key, new String[] {"$.a", "$.b"}).get(); - JSONAssert.assertEquals( - "{\"$.a\":[1.0],\"$.b\":[2]}", getResultWithMultiPaths, JSONCompareMode.LENIENT); + assertEquals( + JsonParser.parseString("{\"$.a\":[1.0],\"$.b\":[2]}"), + JsonParser.parseString(getResultWithMultiPaths)); assertNull(Json.get(client, "non_existing_key").get()); assertEquals("[]", Json.get(client, key, new String[] {"$.d"}).get()); @@ -86,21 +85,20 @@ public void json_set_get_multiple_values() { GlideString getResult = Json.get(client, gs(key), new GlideString[] {gs("$..c")}).get(); - JSONAssert.assertEquals("[true, 1, 2]", getResult.getString(), JSONCompareMode.LENIENT); + assertEquals( + JsonParser.parseString("[true, 1, 2]"), JsonParser.parseString(getResult.getString())); String getResultWithMultiPaths = Json.get(client, key, new String[] {"$..c", "$.c"}).get(); - JSONAssert.assertEquals( - "{\"$..c\": [True, 1, 2], \"$.c\": [True]}", - getResultWithMultiPaths, - JSONCompareMode.LENIENT); + assertEquals( + JsonParser.parseString("{\"$..c\": [True, 1, 2], \"$.c\": [True]}"), + JsonParser.parseString(getResultWithMultiPaths)); assertEquals(OK, Json.set(client, key, "$..c", "\"new_value\"").get()); String getResultAfterSetNewValue = Json.get(client, key, new String[] {"$..c"}).get(); - JSONAssert.assertEquals( - "[\"new_value\", \"new_value\", \"new_value\"]", - getResultAfterSetNewValue, - JSONCompareMode.LENIENT); + assertEquals( + JsonParser.parseString("[\"new_value\", \"new_value\", \"new_value\"]"), + JsonParser.parseString(getResultAfterSetNewValue)); } @Test @@ -322,4 +320,44 @@ public void arrlen() { res = Json.arrlen(client, key).get(); assertEquals(5L, res); } + + @Test + @SneakyThrows + public void json_del() { + String key = UUID.randomUUID().toString(); + assertEquals( + OK, + Json.set(client, key, "$", "{\"a\": 1.0, \"b\": {\"a\": 1, \"b\": 2.5, \"c\": true}}") + .get()); + assertEquals(2L, Json.del(client, key, "$..a").get()); + assertEquals("[]", Json.get(client, key, new String[] {"$..a"}).get()); + String expectedGetResult = "{\"b\": {\"b\": 2.5, \"c\": true}}"; + String actualGetResult = Json.get(client, key).get(); + assertEquals( + JsonParser.parseString(expectedGetResult), JsonParser.parseString(actualGetResult)); + + assertEquals(1L, Json.del(client, gs(key), gs("$")).get()); + assertEquals(0L, Json.del(client, key).get()); + assertNull(Json.get(client, key, new String[] {"$"}).get()); + } + + @Test + @SneakyThrows + public void json_forget() { + String key = UUID.randomUUID().toString(); + assertEquals( + OK, + Json.set(client, key, "$", "{\"a\": 1.0, \"b\": {\"a\": 1, \"b\": 2.5, \"c\": true}}") + .get()); + assertEquals(2L, Json.forget(client, key, "$..a").get()); + assertEquals("[]", Json.get(client, key, new String[] {"$..a"}).get()); + String expectedGetResult = "{\"b\": {\"b\": 2.5, \"c\": true}}"; + String actualGetResult = Json.get(client, key).get(); + assertEquals( + JsonParser.parseString(expectedGetResult), JsonParser.parseString(actualGetResult)); + + assertEquals(1L, Json.forget(client, gs(key), gs("$")).get()); + assertEquals(0L, Json.forget(client, key).get()); + assertNull(Json.get(client, key, new String[] {"$"}).get()); + } } From f65b0fe29b55e9d0fcb00a86998e82472f9dac86 Mon Sep 17 00:00:00 2001 From: tjzhang-BQ <111323543+tjzhang-BQ@users.noreply.github.com> Date: Tue, 22 Oct 2024 11:33:15 -0700 Subject: [PATCH 035/180] Node: Add command JSON.TOGGLE (#2491) * Node: Add command JSON.TOGGLE Signed-off-by: TJ Zhang Co-authored-by: TJ Zhang --- CHANGELOG.md | 1 + node/src/server-modules/GlideJson.ts | 70 ++++++++++++++++++++++++++-- node/tests/ServerModules.test.ts | 52 +++++++++++++++++++++ 3 files changed, 118 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9297ff4644..974fcec9a6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ * Core: Update routing for commands from server modules ([#2461](https://github.com/valkey-io/valkey-glide/pull/2461)) * Node: Added `JSON.SET` and `JSON.GET` ([#2427](https://github.com/valkey-io/valkey-glide/pull/2427)) * Java: Added `JSON.ARRAPPEND` ([#2489](https://github.com/valkey-io/valkey-glide/pull/2489)) +* Node: Added `JSON.TOGGLE` ([#2491](https://github.com/valkey-io/valkey-glide/pull/2491)) #### Breaking Changes diff --git a/node/src/server-modules/GlideJson.ts b/node/src/server-modules/GlideJson.ts index 6dd57b16d3..9cab5092e4 100644 --- a/node/src/server-modules/GlideJson.ts +++ b/node/src/server-modules/GlideJson.ts @@ -7,7 +7,7 @@ import { ConditionalChange } from "../Commands"; import { GlideClient } from "../GlideClient"; import { GlideClusterClient, RouteOption } from "../GlideClusterClient"; -export type ReturnTypeJson = GlideString | (GlideString | null)[]; +export type ReturnTypeJson = T | (T | null)[]; /** * Represents options for formatting JSON data, to be used in the [JSON.GET](https://valkey.io/commands/json.get/) command. @@ -80,6 +80,7 @@ export class GlideJson { /** * Sets the JSON value at the specified `path` stored at `key`. * + * @param client The client to execute the command. * @param key - The key of the JSON document. * @param path - Represents the path within the JSON document where the value will be set. * The key will be modified only if `value` is added as the last child in the specified `path`, or if the specified `path` acts as the parent of a new child being added. @@ -123,8 +124,11 @@ export class GlideJson { /** * Retrieves the JSON value at the specified `paths` stored at `key`. * + * @param client The client to execute the command. * @param key - The key of the JSON document. - * @param options - Options for formatting the byte representation of the JSON data. See {@link JsonGetOptions}. + * @param options - (Optional) Additional parameters: + * - (Optional) Options for formatting the byte representation of the JSON data. See {@link JsonGetOptions}. + * - (Optional) `decoder`: see {@link DecoderOption}. * @returns ReturnTypeJson: * - If one path is given: * - For JSONPath (path starts with `$`): @@ -164,10 +168,10 @@ export class GlideJson { * ``` */ static async get( - client: GlideClient | GlideClusterClient, + client: BaseClient, key: GlideString, options?: JsonGetOptions & DecoderOption, - ): Promise { + ): Promise> { const args = ["JSON.GET", key]; if (options) { @@ -175,6 +179,62 @@ export class GlideJson { args.push(...optionArgs); } - return _executeCommand(client, args, options); + return _executeCommand>( + client, + args, + options, + ); + } + + /** + * Toggles a Boolean value stored at the specified `path` within the JSON document stored at `key`. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param options - (Optional) Additional parameters: + * - (Optional) The JSONPath to specify. Defaults to the root if not specified. + * @returns - For JSONPath (`path` starts with `$`), returns a list of boolean replies for every possible path, with the toggled boolean value, + * or null for JSON values matching the path that are not boolean. + * - For legacy path (`path` doesn't starts with `$`), returns the value of the toggled boolean in `path`. + * - Note that when sending legacy path syntax, If `path` doesn't exist or the value at `path` isn't a boolean, an error is raised. + * + * @example + * ```typescript + * const value = {bool: true, nested: {bool: false, nested: {bool: 10}}}; + * const jsonStr = JSON.stringify(value); + * const resultSet = await GlideJson.set("doc", "$", jsonStr); + * // Output: 'OK' + * + * const resultToggle = await.GlideJson.toggle(client, "doc", "$.bool") + * // Output: [false, true, null] - Indicates successful toggling of the Boolean values at path '$.bool' in the key stored at `doc`. + * + * const resultToggle = await.GlideJson.toggle(client, "doc", "bool") + * // Output: true - Indicates successful toggling of the Boolean value at path 'bool' in the key stored at `doc`. + * + * const resultToggle = await.GlideJson.toggle(client, "doc", "bool") + * // Output: true - Indicates successful toggling of the Boolean value at path 'bool' in the key stored at `doc`. + * + * const jsonGetStr = await GlideJson.get(client, "doc", "$"); + * console.log(JSON.stringify(jsonGetStr)); + * // Output: [{bool: true, nested: {bool: true, nested: {bool: 10}}}] - The updated JSON value in the key stored at `doc`. + * + * // Without specifying a path, the path defaults to root. + * console.log(await GlideJson.set(client, "doc2", ".", true)); // Output: "OK" + * console.log(await GlideJson.toggle(client,"doc2")); // Output: "false" + * console.log(await GlideJson.toggle(client, "doc2")); // Output: "true" + * ``` + */ + static async toggle( + client: BaseClient, + key: GlideString, + options?: { path: GlideString }, + ): Promise> { + const args = ["JSON.TOGGLE", key]; + + if (options !== undefined) { + args.push(options.path); + } + + return _executeCommand>(client, args); } } diff --git a/node/tests/ServerModules.test.ts b/node/tests/ServerModules.test.ts index 855dfd4305..fad81f2bbd 100644 --- a/node/tests/ServerModules.test.ts +++ b/node/tests/ServerModules.test.ts @@ -17,6 +17,7 @@ import { InfoOptions, JsonGetOptions, ProtocolVersion, + RequestError, } from ".."; import { ValkeyCluster } from "../../utils/TestUtils"; import { @@ -227,4 +228,55 @@ describe("GlideJson", () => { expect(result).toEqual(expectedResult2); }, ); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "json.toggle tests", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption(cluster.getAddresses(), protocol), + ); + const key = uuidv4(); + const key2 = uuidv4(); + const jsonValue = { + bool: true, + nested: { bool: false, nested: { bool: 10 } }, + }; + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + expect( + await GlideJson.toggle(client, key, { path: "$..bool" }), + ).toEqual([false, true, null]); + expect(await GlideJson.toggle(client, key, { path: "bool" })).toBe( + true, + ); + expect( + await GlideJson.toggle(client, key, { path: "$.non_existing" }), + ).toEqual([]); + expect( + await GlideJson.toggle(client, key, { path: "$.nested" }), + ).toEqual([null]); + + // testing behavior with default pathing + expect(await GlideJson.set(client, key2, ".", "true")).toBe("OK"); + expect(await GlideJson.toggle(client, key2)).toBe(false); + expect(await GlideJson.toggle(client, key2)).toBe(true); + + // expect request errors + await expect( + GlideJson.toggle(client, key, { path: "nested" }), + ).rejects.toThrow(RequestError); + await expect( + GlideJson.toggle(client, key, { path: ".non_existing" }), + ).rejects.toThrow(RequestError); + await expect( + GlideJson.toggle(client, "non_existing_key", { path: "$" }), + ).rejects.toThrow(RequestError); + }, + ); }); From bea5434aa1a8490e3e2d2c830046dd0f3958ae1c Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Tue, 22 Oct 2024 15:59:01 -0700 Subject: [PATCH 036/180] Java: fix IT (#2497) fix IT Signed-off-by: Yury-Fridlyand --- java/integTest/src/test/java/glide/SharedCommandTests.java | 1 + 1 file changed, 1 insertion(+) diff --git a/java/integTest/src/test/java/glide/SharedCommandTests.java b/java/integTest/src/test/java/glide/SharedCommandTests.java index 88e0256a69..d9517eb6a3 100644 --- a/java/integTest/src/test/java/glide/SharedCommandTests.java +++ b/java/integTest/src/test/java/glide/SharedCommandTests.java @@ -1119,6 +1119,7 @@ public void non_UTF8_GlideString_map_with_geospatial(BaseClient client) { @ParameterizedTest(autoCloseArguments = false) @MethodSource("getClients") public void non_UTF8_GlideString_map_of_arrays(BaseClient client) { + assumeTrue(SERVER_VERSION.isGreaterThanOrEqualTo("7.0.0")); byte[] nonUTF8Bytes = new byte[] {(byte) 0xEE}; GlideString key = gs(UUID.randomUUID().toString()); GlideString nonUTF8Key = gs(new byte[] {(byte) 0xFE}); From 0a11fdec64bb0bd851f0dd8c46803c1e860e9820 Mon Sep 17 00:00:00 2001 From: prateek-kumar-improving Date: Tue, 22 Oct 2024 16:11:56 -0700 Subject: [PATCH 037/180] Python `FT.INFO` command added (#2494) * Python FT.INFO command added --------- Signed-off-by: Prateek Kumar --- CHANGELOG.md | 1 + python/python/glide/__init__.py | 2 + .../glide/async_commands/server_modules/ft.py | 67 +++++++++-- .../server_modules/ft_options/ft_constants.py | 1 + python/python/glide/constants.py | 15 +++ .../tests/tests_server_modules/test_ft.py | 110 ++++++++++++++++-- 6 files changed, 182 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 974fcec9a6..8c16c45936 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ #### Changes +* Python: Python FT.INFO command added([#2429](https://github.com/valkey-io/valkey-glide/pull/2494)) * Python: Add FT.SEARCH command([#2470](https://github.com/valkey-io/valkey-glide/pull/2470)) * Python: Add commands FT.ALIASADD, FT.ALIASDEL, FT.ALIASUPDATE([#2471](https://github.com/valkey-io/valkey-glide/pull/2471)) * Python: Python FT.DROPINDEX command ([#2437](https://github.com/valkey-io/valkey-glide/pull/2437)) diff --git a/python/python/glide/__init__.py b/python/python/glide/__init__.py index 5bb31c75fb..b690e81137 100644 --- a/python/python/glide/__init__.py +++ b/python/python/glide/__init__.py @@ -104,6 +104,7 @@ from glide.constants import ( OK, TOK, + FtInfoResponse, TClusterResponse, TEncodable, TFunctionListResponse, @@ -170,6 +171,7 @@ "TResult", "TXInfoStreamFullResponse", "TXInfoStreamResponse", + "FtInfoResponse", # Commands "BitEncoding", "BitFieldGet", diff --git a/python/python/glide/async_commands/server_modules/ft.py b/python/python/glide/async_commands/server_modules/ft.py index 82118e9070..dbfa0b9cb7 100644 --- a/python/python/glide/async_commands/server_modules/ft.py +++ b/python/python/glide/async_commands/server_modules/ft.py @@ -16,7 +16,7 @@ from glide.async_commands.server_modules.ft_options.ft_search_options import ( FtSeachOptions, ) -from glide.constants import TOK, TEncodable +from glide.constants import TOK, FtInfoResponse, TEncodable from glide.glide_client import TGlideClient @@ -39,7 +39,7 @@ async def create( TOK: A simple "OK" response. Examples: - >>> from glide.async_commands.server_modules import ft + >>> from glide import ft >>> schema: List[Field] = [] >>> field: TextField = TextField("title") >>> schema.append(field) @@ -72,7 +72,7 @@ async def dropindex(client: TGlideClient, indexName: TEncodable) -> TOK: Examples: For the following example to work, an index named 'idx' must be already created. If not created, you will get an error. - >>> from glide.async_commands.server_modules import ft + >>> from glide import ft >>> indexName = "idx" >>> result = await ft.dropindex(glide_client, indexName) 'OK' # Indicates successful deletion/dropping of index named 'idx' @@ -99,12 +99,13 @@ async def search( Returns: List[Union[int, Mapping[TEncodable, Mapping[TEncodable, TEncodable]]]]: A two element array, where first element is count of documents in result set, and the second element, which has the format Mapping[TEncodable, Mapping[TEncodable, TEncodable]] is a mapping between document names and map of their attributes. If count(option in `FtSearchOptions`) is set to true or limit(option in `FtSearchOptions`) is set to FtSearchLimit(0, 0), the command returns array with only one element - the count of the documents. + Examples: For the following example to work the following must already exist: - An index named "idx", with fields having identifiers as "a" and "b" and prefix as "{json:}" - A key named {json:}1 with value {"a":1, "b":2} - >>> from glide.async_commands.server_modules import ft + >>> from glide import ft >>> result = await ft.search(glide_client, "idx", "*", options=FtSeachOptions(return_fields=[ReturnField(field_identifier="first"), ReturnField(field_identifier="second")])) [1, { b'json:1': { b'first': b'42', b'second': b'33' } }] # The first element, 1 is the number of keys returned in the search result. The second element is a map of data queried per key. """ @@ -132,7 +133,7 @@ async def aliasadd( TOK: A simple "OK" response. Examples: - >>> from glide.async_commands.server_modules import ft + >>> from glide import ft >>> result = await ft.aliasadd(glide_client, "myalias", "myindex") 'OK' # Indicates the successful addition of the alias named "myalias" for the index. """ @@ -152,7 +153,7 @@ async def aliasdel(client: TGlideClient, alias: TEncodable) -> TOK: TOK: A simple "OK" response. Examples: - >>> from glide.async_commands.server_modules import ft + >>> from glide import ft >>> result = await ft.aliasdel(glide_client, "myalias") 'OK' # Indicates the successful deletion of the alias named "myalias" """ @@ -175,9 +176,61 @@ async def aliasupdate( TOK: A simple "OK" response. Examples: - >>> from glide.async_commands.server_modules import ft + >>> from glide import ft >>> result = await ft.aliasupdate(glide_client, "myalias", "myindex") 'OK' # Indicates the successful update of the alias to point to the index named "myindex" """ args: List[TEncodable] = [CommandNames.FT_ALIASUPDATE, alias, indexName] return cast(TOK, await client.custom_command(args)) + + +async def info(client: TGlideClient, indexName: TEncodable) -> FtInfoResponse: + """ + Returns information about a given index. + + Args: + client (TGlideClient): The client to execute the command. + indexName (TEncodable): The index name for which the information has to be returned. + + Returns: + FtInfoResponse: Nested maps with info about the index. See example for more details. See `FtInfoResponse`. + + Examples: + An index with name 'myIndex', 1 text field and 1 vector field is already created for gettting the output of this example. + >>> from glide import ft + >>> result = await ft.info(glide_client, "myIndex") + [ + b'index_name', + b'myIndex', + b'creation_timestamp', 1729531116945240, + b'key_type', b'JSON', + b'key_prefixes', [b'key-prefix'], + b'fields', [ + [ + b'identifier', b'$.vec', + b'field_name', b'VEC', + b'type', b'VECTOR', + b'option', b'', + b'vector_params', [ + b'algorithm', b'HNSW', b'data_type', b'FLOAT32', b'dimension', 2, b'distance_metric', b'L2', b'initial_capacity', 1000, b'current_capacity', 1000, b'maximum_edges', 16, b'ef_construction', 200, b'ef_runtime', 10, b'epsilon', b'0.01' + ] + ], + [ + b'identifier', b'$.text-field', + b'field_name', b'text-field', + b'type', b'TEXT', + b'option', b'' + ] + ], + b'space_usage', 653351, + b'fulltext_space_usage', 0, + b'vector_space_usage', 653351, + b'num_docs', 0, + b'num_indexed_vectors', 0, + b'current_lag', 0, + b'index_status', b'AVAILABLE', + b'index_degradation_percentage', 0 + ] + """ + args: List[TEncodable] = [CommandNames.FT_INFO, indexName] + return cast(FtInfoResponse, await client.custom_command(args)) diff --git a/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py b/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py index 541b286d83..0077c8c3f3 100644 --- a/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py +++ b/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py @@ -9,6 +9,7 @@ class CommandNames: FT_CREATE = "FT.CREATE" FT_DROPINDEX = "FT.DROPINDEX" FT_SEARCH = "FT.SEARCH" + FT_INFO = "FT.INFO" FT_ALIASADD = "FT.ALIASADD" FT_ALIASDEL = "FT.ALIASDEL" FT_ALIASUPDATE = "FT.ALIASUPDATE" diff --git a/python/python/glide/constants.py b/python/python/glide/constants.py index 754aacf6fa..4ecd2003a3 100644 --- a/python/python/glide/constants.py +++ b/python/python/glide/constants.py @@ -74,3 +74,18 @@ List[Mapping[bytes, Union[bytes, int, List[List[Union[bytes, int]]]]]], ], ] + +FtInfoResponse = Mapping[ + TEncodable, + Union[ + TEncodable, + int, + List[TEncodable], + List[ + Mapping[ + TEncodable, + Union[TEncodable, Mapping[TEncodable, Union[TEncodable, int]]], + ] + ], + ], +] diff --git a/python/python/tests/tests_server_modules/test_ft.py b/python/python/tests/tests_server_modules/test_ft.py index a44a68bfc8..812e184d7e 100644 --- a/python/python/tests/tests_server_modules/test_ft.py +++ b/python/python/tests/tests_server_modules/test_ft.py @@ -1,14 +1,19 @@ # Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 import uuid -from typing import List +from typing import List, Mapping, Union, cast import pytest from glide.async_commands.server_modules import ft from glide.async_commands.server_modules.ft_options.ft_create_options import ( DataType, + DistanceMetricType, Field, FtCreateOptions, TextField, + VectorAlgorithm, + VectorField, + VectorFieldAttributesHnsw, + VectorType, ) from glide.config import ProtocolVersion from glide.constants import OK, TEncodable @@ -18,6 +23,17 @@ @pytest.mark.asyncio class TestFt: + SearchResultField = Mapping[ + TEncodable, Union[TEncodable, Mapping[TEncodable, Union[TEncodable, int]]] + ] + + SerchResultFieldsList = List[ + Mapping[ + TEncodable, + Union[TEncodable, Mapping[TEncodable, Union[TEncodable, int]]], + ] + ] + @pytest.mark.parametrize("cluster_mode", [True]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) async def test_ft_aliasadd(self, glide_client: GlideClusterClient): @@ -28,7 +44,7 @@ async def test_ft_aliasadd(self, glide_client: GlideClusterClient): await ft.aliasadd(glide_client, alias, indexName) # Test ft.aliasadd successfully adds an alias to an existing index. - await TestFt.create_test_index_hash_type(self, glide_client, indexName) + await TestFt._create_test_index_hash_type(self, glide_client, indexName) assert await ft.aliasadd(glide_client, alias, indexName) == OK assert await ft.dropindex(glide_client, indexName=indexName) == OK @@ -36,7 +52,7 @@ async def test_ft_aliasadd(self, glide_client: GlideClusterClient): indexNameString = str(uuid.uuid4()) indexNameBytes = bytes(indexNameString, "utf-8") aliasNameBytes = b"alias-bytes" - await TestFt.create_test_index_hash_type(self, glide_client, indexNameString) + await TestFt._create_test_index_hash_type(self, glide_client, indexNameString) assert await ft.aliasadd(glide_client, aliasNameBytes, indexNameBytes) == OK assert await ft.dropindex(glide_client, indexName=indexNameString) == OK @@ -45,7 +61,7 @@ async def test_ft_aliasadd(self, glide_client: GlideClusterClient): async def test_ft_aliasdel(self, glide_client: GlideClusterClient): indexName: TEncodable = str(uuid.uuid4()) alias: str = "alias" - await TestFt.create_test_index_hash_type(self, glide_client, indexName) + await TestFt._create_test_index_hash_type(self, glide_client, indexName) # Test if deleting a non existent alias throws an error. with pytest.raises(RequestError): @@ -66,12 +82,12 @@ async def test_ft_aliasdel(self, glide_client: GlideClusterClient): async def test_ft_aliasupdate(self, glide_client: GlideClusterClient): indexName: str = str(uuid.uuid4()) alias: str = "alias" - await TestFt.create_test_index_hash_type(self, glide_client, indexName) + await TestFt._create_test_index_hash_type(self, glide_client, indexName) assert await ft.aliasadd(glide_client, alias, indexName) == OK newAliasName: str = "newAlias" newIndexName: str = str(uuid.uuid4()) - await TestFt.create_test_index_hash_type(self, glide_client, newIndexName) + await TestFt._create_test_index_hash_type(self, glide_client, newIndexName) assert await ft.aliasadd(glide_client, newAliasName, newIndexName) == OK # Test if updating an already existing alias to point to an existing index returns "OK". @@ -86,7 +102,7 @@ async def test_ft_aliasupdate(self, glide_client: GlideClusterClient): assert await ft.dropindex(glide_client, indexName=indexName) == OK assert await ft.dropindex(glide_client, indexName=newIndexName) == OK - async def create_test_index_hash_type( + async def _create_test_index_hash_type( self, glide_client: GlideClusterClient, index_name: TEncodable ): # Helper function used for creating a basic index with hash data type with one text field. @@ -102,3 +118,83 @@ async def create_test_index_hash_type( glide_client, index_name, fields, FtCreateOptions(DataType.HASH, prefixes) ) assert result == OK + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_info(self, glide_client: GlideClusterClient): + indexName = str(uuid.uuid4()) + await TestFt._create_test_index_with_vector_field( + self, glide_client=glide_client, index_name=indexName + ) + result = await ft.info(glide_client, indexName) + assert await ft.dropindex(glide_client, indexName=indexName) == OK + + assert indexName.encode() == result.get(b"index_name") + assert b"JSON" == result.get(b"key_type") + assert [b"key-prefix"] == result.get(b"key_prefixes") + + # Get vector and text fields from the fields array. + fields: TestFt.SerchResultFieldsList = cast( + TestFt.SerchResultFieldsList, result.get(b"fields") + ) + assert len(fields) == 2 + textField: TestFt.SearchResultField = {} + vectorField: TestFt.SearchResultField = {} + if fields[0].get(b"type") == b"VECTOR": + vectorField = cast(TestFt.SearchResultField, fields[0]) + textField = cast(TestFt.SearchResultField, fields[1]) + else: + vectorField = cast(TestFt.SearchResultField, fields[1]) + textField = cast(TestFt.SearchResultField, fields[0]) + + # Compare vector field arguments + assert b"$.vec" == vectorField.get(b"identifier") + assert b"VECTOR" == vectorField.get(b"type") + assert b"VEC" == vectorField.get(b"field_name") + vectorFieldParams: Mapping[TEncodable, Union[TEncodable, int]] = cast( + Mapping[TEncodable, Union[TEncodable, int]], + vectorField.get(b"vector_params"), + ) + assert DistanceMetricType.L2.value.encode() == vectorFieldParams.get( + b"distance_metric" + ) + assert 2 == vectorFieldParams.get(b"dimension") + assert b"HNSW" == vectorFieldParams.get(b"algorithm") + assert b"FLOAT32" == vectorFieldParams.get(b"data_type") + + # Compare text field arguments. + assert b"$.text-field" == textField.get(b"identifier") + assert b"TEXT" == textField.get(b"type") + assert b"text-field" == textField.get(b"field_name") + + # Querying a missing index throws an error. + with pytest.raises(RequestError): + await ft.info(glide_client, str(uuid.uuid4())) + + async def _create_test_index_with_vector_field( + self, glide_client: GlideClusterClient, index_name: TEncodable + ): + # Helper function used for creating an index with JSON data type with a text and vector field. + fields: List[Field] = [] + textField: Field = TextField("$.text-field", "text-field") + + vectorFieldHash: VectorField = VectorField( + name="$.vec", + algorithm=VectorAlgorithm.HNSW, + attributes=VectorFieldAttributesHnsw( + dim=2, distance_metric=DistanceMetricType.L2, type=VectorType.FLOAT32 + ), + alias="VEC", + ) + fields.append(vectorFieldHash) + fields.append(textField) + + prefixes: List[TEncodable] = [] + prefixes.append("key-prefix") + + await ft.create( + glide_client, + indexName=index_name, + schema=fields, + options=FtCreateOptions(DataType.JSON, prefixes=prefixes), + ) From ebf6460ce824b70cf419d6509ed72ca83ab6c501 Mon Sep 17 00:00:00 2001 From: Avi Fenesh <55848801+avifenesh@users.noreply.github.com> Date: Wed, 23 Oct 2024 05:51:29 +0300 Subject: [PATCH 038/180] Ci/redis rs (#2488) * adjusting old pathes to submodules to fit changes Signed-off-by: avifenesh * redis-rs CI Signed-off-by: avifenesh * Update settings.json Signed-off-by: avifenesh * Update settings.json Signed-off-by: avifenesh --------- Signed-off-by: avifenesh --- .github/workflows/csharp.yml | 10 +- .github/workflows/go.yml | 10 +- .github/workflows/install-redis/action.yml | 3 +- .github/workflows/java-cd.yml | 6 +- .github/workflows/java.yml | 13 +- .github/workflows/lint-rust/action.yml | 5 +- .github/workflows/node.yml | 16 +- .github/workflows/python.yml | 15 +- .github/workflows/redis-rs.yml | 142 ++++++++++++++++++ .github/workflows/rust.yml | 14 +- deny.toml | 56 +------ glide-core/Cargo.toml | 8 +- .../redis/src/aio/multiplexed_connection.rs | 1 - 13 files changed, 185 insertions(+), 114 deletions(-) create mode 100644 .github/workflows/redis-rs.yml diff --git a/.github/workflows/csharp.yml b/.github/workflows/csharp.yml index aa85d9a991..eab61c6dc1 100644 --- a/.github/workflows/csharp.yml +++ b/.github/workflows/csharp.yml @@ -9,7 +9,7 @@ on: paths: - csharp/** - glide-core/src/** - - submodules/** + - glide-core/redis-rs/redis/src/** - .github/workflows/csharp.yml - .github/workflows/install-shared-dependencies/action.yml - .github/workflows/test-benchmark/action.yml @@ -20,7 +20,7 @@ on: paths: - csharp/** - glide-core/src/** - - submodules/** + - glide-core/redis-rs/redis/src/** - .github/workflows/csharp.yml - .github/workflows/install-shared-dependencies/action.yml - .github/workflows/test-benchmark/action.yml @@ -76,8 +76,7 @@ jobs: steps: - uses: actions/checkout@v4 - with: - submodules: recursive + - name: Set up dotnet ${{ matrix.dotnet }} uses: actions/setup-dotnet@v4 @@ -122,8 +121,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - with: - submodules: recursive + - uses: ./.github/workflows/lint-rust with: diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 6eaf3d1d19..3290839a6a 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -8,7 +8,7 @@ on: - v* paths: - glide-core/src/** - - submodules/** + - glide-core/redis-rs/redis/src/** - go/** - .github/workflows/go.yml - .github/workflows/install-shared-dependencies/action.yml @@ -19,7 +19,7 @@ on: pull_request: paths: - glide-core/src/** - - submodules/** + - glide-core/redis-rs/redis/src/** - go/** - .github/workflows/go.yml - .github/workflows/install-shared-dependencies/action.yml @@ -73,8 +73,7 @@ jobs: steps: - uses: actions/checkout@v4 - with: - submodules: recursive + - name: Set up Go ${{ matrix.go }} uses: actions/setup-go@v5 @@ -204,8 +203,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - with: - submodules: recursive + - uses: ./.github/workflows/lint-rust with: diff --git a/.github/workflows/install-redis/action.yml b/.github/workflows/install-redis/action.yml index f4a7491d04..b60f0687b5 100644 --- a/.github/workflows/install-redis/action.yml +++ b/.github/workflows/install-redis/action.yml @@ -17,8 +17,7 @@ runs: shell: bash - uses: actions/checkout@v4 - with: - submodules: recursive + - uses: actions/cache@v3 id: cache-redis diff --git a/.github/workflows/java-cd.yml b/.github/workflows/java-cd.yml index e9df283c50..d3f2038313 100644 --- a/.github/workflows/java-cd.yml +++ b/.github/workflows/java-cd.yml @@ -86,8 +86,7 @@ jobs: echo "No cleaning needed" fi - uses: actions/checkout@v4 - with: - submodules: recursive + - name: Set up JDK uses: actions/setup-java@v4 @@ -225,8 +224,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - with: - submodules: recursive + - name: Set up JDK uses: actions/setup-java@v4 diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index ebc6a06169..2c31562f78 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -8,7 +8,7 @@ on: - v* paths: - glide-core/src/** - - submodules/** + - glide-core/redis-rs/redis/src/** - java/** - .github/workflows/java.yml - .github/workflows/install-shared-dependencies/action.yml @@ -19,7 +19,7 @@ on: pull_request: paths: - glide-core/src/** - - submodules/** + - glide-core/redis-rs/redis/src/** - java/** - .github/workflows/java.yml - .github/workflows/install-shared-dependencies/action.yml @@ -74,8 +74,7 @@ jobs: steps: - uses: actions/checkout@v4 - with: - submodules: recursive + - uses: gradle/actions/wrapper-validation@v3 @@ -191,8 +190,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - with: - submodules: recursive + - uses: ./.github/workflows/lint-rust with: @@ -210,8 +208,7 @@ jobs: run: sudo chown -R $USER:$USER /home/ubuntu/actions-runner/_work/valkey-glide - uses: actions/checkout@v4 - with: - submodules: recursive + - name: Set up JDK uses: actions/setup-java@v4 diff --git a/.github/workflows/lint-rust/action.yml b/.github/workflows/lint-rust/action.yml index 11ca944f71..0823dda958 100644 --- a/.github/workflows/lint-rust/action.yml +++ b/.github/workflows/lint-rust/action.yml @@ -15,8 +15,7 @@ runs: steps: - uses: actions/checkout@v4 - with: - submodules: recursive + - name: Install Rust toolchain and protoc uses: ./.github/workflows/install-rust-and-protoc @@ -42,7 +41,7 @@ runs: - run: | cargo update - cargo install --locked --version 0.15.1 cargo-deny + cargo install --locked cargo-deny cargo deny check --config ${GITHUB_WORKSPACE}/deny.toml working-directory: ${{ inputs.cargo-toml-folder }} shell: bash diff --git a/.github/workflows/node.yml b/.github/workflows/node.yml index a9b6b4be18..c4c17a7e46 100644 --- a/.github/workflows/node.yml +++ b/.github/workflows/node.yml @@ -8,7 +8,7 @@ on: - v* paths: - glide-core/src/** - - submodules/** + - glide-core/redis-rs/redis/src/** - node/** - utils/cluster_manager.py - .github/workflows/node.yml @@ -22,7 +22,7 @@ on: pull_request: paths: - glide-core/src/** - - submodules/** + - glide-core/redis-rs/redis/src/** - node/** - utils/cluster_manager.py - .github/workflows/node.yml @@ -67,8 +67,7 @@ jobs: steps: - uses: actions/checkout@v4 - with: - submodules: recursive + - name: Use Node.js 16.x uses: actions/setup-node@v3 @@ -125,8 +124,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - with: - submodules: recursive + - uses: ./.github/workflows/lint-rust with: @@ -239,8 +237,7 @@ jobs: apk add git - uses: actions/checkout@v4 - with: - submodules: recursive + - name: Setup musl on Linux uses: ./.github/workflows/setup-musl-on-linux @@ -287,8 +284,7 @@ jobs: run: sudo chown -R $USER:$USER /home/ubuntu/actions-runner/_work/valkey-glide - uses: actions/checkout@v4 - with: - submodules: recursive + - name: Use Node.js 18.x uses: actions/setup-node@v4 diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index ebb83894e6..6c45d7707c 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -9,7 +9,7 @@ on: paths: - python/** - glide-core/src/** - - submodules/** + - glide-core/redis-rs/redis/src/**/** - utils/cluster_manager.py - .github/workflows/python.yml - .github/workflows/build-python-wrapper/action.yml @@ -25,7 +25,7 @@ on: paths: - python/** - glide-core/src/** - - submodules/** + - glide-core/redis-rs/redis/src/** - utils/cluster_manager.py - .github/workflows/python.yml - .github/workflows/build-python-wrapper/action.yml @@ -84,8 +84,7 @@ jobs: steps: - uses: actions/checkout@v4 - with: - submodules: recursive + - name: Set up Python uses: actions/setup-python@v4 @@ -167,8 +166,7 @@ jobs: steps: - uses: actions/checkout@v4 - with: - submodules: recursive + - name: Set up Python uses: actions/setup-python@v4 @@ -204,8 +202,7 @@ jobs: timeout-minutes: 15 steps: - uses: actions/checkout@v4 - with: - submodules: recursive + - name: lint rust uses: ./.github/workflows/lint-rust @@ -299,8 +296,6 @@ jobs: run: sudo chown -R $USER:$USER /home/ubuntu/actions-runner/_work/valkey-glide - uses: actions/checkout@v4 - with: - submodules: recursive - name: Build Python wrapper uses: ./.github/workflows/build-python-wrapper diff --git a/.github/workflows/redis-rs.yml b/.github/workflows/redis-rs.yml new file mode 100644 index 0000000000..cdc4967759 --- /dev/null +++ b/.github/workflows/redis-rs.yml @@ -0,0 +1,142 @@ +name: Redis-rs CI + +on: + push: + branches: + - main + - release-* + - v* + paths: + - glide-core/redis-rs/redis/** + - utils/cluster_manager.py + - deny.toml + - .github/workflows/install-shared-dependencies/action.yml + - .github/workflows/redis-rs.yml + pull_request: + paths: + - glide-core/redis-rs/redis/** + - utils/cluster_manager.py + - deny.toml + - .github/workflows/install-shared-dependencies/action.yml + - .github/workflows/redis-rs.yml + workflow_dispatch: + workflow_call: + +concurrency: + group: redis-rs-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +env: + CARGO_TERM_COLOR: always + +jobs: + redis-rs-CI: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install shared software dependencies + uses: ./.github/workflows/install-shared-dependencies + with: + os: "ubuntu" + target: "x86_64-unknown-linux-gnu" + engine-version: "7.2.5" + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Cache dependencies + uses: Swatinem/rust-cache@v2 + with: + cache-on-failure: true + workspaces: ./glide-core/redis-rs/redis + + + - name: Build project + run: cargo build --release + working-directory: ./glide-core/redis-rs/redis/src + + - name: Lint redis-rs + shell: bash + run: | + cargo fmt --all -- --check + cargo clippy -- -D warnings + cargo install --locked cargo-deny + cargo deny check all --config ${GITHUB_WORKSPACE}/deny.toml --exclude-dev all + working-directory: ./glide-core/redis-rs/redis + + - name: Test + run: | + cargo test --release -- --test-threads=1 | tee ../test-results.xml + echo "### Tests passed :v:" >> $GITHUB_STEP_SUMMARY + working-directory: ./glide-core/redis-rs/redis/src + + - name: Upload test reports + if: always() + continue-on-error: true + uses: actions/upload-artifact@v4 + with: + name: test-reports-redis-rs-${{ github.sha }} + path: ./glide-core/redis-rs/redis/test-results.xml + + - name: Run benchmarks + run: | + cargo bench | tee bench-results.xml + working-directory: ./glide-core/redis-rs/redis + + - name: Upload benchmark results + if: always() + continue-on-error: true + uses: actions/upload-artifact@v4 + with: + name: benchmark-results-redis-rs-${{ github.sha }} + path: ./glide-core/redis-rs/redis/bench-results.xml + + - name: Test docs + run: | + cargo test --doc + working-directory: ./glide-core/redis-rs/redis/src + + - name: Security audit + run: | + cargo audit | tee audit-results.txt + if grep -q "Crate: " audit-results.txt; then + echo "## Security audit results summary: Security vulnerabilities found :exclamation: :exclamation:" >> $GITHUB_STEP_SUMMARY + echo "Security audit results summary: Security vulnerabilities found" + exit 1 + else + echo "### Security audit results summary: All good, no security vulnerabilities found :closed_lock_with_key:" >> $GITHUB_STEP_SUMMARY + echo "Security audit results summary: All good, no security vulnerabilities found" + fi + working-directory: ./glide-core/redis-rs/redis + + - name: Upload audit results + if: always() + continue-on-error: true + uses: actions/upload-artifact@v4 + with: + name: audit-results-redis-rs--${{ github.sha }} + path: ./glide-core/redis-rs/redis/audit-results.txt + + - name: Run cargo machete + run: | + cargo install cargo-machete + cargo machete | tee machete-results.txt + if grep -A1 "cargo-machete found the following unused dependencies in this directory:" machete-results.txt | sed -n '2p' | grep -v "^if" > /dev/null; then + echo "Machete results summary: Unused dependencies found" >> $GITHUB_STEP_SUMMARY + echo "Machete results summary: Unused dependencies found" + cat machete-results.txt | grep -A1 "cargo-machete found the following unused dependencies in this directory:" | sed -n '2p' | grep -v "^if" >> $GITHUB_STEP_SUMMARY + exit 1 + else + echo "### Machete results summary: All good, no unused dependencies found :rocket:" >> $GITHUB_STEP_SUMMARY + echo "Machete results summary: All good, no unused dependencies found" + fi + working-directory: ./glide-core/redis-rs/redis + + - name: Upload machete results + if: always() + continue-on-error: true + uses: actions/upload-artifact@v4 + with: + name: machete-results-redis-rs-${{ github.sha }} + path: ./glide-core/redis-rs/redis/machete-results.txt diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c632880a2b..95b47f2ce2 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -9,7 +9,7 @@ on: paths: - logger_core/** - glide-core/** - - submodules/** + - glide-core/redis-rs/redis/src/** - utils/cluster_manager.py - .github/workflows/rust.yml - .github/workflows/install-shared-dependencies/action.yml @@ -21,7 +21,7 @@ on: paths: - logger_core/** - glide-core/** - - submodules/** + - glide-core/redis-rs/redis/src/** - utils/cluster_manager.py - .github/workflows/rust.yml - .github/workflows/install-shared-dependencies/action.yml @@ -63,8 +63,7 @@ jobs: steps: - uses: actions/checkout@v4 - with: - submodules: recursive + - name: Install shared software dependencies uses: ./.github/workflows/install-shared-dependencies @@ -78,11 +77,11 @@ jobs: - name: Run tests working-directory: ./glide-core - run: cargo test --all-features -- --nocapture --test-threads=1 # TODO remove the concurrency limit after we fix test flakyness. + run: cargo test --all-features --release -- --test-threads=1 # TODO remove the concurrency limit after we fix test flakyness. - name: Run logger tests working-directory: ./logger_core - run: cargo test --all-features -- --nocapture --test-threads=1 + run: cargo test --all-features -- --test-threads=1 - name: Check features working-directory: ./glide-core @@ -99,8 +98,7 @@ jobs: timeout-minutes: 30 steps: - uses: actions/checkout@v4 - with: - submodules: recursive + - uses: ./.github/workflows/lint-rust with: diff --git a/deny.toml b/deny.toml index f97a82f0c8..a0c30726d4 100644 --- a/deny.toml +++ b/deny.toml @@ -9,24 +9,6 @@ # The values provided in this template are the default values that will be used # when any section or field is not specified in your own configuration -# If 1 or more target triples (and optionally, target_features) are specified, -# only the specified targets will be checked when running `cargo deny check`. -# This means, if a particular package is only ever used as a target specific -# dependency, such as, for example, the `nix` crate only being used via the -# `target_family = "unix"` configuration, that only having windows targets in -# this list would mean the nix crate, as well as any of its exclusive -# dependencies not shared by any other crates, would be ignored, as the target -# list here is effectively saying which targets you are building for. -targets = [ - # The triple can be any string, but only the target triples built in to - # rustc (as of 1.40) can be checked against actual config expressions - #{ triple = "x86_64-unknown-linux-musl" }, - # You can also specify which target_features you promise are enabled for a - # particular target. target_features are currently not validated against - # the actual valid features supported by the target architecture. - #{ triple = "wasm32-unknown-unknown", features = ["atomics"] }, -] - # This section is considered when running `cargo deny check advisories` # More documentation for the advisories section can be found here: # https://embarkstudios.github.io/cargo-deny/checks/advisories/cfg.html @@ -35,22 +17,13 @@ targets = [ db-path = "~/.cargo/advisory-db" # The url(s) of the advisory databases to use db-urls = ["https://github.com/rustsec/advisory-db"] -# The lint level for security vulnerabilities -vulnerability = "deny" -# The lint level for unmaintained crates -unmaintained = "deny" # The lint level for crates that have been yanked from their source registry yanked = "deny" -# The lint level for crates with security notices. Note that as of -# 2019-12-17 there are no security notice advisories in -# https://github.com/rustsec/advisory-db -notice = "deny" -unsound = "deny" # A list of advisory IDs to ignore. Note that ignored advisories will still # output a note when they are encountered. ignore = [ # Unmaintained dependency error that needs more attention due to nested dependencies - "RUSTSEC-2024-0370" + "RUSTSEC-2024-0370", ] # Threshold for security vulnerabilities, any vulnerability with a CVSS score # lower than the range specified will be ignored. Note that ignored advisories @@ -72,8 +45,6 @@ ignore = [ # More documentation for the licenses section can be found here: # https://embarkstudios.github.io/cargo-deny/checks/licenses/cfg.html [licenses] -# The lint level for crates which do not have a detectable license -unlicensed = "deny" # List of explicitly allowed licenses # See https://spdx.org/licenses/ for list of possible licenses # [possible values: any SPDX 3.11 short identifier (+ optional exception)]. @@ -85,28 +56,9 @@ allow = [ "BSD-3-Clause", "Unicode-DFS-2016", "ISC", - "OpenSSL" -] -# List of explicitly disallowed licenses -# See https://spdx.org/licenses/ for list of possible licenses -# [possible values: any SPDX 3.11 short identifier (+ optional exception)]. -deny = [ - #"Nokia", + "OpenSSL", + "MPL-2.0", ] -# Lint level for licenses considered copyleft -copyleft = "deny" -# Blanket approval or denial for OSI-approved or FSF Free/Libre licenses -# * both - The license will be approved if it is both OSI-approved *AND* FSF -# * either - The license will be approved if it is either OSI-approved *OR* FSF -# * osi-only - The license will be approved if is OSI-approved *AND NOT* FSF -# * fsf-only - The license will be approved if is FSF *AND NOT* OSI-approved -# * neither - This predicate is ignored and the default lint level is used -allow-osi-fsf-free = "neither" -# Lint level used when no other predicates are matched -# 1. License isn't in the allow or deny lists -# 2. License isn't copyleft -# 3. License isn't OSI/FSF, or allow-osi-fsf-free = "neither" -default = "deny" # The confidence threshold for detecting a license from license text. # The higher the value, the more closely the license text must be to the # canonical license text of a valid SPDX license file. @@ -137,7 +89,7 @@ expression = "MIT AND ISC AND OpenSSL" # depending on the rest of your configuration license-files = [ # Each entry is a crate relative path, and the (opaque) hash of its contents - { path = "LICENSE", hash = 0xbd0eed23 } + { path = "LICENSE", hash = 0xbd0eed23 }, ] [licenses.private] diff --git a/glide-core/Cargo.toml b/glide-core/Cargo.toml index 28cad8e646..150c0ff33d 100644 --- a/glide-core/Cargo.toml +++ b/glide-core/Cargo.toml @@ -32,7 +32,7 @@ integer-encoding = { version = "4.0.0", optional = true } thiserror = "1" rand = { version = "0.8.5" } futures-intrusive = "0.5.0" -directories = { version = "4.0", optional = true } +directories = { version = "5.0", optional = true } once_cell = "1.18.0" sha1_smol = "1.0.0" nanoid = "0.4.0" @@ -52,13 +52,13 @@ standalone_heartbeat = [] rsevents = "0.3.1" socket2 = "^0.5" tempfile = "3.3.0" -rstest = "^0.18" +rstest = "^0.23" serial_test = "3" criterion = { version = "^0.5", features = ["html_reports", "async_tokio"] } -which = "5" +which = "6" ctor = "0.2.2" redis = { path = "./redis-rs/redis", features = ["tls-rustls-insecure"] } -iai-callgrind = "0.9" +iai-callgrind = "0.14" tokio = { version = "1", features = ["rt-multi-thread"] } glide-core = { path = ".", features = [ "socket-layer", diff --git a/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs b/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs index c23d4dfca4..15df4e9aa8 100644 --- a/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs +++ b/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs @@ -23,7 +23,6 @@ use pin_project_lite::pin_project; use std::collections::VecDeque; use std::fmt; use std::fmt::Debug; -use std::io; use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; From 78c927b1eb44703156e873defa6ad602b4420a5e Mon Sep 17 00:00:00 2001 From: Andrew Carbonetto Date: Wed, 23 Oct 2024 08:54:12 -0700 Subject: [PATCH 039/180] Node: add `FT.CREATE` command (#2501) * Add FT.CREATE command for Node Signed-off-by: Andrew Carbonetto --------- Signed-off-by: Andrew Carbonetto --- CHANGELOG.md | 1 + node/DEVELOPER.md | 10 + node/index.ts | 2 + node/npm/glide/index.ts | 16 + node/package.json | 3 +- node/src/server-modules/GlideFt.ts | 176 ++++++ node/src/server-modules/GlideFtOptions.ts | 120 +++++ node/tests/ServerModules.test.ts | 630 +++++++++++++++------- 8 files changed, 748 insertions(+), 210 deletions(-) create mode 100644 node/src/server-modules/GlideFt.ts create mode 100644 node/src/server-modules/GlideFtOptions.ts diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c16c45936..7f1a3eff1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ * Java: Added `FT.SEARCH` ([#2439](https://github.com/valkey-io/valkey-glide/pull/2439)) * Java: Added `FT.AGGREGATE` ([#2466](https://github.com/valkey-io/valkey-glide/pull/2466)) * Java: Added `JSON.SET` and `JSON.GET` ([#2462](https://github.com/valkey-io/valkey-glide/pull/2462)) +* Node: Added `FT.CREATE` ([#2501](https://github.com/valkey-io/valkey-glide/pull/2501)) * Java: Added `JSON.ARRINSERT` and `JSON.ARRLEN` ([#2476](https://github.com/valkey-io/valkey-glide/pull/2476)) * Java: Added `JSON.DEL` and `JSON.FORGET` ([#2490](https://github.com/valkey-io/valkey-glide/pull/2490)) * Java: Added `FT.ALIASADD`, `FT.ALIASDEL`, `FT.ALIASUPDATE` ([#2442](https://github.com/valkey-io/valkey-glide/pull/2442)) diff --git a/node/DEVELOPER.md b/node/DEVELOPER.md index 7185ce1359..4d06e391a1 100644 --- a/node/DEVELOPER.md +++ b/node/DEVELOPER.md @@ -137,6 +137,16 @@ To run the integration tests with existing servers, run the following command: ```bash npm run test -- --cluster-endpoints=localhost:7000 --standalone-endpoints=localhost:6379 + +# If those endpoints use TLS, add `--tls=true` (applies to both endpoints) +npm run test -- --cluster-endpoints=localhost:7000 --standalone-endpoints=localhost:6379 --tls=true +``` + +By default, the server_modules tests do not run using `npm run test`. After pointing to a server with JSON and VSS modules setup, +run the following command: + +```bash +npm run test-modules ``` ### Submodules diff --git a/node/index.ts b/node/index.ts index 2127178d07..ee035c2a49 100644 --- a/node/index.ts +++ b/node/index.ts @@ -10,4 +10,6 @@ export * from "./src/GlideClient"; export * from "./src/GlideClusterClient"; export * from "./src/Logger"; export * from "./src/server-modules/GlideJson"; +export * from "./src/server-modules/GlideFt"; +export * from "./src/server-modules/GlideFtOptions"; export * from "./src/Transaction"; diff --git a/node/npm/glide/index.ts b/node/npm/glide/index.ts index 245ddcd2e4..7539524e32 100644 --- a/node/npm/glide/index.ts +++ b/node/npm/glide/index.ts @@ -118,6 +118,14 @@ function initialize() { GlideClusterClient, GlideClientConfiguration, GlideJson, + GlideFt, + TextField, + TagField, + NumericField, + VectorField, + VectorFieldAttributesFlat, + VectorFieldAttributesHnsw, + FtCreateOptions, GlideRecord, GlideString, JsonGetOptions, @@ -228,6 +236,14 @@ function initialize() { Decoder, DecoderOption, GeoAddOptions, + GlideFt, + TextField, + TagField, + NumericField, + VectorField, + VectorFieldAttributesFlat, + VectorFieldAttributesHnsw, + FtCreateOptions, GlideRecord, GlideJson, GlideString, diff --git a/node/package.json b/node/package.json index b7595bc79d..53b77772b6 100644 --- a/node/package.json +++ b/node/package.json @@ -32,7 +32,8 @@ "compile-protobuf-files": "cd src && pbjs -t static-module -o ProtobufMessage.js ../../glide-core/src/protobuf/*.proto && pbts -o ProtobufMessage.d.ts ProtobufMessage.js", "fix-protobuf-file": "replace 'this\\.encode\\(message, writer\\)\\.ldelim' 'this.encode(message, writer && writer.len ? writer.fork() : writer).ldelim' src/ProtobufMessage.js", "test": "npm run build-test-utils && jest --verbose --runInBand --testPathIgnorePatterns='ServerModules'", - "test-modules": "npm run build-test-utils && jest --verbose --runInBand --testPathPattern='ServerModules'", + "test-minimum": "npm run build-test-utils && jest --verbose --runInBand --testNamePattern='^(.(?!(GlideJson|GlideFt|pubsub|kill)))*$'", + "test-modules": "npm run build-test-utils && jest --verbose --runInBand --testNamePattern='(GlideJson|GlideFt)'", "build-test-utils": "cd ../utils && npm i && npm run build", "lint:fix": "npm run install-linting && npx eslint -c ../eslint.config.mjs --fix && npm run prettier:format", "lint": "npm run install-linting && npx eslint -c ../eslint.config.mjs && npm run prettier:check:ci", diff --git a/node/src/server-modules/GlideFt.ts b/node/src/server-modules/GlideFt.ts new file mode 100644 index 0000000000..566e4d54c4 --- /dev/null +++ b/node/src/server-modules/GlideFt.ts @@ -0,0 +1,176 @@ +/** + * Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + */ + +import { Decoder, DecoderOption, GlideString } from "../BaseClient"; +import { GlideClient } from "../GlideClient"; +import { GlideClusterClient } from "../GlideClusterClient"; +import { Field, FtCreateOptions } from "./GlideFtOptions"; + +/** Module for Vector Search commands */ +export class GlideFt { + /** + * Creates an index and initiates a backfill of that index. + * + * @param client The client to execute the command. + * @param indexName The index name for the index to be created. + * @param schema The fields of the index schema, specifying the fields and their types. + * @param options Optional arguments for the `FT.CREATE` command. See {@link FtCreateOptions}. + * + * @returns If the index is successfully created, returns "OK". + * + * @example + * ```typescript + * // Example usage of FT.CREATE to create a 6-dimensional JSON index using the HNSW algorithm + * await GlideFt.create(client, "json_idx1", [{ + * type: "VECTOR", + * name: "$.vec", + * alias: "VEC", + * attributes: { + * algorithm: "HNSW", + * type: "FLOAT32", + * dimension: 6, + * distanceMetric: "L2", + * numberOfEdges: 32, + * }, + * }], { + * dataType: "JSON", + * prefixes: ["json:"] + * }); + * ``` + */ + static async create( + client: GlideClient | GlideClusterClient, + indexName: GlideString, + schema: Field[], + options?: FtCreateOptions, + ): Promise<"OK"> { + const args: GlideString[] = ["FT.CREATE", indexName]; + + if (options) { + if ("dataType" in options) { + args.push("ON", options.dataType); + } + + if ("prefixes" in options && options.prefixes) { + args.push( + "PREFIX", + options.prefixes.length.toString(), + ...options.prefixes, + ); + } + } + + args.push("SCHEMA"); + + schema.forEach((f) => { + args.push(f.name); + + if (f.alias) { + args.push("AS", f.alias); + } + + args.push(f.type); + + switch (f.type) { + case "TAG": { + if (f.separator) { + args.push("SEPARATOR", f.separator); + } + + if (f.caseSensitive) { + args.push("CASESENSITIVE"); + } + + break; + } + + case "VECTOR": { + if (f.attributes) { + args.push(f.attributes.algorithm); + + const attributes: GlideString[] = []; + + // all VectorFieldAttributes attributes + if (f.attributes.dimension) { + attributes.push( + "DIM", + f.attributes.dimension.toString(), + ); + } + + if (f.attributes.distanceMetric) { + attributes.push( + "DISTANCE_METRIC", + f.attributes.distanceMetric.toString(), + ); + } + + if (f.attributes.type) { + attributes.push( + "TYPE", + f.attributes.type.toString(), + ); + } + + if (f.attributes.initialCap) { + attributes.push( + "INITIAL_CAP", + f.attributes.initialCap.toString(), + ); + } + + // VectorFieldAttributesHnsw attributes + if ("m" in f.attributes && f.attributes.m) { + attributes.push("M", f.attributes.m.toString()); + } + + if ( + "efContruction" in f.attributes && + f.attributes.efContruction + ) { + attributes.push( + "EF_CONSTRUCTION", + f.attributes.efContruction.toString(), + ); + } + + if ( + "efRuntime" in f.attributes && + f.attributes.efRuntime + ) { + attributes.push( + "EF_RUNTIME", + f.attributes.efRuntime.toString(), + ); + } + + args.push(attributes.length.toString(), ...attributes); + } + + break; + } + + default: + // no-op + } + }); + + return _handleCustomCommand(client, args, { + decoder: Decoder.String, + }) as Promise<"OK">; + } +} + +/** + * @internal + */ +function _handleCustomCommand( + client: GlideClient | GlideClusterClient, + args: GlideString[], + decoderOption: DecoderOption, +) { + return client instanceof GlideClient + ? (client as GlideClient).customCommand(args, decoderOption) + : (client as GlideClusterClient).customCommand(args, decoderOption); +} diff --git a/node/src/server-modules/GlideFtOptions.ts b/node/src/server-modules/GlideFtOptions.ts new file mode 100644 index 0000000000..6fe723cc9d --- /dev/null +++ b/node/src/server-modules/GlideFtOptions.ts @@ -0,0 +1,120 @@ +/** + * Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + */ + +import { GlideString } from "../BaseClient"; + +interface BaseField { + /** The name of the field. */ + name: GlideString; + /** An alias for field. */ + alias?: GlideString; +} + +/** + * Field contains any blob of data. + */ +export type TextField = BaseField & { + /** Field identifier */ + type: "TEXT"; +}; + +/** + * Tag fields are similar to full-text fields, but they interpret the text as a simple list of + * tags delimited by a separator character. + * + * For HASH fields, separator default is a comma (`,`). For JSON fields, there is no default + * separator; you must declare one explicitly if needed. + */ +export type TagField = BaseField & { + /** Field identifier */ + type: "TAG"; + /** Specify how text in the attribute is split into individual tags. Must be a single character. */ + separator?: GlideString; + /** Preserve the original letter cases of tags. If set to False, characters are converted to lowercase by default. */ + caseSensitive?: boolean; +}; + +/** + * Field contains a number. + */ +export type NumericField = BaseField & { + /** Field identifier */ + type: "NUMERIC"; +}; + +/** + * Superclass for vector field implementations, contains common logic. + */ +export type VectorField = BaseField & { + /** Field identifier */ + type: "VECTOR"; + /** Additional attributes to be passed with the vector field after the algorithm name. */ + attributes: VectorFieldAttributesFlat | VectorFieldAttributesHnsw; +}; + +/** + * Base class for defining vector field attributes to be used after the vector algorithm name. + */ +export interface VectorFieldAttributes { + /** Number of dimensions in the vector. Equivalent to DIM in the option. */ + dimension: number; + /** + * The distance metric used in vector type field. Can be one of [L2 | IP | COSINE]. + */ + distanceMetric: "L2" | "IP" | "COSINE"; + /** Vector type. The only supported type is FLOAT32. */ + type: "FLOAT32"; + /** + * Initial vector capacity in the index affecting memory allocation size of the index. Defaults to 1024. + */ + initialCap?: number; +} + +/** + * Vector field that supports vector search by FLAT (brute force) algorithm. + * + * The algorithm is a brute force linear processing of each vector in the index, yielding exact + * answers within the bounds of the precision of the distance computations. + */ +export type VectorFieldAttributesFlat = VectorFieldAttributes & { + algorithm: "FLAT"; +}; + +/** + * Vector field that supports vector search by HNSM (Hierarchical Navigable Small World) algorithm. + * + * The algorithm provides an approximation of the correct answer in exchange for substantially + * lower execution times. + */ +export type VectorFieldAttributesHnsw = VectorFieldAttributes & { + algorithm: "HNSW"; + /** + * Number of maximum allowed outgoing edges for each node in the graph in each layer. Default is 16, maximum is 512. + * Equivalent to the `m` attribute. + */ + numberOfEdges?: number; + /** + * Controls the number of vectors examined during index construction. Default value is 200, Maximum value is 4096. + * Equivalent to the `efContruction` attribute. + */ + vectorsExaminedOnConstruction?: number; + /** + * Controls the number of vectors examined during query operations. Default value is 10, Maximum value is 4096. + * Equivalent to the `efRuntime` attribute. + */ + vectorsExaminedOnRuntime?: number; +}; + +export type Field = TextField | TagField | NumericField | VectorField; + +/** + * Represents the input options to be used in the FT.CREATE command. + * All fields in this class are optional inputs for FT.CREATE. + */ +export interface FtCreateOptions { + /** The type of data to be indexed using FT.CREATE. */ + dataType: "JSON" | "HASH"; + /** The prefix of the key to be indexed. */ + prefixes?: GlideString[]; +} diff --git a/node/tests/ServerModules.test.ts b/node/tests/ServerModules.test.ts index fad81f2bbd..cffdc77dcd 100644 --- a/node/tests/ServerModules.test.ts +++ b/node/tests/ServerModules.test.ts @@ -13,11 +13,13 @@ import { v4 as uuidv4 } from "uuid"; import { ConditionalChange, GlideClusterClient, + GlideFt, GlideJson, InfoOptions, JsonGetOptions, ProtocolVersion, RequestError, + VectorField, } from ".."; import { ValkeyCluster } from "../../utils/TestUtils"; import { @@ -29,10 +31,10 @@ import { } from "./TestUtilities"; const TIMEOUT = 50000; -describe("GlideJson", () => { - const testsFailed = 0; + +describe("Server Module Tests", () => { let cluster: ValkeyCluster; - let client: GlideClusterClient; + beforeAll(async () => { const clusterAddresses = parseCommandLineArgs()["cluster-endpoints"]; cluster = await ValkeyCluster.initFromExistingCluster( @@ -42,241 +44,451 @@ describe("GlideJson", () => { ); }, 20000); - afterEach(async () => { - await flushAndCloseClient(true, cluster.getAddresses(), client); - }); - afterAll(async () => { - if (testsFailed === 0) { - await cluster.close(); - } + await cluster.close(); }, TIMEOUT); - it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( - "check modules loaded", - async (protocol) => { - client = await GlideClusterClient.createClient( - getClientConfigurationOption(cluster.getAddresses(), protocol), - ); - const info = await client.info({ - sections: [InfoOptions.Modules], - route: "randomNode", - }); - expect(info).toContain("# json_core_metrics"); - expect(info).toContain("# search_index_stats"); - }, - ); + describe("GlideJson", () => { + let client: GlideClusterClient; - it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( - "json.set and json.get tests", - async (protocol) => { - client = await GlideClusterClient.createClient( - getClientConfigurationOption(cluster.getAddresses(), protocol), - ); - const key = uuidv4(); - const jsonValue = { a: 1.0, b: 2 }; + afterEach(async () => { + await flushAndCloseClient(true, cluster.getAddresses(), client); + }); - // JSON.set - expect( - await GlideJson.set( - client, - key, - "$", - JSON.stringify(jsonValue), - ), - ).toBe("OK"); + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "check modules loaded", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const info = await client.info({ + sections: [InfoOptions.Modules], + route: "randomNode", + }); + expect(info).toContain("# json_core_metrics"); + expect(info).toContain("# search_index_stats"); + }, + ); - // JSON.get - let result = await GlideJson.get(client, key, { paths: ["."] }); - expect(JSON.parse(result.toString())).toEqual(jsonValue); + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "json.set and json.get tests", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const jsonValue = { a: 1.0, b: 2 }; - // JSON.get with array of paths - result = await GlideJson.get(client, key, { - paths: ["$.a", "$.b"], - }); - expect(JSON.parse(result.toString())).toEqual({ - "$.a": [1.0], - "$.b": [2], - }); + // JSON.set + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); - // JSON.get with non-existing key - expect( - await GlideJson.get(client, "non_existing_key", { + // JSON.get + let result = await GlideJson.get(client, key, { paths: ["."] }); + expect(JSON.parse(result.toString())).toEqual(jsonValue); + + // JSON.get with array of paths + result = await GlideJson.get(client, key, { + paths: ["$.a", "$.b"], + }); + expect(JSON.parse(result.toString())).toEqual({ + "$.a": [1.0], + "$.b": [2], + }); + + // JSON.get with non-existing key + expect( + await GlideJson.get(client, "non_existing_key", { + paths: ["$"], + }), + ); + + // JSON.get with non-existing path + result = await GlideJson.get(client, key, { paths: ["$.d"] }); + expect(result).toEqual("[]"); + }, + ); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "json.set and json.get tests with multiple value", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + + // JSON.set with complex object + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify({ + a: { c: 1, d: 4 }, + b: { c: 2 }, + c: true, + }), + ), + ).toBe("OK"); + + // JSON.get with deep path + let result = await GlideJson.get(client, key, { + paths: ["$..c"], + }); + expect(JSON.parse(result.toString())).toEqual([true, 1, 2]); + + // JSON.set with deep path + expect( + await GlideJson.set(client, key, "$..c", '"new_value"'), + ).toBe("OK"); + + // verify JSON.set result + result = await GlideJson.get(client, key, { paths: ["$..c"] }); + expect(JSON.parse(result.toString())).toEqual([ + "new_value", + "new_value", + "new_value", + ]); + }, + ); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "json.set conditional set", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const value = JSON.stringify({ a: 1.0, b: 2 }); + + expect( + await GlideJson.set(client, key, "$", value, { + conditionalChange: ConditionalChange.ONLY_IF_EXISTS, + }), + ).toBeNull(); + + expect( + await GlideJson.set(client, key, "$", value, { + conditionalChange: + ConditionalChange.ONLY_IF_DOES_NOT_EXIST, + }), + ).toBe("OK"); + + expect( + await GlideJson.set(client, key, "$.a", "4.5", { + conditionalChange: + ConditionalChange.ONLY_IF_DOES_NOT_EXIST, + }), + ).toBeNull(); + let result = await GlideJson.get(client, key, { + paths: [".a"], + }); + expect(result).toEqual("1"); + + expect( + await GlideJson.set(client, key, "$.a", "4.5", { + conditionalChange: ConditionalChange.ONLY_IF_EXISTS, + }), + ).toBe("OK"); + result = await GlideJson.get(client, key, { paths: [".a"] }); + expect(result).toEqual("4.5"); + }, + ); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "json.get formatting", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + // Set initial JSON value + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify({ a: 1.0, b: 2, c: { d: 3, e: 4 } }), + ), + ).toBe("OK"); + // JSON.get with formatting options + let result = await GlideJson.get(client, key, { paths: ["$"], - }), - ); + indent: " ", + newline: "\n", + space: " ", + } as JsonGetOptions); - // JSON.get with non-existing path - result = await GlideJson.get(client, key, { paths: ["$.d"] }); - expect(result).toEqual("[]"); - }, - ); + const expectedResult1 = + '[\n {\n "a": 1,\n "b": 2,\n "c": {\n "d": 3,\n "e": 4\n }\n }\n]'; + expect(result).toEqual(expectedResult1); + // JSON.get with different formatting options + result = await GlideJson.get(client, key, { + paths: ["$"], + indent: "~", + newline: "\n", + space: "*", + } as JsonGetOptions); - it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( - "json.set and json.get tests with multiple value", - async (protocol) => { - client = await GlideClusterClient.createClient( - getClientConfigurationOption(cluster.getAddresses(), protocol), - ); - const key = uuidv4(); + const expectedResult2 = + '[\n~{\n~~"a":*1,\n~~"b":*2,\n~~"c":*{\n~~~"d":*3,\n~~~"e":*4\n~~}\n~}\n]'; + expect(result).toEqual(expectedResult2); + }, + ); - // JSON.set with complex object - expect( - await GlideJson.set( - client, - key, - "$", - JSON.stringify({ a: { c: 1, d: 4 }, b: { c: 2 }, c: true }), - ), - ).toBe("OK"); + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "json.toggle tests", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const key2 = uuidv4(); + const jsonValue = { + bool: true, + nested: { bool: false, nested: { bool: 10 } }, + }; + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + expect( + await GlideJson.toggle(client, key, { path: "$..bool" }), + ).toEqual([false, true, null]); + expect( + await GlideJson.toggle(client, key, { path: "bool" }), + ).toBe(true); + expect( + await GlideJson.toggle(client, key, { + path: "$.non_existing", + }), + ).toEqual([]); + expect( + await GlideJson.toggle(client, key, { path: "$.nested" }), + ).toEqual([null]); - // JSON.get with deep path - let result = await GlideJson.get(client, key, { paths: ["$..c"] }); - expect(JSON.parse(result.toString())).toEqual([true, 1, 2]); + // testing behavior with default pathing + expect(await GlideJson.set(client, key2, ".", "true")).toBe( + "OK", + ); + expect(await GlideJson.toggle(client, key2)).toBe(false); + expect(await GlideJson.toggle(client, key2)).toBe(true); - // JSON.set with deep path - expect( - await GlideJson.set(client, key, "$..c", '"new_value"'), - ).toBe("OK"); - - // verify JSON.set result - result = await GlideJson.get(client, key, { paths: ["$..c"] }); - expect(JSON.parse(result.toString())).toEqual([ - "new_value", - "new_value", - "new_value", - ]); - }, - ); - - it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( - "json.set conditional set", - async (protocol) => { + // expect request errors + await expect( + GlideJson.toggle(client, key, { path: "nested" }), + ).rejects.toThrow(RequestError); + await expect( + GlideJson.toggle(client, key, { path: ".non_existing" }), + ).rejects.toThrow(RequestError); + await expect( + GlideJson.toggle(client, "non_existing_key", { path: "$" }), + ).rejects.toThrow(RequestError); + }, + ); + }); + + describe("GlideFt", () => { + let client: GlideClusterClient; + + afterEach(async () => { + await flushAndCloseClient(true, cluster.getAddresses(), client); + }); + + it("ServerModules check Vector Search module is loaded", async () => { client = await GlideClusterClient.createClient( - getClientConfigurationOption(cluster.getAddresses(), protocol), + getClientConfigurationOption( + cluster.getAddresses(), + ProtocolVersion.RESP3, + ), + ); + const info = await client.info({ + sections: [InfoOptions.Modules], + route: "randomNode", + }); + expect(info).toContain("# search_index_stats"); + }); + + it("Ft.Create test", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + ProtocolVersion.RESP3, + ), ); - const key = uuidv4(); - const value = JSON.stringify({ a: 1.0, b: 2 }); + // Create a few simple indices: + const vectorField_1: VectorField = { + type: "VECTOR", + name: "vec", + alias: "VEC", + attributes: { + algorithm: "HNSW", + type: "FLOAT32", + dimension: 2, + distanceMetric: "L2", + }, + }; expect( - await GlideJson.set(client, key, "$", value, { - conditionalChange: ConditionalChange.ONLY_IF_EXISTS, - }), - ).toBeNull(); + await GlideFt.create(client, uuidv4(), [vectorField_1]), + ).toEqual("OK"); expect( - await GlideJson.set(client, key, "$", value, { - conditionalChange: ConditionalChange.ONLY_IF_DOES_NOT_EXIST, - }), - ).toBe("OK"); + await GlideFt.create( + client, + "json_idx1", + [ + { + type: "VECTOR", + name: "$.vec", + alias: "VEC", + attributes: { + algorithm: "HNSW", + type: "FLOAT32", + dimension: 6, + distanceMetric: "L2", + numberOfEdges: 32, + }, + }, + ], + { + dataType: "JSON", + prefixes: ["json:"], + }, + ), + ).toEqual("OK"); + const vectorField_2: VectorField = { + type: "VECTOR", + name: "$.vec", + alias: "VEC", + attributes: { + algorithm: "FLAT", + type: "FLOAT32", + dimension: 6, + distanceMetric: "L2", + }, + }; expect( - await GlideJson.set(client, key, "$.a", "4.5", { - conditionalChange: ConditionalChange.ONLY_IF_DOES_NOT_EXIST, - }), - ).toBeNull(); - let result = await GlideJson.get(client, key, { paths: [".a"] }); - expect(result).toEqual("1"); + await GlideFt.create(client, uuidv4(), [vectorField_2]), + ).toEqual("OK"); + // create an index with HNSW vector with additional parameters + const vectorField_3: VectorField = { + type: "VECTOR", + name: "doc_embedding", + attributes: { + algorithm: "HNSW", + type: "FLOAT32", + dimension: 1536, + distanceMetric: "COSINE", + numberOfEdges: 40, + vectorsExaminedOnConstruction: 250, + vectorsExaminedOnRuntime: 40, + }, + }; expect( - await GlideJson.set(client, key, "$.a", "4.5", { - conditionalChange: ConditionalChange.ONLY_IF_EXISTS, + await GlideFt.create(client, uuidv4(), [vectorField_3], { + dataType: "HASH", + prefixes: ["docs:"], }), - ).toBe("OK"); - result = await GlideJson.get(client, key, { paths: [".a"] }); - expect(result).toEqual("4.5"); - }, - ); - - it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( - "json.get formatting", - async (protocol) => { - client = await GlideClusterClient.createClient( - getClientConfigurationOption(cluster.getAddresses(), protocol), - ); - const key = uuidv4(); - // Set initial JSON value + ).toEqual("OK"); + + // create an index with multiple fields expect( - await GlideJson.set( + await GlideFt.create( client, - key, - "$", - JSON.stringify({ a: 1.0, b: 2, c: { d: 3, e: 4 } }), + uuidv4(), + [ + { type: "TEXT", name: "title" }, + { type: "NUMERIC", name: "published_at" }, + { type: "TAG", name: "category" }, + ], + { dataType: "HASH", prefixes: ["blog:post:"] }, ), - ).toBe("OK"); - // JSON.get with formatting options - let result = await GlideJson.get(client, key, { - paths: ["$"], - indent: " ", - newline: "\n", - space: " ", - } as JsonGetOptions); - - const expectedResult1 = - '[\n {\n "a": 1,\n "b": 2,\n "c": {\n "d": 3,\n "e": 4\n }\n }\n]'; - expect(result).toEqual(expectedResult1); - // JSON.get with different formatting options - result = await GlideJson.get(client, key, { - paths: ["$"], - indent: "~", - newline: "\n", - space: "*", - } as JsonGetOptions); - - const expectedResult2 = - '[\n~{\n~~"a":*1,\n~~"b":*2,\n~~"c":*{\n~~~"d":*3,\n~~~"e":*4\n~~}\n~}\n]'; - expect(result).toEqual(expectedResult2); - }, - ); - - it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( - "json.toggle tests", - async (protocol) => { - client = await GlideClusterClient.createClient( - getClientConfigurationOption(cluster.getAddresses(), protocol), - ); - const key = uuidv4(); - const key2 = uuidv4(); - const jsonValue = { - bool: true, - nested: { bool: false, nested: { bool: 10 } }, - }; + ).toEqual("OK"); + + // create an index with multiple prefixes + const name = uuidv4(); expect( - await GlideJson.set( + await GlideFt.create( client, - key, - "$", - JSON.stringify(jsonValue), + name, + [ + { type: "TAG", name: "author_id" }, + { type: "TAG", name: "author_ids" }, + { type: "TEXT", name: "title" }, + { type: "TEXT", name: "name" }, + ], + { + dataType: "HASH", + prefixes: ["author:details:", "book:details:"], + }, ), - ).toBe("OK"); - expect( - await GlideJson.toggle(client, key, { path: "$..bool" }), - ).toEqual([false, true, null]); - expect(await GlideJson.toggle(client, key, { path: "bool" })).toBe( - true, - ); - expect( - await GlideJson.toggle(client, key, { path: "$.non_existing" }), - ).toEqual([]); - expect( - await GlideJson.toggle(client, key, { path: "$.nested" }), - ).toEqual([null]); - - // testing behavior with default pathing - expect(await GlideJson.set(client, key2, ".", "true")).toBe("OK"); - expect(await GlideJson.toggle(client, key2)).toBe(false); - expect(await GlideJson.toggle(client, key2)).toBe(true); - - // expect request errors - await expect( - GlideJson.toggle(client, key, { path: "nested" }), - ).rejects.toThrow(RequestError); - await expect( - GlideJson.toggle(client, key, { path: ".non_existing" }), - ).rejects.toThrow(RequestError); - await expect( - GlideJson.toggle(client, "non_existing_key", { path: "$" }), - ).rejects.toThrow(RequestError); - }, - ); + ).toEqual("OK"); + + // create a duplicating index - expect a RequestError + try { + expect( + await GlideFt.create(client, name, [ + { type: "TEXT", name: "title" }, + { type: "TEXT", name: "name" }, + ]), + ).rejects.toThrow(); + } catch (e) { + expect((e as Error).message).toContain("already exists"); + } + + // create an index without fields - expect a RequestError + try { + expect( + await GlideFt.create(client, uuidv4(), []), + ).rejects.toThrow(); + } catch (e) { + expect((e as Error).message).toContain( + "wrong number of arguments", + ); + } + + // duplicated field name - expect a RequestError + try { + expect( + await GlideFt.create(client, uuidv4(), [ + { type: "TEXT", name: "name" }, + { type: "TEXT", name: "name" }, + ]), + ).rejects.toThrow(); + } catch (e) { + expect((e as Error).message).toContain("already exists"); + } + }); + }); }); From 71257a7f0b52a3088e0c23f15b53d9c02d68f001 Mon Sep 17 00:00:00 2001 From: Yi-Pin Chen Date: Wed, 23 Oct 2024 11:12:50 -0700 Subject: [PATCH 040/180] Java: add JSON.TOGGLE command (#2504) * Java: add JSON.TOGGLE command --------- Signed-off-by: Yi-Pin Chen --- CHANGELOG.md | 1 + .../api/commands/servermodules/Json.java | 111 ++++++++++++++++++ .../test/java/glide/modules/JsonTests.java | 33 ++++++ 3 files changed, 145 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f1a3eff1c..c23a827531 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ * Node: Added `JSON.SET` and `JSON.GET` ([#2427](https://github.com/valkey-io/valkey-glide/pull/2427)) * Java: Added `JSON.ARRAPPEND` ([#2489](https://github.com/valkey-io/valkey-glide/pull/2489)) * Node: Added `JSON.TOGGLE` ([#2491](https://github.com/valkey-io/valkey-glide/pull/2491)) +* Java: Added `JSON.TOGGLE` ([#2504](https://github.com/valkey-io/valkey-glide/pull/2504)) #### Breaking Changes diff --git a/java/client/src/main/java/glide/api/commands/servermodules/Json.java b/java/client/src/main/java/glide/api/commands/servermodules/Json.java index 4f43acafd4..7beb5fbab5 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/Json.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/Json.java @@ -27,6 +27,7 @@ public class Json { private static final String JSON_ARRLEN = JSON_PREFIX + "ARRLEN"; private static final String JSON_DEL = JSON_PREFIX + "DEL"; private static final String JSON_FORGET = JSON_PREFIX + "FORGET"; + private static final String JSON_TOGGLE = JSON_PREFIX + "TOGGLE"; private Json() {} @@ -849,6 +850,116 @@ public static CompletableFuture forget( return executeCommand(client, new GlideString[] {gs(JSON_FORGET), key, path}); } + /** + * Toggles a Boolean value stored at the root within the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return Returns the toggled boolean value at the root of the document, or null for + * JSON values matching the root that are not boolean. If key doesn't exist, + * returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", true).get();
      +     * var res = Json.toggle(client, "doc").get();
      +     * assert res.equals(false);
      +     * res = Json.toggle(client, "doc").get();
      +     * assert res.equals(true);
      +     * }
      + */ + public static CompletableFuture toggle(@NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new String[] {JSON_TOGGLE, key}); + } + + /** + * Toggles a Boolean value stored at the root within the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return Returns the toggled boolean value at the root of the document, or null for + * JSON values matching the root that are not boolean. If key doesn't exist, + * returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", true).get();
      +     * var res = Json.toggle(client, gs("doc")).get();
      +     * assert res.equals(false);
      +     * res = Json.toggle(client, gs("doc")).get();
      +     * assert res.equals(true);
      +     * }
      + */ + public static CompletableFuture toggle( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new ArgsBuilder().add(gs(JSON_TOGGLE)).add(key).toArray()); + } + + /** + * Toggles a Boolean value stored at the specified path within the JSON document + * stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns a Boolean[] with the toggled boolean value for every possible + * path, or null for JSON values matching the path that are not boolean. + *
      • For legacy path (path doesn't start with $):
        + * Returns the value of the toggled boolean in path. If path + * doesn't exist or the value at path isn't a boolean, an error is raised. + *
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"bool\": true, \"nested\": {\"bool\": false, \"nested\": {\"bool\": 10}}}").get();
      +     * var res = Json.toggle(client, "doc", "$..bool").get();
      +     * assert Arrays.equals((Boolean[]) res, new Boolean[] {false, true, null});
      +     * res = Json.toggle(client, "doc", "bool").get();
      +     * assert res.equals(true);
      +     * var getResult = Json.get(client, "doc", "$").get();
      +     * assert getResult.equals("{\"bool\": true, \"nested\": {\"bool\": true, \"nested\": {\"bool\": 10}}}");
      +     * }
      + */ + public static CompletableFuture toggle( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, new String[] {JSON_TOGGLE, key, path}); + } + + /** + * Toggles a Boolean value stored at the specified path within the JSON document + * stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns a Boolean[] with the toggled boolean value for every possible + * path, or null for JSON values matching the path that are not boolean. + *
      • For legacy path (path doesn't start with $):
        + * Returns the value of the toggled boolean in path. If path + * doesn't exist or the value at path isn't a boolean, an error is raised. + *
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"bool\": true, \"nested\": {\"bool\": false, \"nested\": {\"bool\": 10}}}").get();
      +     * var res = Json.toggle(client, gs("doc"), gs("$..bool")).get();
      +     * assert Arrays.equals((Boolean[]) res, new Boolean[] {false, true, null});
      +     * res = Json.toggle(client, gs("doc"), gs("bool")).get();
      +     * assert res.equals(true);
      +     * var getResult = Json.get(client, "doc", "$").get();
      +     * assert getResult.equals("{\"bool\": true, \"nested\": {\"bool\": true, \"nested\": {\"bool\": 10}}}");
      +     * }
      + */ + public static CompletableFuture toggle( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand( + client, new ArgsBuilder().add(gs(JSON_TOGGLE)).add(key).add(path).toArray()); + } + /** * A wrapper for custom command API. * diff --git a/java/integTest/src/test/java/glide/modules/JsonTests.java b/java/integTest/src/test/java/glide/modules/JsonTests.java index ad5ae55489..5a5d38c9fa 100644 --- a/java/integTest/src/test/java/glide/modules/JsonTests.java +++ b/java/integTest/src/test/java/glide/modules/JsonTests.java @@ -360,4 +360,37 @@ public void json_forget() { assertEquals(0L, Json.forget(client, key).get()); assertNull(Json.get(client, key, new String[] {"$"}).get()); } + + @Test + @SneakyThrows + public void toggle() { + String key = UUID.randomUUID().toString(); + String key2 = UUID.randomUUID().toString(); + String doc = "{\"bool\": true, \"nested\": {\"bool\": false, \"nested\": {\"bool\": 10}}}"; + + assertEquals("OK", Json.set(client, key, "$", doc).get()); + + assertArrayEquals( + new Object[] {false, true, null}, (Object[]) Json.toggle(client, key, "$..bool").get()); + + assertEquals(true, Json.toggle(client, gs(key), gs("bool")).get()); + + assertArrayEquals(new Object[] {}, (Object[]) Json.toggle(client, key, "$.non_existing").get()); + assertArrayEquals(new Object[] {null}, (Object[]) Json.toggle(client, key, "$.nested").get()); + + // testing behaviour with default path + assertEquals("OK", Json.set(client, key2, ".", "true").get()); + assertEquals(false, Json.toggle(client, key2).get()); + assertEquals(true, Json.toggle(client, gs(key2)).get()); + + // expect request errors + var exception = + assertThrows(ExecutionException.class, () -> Json.toggle(client, key, "nested").get()); + exception = + assertThrows( + ExecutionException.class, () -> Json.toggle(client, key, ".non_existing").get()); + exception = + assertThrows( + ExecutionException.class, () -> Json.toggle(client, "non_existing_key", "$").get()); + } } From fb8a63b0ea19489315d527a12674f288b7173a1c Mon Sep 17 00:00:00 2001 From: tjzhang-BQ <111323543+tjzhang-BQ@users.noreply.github.com> Date: Wed, 23 Oct 2024 11:47:44 -0700 Subject: [PATCH 041/180] Node: Add command JSON.DEL and JSON.FORGET (#2505) * Node: Add command JSON.DEL and JSON.FORGET Signed-off-by: TJ Zhang --- CHANGELOG.md | 1 + node/src/server-modules/GlideJson.ts | 83 ++++++++++- node/tests/ServerModules.test.ts | 208 +++++++++++++++++++++++++++ 3 files changed, 286 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c23a827531..29bf37edfb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ * Node: Added `JSON.SET` and `JSON.GET` ([#2427](https://github.com/valkey-io/valkey-glide/pull/2427)) * Java: Added `JSON.ARRAPPEND` ([#2489](https://github.com/valkey-io/valkey-glide/pull/2489)) * Node: Added `JSON.TOGGLE` ([#2491](https://github.com/valkey-io/valkey-glide/pull/2491)) +* Node: Added `JSON.DEL` and `JSON.FORGET` ([#2505](https://github.com/valkey-io/valkey-glide/pull/2505)) * Java: Added `JSON.TOGGLE` ([#2504](https://github.com/valkey-io/valkey-glide/pull/2504)) #### Breaking Changes diff --git a/node/src/server-modules/GlideJson.ts b/node/src/server-modules/GlideJson.ts index 9cab5092e4..0e2f10228a 100644 --- a/node/src/server-modules/GlideJson.ts +++ b/node/src/server-modules/GlideJson.ts @@ -87,7 +87,7 @@ export class GlideJson { * @param value - The value to set at the specific path, in JSON formatted bytes or str. * @param options - (Optional) Additional parameters: * - (Optional) `conditionalChange` - Set the value only if the given condition is met (within the key or path). - * Equivalent to [`XX` | `NX`] in the module API. Defaults to null. + * Equivalent to [`XX` | `NX`] in the module API. * - (Optional) `decoder`: see {@link DecoderOption}. * * @returns If the value is successfully set, returns `"OK"`. @@ -134,11 +134,11 @@ export class GlideJson { * - For JSONPath (path starts with `$`): * - Returns a stringified JSON list of bytes replies for every possible path, * or a byte string representation of an empty array, if path doesn't exist. - * If `key` doesn't exist, returns null. + * If `key` doesn't exist, returns `null`. * - For legacy path (path doesn't start with `$`): * Returns a byte string representation of the value in `path`. * If `path` doesn't exist, an error is raised. - * If `key` doesn't exist, returns null. + * If `key` doesn't exist, returns `null`. * - If multiple paths are given: * Returns a stringified JSON object in bytes, in which each path is a key, and it's corresponding value, is the value as if the path was executed in the command as a single path. * In case of multiple paths, and `paths` are a mix of both JSONPath and legacy path, the command behaves as if all are JSONPath paths. @@ -192,9 +192,9 @@ export class GlideJson { * @param client - The client to execute the command. * @param key - The key of the JSON document. * @param options - (Optional) Additional parameters: - * - (Optional) The JSONPath to specify. Defaults to the root if not specified. + * - (Optional) path - The JSONPath to specify. Defaults to the root if not specified. * @returns - For JSONPath (`path` starts with `$`), returns a list of boolean replies for every possible path, with the toggled boolean value, - * or null for JSON values matching the path that are not boolean. + * or `null` for JSON values matching the path that are not boolean. * - For legacy path (`path` doesn't starts with `$`), returns the value of the toggled boolean in `path`. * - Note that when sending legacy path syntax, If `path` doesn't exist or the value at `path` isn't a boolean, an error is raised. * @@ -231,10 +231,81 @@ export class GlideJson { ): Promise> { const args = ["JSON.TOGGLE", key]; - if (options !== undefined) { + if (options) { args.push(options.path); } return _executeCommand>(client, args); } + + /** + * Deletes the JSON value at the specified `path` within the JSON document stored at `key`. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param options - (Optional) Additional parameters: + * - (Optional) path - If `null`, deletes the entire JSON document at `key`. + * @returns - The number of elements removed. If `key` or `path` doesn't exist, returns 0. + * + * @example + * ```typescript + * console.log(await GlideJson.set(client, "doc", "$", '{a: 1, nested: {a:2, b:3}}')); + * // Output: "OK" - Indicates successful setting of the value at path '$' in the key stored at `doc`. + * console.log(await GlideJson.del(client, "doc", "$..a")); + * // Output: 2 - Indicates successful deletion of the specific values in the key stored at `doc`. + * console.log(await GlideJson.get(client, "doc", "$")); + * // Output: "[{nested: {b: 3}}]" - Returns the value at path '$' in the JSON document stored at `doc`. + * console.log(await GlideJson.del(client, "doc")); + * // Output: 1 - Deletes the entire JSON document stored at `doc`. + * ``` + */ + static async del( + client: BaseClient, + key: GlideString, + options?: { path: GlideString }, + ): Promise { + const args = ["JSON.DEL", key]; + + if (options) { + args.push(options.path); + } + + return _executeCommand(client, args); + } + + /** + * Deletes the JSON value at the specified `path` within the JSON document stored at `key`. This command is + * an alias of {@link del}. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param options - (Optional) Additional parameters: + * - (Optional) path - If `null`, deletes the entire JSON document at `key`. + * @returns - The number of elements removed. If `key` or `path` doesn't exist, returns 0. + * + * @example + * ```typescript + * console.log(await GlideJson.set(client, "doc", "$", '{a: 1, nested: {a:2, b:3}}')); + * // Output: "OK" - Indicates successful setting of the value at path '$' in the key stored at `doc`. + * console.log(await GlideJson.forget(client, "doc", "$..a")); + * // Output: 2 - Indicates successful deletion of the specific values in the key stored at `doc`. + * console.log(await GlideJson.get(client, "doc", "$")); + * // Output: "[{nested: {b: 3}}]" - Returns the value at path '$' in the JSON document stored at `doc`. + * console.log(await GlideJson.forget(client, "doc")); + * // Output: 1 - Deletes the entire JSON document stored at `doc`. + * ``` + */ + static async forget( + client: BaseClient, + key: GlideString, + options?: { path: GlideString }, + ): Promise { + const args = ["JSON.FORGET", key]; + + if (options) { + args.push(options.path); + } + + return _executeCommand(client, args); + } } diff --git a/node/tests/ServerModules.test.ts b/node/tests/ServerModules.test.ts index cffdc77dcd..160707aded 100644 --- a/node/tests/ServerModules.test.ts +++ b/node/tests/ServerModules.test.ts @@ -314,6 +314,214 @@ describe("Server Module Tests", () => { ).rejects.toThrow(RequestError); }, ); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "json.del tests", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const jsonValue = { a: 1.0, b: { a: 1, b: 2.5, c: true } }; + // setup + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + + // non-existing paths + expect( + await GlideJson.del(client, key, { path: "$..path" }), + ).toBe(0); + expect( + await GlideJson.del(client, key, { path: "..path" }), + ).toBe(0); + + // deleting existing paths + expect(await GlideJson.del(client, key, { path: "$..a" })).toBe( + 2, + ); + expect( + await GlideJson.get(client, key, { paths: ["$..a"] }), + ).toBe("[]"); + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + expect(await GlideJson.del(client, key, { path: "..a" })).toBe( + 2, + ); + await expect( + GlideJson.get(client, key, { paths: ["..a"] }), + ).rejects.toThrow(RequestError); + + // verify result + const result = await GlideJson.get(client, key, { + paths: ["$"], + }); + expect(JSON.parse(result as string)).toEqual([ + { b: { b: 2.5, c: true } }, + ]); + + // test root deletion operations + expect(await GlideJson.del(client, key, { path: "$" })).toBe(1); + + // reset and test dot deletion + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + expect(await GlideJson.del(client, key, { path: "." })).toBe(1); + + // reset and test key deletion + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + expect(await GlideJson.del(client, key)).toBe(1); + expect(await GlideJson.del(client, key)).toBe(0); + expect( + await GlideJson.get(client, key, { paths: ["$"] }), + ).toBeNull(); + + // non-existing keys + expect( + await GlideJson.del(client, "non_existing_key", { + path: "$", + }), + ).toBe(0); + expect( + await GlideJson.del(client, "non_existing_key", { + path: ".", + }), + ).toBe(0); + }, + ); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "json.forget tests", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const jsonValue = { a: 1.0, b: { a: 1, b: 2.5, c: true } }; + // setup + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + + // non-existing paths + expect( + await GlideJson.forget(client, key, { path: "$..path" }), + ).toBe(0); + expect( + await GlideJson.forget(client, key, { path: "..path" }), + ).toBe(0); + + // deleting existing paths + expect( + await GlideJson.forget(client, key, { path: "$..a" }), + ).toBe(2); + expect( + await GlideJson.get(client, key, { paths: ["$..a"] }), + ).toBe("[]"); + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + expect( + await GlideJson.forget(client, key, { path: "..a" }), + ).toBe(2); + await expect( + GlideJson.get(client, key, { paths: ["..a"] }), + ).rejects.toThrow(RequestError); + + // verify result + const result = await GlideJson.get(client, key, { + paths: ["$"], + }); + expect(JSON.parse(result as string)).toEqual([ + { b: { b: 2.5, c: true } }, + ]); + + // test root deletion operations + expect(await GlideJson.forget(client, key, { path: "$" })).toBe( + 1, + ); + + // reset and test dot deletion + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + expect(await GlideJson.forget(client, key, { path: "." })).toBe( + 1, + ); + + // reset and test key deletion + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + expect(await GlideJson.forget(client, key)).toBe(1); + expect(await GlideJson.forget(client, key)).toBe(0); + expect( + await GlideJson.get(client, key, { paths: ["$"] }), + ).toBeNull(); + + // non-existing keys + expect( + await GlideJson.forget(client, "non_existing_key", { + path: "$", + }), + ).toBe(0); + expect( + await GlideJson.forget(client, "non_existing_key", { + path: ".", + }), + ).toBe(0); + }, + ); }); describe("GlideFt", () => { From 63c225b7c31f7ff43b3856f326763599daae0a4a Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Thu, 24 Oct 2024 10:44:51 -0700 Subject: [PATCH 042/180] Minor update for node build (#2393) Signed-off-by: Yury-Fridlyand --- node/package.json | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/node/package.json b/node/package.json index 53b77772b6..6cf654aa29 100644 --- a/node/package.json +++ b/node/package.json @@ -20,7 +20,7 @@ "glide-rs" ], "scripts": { - "build": "npm run build-internal && npm run build-protobuf && npm run build-external", + "build": "npm run prereq && npm run build-internal && npm run build-protobuf && npm run build-external", "build:release": "npm run build-internal:release && npm run build-protobuf && npm run build-external:release", "build:benchmark": "npm run build-internal:benchmark && npm run build-protobuf && npm run build-external", "build-internal": "cd rust-client && npm run build", @@ -30,6 +30,7 @@ "build-external:release": "rm -rf build-ts && tsc --stripInternal", "build-protobuf": "npm run compile-protobuf-files && npm run fix-protobuf-file", "compile-protobuf-files": "cd src && pbjs -t static-module -o ProtobufMessage.js ../../glide-core/src/protobuf/*.proto && pbts -o ProtobufMessage.d.ts ProtobufMessage.js", + "clean": "rm -rf build-ts rust-client/target docs glide-logs rust-client/glide-rs.*.node rust-client/index.* src/ProtobufMessage.*", "fix-protobuf-file": "replace 'this\\.encode\\(message, writer\\)\\.ldelim' 'this.encode(message, writer && writer.len ? writer.fork() : writer).ldelim' src/ProtobufMessage.js", "test": "npm run build-test-utils && jest --verbose --runInBand --testPathIgnorePatterns='ServerModules'", "test-minimum": "npm run build-test-utils && jest --verbose --runInBand --testNamePattern='^(.(?!(GlideJson|GlideFt|pubsub|kill)))*$'", @@ -39,6 +40,7 @@ "lint": "npm run install-linting && npx eslint -c ../eslint.config.mjs && npm run prettier:check:ci", "install-linting": "cd ../ & npm install", "prepack": "npmignore --auto", + "prereq": "git submodule update --init --recursive && npm install", "prettier:check:ci": "npx prettier --check . --ignore-unknown '!**/*.{js,d.ts}'", "prettier:format": "npx prettier --write . --ignore-unknown '!**/*.{js,d.ts}'" }, From b44368e2d904a66573480280697a6b960d2804f3 Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Thu, 24 Oct 2024 12:43:29 -0700 Subject: [PATCH 043/180] Java: `JSON.OBJLEN` and `JSON.OBJKEYS`. (#2492) * `JSON.OBJLEN` and `JSON.OBJKEYS`. Signed-off-by: Yury-Fridlyand --------- Signed-off-by: Yury-Fridlyand --- CHANGELOG.md | 1 + .../api/commands/servermodules/Json.java | 297 ++++++++++++++++-- .../test/java/glide/modules/JsonTests.java | 42 +++ 3 files changed, 316 insertions(+), 24 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 29bf37edfb..d65cabacbb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ * Java: Added `JSON.SET` and `JSON.GET` ([#2462](https://github.com/valkey-io/valkey-glide/pull/2462)) * Node: Added `FT.CREATE` ([#2501](https://github.com/valkey-io/valkey-glide/pull/2501)) * Java: Added `JSON.ARRINSERT` and `JSON.ARRLEN` ([#2476](https://github.com/valkey-io/valkey-glide/pull/2476)) +* Java: Added `JSON.OBJLEN` and `JSON.OBJKEYS` ([#2492](https://github.com/valkey-io/valkey-glide/pull/2492)) * Java: Added `JSON.DEL` and `JSON.FORGET` ([#2490](https://github.com/valkey-io/valkey-glide/pull/2490)) * Java: Added `FT.ALIASADD`, `FT.ALIASDEL`, `FT.ALIASUPDATE` ([#2442](https://github.com/valkey-io/valkey-glide/pull/2442)) * Core: Update routing for commands from server modules ([#2461](https://github.com/valkey-io/valkey-glide/pull/2461)) diff --git a/java/client/src/main/java/glide/api/commands/servermodules/Json.java b/java/client/src/main/java/glide/api/commands/servermodules/Json.java index 7beb5fbab5..80b5eaf028 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/Json.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/Json.java @@ -25,6 +25,8 @@ public class Json { private static final String JSON_ARRAPPEND = JSON_PREFIX + "ARRAPPEND"; private static final String JSON_ARRINSERT = JSON_PREFIX + "ARRINSERT"; private static final String JSON_ARRLEN = JSON_PREFIX + "ARRLEN"; + private static final String JSON_OBJLEN = JSON_PREFIX + "OBJLEN"; + private static final String JSON_OBJKEYS = JSON_PREFIX + "OBJKEYS"; private static final String JSON_DEL = JSON_PREFIX + "DEL"; private static final String JSON_FORGET = JSON_PREFIX + "FORGET"; private static final String JSON_TOGGLE = JSON_PREFIX + "TOGGLE"; @@ -705,6 +707,245 @@ public static CompletableFuture arrlen( return executeCommand(client, new GlideString[] {gs(JSON_ARRLEN), key}); } + /** + * Retrieves the number of key-value pairs in the object values at the specified path + * within the JSON document stored at key.
      + * Equivalent to {@link #objlen(BaseClient, String, String)} with path set to + * ".". + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return The object length stored at the root of the document. If document root is not an + * object, an error is raised.
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\": 1.0, \"b\": {\"a\": {\"x\": 1, \"y\": 2}, \"b\": 2.5, \"c\": true}}").get();
      +     * var res = Json.objlen(client, "doc").get();
      +     * assert res == 2; // the size of object matching the path `.`, which has 2 keys: 'a' and 'b'.
      +     * }
      + */ + public static CompletableFuture objlen(@NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new String[] {JSON_OBJLEN, key}); + } + + /** + * Retrieves the number of key-value pairs in the object values at the specified path + * within the JSON document stored at key.
      + * Equivalent to {@link #objlen(BaseClient, GlideString, GlideString)} with path set + * to gs("."). + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return The object length stored at the root of the document. If document root is not an + * object, an error is raised.
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\": 1.0, \"b\": {\"a\": {\"x\": 1, \"y\": 2}, \"b\": 2.5, \"c\": true}}").get();
      +     * var res = Json.objlen(client, gs("doc"), gs(".")).get();
      +     * assert res == 2; // the size of object matching the path `.`, which has 2 keys: 'a' and 'b'.
      +     * }
      + */ + public static CompletableFuture objlen( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new GlideString[] {gs(JSON_OBJLEN), key}); + } + + /** + * Retrieves the number of key-value pairs in the object values at the specified path + * within the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an Object[] with a list of long integers for every possible + * path, indicating the number of key-value pairs for each matching object, or + * null + * for JSON values matching the path that are not an object. If path + * does not exist, an empty array will be returned. + *
      • For legacy path (path doesn't start with $):
        + * Returns the number of key-value pairs for the object value matching the path. If + * multiple paths are matched, returns the length of the first matching object. If + * path doesn't exist or the value at path is not an array, an + * error is raised. + *
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\": 1.0, \"b\": {\"a\": {\"x\": 1, \"y\": 2}, \"b\": 2.5, \"c\": true}}").get();
      +     * var res = Json.objlen(client, "doc", ".").get(); // legacy path - command returns first value as `Long`
      +     * assert res == 2L; // the size of object matching the path `.`, which has 2 keys: 'a' and 'b'.
      +     *
      +     * res = Json.objlen(client, "doc", "$.b").get(); // JSONPath - command returns an array
      +     * assert Arrays.equals((Object[]) res, new Object[] { 3L }); // the length of the objects at path `$.b`
      +     * }
      + */ + public static CompletableFuture objlen( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, new String[] {JSON_OBJLEN, key, path}); + } + + /** + * Retrieves the number of key-value pairs in the object values at the specified path + * within the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an Object[] with a list of long integers for every possible + * path, indicating the number of key-value pairs for each matching object, or + * null + * for JSON values matching the path that are not an object. If path + * does not exist, an empty array will be returned. + *
      • For legacy path (path doesn't start with $):
        + * Returns the number of key-value pairs for the object value matching the path. If + * multiple paths are matched, returns the length of the first matching object. If + * path doesn't exist or the value at path is not an array, an + * error is raised. + *
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\": 1.0, \"b\": {\"a\": {\"x\": 1, \"y\": 2}, \"b\": 2.5, \"c\": true}}").get();
      +     * var res = Json.objlen(client, gs("doc"), gs(".")).get(); // legacy path - command returns first value as `Long`
      +     * assert res == 2L; // the size of object matching the path `.`, which has 2 keys: 'a' and 'b'.
      +     *
      +     * res = Json.objlen(client, gs("doc"), gs("$.b")).get(); // JSONPath - command returns an array
      +     * assert Arrays.equals((Object[]) res, new Object[] { 3L }); // the length of the objects at path `$.b`
      +     * }
      + */ + public static CompletableFuture objlen( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand(client, new GlideString[] {gs(JSON_OBJLEN), key, path}); + } + + /** + * Retrieves the key names in the object values at the specified path within the JSON + * document stored at key.
      + * Equivalent to {@link #objkeys(BaseClient, String, String)} with path set to + * ".". + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return The object length stored at the root of the document. If document root is not an + * object, an error is raised.
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\": 1.0, \"b\": {\"a\": {\"x\": 1, \"y\": 2}, \"b\": 2.5, \"c\": true}}").get();
      +     * var res = Json.objkeys(client, "doc").get();
      +     * assert Arrays.equals((Object[]) res, new Object[] { "a", "b" }); // the keys of the object matching the path `.`, which has 2 keys: 'a' and 'b'.
      +     * }
      + */ + public static CompletableFuture objkeys( + @NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new String[] {JSON_OBJKEYS, key}); + } + + /** + * Retrieves the key names in the object values at the specified path within the JSON + * document stored at key.
      + * Equivalent to {@link #objkeys(BaseClient, GlideString, GlideString)} with path set + * to gs("."). + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return The object length stored at the root of the document. If document root is not an + * object, an error is raised.
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\": 1.0, \"b\": {\"a\": {\"x\": 1, \"y\": 2}, \"b\": 2.5, \"c\": true}}").get();
      +     * var res = Json.objkeys(client, gs("doc"), gs(".")).get();
      +     * assert Arrays.equals((Object[]) res, new Object[] { gs("a"), gs("b") }); // the keys of the object matching the path `.`, which has 2 keys: 'a' and 'b'.
      +     * }
      + */ + public static CompletableFuture objkeys( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new GlideString[] {gs(JSON_OBJKEYS), key}); + } + + /** + * Retrieves the key names in the object values at the specified path within the JSON + * document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an Object[][] with each nested array containing key names for + * each matching object for every possible path, indicating the list of object keys for + * each matching object, or null for JSON values matching the path that are + * not an object. If path does not exist, an empty sub-array will be + * returned. + *
      • For legacy path (path doesn't start with $):
        + * Returns an array of object keys for the object value matching the path. If multiple + * paths are matched, returns the length of the first matching object. If path + * doesn't exist or the value at path is not an array, an error is + * raised. + *
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\": 1.0, \"b\": {\"a\": {\"x\": 1, \"y\": 2}, \"b\": 2.5, \"c\": true}}").get();
      +     * var res = Json.objkeys(client, "doc", ".").get(); // legacy path - command returns array for first matched object
      +     * assert Arrays.equals((Object[]) res, new Object[] { "a", "b" }); // key names for the object matching the path `.` as it is the only match.
      +     *
      +     * res = Json.objkeys(client, "doc", "$.b").get(); // JSONPath - command returns an array for each matched object
      +     * assert Arrays.equals((Object[]) res, new Object[][] { { "a", "b", "c" } }); // key names as a nested list for objects matching the JSONPath `$.b`.
      +     * }
      + */ + public static CompletableFuture objkeys( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, new String[] {JSON_OBJKEYS, key, path}); + } + + /** + * Retrieves the key names in the object values at the specified path within the JSON + * document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an Object[][] with each nested array containing key names for + * each matching object for every possible path, indicating the list of object keys for + * each matching object, or null for JSON values matching the path that are + * not an object. If path does not exist, an empty sub-array will be + * returned. + *
      • For legacy path (path doesn't start with $):
        + * Returns an array of object keys for the object value matching the path. If multiple + * paths are matched, returns the length of the first matching object. If path + * doesn't exist or the value at path is not an array, an error is + * raised. + *
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\": 1.0, \"b\": {\"a\": {\"x\": 1, \"y\": 2}, \"b\": 2.5, \"c\": true}}").get();
      +     * var res = Json.objkeys(client, gs("doc"), gs(".")).get(); // legacy path - command returns array for first matched object
      +     * assert Arrays.equals((Object[]) res, new Object[] { "a", "b" }); // key names for the object matching the path `.` as it is the only match.
      +     *
      +     * res = Json.objkeys(client, gs("doc"), gs("$.b")).get(); // JSONPath - command returns an array for each matched object
      +     * assert Arrays.equals((Object[]) res, new Object[][] { { "a", "b", "c" } }); // key names as a nested list for objects matching the JSONPath `$.b`.
      +     * }
      + */ + public static CompletableFuture objkeys( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand(client, new GlideString[] {gs(JSON_OBJKEYS), key, path}); + } + /** * Deletes the JSON document stored at key. * @@ -713,9 +954,9 @@ public static CompletableFuture arrlen( * @return The number of elements deleted. 0 if the key does not exist. * @example *
      {@code
      -     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}");
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}}");
            * Long result = Json.del(client, "doc").get();
      -     * assertEquals(result, 1L);
      +     * assert result == 1L;
            * }
      */ public static CompletableFuture del(@NonNull BaseClient client, @NonNull String key) { @@ -730,9 +971,9 @@ public static CompletableFuture del(@NonNull BaseClient client, @NonNull S * @return The number of elements deleted. 0 if the key does not exist. * @example *
      {@code
      -     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}");
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}}");
            * Long result = Json.del(client, gs("doc")).get();
      -     * assertEquals(result, 1L);
      +     * assert result == 1L;
            * }
      */ public static CompletableFuture del(@NonNull BaseClient client, @NonNull GlideString key) { @@ -740,17 +981,19 @@ public static CompletableFuture del(@NonNull BaseClient client, @NonNull G } /** - * Deletes the JSON value at the specified path within the JSON document stored at key. + * Deletes the JSON value at the specified path within the JSON document stored at + * key. * * @param client The Valkey GLIDE client to execute the command. * @param key The key of the JSON document. * @param path Represents the path within the JSON document where the value will be deleted. - * @return The number of elements deleted. 0 if the key does not exist, or if the JSON path is invalid or does not exist. + * @return The number of elements deleted. 0 if the key does not exist, or if the JSON path is + * invalid or does not exist. * @example *
      {@code
      -     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}");
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}}");
            * Long result = Json.del(client, "doc", "$..a").get();
      -     * assertEquals(result, 2L);
      +     * assert result == 2L;
            * }
      */ public static CompletableFuture del( @@ -759,17 +1002,19 @@ public static CompletableFuture del( } /** - * Deletes the JSON value at the specified path within the JSON document stored at key. + * Deletes the JSON value at the specified path within the JSON document stored at + * key. * * @param client The Valkey GLIDE client to execute the command. * @param key The key of the JSON document. * @param path Represents the path within the JSON document where the value will be deleted. - * @return The number of elements deleted. 0 if the key does not exist, or if the JSON path is invalid or does not exist. + * @return The number of elements deleted. 0 if the key does not exist, or if the JSON path is + * invalid or does not exist. * @example *
      {@code
      -     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}");
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}}");
            * Long result = Json.del(client, gs("doc"), gs("$..a")).get();
      -     * assertEquals(result, 2L);
      +     * assert result == 2L;
            * }
      */ public static CompletableFuture del( @@ -785,9 +1030,9 @@ public static CompletableFuture del( * @return The number of elements deleted. 0 if the key does not exist. * @example *
      {@code
      -     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}");
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}}");
            * Long result = Json.forget(client, "doc").get();
      -     * assertEquals(result, 1L);
      +     * assert result == 1L;
            * }
      */ public static CompletableFuture forget(@NonNull BaseClient client, @NonNull String key) { @@ -802,9 +1047,9 @@ public static CompletableFuture forget(@NonNull BaseClient client, @NonNul * @return The number of elements deleted. 0 if the key does not exist. * @example *
      {@code
      -     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}");
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}}");
            * Long result = Json.forget(client, gs("doc")).get();
      -     * assertEquals(result, 1L);
      +     * assert result == 1L;
            * }
      */ public static CompletableFuture forget( @@ -813,17 +1058,19 @@ public static CompletableFuture forget( } /** - * Deletes the JSON value at the specified path within the JSON document stored at key. + * Deletes the JSON value at the specified path within the JSON document stored at + * key. * * @param client The Valkey GLIDE client to execute the command. * @param key The key of the JSON document. * @param path Represents the path within the JSON document where the value will be deleted. - * @return The number of elements deleted. 0 if the key does not exist, or if the JSON path is invalid or does not exist. + * @return The number of elements deleted. 0 if the key does not exist, or if the JSON path is + * invalid or does not exist. * @example *
      {@code
      -     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}");
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}}");
            * Long result = Json.forget(client, "doc", "$..a").get();
      -     * assertEquals(result, 2L);
      +     * assert result == 2L;
            * }
      */ public static CompletableFuture forget( @@ -832,17 +1079,19 @@ public static CompletableFuture forget( } /** - * Deletes the JSON value at the specified path within the JSON document stored at key. + * Deletes the JSON value at the specified path within the JSON document stored at + * key. * * @param client The Valkey GLIDE client to execute the command. * @param key The key of the JSON document. * @param path Represents the path within the JSON document where the value will be deleted. - * @return The number of elements deleted. 0 if the key does not exist, or if the JSON path is invalid or does not exist. + * @return The number of elements deleted. 0 if the key does not exist, or if the JSON path is + * invalid or does not exist. * @example *
      {@code
      -     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}");
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}}");
            * Long result = Json.forget(client, gs("doc"), gs("$..a")).get();
      -     * assertEquals(result, 2L);
      +     * assert result == 2L;
            * }
      */ public static CompletableFuture forget( diff --git a/java/integTest/src/test/java/glide/modules/JsonTests.java b/java/integTest/src/test/java/glide/modules/JsonTests.java index 5a5d38c9fa..de59753c20 100644 --- a/java/integTest/src/test/java/glide/modules/JsonTests.java +++ b/java/integTest/src/test/java/glide/modules/JsonTests.java @@ -321,6 +321,27 @@ public void arrlen() { assertEquals(5L, res); } + @Test + @SneakyThrows + public void objlen() { + String key = UUID.randomUUID().toString(); + + String doc = "{\"a\": 1.0, \"b\": {\"a\": {\"x\": 1, \"y\": 2}, \"b\": 2.5, \"c\": true}}"; + assertEquals("OK", Json.set(client, key, "$", doc).get()); + + var res = Json.objlen(client, key, "$..").get(); + assertArrayEquals(new Object[] {2L, 3L, 2L}, (Object[]) res); + + res = Json.objlen(client, gs(key), gs("..b")).get(); + assertEquals(3L, res); + + // without path + res = Json.objlen(client, key).get(); + assertEquals(2L, res); + res = Json.objlen(client, gs(key)).get(); + assertEquals(2L, res); + } + @Test @SneakyThrows public void json_del() { @@ -341,6 +362,27 @@ public void json_del() { assertNull(Json.get(client, key, new String[] {"$"}).get()); } + @Test + @SneakyThrows + public void objkeys() { + String key = UUID.randomUUID().toString(); + + String doc = "{\"a\": 1.0, \"b\": {\"a\": {\"x\": 1, \"y\": 2}, \"b\": 2.5, \"c\": true}}"; + assertEquals("OK", Json.set(client, key, "$", doc).get()); + + var res = Json.objkeys(client, key, "..").get(); + assertArrayEquals(new Object[] {"a", "b"}, res); + + res = Json.objkeys(client, gs(key), gs("$..b")).get(); + assertArrayEquals(new Object[][] {{gs("a"), gs("b"), gs("c")}, {}}, res); + + // without path + res = Json.objkeys(client, key).get(); + assertArrayEquals(new Object[] {"a", "b"}, res); + res = Json.objkeys(client, gs(key)).get(); + assertArrayEquals(new Object[] {gs("a"), gs("b")}, res); + } + @Test @SneakyThrows public void json_forget() { From 50786a3aacac611355bcecb75ccca9e50d2509ac Mon Sep 17 00:00:00 2001 From: prateek-kumar-improving Date: Thu, 24 Oct 2024 13:19:16 -0700 Subject: [PATCH 044/180] FT.EXPLAIN and FT.EXPLAINCLI commands added (#2508) * FT.EXPLAIN and FT.EXPLAINCLI commands added --------- Signed-off-by: Prateek Kumar --- CHANGELOG.md | 1 + .../glide/async_commands/server_modules/ft.py | 46 ++++++++ .../server_modules/ft_options/ft_constants.py | 2 + .../tests/tests_server_modules/test_ft.py | 107 +++++++++++++++++- 4 files changed, 154 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d65cabacbb..cb3817dea2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ #### Changes +* Python: FT.EXPLAIN and FT.EXPLAINCLI commands added([#2508](https://github.com/valkey-io/valkey-glide/pull/2508)) * Python: Python FT.INFO command added([#2429](https://github.com/valkey-io/valkey-glide/pull/2494)) * Python: Add FT.SEARCH command([#2470](https://github.com/valkey-io/valkey-glide/pull/2470)) * Python: Add commands FT.ALIASADD, FT.ALIASDEL, FT.ALIASUPDATE([#2471](https://github.com/valkey-io/valkey-glide/pull/2471)) diff --git a/python/python/glide/async_commands/server_modules/ft.py b/python/python/glide/async_commands/server_modules/ft.py index dbfa0b9cb7..d96352a36d 100644 --- a/python/python/glide/async_commands/server_modules/ft.py +++ b/python/python/glide/async_commands/server_modules/ft.py @@ -234,3 +234,49 @@ async def info(client: TGlideClient, indexName: TEncodable) -> FtInfoResponse: """ args: List[TEncodable] = [CommandNames.FT_INFO, indexName] return cast(FtInfoResponse, await client.custom_command(args)) + + +async def explain( + client: TGlideClient, indexName: TEncodable, query: TEncodable +) -> TEncodable: + """ + Parse a query and return information about how that query was parsed. + + Args: + client (TGlideClient): The client to execute the command. + indexName (TEncodable): The index name for which the query is written. + query (TEncodable): The search query, same as the query passed as an argument to FT.SEARCH. + + Returns: + TEncodable: A string containing the parsed results representing the execution plan. + + Examples: + >>> from glide import ft + >>> result = await ft.explain(glide_client, indexName="myIndex", query="@price:[0 10]") + b'Field {\n price\n 0\n 10\n}\n' # Parsed results. + """ + args: List[TEncodable] = [CommandNames.FT_EXPLAIN, indexName, query] + return cast(TEncodable, await client.custom_command(args)) + + +async def explaincli( + client: TGlideClient, indexName: TEncodable, query: TEncodable +) -> List[TEncodable]: + """ + Same as the FT.EXPLAIN command except that the results are displayed in a different format. More useful with cli. + + Args: + client (TGlideClient): The client to execute the command. + indexName (TEncodable): The index name for which the query is written. + query (TEncodable): The search query, same as the query passed as an argument to FT.SEARCH. + + Returns: + List[TEncodable]: An array containing the execution plan. + + Examples: + >>> from glide import ft + >>> result = await ft.explaincli(glide_client, indexName="myIndex", query="@price:[0 10]") + [b'Field {', b' price', b' 0', b' 10', b'}', b''] # Parsed results. + """ + args: List[TEncodable] = [CommandNames.FT_EXPLAINCLI, indexName, query] + return cast(List[TEncodable], await client.custom_command(args)) diff --git a/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py b/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py index 0077c8c3f3..1755c6136e 100644 --- a/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py +++ b/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py @@ -13,6 +13,8 @@ class CommandNames: FT_ALIASADD = "FT.ALIASADD" FT_ALIASDEL = "FT.ALIASDEL" FT_ALIASUPDATE = "FT.ALIASUPDATE" + FT_EXPLAIN = "FT.EXPLAIN" + FT_EXPLAINCLI = "FT.EXPLAINCLI" class FtCreateKeywords: diff --git a/python/python/tests/tests_server_modules/test_ft.py b/python/python/tests/tests_server_modules/test_ft.py index 812e184d7e..9d38531737 100644 --- a/python/python/tests/tests_server_modules/test_ft.py +++ b/python/python/tests/tests_server_modules/test_ft.py @@ -9,6 +9,7 @@ DistanceMetricType, Field, FtCreateOptions, + NumericField, TextField, VectorAlgorithm, VectorField, @@ -107,10 +108,10 @@ async def _create_test_index_hash_type( ): # Helper function used for creating a basic index with hash data type with one text field. fields: List[Field] = [] - text_field_title: TextField = TextField("$title") + text_field_title: TextField = TextField("title") fields.append(text_field_title) - prefix = "{json-search-" + str(uuid.uuid4()) + "}:" + prefix = "{hash-search-" + str(uuid.uuid4()) + "}:" prefixes: List[TEncodable] = [] prefixes.append(prefix) @@ -198,3 +199,105 @@ async def _create_test_index_with_vector_field( schema=fields, options=FtCreateOptions(DataType.JSON, prefixes=prefixes), ) + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_explain(self, glide_client: GlideClusterClient): + indexName = str(uuid.uuid4()) + await TestFt._create_test_index_for_ft_explain_commands( + self=self, glide_client=glide_client, index_name=indexName + ) + + # FT.EXPLAIN on a search query containing numeric field. + query = "@price:[0 10]" + result = await ft.explain(glide_client, indexName=indexName, query=query) + resultString = cast(bytes, result).decode(encoding="utf-8") + assert "price" in resultString and "0" in resultString and "10" in resultString + + # FT.EXPLAIN on a search query containing numeric field and having bytes type input to the command. + result = await ft.explain( + glide_client, indexName=indexName.encode(), query=query.encode() + ) + resultString = cast(bytes, result).decode(encoding="utf-8") + assert "price" in resultString and "0" in resultString and "10" in resultString + + # FT.EXPLAIN on a search query that returns all data. + result = await ft.explain(glide_client, indexName=indexName, query="*") + resultString = cast(bytes, result).decode(encoding="utf-8") + assert "*" in resultString + + assert await ft.dropindex(glide_client, indexName=indexName) + + # FT.EXPLAIN on a missing index throws an error. + with pytest.raises(RequestError): + await ft.explain(glide_client, str(uuid.uuid4()), "*") + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_explaincli(self, glide_client: GlideClusterClient): + indexName = str(uuid.uuid4()) + await TestFt._create_test_index_for_ft_explain_commands( + self=self, glide_client=glide_client, index_name=indexName + ) + + # FT.EXPLAINCLI on a search query containing numeric field. + query = "@price:[0 10]" + result = await ft.explaincli(glide_client, indexName=indexName, query=query) + resultStringArr = [] + for i in result: + resultStringArr.append(cast(bytes, i).decode(encoding="utf-8").strip()) + assert ( + "price" in resultStringArr + and "0" in resultStringArr + and "10" in resultStringArr + ) + + # FT.EXPLAINCLI on a search query containing numeric field and having bytes type input to the command. + result = await ft.explaincli( + glide_client, indexName=indexName.encode(), query=query.encode() + ) + resultStringArr = [] + for i in result: + resultStringArr.append(cast(bytes, i).decode(encoding="utf-8").strip()) + assert ( + "price" in resultStringArr + and "0" in resultStringArr + and "10" in resultStringArr + ) + + # FT.EXPLAINCLI on a search query that returns all data. + result = await ft.explaincli(glide_client, indexName=indexName, query="*") + resultStringArr = [] + for i in result: + resultStringArr.append(cast(bytes, i).decode(encoding="utf-8").strip()) + assert "*" in resultStringArr + + assert await ft.dropindex(glide_client, indexName=indexName) + + # FT.EXPLAINCLI on a missing index throws an error. + with pytest.raises(RequestError): + await ft.explaincli(glide_client, str(uuid.uuid4()), "*") + + async def _create_test_index_for_ft_explain_commands( + self, glide_client: GlideClusterClient, index_name: TEncodable + ): + # Helper function used for creating an index having hash data type, one text field and one numeric field. + fields: List[Field] = [] + numeric_field: NumericField = NumericField("price") + text_field: TextField = TextField("title") + fields.append(text_field) + fields.append(numeric_field) + + prefix = "{hash-search-" + str(uuid.uuid4()) + "}:" + prefixes: List[TEncodable] = [] + prefixes.append(prefix) + + assert ( + await ft.create( + glide_client, + index_name, + fields, + FtCreateOptions(DataType.HASH, prefixes), + ) + == OK + ) From f19976bd566c45d96a00cf2cf5df59514882d5e6 Mon Sep 17 00:00:00 2001 From: tjzhang-BQ <111323543+tjzhang-BQ@users.noreply.github.com> Date: Thu, 24 Oct 2024 15:09:24 -0700 Subject: [PATCH 045/180] Node: Add command JSON.TYPE (#2510) * Node: Add command JSON.TYPE and JSON.RESP --------- Signed-off-by: TJ Zhang Co-authored-by: TJ Zhang --- CHANGELOG.md | 1 + node/src/server-modules/GlideJson.ts | 44 ++++++++++++++++++ node/tests/ServerModules.test.ts | 68 ++++++++++++++++++++++++++++ 3 files changed, 113 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index cb3817dea2..9d5bbae376 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ * Node: Added `JSON.TOGGLE` ([#2491](https://github.com/valkey-io/valkey-glide/pull/2491)) * Node: Added `JSON.DEL` and `JSON.FORGET` ([#2505](https://github.com/valkey-io/valkey-glide/pull/2505)) * Java: Added `JSON.TOGGLE` ([#2504](https://github.com/valkey-io/valkey-glide/pull/2504)) +* Node: Added `JSON.TYPE` ([#2510](https://github.com/valkey-io/valkey-glide/pull/2510)) #### Breaking Changes diff --git a/node/src/server-modules/GlideJson.ts b/node/src/server-modules/GlideJson.ts index 0e2f10228a..db40a31efc 100644 --- a/node/src/server-modules/GlideJson.ts +++ b/node/src/server-modules/GlideJson.ts @@ -308,4 +308,48 @@ export class GlideJson { return _executeCommand(client, args); } + + /** + * Reports the type of values at the given path. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param options - (Optional) Additional parameters: + * - (Optional) path - defaults to root if not provided. + * @returns ReturnTypeJson: + * - For JSONPath (path starts with `$`): + * - Returns an array of strings that represents the type of value at each path. + * The type is one of "null", "boolean", "string", "number", "integer", "object" and "array". + * - If a path does not exist, its corresponding return value is `null`. + * - Empty array if the document key does not exist. + * - For legacy path (path doesn't start with `$`): + * - String that represents the type of the value. + * - `null` if the document key does not exist. + * - `null` if the JSON path is invalid or does not exist. + * + * @example + * ```typescript + * console.log(await GlideJson.set(client, "doc", "$", "[1, 2.3, "foo", true, null, {}, []]")); + * // Output: 'OK' - Indicates successful setting of the value at path '$' in the key stored at `doc`. + * const result = await GlideJson.type(client, "doc", "$[*]"); + * console.log(result); + * // Output: ["integer", "number", "string", "boolean", null, "object", "array"]; + * console.log(await GlideJson.set(client, "doc2", ".", "{Name: 'John', Age: 27}")); + * console.log(await GlideJson.type(client, "doc2")); // Output: "object" + * console.log(await GlideJson.type(client, "doc2", ".Age")); // Output: "integer" + * ``` + */ + static async type( + client: BaseClient, + key: GlideString, + options?: { path: GlideString }, + ): Promise> { + const args = ["JSON.TYPE", key]; + + if (options) { + args.push(options.path); + } + + return _executeCommand>(client, args); + } } diff --git a/node/tests/ServerModules.test.ts b/node/tests/ServerModules.test.ts index 160707aded..29e10bcfaf 100644 --- a/node/tests/ServerModules.test.ts +++ b/node/tests/ServerModules.test.ts @@ -522,6 +522,74 @@ describe("Server Module Tests", () => { ).toBe(0); }, ); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "json.type tests", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const jsonValue = [1, 2.3, "foo", true, null, {}, []]; + // setup + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + expect( + await GlideJson.type(client, key, { path: "$[*]" }), + ).toEqual([ + "integer", + "number", + "string", + "boolean", + "null", + "object", + "array", + ]); + expect( + await GlideJson.type(client, "non_existing", { + path: "$[*]", + }), + ).toBeNull(); + expect( + await GlideJson.type(client, key, { + path: "$non_existing", + }), + ).toEqual([]); + + const key2 = uuidv4(); + const jsonValue2 = { Name: "John", Age: 27 }; + // setup + expect( + await GlideJson.set( + client, + key2, + "$", + JSON.stringify(jsonValue2), + ), + ).toBe("OK"); + expect( + await GlideJson.type(client, key2, { path: "." }), + ).toEqual("object"); + expect( + await GlideJson.type(client, key2, { path: ".Age" }), + ).toEqual("integer"); + expect( + await GlideJson.type(client, key2, { path: ".Job" }), + ).toBeNull(); + expect( + await GlideJson.type(client, "non_existing", { path: "." }), + ).toBeNull(); + }, + ); }); describe("GlideFt", () => { From b13535aad2fcdc0e5a41ee552d72f1f399b47943 Mon Sep 17 00:00:00 2001 From: Andrew Carbonetto Date: Thu, 24 Oct 2024 15:09:41 -0700 Subject: [PATCH 046/180] Docs: Update `ft.create` docs for Java and Python (#2512) * Docs: Update Java and Python FT.CREATE docs --------- Signed-off-by: Andrew Carbonetto --- .../glide/api/commands/servermodules/FT.java | 38 +++--- .../models/commands/FT/FTCreateOptions.java | 81 +++++++------ .../java/glide/modules/VectorSearchTests.java | 18 +-- python/DEVELOPER.md | 8 +- python/python/glide/__init__.py | 2 - .../glide/async_commands/server_modules/ft.py | 14 +-- .../ft_options/ft_create_options.py | 110 ++++++++++-------- .../search/test_ft_create.py | 8 +- .../tests/tests_server_modules/test_ft.py | 4 +- 9 files changed, 156 insertions(+), 127 deletions(-) diff --git a/java/client/src/main/java/glide/api/commands/servermodules/FT.java b/java/client/src/main/java/glide/api/commands/servermodules/FT.java index 38b0ec7096..714e1c1109 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/FT.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/FT.java @@ -27,8 +27,9 @@ public class FT { * * @param client The client to execute the command. * @param indexName The index name. - * @param fields Fields to populate into the index. - * @return OK. + * @param schema Fields to populate into the index. Equivalent to `SCHEMA` block in the module + * API. + * @return "OK". * @example *
      {@code
            * // Create an index for vectors of size 2:
      @@ -44,11 +45,11 @@ public class FT {
            * }
      */ public static CompletableFuture create( - @NonNull BaseClient client, @NonNull String indexName, @NonNull FieldInfo[] fields) { + @NonNull BaseClient client, @NonNull String indexName, @NonNull FieldInfo[] schema) { // Node: bug in meme DB - command fails if cmd is too short even though all mandatory args are // present // TODO confirm is it fixed or not and update docs if needed - return create(client, indexName, fields, FTCreateOptions.builder().build()); + return create(client, indexName, schema, FTCreateOptions.builder().build()); } /** @@ -56,9 +57,10 @@ public static CompletableFuture create( * * @param client The client to execute the command. * @param indexName The index name. - * @param fields Fields to populate into the index. + * @param schema Fields to populate into the index. Equivalent to `SCHEMA` block in the module + * API. * @param options Additional parameters for the command - see {@link FTCreateOptions}. - * @return OK. + * @return "OK". * @example *
      {@code
            * // Create a 6-dimensional JSON index using the HNSW algorithm:
      @@ -66,16 +68,16 @@ public static CompletableFuture create(
            *     new FieldInfo[] { new FieldInfo("$.vec", "VEC",
            *         VectorFieldHnsw.builder(DistanceMetric.L2, 6).numberOfEdges(32).build())
            *     },
      -     *     FTCreateOptions.builder().indexType(JSON).prefixes(new String[] {"json:"}).build(),
      +     *     FTCreateOptions.builder().dataType(JSON).prefixes(new String[] {"json:"}).build(),
            * ).get();
            * }
      */ public static CompletableFuture create( @NonNull BaseClient client, @NonNull String indexName, - @NonNull FieldInfo[] fields, + @NonNull FieldInfo[] schema, @NonNull FTCreateOptions options) { - return create(client, gs(indexName), fields, options); + return create(client, gs(indexName), schema, options); } /** @@ -83,8 +85,9 @@ public static CompletableFuture create( * * @param client The client to execute the command. * @param indexName The index name. - * @param fields Fields to populate into the index. - * @return OK. + * @param schema Fields to populate into the index. Equivalent to `SCHEMA` block in the module + * API. + * @return "OK". * @example *
      {@code
            * // Create an index for vectors of size 2:
      @@ -100,11 +103,11 @@ public static CompletableFuture create(
            * }
      */ public static CompletableFuture create( - @NonNull BaseClient client, @NonNull GlideString indexName, @NonNull FieldInfo[] fields) { + @NonNull BaseClient client, @NonNull GlideString indexName, @NonNull FieldInfo[] schema) { // Node: bug in meme DB - command fails if cmd is too short even though all mandatory args are // present // TODO confirm is it fixed or not and update docs if needed - return create(client, indexName, fields, FTCreateOptions.builder().build()); + return create(client, indexName, schema, FTCreateOptions.builder().build()); } /** @@ -112,7 +115,8 @@ public static CompletableFuture create( * * @param client The client to execute the command. * @param indexName The index name. - * @param fields Fields to populate into the index. + * @param schema Fields to populate into the index. Equivalent to `SCHEMA` block in the module + * API. * @param options Additional parameters for the command - see {@link FTCreateOptions}. * @return OK. * @example @@ -122,21 +126,21 @@ public static CompletableFuture create( * new FieldInfo[] { new FieldInfo(gs("$.vec"), gs("VEC"), * VectorFieldHnsw.builder(DistanceMetric.L2, 6).numberOfEdges(32).build()) * }, - * FTCreateOptions.builder().indexType(JSON).prefixes(new String[] {"json:"}).build(), + * FTCreateOptions.builder().dataType(JSON).prefixes(new String[] {"json:"}).build(), * ).get(); * } */ public static CompletableFuture create( @NonNull BaseClient client, @NonNull GlideString indexName, - @NonNull FieldInfo[] fields, + @NonNull FieldInfo[] schema, @NonNull FTCreateOptions options) { var args = Stream.of( new GlideString[] {gs("FT.CREATE"), indexName}, options.toArgs(), new GlideString[] {gs("SCHEMA")}, - Arrays.stream(fields) + Arrays.stream(schema) .map(FieldInfo::toArgs) .flatMap(Arrays::stream) .toArray(GlideString[]::new)) diff --git a/java/client/src/main/java/glide/api/models/commands/FT/FTCreateOptions.java b/java/client/src/main/java/glide/api/models/commands/FT/FTCreateOptions.java index 1cdb6c77d0..ae651ee2a6 100644 --- a/java/client/src/main/java/glide/api/models/commands/FT/FTCreateOptions.java +++ b/java/client/src/main/java/glide/api/models/commands/FT/FTCreateOptions.java @@ -24,14 +24,14 @@ */ @Builder public class FTCreateOptions { - /** The index type. If not given a {@link IndexType#HASH} index is created. */ - private final IndexType indexType; + /** The index data type. If not defined a {@link DataType#HASH} index is created. */ + private final DataType dataType; /** A list of prefixes of index definitions. */ private final GlideString[] prefixes; - FTCreateOptions(IndexType indexType, GlideString[] prefixes) { - this.indexType = indexType; + FTCreateOptions(DataType dataType, GlideString[] prefixes) { + this.dataType = dataType; this.prefixes = prefixes; } @@ -41,9 +41,9 @@ public static FTCreateOptionsBuilder builder() { public GlideString[] toArgs() { var args = new ArrayList(); - if (indexType != null) { + if (dataType != null) { args.add(gs("ON")); - args.add(gs(indexType.toString())); + args.add(gs(dataType.toString())); } if (prefixes != null && prefixes.length > 0) { args.add(gs("PREFIX")); @@ -61,10 +61,10 @@ public FTCreateOptionsBuilder prefixes(String[] prefixes) { } /** Type of the index dataset. */ - public enum IndexType { - /** Data stored in hashes, so field identifiers are field names within the hashes. */ + public enum DataType { + /** Data stored in hashes. Field identifiers are field names within the hashes. */ HASH, - /** Data stored in JSONs, so field identifiers are JSON Path expressions. */ + /** Data stored as a JSON document. Field identifiers are JSON Path expressions. */ JSON } @@ -110,8 +110,8 @@ public String[] toArgs() { /** * Tag fields are similar to full-text fields, but they interpret the text as a simple list of * tags delimited by a separator character.
      - * For {@link IndexType#HASH} fields, separator default is a comma (,). For {@link - * IndexType#JSON} fields, there is no default separator; you must declare one explicitly if + * For {@link DataType#HASH} fields, separator default is a comma (,). For {@link + * DataType#JSON} fields, there is no default separator; you must declare one explicitly if * needed. */ public static class TagField implements Field { @@ -127,7 +127,8 @@ public TagField() { /** * Create a TAG field. * - * @param separator The tag separator. + * @param separator Specify how text in the attribute is split into individual tags. Must be a + * single character. */ public TagField(char separator) { this.separator = Optional.of(separator); @@ -137,8 +138,10 @@ public TagField(char separator) { /** * Create a TAG field. * - * @param separator The tag separator. - * @param caseSensitive Whether to keep the original case. + * @param separator Specify how text in the attribute is split into individual tags. Must be a + * single character. + * @param caseSensitive Preserve the original letter cases of tags. If set to False, characters + * are converted to lowercase by default. */ public TagField(char separator, boolean caseSensitive) { this.separator = Optional.of(separator); @@ -148,7 +151,8 @@ public TagField(char separator, boolean caseSensitive) { /** * Create a TAG field. * - * @param caseSensitive Whether to keep the original case. + * @param caseSensitive Preserve the original letter cases of tags. If set to False, characters + * are converted to lowercase by default. */ public TagField(boolean caseSensitive) { this.caseSensitive = caseSensitive; @@ -204,6 +208,7 @@ public String[] toArgs() { } } + /** Algorithm for vector type fields used for vector similarity search. */ private enum VectorAlgorithm { HNSW, FLAT @@ -234,8 +239,9 @@ private VectorFieldHnsw(Map params) { * Init a builder. * * @param distanceMetric {@link DistanceMetric} to measure the degree of similarity between two - * vectors. - * @param dimensions Vector dimension, specified as a positive integer. Maximum: 32768 + * vectors. Equivalent to DISTANCE_METRIC on the module API. + * @param dimensions Vector dimension, specified as a positive integer. Maximum: 32768. + * Equivalent to DIM on the module API. */ public static VectorFieldHnswBuilder builder( @NonNull DistanceMetric distanceMetric, int dimensions) { @@ -256,6 +262,7 @@ public VectorFieldHnsw build() { /** * Number of maximum allowed outgoing edges for each node in the graph in each layer. On layer * zero the maximal number of outgoing edges is doubled. Default is 16 Maximum is 512. + * Equivalent to M on the module API. */ public VectorFieldHnswBuilder numberOfEdges(int numberOfEdges) { params.put(VectorAlgorithmParam.M, Integer.toString(numberOfEdges)); @@ -265,7 +272,8 @@ public VectorFieldHnswBuilder numberOfEdges(int numberOfEdges) { /** * (Optional) The number of vectors examined during index construction. Higher values for this * parameter will improve recall ratio at the expense of longer index creation times. Default - * value is 200. Maximum value is 4096. + * value is 200. Maximum value is 4096. Equivalent to EF_CONSTRUCTION on the module + * API. */ public VectorFieldHnswBuilder vectorsExaminedOnConstruction(int vectorsExaminedOnConstruction) { params.put( @@ -277,6 +285,7 @@ public VectorFieldHnswBuilder vectorsExaminedOnConstruction(int vectorsExaminedO * (Optional) The number of vectors examined during query operations. Higher values for this * parameter can yield improved recall at the expense of longer query times. The value of this * parameter can be overriden on a per-query basis. Default value is 10. Maximum value is 4096. + * Equivalent to EF_RUNTIME on the module API. */ public VectorFieldHnswBuilder vectorsExaminedOnRuntime(int vectorsExaminedOnRuntime) { params.put(VectorAlgorithmParam.EF_RUNTIME, Integer.toString(vectorsExaminedOnRuntime)); @@ -299,8 +308,9 @@ private VectorFieldFlat(Map params) { * Init a builder. * * @param distanceMetric {@link DistanceMetric} to measure the degree of similarity between two - * vectors. - * @param dimensions Vector dimension, specified as a positive integer. Maximum: 32768 + * vectors. Equivalent to DISTANCE_METRIC on the module API. + * @param dimensions Vector dimension, specified as a positive integer. Maximum: 32768. + * Equivalent to DIM on the module API. */ public static VectorFieldFlatBuilder builder( @NonNull DistanceMetric distanceMetric, int dimensions) { @@ -330,7 +340,7 @@ abstract static class VectorFieldBuilder> { /** * Initial vector capacity in the index affecting memory allocation size of the index. Defaults - * to 1024. + * to 1024. Equivalent to INITIAL_CAP on the module API. */ @SuppressWarnings("unchecked") public T initialCapacity(int initialCapacity) { @@ -343,18 +353,18 @@ public T initialCapacity(int initialCapacity) { /** Field definition to be added into index schema. */ public static class FieldInfo { - private final GlideString identifier; + private final GlideString name; private final GlideString alias; private final Field field; /** * Field definition to be added into index schema. * - * @param identifier Field identifier (name). + * @param name Field name. * @param field The {@link Field} itself. */ - public FieldInfo(@NonNull String identifier, @NonNull Field field) { - this.identifier = gs(identifier); + public FieldInfo(@NonNull String name, @NonNull Field field) { + this.name = gs(name); this.field = field; this.alias = null; } @@ -362,12 +372,12 @@ public FieldInfo(@NonNull String identifier, @NonNull Field field) { /** * Field definition to be added into index schema. * - * @param identifier Field identifier (name). + * @param name Field name. * @param alias Field alias. * @param field The {@link Field} itself. */ - public FieldInfo(@NonNull String identifier, @NonNull String alias, @NonNull Field field) { - this.identifier = gs(identifier); + public FieldInfo(@NonNull String name, @NonNull String alias, @NonNull Field field) { + this.name = gs(name); this.alias = gs(alias); this.field = field; } @@ -375,11 +385,11 @@ public FieldInfo(@NonNull String identifier, @NonNull String alias, @NonNull Fie /** * Field definition to be added into index schema. * - * @param identifier Field identifier (name). + * @param name Field name. * @param field The {@link Field} itself. */ - public FieldInfo(@NonNull GlideString identifier, @NonNull Field field) { - this.identifier = identifier; + public FieldInfo(@NonNull GlideString name, @NonNull Field field) { + this.name = name; this.field = field; this.alias = null; } @@ -387,13 +397,12 @@ public FieldInfo(@NonNull GlideString identifier, @NonNull Field field) { /** * Field definition to be added into index schema. * - * @param identifier Field identifier (name). + * @param name Field name. * @param alias Field alias. * @param field The {@link Field} itself. */ - public FieldInfo( - @NonNull GlideString identifier, @NonNull GlideString alias, @NonNull Field field) { - this.identifier = identifier; + public FieldInfo(@NonNull GlideString name, @NonNull GlideString alias, @NonNull Field field) { + this.name = name; this.alias = alias; this.field = field; } @@ -401,7 +410,7 @@ public FieldInfo( /** Convert to module API. */ public GlideString[] toArgs() { var args = new ArrayList(); - args.add(identifier); + args.add(name); if (alias != null) { args.add(gs("AS")); args.add(alias); diff --git a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java index bb39afe19c..09cf22cabf 100644 --- a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java +++ b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java @@ -5,6 +5,7 @@ import static glide.TestUtilities.commonClusterClientConfig; import static glide.api.BaseClient.OK; import static glide.api.models.GlideString.gs; +import static glide.api.models.commands.FT.FTCreateOptions.DataType; import static glide.api.models.configuration.RequestRoutingConfiguration.SimpleMultiNodeRoute.ALL_PRIMARIES; import static glide.api.models.configuration.RequestRoutingConfiguration.SimpleSingleNodeRoute.RANDOM; import static org.junit.jupiter.api.Assertions.assertArrayEquals; @@ -27,7 +28,6 @@ import glide.api.models.commands.FT.FTCreateOptions; import glide.api.models.commands.FT.FTCreateOptions.DistanceMetric; import glide.api.models.commands.FT.FTCreateOptions.FieldInfo; -import glide.api.models.commands.FT.FTCreateOptions.IndexType; import glide.api.models.commands.FT.FTCreateOptions.NumericField; import glide.api.models.commands.FT.FTCreateOptions.TagField; import glide.api.models.commands.FT.FTCreateOptions.TextField; @@ -99,7 +99,7 @@ public void ft_create() { "$.vec", "VEC", VectorFieldFlat.builder(DistanceMetric.L2, 6).build()) }, FTCreateOptions.builder() - .indexType(IndexType.JSON) + .dataType(DataType.JSON) .prefixes(new String[] {"json:"}) .build()) .get()); @@ -120,7 +120,7 @@ public void ft_create() { .build()) }, FTCreateOptions.builder() - .indexType(IndexType.HASH) + .dataType(DataType.HASH) .prefixes(new String[] {"docs:"}) .build()) .get()); @@ -137,7 +137,7 @@ public void ft_create() { new FieldInfo("category", new TagField()) }, FTCreateOptions.builder() - .indexType(IndexType.HASH) + .dataType(DataType.HASH) .prefixes(new String[] {"blog:post:"}) .build()) .get()); @@ -156,7 +156,7 @@ public void ft_create() { new FieldInfo("name", new TextField()) }, FTCreateOptions.builder() - .indexType(IndexType.HASH) + .dataType(DataType.HASH) .prefixes(new String[] {"author:details:", "book:details:"}) .build()) .get()); @@ -217,7 +217,7 @@ public void ft_search() { new FieldInfo("vec", "VEC", VectorFieldHnsw.builder(DistanceMetric.L2, 2).build()) }, FTCreateOptions.builder() - .indexType(IndexType.HASH) + .dataType(DataType.HASH) .prefixes(new String[] {prefix}) .build()) .get()); @@ -365,7 +365,7 @@ public void ft_aggregate() { new FieldInfo("$.condition", "condition", new TagField(',')), }, FTCreateOptions.builder() - .indexType(IndexType.JSON) + .dataType(DataType.JSON) .prefixes(new String[] {prefixBicycles}) .build()) .get()); @@ -567,7 +567,7 @@ public void ft_aggregate() { new FieldInfo("votes", new NumericField()), }, FTCreateOptions.builder() - .indexType(IndexType.HASH) + .dataType(DataType.HASH) .prefixes(new String[] {prefixMovies}) .build()) .get()); @@ -739,7 +739,7 @@ public void ft_info() { new FieldInfo("$.name", new TextField()), }, FTCreateOptions.builder() - .indexType(IndexType.JSON) + .dataType(DataType.JSON) .prefixes(new String[] {"123"}) .build()) .get()); diff --git a/python/DEVELOPER.md b/python/DEVELOPER.md index a3e5b07237..66127913c3 100644 --- a/python/DEVELOPER.md +++ b/python/DEVELOPER.md @@ -109,7 +109,7 @@ cd python python3 -m venv .env source .env/bin/activate pip install -r requirements.txt -pip install -r python/dev_requirements.txt +pip install -r dev_requirements.txt ``` ## Build the package (in release mode): @@ -117,7 +117,7 @@ pip install -r python/dev_requirements.txt ```bash maturin develop --release --strip ``` - + > **Note:** to build the wrapper binary with debug symbols remove the `--strip` flag. > **Note 2:** for a faster build time, execute `maturin develop` without the release flag. This will perform an unoptimized build, which is suitable for developing tests. Keep in mind that performance is significantly affected in an unoptimized build, so it's required to include the `--release` flag when measuring performance. @@ -156,8 +156,8 @@ pytest --asyncio-mode=auto --cluster-endpoints=localhost:7000 --standalone-endpo # Generate protobuf files --- -During the initial build, Python protobuf files were created in `python/python/glide/protobuf`. If modifications are made -to the protobuf definition files (`.proto` files located in `glide-core/src/protofuf`), it becomes necessary to +During the initial build, Python protobuf files were created in `python/python/glide/protobuf`. If modifications are made +to the protobuf definition files (`.proto` files located in `glide-core/src/protofuf`), it becomes necessary to regenerate the Python protobuf files. To do so, run: ```bash diff --git a/python/python/glide/__init__.py b/python/python/glide/__init__.py index b690e81137..a4fdc27d67 100644 --- a/python/python/glide/__init__.py +++ b/python/python/glide/__init__.py @@ -37,7 +37,6 @@ DataType, DistanceMetricType, Field, - FieldType, FtCreateOptions, NumericField, TagField, @@ -261,7 +260,6 @@ "DataType", "DistanceMetricType", "Field", - "FieldType", "FtCreateOptions", "NumericField", "TagField", diff --git a/python/python/glide/async_commands/server_modules/ft.py b/python/python/glide/async_commands/server_modules/ft.py index d96352a36d..cce57bd727 100644 --- a/python/python/glide/async_commands/server_modules/ft.py +++ b/python/python/glide/async_commands/server_modules/ft.py @@ -31,8 +31,8 @@ async def create( Args: client (TGlideClient): The client to execute the command. - indexName (TEncodable): The index name for the index to be created - schema (List[Field]): The fields of the index schema, specifying the fields and their types. + indexName (TEncodable): The index name. + schema (List[Field]): Fields to populate into the index. Equivalent to `SCHEMA` block in the module API. options (Optional[FtCreateOptions]): Optional arguments for the FT.CREATE command. See `FtCreateOptions`. Returns: @@ -40,13 +40,9 @@ async def create( Examples: >>> from glide import ft - >>> schema: List[Field] = [] - >>> field: TextField = TextField("title") - >>> schema.append(field) - >>> prefixes: List[str] = [] - >>> prefixes.append("blog:post:") - >>> index = "idx" - >>> result = await ft.create(glide_client, index, schema, FtCreateOptions(DataType.HASH, prefixes)) + >>> schema: List[Field] = [TextField("title")] + >>> prefixes: List[str] = ["blog:post:"] + >>> result = await ft.create(glide_client, "my_idx1", schema, FtCreateOptions(DataType.HASH, prefixes)) 'OK' # Indicates successful creation of index named 'idx' """ args: List[TEncodable] = [CommandNames.FT_CREATE, indexName] diff --git a/python/python/glide/async_commands/server_modules/ft_options/ft_create_options.py b/python/python/glide/async_commands/server_modules/ft_options/ft_create_options.py index 89ac1d760d..90aa2d9fdf 100644 --- a/python/python/glide/async_commands/server_modules/ft_options/ft_create_options.py +++ b/python/python/glide/async_commands/server_modules/ft_options/ft_create_options.py @@ -18,7 +18,7 @@ class FieldType(Enum): """ TAG = "TAG" """ - If the field contains a tag field. + If the field contains a tag field. """ NUMERIC = "NUMERIC" """ @@ -47,7 +47,10 @@ class VectorAlgorithm(Enum): class DistanceMetricType(Enum): """ - The metric options for the distance in vector type field. + Distance metrics to measure the degree of similarity between two vectors. + + The above metrics calculate distance between two vectors, where the smaller the value is, the + closer the two vectors are in the vector space. """ L2 = "L2" @@ -77,7 +80,7 @@ class VectorType(Enum): class Field(ABC): """ - Abstract base class for defining fields in a schema. + Abstract base class for a vector search field. """ @abstractmethod @@ -116,7 +119,7 @@ def toArgs(self) -> List[TEncodable]: class TextField(Field): """ - Class for defining text fields in a schema. + Field contains any blob of data. """ def __init__(self, name: TEncodable, alias: Optional[TEncodable] = None): @@ -142,7 +145,11 @@ def toArgs(self) -> List[TEncodable]: class TagField(Field): """ - Class for defining tag fields in a schema. + Tag fields are similar to full-text fields, but they interpret the text as a simple list of + tags delimited by a separator character. + + For `HASH fields, separator default is a comma `,`. For `JSON` fields, there is no + default separator; you must declare one explicitly if needed. """ def __init__( @@ -182,7 +189,7 @@ def toArgs(self) -> List[TEncodable]: class NumericField(Field): """ - Class for defining the numeric fields in a schema. + Field contains a number. """ def __init__(self, name: TEncodable, alias: Optional[TEncodable] = None): @@ -212,16 +219,18 @@ class VectorFieldAttributes(ABC): """ @abstractmethod - def __init__(self, dim: int, distance_metric: DistanceMetricType, type: VectorType): + def __init__( + self, dimensions: int, distance_metric: DistanceMetricType, type: VectorType + ): """ Initialize a new vector field attributes instance. Args: - dim (int): Number of dimensions in the vector. - distance_metric (DistanceMetricType): The distance metric used in vector type field. Can be one of [L2 | IP | COSINE]. - type (VectorType): Vector type. The only supported type is FLOAT32. See `VectorType`. + dimensions (int): Number of dimensions in the vector. Equivalent to `DIM` on the module API. + distance_metric (DistanceMetricType): The distance metric used in vector type field. Can be one of `[L2 | IP | COSINE]`. Equivalent to `DISTANCE_METRIC` on the module API. + type (VectorType): Vector type. The only supported type is `FLOAT32`. Equivalent to `TYPE` on the module API. """ - self.dim = dim + self.dimensions = dimensions self.distance_metric = distance_metric self.type = type @@ -234,8 +243,8 @@ def toArgs(self) -> List[TEncodable]: List[TEncodable]: A list of arguments. """ args: List[TEncodable] = [] - if self.dim: - args.extend([FtCreateKeywords.DIM, str(self.dim)]) + if self.dimensions: + args.extend([FtCreateKeywords.DIM, str(self.dimensions)]) if self.distance_metric: args.extend([FtCreateKeywords.DISTANCE_METRIC, self.distance_metric.name]) if self.type: @@ -250,7 +259,7 @@ class VectorFieldAttributesFlat(VectorFieldAttributes): def __init__( self, - dim: int, + dimensions: int, distance_metric: DistanceMetricType, type: VectorType, initial_cap: Optional[int] = None, @@ -259,12 +268,12 @@ def __init__( Initialize a new flat vector field attributes instance. Args: - dim (int): Number of dimensions in the vector. - distance_metric (DistanceMetricType): The distance metric used in vector type field. Can be one of [L2 | IP | COSINE]. See `DistanceMetricType`. - type (VectorType): Vector type. The only supported type is FLOAT32. See `VectorType`. - initial_cap (Optional[int]): Initial vector capacity in the index affecting memory allocation size of the index. Defaults to 1024. + dimensions (int): Number of dimensions in the vector. Equivalent to `DIM` on the module API. + distance_metric (DistanceMetricType): The distance metric used in vector type field. Can be one of `[L2 | IP | COSINE]`. Equivalent to `DISTANCE_METRIC` on the module API. + type (VectorType): Vector type. The only supported type is `FLOAT32`. Equivalent to `TYPE` on the module API. + initial_cap (Optional[int]): Initial vector capacity in the index affecting memory allocation size of the index. Defaults to `1024`. Equivalent to `INITIAL_CAP` on the module API. """ - super().__init__(dim, distance_metric, type) + super().__init__(dimensions, distance_metric, type) self.initial_cap = initial_cap def toArgs(self) -> List[TEncodable]: @@ -287,31 +296,31 @@ class VectorFieldAttributesHnsw(VectorFieldAttributes): def __init__( self, - dim: int, + dimensions: int, distance_metric: DistanceMetricType, type: VectorType, initial_cap: Optional[int] = None, - m: Optional[int] = None, - ef_contruction: Optional[int] = None, - ef_runtime: Optional[int] = None, + number_of_edges: Optional[int] = None, + vectors_examined_on_construction: Optional[int] = None, + vectors_examined_on_runtime: Optional[int] = None, ): """ - Initialize a new TagField instance. + Initialize a new HNSW vector field attributes instance. Args: - dim (int): Number of dimensions in the vector. - distance_metric (DistanceMetricType): The distance metric used in vector type field. Can be one of [L2 | IP | COSINE]. See `DistanceMetricType`. - type (VectorType): Vector type. The only supported type is FLOAT32. See `VectorType`. - initial_cap (Optional[int]): Initial vector capacity in the index affecting memory allocation size of the index. Defaults to 1024. - m (Optional[int]): Number of maximum allowed outgoing edges for each node in the graph in each layer. Default is 16, maximum is 512. - ef_contruction (Optional[int]): Controls the number of vectors examined during index construction. Default value is 200, Maximum value is 4096. - ef_runtime (Optional[int]): Controls the number of vectors examined during query operations. Default value is 10, Maximum value is 4096. - """ - super().__init__(dim, distance_metric, type) + dimensions (int): Number of dimensions in the vector. Equivalent to `DIM` on the module API. + distance_metric (DistanceMetricType): The distance metric used in vector type field. Can be one of `[L2 | IP | COSINE]`. Equivalent to `DISTANCE_METRIC` on the module API. + type (VectorType): Vector type. The only supported type is `FLOAT32`. Equivalent to `TYPE` on the module API. + initial_cap (Optional[int]): Initial vector capacity in the index affecting memory allocation size of the index. Defaults to `1024`. Equivalent to `INITIAL_CAP` on the module API. + number_of_edges (Optional[int]): Number of maximum allowed outgoing edges for each node in the graph in each layer. Default is `16`, maximum is `512`. Equivalent to `M` on the module API. + vectors_examined_on_construction (Optional[int]): Controls the number of vectors examined during index construction. Default value is `200`, Maximum value is `4096`. Equivalent to `EF_CONSTRUCTION` on the module API. + vectors_examined_on_runtime (Optional[int]): Controls the number of vectors examined during query operations. Default value is `10`, Maximum value is `4096`. Equivalent to `EF_RUNTIME` on the module API. + """ + super().__init__(dimensions, distance_metric, type) self.initial_cap = initial_cap - self.m = m - self.ef_contruction = ef_contruction - self.ef_runtime = ef_runtime + self.number_of_edges = number_of_edges + self.vectors_examined_on_construction = vectors_examined_on_construction + self.vectors_examined_on_runtime = vectors_examined_on_runtime def toArgs(self) -> List[TEncodable]: """ @@ -323,12 +332,19 @@ def toArgs(self) -> List[TEncodable]: args = super().toArgs() if self.initial_cap: args.extend([FtCreateKeywords.INITIAL_CAP, str(self.initial_cap)]) - if self.m: - args.extend([FtCreateKeywords.M, str(self.m)]) - if self.ef_contruction: - args.extend([FtCreateKeywords.EF_CONSTRUCTION, str(self.ef_contruction)]) - if self.ef_runtime: - args.extend([FtCreateKeywords.EF_RUNTIME, str(self.ef_runtime)]) + if self.number_of_edges: + args.extend([FtCreateKeywords.M, str(self.number_of_edges)]) + if self.vectors_examined_on_construction: + args.extend( + [ + FtCreateKeywords.EF_CONSTRUCTION, + str(self.vectors_examined_on_construction), + ] + ) + if self.vectors_examined_on_runtime: + args.extend( + [FtCreateKeywords.EF_RUNTIME, str(self.vectors_examined_on_runtime)] + ) return args @@ -375,16 +391,16 @@ def toArgs(self) -> List[TEncodable]: class DataType(Enum): """ - Options for the type of data for which the index is being created. + Type of the index dataset. """ HASH = "HASH" """ - If the created index will index HASH data. + Data stored in hashes, so field identifiers are field names within the hashes. """ JSON = "JSON" """ - If the created index will index JSON document data. + Data stored as a JSON document, so field identifiers are JSON Path expressions. """ @@ -403,8 +419,8 @@ def __init__( Initialize the FT.CREATE optional fields. Args: - data_type (Optional[DataType]): The type of data to be indexed using FT.CREATE. See `DataType`. - prefixes (Optional[List[TEncodable]]): The prefix of the key to be indexed. + data_type (Optional[DataType]): The index data type. If not defined a `HASH` index is created. See `DataType`. + prefixes (Optional[List[TEncodable]]): A list of prefixes of index definitions. """ self.data_type = data_type self.prefixes = prefixes diff --git a/python/python/tests/tests_server_modules/search/test_ft_create.py b/python/python/tests/tests_server_modules/search/test_ft_create.py index eba7592698..6655fac0c0 100644 --- a/python/python/tests/tests_server_modules/search/test_ft_create.py +++ b/python/python/tests/tests_server_modules/search/test_ft_create.py @@ -63,7 +63,9 @@ async def test_ft_create(self, glide_client: GlideClusterClient): name="vec", algorithm=VectorAlgorithm.HNSW, attributes=VectorFieldAttributesHnsw( - dim=2, distance_metric=DistanceMetricType.L2, type=VectorType.FLOAT32 + dimensions=2, + distance_metric=DistanceMetricType.L2, + type=VectorType.FLOAT32, ), alias="VEC", ) @@ -85,7 +87,9 @@ async def test_ft_create(self, glide_client: GlideClusterClient): name="$.vec", algorithm=VectorAlgorithm.HNSW, attributes=VectorFieldAttributesHnsw( - dim=6, distance_metric=DistanceMetricType.L2, type=VectorType.FLOAT32 + dimensions=6, + distance_metric=DistanceMetricType.L2, + type=VectorType.FLOAT32, ), alias="VEC", ) diff --git a/python/python/tests/tests_server_modules/test_ft.py b/python/python/tests/tests_server_modules/test_ft.py index 9d38531737..ea6ad8261b 100644 --- a/python/python/tests/tests_server_modules/test_ft.py +++ b/python/python/tests/tests_server_modules/test_ft.py @@ -183,7 +183,9 @@ async def _create_test_index_with_vector_field( name="$.vec", algorithm=VectorAlgorithm.HNSW, attributes=VectorFieldAttributesHnsw( - dim=2, distance_metric=DistanceMetricType.L2, type=VectorType.FLOAT32 + dimensions=2, + distance_metric=DistanceMetricType.L2, + type=VectorType.FLOAT32, ), alias="VEC", ) From a3fbf8f75048afec27b65fd20c275d550077f6a7 Mon Sep 17 00:00:00 2001 From: Andrew Carbonetto Date: Fri, 25 Oct 2024 13:21:25 -0700 Subject: [PATCH 047/180] Node: add `FT.DROPINDEX` command (#2516) * Node: Add FT.DROPINDEX command Signed-off-by: Andrew Carbonetto --- CHANGELOG.md | 1 + .../glide/api/commands/servermodules/FT.java | 2 + node/src/server-modules/GlideFt.ts | 31 ++++++++++- node/src/server-modules/GlideFtOptions.ts | 24 ++++----- node/tests/ServerModules.test.ts | 53 +++++++++++++++++-- 5 files changed, 93 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d5bbae376..d21043f57f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ * Node: Added `JSON.DEL` and `JSON.FORGET` ([#2505](https://github.com/valkey-io/valkey-glide/pull/2505)) * Java: Added `JSON.TOGGLE` ([#2504](https://github.com/valkey-io/valkey-glide/pull/2504)) * Node: Added `JSON.TYPE` ([#2510](https://github.com/valkey-io/valkey-glide/pull/2510)) +* Node: Added `FT.DROPINDEX` ([#2516](https://github.com/valkey-io/valkey-glide/pull/2516)) #### Breaking Changes diff --git a/java/client/src/main/java/glide/api/commands/servermodules/FT.java b/java/client/src/main/java/glide/api/commands/servermodules/FT.java index 714e1c1109..0250865648 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/FT.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/FT.java @@ -285,6 +285,7 @@ public static CompletableFuture search( /** * Deletes an index and associated content. Indexed document keys are unaffected. * + * @param client The client to execute the command. * @param indexName The index name. * @return "OK". * @example @@ -300,6 +301,7 @@ public static CompletableFuture dropindex( /** * Deletes an index and associated content. Indexed document keys are unaffected. * + * @param client The client to execute the command. * @param indexName The index name. * @return "OK". * @example diff --git a/node/src/server-modules/GlideFt.ts b/node/src/server-modules/GlideFt.ts index 566e4d54c4..e78f961c8c 100644 --- a/node/src/server-modules/GlideFt.ts +++ b/node/src/server-modules/GlideFt.ts @@ -92,10 +92,10 @@ export class GlideFt { const attributes: GlideString[] = []; // all VectorFieldAttributes attributes - if (f.attributes.dimension) { + if (f.attributes.dimensions) { attributes.push( "DIM", - f.attributes.dimension.toString(), + f.attributes.dimensions.toString(), ); } @@ -111,6 +111,8 @@ export class GlideFt { "TYPE", f.attributes.type.toString(), ); + } else { + attributes.push("TYPE", "FLOAT32"); } if (f.attributes.initialCap) { @@ -160,6 +162,31 @@ export class GlideFt { decoder: Decoder.String, }) as Promise<"OK">; } + + /** + * Deletes an index and associated content. Indexed document keys are unaffected. + * + * @param client The client to execute the command. + * @param indexName The index name. + * + * @returns "OK" + * + * @example + * ```typescript + * // Example usage of FT.DROPINDEX to drop an index + * await GlideFt.dropindex(client, "json_idx1"); // "OK" + * ``` + */ + static async dropindex( + client: GlideClient | GlideClusterClient, + indexName: GlideString, + ): Promise<"OK"> { + const args: GlideString[] = ["FT.DROPINDEX", indexName]; + + return _handleCustomCommand(client, args, { + decoder: Decoder.String, + }) as Promise<"OK">; + } } /** diff --git a/node/src/server-modules/GlideFtOptions.ts b/node/src/server-modules/GlideFtOptions.ts index 6fe723cc9d..24846da6d2 100644 --- a/node/src/server-modules/GlideFtOptions.ts +++ b/node/src/server-modules/GlideFtOptions.ts @@ -31,7 +31,7 @@ export type TagField = BaseField & { type: "TAG"; /** Specify how text in the attribute is split into individual tags. Must be a single character. */ separator?: GlideString; - /** Preserve the original letter cases of tags. If set to False, characters are converted to lowercase by default. */ + /** Preserve the original letter cases of tags. If set to `false`, characters are converted to lowercase by default. */ caseSensitive?: boolean; }; @@ -57,16 +57,16 @@ export type VectorField = BaseField & { * Base class for defining vector field attributes to be used after the vector algorithm name. */ export interface VectorFieldAttributes { - /** Number of dimensions in the vector. Equivalent to DIM in the option. */ - dimension: number; + /** Number of dimensions in the vector. Equivalent to `DIM` in the module API. */ + dimensions: number; /** - * The distance metric used in vector type field. Can be one of [L2 | IP | COSINE]. + * The distance metric used in vector type field. Can be one of `[L2 | IP | COSINE]`. Equivalent to `DISTANCE_METRIC` in the module API. */ distanceMetric: "L2" | "IP" | "COSINE"; /** Vector type. The only supported type is FLOAT32. */ - type: "FLOAT32"; + type?: "FLOAT32"; /** - * Initial vector capacity in the index affecting memory allocation size of the index. Defaults to 1024. + * Initial vector capacity in the index affecting memory allocation size of the index. Defaults to `1024`. Equivalent to `INITIAL_CAP` in the module API. */ initialCap?: number; } @@ -90,18 +90,18 @@ export type VectorFieldAttributesFlat = VectorFieldAttributes & { export type VectorFieldAttributesHnsw = VectorFieldAttributes & { algorithm: "HNSW"; /** - * Number of maximum allowed outgoing edges for each node in the graph in each layer. Default is 16, maximum is 512. - * Equivalent to the `m` attribute. + * Number of maximum allowed outgoing edges for each node in the graph in each layer. Default is `16`, maximum is `512`. + * Equivalent to `M` in the module API. */ numberOfEdges?: number; /** - * Controls the number of vectors examined during index construction. Default value is 200, Maximum value is 4096. - * Equivalent to the `efContruction` attribute. + * Controls the number of vectors examined during index construction. Default value is `200`, Maximum value is `4096`. + * Equivalent to `EF_CONSTRUCTION` in the module API. */ vectorsExaminedOnConstruction?: number; /** - * Controls the number of vectors examined during query operations. Default value is 10, Maximum value is 4096. - * Equivalent to the `efRuntime` attribute. + * Controls the number of vectors examined during query operations. Default value is `10`, Maximum value is `4096`. + * Equivalent to `EF_RUNTIME` in the module API. */ vectorsExaminedOnRuntime?: number; }; diff --git a/node/tests/ServerModules.test.ts b/node/tests/ServerModules.test.ts index 29e10bcfaf..899bf88644 100644 --- a/node/tests/ServerModules.test.ts +++ b/node/tests/ServerModules.test.ts @@ -629,7 +629,7 @@ describe("Server Module Tests", () => { attributes: { algorithm: "HNSW", type: "FLOAT32", - dimension: 2, + dimensions: 2, distanceMetric: "L2", }, }; @@ -649,7 +649,7 @@ describe("Server Module Tests", () => { attributes: { algorithm: "HNSW", type: "FLOAT32", - dimension: 6, + dimensions: 6, distanceMetric: "L2", numberOfEdges: 32, }, @@ -669,7 +669,7 @@ describe("Server Module Tests", () => { attributes: { algorithm: "FLAT", type: "FLOAT32", - dimension: 6, + dimensions: 6, distanceMetric: "L2", }, }; @@ -684,7 +684,7 @@ describe("Server Module Tests", () => { attributes: { algorithm: "HNSW", type: "FLOAT32", - dimension: 1536, + dimensions: 1536, distanceMetric: "COSINE", numberOfEdges: 40, vectorsExaminedOnConstruction: 250, @@ -766,5 +766,50 @@ describe("Server Module Tests", () => { expect((e as Error).message).toContain("already exists"); } }); + + it("Ft.DROPINDEX test", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + ProtocolVersion.RESP3, + ), + ); + + // create an index + const index = uuidv4(); + expect( + await GlideFt.create(client, index, [ + { + type: "VECTOR", + name: "vec", + attributes: { + algorithm: "HNSW", + distanceMetric: "L2", + dimensions: 2, + }, + }, + { type: "NUMERIC", name: "published_at" }, + { type: "TAG", name: "category" }, + ]), + ).toEqual("OK"); + + const before = await client.customCommand(["FT._LIST"]); + expect(before).toContain(index); + + // DROP it + expect(await GlideFt.dropindex(client, index)).toEqual("OK"); + + const after = await client.customCommand(["FT._LIST"]); + expect(after).not.toContain(index); + + // dropping the index again results in an error + try { + expect( + await GlideFt.dropindex(client, index), + ).rejects.toThrow(); + } catch (e) { + expect((e as Error).message).toContain("Index does not exist"); + } + }); }); }); From f434b79e69c8bc7fa5e99e8e61695b141fa7571b Mon Sep 17 00:00:00 2001 From: James Xin Date: Fri, 25 Oct 2024 14:23:42 -0700 Subject: [PATCH 048/180] Java: add JSON.RESP (#2513) * Java: add JSON.RESP --------- Signed-off-by: James Xin Signed-off-by: Andrew Carbonetto Co-authored-by: Andrew Carbonetto --- CHANGELOG.md | 1 + .../api/commands/servermodules/Json.java | 189 +++++++++++++++++- .../api/commands/servermodules/JsonTest.java | 82 ++++++++ .../test/java/glide/modules/JsonTests.java | 90 +++++++++ 4 files changed, 354 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d21043f57f..d226dffc62 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ * Node: Added `JSON.DEL` and `JSON.FORGET` ([#2505](https://github.com/valkey-io/valkey-glide/pull/2505)) * Java: Added `JSON.TOGGLE` ([#2504](https://github.com/valkey-io/valkey-glide/pull/2504)) * Node: Added `JSON.TYPE` ([#2510](https://github.com/valkey-io/valkey-glide/pull/2510)) +* Java: Added `JSON.RESP` ([#2513](https://github.com/valkey-io/valkey-glide/pull/2513)) * Node: Added `FT.DROPINDEX` ([#2516](https://github.com/valkey-io/valkey-glide/pull/2516)) #### Breaking Changes diff --git a/java/client/src/main/java/glide/api/commands/servermodules/Json.java b/java/client/src/main/java/glide/api/commands/servermodules/Json.java index 80b5eaf028..7fee5fefac 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/Json.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/Json.java @@ -30,6 +30,8 @@ public class Json { private static final String JSON_DEL = JSON_PREFIX + "DEL"; private static final String JSON_FORGET = JSON_PREFIX + "FORGET"; private static final String JSON_TOGGLE = JSON_PREFIX + "TOGGLE"; + private static final String JSON_RESP = JSON_PREFIX + "RESP"; + private static final String JSON_TYPE = JSON_PREFIX + "TYPE"; private Json() {} @@ -189,11 +191,12 @@ public static CompletableFuture get( *
        *
      • For JSONPath (path starts with $): Returns a stringified JSON list * replies for every possible path, or a string representation of an empty array, - * if path doesn't exist. If key doesn't exist, returns None. + * if path doesn't exist. If key doesn't exist, returns null + * . *
      • For legacy path (path doesn't start with $): Returns a string * representation of the value in paths. If paths * doesn't exist, an error is raised. If key doesn't exist, returns - * None. + * null. *
      *
    • If multiple paths are given: Returns a stringified JSON, in which each path is a key, * and it's corresponding value, is the value as if the path was executed in the command @@ -226,11 +229,12 @@ public static CompletableFuture get( *
        *
      • For JSONPath (path starts with $): Returns a stringified JSON list * replies for every possible path, or a string representation of an empty array, - * if path doesn't exist. If key doesn't exist, returns None. + * if path doesn't exist. If key doesn't exist, returns null + * . *
      • For legacy path (path doesn't start with $): Returns a string * representation of the value in paths. If paths * doesn't exist, an error is raised. If key doesn't exist, returns - * None. + * null. *
      *
    • If multiple paths are given: Returns a stringified JSON, in which each path is a key, * and it's corresponding value, is the value as if the path was executed in the command @@ -317,11 +321,12 @@ public static CompletableFuture get( *
        *
      • For JSONPath (path starts with $): Returns a stringified JSON list * replies for every possible path, or a string representation of an empty array, - * if path doesn't exist. If key doesn't exist, returns None. + * if path doesn't exist. If key doesn't exist, returns null + * . *
      • For legacy path (path doesn't start with $): Returns a string * representation of the value in paths. If paths * doesn't exist, an error is raised. If key doesn't exist, returns - * None. + * null. *
      *
    • If multiple paths are given: Returns a stringified JSON, in which each path is a key, * and it's corresponding value, is the value as if the path was executed in the command @@ -363,11 +368,12 @@ public static CompletableFuture get( *
        *
      • For JSONPath (path starts with $): Returns a stringified JSON list * replies for every possible path, or a string representation of an empty array, - * if path doesn't exist. If key doesn't exist, returns None. + * if path doesn't exist. If key doesn't exist, returns null + * . *
      • For legacy path (path doesn't start with $): Returns a string * representation of the value in paths. If paths * doesn't exist, an error is raised. If key doesn't exist, returns - * None. + * null. *
      *
    • If multiple paths are given: Returns a stringified JSON, in which each path is a key, * and it's corresponding value, is the value as if the path was executed in the command @@ -1209,6 +1215,173 @@ public static CompletableFuture toggle( client, new ArgsBuilder().add(gs(JSON_TOGGLE)).add(key).add(path).toArray()); } + /** + * Retrieves the JSON document stored at key. The returning result is in the Valkey or Redis OSS Serialization Protocol (RESP). + *
        + *
      • JSON null is mapped to the RESP Null Bulk String.
      • + *
      • JSON Booleans are mapped to RESP Simple string.
      • + *
      • JSON integers are mapped to RESP Integers.
      • + *
      • JSON doubles are mapped to RESP Bulk Strings.
      • + *
      • JSON strings are mapped to RESP Bulk Strings.
      • + *
      • JSON arrays are represented as RESP arrays, where the first element is the simple string [, followed by the array's elements.
      • + *
      • JSON objects are represented as RESP object, where the first element is the simple string {, followed by key-value pairs, each of which is a RESP bulk string.
      • + *
      + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return Returns the JSON document in its RESP form. + * If key doesn't exist, null is returned. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": [1, 2, 3], \"b\": {\"b1\": 1}, \"c\": 42}");
      +     * Object actualResult = Json.resp(client, "doc").get();
      +     * Object[] expectedResult = new Object[] {
      +     *     "{",
      +     *     new Object[] {"a", new Object[] {"[", 1L, 2L, 3L}},
      +     *     new Object[] {"b", new Object[] {"{", new Object[] {"b1", 1L}}},
      +     *     new Object[] {"c", 42L}
      +     * };
      +     * assertInstanceOf(Object[].class, actualResult);
      +     * assertArrayEquals(expectedResult, (Object[]) actualResult);
      +     * }
      + */ + public static CompletableFuture resp(@NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new String[] {JSON_RESP, key}); + } + + /** + * Retrieves the JSON document stored at key. The returning result is in the Valkey or Redis OSS Serialization Protocol (RESP). + *
        + *
      • JSON null is mapped to the RESP Null Bulk String.
      • + *
      • JSON Booleans are mapped to RESP Simple string.
      • + *
      • JSON integers are mapped to RESP Integers.
      • + *
      • JSON doubles are mapped to RESP Bulk Strings.
      • + *
      • JSON strings are mapped to RESP Bulk Strings.
      • + *
      • JSON arrays are represented as RESP arrays, where the first element is the simple string [, followed by the array's elements.
      • + *
      • JSON objects are represented as RESP object, where the first element is the simple string {, followed by key-value pairs, each of which is a RESP bulk string.
      • + *
      + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return Returns the JSON document in its RESP form. + * If key doesn't exist, null is returned. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": [1, 2, 3], \"b\": {\"b1\": 1}, \"c\": 42}");
      +     * Object actualResultBinary = Json.resp(client, gs("doc")).get();
      +     * Object[] expectedResultBinary = new Object[] {
      +     *     "{",
      +     *     new Object[] {gs("a"), new Object[] {gs("["), 1L, 2L, 3L}},
      +     *     new Object[] {gs("b"), new Object[] {gs("{"), new Object[] {gs("b1"), 1L}}},
      +     *     new Object[] {gs("c"), 42L}
      +     * };
      +     * assertInstanceOf(Object[].class, actualResultBinary);
      +     * assertArrayEquals(expectedResultBinary, (Object[]) actualResultBinary);
      +     * }
      + */ + public static CompletableFuture resp( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new GlideString[] {gs(JSON_RESP), key}); + } + + /** + * Retrieve the JSON value at the specified path within the JSON document stored at + * key. The returning result is in the Valkey or Redis OSS Serialization Protocol + * (RESP). + * + *
        + *
      • JSON null is mapped to the RESP Null Bulk String. + *
      • JSON Booleans are mapped to RESP Simple string. + *
      • JSON integers are mapped to RESP Integers. + *
      • JSON doubles are mapped to RESP Bulk Strings. + *
      • JSON strings are mapped to RESP Bulk Strings. + *
      • JSON arrays are represented as RESP arrays, where the first element is the simple string + * [, followed by the array's elements. + *
      • JSON objects are represented as RESP object, where the first element is the simple string + * {, followed by key-value pairs, each of which is a RESP bulk string. + *
      + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $): Returns a list of + * replies for every possible path, indicating the RESP form of the JSON value. If + * path doesn't exist, returns an empty list. + *
      • For legacy path (path doesn't starts with $): Returns a + * single reply for the JSON value at the specified path, in its RESP form. If multiple + * paths match, the value of the first JSON value match is returned. If path + * doesn't exist, an error is raised. + *
      + * If key doesn't exist, null is returned. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": [1, 2, 3], \"b\": {\"a\": [1, 2], \"c\": {\"a\": 42}}}");
      +     * Object actualResult = Json.resp(client, "doc", "$..a").get(); // JSONPath returns all possible paths
      +     * Object[] expectedResult = new Object[] {
      +     *                 new Object[] {"[", 1L, 2L, 3L},
      +     *                 new Object[] {"[", 1L, 2L},
      +     *                 42L};
      +     * assertArrayEquals(expectedResult, (Object[]) actualResult);
      +     * // legacy path only returns the first JSON value match
      +     * assertArrayEquals(new Object[] {"[", 1L, 2L, 3L}, (Object[]) Json.resp(client, key, "..a").get());
      +     * }
      + */ + public static CompletableFuture resp( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, new String[] {JSON_RESP, key, path}); + } + + /** + * Retrieve the JSON value at the specified path within the JSON document stored at + * key. The returning result is in the Valkey or Redis OSS Serialization Protocol + * (RESP). + * + *
        + *
      • JSON null is mapped to the RESP Null Bulk String. + *
      • JSON Booleans are mapped to RESP Simple string. + *
      • JSON integers are mapped to RESP Integers. + *
      • JSON doubles are mapped to RESP Bulk Strings. + *
      • JSON strings are mapped to RESP Bulk Strings. + *
      • JSON arrays are represented as RESP arrays, where the first element is the simple string + * [, followed by the array's elements. + *
      • JSON objects are represented as RESP object, where the first element is the simple string + * {, followed by key-value pairs, each of which is a RESP bulk string. + *
      + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $): Returns a list of + * replies for every possible path, indicating the RESP form of the JSON value. If + * path doesn't exist, returns an empty list. + *
      • For legacy path (path doesn't starts with $): Returns a + * single reply for the JSON value at the specified path, in its RESP form. If multiple + * paths match, the value of the first JSON value match is returned. If path + * doesn't exist, an error is raised. + *
      + * If key doesn't exist, null is returned. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": [1, 2, 3], \"b\": {\"a\": [1, 2], \"c\": {\"a\": 42}}}");
      +     * Object actualResult = Json.resp(client, gs("doc"), gs("$..a")).get(); // JSONPath returns all possible paths
      +     * Object[] expectedResult = new Object[] {
      +     *                 new Object[] {gs("["), 1L, 2L, 3L},
      +     *                 new Object[] {gs("["), 1L, 2L},
      +     *                 42L};
      +     * assertArrayEquals(expectedResult, (Object[]) actualResult);
      +     * // legacy path only returns the first JSON value match
      +     * assertArrayEquals(new Object[] {gs("["), 1L, 2L, 3L}, (Object[]) Json.resp(client, gs(key), gs("..a")).get());
      +     * }
      + */ + public static CompletableFuture resp( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand(client, new GlideString[] {gs(JSON_RESP), key, path}); + } + /** * A wrapper for custom command API. * diff --git a/java/client/src/test/java/glide/api/commands/servermodules/JsonTest.java b/java/client/src/test/java/glide/api/commands/servermodules/JsonTest.java index 0425831ea2..f05a9b2bb6 100644 --- a/java/client/src/test/java/glide/api/commands/servermodules/JsonTest.java +++ b/java/client/src/test/java/glide/api/commands/servermodules/JsonTest.java @@ -516,4 +516,86 @@ void forget_binary_with_path_returns_success() { assertEquals(expectedResponse, actualResponse); assertEquals(expectedResponseValue, actualResponseValue); } + + @Test + @SneakyThrows + void resp_without_path_returns_success() { + // setup + String key = "testKey"; + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "foo"; + expectedResponse.complete(expectedResponseValue); + when(glideClient.customCommand(eq(new String[] {"JSON.RESP", key})).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.resp(glideClient, key); + Object actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void resp_binary_without_path_returns_success() { + // setup + GlideString key = gs("testKey"); + CompletableFuture expectedResponse = new CompletableFuture<>(); + GlideString expectedResponseValue = gs("foo"); + expectedResponse.complete(expectedResponseValue); + when(glideClient.customCommand(eq(new GlideString[] {gs("JSON.RESP"), key})).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.resp(glideClient, key); + Object actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void resp_with_path_returns_success() { + // setup + String key = "testKey"; + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "foo"; + expectedResponse.complete(expectedResponseValue); + when(glideClient.customCommand(eq(new String[] {"JSON.RESP", key, "$"})).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.resp(glideClient, key, "$"); + Object actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void resp_binary_with_path_returns_success() { + // setup + GlideString key = gs("testKey"); + CompletableFuture expectedResponse = new CompletableFuture<>(); + GlideString expectedResponseValue = gs("foo"); + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new GlideString[] {gs("JSON.RESP"), key, gs("$")})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.resp(glideClient, key, gs("$")); + Object actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } } diff --git a/java/integTest/src/test/java/glide/modules/JsonTests.java b/java/integTest/src/test/java/glide/modules/JsonTests.java index de59753c20..834a4c7f0e 100644 --- a/java/integTest/src/test/java/glide/modules/JsonTests.java +++ b/java/integTest/src/test/java/glide/modules/JsonTests.java @@ -8,6 +8,7 @@ import static glide.api.models.configuration.RequestRoutingConfiguration.SimpleSingleNodeRoute.RANDOM; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -435,4 +436,93 @@ public void toggle() { assertThrows( ExecutionException.class, () -> Json.toggle(client, "non_existing_key", "$").get()); } + + @Test + @SneakyThrows + public void json_resp() { + String key = UUID.randomUUID().toString(); + String jsonValue = + "{\"obj\":{\"a\":1, \"b\":2}, \"arr\":[1,2,3], \"str\": \"foo\", \"bool\": true, \"int\":" + + " 42, \"float\": 3.14, \"nullVal\": null}"; + assertEquals(OK, Json.set(client, key, "$", jsonValue).get()); + + Object actualResult1 = Json.resp(client, key, "$.*").get(); + Object[] expectedResult1 = + new Object[] { + new Object[] { + "{", + new Object[] {"a", 1L}, + new Object[] {"b", 2L} // leading "{" indicates JSON objects + }, + new Object[] {"[", 1L, 2L, 3L}, // leading "[" indicates JSON arrays + "foo", + "true", + 42L, + "3.14", + null + }; + assertInstanceOf(Object[].class, actualResult1); + assertArrayEquals(expectedResult1, (Object[]) actualResult1); + + // multiple path match, the first will be returned + Object actualResult2 = Json.resp(client, key, "*").get(); + Object[] expectedResult2 = new Object[] {"{", new Object[] {"a", 1L}, new Object[] {"b", 2L}}; + assertInstanceOf(Object[].class, actualResult2); + assertArrayEquals(expectedResult2, (Object[]) actualResult2); + + Object actualResult3 = Json.resp(client, key, "$").get(); + Object[] expectedResult3 = + new Object[] { + new Object[] { + "{", + new Object[] { + "obj", new Object[] {"{", new Object[] {"a", 1L}, new Object[] {"b", 2L}} + }, + new Object[] {"arr", new Object[] {"[", 1L, 2L, 3L}}, + new Object[] {"str", "foo"}, + new Object[] {"bool", "true"}, + new Object[] {"int", 42L}, + new Object[] {"float", "3.14"}, + new Object[] {"nullVal", null} + } + }; + assertInstanceOf(Object[].class, actualResult3); + assertArrayEquals(expectedResult3, (Object[]) actualResult3); + + Object actualResult4 = Json.resp(client, key, ".").get(); + Object[] expectedResult4 = + new Object[] { + "{", + new Object[] {"obj", new Object[] {"{", new Object[] {"a", 1L}, new Object[] {"b", 2L}}}, + new Object[] {"arr", new Object[] {"[", 1L, 2L, 3L}}, + new Object[] {"str", "foo"}, + new Object[] {"bool", "true"}, + new Object[] {"int", 42L}, + new Object[] {"float", "3.14"}, + new Object[] {"nullVal", null} + }; + assertInstanceOf(Object[].class, actualResult4); + assertArrayEquals(expectedResult4, (Object[]) actualResult4); + // resp without path defaults to the same behavior of passing "." as path + Object actualResult4WithoutPath = Json.resp(client, key).get(); + assertArrayEquals(expectedResult4, (Object[]) actualResult4WithoutPath); + assertArrayEquals(expectedResult4, (Object[]) actualResult4WithoutPath); + + Object actualResult5 = Json.resp(client, gs(key), gs("$.str")).get(); + Object[] expectedResult5 = new Object[] {gs("foo")}; + assertInstanceOf(Object[].class, actualResult5); + assertArrayEquals(expectedResult5, (Object[]) actualResult5); + + Object actualResult6 = Json.resp(client, key, ".str").get(); + String expectedResult6 = "foo"; + assertEquals(expectedResult6, actualResult6); + + assertArrayEquals(new Object[] {}, (Object[]) Json.resp(client, key, "$.nonexistent").get()); + + assertThrows(ExecutionException.class, () -> Json.resp(client, key, "nonexistent").get()); + + assertNull(Json.resp(client, "nonexistent_key", "$").get()); + assertNull(Json.resp(client, "nonexistent_key", ".").get()); + assertNull(Json.resp(client, "nonexistent_key").get()); + } } From 826bb08721af115b2e040788b7ae0b7112d82869 Mon Sep 17 00:00:00 2001 From: Yi-Pin Chen Date: Fri, 25 Oct 2024 14:44:18 -0700 Subject: [PATCH 049/180] Java: add JSON.ARRTRIM command (#2518) * Java: add JSON.ARRTRIM command --------- Signed-off-by: Yi-Pin Chen --- CHANGELOG.md | 1 + .../api/commands/servermodules/Json.java | 107 ++++++++++++++++++ .../test/java/glide/modules/JsonTests.java | 89 +++++++++++++++ 3 files changed, 197 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d226dffc62..e9224d1245 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ * Core: Update routing for commands from server modules ([#2461](https://github.com/valkey-io/valkey-glide/pull/2461)) * Node: Added `JSON.SET` and `JSON.GET` ([#2427](https://github.com/valkey-io/valkey-glide/pull/2427)) * Java: Added `JSON.ARRAPPEND` ([#2489](https://github.com/valkey-io/valkey-glide/pull/2489)) +* Java: Added `JSON.ARRTRIM` ([#2518](https://github.com/valkey-io/valkey-glide/pull/2518)) * Node: Added `JSON.TOGGLE` ([#2491](https://github.com/valkey-io/valkey-glide/pull/2491)) * Node: Added `JSON.DEL` and `JSON.FORGET` ([#2505](https://github.com/valkey-io/valkey-glide/pull/2505)) * Java: Added `JSON.TOGGLE` ([#2504](https://github.com/valkey-io/valkey-glide/pull/2504)) diff --git a/java/client/src/main/java/glide/api/commands/servermodules/Json.java b/java/client/src/main/java/glide/api/commands/servermodules/Json.java index 7fee5fefac..ac7f9af620 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/Json.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/Json.java @@ -25,6 +25,7 @@ public class Json { private static final String JSON_ARRAPPEND = JSON_PREFIX + "ARRAPPEND"; private static final String JSON_ARRINSERT = JSON_PREFIX + "ARRINSERT"; private static final String JSON_ARRLEN = JSON_PREFIX + "ARRLEN"; + private static final String JSON_ARRTRIM = JSON_PREFIX + "ARRTRIM"; private static final String JSON_OBJLEN = JSON_PREFIX + "OBJLEN"; private static final String JSON_OBJKEYS = JSON_PREFIX + "OBJKEYS"; private static final String JSON_DEL = JSON_PREFIX + "DEL"; @@ -713,6 +714,112 @@ public static CompletableFuture arrlen( return executeCommand(client, new GlideString[] {gs(JSON_ARRLEN), key}); } + /** + * Trims an array at the specified path within the JSON document started at key + * so that it becomes a subarray [start, end], both inclusive. + *
      + * If start < 0, it is treated as 0.
      + * If end >= size (size of the array), it is treated as size -1.
      + * If start >= size or start > end, the array is emptied + * and 0 is return.
      + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param start The index of the first element to keep, inclusive. + * @param end The index of the last element to keep, inclusive. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an Object[] with a list of integers for every possible path, + * indicating the new length of the array, or null for JSON values matching + * the path that are not an array. If the array is empty, its corresponding return value + * is 0. If path doesn't exist, an empty array will be return. If an index + * argument is out of bounds, an error is raised. + *
      • For legacy path (path doesn't start with $):
        + * Returns an integer representing the new length of the array. If the array is empty, + * its corresponding return value is 0. If multiple paths match, the length of the first + * trimmed array match is returned. If path doesn't exist, or the value at + * path is not an array, an error is raised. If an index argument is out of + * bounds, an error is raised. + *
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{[], [\"a\"], [\"a\", \"b\"], [\"a\", \"b\", \"c\"]}").get();
      +     * var res = Json.arrtrim(client, "doc", "$[*]", 0, 1).get();
      +     * assert Arrays.equals((Object[]) res, new Object[] { 0, 1, 2, 2 }); // New lengths of arrays after trimming
      +     *
      +     * Json.set(client, "doc", "$", "{\"children\": [\"John\", \"Jack\", \"Tom\", \"Bob\", \"Mike\"]}").get();
      +     * res = Json.arrtrim(client, "doc", ".children", 0, 1).get();
      +     * assert res == 2; // new length after trimming
      +     * }
      + */ + public static CompletableFuture arrtrim( + @NonNull BaseClient client, @NonNull String key, @NonNull String path, int start, int end) { + return executeCommand( + client, + new String[] {JSON_ARRTRIM, key, path, Integer.toString(start), Integer.toString(end)}); + } + + /** + * Trims an array at the specified path within the JSON document started at key + * so that it becomes a subarray [start, end], both inclusive. + *
      + * If start < 0, it is treated as 0.
      + * If end >= size (size of the array), it is treated as size -1.
      + * If start >= size or start > end, the array is emptied + * and 0 is return.
      + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param start The index of the first element to keep, inclusive. + * @param end The index of the last element to keep, inclusive. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an Object[] with a list of integers for every possible path, + * indicating the new length of the array, or null for JSON values matching + * the path that are not an array. If the array is empty, its corresponding return value + * is 0. If path doesn't exist, an empty array will be return. If an index + * argument is out of bounds, an error is raised. + *
      • For legacy path (path doesn't start with $):
        + * Returns an integer representing the new length of the array. If the array is empty, + * its corresponding return value is 0. If multiple paths match, the length of the first + * trimmed array match is returned. If path doesn't exist, or the value at + * path is not an array, an error is raised. If an index argument is out of + * bounds, an error is raised. + *
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{[], [\"a\"], [\"a\", \"b\"], [\"a\", \"b\", \"c\"]}").get();
      +     * var res = Json.arrtrim(client, gs("doc"), gs("$[*]"), 0, 1).get();
      +     * assert Arrays.equals((Object[]) res, new Object[] { 0, 1, 2, 2 }); // New lengths of arrays after trimming
      +     *
      +     * Json.set(client, "doc", "$", "{\"children\": [\"John\", \"Jack\", \"Tom\", \"Bob\", \"Mike\"]}").get();
      +     * res = Json.arrtrim(client, gs("doc"), gs(".children"), 0, 1).get();
      +     * assert res == 2; // new length after trimming
      +     * }
      + */ + public static CompletableFuture arrtrim( + @NonNull BaseClient client, + @NonNull GlideString key, + @NonNull GlideString path, + int start, + int end) { + return executeCommand( + client, + new ArgsBuilder() + .add(gs(JSON_ARRTRIM)) + .add(key) + .add(path) + .add(Integer.toString(start)) + .add(Integer.toString(end)) + .toArray()); + } + /** * Retrieves the number of key-value pairs in the object values at the specified path * within the JSON document stored at key.
      diff --git a/java/integTest/src/test/java/glide/modules/JsonTests.java b/java/integTest/src/test/java/glide/modules/JsonTests.java index 834a4c7f0e..8aa93413d1 100644 --- a/java/integTest/src/test/java/glide/modules/JsonTests.java +++ b/java/integTest/src/test/java/glide/modules/JsonTests.java @@ -322,6 +322,95 @@ public void arrlen() { assertEquals(5L, res); } + @Test + @SneakyThrows + public void arrtrim() { + String key = UUID.randomUUID().toString(); + + String doc = + "{\"a\": [0, 1, 2, 3, 4, 5, 6, 7, 8], \"b\": {\"a\": [0, 9, 10, 11, 12, 13], \"c\": {\"a\":" + + " 42}}}"; + assertEquals("OK", Json.set(client, key, "$", doc).get()); + + // Basic trim + var res = Json.arrtrim(client, key, "$..a", 1, 7).get(); + assertArrayEquals(new Object[] {7L, 5L, null}, (Object[]) res); + + String getResult = Json.get(client, key, new String[] {"$..a"}).get(); + String expectedGetResult = "[[1, 2, 3, 4, 5, 6, 7], [9, 10, 11, 12, 13], 42]"; + assertEquals(JsonParser.parseString(expectedGetResult), JsonParser.parseString(getResult)); + + // Test end >= size (should be treated as size-1) + res = Json.arrtrim(client, key, "$.a", 0, 10).get(); + assertArrayEquals(new Object[] {7L}, (Object[]) res); + res = Json.arrtrim(client, key, ".a", 0, 10).get(); + assertEquals(7L, res); + + // Test negative start (should be treated as 0) + res = Json.arrtrim(client, key, "$.a", -1, 5).get(); + assertArrayEquals(new Object[] {6L}, (Object[]) res); + res = Json.arrtrim(client, key, ".a", -1, 5).get(); + assertEquals(6L, res); + + // Test start >= size (should empty the array) + res = Json.arrtrim(client, key, "$.a", 7, 10).get(); + assertArrayEquals(new Object[] {0L}, (Object[]) res); + + assertEquals("OK", Json.set(client, key, ".a", "[\"a\", \"b\", \"c\"]").get()); + res = Json.arrtrim(client, key, ".a", 7, 10).get(); + assertEquals(0L, res); + + // Test start > end (should empty the array) + res = Json.arrtrim(client, key, "$..a", 2, 1).get(); + assertArrayEquals(new Object[] {0L, 0L, null}, (Object[]) res); + + assertEquals("OK", Json.set(client, key, ".a", "[\"a\", \"b\", \"c\", \"d\"]").get()); + res = Json.arrtrim(client, key, "..a", 2, 1).get(); + assertEquals(0L, res); + + // Multiple path match + assertEquals("OK", Json.set(client, key, "$", doc).get()); + res = Json.arrtrim(client, key, "..a", 1, 10).get(); + assertEquals(8L, res); + + getResult = Json.get(client, key, new String[] {"$..a"}).get(); + expectedGetResult = "[[1,2,3,4,5,6,7,8], [9,10,11,12,13], 42]"; + assertEquals(JsonParser.parseString(expectedGetResult), JsonParser.parseString(getResult)); + + // Test with non-existing path + var exception = + assertThrows( + ExecutionException.class, () -> Json.arrtrim(client, key, ".non_existing", 0, 1).get()); + + res = Json.arrtrim(client, key, "$.non_existing", 0, 1).get(); + assertArrayEquals(new Object[] {}, (Object[]) res); + + // Test with non-array path + res = Json.arrtrim(client, key, "$", 0, 1).get(); + assertArrayEquals(new Object[] {null}, (Object[]) res); + + exception = + assertThrows(ExecutionException.class, () -> Json.arrtrim(client, key, ".", 0, 1).get()); + + // Test with non-existing key + exception = + assertThrows( + ExecutionException.class, + () -> Json.arrtrim(client, "non_existing_key", "$", 0, 1).get()); + + exception = + assertThrows( + ExecutionException.class, + () -> Json.arrtrim(client, "non_existing_key", ".", 0, 1).get()); + + // Test with empty array + assertEquals("OK", Json.set(client, key, "$.empty", "[]").get()); + res = Json.arrtrim(client, key, "$.empty", 0, 1).get(); + assertArrayEquals(new Object[] {0L}, (Object[]) res); + res = Json.arrtrim(client, key, ".empty", 0, 1).get(); + assertEquals(0L, res); + } + @Test @SneakyThrows public void objlen() { From e2c3b8abf44faeccdc12594debefcdfa58f34b65 Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Fri, 25 Oct 2024 15:37:35 -0700 Subject: [PATCH 050/180] Java: `JSON.CLEAR`. (#2519) * `JSON.CLEAR`. Signed-off-by: Yury-Fridlyand --- CHANGELOG.md | 1 + .../api/commands/servermodules/Json.java | 118 ++++++++++++++++++ .../test/java/glide/modules/JsonTests.java | 25 ++++ 3 files changed, 144 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e9224d1245..7f0a3e6e00 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ * Node: Added `JSON.TOGGLE` ([#2491](https://github.com/valkey-io/valkey-glide/pull/2491)) * Node: Added `JSON.DEL` and `JSON.FORGET` ([#2505](https://github.com/valkey-io/valkey-glide/pull/2505)) * Java: Added `JSON.TOGGLE` ([#2504](https://github.com/valkey-io/valkey-glide/pull/2504)) +* Java: Added `JSON.CLEAR` ([#2519](https://github.com/valkey-io/valkey-glide/pull/2519)) * Node: Added `JSON.TYPE` ([#2510](https://github.com/valkey-io/valkey-glide/pull/2510)) * Java: Added `JSON.RESP` ([#2513](https://github.com/valkey-io/valkey-glide/pull/2513)) * Node: Added `FT.DROPINDEX` ([#2516](https://github.com/valkey-io/valkey-glide/pull/2516)) diff --git a/java/client/src/main/java/glide/api/commands/servermodules/Json.java b/java/client/src/main/java/glide/api/commands/servermodules/Json.java index ac7f9af620..fd473c6e85 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/Json.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/Json.java @@ -31,6 +31,7 @@ public class Json { private static final String JSON_DEL = JSON_PREFIX + "DEL"; private static final String JSON_FORGET = JSON_PREFIX + "FORGET"; private static final String JSON_TOGGLE = JSON_PREFIX + "TOGGLE"; + private static final String JSON_CLEAR = JSON_PREFIX + "CLEAR"; private static final String JSON_RESP = JSON_PREFIX + "RESP"; private static final String JSON_TYPE = JSON_PREFIX + "TYPE"; @@ -1322,6 +1323,123 @@ public static CompletableFuture toggle( client, new ArgsBuilder().add(gs(JSON_TOGGLE)).add(key).add(path).toArray()); } + /** + * Clears an array and an object at the root of the JSON document stored at key.
      + * Equivalent to {@link #clear(BaseClient, String, String)} with path set to + * ".". + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return 1 if the document wasn't empty or 0 if it was.
      + * If key doesn't exist, an error is raised. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\":1, \"b\":2}").get();
      +     * long res = Json.clear(client, "doc").get();
      +     * assert res == 1;
      +     *
      +     * var doc = Json.get(client, "doc", "$").get();
      +     * assert doc.equals("[{}]");
      +     *
      +     * res = Json.clear(client, "doc").get();
      +     * assert res == 0; // the doc is already empty
      +     * }
      + */ + public static CompletableFuture clear(@NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new String[] {JSON_CLEAR, key}); + } + + /** + * Clears an array and an object at the root of the JSON document stored at key.
      + * Equivalent to {@link #clear(BaseClient, GlideString, GlideString)} with path set + * to ".". + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return 1 if the document wasn't empty or 0 if it was.
      + * If key doesn't exist, an error is raised. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\":1, \"b\":2}").get();
      +     * long res = Json.clear(client, gs("doc")).get();
      +     * assert res == 1;
      +     *
      +     * var doc = Json.get(client, "doc", "$").get();
      +     * assert doc.equals("[{}]");
      +     *
      +     * res = Json.clear(client, gs("doc")).get();
      +     * assert res == 0; // the doc is already empty
      +     * }
      + */ + public static CompletableFuture clear( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new GlideString[] {gs(JSON_CLEAR), key}); + } + + /** + * Clears arrays and objects at the specified path within the JSON document stored at + * key.
      + * Numeric values are set to 0, boolean values are set to false, and + * string values are converted to empty strings. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return The number of containers cleared.
      + * If path doesn't exist, or the value at path is already cleared + * (e.g., an empty array, object, or string), 0 is returned. If key doesn't + * exist, an error is raised. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"obj\": {\"a\":1, \"b\":2}, \"arr\":[1, 2, 3], \"str\": \"foo\", \"bool\": true,
      +     *     \"int\": 42, \"float\": 3.14, \"nullVal\": null}").get();
      +     * long res = Json.clear(client, "doc", "$.*").get();
      +     * assert res == 6; // 6 values are cleared: "obj", "arr", "str", "bool", "int", and "float"; "nullVal" is not clearable.
      +     *
      +     * var doc = Json.get(client, "doc", "$").get();
      +     * assert doc.equals("[{\"obj\":{},\"arr\":[],\"str\":\"\",\"bool\":false,\"int\":0,\"float\":0.0,\"nullVal\":null}]");
      +     *
      +     * res = Json.clear(client, "doc", "$.*").get();
      +     * assert res == 0; // containers are already empty and nothing is cleared
      +     * }
      + */ + public static CompletableFuture clear( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, new String[] {JSON_CLEAR, key, path}); + } + + /** + * Clears arrays and objects at the specified path within the JSON document stored at + * key.
      + * Numeric values are set to 0, boolean values are set to false, and + * string values are converted to empty strings. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return The number of containers cleared.
      + * If path doesn't exist, or the value at path is already cleared + * (e.g., an empty array, object, or string), 0 is returned. If key doesn't + * exist, an error is raised. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"obj\": {\"a\":1, \"b\":2}, \"arr\":[1, 2, 3], \"str\": \"foo\", \"bool\": true,
      +     *     \"int\": 42, \"float\": 3.14, \"nullVal\": null}").get();
      +     * long res = Json.clear(client, gs("doc"), gs("$.*")).get();
      +     * assert res == 6; // 6 values are cleared: "obj", "arr", "str", "bool", "int", and "float"; "nullVal" is not clearable.
      +     *
      +     * var doc = Json.get(client, "doc", "$").get();
      +     * assert doc.equals("[{\"obj\":{},\"arr\":[],\"str\":\"\",\"bool\":false,\"int\":0,\"float\":0.0,\"nullVal\":null}]");
      +     *
      +     * res = Json.clear(client, gs("doc"), gs("$.*")).get();
      +     * assert res == 0; // containers are already empty and nothing is cleared
      +     * }
      + */ + public static CompletableFuture clear( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand(client, new GlideString[] {gs(JSON_CLEAR), key, path}); + } + /** * Retrieves the JSON document stored at key. The returning result is in the Valkey or Redis OSS Serialization Protocol (RESP). *
        diff --git a/java/integTest/src/test/java/glide/modules/JsonTests.java b/java/integTest/src/test/java/glide/modules/JsonTests.java index 8aa93413d1..69a609a65f 100644 --- a/java/integTest/src/test/java/glide/modules/JsonTests.java +++ b/java/integTest/src/test/java/glide/modules/JsonTests.java @@ -322,6 +322,31 @@ public void arrlen() { assertEquals(5L, res); } + @Test + @SneakyThrows + public void clear() { + String key = UUID.randomUUID().toString(); + String json = + "{\"obj\": {\"a\":1, \"b\":2}, \"arr\":[1, 2, 3], \"str\": \"foo\", \"bool\": true," + + " \"int\": 42, \"float\": 3.14, \"nullVal\": null}"; + + assertEquals("OK", Json.set(client, key, "$", json).get()); + + assertEquals(6L, Json.clear(client, key, "$.*").get()); + var doc = Json.get(client, key, new String[] {"$"}).get(); + assertEquals( + "[{\"obj\":{},\"arr\":[],\"str\":\"\",\"bool\":false,\"int\":0,\"float\":0.0,\"nullVal\":null}]", + doc); + assertEquals(0L, Json.clear(client, gs(key), gs(".*")).get()); + + assertEquals(1L, Json.clear(client, gs(key)).get()); + doc = Json.get(client, key, new String[] {"$"}).get(); + assertEquals("[{}]", doc); + + assertThrows( + ExecutionException.class, () -> Json.clear(client, UUID.randomUUID().toString()).get()); + } + @Test @SneakyThrows public void arrtrim() { From 190939de8ca9b4858b8cecfad0e657ebeda4f649 Mon Sep 17 00:00:00 2001 From: Avi Fenesh <55848801+avifenesh@users.noreply.github.com> Date: Sat, 26 Oct 2024 06:20:14 +0300 Subject: [PATCH 051/180] Rc 1.2 fix (#2503) Refactor GitHub workflows and configuration for improved dependency management and consistency and fixing alpine release Signed-off-by: avifenesh --- .github/json_matrices/build-matrix.json | 48 +++++++++++++++---- .../install-shared-dependencies/action.yml | 2 +- .github/workflows/node.yml | 6 +-- .github/workflows/npm-cd.yml | 25 +++++----- .github/workflows/pypi-cd.yml | 16 ++----- .../workflows/setup-musl-on-linux/action.yml | 4 +- node/package.json | 10 ++-- 7 files changed, 70 insertions(+), 41 deletions(-) diff --git a/.github/json_matrices/build-matrix.json b/.github/json_matrices/build-matrix.json index 45ac7a57f3..fc02093b9f 100644 --- a/.github/json_matrices/build-matrix.json +++ b/.github/json_matrices/build-matrix.json @@ -5,15 +5,27 @@ "RUNNER": "ubuntu-latest", "ARCH": "x64", "TARGET": "x86_64-unknown-linux-gnu", - "PACKAGE_MANAGERS": ["pypi", "npm", "maven"] + "PACKAGE_MANAGERS": [ + "pypi", + "npm", + "maven" + ] }, { "OS": "ubuntu", "NAMED_OS": "linux", - "RUNNER": ["self-hosted", "Linux", "ARM64"], + "RUNNER": [ + "self-hosted", + "Linux", + "ARM64" + ], "ARCH": "arm64", "TARGET": "aarch64-unknown-linux-gnu", - "PACKAGE_MANAGERS": ["pypi", "npm", "maven"], + "PACKAGE_MANAGERS": [ + "pypi", + "npm", + "maven" + ], "CONTAINER": "2_28" }, { @@ -22,7 +34,11 @@ "RUNNER": "macos-12", "ARCH": "x64", "TARGET": "x86_64-apple-darwin", - "PACKAGE_MANAGERS": ["pypi", "npm", "maven"] + "PACKAGE_MANAGERS": [ + "pypi", + "npm", + "maven" + ] }, { "OS": "macos", @@ -30,17 +46,27 @@ "RUNNER": "macos-latest", "ARCH": "arm64", "TARGET": "aarch64-apple-darwin", - "PACKAGE_MANAGERS": ["pypi", "npm", "maven"] + "PACKAGE_MANAGERS": [ + "pypi", + "npm", + "maven" + ] }, { "OS": "ubuntu", "NAMED_OS": "linux", "ARCH": "arm64", "TARGET": "aarch64-unknown-linux-musl", - "RUNNER": ["self-hosted", "Linux", "ARM64"], - "IMAGE": "node:alpine", + "RUNNER": [ + "self-hosted", + "Linux", + "ARM64" + ], + "IMAGE": "node:lts-alpine3.19", "CONTAINER_OPTIONS": "--user root --privileged --rm", - "PACKAGE_MANAGERS": ["npm"] + "PACKAGE_MANAGERS": [ + "npm" + ] }, { "OS": "ubuntu", @@ -48,8 +74,10 @@ "ARCH": "x64", "TARGET": "x86_64-unknown-linux-musl", "RUNNER": "ubuntu-latest", - "IMAGE": "node:alpine", + "IMAGE": "node:lts-alpine3.19", "CONTAINER_OPTIONS": "--user root --privileged", - "PACKAGE_MANAGERS": ["npm"] + "PACKAGE_MANAGERS": [ + "npm" + ] } ] diff --git a/.github/workflows/install-shared-dependencies/action.yml b/.github/workflows/install-shared-dependencies/action.yml index 1cb56e63f0..ed065e9840 100644 --- a/.github/workflows/install-shared-dependencies/action.yml +++ b/.github/workflows/install-shared-dependencies/action.yml @@ -39,7 +39,7 @@ runs: if: "${{ inputs.os == 'macos' }}" run: | brew update - brew install git gcc pkgconfig openssl coreutils + brew install git openssl coreutils - name: Install software dependencies for Ubuntu GNU shell: bash diff --git a/.github/workflows/node.yml b/.github/workflows/node.yml index c4c17a7e46..d8f690e560 100644 --- a/.github/workflows/node.yml +++ b/.github/workflows/node.yml @@ -70,7 +70,7 @@ jobs: - name: Use Node.js 16.x - uses: actions/setup-node@v3 + uses: actions/setup-node@v4 with: node-version: 16.x @@ -286,10 +286,10 @@ jobs: - uses: actions/checkout@v4 - - name: Use Node.js 18.x + - name: Use Node.js 16.x uses: actions/setup-node@v4 with: - node-version: 18.x + node-version: 16.x - name: Build Node wrapper uses: ./.github/workflows/build-node-wrapper diff --git a/.github/workflows/npm-cd.yml b/.github/workflows/npm-cd.yml index 3788d87e04..362117affb 100644 --- a/.github/workflows/npm-cd.yml +++ b/.github/workflows/npm-cd.yml @@ -119,10 +119,10 @@ jobs: INPUT_VERSION: ${{ github.event.inputs.version }} - name: Setup node - if: ${{ matrix.build.TARGET != 'aarch64-unknown-linux-musl' }} - uses: actions/setup-node@v3 + if: ${{ !contains(matrix.build.TARGET, 'musl') }} + uses: actions/setup-node@v4 with: - node-version: "20" + node-version: "latest" registry-url: "https://registry.npmjs.org" architecture: ${{ matrix.build.ARCH }} scope: "${{ vars.NPM_SCOPE }}" @@ -130,7 +130,7 @@ jobs: token: ${{ secrets.NPM_AUTH_TOKEN }} - name: Setup node for publishing - if: ${{ matrix.build.TARGET == 'aarch64-unknown-linux-musl' }} + if: ${{ !contains(matrix.build.TARGET, 'musl') }} working-directory: ./node run: | npm config set registry https://registry.npmjs.org/ @@ -185,7 +185,7 @@ jobs: # 2>&1 1>&3- redirects stderr to stdout and then redirects the original stdout to another file descriptor, # effectively separating stderr and stdout. The 3>&1 at the end redirects the original stdout back to the console. # https://github.com/npm/npm/issues/118#issuecomment-325440 - ignoring notice messages since currentlly they are directed to stderr - { npm_publish_err=$(npm publish --tag ${{ env.NPM_TAG }} --access public 2>&1 1>&3- | grep -v "notice") ;} 3>&1 + { npm_publish_err=$(npm publish --tag ${{ env.NPM_TAG }} --access public 2>&1 1>&3- | grep -Ev "notice|ExperimentalWarning") ;} 3>&1 if [[ "$npm_publish_err" == *"You cannot publish over the previously published versions"* ]] then echo "Skipping publishing, package already published" @@ -203,8 +203,11 @@ jobs: if: ${{ matrix.build.ARCH == 'arm64' }} shell: bash run: | - git reset --hard + echo "Resetting repository" git clean -xdf + git reset --hard + git fetch + git checkout ${{ github.sha }} publish-base-to-npm: if: github.event_name != 'pull_request' @@ -218,9 +221,9 @@ jobs: submodules: "true" - name: Install node - uses: actions/setup-node@v3 + uses: actions/setup-node@v4 with: - node-version: "20" + node-version: "latest" registry-url: "https://registry.npmjs.org" scope: "${{ vars.NPM_SCOPE }}" always-auth: true @@ -336,10 +339,10 @@ jobs: arch: ${{ matrix.build.ARCH }} - name: Setup node - if: ${{ matrix.build.TARGET != 'aarch64-unknown-linux-musl' }} - uses: actions/setup-node@v3 + if: ${{ !contains(matrix.build.TARGET, 'musl') }} + uses: actions/setup-node@v4 with: - node-version: "16" + node-version: "latest" registry-url: "https://registry.npmjs.org" architecture: ${{ matrix.build.ARCH }} scope: "${{ vars.NPM_SCOPE }}" diff --git a/.github/workflows/pypi-cd.yml b/.github/workflows/pypi-cd.yml index e69343f234..5d547ac7f5 100644 --- a/.github/workflows/pypi-cd.yml +++ b/.github/workflows/pypi-cd.yml @@ -112,16 +112,10 @@ jobs: if: ${{ !contains(matrix.build.RUNNER, 'self-hosted') }} uses: actions/setup-python@v5 with: - python-version: "3.10" - - - name: Set up Python older versions for MacOS - if: startsWith(matrix.build.NAMED_OS, 'darwin') - run: | - brew update - brew install python@3.8 python@3.9 + python-version: "3.12" - name: Setup Python for self-hosted Ubuntu runners - if: contains(matrix.build.OS, 'ubuntu') && contains(matrix.build.RUNNER, 'self-hosted') + if: contains(matrix.build.RUNNER, 'self-hosted') run: | sudo apt update -y sudo apt upgrade -y @@ -140,7 +134,7 @@ jobs: target: ${{ matrix.build.TARGET }} publish: "true" github-token: ${{ secrets.GITHUB_TOKEN }} - engine-version: "7.2.5" + engine-version: "7.2" - name: Include protobuf files in the package working-directory: ./python @@ -186,7 +180,7 @@ jobs: with: working-directory: ./python target: ${{ matrix.build.TARGET }} - args: --release --strip --out wheels -i ${{ github.event_name != 'pull_request' && 'python3.8 python3.9 python3.10 python3.11 python3.12' || 'python3.10' }} + args: --release --strip --out wheels -i ${{ github.event_name != 'pull_request' && 'python3.8 python3.9 python3.10 python3.11 python3.12' || 'python3.12' }} - name: Upload Python wheels if: github.event_name != 'pull_request' @@ -241,7 +235,7 @@ jobs: - name: Install ValKey uses: ./.github/workflows/install-valkey with: - version: "8.0.0" + version: "8.0" - name: Check if RC and set a distribution tag for the package shell: bash diff --git a/.github/workflows/setup-musl-on-linux/action.yml b/.github/workflows/setup-musl-on-linux/action.yml index ed4677b74a..f270c27507 100644 --- a/.github/workflows/setup-musl-on-linux/action.yml +++ b/.github/workflows/setup-musl-on-linux/action.yml @@ -42,10 +42,10 @@ runs: shell: bash run: | git config --global --add safe.directory "${{ inputs.workspace }}" + git fetch origin ${{ github.sha }} + git checkout ${{ github.sha }} git clean -xdf git reset --hard - git submodule sync - git submodule update --init --recursive - name: Set up access for musl on ARM shell: bash diff --git a/node/package.json b/node/package.json index 6cf654aa29..85d92b0476 100644 --- a/node/package.json +++ b/node/package.json @@ -46,7 +46,7 @@ }, "devDependencies": { "@jest/globals": "^29.7.0", - "@types/jest": "^29.5.12", + "@types/jest": "^29.5.14", "@types/minimist": "^1.2.5", "@types/redis-server": "^1.2.2", "@types/semver": "^7.5.8", @@ -71,10 +71,14 @@ "tests/", "rust-client/**", "!build-ts/**", - "babel.config.js", + ".prettierignore", "jest.config.js", "hybrid-node-tests/**", - "docs/" + "docs/", + "DEVELOPER.md", + ".ort.yml", + "tsconfig.json", + "THIRD_PARTY_LICENSES_NODE" ] }, "//": [ From 8d09a1c1f94275eaf43a9dc9bc5f82882e0ab112 Mon Sep 17 00:00:00 2001 From: Shoham Elias <116083498+shohamazon@users.noreply.github.com> Date: Sun, 27 Oct 2024 10:33:49 +0200 Subject: [PATCH 052/180] Python: adds JSON.STRLEN, JSON.STRAPPEND commands (#2372) --------- Signed-off-by: Shoham Elias Signed-off-by: Shoham Elias <116083498+shohamazon@users.noreply.github.com> Co-authored-by: Andrew Carbonetto Co-authored-by: jonathanl-bq <72158117+jonathanl-bq@users.noreply.github.com> --- CHANGELOG.md | 1 + .../async_commands/server_modules/json.py | 248 +++++++++++++----- .../tests/tests_server_modules/test_json.py | 116 +++++++- 3 files changed, 289 insertions(+), 76 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f0a3e6e00..18bef9336c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ * Node: Added `JSON.TYPE` ([#2510](https://github.com/valkey-io/valkey-glide/pull/2510)) * Java: Added `JSON.RESP` ([#2513](https://github.com/valkey-io/valkey-glide/pull/2513)) * Node: Added `FT.DROPINDEX` ([#2516](https://github.com/valkey-io/valkey-glide/pull/2516)) +* Python: Add `JSON.STRAPPEND` , `JSON.STRLEN` commands ([#2372](https://github.com/valkey-io/valkey-glide/pull/2372)) #### Breaking Changes diff --git a/python/python/glide/async_commands/server_modules/json.py b/python/python/glide/async_commands/server_modules/json.py index 7da1c6c4aa..e2bfa20ac2 100644 --- a/python/python/glide/async_commands/server_modules/json.py +++ b/python/python/glide/async_commands/server_modules/json.py @@ -1,18 +1,18 @@ # Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 -"""module for `RedisJSON` commands. +"""Glide module for `JSON` commands. Examples: - >>> from glide import json as redisJson - >>> import json + >>> from glide import json + >>> import json as jsonpy >>> value = {'a': 1.0, 'b': 2} - >>> json_str = json.dumps(value) # Convert Python dictionary to JSON string using json.dumps() - >>> await redisJson.set(client, "doc", "$", json_str) + >>> json_str = jsonpy.dumps(value) # Convert Python dictionary to JSON string using json.dumps() + >>> await json.set(client, "doc", "$", json_str) 'OK' # Indicates successful setting of the value at path '$' in the key stored at `doc`. - >>> json_get = await redisJson.get(client, "doc", "$") # Returns the value at path '$' in the JSON document stored at `doc` as JSON string. + >>> json_get = await json.get(client, "doc", "$") # Returns the value at path '$' in the JSON document stored at `doc` as JSON string. >>> print(json_get) b"[{\"a\":1.0,\"b\":2}]" - >>> json.loads(str(json_get)) + >>> jsonpy.loads(str(json_get)) [{"a": 1.0, "b" :2}] # JSON object retrieved from the key `doc` using json.loads() """ from typing import List, Optional, Union, cast @@ -64,27 +64,25 @@ async def set( """ Sets the JSON value at the specified `path` stored at `key`. - See https://valkey.io/commands/json.set/ for more details. - Args: - client (TGlideClient): The Redis client to execute the command. + client (TGlideClient): The client to execute the command. key (TEncodable): The key of the JSON document. path (TEncodable): Represents the path within the JSON document where the value will be set. The key will be modified only if `value` is added as the last child in the specified `path`, or if the specified `path` acts as the parent of a new child being added. value (TEncodable): The value to set at the specific path, in JSON formatted bytes or str. set_condition (Optional[ConditionalChange]): Set the value only if the given condition is met (within the key or path). - Equivalent to [`XX` | `NX`] in the Redis API. Defaults to None. + Equivalent to [`XX` | `NX`] in the RESP API. Defaults to None. Returns: Optional[TOK]: If the value is successfully set, returns OK. - If value isn't set because of `set_condition`, returns None. + If `value` isn't set because of `set_condition`, returns None. Examples: - >>> from glide import json as redisJson - >>> import json + >>> from glide import json + >>> import json as jsonpy >>> value = {'a': 1.0, 'b': 2} - >>> json_str = json.dumps(value) - >>> await redisJson.set(client, "doc", "$", json_str) + >>> json_str = jsonpy.dumps(value) + >>> await json.set(client, "doc", "$", json_str) 'OK' # Indicates successful setting of the value at path '$' in the key stored at `doc`. """ args = ["JSON.SET", key, path, value] @@ -99,33 +97,43 @@ async def get( key: TEncodable, paths: Optional[Union[TEncodable, List[TEncodable]]] = None, options: Optional[JsonGetOptions] = None, -) -> Optional[bytes]: +) -> TJsonResponse[Optional[bytes]]: """ Retrieves the JSON value at the specified `paths` stored at `key`. - See https://valkey.io/commands/json.get/ for more details. - Args: - client (TGlideClient): The Redis client to execute the command. + client (TGlideClient): The client to execute the command. key (TEncodable): The key of the JSON document. - paths (Optional[Union[TEncodable, List[TEncodable]]]): The path or list of paths within the JSON document. Default is root `$`. + paths (Optional[Union[TEncodable, List[TEncodable]]]): The path or list of paths within the JSON document. Default to None. options (Optional[JsonGetOptions]): Options for formatting the byte representation of the JSON data. See `JsonGetOptions`. Returns: - bytes: A bytes representation of the returned value. - If `key` doesn't exists, returns None. + TJsonResponse[Optional[bytes]]: + If one path is given: + For JSONPath (path starts with `$`): + Returns a stringified JSON list of bytes replies for every possible path, + or a byte string representation of an empty array, if path doesn't exists. + If `key` doesn't exist, returns None. + For legacy path (path doesn't start with `$`): + Returns a byte string representation of the value in `path`. + If `path` doesn't exist, an error is raised. + If `key` doesn't exist, returns None. + If multiple paths are given: + Returns a stringified JSON object in bytes, in which each path is a key, and it's corresponding value, is the value as if the path was executed in the command as a single path. + In case of multiple paths, and `paths` are a mix of both JSONPath and legacy path, the command behaves as if all are JSONPath paths. + For more information about the returned type, see `TJsonResponse`. Examples: - >>> from glide import json as redisJson - >>> import json - >>> json_str = await redisJson.get(client, "doc", "$") - >>> json.loads(str(json_str)) # Parse JSON string to Python data + >>> from glide import json, JsonGetOptions + >>> import as jsonpy + >>> json_str = await json.get(client, "doc", "$") + >>> jsonpy.loads(str(json_str)) # Parse JSON string to Python data [{"a": 1.0, "b" :2}] # JSON object retrieved from the key `doc` using json.loads() - >>> await redisJson.get(client, "doc", "$") + >>> await json.get(client, "doc", "$") b"[{\"a\":1.0,\"b\":2}]" # Returns the value at path '$' in the JSON document stored at `doc`. - >>> await redisJson.get(client, "doc", ["$.a", "$.b"], json.JsonGetOptions(indent=" ", newline="\n", space=" ")) + >>> await json.get(client, "doc", ["$.a", "$.b"], JsonGetOptions(indent=" ", newline="\n", space=" ")) b"{\n \"$.a\": [\n 1.0\n ],\n \"$.b\": [\n 2\n ]\n}" # Returns the values at paths '$.a' and '$.b' in the JSON document stored at `doc`, with specified formatting options. - >>> await redisJson.get(client, "doc", "$.non_existing_path") + >>> await json.get(client, "doc", "$.non_existing_path") b"[]" # Returns an empty array since the path '$.non_existing_path' does not exist in the JSON document stored at `doc`. """ args = ["JSON.GET", key] @@ -136,7 +144,7 @@ async def get( paths = [paths] args.extend(paths) - return cast(bytes, await client.custom_command(args)) + return cast(TJsonResponse[Optional[bytes]], await client.custom_command(args)) async def arrlen( @@ -167,7 +175,7 @@ async def arrlen( Examples: >>> from glide import json >>> await json.set(client, "doc", "$", '{"a": [1, 2, 3], "b": {"a": [1, 2], "c": {"a": 42}}}') - b'OK' # JSON is successfully set for doc + 'OK' # JSON is successfully set for doc >>> await json.arrlen(client, "doc", "$") [None] # No array at the root path. >>> await json.arrlen(client, "doc", "$.a") @@ -180,7 +188,7 @@ async def arrlen( None # Returns None because the key does not exist. >>> await json.set(client, "doc", "$", '[1, 2, 3, 4]') - b'OK' # JSON is successfully set for doc + 'OK' # JSON is successfully set for doc >>> await json.arrlen(client, "doc") 4 # Retrieves lengths of arrays in root. """ @@ -205,7 +213,7 @@ async def clear( Args: client (TGlideClient): The client to execute the command. key (TEncodable): The key of the JSON document. - path (Optional[str]): The JSON path to the arrays or objects to be cleared. Defaults to None. + path (Optional[str]): The path within the JSON document. Default to None. Returns: int: The number of containers cleared, numeric values zeroed, and booleans toggled to `false`, @@ -216,7 +224,7 @@ async def clear( Examples: >>> from glide import json >>> await json.set(client, "doc", "$", '{"obj":{"a":1, "b":2}, "arr":[1,2,3], "str": "foo", "bool": true, "int": 42, "float": 3.14, "nullVal": null}') - b'OK' # JSON document is successfully set. + 'OK' # JSON document is successfully set. >>> await json.clear(client, "doc", "$.*") 6 # 6 values are cleared (arrays/objects/strings/numbers/booleans), but `null` remains as is. >>> await json.get(client, "doc", "$") @@ -225,7 +233,7 @@ async def clear( 0 # No further clearing needed since the containers are already empty and the values are defaults. >>> await json.set(client, "doc", "$", '{"a": 1, "b": {"a": [5, 6, 7], "b": {"a": true}}, "c": {"a": "value", "b": {"a": 3.5}}, "d": {"a": {"foo": "foo"}}, "nullVal": null}') - b'OK' + 'OK' >>> await json.clear(client, "doc", "b.a[1:3]") 2 # 2 elements (`6` and `7`) are cleared. >>> await json.clear(client, "doc", "b.a[1:3]") @@ -253,27 +261,25 @@ async def delete( """ Deletes the JSON value at the specified `path` within the JSON document stored at `key`. - See https://valkey.io/commands/json.del/ for more details. - Args: - client (TGlideClient): The Redis client to execute the command. + client (TGlideClient): The client to execute the command. key (TEncodable): The key of the JSON document. - path (Optional[TEncodable]): Represents the path within the JSON document where the value will be deleted. + path (Optional[TEncodable]): The path within the JSON document. If None, deletes the entire JSON document at `key`. Defaults to None. Returns: int: The number of elements removed. - If `key` or path doesn't exist, returns 0. + If `key` or `path` doesn't exist, returns 0. Examples: - >>> from glide import json as redisJson - >>> await redisJson.set(client, "doc", "$", '{"a": 1, "nested": {"a": 2, "b": 3}}') + >>> from glide import json + >>> await json.set(client, "doc", "$", '{"a": 1, "nested": {"a": 2, "b": 3}}') 'OK' # Indicates successful setting of the value at path '$' in the key stored at `doc`. - >>> await redisJson.delete(client, "doc", "$..a") + >>> await json.delete(client, "doc", "$..a") 2 # Indicates successful deletion of the specific values in the key stored at `doc`. - >>> await redisJson.get(client, "doc", "$") + >>> await json.get(client, "doc", "$") "[{\"nested\":{\"b\":3}}]" # Returns the value at path '$' in the JSON document stored at `doc`. - >>> await redisJson.delete(client, "doc") + >>> await json.delete(client, "doc") 1 # Deletes the entire JSON document stored at `doc`. """ @@ -290,27 +296,25 @@ async def forget( """ Deletes the JSON value at the specified `path` within the JSON document stored at `key`. - See https://valkey.io/commands/json.forget/ for more details. - Args: - client (TGlideClient): The Redis client to execute the command. + client (TGlideClient): The client to execute the command. key (TEncodable): The key of the JSON document. - path (Optional[TEncodable]): Represents the path within the JSON document where the value will be deleted. + path (Optional[TEncodable]): The path within the JSON document. If None, deletes the entire JSON document at `key`. Defaults to None. Returns: int: The number of elements removed. - If `key` or path doesn't exist, returns 0. + If `key` or `path` doesn't exist, returns 0. Examples: - >>> from glide import json as redisJson - >>> await redisJson.set(client, "doc", "$", '{"a": 1, "nested": {"a": 2, "b": 3}}') + >>> from glide import json + >>> await json.set(client, "doc", "$", '{"a": 1, "nested": {"a": 2, "b": 3}}') 'OK' # Indicates successful setting of the value at path '$' in the key stored at `doc`. - >>> await redisJson.forget(client, "doc", "$..a") + >>> await json.forget(client, "doc", "$..a") 2 # Indicates successful deletion of the specific values in the key stored at `doc`. - >>> await redisJson.get(client, "doc", "$") + >>> await json.get(client, "doc", "$") "[{\"nested\":{\"b\":3}}]" # Returns the value at path '$' in the JSON document stored at `doc`. - >>> await redisJson.forget(client, "doc") + >>> await json.forget(client, "doc") 1 # Deletes the entire JSON document stored at `doc`. """ @@ -404,6 +408,103 @@ async def nummultby( return cast(Optional[bytes], await client.custom_command(args)) +async def strappend( + client: TGlideClient, + key: TEncodable, + value: TEncodable, + path: Optional[TEncodable] = None, +) -> TJsonResponse[int]: + """ + Appends the specified `value` to the string stored at the specified `path` within the JSON document stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + value (TEncodable): The value to append to the string. Must be wrapped with single quotes. For example, to append "foo", pass '"foo"'. + path (Optional[TEncodable]): The path within the JSON document. Default to None. + + Returns: + TJsonResponse[int]: + For JSONPath (`path` starts with `$`): + Returns a list of integer replies for every possible path, indicating the length of the resulting string after appending `value`, + or None for JSON values matching the path that are not string. + If `key` doesn't exist, an error is raised. + For legacy path (`path` doesn't start with `$`): + Returns the length of the resulting string after appending `value` to the string at `path`. + If multiple paths match, the length of the last updated string is returned. + If the JSON value at `path` is not a string of if `path` doesn't exist, an error is raised. + If `key` doesn't exist, an error is raised. + For more information about the returned type, see `TJsonResponse`. + + Examples: + >>> from glide import json + >>> import json as jsonpy + >>> await json.set(client, "doc", "$", jsonpy.dumps({"a":"foo", "nested": {"a": "hello"}, "nested2": {"a": 31}})) + 'OK' + >>> await json.strappend(client, "doc", jsonpy.dumps("baz"), "$..a") + [6, 8, None] # The new length of the string values at path '$..a' in the key stored at `doc` after the append operation. + >>> await json.strappend(client, "doc", '"foo"', "nested.a") + 11 # The length of the string value after appending "foo" to the string at path 'nested.array' in the key stored at `doc`. + >>> jsonpy.loads(await json.get(client, jsonpy.dumps("doc"), "$")) + [{"a":"foobaz", "nested": {"a": "hellobazfoo"}, "nested2": {"a": 31}}] # The updated JSON value in the key stored at `doc`. + """ + + return cast( + TJsonResponse[int], + await client.custom_command( + ["JSON.STRAPPEND", key] + ([path, value] if path else [value]) + ), + ) + + +async def strlen( + client: TGlideClient, + key: TEncodable, + path: Optional[TEncodable] = None, +) -> TJsonResponse[Optional[int]]: + """ + Returns the length of the JSON string value stored at the specified `path` within the JSON document stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (Optional[TEncodable]): The path within the JSON document. Default to None. + + Returns: + TJsonResponse[Optional[int]]: + For JSONPath (`path` starts with `$`): + Returns a list of integer replies for every possible path, indicating the length of the JSON string value, + or None for JSON values matching the path that are not string. + For legacy path (`path` doesn't start with `$`): + Returns the length of the JSON value at `path` or None if `key` doesn't exist. + If multiple paths match, the length of the first mached string is returned. + If the JSON value at `path` is not a string of if `path` doesn't exist, an error is raised. + If `key` doesn't exist, None is returned. + For more information about the returned type, see `TJsonResponse`. + + Examples: + >>> from glide import json + >>> import jsonpy + >>> await json.set(client, "doc", "$", jsonpy.dumps({"a":"foo", "nested": {"a": "hello"}, "nested2": {"a": 31}})) + 'OK' + >>> await json.strlen(client, "doc", "$..a") + [3, 5, None] # The length of the string values at path '$..a' in the key stored at `doc`. + >>> await json.strlen(client, "doc", "nested.a") + 5 # The length of the JSON value at path 'nested.a' in the key stored at `doc`. + >>> await json.strlen(client, "doc", "$") + [None] # Returns an array with None since the value at root path does in the JSON document stored at `doc` is not a string. + >>> await json.strlen(client, "non_existing_key", ".") + None # `key` doesn't exist. + """ + + return cast( + TJsonResponse[Optional[int]], + await client.custom_command( + ["JSON.STRLEN", key, path] if path else ["JSON.STRLEN", key] + ), + ) + + async def toggle( client: TGlideClient, key: TEncodable, @@ -412,30 +513,33 @@ async def toggle( """ Toggles a Boolean value stored at the specified `path` within the JSON document stored at `key`. - See https://valkey.io/commands/json.toggle/ for more details. - Args: - client (TGlideClient): The Redis client to execute the command. + client (TGlideClient): The client to execute the command. key (TEncodable): The key of the JSON document. - path (TEncodable): The JSONPath to specify. + path (TEncodable): The path within the JSON document. Default to None. Returns: - TJsonResponse[bool]: For JSONPath (`path` starts with `$`), returns a list of boolean replies for every possible path, with the toggled boolean value, - or None for JSON values matching the path that are not boolean. - For legacy path (`path` doesn't starts with `$`), returns the value of the toggled boolean in `path`. - Note that when sending legacy path syntax, If `path` doesn't exist or the value at `path` isn't a boolean, an error is raised. + TJsonResponse[bool]: + For JSONPath (`path` starts with `$`): + Returns a list of boolean replies for every possible path, with the toggled boolean value, + or None for JSON values matching the path that are not boolean. + If `key` doesn't exist, an error is raised. + For legacy path (`path` doesn't start with `$`): + Returns the value of the toggled boolean in `path`. + If the JSON value at `path` is not a boolean of if `path` doesn't exist, an error is raised. + If `key` doesn't exist, an error is raised. For more information about the returned type, see `TJsonResponse`. Examples: - >>> from glide import json as redisJson - >>> import json - >>> await redisJson.set(client, "doc", "$", json.dumps({"bool": True, "nested": {"bool": False, "nested": {"bool": 10}}})) + >>> from glide import json + >>> import json as jsonpy + >>> await json.set(client, "doc", "$", jsonpy.dumps({"bool": True, "nested": {"bool": False, "nested": {"bool": 10}}})) 'OK' - >>> await redisJson.toggle(client, "doc", "$.bool") + >>> await json.toggle(client, "doc", "$.bool") [False, True, None] # Indicates successful toggling of the Boolean values at path '$.bool' in the key stored at `doc`. - >>> await redisJson.toggle(client, "doc", "bool") + >>> await json.toggle(client, "doc", "bool") True # Indicates successful toggling of the Boolean value at path 'bool' in the key stored at `doc`. - >>> json.loads(await redisJson.get(client, "doc", "$")) + >>> jsonpy.loads(await json.get(client, "doc", "$")) [{"bool": True, "nested": {"bool": True, "nested": {"bool": 10}}}] # The updated JSON value in the key stored at `doc`. """ @@ -456,8 +560,7 @@ async def type( Args: client (TGlideClient): The client to execute the command. key (TEncodable): The key of the JSON document. - path (Optional[TEncodable]): Represents the path within the JSON document where the type will be retrieved. - Defaults to None. + path (Optional[TEncodable]): The path within the JSON document. Default to None. Returns: Optional[Union[bytes, List[bytes]]]: @@ -473,6 +576,7 @@ async def type( Examples: >>> from glide import json >>> await json.set(client, "doc", "$", '{"a": 1, "nested": {"a": 2, "b": 3}}') + 'OK' >>> await json.type(client, "doc", "$.nested") [b'object'] # Indicates the type of the value at path '$.nested' in the key stored at `doc`. >>> await json.type(client, "doc", "$.nested.a") diff --git a/python/python/tests/tests_server_modules/test_json.py b/python/python/tests/tests_server_modules/test_json.py index 794e885cfe..982b563995 100644 --- a/python/python/tests/tests_server_modules/test_json.py +++ b/python/python/tests/tests_server_modules/test_json.py @@ -144,42 +144,78 @@ async def test_json_get_formatting(self, glide_client: TGlideClient): @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) - async def test_del(self, glide_client: TGlideClient): + async def test_json_del(self, glide_client: TGlideClient): key = get_random_string(5) json_value = {"a": 1.0, "b": {"a": 1, "b": 2.5, "c": True}} assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + # Non-exiseting paths + assert await json.delete(glide_client, key, "$..path") == 0 + assert await json.delete(glide_client, key, "..path") == 0 + assert await json.delete(glide_client, key, "$..a") == 2 assert await json.get(glide_client, key, "$..a") == b"[]" + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + + assert await json.delete(glide_client, key, "..a") == 2 + with pytest.raises(RequestError): + assert await json.get(glide_client, key, "..a") + result = await json.get(glide_client, key, "$") assert isinstance(result, bytes) assert OuterJson.loads(result) == [{"b": {"b": 2.5, "c": True}}] assert await json.delete(glide_client, key, "$") == 1 + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + assert await json.delete(glide_client, key, ".") == 1 + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + assert await json.delete(glide_client, key) == 1 assert await json.delete(glide_client, key) == 0 assert await json.get(glide_client, key, "$") == None + # Non-existing keys + assert await json.delete(glide_client, "non_existing_key", "$") == 0 + assert await json.delete(glide_client, "non_existing_key", ".") == 0 + @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) - async def test_forget(self, glide_client: TGlideClient): + async def test_json_forget(self, glide_client: TGlideClient): key = get_random_string(5) json_value = {"a": 1.0, "b": {"a": 1, "b": 2.5, "c": True}} assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + # Non-existing paths + assert await json.forget(glide_client, key, "$..path") == 0 + assert await json.forget(glide_client, key, "..path") == 0 + assert await json.forget(glide_client, key, "$..a") == 2 assert await json.get(glide_client, key, "$..a") == b"[]" + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + + assert await json.forget(glide_client, key, "..a") == 2 + with pytest.raises(RequestError): + assert await json.get(glide_client, key, "..a") + result = await json.get(glide_client, key, "$") assert isinstance(result, bytes) assert OuterJson.loads(result) == [{"b": {"b": 2.5, "c": True}}] assert await json.forget(glide_client, key, "$") == 1 + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + assert await json.forget(glide_client, key, ".") == 1 + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + assert await json.forget(glide_client, key) == 1 assert await json.forget(glide_client, key) == 0 assert await json.get(glide_client, key, "$") == None + # Non-existing keys + assert await json.forget(glide_client, "non_existing_key", "$") == 0 + assert await json.forget(glide_client, "non_existing_key", ".") == 0 + @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) async def test_json_toggle(self, glide_client: TGlideClient): @@ -189,11 +225,15 @@ async def test_json_toggle(self, glide_client: TGlideClient): assert await json.toggle(glide_client, key, "$..bool") == [False, True, None] assert await json.toggle(glide_client, key, "bool") is True + assert await json.toggle(glide_client, key, "$.not_existing") == [] assert await json.toggle(glide_client, key, "$.nested") == [None] with pytest.raises(RequestError): assert await json.toggle(glide_client, key, "nested") + with pytest.raises(RequestError): + assert await json.toggle(glide_client, key, ".not_existing") + with pytest.raises(RequestError): assert await json.toggle(glide_client, "non_exiting_key", "$") @@ -470,7 +510,7 @@ async def test_json_numincrby(self, glide_client: TGlideClient): assert result == b"76" # Check if the rest of the key1 path matches were updated and not only the last value - result = await json.get(glide_client, key, "$..key1") + result = await json.get(glide_client, key, "$..key1") # type: ignore assert ( result == b"[0,[16,17],76]" ) # First is 0 as 0 + 0 = 0, Second doesn't change as its an array type (non-numeric), third is 76 as 0 + 76 = 0 @@ -610,7 +650,7 @@ async def test_json_nummultby(self, glide_client: TGlideClient): assert result == b"1380" # Expect the last updated key1 value multiplied by 2 # Check if the rest of the key1 path matches were updated and not only the last value - result = await json.get(glide_client, key, "$..key1") + result = await json.get(glide_client, key, "$..key1") # type: ignore assert result == b"[-16500,[140,175],1380]" # Check for non-existent path in legacy @@ -624,3 +664,71 @@ async def test_json_nummultby(self, glide_client: TGlideClient): # Check for Overflow in legacy with pytest.raises(RequestError): await json.nummultby(glide_client, key, ".key9", 1.7976931348623157e308) + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_strlen(self, glide_client: TGlideClient): + key = get_random_string(10) + json_value = {"a": "foo", "nested": {"a": "hello"}, "nested2": {"a": 31}} + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + + assert await json.strlen(glide_client, key, "$..a") == [3, 5, None] + assert await json.strlen(glide_client, key, "a") == 3 + + assert await json.strlen(glide_client, key, "$.nested") == [None] + with pytest.raises(RequestError): + assert await json.strlen(glide_client, key, "nested") + + with pytest.raises(RequestError): + assert await json.strlen(glide_client, key) + + assert await json.strlen(glide_client, key, "$.non_existing_path") == [] + with pytest.raises(RequestError): + await json.strlen(glide_client, key, ".non_existing_path") + + assert await json.strlen(glide_client, "non_exiting_key", ".") is None + assert await json.strlen(glide_client, "non_exiting_key", "$") is None + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_strappend(self, glide_client: TGlideClient): + key = get_random_string(10) + json_value = {"a": "foo", "nested": {"a": "hello"}, "nested2": {"a": 31}} + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + + assert await json.strappend(glide_client, key, '"bar"', "$..a") == [6, 8, None] + assert await json.strappend(glide_client, key, OuterJson.dumps("foo"), "a") == 9 + + json_str = await json.get(glide_client, key, ".") + assert isinstance(json_str, bytes) + assert OuterJson.loads(json_str) == { + "a": "foobarfoo", + "nested": {"a": "hellobar"}, + "nested2": {"a": 31}, + } + + assert await json.strappend( + glide_client, key, OuterJson.dumps("bar"), "$.nested" + ) == [None] + + with pytest.raises(RequestError): + await json.strappend(glide_client, key, OuterJson.dumps("bar"), ".nested") + + with pytest.raises(RequestError): + await json.strappend(glide_client, key, OuterJson.dumps("bar")) + + assert ( + await json.strappend( + glide_client, key, OuterJson.dumps("try"), "$.non_existing_path" + ) + == [] + ) + with pytest.raises(RequestError): + await json.strappend( + glide_client, key, OuterJson.dumps("try"), "non_existing_path" + ) + + with pytest.raises(RequestError): + await json.strappend( + glide_client, "non_exiting_key", OuterJson.dumps("try") + ) From d71298340968581f9019cc8fa360d1059154e941 Mon Sep 17 00:00:00 2001 From: Shoham Elias <116083498+shohamazon@users.noreply.github.com> Date: Sun, 27 Oct 2024 10:56:40 +0200 Subject: [PATCH 053/180] Python: adds JSON.OBJKEYS command (#2395) --------- Signed-off-by: Shoham Elias --- CHANGELOG.md | 1 + .../async_commands/server_modules/json.py | 49 +++++++++++++++++++ .../tests/tests_server_modules/test_json.py | 39 +++++++++++++++ 3 files changed, 89 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 18bef9336c..637cd071aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ * Java: Added `JSON.RESP` ([#2513](https://github.com/valkey-io/valkey-glide/pull/2513)) * Node: Added `FT.DROPINDEX` ([#2516](https://github.com/valkey-io/valkey-glide/pull/2516)) * Python: Add `JSON.STRAPPEND` , `JSON.STRLEN` commands ([#2372](https://github.com/valkey-io/valkey-glide/pull/2372)) +* Python: Add `JSON.OBJKEYS` command ([#2395](https://github.com/valkey-io/valkey-glide/pull/2395)) #### Breaking Changes diff --git a/python/python/glide/async_commands/server_modules/json.py b/python/python/glide/async_commands/server_modules/json.py index e2bfa20ac2..1d9d5001b4 100644 --- a/python/python/glide/async_commands/server_modules/json.py +++ b/python/python/glide/async_commands/server_modules/json.py @@ -505,6 +505,55 @@ async def strlen( ) +async def objkeys( + client: TGlideClient, + key: TEncodable, + path: Optional[TEncodable] = None, +) -> Optional[Union[List[bytes], List[List[bytes]]]]: + """ + Retrieves key names in the object values at the specified `path` within the JSON document stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (Optional[TEncodable]): Represents the path within the JSON document where the key names will be retrieved. + Defaults to None. + + Returns: + Optional[Union[List[bytes], List[List[bytes]]]]: + For JSONPath (`path` starts with `$`): + Returns a list of arrays containing key names for each matching object. + If a value matching the path is not an object, an empty array is returned. + If `path` doesn't exist, an empty array is returned. + For legacy path (`path` starts with `.`): + Returns a list of key names for the object value matching the path. + If multiple objects match the path, the key names of the first object are returned. + If a value matching the path is not an object, an error is raised. + If `path` doesn't exist, None is returned. + If `key` doesn't exist, None is returned. + + Examples: + >>> from glide import json + >>> await json.set(client, "doc", "$", '{"a": 1.0, "b": {"a": {"x": 1, "y": 2}, "b": 2.5, "c": true}}') + b'OK' # Indicates successful setting of the value at the root path '$' in the key `doc`. + >>> await json.objkeys(client, "doc", "$") + [[b"a", b"b"]] # Returns a list of arrays containing the key names for objects matching the path '$'. + >>> await json.objkeys(client, "doc", ".") + [b"a", b"b"] # Returns key names for the object matching the path '.' as it is the only match. + >>> await json.objkeys(client, "doc", "$.b") + [[b"a", b"b", b"c"]] # Returns key names as a nested list for objects matching the JSONPath '$.b'. + >>> await json.objkeys(client, "doc", ".b") + [b"a", b"b", b"c"] # Returns key names for the nested object at path '.b'. + """ + args = ["JSON.OBJKEYS", key] + if path: + args.append(path) + return cast( + Optional[Union[List[bytes], List[List[bytes]]]], + await client.custom_command(args), + ) + + async def toggle( client: TGlideClient, key: TEncodable, diff --git a/python/python/tests/tests_server_modules/test_json.py b/python/python/tests/tests_server_modules/test_json.py index 982b563995..27d70b015f 100644 --- a/python/python/tests/tests_server_modules/test_json.py +++ b/python/python/tests/tests_server_modules/test_json.py @@ -216,6 +216,45 @@ async def test_json_forget(self, glide_client: TGlideClient): assert await json.forget(glide_client, "non_existing_key", "$") == 0 assert await json.forget(glide_client, "non_existing_key", ".") == 0 + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_objkeys(self, glide_client: TGlideClient): + key = get_random_string(5) + + json_value = {"a": 1.0, "b": {"a": {"x": 1, "y": 2}, "b": 2.5, "c": True}} + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + + keys = await json.objkeys(glide_client, key, "$") + assert keys == [[b"a", b"b"]] + + keys = await json.objkeys(glide_client, key, ".") + assert keys == [b"a", b"b"] + + keys = await json.objkeys(glide_client, key, "$..") + assert keys == [[b"a", b"b"], [b"a", b"b", b"c"], [b"x", b"y"]] + + keys = await json.objkeys(glide_client, key, "..") + assert keys == [b"a", b"b"] + + keys = await json.objkeys(glide_client, key, "$..b") + assert keys == [[b"a", b"b", b"c"], []] + + keys = await json.objkeys(glide_client, key, "..b") + assert keys == [b"a", b"b", b"c"] + + # path doesn't exist + assert await json.objkeys(glide_client, key, "$.non_existing_path") == [] + assert await json.objkeys(glide_client, key, "non_existing_path") == None + + # Value at path isnt an object + assert await json.objkeys(glide_client, key, "$.a") == [[]] + with pytest.raises(RequestError): + assert await json.objkeys(glide_client, key, ".a") + + # Non-existing key + assert await json.objkeys(glide_client, "non_exiting_key", "$") == None + assert await json.objkeys(glide_client, "non_exiting_key", ".") == None + @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) async def test_json_toggle(self, glide_client: TGlideClient): From 08cce19e092f4f10439d892222917233b3269545 Mon Sep 17 00:00:00 2001 From: Shoham Elias Date: Wed, 16 Oct 2024 10:53:28 +0000 Subject: [PATCH 054/180] Python: adds JSON.ARRINSERT command Signed-off-by: Shoham Elias --- CHANGELOG.md | 1 + .../async_commands/server_modules/json.py | 51 +++++ .../tests/tests_server_modules/test_json.py | 196 ++++++++++++++++++ 3 files changed, 248 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 637cd071aa..6f72e9153c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ * Node: Added `FT.DROPINDEX` ([#2516](https://github.com/valkey-io/valkey-glide/pull/2516)) * Python: Add `JSON.STRAPPEND` , `JSON.STRLEN` commands ([#2372](https://github.com/valkey-io/valkey-glide/pull/2372)) * Python: Add `JSON.OBJKEYS` command ([#2395](https://github.com/valkey-io/valkey-glide/pull/2395)) +* Python: Add `JSON.ARRINSERT` command ([#2464](https://github.com/valkey-io/valkey-glide/pull/2464)) #### Breaking Changes diff --git a/python/python/glide/async_commands/server_modules/json.py b/python/python/glide/async_commands/server_modules/json.py index 1d9d5001b4..48048e6b26 100644 --- a/python/python/glide/async_commands/server_modules/json.py +++ b/python/python/glide/async_commands/server_modules/json.py @@ -147,6 +147,57 @@ async def get( return cast(TJsonResponse[Optional[bytes]], await client.custom_command(args)) +async def arrinsert( + client: TGlideClient, + key: TEncodable, + path: TEncodable, + index: int, + values: List[TEncodable], +) -> TJsonResponse[int]: + """ + Inserts one or more values into the array at the specified `path` within the JSON document stored at `key`, before the given `index`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (TEncodable): The path within the JSON document. + index (int): The array index before which values are inserted. + values (List[TEncodable]): The JSON values to be inserted into the array, in JSON formatted bytes or str. + Json string values must be wrapped with single quotes. For example, to append "foo", pass '"foo"'. + + Returns: + TJsonResponse[int]: + For JSONPath (`path` starts with '$'): + Returns a list of integer replies for every possible path, indicating the new length of the array, + or None for JSON values matching the path that are not an array. + If `path` does not exist, an empty array will be returned. + For legacy path (`path` doesn't start with '$'): + Returns an integer representing the new length of the array. + If multiple paths are matched, returns the length of the first modified array. + If `path` doesn't exist or the value at `path` is not an array, an error is raised. + If the index is out of bounds, an error is raised. + If `key` doesn't exist, an error is raised. + + Examples: + >>> from glide import json + >>> await json.set(client, "doc", "$", '[[], ["a"], ["a", "b"]]') + 'OK' + >>> await json.arrinsert(client, "doc", "$[*]", 0, ['"c"', '{"key": "value"}', "true", "null", '["bar"]']) + [5, 6, 7] # New lengths of arrays after insertion + >>> await json.get(client, "doc") + b'[["c",{"key":"value"},true,null,["bar"]],["c",{"key":"value"},true,null,["bar"],"a"],["c",{"key":"value"},true,null,["bar"],"a","b"]]' + + >>> await json.set(client, "doc", "$", '[[], ["a"], ["a", "b"]]') + 'OK' + >>> await json.arrinsert(client, "doc", ".", 0, ['"c"']) + 4 # New length of the root array after insertion + >>> await json.get(client, "doc") + b'[\"c\",[],[\"a\"],[\"a\",\"b\"]]' + """ + args = ["JSON.ARRINSERT", key, path, str(index)] + values + return cast(TJsonResponse[int], await client.custom_command(args)) + + async def arrlen( client: TGlideClient, key: TEncodable, diff --git a/python/python/tests/tests_server_modules/test_json.py b/python/python/tests/tests_server_modules/test_json.py index 27d70b015f..c00dcdeca6 100644 --- a/python/python/tests/tests_server_modules/test_json.py +++ b/python/python/tests/tests_server_modules/test_json.py @@ -1,6 +1,7 @@ # Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 import json as OuterJson +import typing import pytest from glide.async_commands.core import ConditionalChange, InfoSection @@ -771,3 +772,198 @@ async def test_json_strappend(self, glide_client: TGlideClient): await json.strappend( glide_client, "non_exiting_key", OuterJson.dumps("try") ) + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + @typing.no_type_check # since this is a complex test, skip typing to be more effective + async def test_json_arrinsert(self, glide_client: TGlideClient): + key = get_random_string(10) + + assert ( + await json.set( + glide_client, + key, + "$", + """ + { + "a": [], + "b": { "a": [1, 2, 3, 4] }, + "c": { "a": "not an array" }, + "d": [{ "a": ["x", "y"] }, { "a": [["foo"]] }], + "e": [{ "a": 42 }, { "a": {} }], + "f": { "a": [true, false, null] } + } + """, + ) + == OK + ) + + # Insert different types of values into the matching paths + result = await json.arrinsert( + glide_client, + key, + "$..a", + 0, + ['"string_value"', "123", '{"key": "value"}', "true", "null", '["bar"]'], + ) + assert result == [6, 10, None, 8, 7, None, None, 9] + + updated_doc = await json.get(glide_client, key) + + expected_doc = { + "a": ["string_value", 123, {"key": "value"}, True, None, ["bar"]], + "b": { + "a": [ + "string_value", + 123, + {"key": "value"}, + True, + None, + ["bar"], + 1, + 2, + 3, + 4, + ], + }, + "c": {"a": "not an array"}, + "d": [ + { + "a": [ + "string_value", + 123, + {"key": "value"}, + True, + None, + ["bar"], + "x", + "y", + ] + }, + { + "a": [ + "string_value", + 123, + {"key": "value"}, + True, + None, + ["bar"], + ["foo"], + ] + }, + ], + "e": [{"a": 42}, {"a": {}}], + "f": { + "a": [ + "string_value", + 123, + {"key": "value"}, + True, + None, + ["bar"], + True, + False, + None, + ] + }, + } + + assert OuterJson.loads(updated_doc) == expected_doc + + # Insert into a specific index (non-zero) + result = await json.arrinsert( + glide_client, + key, + "$..a", + 2, + ['"insert_at_2"'], + ) + assert result == [7, 11, None, 9, 8, None, None, 10] + + # Check document after insertion at index 2 + updated_doc_at_2 = await json.get(glide_client, key) + expected_doc["a"].insert(2, "insert_at_2") + expected_doc["b"]["a"].insert(2, "insert_at_2") + expected_doc["d"][0]["a"].insert(2, "insert_at_2") + expected_doc["d"][1]["a"].insert(2, "insert_at_2") + expected_doc["f"]["a"].insert(2, "insert_at_2") + assert OuterJson.loads(updated_doc_at_2) == expected_doc + + # Insert with a legacy path + result = await json.arrinsert( + glide_client, + key, + "..a", # legacy path + 0, + ['"legacy_value"'], + ) + assert ( + result == 8 + ) # Returns length of the first modified array (in this case, 'a') + + # Check document after insertion at root legacy path (all matching arrays should be updated) + updated_doc_legacy = await json.get(glide_client, key) + + # Update `expected_doc` with the new value inserted at index 0 of all matching arrays + expected_doc["a"].insert(0, "legacy_value") + expected_doc["b"]["a"].insert(0, "legacy_value") + expected_doc["d"][0]["a"].insert(0, "legacy_value") + expected_doc["d"][1]["a"].insert(0, "legacy_value") + expected_doc["f"]["a"].insert(0, "legacy_value") + + assert OuterJson.loads(updated_doc_legacy) == expected_doc + + # Insert with an index out of range for some arrays + with pytest.raises(RequestError): + await json.arrinsert( + glide_client, + key, + "$..a", + 10, # Index out of range for some paths but valid for others + ['"out_of_range_value"'], + ) + + with pytest.raises(RequestError): + await json.arrinsert( + glide_client, + key, + "..a", + 10, # Index out of range for some paths but valid for others + ['"out_of_range_value"'], + ) + + # Negative index insertion (should insert from the end of the array) + result = await json.arrinsert( + glide_client, + key, + "$..a", + -1, + ['"negative_index_value"'], + ) + assert result == [9, 13, None, 11, 10, None, None, 12] # Update valid paths + + # Check document after negative index insertion + updated_doc_negative = await json.get(glide_client, key) + expected_doc["a"].insert(-1, "negative_index_value") + expected_doc["b"]["a"].insert(-1, "negative_index_value") + expected_doc["d"][0]["a"].insert(-1, "negative_index_value") + expected_doc["d"][1]["a"].insert(-1, "negative_index_value") + expected_doc["f"]["a"].insert(-1, "negative_index_value") + assert OuterJson.loads(updated_doc_negative) == expected_doc + + # Non-existing path + with pytest.raises(RequestError): + await json.arrinsert(glide_client, key, ".path", 5, ['"value"']) + + await json.arrinsert(glide_client, key, "$.path", 5, ['"value"']) == [] + + # Key doesnt exist + with pytest.raises(RequestError): + await json.arrinsert(glide_client, "non_existent_key", "$", 5, ['"value"']) + + with pytest.raises(RequestError): + await json.arrinsert(glide_client, "non_existent_key", ".", 5, ['"value"']) + + # value at path is not an array + with pytest.raises(RequestError): + await json.arrinsert(glide_client, key, ".e", 5, ['"value"']) From a99d89a34a6cc4723ba3583a3745a5dc509b21f2 Mon Sep 17 00:00:00 2001 From: Shoham Elias Date: Sun, 27 Oct 2024 09:17:52 +0000 Subject: [PATCH 055/180] Python: reorder json.py commands Signed-off-by: Shoham Elias --- .../async_commands/server_modules/json.py | 98 +++++++++---------- 1 file changed, 49 insertions(+), 49 deletions(-) diff --git a/python/python/glide/async_commands/server_modules/json.py b/python/python/glide/async_commands/server_modules/json.py index 48048e6b26..06fc8a943b 100644 --- a/python/python/glide/async_commands/server_modules/json.py +++ b/python/python/glide/async_commands/server_modules/json.py @@ -459,6 +459,55 @@ async def nummultby( return cast(Optional[bytes], await client.custom_command(args)) +async def objkeys( + client: TGlideClient, + key: TEncodable, + path: Optional[TEncodable] = None, +) -> Optional[Union[List[bytes], List[List[bytes]]]]: + """ + Retrieves key names in the object values at the specified `path` within the JSON document stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (Optional[TEncodable]): Represents the path within the JSON document where the key names will be retrieved. + Defaults to None. + + Returns: + Optional[Union[List[bytes], List[List[bytes]]]]: + For JSONPath (`path` starts with `$`): + Returns a list of arrays containing key names for each matching object. + If a value matching the path is not an object, an empty array is returned. + If `path` doesn't exist, an empty array is returned. + For legacy path (`path` starts with `.`): + Returns a list of key names for the object value matching the path. + If multiple objects match the path, the key names of the first object are returned. + If a value matching the path is not an object, an error is raised. + If `path` doesn't exist, None is returned. + If `key` doesn't exist, None is returned. + + Examples: + >>> from glide import json + >>> await json.set(client, "doc", "$", '{"a": 1.0, "b": {"a": {"x": 1, "y": 2}, "b": 2.5, "c": true}}') + b'OK' # Indicates successful setting of the value at the root path '$' in the key `doc`. + >>> await json.objkeys(client, "doc", "$") + [[b"a", b"b"]] # Returns a list of arrays containing the key names for objects matching the path '$'. + >>> await json.objkeys(client, "doc", ".") + [b"a", b"b"] # Returns key names for the object matching the path '.' as it is the only match. + >>> await json.objkeys(client, "doc", "$.b") + [[b"a", b"b", b"c"]] # Returns key names as a nested list for objects matching the JSONPath '$.b'. + >>> await json.objkeys(client, "doc", ".b") + [b"a", b"b", b"c"] # Returns key names for the nested object at path '.b'. + """ + args = ["JSON.OBJKEYS", key] + if path: + args.append(path) + return cast( + Optional[Union[List[bytes], List[List[bytes]]]], + await client.custom_command(args), + ) + + async def strappend( client: TGlideClient, key: TEncodable, @@ -556,55 +605,6 @@ async def strlen( ) -async def objkeys( - client: TGlideClient, - key: TEncodable, - path: Optional[TEncodable] = None, -) -> Optional[Union[List[bytes], List[List[bytes]]]]: - """ - Retrieves key names in the object values at the specified `path` within the JSON document stored at `key`. - - Args: - client (TGlideClient): The client to execute the command. - key (TEncodable): The key of the JSON document. - path (Optional[TEncodable]): Represents the path within the JSON document where the key names will be retrieved. - Defaults to None. - - Returns: - Optional[Union[List[bytes], List[List[bytes]]]]: - For JSONPath (`path` starts with `$`): - Returns a list of arrays containing key names for each matching object. - If a value matching the path is not an object, an empty array is returned. - If `path` doesn't exist, an empty array is returned. - For legacy path (`path` starts with `.`): - Returns a list of key names for the object value matching the path. - If multiple objects match the path, the key names of the first object are returned. - If a value matching the path is not an object, an error is raised. - If `path` doesn't exist, None is returned. - If `key` doesn't exist, None is returned. - - Examples: - >>> from glide import json - >>> await json.set(client, "doc", "$", '{"a": 1.0, "b": {"a": {"x": 1, "y": 2}, "b": 2.5, "c": true}}') - b'OK' # Indicates successful setting of the value at the root path '$' in the key `doc`. - >>> await json.objkeys(client, "doc", "$") - [[b"a", b"b"]] # Returns a list of arrays containing the key names for objects matching the path '$'. - >>> await json.objkeys(client, "doc", ".") - [b"a", b"b"] # Returns key names for the object matching the path '.' as it is the only match. - >>> await json.objkeys(client, "doc", "$.b") - [[b"a", b"b", b"c"]] # Returns key names as a nested list for objects matching the JSONPath '$.b'. - >>> await json.objkeys(client, "doc", ".b") - [b"a", b"b", b"c"] # Returns key names for the nested object at path '.b'. - """ - args = ["JSON.OBJKEYS", key] - if path: - args.append(path) - return cast( - Optional[Union[List[bytes], List[List[bytes]]]], - await client.custom_command(args), - ) - - async def toggle( client: TGlideClient, key: TEncodable, From d95040ca86f2f9cad5b86f3d39eda77eb4061335 Mon Sep 17 00:00:00 2001 From: Muhammad Awawdi Date: Sun, 27 Oct 2024 14:19:04 +0200 Subject: [PATCH 056/180] Python: add JSON.DEBUG.FIELDS and JSON.DEBUG.MEMORY commands (#2481) --------- Signed-off-by: Muhammad Awawdi Signed-off-by: Shoham Elias Co-authored-by: Shoham Elias --- CHANGELOG.md | 1 + .../async_commands/server_modules/json.py | 101 ++++++++ .../tests/tests_server_modules/test_json.py | 224 ++++++++++++++++++ 3 files changed, 326 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f72e9153c..9e89f724f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ * Python: Add JSON.TYPE command ([#2409](https://github.com/valkey-io/valkey-glide/pull/2409)) * Python: Add JSON.NUMINCRBY command ([#2448](https://github.com/valkey-io/valkey-glide/pull/2448)) * Python: Add JSON.NUMMULTBY command ([#2458](https://github.com/valkey-io/valkey-glide/pull/2458)) +* Python: Add `JSON.DEBUG_MEMORY` and `JSON.DEBUG_FIELDS` commands ([#2481](https://github.com/valkey-io/valkey-glide/pull/2481)) * Java: Added `FT.CREATE` ([#2414](https://github.com/valkey-io/valkey-glide/pull/2414)) * Java: Added `FT.INFO` ([#2405](https://github.com/valkey-io/valkey-glide/pull/2441)) * Java: Added `FT.DROPINDEX` ([#2440](https://github.com/valkey-io/valkey-glide/pull/2440)) diff --git a/python/python/glide/async_commands/server_modules/json.py b/python/python/glide/async_commands/server_modules/json.py index 06fc8a943b..74fb0cc327 100644 --- a/python/python/glide/async_commands/server_modules/json.py +++ b/python/python/glide/async_commands/server_modules/json.py @@ -304,6 +304,107 @@ async def clear( return cast(int, await client.custom_command(args)) +async def debug_fields( + client: TGlideClient, + key: TEncodable, + path: Optional[TEncodable] = None, +) -> Optional[Union[int, List[int]]]: + """ + Returns the number of fields of the JSON value at the specified `path` within the JSON document stored at `key`. + - **Primitive Values**: Each non-container JSON value (e.g., strings, numbers, booleans, and null) counts as one field. + - **Arrays and Objects:**: Each item in an array and each key-value pair in an object is counted as one field. (Each top-level value counts as one field, regardless of it's type.) + - Their nested values are counted recursively and added to the total. + - **Example**: For the JSON `{"a": 1, "b": [2, 3, {"c": 4}]}`, the count would be: + - Top-level: 2 fields (`"a"` and `"b"`) + - Nested: 3 fields in the array (`2`, `3`, and `{"c": 4}`) plus 1 for the object (`"c"`) + - Total: 2 (top-level) + 3 (from array) + 1 (from nested object) = 6 fields. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (Optional[TEncodable]): The path within the JSON document. Defaults to root if not provided. + + Returns: + Optional[Union[int, List[int]]]: + For JSONPath (`path` starts with `$`): + Returns an array of integers, each indicating the number of fields for each matched `path`. + If `path` doesn't exist, an empty array will be returned. + For legacy path (`path` doesn't start with `$`): + Returns an integer indicating the number of fields for each matched `path`. + If multiple paths match, number of fields of the first JSON value match is returned. + If `path` doesn't exist, an error is raised. + If `path` is not provided, it reports the total number of fields in the entire JSON document. + If `key` doesn't exist, None is returned. + + Examples: + >>> from glide import json + >>> await json.set(client, "k1", "$", '[1, 2.3, "foo", true, null, {}, [], {"a":1, "b":2}, [1,2,3]]') + 'OK' + >>> await json.debug_fields(client, "k1", "$[*]") + [1, 1, 1, 1, 1, 0, 0, 2, 3] + >>> await json.debug_fields(client, "k1", ".") + 14 # 9 top-level fields + 5 nested address fields + + >>> await json.set(client, "k1", "$", '{"firstName":"John","lastName":"Smith","age":27,"weight":135.25,"isAlive":true,"address":{"street":"21 2nd Street","city":"New York","state":"NY","zipcode":"10021-3100"},"phoneNumbers":[{"type":"home","number":"212 555-1234"},{"type":"office","number":"646 555-4567"}],"children":[],"spouse":null}') + 'OK' + >>> await json.debug_fields(client, "k1") + 19 + >>> await json.debug_fields(client, "k1", ".address") + 4 + """ + args = ["JSON.DEBUG", "FIELDS", key] + if path: + args.append(path) + + return cast(Optional[Union[int, List[int]]], await client.custom_command(args)) + + +async def debug_memory( + client: TGlideClient, + key: TEncodable, + path: Optional[TEncodable] = None, +) -> Optional[Union[int, List[int]]]: + """ + Reports memory usage in bytes of a JSON value at the specified `path` within the JSON document stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (Optional[TEncodable]): The path within the JSON document. Defaults to None. + + Returns: + Optional[Union[int, List[int]]]: + For JSONPath (`path` starts with `$`): + Returns an array of integers, indicating the memory usage in bytes of a JSON value for each matched `path`. + If `path` doesn't exist, an empty array will be returned. + For legacy path (`path` doesn't start with `$`): + Returns an integer, indicating the memory usage in bytes for the JSON value in `path`. + If multiple paths match, the memory usage of the first JSON value match is returned. + If `path` doesn't exist, an error is raised. + If `path` is not provided, it reports the total memory usage in bytes in the entire JSON document. + If `key` doesn't exist, None is returned. + + Examples: + >>> from glide import json + >>> await json.set(client, "k1", "$", '[1, 2.3, "foo", true, null, {}, [], {"a":1, "b":2}, [1,2,3]]') + 'OK' + >>> await json.debug_memory(client, "k1", "$[*]") + [16, 16, 19, 16, 16, 16, 16, 66, 64] + + >>> await json.set(client, "k1", "$", '{"firstName":"John","lastName":"Smith","age":27,"weight":135.25,"isAlive":true,"address":{"street":"21 2nd Street","city":"New York","state":"NY","zipcode":"10021-3100"},"phoneNumbers":[{"type":"home","number":"212 555-1234"},{"type":"office","number":"646 555-4567"}],"children":[],"spouse":null}') + 'OK' + >>> await json.debug_memory(client, "k1") + 472 + >>> await json.debug_memory(client, "k1", ".phoneNumbers") + 164 + """ + args = ["JSON.DEBUG", "MEMORY", key] + if path: + args.append(path) + + return cast(Optional[Union[int, List[int]]], await client.custom_command(args)) + + async def delete( client: TGlideClient, key: TEncodable, diff --git a/python/python/tests/tests_server_modules/test_json.py b/python/python/tests/tests_server_modules/test_json.py index c00dcdeca6..c37019911b 100644 --- a/python/python/tests/tests_server_modules/test_json.py +++ b/python/python/tests/tests_server_modules/test_json.py @@ -967,3 +967,227 @@ async def test_json_arrinsert(self, glide_client: TGlideClient): # value at path is not an array with pytest.raises(RequestError): await json.arrinsert(glide_client, key, ".e", 5, ['"value"']) + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_debug_fields(self, glide_client: TGlideClient): + key = get_random_string(10) + + json_value = { + "key1": 1, + "key2": 3.5, + "key3": {"nested_key": {"key1": [4, 5]}}, + "key4": [1, 2, 3], + "key5": 0, + "key6": "hello", + "key7": None, + "key8": {"nested_key": {"key1": 3.5953862697246314e307}}, + "key9": 3.5953862697246314e307, + "key10": True, + } + + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + + # Test JSONPath - Fields Subcommand + # Test integer + result = await json.debug_fields(glide_client, key, "$.key1") + assert result == [1] + + # Test float + result = await json.debug_fields(glide_client, key, "$.key2") + assert result == [1] + + # Test Nested Value + result = await json.debug_fields(glide_client, key, "$.key3") + assert result == [4] + + result = await json.debug_fields(glide_client, key, "$.key3.nested_key.key1") + assert result == [2] + + # Test Array + result = await json.debug_fields(glide_client, key, "$.key4[2]") + assert result == [1] + + # Test String + result = await json.debug_fields(glide_client, key, "$.key6") + assert result == [1] + + # Test Null + result = await json.debug_fields(glide_client, key, "$.key7") + assert result == [1] + + # Test Bool + result = await json.debug_fields(glide_client, key, "$.key10") + assert result == [1] + + # Test all keys + result = await json.debug_fields(glide_client, key, "$[*]") + assert result == [1, 1, 4, 3, 1, 1, 1, 2, 1, 1] + + # Test multiple paths + result = await json.debug_fields(glide_client, key, "$..key1") + assert result == [1, 2, 1] + + # Test for non-existent path + result = await json.debug_fields(glide_client, key, "$.key11") + assert result == [] + + # Test for non-existent key + result = await json.debug_fields(glide_client, "non_existent_key", "$.key10") + assert result == None + + # Test no provided path + # Total Fields (19) - breakdown: + # Top-Level Fields: 10 + # Fields within key3: 4 ($.key3, $.key3.nested_key, $.key3.nested_key.key1, $.key3.nested_key.key1) + # Fields within key4: 3 ($.key4[0], $.key4[1], $.key4[2]) + # Fields within key8: 2 ($.key8, $.key8.nested_key) + result = await json.debug_fields(glide_client, key) + assert result == 19 + + # Test legacy path - Fields Subcommand + # Test integer + result = await json.debug_fields(glide_client, key, ".key1") + assert result == 1 + + # Test float + result = await json.debug_fields(glide_client, key, ".key2") + assert result == 1 + + # Test Nested Value + result = await json.debug_fields(glide_client, key, ".key3") + assert result == 4 + + result = await json.debug_fields(glide_client, key, ".key3.nested_key.key1") + assert result == 2 + + # Test Array + result = await json.debug_fields(glide_client, key, ".key4[2]") + assert result == 1 + + # Test String + result = await json.debug_fields(glide_client, key, ".key6") + assert result == 1 + + # Test Null + result = await json.debug_fields(glide_client, key, ".key7") + assert result == 1 + + # Test Bool + result = await json.debug_fields(glide_client, key, ".key10") + assert result == 1 + + # Test multiple paths + result = await json.debug_fields(glide_client, key, "..key1") + assert result == 1 # Returns number of fields of the first JSON value + + # Test for non-existent path + with pytest.raises(RequestError): + await json.debug_fields(glide_client, key, ".key11") + + # Test for non-existent key + result = await json.debug_fields(glide_client, "non_existent_key", ".key10") + assert result == None + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_debug_memory(self, glide_client: TGlideClient): + key = get_random_string(10) + + json_value = { + "key1": 1, + "key2": 3.5, + "key3": {"nested_key": {"key1": [4, 5]}}, + "key4": [1, 2, 3], + "key5": 0, + "key6": "hello", + "key7": None, + "key8": {"nested_key": {"key1": 3.5953862697246314e307}}, + "key9": 3.5953862697246314e307, + "key10": True, + } + + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + # Test JSONPath - Memory Subcommand + # Test integer + result = await json.debug_memory(glide_client, key, "$.key1") + assert result == [16] + # Test float + result = await json.debug_memory(glide_client, key, "$.key2") + assert result == [16] + # Test Nested Value + result = await json.debug_memory(glide_client, key, "$.key3.nested_key.key1[0]") + assert result == [16] + # Test Array + result = await json.debug_memory(glide_client, key, "$.key4") + assert result == [16 * 4] + + result = await json.debug_memory(glide_client, key, "$.key4[2]") + assert result == [16] + # Test String + result = await json.debug_memory(glide_client, key, "$.key6") + assert result == [16] + # Test Null + result = await json.debug_memory(glide_client, key, "$.key7") + assert result == [16] + # Test Bool + result = await json.debug_memory(glide_client, key, "$.key10") + assert result == [16] + # Test all keys + result = await json.debug_memory(glide_client, key, "$[*]") + assert result == [16, 16, 110, 64, 16, 16, 16, 101, 39, 16] + # Test multiple paths + result = await json.debug_memory(glide_client, key, "$..key1") + assert result == [16, 48, 39] + # Test for non-existent path + result = await json.debug_memory(glide_client, key, "$.key11") + assert result == [] + # Test for non-existent key + result = await json.debug_memory(glide_client, "non_existent_key", "$.key10") + assert result == None + # Test no provided path + # Total Memory (504 bytes) - visual breakdown: + # ├── Root Object Overhead (129 bytes) + # └── JSON Elements (374 bytes) + # ├── key1: 16 bytes + # ├── key2: 16 bytes + # ├── key3: 110 bytes + # ├── key4: 64 bytes + # ├── key5: 16 bytes + # ├── key6: 16 bytes + # ├── key7: 16 bytes + # ├── key8: 101 bytes + # └── key9: 39 bytes + result = await json.debug_memory(glide_client, key) + assert result == 504 + # Test Legacy Path - Memory Subcommand + # Test integer + result = await json.debug_memory(glide_client, key, ".key1") + assert result == 16 + # Test float + result = await json.debug_memory(glide_client, key, ".key2") + assert result == 16 + # Test Nested Value + result = await json.debug_memory(glide_client, key, ".key3.nested_key.key1[0]") + assert result == 16 + # Test Array + result = await json.debug_memory(glide_client, key, ".key4[2]") + assert result == 16 + # Test String + result = await json.debug_memory(glide_client, key, ".key6") + assert result == 16 + # Test Null + result = await json.debug_memory(glide_client, key, ".key7") + assert result == 16 + # Test Bool + result = await json.debug_memory(glide_client, key, ".key10") + assert result == 16 + # Test multiple paths + result = await json.debug_memory(glide_client, key, "..key1") + assert result == 16 # Returns the memory usage of the first JSON value + # Test for non-existent path + with pytest.raises(RequestError): + await json.debug_memory(glide_client, key, ".key11") + # Test for non-existent key + result = await json.debug_memory(glide_client, "non_existent_key", ".key10") + assert result == None From c197f1d93319d017eb656c77fe38e5c69fd9c9f7 Mon Sep 17 00:00:00 2001 From: Shoham Elias <116083498+shohamazon@users.noreply.github.com> Date: Sun, 27 Oct 2024 15:45:18 +0200 Subject: [PATCH 057/180] Python: adds JSON.ARRTRIM command (#2457) --------- Signed-off-by: Shoham Elias --- CHANGELOG.md | 1 + .../async_commands/server_modules/json.py | 55 ++++++++++ .../tests/tests_server_modules/test_json.py | 103 ++++++++++++++++++ 3 files changed, 159 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e89f724f8..c1bec5d60b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ * Python: Add `JSON.STRAPPEND` , `JSON.STRLEN` commands ([#2372](https://github.com/valkey-io/valkey-glide/pull/2372)) * Python: Add `JSON.OBJKEYS` command ([#2395](https://github.com/valkey-io/valkey-glide/pull/2395)) * Python: Add `JSON.ARRINSERT` command ([#2464](https://github.com/valkey-io/valkey-glide/pull/2464)) +* Python: Add `JSON.ARRTRIM` command ([#2457](https://github.com/valkey-io/valkey-glide/pull/2457)) #### Breaking Changes diff --git a/python/python/glide/async_commands/server_modules/json.py b/python/python/glide/async_commands/server_modules/json.py index 74fb0cc327..40e4e80c51 100644 --- a/python/python/glide/async_commands/server_modules/json.py +++ b/python/python/glide/async_commands/server_modules/json.py @@ -252,6 +252,61 @@ async def arrlen( ) +async def arrtrim( + client: TGlideClient, + key: TEncodable, + path: TEncodable, + start: int, + end: int, +) -> TJsonResponse[int]: + """ + Trims an array at the specified `path` within the JSON document stored at `key` so that it becomes a subarray [start, end], both inclusive.› + If `start` < 0, it is treated as 0. + If `end` >= size (size of the array), it is treated as size-1. + If `start` >= size or `start` > `end`, the array is emptied and 0 is returned. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (TEncodable): The path within the JSON document. + start (int): The start index, inclusive. + end (int): The end index, inclusive. + + Returns: + TJsonResponse[int]: + For JSONPath (`path` starts with '$'): + Returns a list of integer replies for every possible path, indicating the new length of the array, or None for JSON values matching the path that are not an array. + If a value is an empty array, its corresponding return value is 0. + If `path` doesn't exist, an empty array will be returned. + For legacy path (`path` doesn't starts with `$`): + Returns an integer representing the new length of the array. + If the array is empty, returns 0. + If multiple paths match, the length of the first trimmed array match is returned. + If `path` doesn't exist, or the value at `path` is not an array, an error is raised. + If `key` doesn't exist, an error is raised. + + Examples: + >>> from glide import json + >>> await json.set(client, "doc", "$", '[[], ["a"], ["a", "b"], ["a", "b", "c"]]') + 'OK' + >>> await json.arrtrim(client, "doc", "$[*]", 0, 1) + [0, 1, 2, 2] + >>> await json.get(client, "doc") + b'[[],[\"a\"],[\"a\",\"b\"],[\"a\",\"b\"]]' + + >>> await json.set(client, "doc", "$", '{"children": ["John", "Jack", "Tom", "Bob", "Mike"]}') + 'OK' + >>> await json.arrtrim(client, "doc", ".children", 0, 1) + 2 + >>> await json.get(client, "doc", ".children") + b'["John","Jack"]' + """ + return cast( + TJsonResponse[int], + await client.custom_command(["JSON.ARRTRIM", key, path, str(start), str(end)]), + ) + + async def clear( client: TGlideClient, key: TEncodable, diff --git a/python/python/tests/tests_server_modules/test_json.py b/python/python/tests/tests_server_modules/test_json.py index c37019911b..5989cb9f2e 100644 --- a/python/python/tests/tests_server_modules/test_json.py +++ b/python/python/tests/tests_server_modules/test_json.py @@ -1191,3 +1191,106 @@ async def test_json_debug_memory(self, glide_client: TGlideClient): # Test for non-existent key result = await json.debug_memory(glide_client, "non_existent_key", ".key10") assert result == None + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @typing.no_type_check + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_arrtrim(self, glide_client: TGlideClient): + key = get_random_string(5) + + # Test with enhanced path syntax + json_value = '{"a": [0, 1, 2, 3, 4, 5, 6, 7, 8], "b": {"a": [0, 9, 10, 11, 12, 13], "c": {"a": 42}}}' + assert await json.set(glide_client, key, "$", json_value) == OK + + # Basic trim + assert await json.arrtrim(glide_client, key, "$..a", 1, 7) == [7, 5, None] + assert OuterJson.loads(await json.get(glide_client, key, "$..a")) == [ + [1, 2, 3, 4, 5, 6, 7], + [9, 10, 11, 12, 13], + 42, + ] + + # Test negative start (should be treated as 0) + assert await json.arrtrim(glide_client, key, "$.a", -1, 5) == [6] + assert OuterJson.loads(await json.get(glide_client, key, "$.a")) == [ + [1, 2, 3, 4, 5, 6] + ] + assert await json.arrtrim(glide_client, key, ".a", -1, 5) == 6 + assert OuterJson.loads(await json.get(glide_client, key, ".a")) == [ + 1, + 2, + 3, + 4, + 5, + 6, + ] + + # Test end >= size (should be treated as size-1) + assert await json.arrtrim(glide_client, key, "$.a", 0, 10) == [6] + assert OuterJson.loads(await json.get(glide_client, key, "$.a")) == [ + [1, 2, 3, 4, 5, 6] + ] + + assert await json.arrtrim(glide_client, key, ".a", 0, 10) == 6 + assert OuterJson.loads(await json.get(glide_client, key, ".a")) == [ + 1, + 2, + 3, + 4, + 5, + 6, + ] + + # Test start >= size (should empty the array) + assert await json.arrtrim(glide_client, key, "$.a", 7, 10) == [0] + assert OuterJson.loads(await json.get(glide_client, key, "$.a")) == [[]] + + assert await json.set(glide_client, key, ".a", '["a", "b", "c"]') == OK + assert await json.arrtrim(glide_client, key, ".a", 7, 10) == 0 + assert OuterJson.loads(await json.get(glide_client, key, ".a")) == [] + + # Test start > end (should empty the array) + assert await json.arrtrim(glide_client, key, "$..a", 2, 1) == [0, 0, None] + assert OuterJson.loads(await json.get(glide_client, key, "$..a")) == [ + [], + [], + 42, + ] + assert await json.set(glide_client, key, "..a", '["a", "b", "c", "d"]') == OK + assert await json.arrtrim(glide_client, key, "..a", 2, 1) == 0 + assert OuterJson.loads(await json.get(glide_client, key, ".a")) == [] + + # Multiple path match + assert await json.set(glide_client, key, "$", json_value) == OK + assert await json.arrtrim(glide_client, key, "..a", 1, 10) == 8 + assert OuterJson.loads(await json.get(glide_client, key, "$..a")) == [ + [1, 2, 3, 4, 5, 6, 7, 8], + [9, 10, 11, 12, 13], + 42, + ] + + # Test with non-existent path + with pytest.raises(RequestError): + await json.arrtrim(glide_client, key, ".non_existent", 0, 1) + + assert await json.arrtrim(glide_client, key, "$.non_existent", 0, 1) == [] + + # Test with non-array path + assert await json.arrtrim(glide_client, key, "$", 0, 1) == [None] + + with pytest.raises(RequestError): + await json.arrtrim(glide_client, key, ".", 0, 1) + + # Test with non-existent key + with pytest.raises(RequestError): + await json.arrtrim(glide_client, "non_existent_key", "$", 0, 1) + + # Test with non-existent key + with pytest.raises(RequestError): + await json.arrtrim(glide_client, "non_existent_key", ".", 0, 1) + + # Test empty array + assert await json.set(glide_client, key, "$.empty", "[]") == OK + assert await json.arrtrim(glide_client, key, "$.empty", 0, 1) == [0] + assert await json.arrtrim(glide_client, key, ".empty", 0, 1) == 0 + assert OuterJson.loads(await json.get(glide_client, key, "$.empty")) == [[]] From ee0474497ddc3a8c996ab662e96fe9b3ac0927e8 Mon Sep 17 00:00:00 2001 From: BoazBD Date: Wed, 23 Oct 2024 07:10:43 +0000 Subject: [PATCH 058/180] implement python json mget command Signed-off-by: BoazBD --- .../async_commands/server_modules/json.py | 44 +++++++ .../tests/tests_server_modules/test_json.py | 116 ++++++++++++++++++ 2 files changed, 160 insertions(+) diff --git a/python/python/glide/async_commands/server_modules/json.py b/python/python/glide/async_commands/server_modules/json.py index 40e4e80c51..4438c72c23 100644 --- a/python/python/glide/async_commands/server_modules/json.py +++ b/python/python/glide/async_commands/server_modules/json.py @@ -147,6 +147,50 @@ async def get( return cast(TJsonResponse[Optional[bytes]], await client.custom_command(args)) +async def mget( + client: TGlideClient, + keys: List[TEncodable], + paths: Optional[Union[TEncodable, List[TEncodable]]] = None, + options: Optional[JsonGetOptions] = None, +) -> Optional[List[bytes]]: + """ + Retrieves the JSON values at the specified `paths` stored at multiple `keys`. + + See https://valkey.io/commands/json.mget/ for more details. + + Args: + client (TGlideClient): The Redis client to execute the command. + keys (List[TEncodable]): A list of keys for the JSON documents. + paths (Optional[Union[TEncodable, List[TEncodable]]]): The path or list of paths within the JSON documents. Default is root `$`. + options (Optional[JsonGetOptions]): Options for formatting the byte representation of the JSON data. See `JsonGetOptions`. + + Returns: + Optional[List[bytes]]: A list of bytes representations of the returned values. + If a key doesn't exist, its corresponding entry will be `None`. + + Examples: + >>> from glide import json as redisJson + >>> import json + >>> json_strs = await redisJson.mget(client, ["doc1", "doc2"], ["$"]) + >>> [json.loads(js) for js in json_strs] # Parse JSON strings to Python data + [[{"a": 1.0, "b": 2}], [{"a": 2.0, "b": {"a": 3.0, "b" : 4.0}}]] # JSON objects retrieved from keys `doc1` and `doc2` + >>> await redisJson.mget(client, ["doc1", "doc2"], ["$.a"]) + [b"[1.0]", b"[2.0]"] # Returns values at path '$.a' for the JSON documents stored at `doc1` and `doc2`. + >>> await redisJson.mget(client, ["doc1"], ["$.non_existing_path"]) + [None] # Returns an empty array since the path '$.non_existing_path' does not exist in the JSON document stored at `doc1`. + """ + args = ["JSON.MGET"] + keys + if options: + args.extend(options.get_options()) + if paths: + if isinstance(paths, (str, bytes)): + paths = [paths] + args.extend(paths) + + results = await client.custom_command(args) + return [result if result is not None else None for result in results] + + async def arrinsert( client: TGlideClient, key: TEncodable, diff --git a/python/python/tests/tests_server_modules/test_json.py b/python/python/tests/tests_server_modules/test_json.py index 5989cb9f2e..9c86857032 100644 --- a/python/python/tests/tests_server_modules/test_json.py +++ b/python/python/tests/tests_server_modules/test_json.py @@ -143,6 +143,122 @@ async def test_json_get_formatting(self, glide_client: TGlideClient): expected_result = b'[\n~{\n~~"a":*1.0,\n~~"b":*2,\n~~"c":*{\n~~~"d":*3,\n~~~"e":*4\n~~}\n~}\n]' assert result == expected_result + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_mget(self, glide_client: TGlideClient): + key = get_random_string(5) + key1 = f"{{{key}}}1" + key2 = f"{{{key}}}2" + # The prefix ensures that both keys hash to the same slot + + json1_value = {"a": 1.0, "b": {"a": 1, "b": 2.5, "c": True}} + json2_value = {"a": 3.0, "b": {"a": 1, "b": 4}} + + assert ( + await json.set(glide_client, key1, "$", OuterJson.dumps(json1_value)) == OK + ) + assert ( + await json.set(glide_client, key2, "$", OuterJson.dumps(json2_value)) == OK + ) + + result = await json.mget( + glide_client, + [key1, key2], + ["$"], + ) + expected_result = [ + b'[{"a":1.0,"b":{"a":1,"b":2.5,"c":true}}]', + b'[{"a":3.0,"b":{"a":1,"b":4}}]', + ] + assert result == expected_result + + result = await json.mget( + glide_client, + [key1, key2], + ["."], + ) + expected_result = [ + b'{"a":1.0,"b":{"a":1,"b":2.5,"c":true}}', + b'{"a":3.0,"b":{"a":1,"b":4}}', + ] + assert result == expected_result + + result = await json.mget( + glide_client, + [key1, key2], + ["$.a"], + ) + expected_result = [b"[1.0]", b"[3.0]"] + assert result == expected_result + + result = await json.mget( + glide_client, + [key1, key2], + ["$.b"], + ) + expected_result = [b'[{"a":1,"b":2.5,"c":true}]', b'[{"a":1,"b":4}]'] + assert result == expected_result + + result = await json.mget( + glide_client, + [key1, key2], + ["$..b"], + ) + expected_result = [b'[{"a":1,"b":2.5,"c":true},2.5]', b'[{"a":1,"b":4},4]'] + assert result == expected_result + + result = await json.mget( + glide_client, + [key1, key2], + [".b.b"], + ) + expected_result = [b"2.5", b"4"] + assert result == expected_result + + # Path doesn't exist + result = await json.mget( + glide_client, + [key1, key2], + ["$non_existing_path"], + ) + expected_result = [b"[]", b"[]"] + assert result == expected_result + + # Keys don't exist + result = await json.mget( + glide_client, + ["{non_existing_key}1", "{non_existing_key}2"], + ["$a"], + ) + expected_result = [None, None] + assert result == expected_result + + # Test with only one key + result = await json.mget( + glide_client, + [key1], + ["$.a"], + ) + expected_result = [b"[1.0]"] + assert result == expected_result + + # Value at path isnt an object + result = await json.mget( + glide_client, + [key1, key2], + ["$.e"], + ) + expected_result = [b"[]", b"[]"] + assert result == expected_result + + # No path given + result = await json.mget( + glide_client, + [key1, key2], + ) + expected_result = [None] + assert result == expected_result + @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) async def test_json_del(self, glide_client: TGlideClient): From 0853c09a3525a3480b3c67890d4f92b274c15eac Mon Sep 17 00:00:00 2001 From: BoazBD Date: Wed, 23 Oct 2024 08:21:06 +0000 Subject: [PATCH 059/180] update changelog Signed-off-by: BoazBD --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c1bec5d60b..f3182c1ec7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ * Python: Add commands FT.ALIASADD, FT.ALIASDEL, FT.ALIASUPDATE([#2471](https://github.com/valkey-io/valkey-glide/pull/2471)) * Python: Python FT.DROPINDEX command ([#2437](https://github.com/valkey-io/valkey-glide/pull/2437)) * Python: Python: Added FT.CREATE command([#2413](https://github.com/valkey-io/valkey-glide/pull/2413)) +* Python: Add JSON.MGET command ([#2507](https://github.com/valkey-io/valkey-glide/pull/2507)) * Python: Add JSON.ARRLEN command ([#2403](https://github.com/valkey-io/valkey-glide/pull/2403)) * Python: Add JSON.CLEAR command ([#2418](https://github.com/valkey-io/valkey-glide/pull/2418)) * Python: Add JSON.TYPE command ([#2409](https://github.com/valkey-io/valkey-glide/pull/2409)) From 4133d049c0b3377304530b4aae14d43f56c0d9b0 Mon Sep 17 00:00:00 2001 From: BoazBD Date: Wed, 23 Oct 2024 08:24:50 +0000 Subject: [PATCH 060/180] fix type of mget Signed-off-by: BoazBD --- python/python/glide/async_commands/server_modules/json.py | 7 +++---- python/python/tests/tests_server_modules/test_json.py | 6 ++---- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/python/python/glide/async_commands/server_modules/json.py b/python/python/glide/async_commands/server_modules/json.py index 4438c72c23..0f5b0b8849 100644 --- a/python/python/glide/async_commands/server_modules/json.py +++ b/python/python/glide/async_commands/server_modules/json.py @@ -152,7 +152,7 @@ async def mget( keys: List[TEncodable], paths: Optional[Union[TEncodable, List[TEncodable]]] = None, options: Optional[JsonGetOptions] = None, -) -> Optional[List[bytes]]: +) -> List[Optional[bytes]]: """ Retrieves the JSON values at the specified `paths` stored at multiple `keys`. @@ -165,7 +165,7 @@ async def mget( options (Optional[JsonGetOptions]): Options for formatting the byte representation of the JSON data. See `JsonGetOptions`. Returns: - Optional[List[bytes]]: A list of bytes representations of the returned values. + List[Optional[bytes]]: A list of bytes representations of the returned values. If a key doesn't exist, its corresponding entry will be `None`. Examples: @@ -187,8 +187,7 @@ async def mget( paths = [paths] args.extend(paths) - results = await client.custom_command(args) - return [result if result is not None else None for result in results] + return cast(List[Optional[bytes]], await client.custom_command(args)) async def arrinsert( diff --git a/python/python/tests/tests_server_modules/test_json.py b/python/python/tests/tests_server_modules/test_json.py index 9c86857032..635bd0b422 100644 --- a/python/python/tests/tests_server_modules/test_json.py +++ b/python/python/tests/tests_server_modules/test_json.py @@ -230,8 +230,7 @@ async def test_json_mget(self, glide_client: TGlideClient): ["{non_existing_key}1", "{non_existing_key}2"], ["$a"], ) - expected_result = [None, None] - assert result == expected_result + assert result == [None, None] # Test with only one key result = await json.mget( @@ -256,8 +255,7 @@ async def test_json_mget(self, glide_client: TGlideClient): glide_client, [key1, key2], ) - expected_result = [None] - assert result == expected_result + assert result == [None] @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) From 099d77fd8a1d4c99a35f9beae378ff05724e5a02 Mon Sep 17 00:00:00 2001 From: BoazBD Date: Sun, 27 Oct 2024 09:02:33 +0000 Subject: [PATCH 061/180] fix signature and docstring Signed-off-by: BoazBD --- .../async_commands/server_modules/json.py | 36 +++++++++--------- .../tests/tests_server_modules/test_json.py | 38 +++++++++---------- 2 files changed, 37 insertions(+), 37 deletions(-) diff --git a/python/python/glide/async_commands/server_modules/json.py b/python/python/glide/async_commands/server_modules/json.py index 0f5b0b8849..63fef358ea 100644 --- a/python/python/glide/async_commands/server_modules/json.py +++ b/python/python/glide/async_commands/server_modules/json.py @@ -150,42 +150,42 @@ async def get( async def mget( client: TGlideClient, keys: List[TEncodable], - paths: Optional[Union[TEncodable, List[TEncodable]]] = None, - options: Optional[JsonGetOptions] = None, + path: Optional[TEncodable] = None, ) -> List[Optional[bytes]]: """ - Retrieves the JSON values at the specified `paths` stored at multiple `keys`. - - See https://valkey.io/commands/json.mget/ for more details. + Retrieves the JSON values at the specified `path` stored at multiple `keys`. + Note: + When in cluster mode, the command may route to multiple nodes when `keys` map to different hash slots. Args: client (TGlideClient): The Redis client to execute the command. keys (List[TEncodable]): A list of keys for the JSON documents. - paths (Optional[Union[TEncodable, List[TEncodable]]]): The path or list of paths within the JSON documents. Default is root `$`. - options (Optional[JsonGetOptions]): Options for formatting the byte representation of the JSON data. See `JsonGetOptions`. + path (Optional[TEncodable]): The path within the JSON documents. Default is root `$`. Returns: - List[Optional[bytes]]: A list of bytes representations of the returned values. - If a key doesn't exist, its corresponding entry will be `None`. + List[Optional[bytes]]: + For JSONPath (`path` starts with `$`): + Returns a list of byte representations of the values found at the given path for each key. If the path does not exist, + the entry will be an empty array. + For legacy path (`path` starts with `.`): + Returns a string representation of the value at the specified path. If the path does not exist, the entry will be None. + If a key doesn't exist, the corresponding list element will also be `None`. + Examples: >>> from glide import json as redisJson >>> import json - >>> json_strs = await redisJson.mget(client, ["doc1", "doc2"], ["$"]) + >>> json_strs = await redisJson.mget(client, ["doc1", "doc2"], "$") >>> [json.loads(js) for js in json_strs] # Parse JSON strings to Python data [[{"a": 1.0, "b": 2}], [{"a": 2.0, "b": {"a": 3.0, "b" : 4.0}}]] # JSON objects retrieved from keys `doc1` and `doc2` - >>> await redisJson.mget(client, ["doc1", "doc2"], ["$.a"]) + >>> await redisJson.mget(client, ["doc1", "doc2"], "$.a") [b"[1.0]", b"[2.0]"] # Returns values at path '$.a' for the JSON documents stored at `doc1` and `doc2`. - >>> await redisJson.mget(client, ["doc1"], ["$.non_existing_path"]) + >>> await redisJson.mget(client, ["doc1"], "$.non_existing_path") [None] # Returns an empty array since the path '$.non_existing_path' does not exist in the JSON document stored at `doc1`. """ args = ["JSON.MGET"] + keys - if options: - args.extend(options.get_options()) - if paths: - if isinstance(paths, (str, bytes)): - paths = [paths] - args.extend(paths) + if path: + args.append(path) return cast(List[Optional[bytes]], await client.custom_command(args)) diff --git a/python/python/tests/tests_server_modules/test_json.py b/python/python/tests/tests_server_modules/test_json.py index 635bd0b422..bb4495665e 100644 --- a/python/python/tests/tests_server_modules/test_json.py +++ b/python/python/tests/tests_server_modules/test_json.py @@ -164,7 +164,7 @@ async def test_json_mget(self, glide_client: TGlideClient): result = await json.mget( glide_client, [key1, key2], - ["$"], + "$", ) expected_result = [ b'[{"a":1.0,"b":{"a":1,"b":2.5,"c":true}}]', @@ -175,7 +175,7 @@ async def test_json_mget(self, glide_client: TGlideClient): result = await json.mget( glide_client, [key1, key2], - ["."], + ".", ) expected_result = [ b'{"a":1.0,"b":{"a":1,"b":2.5,"c":true}}', @@ -186,7 +186,7 @@ async def test_json_mget(self, glide_client: TGlideClient): result = await json.mget( glide_client, [key1, key2], - ["$.a"], + "$.a", ) expected_result = [b"[1.0]", b"[3.0]"] assert result == expected_result @@ -194,7 +194,7 @@ async def test_json_mget(self, glide_client: TGlideClient): result = await json.mget( glide_client, [key1, key2], - ["$.b"], + "$.b", ) expected_result = [b'[{"a":1,"b":2.5,"c":true}]', b'[{"a":1,"b":4}]'] assert result == expected_result @@ -202,7 +202,7 @@ async def test_json_mget(self, glide_client: TGlideClient): result = await json.mget( glide_client, [key1, key2], - ["$..b"], + "$..b", ) expected_result = [b'[{"a":1,"b":2.5,"c":true},2.5]', b'[{"a":1,"b":4},4]'] assert result == expected_result @@ -210,25 +210,34 @@ async def test_json_mget(self, glide_client: TGlideClient): result = await json.mget( glide_client, [key1, key2], - [".b.b"], + ".b.b", ) expected_result = [b"2.5", b"4"] assert result == expected_result - # Path doesn't exist + # JSONPath doesn't exist result = await json.mget( glide_client, [key1, key2], - ["$non_existing_path"], + "$non_existing_path", ) expected_result = [b"[]", b"[]"] assert result == expected_result + # Legacy path doesn't exist + result = await json.mget( + glide_client, + [key1, key2], + ".non_existing_path", + ) + expected_result = [None, None] + assert result == expected_result + # Keys don't exist result = await json.mget( glide_client, ["{non_existing_key}1", "{non_existing_key}2"], - ["$a"], + "$a", ) assert result == [None, None] @@ -236,20 +245,11 @@ async def test_json_mget(self, glide_client: TGlideClient): result = await json.mget( glide_client, [key1], - ["$.a"], + "$.a", ) expected_result = [b"[1.0]"] assert result == expected_result - # Value at path isnt an object - result = await json.mget( - glide_client, - [key1, key2], - ["$.e"], - ) - expected_result = [b"[]", b"[]"] - assert result == expected_result - # No path given result = await json.mget( glide_client, From 340cef786e2a00af6374b74cb8efdec0f673310e Mon Sep 17 00:00:00 2001 From: BoazBD Date: Sun, 27 Oct 2024 09:14:41 +0000 Subject: [PATCH 062/180] fix type test fail Signed-off-by: BoazBD --- python/python/tests/tests_server_modules/test_json.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/python/tests/tests_server_modules/test_json.py b/python/python/tests/tests_server_modules/test_json.py index bb4495665e..a7d0498686 100644 --- a/python/python/tests/tests_server_modules/test_json.py +++ b/python/python/tests/tests_server_modules/test_json.py @@ -230,8 +230,7 @@ async def test_json_mget(self, glide_client: TGlideClient): [key1, key2], ".non_existing_path", ) - expected_result = [None, None] - assert result == expected_result + assert result == [None, None] # Keys don't exist result = await json.mget( From 52dd474029b9747d342b1ad39e2a0563591b5fe5 Mon Sep 17 00:00:00 2001 From: BoazBD <50696333+BoazBD@users.noreply.github.com> Date: Mon, 28 Oct 2024 11:18:54 +0200 Subject: [PATCH 063/180] Python - Implement JSON.OBJLEN command functionality (#2495) --------- Signed-off-by: BoazBD Signed-off-by: Shoham Elias <116083498+shohamazon@users.noreply.github.com> Co-authored-by: Shoham Elias <116083498+shohamazon@users.noreply.github.com> --- CHANGELOG.md | 1 + .../async_commands/server_modules/json.py | 52 +++++++++++++++++++ .../tests/tests_server_modules/test_json.py | 50 ++++++++++++++++++ 3 files changed, 103 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c1bec5d60b..18a23c1d3d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ #### Changes +* Python: Add JSON.OBJLEN command ([#2495](https://github.com/valkey-io/valkey-glide/pull/2495)) * Python: FT.EXPLAIN and FT.EXPLAINCLI commands added([#2508](https://github.com/valkey-io/valkey-glide/pull/2508)) * Python: Python FT.INFO command added([#2429](https://github.com/valkey-io/valkey-glide/pull/2494)) * Python: Add FT.SEARCH command([#2470](https://github.com/valkey-io/valkey-glide/pull/2470)) diff --git a/python/python/glide/async_commands/server_modules/json.py b/python/python/glide/async_commands/server_modules/json.py index 40e4e80c51..eb75b6bfb3 100644 --- a/python/python/glide/async_commands/server_modules/json.py +++ b/python/python/glide/async_commands/server_modules/json.py @@ -615,6 +615,58 @@ async def nummultby( return cast(Optional[bytes], await client.custom_command(args)) +async def objlen( + client: TGlideClient, + key: TEncodable, + path: Optional[TEncodable] = None, +) -> Optional[TJsonResponse[int]]: + """ + Retrieves the number of key-value pairs in the object stored at the specified `path` within the JSON document stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (Optional[TEncodable]): The path within the JSON document. Defaults to None. + + Returns: + Optional[TJsonResponse[int]]: + For JSONPath (`path` starts with `$`): + Returns a list of integer replies for every possible path, indicating the length of the object, + or None for JSON values matching the path that are not an object. + If `path` doesn't exist, an empty array will be returned. + For legacy path (`path` doesn't starts with `$`): + Returns the length of the object at `path`. + If multiple paths match, the length of the first object match is returned. + If the JSON value at `path` is not an object or if `path` doesn't exist, an error is raised. + If `key` doesn't exist, None is returned. + + + Examples: + >>> from glide import json + >>> await json.set(client, "doc", "$", '{"a": 1.0, "b": {"a": {"x": 1, "y": 2}, "b": 2.5, "c": true}}') + b'OK' # Indicates successful setting of the value at the root path '$' in the key `doc`. + >>> await json.objlen(client, "doc", "$") + [2] # Returns the number of key-value pairs at the root object, which has 2 keys: 'a' and 'b'. + >>> await json.objlen(client, "doc", ".") + 2 # Returns the number of key-value pairs for the object matching the path '.', which has 2 keys: 'a' and 'b'. + >>> await json.objlen(client, "doc", "$.b") + [3] # Returns the length of the object at path '$.b', which has 3 keys: 'a', 'b', and 'c'. + >>> await json.objlen(client, "doc", ".b") + 3 # Returns the length of the nested object at path '.b', which has 3 keys. + >>> await json.objlen(client, "doc", "$..a") + [None, 2] + >>> await json.objlen(client, "doc") + 2 # Returns the number of key-value pairs for the object matching the path '.', which has 2 keys: 'a' and 'b'. + """ + args = ["JSON.OBJLEN", key] + if path: + args.append(path) + return cast( + Optional[TJsonResponse[int]], + await client.custom_command(args), + ) + + async def objkeys( client: TGlideClient, key: TEncodable, diff --git a/python/python/tests/tests_server_modules/test_json.py b/python/python/tests/tests_server_modules/test_json.py index 5989cb9f2e..c1ad3dfbfc 100644 --- a/python/python/tests/tests_server_modules/test_json.py +++ b/python/python/tests/tests_server_modules/test_json.py @@ -357,6 +357,56 @@ async def test_json_type(self, glide_client: TGlideClient): result = await json.type(glide_client, key, "[*]") assert result == b"string" # Expecting only the first type (string for key1) + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_objlen(self, glide_client: TGlideClient): + key = get_random_string(5) + + json_value = {"a": 1.0, "b": {"a": {"x": 1, "y": 2}, "b": 2.5, "c": True}} + + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + + len = await json.objlen(glide_client, key, "$") + assert len == [2] + + len = await json.objlen(glide_client, key, ".") + assert len == 2 + + len = await json.objlen(glide_client, key, "$..") + assert len == [2, 3, 2] + + len = await json.objlen(glide_client, key, "..") + assert len == 2 + + len = await json.objlen(glide_client, key, "$..b") + assert len == [3, None] + + len = await json.objlen(glide_client, key, "..b") + assert len == 3 + + len = await json.objlen(glide_client, key, "..a") + assert len == 2 + + len = await json.objlen(glide_client, key) + assert len == 2 + + # path doesn't exist + assert await json.objlen(glide_client, key, "$.non_existing_path") == [] + with pytest.raises(RequestError): + await json.objlen(glide_client, key, "non_existing_path") + + # Value at path isnt an object + assert await json.objlen(glide_client, key, "$.a") == [None] + with pytest.raises(RequestError): + await json.objlen(glide_client, key, ".a") + + # Non-existing key + assert await json.objlen(glide_client, "non_exiting_key", "$") == None + assert await json.objlen(glide_client, "non_exiting_key", ".") == None + + assert await json.set(glide_client, key, "$", '{"a": 1, "b": 2, "c":3, "d":4}') + assert await json.objlen(glide_client, key) == 4 + @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) async def test_json_arrlen(self, glide_client: TGlideClient): From 98df6916b9ff534affd77c160b351cc01424cff2 Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Mon, 28 Oct 2024 08:44:32 -0700 Subject: [PATCH 064/180] Java: `JSON.ARRPOP`. (#2486) * `JSON.ARRPOP`. Signed-off-by: Yury-Fridlyand --- CHANGELOG.md | 1 + .../api/commands/servermodules/Json.java | 201 ++++++++++++++++++ .../test/java/glide/modules/JsonTests.java | 79 ++++--- 3 files changed, 255 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 18a23c1d3d..664786838a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ * Java: Added `JSON.SET` and `JSON.GET` ([#2462](https://github.com/valkey-io/valkey-glide/pull/2462)) * Node: Added `FT.CREATE` ([#2501](https://github.com/valkey-io/valkey-glide/pull/2501)) * Java: Added `JSON.ARRINSERT` and `JSON.ARRLEN` ([#2476](https://github.com/valkey-io/valkey-glide/pull/2476)) +* Java: Added `JSON.ARRPOP` ([#2486](https://github.com/valkey-io/valkey-glide/pull/2486)) * Java: Added `JSON.OBJLEN` and `JSON.OBJKEYS` ([#2492](https://github.com/valkey-io/valkey-glide/pull/2492)) * Java: Added `JSON.DEL` and `JSON.FORGET` ([#2490](https://github.com/valkey-io/valkey-glide/pull/2490)) * Java: Added `FT.ALIASADD`, `FT.ALIASDEL`, `FT.ALIASUPDATE` ([#2442](https://github.com/valkey-io/valkey-glide/pull/2442)) diff --git a/java/client/src/main/java/glide/api/commands/servermodules/Json.java b/java/client/src/main/java/glide/api/commands/servermodules/Json.java index fd473c6e85..5fd39aba1e 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/Json.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/Json.java @@ -25,6 +25,7 @@ public class Json { private static final String JSON_ARRAPPEND = JSON_PREFIX + "ARRAPPEND"; private static final String JSON_ARRINSERT = JSON_PREFIX + "ARRINSERT"; private static final String JSON_ARRLEN = JSON_PREFIX + "ARRLEN"; + private static final String JSON_ARRPOP = JSON_PREFIX + "ARRPOP"; private static final String JSON_ARRTRIM = JSON_PREFIX + "ARRTRIM"; private static final String JSON_OBJLEN = JSON_PREFIX + "OBJLEN"; private static final String JSON_OBJKEYS = JSON_PREFIX + "OBJKEYS"; @@ -715,6 +716,206 @@ public static CompletableFuture arrlen( return executeCommand(client, new GlideString[] {gs(JSON_ARRLEN), key}); } + /** + * Pops the last element from the array stored in the root of the JSON document stored at + * key. Equivalent to {@link #arrpop(BaseClient, String, String)} with + * path set to ".". + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return Returns a string representing the popped JSON value, or null if the array + * at document root is empty.
        + * If the JSON value at document root is not an array or if key doesn't exist, an + * error is raised. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "[1, 2, true, {\"a\": 42, \"b\": 33}, \"tree\"]").get();
        +     * var res = Json.arrpop(client, "doc").get();
        +     * assert res.equals("\"tree\"");
        +     * res = Json.arrpop(client, "doc").get();
        +     * assert res.equals("{\"a\": 42, \"b\": 33}");
        +     * }
        + */ + public static CompletableFuture arrpop(@NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new String[] {JSON_ARRPOP, key}); + } + + /** + * Pops the last element from the array located in the root of the JSON document stored at + * key. Equivalent to {@link #arrpop(BaseClient, GlideString, GlideString)} with + * path set to gs("."). + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return Returns a string representing the popped JSON value, or null if the array + * at document root is empty.
        + * If the JSON value at document root is not an array or if key doesn't exist, an + * error is raised. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "[1, 2, true, {\"a\": 42, \"b\": 33}, \"tree\"]").get();
        +     * var res = Json.arrpop(client, gs("doc")).get();
        +     * assert res.equals(gs("\"tree\""));
        +     * res = Json.arrpop(client, gs("doc")).get();
        +     * assert res.equals(gs("{\"a\": 42, \"b\": 33}"));
        +     * }
        + */ + public static CompletableFuture arrpop( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new GlideString[] {gs(JSON_ARRPOP), key}); + } + + /** + * Pops the last element from the array located at path in the JSON document stored + * at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
          + *
        • For JSONPath (path starts with $):
          + * Returns an array with a strings for every possible path, representing the popped JSON + * values, or null for JSON values matching the path that are not an array + * or an empty array. If a value is not an array, its corresponding return value is + * "null". + *
        • For legacy path (path doesn't start with $):
          + * Returns a string representing the popped JSON value, or null if the + * array at path is empty. If multiple paths are matched, the value from + * the first matching array that is not empty is returned. If path doesn't + * exist or the value at path is not an array, an error is raised. + *
        + * If key doesn't exist, an error is raised. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "[1, 2, true, {\"a\": 42, \"b\": 33}, \"tree\"]").get();
        +     * var res = Json.arrpop(client, "doc", "$").get();
        +     * assert Arrays.equals((Object[]) res, new Object[] { "\"tree\"" });
        +     * res = Json.arrpop(client, "doc", ".").get();
        +     * assert res.equals("{\"a\": 42, \"b\": 33}");
        +     * }
        + */ + public static CompletableFuture arrpop( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, new String[] {JSON_ARRPOP, key, path}); + } + + /** + * Pops the last element from the array located at path in the JSON document stored + * at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
          + *
        • For JSONPath (path starts with $):
          + * Returns an array with a strings for every possible path, representing the popped JSON + * values, or null for JSON values matching the path that are not an array + * or an empty array. If a value is not an array, its corresponding return value is + * "null". + *
        • For legacy path (path doesn't start with $):
          + * Returns a string representing the popped JSON value, or null if the + * array at path is empty. If multiple paths are matched, the value from + * the first matching array that is not empty is returned. If path doesn't + * exist or the value at path is not an array, an error is raised. + *
        + * If key doesn't exist, an error is raised. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "[1, 2, true, {\"a\": 42, \"b\": 33}, \"tree\"]").get();
        +     * var res = Json.arrpop(client, gs("doc"), gs("$")).get();
        +     * assert Arrays.equals((Object[]) res, new Object[] { gs("\"tree\"") });
        +     * res = Json.arrpop(client, gs("doc"), gs(".")).get();
        +     * assert res.equals(gs("{\"a\": 42, \"b\": 33}"));
        +     * }
        + */ + public static CompletableFuture arrpop( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand(client, new GlideString[] {gs(JSON_ARRPOP), key, path}); + } + + /** + * Pops an element from the array located at path in the JSON document stored at + * key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param index The index of the element to pop. Out of boundary indexes are rounded to their + * respective array boundaries. + * @return + *
          + *
        • For JSONPath (path starts with $):
          + * Returns an array with a strings for every possible path, representing the popped JSON + * values, or null for JSON values matching the path that are not an array + * or an empty array. If a value is not an array, its corresponding return value is + * "null". + *
        • For legacy path (path doesn't start with $):
          + * Returns a string representing the popped JSON value, or null if the + * array at path is empty. If multiple paths are matched, the value from + * the first matching array that is not empty is returned. If path doesn't + * exist or the value at path is not an array, an error is raised. + *
        + * If key doesn't exist, an error is raised. + * @example + *
        {@code
        +     * String doc = "{\"a\": [1, 2, true], \"b\": {\"a\": [3, 4, [\"value\", 3, false], 5], \"c\": {\"a\": 42}}}";
        +     * Json.set(client, "doc", "$", doc).get();
        +     * var res = Json.arrpop(client, "doc", "$.a", 1).get();
        +     * assert res.equals("2"); // Pop second element from array at path `$.a`
        +     *
        +     * Json.set(client, "doc", "$", "[[], [\"a\"], [\"a\", \"b\", \"c\"]]").get();
        +     * res = Json.arrpop(client, "doc", ".", -1).get());
        +     * assert res.equals("[\"a\", \"b\", \"c\"]"); // Pop last elements at path `.`
        +     * }
        + */ + public static CompletableFuture arrpop( + @NonNull BaseClient client, @NonNull String key, @NonNull String path, long index) { + return executeCommand(client, new String[] {JSON_ARRPOP, key, path, Long.toString(index)}); + } + + /** + * Pops an element from the array located at path in the JSON document stored at + * key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param index The index of the element to pop. Out of boundary indexes are rounded to their + * respective array boundaries. + * @return + *
          + *
        • For JSONPath (path starts with $):
          + * Returns an array with a strings for every possible path, representing the popped JSON + * values, or null for JSON values matching the path that are not an array + * or an empty array. If a value is not an array, its corresponding return value is + * "null". + *
        • For legacy path (path doesn't start with $):
          + * Returns a string representing the popped JSON value, or null if the + * array at path is empty. If multiple paths are matched, the value from + * the first matching array that is not empty is returned. If path doesn't + * exist or the value at path is not an array, an error is raised. + *
        + * If key doesn't exist, an error is raised. + * @example + *
        {@code
        +     * String doc = "{\"a\": [1, 2, true], \"b\": {\"a\": [3, 4, [\"value\", 3, false], 5], \"c\": {\"a\": 42}}}";
        +     * Json.set(client, "doc", "$", doc).get();
        +     * var res = Json.arrpop(client, gs("doc"), gs("$.a"), 1).get();
        +     * assert res.equals("2"); // Pop second element from array at path `$.a`
        +     *
        +     * Json.set(client, "doc", "$", "[[], [\"a\"], [\"a\", \"b\", \"c\"]]").get();
        +     * res = Json.arrpop(client, gs("doc"), gs("."), -1).get());
        +     * assert res.equals(gs("[\"a\", \"b\", \"c\"]")); // Pop last elements at path `.`
        +     * }
        + */ + public static CompletableFuture arrpop( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path, long index) { + return executeCommand( + client, new GlideString[] {gs(JSON_ARRPOP), key, path, gs(Long.toString(index))}); + } + /** * Trims an array at the specified path within the JSON document started at key * so that it becomes a subarray [start, end], both inclusive. diff --git a/java/integTest/src/test/java/glide/modules/JsonTests.java b/java/integTest/src/test/java/glide/modules/JsonTests.java index 69a609a65f..ed5336844a 100644 --- a/java/integTest/src/test/java/glide/modules/JsonTests.java +++ b/java/integTest/src/test/java/glide/modules/JsonTests.java @@ -188,27 +188,22 @@ public void arrappend() { Json.arrappend(client, gs(key), gs("$.c"), new GlideString[] {gs("\"value\"")}).get()); // Legacy path, path doesn't exist - var exception = - assertThrows( - ExecutionException.class, - () -> Json.arrappend(client, key, ".c", new String[] {"\"value\""}).get()); + assertThrows( + ExecutionException.class, + () -> Json.arrappend(client, key, ".c", new String[] {"\"value\""}).get()); // Legacy path, the JSON value at path is not a array - exception = - assertThrows( - ExecutionException.class, - () -> Json.arrappend(client, key, ".a", new String[] {"\"value\""}).get()); + assertThrows( + ExecutionException.class, + () -> Json.arrappend(client, key, ".a", new String[] {"\"value\""}).get()); - exception = - assertThrows( - ExecutionException.class, - () -> - Json.arrappend(client, "non_existing_key", "$.b", new String[] {"\"six\""}).get()); + assertThrows( + ExecutionException.class, + () -> Json.arrappend(client, "non_existing_key", "$.b", new String[] {"\"six\""}).get()); - exception = - assertThrows( - ExecutionException.class, - () -> Json.arrappend(client, "non_existing_key", ".b", new String[] {"\"six\""}).get()); + assertThrows( + ExecutionException.class, + () -> Json.arrappend(client, "non_existing_key", ".b", new String[] {"\"six\""}).get()); } @Test @@ -311,7 +306,7 @@ public void arrlen() { assertArrayEquals(new Object[] {3L, 2L, null}, (Object[]) res); // Legacy path retrieves the first array match at ..a - res = Json.arrlen(client, key, "..a").get(); + res = Json.arrlen(client, gs(key), gs("..a")).get(); assertEquals(3L, res); doc = "[1, 2, true, null, \"tree\"]"; @@ -320,6 +315,42 @@ public void arrlen() { // no path res = Json.arrlen(client, key).get(); assertEquals(5L, res); + res = Json.arrlen(client, gs(key)).get(); + assertEquals(5L, res); + } + + @Test + @SneakyThrows + public void arrpop() { + String key = UUID.randomUUID().toString(); + String doc = + "{\"a\": [1, 2, true], \"b\": {\"a\": [3, 4, [\"value\", 3, false], 5], \"c\": {\"a\":" + + " 42}}}"; + assertEquals(OK, Json.set(client, key, "$", doc).get()); + + var res = Json.arrpop(client, key, "$.a", 1).get(); + assertArrayEquals(new Object[] {"2"}, (Object[]) res); + + res = Json.arrpop(client, gs(key), gs("$..a")).get(); + assertArrayEquals(new Object[] {gs("true"), gs("5"), null}, (Object[]) res); + + res = Json.arrpop(client, key, "..a").get(); + assertEquals("1", res); + + // Even if only one array element was returned, ensure second array at `..a` was popped + doc = Json.get(client, key, new String[] {"$..a"}).get(); + assertEquals("[[],[3,4],42]", doc); + + // Out of index + res = Json.arrpop(client, key, "$..a", 10).get(); + assertArrayEquals(new Object[] {null, "4", null}, (Object[]) res); + + // pop without options + assertEquals(OK, Json.set(client, key, "$", doc).get()); + res = Json.arrpop(client, key).get(); + assertEquals("42", res); + res = Json.arrpop(client, gs(key)).get(); + assertEquals(gs("[3,4]"), res); } @Test @@ -541,14 +572,10 @@ public void toggle() { assertEquals(true, Json.toggle(client, gs(key2)).get()); // expect request errors - var exception = - assertThrows(ExecutionException.class, () -> Json.toggle(client, key, "nested").get()); - exception = - assertThrows( - ExecutionException.class, () -> Json.toggle(client, key, ".non_existing").get()); - exception = - assertThrows( - ExecutionException.class, () -> Json.toggle(client, "non_existing_key", "$").get()); + assertThrows(ExecutionException.class, () -> Json.toggle(client, key, "nested").get()); + assertThrows(ExecutionException.class, () -> Json.toggle(client, key, ".non_existing").get()); + assertThrows( + ExecutionException.class, () -> Json.toggle(client, "non_existing_key", "$").get()); } @Test From cb21081a8cb011b0d1075d6a55cf13d375f9af2e Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Mon, 28 Oct 2024 14:03:34 -0700 Subject: [PATCH 065/180] Java: `FT.PROFILE` (#2473) * `FT.PROFILE`. Signed-off-by: Yury-Fridlyand --- CHANGELOG.md | 1 + glide-core/src/client/value_conversion.rs | 51 ++++++- .../glide/api/commands/servermodules/FT.java | 50 ++++++- .../commands/FT/FTAggregateOptions.java | 39 +++--- .../models/commands/FT/FTCreateOptions.java | 2 +- .../models/commands/FT/FTProfileOptions.java | 126 ++++++++++++++++++ .../models/commands/FT/FTSearchOptions.java | 10 +- .../java/glide/modules/VectorSearchTests.java | 106 +++++++-------- 8 files changed, 305 insertions(+), 80 deletions(-) create mode 100644 java/client/src/main/java/glide/api/models/commands/FT/FTProfileOptions.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 664786838a..6e8f7e3bea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ * Java: Added `FT.DROPINDEX` ([#2440](https://github.com/valkey-io/valkey-glide/pull/2440)) * Java: Added `FT.SEARCH` ([#2439](https://github.com/valkey-io/valkey-glide/pull/2439)) * Java: Added `FT.AGGREGATE` ([#2466](https://github.com/valkey-io/valkey-glide/pull/2466)) +* Java: Added `FT.PROFILE` ([#2473](https://github.com/valkey-io/valkey-glide/pull/2473)) * Java: Added `JSON.SET` and `JSON.GET` ([#2462](https://github.com/valkey-io/valkey-glide/pull/2462)) * Node: Added `FT.CREATE` ([#2501](https://github.com/valkey-io/valkey-glide/pull/2501)) * Java: Added `JSON.ARRINSERT` and `JSON.ARRLEN` ([#2476](https://github.com/valkey-io/valkey-glide/pull/2476)) diff --git a/glide-core/src/client/value_conversion.rs b/glide-core/src/client/value_conversion.rs index d16d4ef939..f4706e762a 100644 --- a/glide-core/src/client/value_conversion.rs +++ b/glide-core/src/client/value_conversion.rs @@ -24,6 +24,7 @@ pub(crate) enum ExpectedReturnType<'a> { ArrayOfDoubleOrNull, FTAggregateReturnType, FTSearchReturnType, + FTProfileReturnType(&'a Option>), FTInfoReturnType, Lolwut, ArrayOfStringAndArrays, @@ -939,7 +940,7 @@ pub(crate) fn convert_to_expected_type( let Value::Array(fields) = aggregation else { return Err(( ErrorKind::TypeError, - "Response couldn't be converted for FT.AGGREGATION", + "Response couldn't be converted for FT.AGGREGATE", format!("(`fields` was {:?})", get_value_type(&aggregation)), ) .into()); @@ -954,7 +955,7 @@ pub(crate) fn convert_to_expected_type( } _ => Err(( ErrorKind::TypeError, - "Response couldn't be converted to FT.AGGREGATION", + "Response couldn't be converted for FT.AGGREGATE", format!("(response was {:?})", get_value_type(&value)), ) .into()), @@ -1106,6 +1107,44 @@ pub(crate) fn convert_to_expected_type( ) .into()) }, + ExpectedReturnType::FTProfileReturnType(type_of_query) => match value { + /* + Example of the response + 1) + 2) 1) 1) "parse.time" + 2) 119 + 2) 1) "all.count" + 2) 4 + 3) 1) "sync.time" + 2) 0 + ... + + Converting response to + 1) + 2) 1# "parse.time" => 119 + 2# "all.count" => 4 + 3# "sync.time" => 0 + ... + + Converting first array element as it is needed for the inner query and second element to a map. + */ + Value::Array(mut array) if array.len() == 2 => { + let res = vec![ + convert_to_expected_type(array.remove(0), *type_of_query)?, + convert_to_expected_type(array.remove(0), Some(ExpectedReturnType::Map { + key_type: &None, + value_type: &None, + }))?]; + + Ok(Value::Array(res)) + }, + _ => Err(( + ErrorKind::TypeError, + "Response couldn't be converted for FT.PROFILE", + format!("(response was {:?})", get_value_type(&value)), + ) + .into()) + } } } @@ -1472,6 +1511,14 @@ pub(crate) fn expected_type_for_cmd(cmd: &Cmd) -> Option { }), b"FT.AGGREGATE" => Some(ExpectedReturnType::FTAggregateReturnType), b"FT.SEARCH" => Some(ExpectedReturnType::FTSearchReturnType), + // TODO replace with tuple + b"FT.PROFILE" => Some(ExpectedReturnType::FTProfileReturnType( + if cmd.arg_idx(2).is_some_and(|a| a == b"SEARCH") { + &Some(ExpectedReturnType::FTSearchReturnType) + } else { + &Some(ExpectedReturnType::FTAggregateReturnType) + }, + )), b"FT.INFO" => Some(ExpectedReturnType::FTInfoReturnType), _ => None, } diff --git a/java/client/src/main/java/glide/api/commands/servermodules/FT.java b/java/client/src/main/java/glide/api/commands/servermodules/FT.java index 0250865648..7a5dbb8714 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/FT.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/FT.java @@ -13,6 +13,7 @@ import glide.api.models.commands.FT.FTAggregateOptions; import glide.api.models.commands.FT.FTCreateOptions; import glide.api.models.commands.FT.FTCreateOptions.FieldInfo; +import glide.api.models.commands.FT.FTProfileOptions; import glide.api.models.commands.FT.FTSearchOptions; import java.util.Arrays; import java.util.Map; @@ -480,6 +481,54 @@ public static CompletableFuture[]> aggregate( .thenApply(res -> castArray(res, Map.class)); } + /** + * Runs a search or aggregation query and collects performance profiling information. + * + * @param client The client to execute the command. + * @param indexName The index name. + * @param options Querying and profiling parameters - see {@link FTProfileOptions}. + * @return A two-element array. The first element contains results of query being profiled, the + * second element stores profiling information. + * @example + *
        {@code
        +     * var options = FTSearchOptions.builder().params(Map.of(
        +     *         gs("query_vec"),
        +     *         gs(new byte[] { (byte) 0, (byte) 0, (byte) 0, (byte) 0 })))
        +     *     .build();
        +     * var result = FT.profile(client, "myIndex", new FTProfileOptions("*=>[KNN 2 @VEC $query_vec]", options)).get();
        +     * // result[0] contains `FT.SEARCH` response with the given options and query
        +     * // result[1] contains profiling data as a `Map`
        +     * }
        + */ + public static CompletableFuture profile( + @NonNull BaseClient client, @NonNull String indexName, @NonNull FTProfileOptions options) { + return profile(client, gs(indexName), options); + } + + /** + * Runs a search or aggregation query and collects performance profiling information. + * + * @param client The client to execute the command. + * @param indexName The index name. + * @param options Querying and profiling parameters - see {@link FTProfileOptions}. + * @return A two-element array. The first element contains results of query being profiled, the + * second element stores profiling information. + * @example + *
        {@code
        +     * var commandLine = new String[] { "*", "LOAD", "1", "__key", "GROUPBY", "1", "@condition", "REDUCE", "COUNT", "0", "AS", "bicylces" };
        +     * var result = FT.profile(client, gs("myIndex"), new FTProfileOptions(QueryType.AGGREGATE, commandLine)).get();
        +     * // result[0] contains `FT.AGGREGATE` response with the given command line
        +     * // result[1] contains profiling data as a `Map`
        +     * }
        + */ + public static CompletableFuture profile( + @NonNull BaseClient client, + @NonNull GlideString indexName, + @NonNull FTProfileOptions options) { + var args = concatenateArrays(new GlideString[] {gs("FT.PROFILE"), indexName}, options.toArgs()); + return executeCommand(client, args, false); + } + /** * Returns information about a given index. * @@ -699,7 +748,6 @@ public static CompletableFuture aliasupdate( public static CompletableFuture aliasupdate( @NonNull BaseClient client, @NonNull GlideString aliasName, @NonNull GlideString indexName) { var args = new GlideString[] {gs("FT.ALIASUPDATE"), aliasName, indexName}; - return executeCommand(client, args, false); } diff --git a/java/client/src/main/java/glide/api/models/commands/FT/FTAggregateOptions.java b/java/client/src/main/java/glide/api/models/commands/FT/FTAggregateOptions.java index 73ffdbf412..700695b7ae 100644 --- a/java/client/src/main/java/glide/api/models/commands/FT/FTAggregateOptions.java +++ b/java/client/src/main/java/glide/api/models/commands/FT/FTAggregateOptions.java @@ -14,6 +14,7 @@ import java.util.Map; import java.util.stream.Stream; import lombok.Builder; +import lombok.NonNull; /** * Additional arguments for {@link FT#aggregate(BaseClient, String, String, FTAggregateOptions)} @@ -79,19 +80,19 @@ public FTAggregateOptionsBuilder loadAll() { return this; } - public FTAggregateOptionsBuilder loadFields(String[] fields) { + public FTAggregateOptionsBuilder loadFields(@NonNull String[] fields) { loadFields = toGlideStringArray(fields); loadAll = false; return this; } - public FTAggregateOptionsBuilder loadFields(GlideString[] fields) { + public FTAggregateOptionsBuilder loadFields(@NonNull GlideString[] fields) { loadFields = fields; loadAll = false; return this; } - public FTAggregateOptionsBuilder addExpression(FTAggregateExpression expression) { + public FTAggregateOptionsBuilder addExpression(@NonNull FTAggregateExpression expression) { if (expressions == null) expressions = new ArrayList<>(); expressions.add(expression); return this; @@ -138,11 +139,11 @@ GlideString[] toArgs() { public static class Filter extends FTAggregateExpression { private final GlideString expression; - public Filter(GlideString expression) { + public Filter(@NonNull GlideString expression) { this.expression = expression; } - public Filter(String expression) { + public Filter(@NonNull String expression) { this.expression = gs(expression); } @@ -160,22 +161,22 @@ public static class GroupBy extends FTAggregateExpression { private final GlideString[] properties; private final Reducer[] reducers; - public GroupBy(GlideString[] properties, Reducer[] reducers) { + public GroupBy(@NonNull GlideString[] properties, @NonNull Reducer[] reducers) { this.properties = properties; this.reducers = reducers; } - public GroupBy(String[] properties, Reducer[] reducers) { + public GroupBy(@NonNull String[] properties, @NonNull Reducer[] reducers) { this.properties = toGlideStringArray(properties); this.reducers = reducers; } - public GroupBy(GlideString[] properties) { + public GroupBy(@NonNull GlideString[] properties) { this.properties = properties; this.reducers = new Reducer[0]; } - public GroupBy(String[] properties) { + public GroupBy(@NonNull String[] properties) { this.properties = toGlideStringArray(properties); this.reducers = new Reducer[0]; } @@ -199,25 +200,25 @@ public static class Reducer { private final GlideString[] args; private final String alias; - public Reducer(String function, GlideString[] args, String alias) { + public Reducer(@NonNull String function, @NonNull GlideString[] args, @NonNull String alias) { this.function = function; this.args = args; this.alias = alias; } - public Reducer(String function, GlideString[] args) { + public Reducer(@NonNull String function, @NonNull GlideString[] args) { this.function = function; this.args = args; this.alias = null; } - public Reducer(String function, String[] args, String alias) { + public Reducer(@NonNull String function, @NonNull String[] args, @NonNull String alias) { this.function = function; this.args = toGlideStringArray(args); this.alias = alias; } - public Reducer(String function, String[] args) { + public Reducer(@NonNull String function, @NonNull String[] args) { this.function = function; this.args = toGlideStringArray(args); this.alias = null; @@ -240,12 +241,12 @@ public static class SortBy extends FTAggregateExpression { private final SortProperty[] properties; private final Integer max; - public SortBy(SortProperty[] properties) { + public SortBy(@NonNull SortProperty[] properties) { this.properties = properties; this.max = null; } - public SortBy(SortProperty[] properties, int max) { + public SortBy(@NonNull SortProperty[] properties, int max) { this.properties = properties; this.max = max; } @@ -273,12 +274,12 @@ public static class SortProperty { private final GlideString property; private final SortOrder order; - public SortProperty(GlideString property, SortOrder order) { + public SortProperty(@NonNull GlideString property, @NonNull SortOrder order) { this.property = property; this.order = order; } - public SortProperty(String property, SortOrder order) { + public SortProperty(@NonNull String property, @NonNull SortOrder order) { this.property = gs(property); this.order = order; } @@ -297,12 +298,12 @@ public static class Apply extends FTAggregateExpression { private final GlideString expression; private final GlideString alias; - public Apply(GlideString expression, GlideString alias) { + public Apply(@NonNull GlideString expression, @NonNull GlideString alias) { this.expression = expression; this.alias = alias; } - public Apply(String expression, String alias) { + public Apply(@NonNull String expression, @NonNull String alias) { this.expression = gs(expression); this.alias = gs(alias); } diff --git a/java/client/src/main/java/glide/api/models/commands/FT/FTCreateOptions.java b/java/client/src/main/java/glide/api/models/commands/FT/FTCreateOptions.java index ae651ee2a6..81bb9e1dce 100644 --- a/java/client/src/main/java/glide/api/models/commands/FT/FTCreateOptions.java +++ b/java/client/src/main/java/glide/api/models/commands/FT/FTCreateOptions.java @@ -54,7 +54,7 @@ public GlideString[] toArgs() { } public static class FTCreateOptionsBuilder { - public FTCreateOptionsBuilder prefixes(String[] prefixes) { + public FTCreateOptionsBuilder prefixes(@NonNull String[] prefixes) { this.prefixes = Stream.of(prefixes).map(GlideString::gs).toArray(GlideString[]::new); return this; } diff --git a/java/client/src/main/java/glide/api/models/commands/FT/FTProfileOptions.java b/java/client/src/main/java/glide/api/models/commands/FT/FTProfileOptions.java new file mode 100644 index 0000000000..5d9b7e892d --- /dev/null +++ b/java/client/src/main/java/glide/api/models/commands/FT/FTProfileOptions.java @@ -0,0 +1,126 @@ +/** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ +package glide.api.models.commands.FT; + +import static glide.api.models.GlideString.gs; +import static glide.utils.ArrayTransformUtils.concatenateArrays; + +import glide.api.commands.servermodules.FT; +import glide.api.models.GlideString; +import java.util.ArrayList; +import java.util.List; +import lombok.NonNull; + +/** Mandatory parameters for {@link FT#profile} command. */ +public class FTProfileOptions { + private final QueryType queryType; + private final boolean limited; + private final GlideString[] commandLine; + + /** Query type being profiled. */ + enum QueryType { + SEARCH, + AGGREGATE + } + + /** + * Profile an aggregation query with given parameters. + * + * @param query The query itself. + * @param options {@link FT#aggregate} options. + */ + public FTProfileOptions(@NonNull String query, @NonNull FTAggregateOptions options) { + this(gs(query), options); + } + + /** + * Profile an aggregation query with given parameters. + * + * @param query The query itself. + * @param options {@link FT#aggregate} options. + */ + public FTProfileOptions(@NonNull GlideString query, @NonNull FTAggregateOptions options) { + this(query, options, false); + } + + /** + * Profile a search query with given parameters. + * + * @param query The query itself. + * @param options {@link FT#search} options. + */ + public FTProfileOptions(@NonNull String query, @NonNull FTSearchOptions options) { + this(gs(query), options); + } + + /** + * Profile a search query with given parameters. + * + * @param query The query itself. + * @param options {@link FT#search} options. + */ + public FTProfileOptions(@NonNull GlideString query, @NonNull FTSearchOptions options) { + this(query, options, false); + } + + /** + * Profile an aggregation query with given parameters. + * + * @param query The query itself. + * @param options {@link FT#aggregate} options. + * @param limited Either provide a full verbose output or some brief version (limited). + */ + public FTProfileOptions( + @NonNull String query, @NonNull FTAggregateOptions options, boolean limited) { + this(gs(query), options, limited); + } + + /** + * Profile a search query with given parameters. + * + * @param query The query itself. + * @param options {@link FT#search} options. + * @param limited Either provide a full verbose output or some brief version (limited). + */ + public FTProfileOptions( + @NonNull GlideString query, @NonNull FTAggregateOptions options, boolean limited) { + queryType = QueryType.AGGREGATE; + commandLine = concatenateArrays(new GlideString[] {query}, options.toArgs()); + this.limited = limited; + } + + /** + * Profile an aggregation query with given parameters. + * + * @param query The query itself. + * @param options {@link FT#aggregate} options. + * @param limited Either provide a full verbose output or some brief version (limited). + */ + public FTProfileOptions( + @NonNull String query, @NonNull FTSearchOptions options, boolean limited) { + this(gs(query), options, limited); + } + + /** + * Profile a search query with given parameters. + * + * @param query The query itself. + * @param options {@link FT#search} options. + * @param limited Either provide a full verbose output or some brief version (limited). + */ + public FTProfileOptions( + @NonNull GlideString query, @NonNull FTSearchOptions options, boolean limited) { + queryType = QueryType.SEARCH; + commandLine = concatenateArrays(new GlideString[] {query}, options.toArgs()); + this.limited = limited; + } + + /** Convert to module API. */ + public GlideString[] toArgs() { + var args = new ArrayList(); + args.add(gs(queryType.toString())); + if (limited) args.add(gs("LIMITED")); + args.add(gs("QUERY")); + args.addAll(List.of(commandLine)); + return args.toArray(GlideString[]::new); + } +} diff --git a/java/client/src/main/java/glide/api/models/commands/FT/FTSearchOptions.java b/java/client/src/main/java/glide/api/models/commands/FT/FTSearchOptions.java index 990eab2cb3..74407c64c0 100644 --- a/java/client/src/main/java/glide/api/models/commands/FT/FTSearchOptions.java +++ b/java/client/src/main/java/glide/api/models/commands/FT/FTSearchOptions.java @@ -9,6 +9,7 @@ import java.util.HashMap; import java.util.Map; import lombok.Builder; +import lombok.NonNull; import org.apache.commons.lang3.tuple.Pair; /** Mandatory parameters for {@link FT#search}. */ @@ -84,25 +85,26 @@ void count(boolean count) {} void identifiers(Map identifiers) {} /** Add a field to be returned. */ - public FTSearchOptionsBuilder addReturnField(String field) { + public FTSearchOptionsBuilder addReturnField(@NonNull String field) { this.identifiers$value.put(gs(field), null); return this; } /** Add a field with an alias to be returned. */ - public FTSearchOptionsBuilder addReturnField(String field, String alias) { + public FTSearchOptionsBuilder addReturnField(@NonNull String field, @NonNull String alias) { this.identifiers$value.put(gs(field), gs(alias)); return this; } /** Add a field to be returned. */ - public FTSearchOptionsBuilder addReturnField(GlideString field) { + public FTSearchOptionsBuilder addReturnField(@NonNull GlideString field) { this.identifiers$value.put(field, null); return this; } /** Add a field with an alias to be returned. */ - public FTSearchOptionsBuilder addReturnField(GlideString field, GlideString alias) { + public FTSearchOptionsBuilder addReturnField( + @NonNull GlideString field, @NonNull GlideString alias) { this.identifiers$value.put(field, alias); return this; } diff --git a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java index 09cf22cabf..10d09a5974 100644 --- a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java +++ b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java @@ -33,6 +33,7 @@ import glide.api.models.commands.FT.FTCreateOptions.TextField; import glide.api.models.commands.FT.FTCreateOptions.VectorFieldFlat; import glide.api.models.commands.FT.FTCreateOptions.VectorFieldHnsw; +import glide.api.models.commands.FT.FTProfileOptions; import glide.api.models.commands.FT.FTSearchOptions; import glide.api.models.commands.FlushMode; import glide.api.models.commands.InfoOptions.Section; @@ -66,6 +67,7 @@ public static void init() { @AfterAll @SneakyThrows public static void teardown() { + client.flushall(FlushMode.SYNC, ALL_PRIMARIES).get(); client.close(); } @@ -255,22 +257,22 @@ public void ft_search() { }))) .get()); Thread.sleep(DATA_PROCESSING_TIMEOUT); // let server digest the data and update index - var ftsearch = - FT.search( - client, - index, - "*=>[KNN 2 @VEC $query_vec]", - FTSearchOptions.builder() - .params( - Map.of( - gs("query_vec"), - gs( - new byte[] { - (byte) 0, (byte) 0, (byte) 0, (byte) 0, (byte) 0, (byte) 0, - (byte) 0, (byte) 0 - }))) - .build()) - .get(); + + // FT.SEARCH hash_idx1 "*=>[KNN 2 @VEC $query_vec]" PARAMS 2 query_vec + // "\x00\x00\x00\x00\x00\x00\x00\x00" DIALECT 2 + var options = + FTSearchOptions.builder() + .params( + Map.of( + gs("query_vec"), + gs( + new byte[] { + (byte) 0, (byte) 0, (byte) 0, (byte) 0, (byte) 0, (byte) 0, (byte) 0, + (byte) 0 + }))) + .build(); + var query = "*=>[KNN 2 @VEC $query_vec]"; + var ftsearch = FT.search(client, index, query, options).get(); assertArrayEquals( new Object[] { @@ -299,6 +301,9 @@ public void ft_search() { // TODO more tests with json index + var ftprofile = FT.profile(client, index, new FTProfileOptions(query, options)).get(); + assertArrayEquals(ftsearch, (Object[]) ftprofile[0]); + // querying non-existing index var exception = assertThrows( @@ -531,19 +536,15 @@ public void ft_aggregate() { Thread.sleep(DATA_PROCESSING_TIMEOUT); // let server digest the data and update index // FT.AGGREGATE idx:bicycle "*" LOAD 1 "__key" GROUPBY 1 "@condition" REDUCE COUNT 0 AS bicylces - var aggreg = - FT.aggregate( - client, - indexBicycles, - "*", - FTAggregateOptions.builder() - .loadFields(new String[] {"__key"}) - .addExpression( - new GroupBy( - new String[] {"@condition"}, - new Reducer[] {new Reducer("COUNT", new String[0], "bicycles")})) - .build()) - .get(); + var options = + FTAggregateOptions.builder() + .loadFields(new String[] {"__key"}) + .addExpression( + new GroupBy( + new String[] {"@condition"}, + new Reducer[] {new Reducer("COUNT", new String[0], "bicycles")})) + .build(); + var aggreg = FT.aggregate(client, indexBicycles, "*", options).get(); // elements (maps in array) could be reordered, comparing as sets assertDeepEquals( Set.of( @@ -658,30 +659,26 @@ public void ft_aggregate() { // FT.AGGREGATE idx:movie * LOAD * APPLY ceil(@rating) as r_rating GROUPBY 1 @genre REDUCE // COUNT 0 AS nb_of_movies REDUCE SUM 1 votes AS nb_of_votes REDUCE AVG 1 r_rating AS avg_rating // SORTBY 4 @avg_rating DESC @nb_of_votes DESC - aggreg = - FT.aggregate( - client, - indexMovies, - "*", - FTAggregateOptions.builder() - .loadAll() - .addExpression(new Apply("ceil(@rating)", "r_rating")) - .addExpression( - new GroupBy( - new String[] {"@genre"}, - new Reducer[] { - new Reducer("COUNT", new String[0], "nb_of_movies"), - new Reducer("SUM", new String[] {"votes"}, "nb_of_votes"), - new Reducer("AVG", new String[] {"r_rating"}, "avg_rating") - })) - .addExpression( - new SortBy( - new SortProperty[] { - new SortProperty("@avg_rating", SortOrder.DESC), - new SortProperty("@nb_of_votes", SortOrder.DESC) - })) - .build()) - .get(); + options = + FTAggregateOptions.builder() + .loadAll() + .addExpression(new Apply("ceil(@rating)", "r_rating")) + .addExpression( + new GroupBy( + new String[] {"@genre"}, + new Reducer[] { + new Reducer("COUNT", new String[0], "nb_of_movies"), + new Reducer("SUM", new String[] {"votes"}, "nb_of_votes"), + new Reducer("AVG", new String[] {"r_rating"}, "avg_rating") + })) + .addExpression( + new SortBy( + new SortProperty[] { + new SortProperty("@avg_rating", SortOrder.DESC), + new SortProperty("@nb_of_votes", SortOrder.DESC) + })) + .build(); + aggreg = FT.aggregate(client, indexMovies, "*", options).get(); // elements (maps in array) could be reordered, comparing as sets assertDeepEquals( Set.of( @@ -713,6 +710,9 @@ public void ft_aggregate() { gs("avg_rating"), 9.)), Set.of(aggreg)); + + var ftprofile = FT.profile(client, indexMovies, new FTProfileOptions("*", options)).get(); + assertDeepEquals(aggreg, ftprofile[0]); } @SuppressWarnings("unchecked") From a882b1f4e96e6951e88bde746c589fc5d986fa90 Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Mon, 28 Oct 2024 14:28:45 -0700 Subject: [PATCH 066/180] Java: `JSON.DEBUG`. (#2520) * `JSON.DEBUG`. Signed-off-by: Yury-Fridlyand --- CHANGELOG.md | 1 + .../api/commands/servermodules/Json.java | 228 ++++++++++++++++++ .../test/java/glide/modules/JsonTests.java | 25 ++ 3 files changed, 254 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e8f7e3bea..cb98c5549b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ * Java: Added `FT.PROFILE` ([#2473](https://github.com/valkey-io/valkey-glide/pull/2473)) * Java: Added `JSON.SET` and `JSON.GET` ([#2462](https://github.com/valkey-io/valkey-glide/pull/2462)) * Node: Added `FT.CREATE` ([#2501](https://github.com/valkey-io/valkey-glide/pull/2501)) +* Java: Added `JSON.DEBUG` ([#2520](https://github.com/valkey-io/valkey-glide/pull/2520)) * Java: Added `JSON.ARRINSERT` and `JSON.ARRLEN` ([#2476](https://github.com/valkey-io/valkey-glide/pull/2476)) * Java: Added `JSON.ARRPOP` ([#2486](https://github.com/valkey-io/valkey-glide/pull/2486)) * Java: Added `JSON.OBJLEN` and `JSON.OBJKEYS` ([#2492](https://github.com/valkey-io/valkey-glide/pull/2492)) diff --git a/java/client/src/main/java/glide/api/commands/servermodules/Json.java b/java/client/src/main/java/glide/api/commands/servermodules/Json.java index 5fd39aba1e..efd0af082b 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/Json.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/Json.java @@ -25,6 +25,8 @@ public class Json { private static final String JSON_ARRAPPEND = JSON_PREFIX + "ARRAPPEND"; private static final String JSON_ARRINSERT = JSON_PREFIX + "ARRINSERT"; private static final String JSON_ARRLEN = JSON_PREFIX + "ARRLEN"; + private static final String[] JSON_DEBUG_MEMORY = new String[] {JSON_PREFIX + "DEBUG", "MEMORY"}; + private static final String[] JSON_DEBUG_FIELDS = new String[] {JSON_PREFIX + "DEBUG", "FIELDS"}; private static final String JSON_ARRPOP = JSON_PREFIX + "ARRPOP"; private static final String JSON_ARRTRIM = JSON_PREFIX + "ARRTRIM"; private static final String JSON_OBJLEN = JSON_PREFIX + "OBJLEN"; @@ -716,6 +718,232 @@ public static CompletableFuture arrlen( return executeCommand(client, new GlideString[] {gs(JSON_ARRLEN), key}); } + /** + * Reports memory usage in bytes of a JSON object at the specified path within the + * JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
          + *
        • For JSONPath (path starts with $):
          + * Returns an Object[] with a list of numbers for every possible path, + * indicating the memory usage. If path does not exist, an empty array will + * be returned. + *
        • For legacy path (path doesn't start with $):
          + * Returns an integer representing the memory usage. If multiple paths are matched, + * returns the data of the first matching object. If path doesn't exist, an + * error is raised. + *
        + * If key doesn't exist, returns null. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "[1, 2.3, \"foo\", true, null, {}, [], {\"a\":1, \"b\":2}, [1, 2, 3]]").get();
        +     * var res = Json.debugMemory(client, "doc", "..").get();
        +     * assert res == 258L;
        +     * }
        + */ + public static CompletableFuture debugMemory( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, concatenateArrays(JSON_DEBUG_MEMORY, new String[] {key, path})); + } + + /** + * Reports the number of fields at the specified path within the JSON document stored + * at key.
        + * Each non-container JSON value counts as one field. Objects and arrays recursively count one + * field for each of their containing JSON values. Each container value, except the root + * container, counts as one additional field. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
          + *
        • For JSONPath (path starts with $):
          + * Returns an Object[] with a list of numbers for every possible path, + * indicating the number of fields. If path does not exist, an empty array + * will be returned. + *
        • For legacy path (path doesn't start with $):
          + * Returns an integer representing the number of fields. If multiple paths are matched, + * returns the data of the first matching object. If path doesn't exist, an + * error is raised. + *
        + * If key doesn't exist, returns null. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "[1, 2.3, \"foo\", true, null, {}, [], {\"a\":1, \"b\":2}, [1, 2, 3]]").get();
        +     * var res = Json.debugFields(client, "doc", "$[*]").get();
        +     * assert Arrays.equals((Object[]) res, new Object[] {1, 1, 1, 1, 1, 0, 0, 2, 3});
        +     * }
        + */ + public static CompletableFuture debugFields( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, concatenateArrays(JSON_DEBUG_FIELDS, new String[] {key, path})); + } + + /** + * Reports memory usage in bytes of a JSON object at the specified path within the + * JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
          + *
        • For JSONPath (path starts with $):
          + * Returns an Object[] with a list of numbers for every possible path, + * indicating the memory usage. If path does not exist, an empty array will + * be returned. + *
        • For legacy path (path doesn't start with $):
          + * Returns an integer representing the memory usage. If multiple paths are matched, + * returns the data of the first matching object. If path doesn't exist, an + * error is raised. + *
        + * If key doesn't exist, returns null. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "[1, 2.3, \"foo\", true, null, {}, [], {\"a\":1, \"b\":2}, [1, 2, 3]]").get();
        +     * var res = Json.debugMemory(client, gs("doc"), gs("..")).get();
        +     * assert res == 258L;
        +     * }
        + */ + public static CompletableFuture debugMemory( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand( + client, new ArgsBuilder().add(JSON_DEBUG_MEMORY).add(key).add(path).toArray()); + } + + /** + * Reports the number of fields at the specified path within the JSON document stored + * at key.
        + * Each non-container JSON value counts as one field. Objects and arrays recursively count one + * field for each of their containing JSON values. Each container value, except the root + * container, counts as one additional field. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
          + *
        • For JSONPath (path starts with $):
          + * Returns an Object[] with a list of numbers for every possible path, + * indicating the number of fields. If path does not exist, an empty array + * will be returned. + *
        • For legacy path (path doesn't start with $):
          + * Returns an integer representing the number of fields. If multiple paths are matched, + * returns the data of the first matching object. If path doesn't exist, an + * error is raised. + *
        + * If key doesn't exist, returns null. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "[1, 2.3, \"foo\", true, null, {}, [], {\"a\":1, \"b\":2}, [1, 2, 3]]").get();
        +     * var res = Json.debugFields(client, gs("doc"), gs("$[*]")).get();
        +     * assert Arrays.equals((Object[]) res, new Object[] {1, 1, 1, 1, 1, 0, 0, 2, 3});
        +     * }
        + */ + public static CompletableFuture debugFields( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand( + client, new ArgsBuilder().add(JSON_DEBUG_FIELDS).add(key).add(path).toArray()); + } + + /** + * Reports memory usage in bytes of a JSON object at the specified path within the + * JSON document stored at key.
        + * Equivalent to {@link #debugMemory(BaseClient, String, String)} with path set to + * "..". + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return The total memory usage in bytes of the entire JSON document.
        + * If key doesn't exist, returns null. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "[1, 2.3, \"foo\", true, null, {}, [], {\"a\":1, \"b\":2}, [1, 2, 3]]").get();
        +     * var res = Json.debugMemory(client, "doc").get();
        +     * assert res == 258L;
        +     * }
        + */ + public static CompletableFuture debugMemory( + @NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, concatenateArrays(JSON_DEBUG_MEMORY, new String[] {key})); + } + + /** + * Reports the number of fields at the specified path within the JSON document stored + * at key.
        + * Each non-container JSON value counts as one field. Objects and arrays recursively count one + * field for each of their containing JSON values. Each container value, except the root + * container, counts as one additional field.
        + * Equivalent to {@link #debugFields(BaseClient, String, String)} with path set to + * "..". + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return The total number of fields in the entire JSON document.
        + * If key doesn't exist, returns null. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "[1, 2.3, \"foo\", true, null, {}, [], {\"a\":1, \"b\":2}, [1, 2, 3]]").get();
        +     * var res = Json.debugFields(client, "doc").get();
        +     * assert res == 14L;
        +     * }
        + */ + public static CompletableFuture debugFields( + @NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, concatenateArrays(JSON_DEBUG_FIELDS, new String[] {key})); + } + + /** + * Reports memory usage in bytes of a JSON object at the specified path within the + * JSON document stored at key.
        + * Equivalent to {@link #debugMemory(BaseClient, GlideString, GlideString)} with path + * set to gs(".."). + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return The total memory usage in bytes of the entire JSON document.
        + * If key doesn't exist, returns null. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "[1, 2.3, \"foo\", true, null, {}, [], {\"a\":1, \"b\":2}, [1, 2, 3]]").get();
        +     * var res = Json.debugMemory(client, gs("doc")).get();
        +     * assert res == 258L;
        +     * }
        + */ + public static CompletableFuture debugMemory( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new ArgsBuilder().add(JSON_DEBUG_MEMORY).add(key).toArray()); + } + + /** + * Reports the number of fields at the specified path within the JSON document stored + * at key.
        + * Each non-container JSON value counts as one field. Objects and arrays recursively count one + * field for each of their containing JSON values. Each container value, except the root + * container, counts as one additional field.
        + * Equivalent to {@link #debugFields(BaseClient, GlideString, GlideString)} with path + * set to gs(".."). + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return The total number of fields in the entire JSON document.
        + * If key doesn't exist, returns null. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "[1, 2.3, \"foo\", true, null, {}, [], {\"a\":1, \"b\":2}, [1, 2, 3]]").get();
        +     * var res = Json.debugFields(client, gs("doc")).get();
        +     * assert res == 14L;
        +     * }
        + */ + public static CompletableFuture debugFields( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new ArgsBuilder().add(JSON_DEBUG_FIELDS).add(key).toArray()); + } + /** * Pops the last element from the array stored in the root of the JSON document stored at * key. Equivalent to {@link #arrpop(BaseClient, String, String)} with diff --git a/java/integTest/src/test/java/glide/modules/JsonTests.java b/java/integTest/src/test/java/glide/modules/JsonTests.java index ed5336844a..2b07f9d44e 100644 --- a/java/integTest/src/test/java/glide/modules/JsonTests.java +++ b/java/integTest/src/test/java/glide/modules/JsonTests.java @@ -291,6 +291,31 @@ public void arrinsert() { assertEquals(JsonParser.parseString(expected), JsonParser.parseString(doc)); } + @Test + @SneakyThrows + public void debug() { + String key = UUID.randomUUID().toString(); + + var doc = + "{ \"key1\": 1, \"key2\": 3.5, \"key3\": {\"nested_key\": {\"key1\": [4, 5]}}, \"key4\":" + + " [1, 2, 3], \"key5\": 0, \"key6\": \"hello\", \"key7\": null, \"key8\":" + + " {\"nested_key\": {\"key1\": 3.5953862697246314e307}}, \"key9\":" + + " 3.5953862697246314e307, \"key10\": true }"; + assertEquals("OK", Json.set(client, key, "$", doc).get()); + + assertArrayEquals(new Object[] {1L}, (Object[]) Json.debugFields(client, key, "$.key1").get()); + + assertEquals(2L, Json.debugFields(client, gs(key), gs(".key3.nested_key.key1")).get()); + + assertArrayEquals( + new Object[] {16L}, (Object[]) Json.debugMemory(client, key, "$.key4[2]").get()); + + assertEquals(16L, Json.debugMemory(client, gs(key), gs(".key6")).get()); + + assertEquals(504L, Json.debugMemory(client, key).get()); + assertEquals(19L, Json.debugFields(client, gs(key)).get()); + } + @Test @SneakyThrows public void arrlen() { From 01308db4721849d6c6c45efd13257a4952edeb43 Mon Sep 17 00:00:00 2001 From: tjzhang-BQ <111323543+tjzhang-BQ@users.noreply.github.com> Date: Mon, 28 Oct 2024 14:32:47 -0700 Subject: [PATCH 067/180] Node: Add command JSON.RESP (#2517) Node: Add command JSON.RESP Signed-off-by: TJ Zhang --- CHANGELOG.md | 1 + node/src/server-modules/GlideJson.ts | 48 ++++++++++ node/tests/ServerModules.test.ts | 129 +++++++++++++++++++++++++++ 3 files changed, 178 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index cb98c5549b..a554b9652b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ * Node: Added `JSON.TYPE` ([#2510](https://github.com/valkey-io/valkey-glide/pull/2510)) * Java: Added `JSON.RESP` ([#2513](https://github.com/valkey-io/valkey-glide/pull/2513)) * Node: Added `FT.DROPINDEX` ([#2516](https://github.com/valkey-io/valkey-glide/pull/2516)) +* Node: Added `JSON.RESP` ([#2517](https://github.com/valkey-io/valkey-glide/pull/2517)) * Python: Add `JSON.STRAPPEND` , `JSON.STRLEN` commands ([#2372](https://github.com/valkey-io/valkey-glide/pull/2372)) * Python: Add `JSON.OBJKEYS` command ([#2395](https://github.com/valkey-io/valkey-glide/pull/2395)) * Python: Add `JSON.ARRINSERT` command ([#2464](https://github.com/valkey-io/valkey-glide/pull/2464)) diff --git a/node/src/server-modules/GlideJson.ts b/node/src/server-modules/GlideJson.ts index db40a31efc..92d3ae35a8 100644 --- a/node/src/server-modules/GlideJson.ts +++ b/node/src/server-modules/GlideJson.ts @@ -352,4 +352,52 @@ export class GlideJson { return _executeCommand>(client, args); } + + /** + * Retrieve the JSON value at the specified `path` within the JSON document stored at `key`. + * The returning result is in the Valkey or Redis OSS Serialization Protocol (RESP). + * JSON null is mapped to the RESP Null Bulk String. + * JSON Booleans are mapped to RESP Simple string. + * JSON integers are mapped to RESP Integers. + * JSON doubles are mapped to RESP Bulk Strings. + * JSON strings are mapped to RESP Bulk Strings. + * JSON arrays are represented as RESP arrays, where the first element is the simple string [, followed by the array's elements. + * JSON objects are represented as RESP object, where the first element is the simple string {, followed by key-value pairs, each of which is a RESP bulk string. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param options - (Optional) Additional parameters: + * - (Optional) path - The path within the JSON document, Defaults to root if not provided. + * @returns ReturnTypeJson: + * - For JSONPath (path starts with `$`): + * - Returns an array of replies for every possible path, indicating the RESP form of the JSON value. + * If `path` doesn't exist, returns an empty array. + * - For legacy path (path doesn't start with `$`): + * - Returns a single reply for the JSON value at the specified `path`, in its RESP form. + * If multiple paths match, the value of the first JSON value match is returned. If `path` doesn't exist, an error is raised. + * - If `key` doesn't exist, `null` is returned. + * + * @example + * ```typescript + * console.log(await GlideJson.set(client, "doc", ".", "{a: [1, 2, 3], b: {a: [1, 2], c: {a: 42}}}")); + * // Output: 'OK' - Indicates successful setting of the value at path '.' in the key stored at `doc`. + * const result = await GlideJson.resp(client, "doc", "$..a"); + * console.log(result); + * // Output: [ ["[", 1L, 2L, 3L], ["[", 1L, 2L], [42L]]; + * console.log(await GlideJson.type(client, "doc", "..a")); // Output: ["[", 1L, 2L, 3L] + * ``` + */ + static async resp( + client: BaseClient, + key: GlideString, + options?: { path: GlideString }, + ): Promise> { + const args = ["JSON.RESP", key]; + + if (options) { + args.push(options.path); + } + + return _executeCommand>(client, args); + } } diff --git a/node/tests/ServerModules.test.ts b/node/tests/ServerModules.test.ts index 899bf88644..4ec5757fd7 100644 --- a/node/tests/ServerModules.test.ts +++ b/node/tests/ServerModules.test.ts @@ -590,6 +590,135 @@ describe("Server Module Tests", () => { ).toBeNull(); }, ); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "json.resp tests", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const jsonValue = { + obj: { a: 1, b: 2 }, + arr: [1, 2, 3], + str: "foo", + bool: true, + int: 42, + float: 3.14, + nullVal: null, + }; + // setup + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + expect( + await GlideJson.resp(client, key, { path: "$.*" }), + ).toEqual([ + ["{", ["a", 1], ["b", 2]], + ["[", 1, 2, 3], + "foo", + "true", + 42, + "3.14", + null, + ]); // leading "{" - JSON objects, leading "[" - JSON arrays + + // multiple path match, the first will be returned + expect( + await GlideJson.resp(client, key, { path: "*" }), + ).toEqual(["{", ["a", 1], ["b", 2]]); + + // testing $ path + expect( + await GlideJson.resp(client, key, { path: "$" }), + ).toEqual([ + [ + "{", + ["obj", ["{", ["a", 1], ["b", 2]]], + ["arr", ["[", 1, 2, 3]], + ["str", "foo"], + ["bool", "true"], + ["int", 42], + ["float", "3.14"], + ["nullVal", null], + ], + ]); + + // testing . path + expect( + await GlideJson.resp(client, key, { path: "." }), + ).toEqual([ + "{", + ["obj", ["{", ["a", 1], ["b", 2]]], + ["arr", ["[", 1, 2, 3]], + ["str", "foo"], + ["bool", "true"], + ["int", 42], + ["float", "3.14"], + ["nullVal", null], + ]); + + // $.str and .str + expect( + await GlideJson.resp(client, key, { path: "$.str" }), + ).toEqual(["foo"]); + expect( + await GlideJson.resp(client, key, { path: ".str" }), + ).toEqual("foo"); + + // setup new json value + const jsonValue2 = { + a: [1, 2, 3], + b: { a: [1, 2], c: { a: 42 } }, + }; + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue2), + ), + ).toBe("OK"); + + expect( + await GlideJson.resp(client, key, { path: "..a" }), + ).toEqual(["[", 1, 2, 3]); + + expect( + await GlideJson.resp(client, key, { + path: "$.nonexistent", + }), + ).toEqual([]); + + // error case + await expect( + GlideJson.resp(client, key, { path: "nonexistent" }), + ).rejects.toThrow(RequestError); + + // non-existent key + expect( + await GlideJson.resp(client, "nonexistent_key", { + path: "$", + }), + ).toBeNull(); + expect( + await GlideJson.resp(client, "nonexistent_key", { + path: ".", + }), + ).toBeNull(); + expect( + await GlideJson.resp(client, "nonexistent_key"), + ).toBeNull(); + }, + ); }); describe("GlideFt", () => { From 5cf72b149533425f5f6f51735129d45688864058 Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Mon, 28 Oct 2024 15:40:35 -0700 Subject: [PATCH 068/180] CI: Minor fixes (#2333) Update CI Signed-off-by: Yury-Fridlyand --- .github/workflows/lint-rust/action.yml | 2 -- .github/workflows/node.yml | 5 +++-- .github/workflows/rust.yml | 2 +- .github/workflows/semgrep.yml | 4 ++-- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/.github/workflows/lint-rust/action.yml b/.github/workflows/lint-rust/action.yml index 0823dda958..06b0b7a75a 100644 --- a/.github/workflows/lint-rust/action.yml +++ b/.github/workflows/lint-rust/action.yml @@ -23,8 +23,6 @@ runs: github-token: ${{ inputs.github-token }} - uses: Swatinem/rust-cache@v2 - with: - github-token: ${{ inputs.github-token }} - run: cargo fmt --all -- --check working-directory: ${{ inputs.cargo-toml-folder }} diff --git a/.github/workflows/node.yml b/.github/workflows/node.yml index d8f690e560..1927649247 100644 --- a/.github/workflows/node.yml +++ b/.github/workflows/node.yml @@ -126,10 +126,11 @@ jobs: - uses: actions/checkout@v4 - - uses: ./.github/workflows/lint-rust + - name: lint node rust + uses: ./.github/workflows/lint-rust with: cargo-toml-folder: ./node/rust-client - name: lint node rust + github-token: ${{ secrets.GITHUB_TOKEN }} # build-macos-latest: # runs-on: macos-latest diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 95b47f2ce2..2fdfa77c1f 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -42,7 +42,7 @@ jobs: load-engine-matrix: runs-on: ubuntu-latest outputs: - matrix: ${{ steps.load-engine-matrix.outputs.matrix }} + matrix: ${{ steps.load-engine-matrix.outputs.matrix }} steps: - name: Checkout uses: actions/checkout@v4 diff --git a/.github/workflows/semgrep.yml b/.github/workflows/semgrep.yml index 58bb7cb238..a9cf3db6df 100644 --- a/.github/workflows/semgrep.yml +++ b/.github/workflows/semgrep.yml @@ -2,7 +2,7 @@ name: Semgrep on: # Scan changed files in PRs (diff-aware scanning): - pull_request: {} + pull_request: # Scan on-demand through GitHub Actions interface: workflow_dispatch: inputs: @@ -34,6 +34,6 @@ jobs: steps: # Fetch project source with GitHub Actions Checkout. - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 # Run the "semgrep ci" command on the command line of the docker image. - run: semgrep ci --config auto --no-suppress-errors --exclude-rule generic.secrets.security.detected-private-key.detected-private-key From 7db739e464791d3156d75300476767f2797c0843 Mon Sep 17 00:00:00 2001 From: jonathanl-bq <72158117+jonathanl-bq@users.noreply.github.com> Date: Mon, 28 Oct 2024 15:44:21 -0700 Subject: [PATCH 069/180] Java: add JSON.STRAPPEND and JSON.STRLEN (#2522) * Implement JSON STRLEN and STRAPPEND commands --------- Signed-off-by: Jonathan Louie Signed-off-by: jonathanl-bq <72158117+jonathanl-bq@users.noreply.github.com> --- CHANGELOG.md | 1 + .../api/commands/servermodules/Json.java | 286 ++++++++++++++++++ .../test/java/glide/modules/JsonTests.java | 84 +++++ 3 files changed, 371 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a554b9652b..064e73e1f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ * Node: Added `JSON.TOGGLE` ([#2491](https://github.com/valkey-io/valkey-glide/pull/2491)) * Node: Added `JSON.DEL` and `JSON.FORGET` ([#2505](https://github.com/valkey-io/valkey-glide/pull/2505)) * Java: Added `JSON.TOGGLE` ([#2504](https://github.com/valkey-io/valkey-glide/pull/2504)) +* Java: Added `JSON.STRAPPEND` and `JSON.STRLEN` ([#2522](https://github.com/valkey-io/valkey-glide/pull/2522)) * Java: Added `JSON.CLEAR` ([#2519](https://github.com/valkey-io/valkey-glide/pull/2519)) * Node: Added `JSON.TYPE` ([#2510](https://github.com/valkey-io/valkey-glide/pull/2510)) * Java: Added `JSON.RESP` ([#2513](https://github.com/valkey-io/valkey-glide/pull/2513)) diff --git a/java/client/src/main/java/glide/api/commands/servermodules/Json.java b/java/client/src/main/java/glide/api/commands/servermodules/Json.java index efd0af082b..3d8a299b26 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/Json.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/Json.java @@ -34,6 +34,8 @@ public class Json { private static final String JSON_DEL = JSON_PREFIX + "DEL"; private static final String JSON_FORGET = JSON_PREFIX + "FORGET"; private static final String JSON_TOGGLE = JSON_PREFIX + "TOGGLE"; + private static final String JSON_STRAPPEND = JSON_PREFIX + "STRAPPEND"; + private static final String JSON_STRLEN = JSON_PREFIX + "STRLEN"; private static final String JSON_CLEAR = JSON_PREFIX + "CLEAR"; private static final String JSON_RESP = JSON_PREFIX + "RESP"; private static final String JSON_TYPE = JSON_PREFIX + "TYPE"; @@ -1752,6 +1754,290 @@ public static CompletableFuture toggle( client, new ArgsBuilder().add(gs(JSON_TOGGLE)).add(key).add(path).toArray()); } + /** + * Appends the specified value to the string stored at the specified path + * within the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param value The value to append to the string. Must be wrapped with single quotes. For + * example, to append "foo", pass '"foo"'. + * @param path The path within the JSON document. + * @return + *
          + *
        • For JSONPath (path starts with $):
          + * Returns a list of integer replies for every possible path, indicating the length of + * the resulting string after appending value, or null for + * JSON values matching the path that are not string.
          + * If key doesn't exist, an error is raised. + *
        • For legacy path (path doesn't start with $):
          + * Returns the length of the resulting string after appending value to the + * string at path.
          + * If multiple paths match, the length of the last updated string is returned.
          + * If the JSON value at path is not a string of if path + * doesn't exist, an error is raised.
          + * If key doesn't exist, an error is raised. + *
        + * + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "{\"a\":\"foo\", \"nested\": {\"a\": \"hello\"}, \"nested2\": {\"a\": 31}}").get();
        +     * var res = Json.strappend(client, "doc", "baz", "$..a").get();
        +     * assert Arrays.equals((Object[]) res, new Object[] {6L, 8L, null}); // The new length of the string values at path '$..a' in the key stored at `doc` after the append operation.
        +     *
        +     * res = Json.strappend(client, "doc", '"foo"', "nested.a").get();
        +     * assert (Long) res == 11L; // The length of the string value after appending "foo" to the string at path 'nested.array' in the key stored at `doc`.
        +     *
        +     * var getResult = Json.get(client, "doc", "$").get();
        +     * assert getResult.equals("[{\"a\":\"foobaz\", \"nested\": {\"a\": \"hellobazfoo\"}, \"nested2\": {\"a\": 31}}]"); // The updated JSON value in the key stored at `doc`.
        +     * }
        + */ + public static CompletableFuture strappend( + @NonNull BaseClient client, + @NonNull String key, + @NonNull String value, + @NonNull String path) { + return executeCommand( + client, new ArgsBuilder().add(JSON_STRAPPEND).add(key).add(path).add(value).toArray()); + } + + /** + * Appends the specified value to the string stored at the specified path + * within the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param value The value to append to the string. Must be wrapped with single quotes. For + * example, to append "foo", pass '"foo"'. + * @param path The path within the JSON document. + * @return + *
          + *
        • For JSONPath (path starts with $):
          + * Returns a list of integer replies for every possible path, indicating the length of + * the resulting string after appending value, or null for + * JSON values matching the path that are not string.
          + * If key doesn't exist, an error is raised. + *
        • For legacy path (path doesn't start with $):
          + * Returns the length of the resulting string after appending value to the + * string at path.
          + * If multiple paths match, the length of the last updated string is returned.
          + * If the JSON value at path is not a string of if path + * doesn't exist, an error is raised.
          + * If key doesn't exist, an error is raised. + *
        + * + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "{\"a\":\"foo\", \"nested\": {\"a\": \"hello\"}, \"nested2\": {\"a\": 31}}").get();
        +     * var res = Json.strappend(client, gs("doc"), gs("baz"), gs("$..a")).get();
        +     * assert Arrays.equals((Object[]) res, new Object[] {6L, 8L, null}); // The new length of the string values at path '$..a' in the key stored at `doc` after the append operation.
        +     *
        +     * res = Json.strappend(client, gs("doc"), gs("'\"foo\"'"), gs("nested.a")).get();
        +     * assert (Long) res == 11L; // The length of the string value after appending "foo" to the string at path 'nested.array' in the key stored at `doc`.
        +     *
        +     * var getResult = Json.get(client, gs("doc"), gs("$")).get();
        +     * assert getResult.equals("[{\"a\":\"foobaz\", \"nested\": {\"a\": \"hellobazfoo\"}, \"nested2\": {\"a\": 31}}]"); // The updated JSON value in the key stored at `doc`.
        +     * }
        + */ + public static CompletableFuture strappend( + @NonNull BaseClient client, + @NonNull GlideString key, + @NonNull GlideString value, + @NonNull GlideString path) { + return executeCommand( + client, new ArgsBuilder().add(gs(JSON_STRAPPEND)).add(key).add(path).add(value).toArray()); + } + + /** + * Appends the specified value to the string stored at the root within the JSON + * document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param value The value to append to the string. Must be wrapped with single quotes. For + * example, to append "foo", pass '"foo"'. + * @return Returns the length of the resulting string after appending value to the + * string at the root.
        + * If the JSON value at root is not a string, an error is raised.
        + * If key doesn't exist, an error is raised. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "'\"foo\"'").get();
        +     * var res = Json.strappend(client, "doc", "'\"baz\"'").get();
        +     * assert res == 6L; // The length of the string value after appending "foo" to the string at root in the key stored at `doc`.
        +     *
        +     * var getResult = Json.get(client, "doc").get();
        +     * assert getResult.equals("\"foobaz\""); // The updated JSON value in the key stored at `doc`.
        +     * }
        + */ + public static CompletableFuture strappend( + @NonNull BaseClient client, @NonNull String key, @NonNull String value) { + return executeCommand( + client, new ArgsBuilder().add(JSON_STRAPPEND).add(key).add(value).toArray()); + } + + /** + * Appends the specified value to the string stored at the root within the JSON + * document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param value The value to append to the string. Must be wrapped with single quotes. For + * example, to append "foo", pass '"foo"'. + * @return Returns the length of the resulting string after appending value to the + * string at the root.
        + * If the JSON value at root is not a string, an error is raised.
        + * If key doesn't exist, an error is raised. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "'\"foo\"'").get();
        +     * var res = Json.strappend(client, gs("doc"), gs("'\"baz\"'")).get();
        +     * assert res == 6L; // The length of the string value after appending "foo" to the string at root in the key stored at `doc`.
        +     *
        +     * var getResult = Json.get(client, gs("$"), gs("doc")).get();
        +     * assert getResult.equals("\"foobaz\""); // The updated JSON value in the key stored at `doc`.
        +     * }
        + */ + public static CompletableFuture strappend( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString value) { + return executeCommand( + client, new ArgsBuilder().add(gs(JSON_STRAPPEND)).add(key).add(value).toArray()); + } + + /** + * Returns the length of the JSON string value stored at the specified path within + * the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
          + *
        • For JSONPath (path starts with $):
          + * Returns a list of integer replies for every possible path, indicating the length of + * the JSON string value, or null for JSON values matching the path that + * are not string. + *
        • For legacy path (path doesn't start with $):
          + * Returns the length of the JSON value at path or null if + * key doesn't exist.
          + * If multiple paths match, the length of the first matched string is returned.
          + * If the JSON value at path is not a string of if path + * doesn't exist, an error is raised. If key doesn't exist, null + * is returned. + *
        + * + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "{\"a\":\"foo\", \"nested\": {\"a\": \"hello\"}, \"nested2\": {\"a\": 31}}").get();
        +     * var res = Json.strlen(client, "doc", "$..a").get();
        +     * assert Arrays.equals((Object[]) res, new Object[] {3L, 5L, null}); // The length of the string values at path '$..a' in the key stored at `doc`.
        +     *
        +     * res = Json.strlen(client, "doc", "nested.a").get();
        +     * assert (Long) res == 5L; // The length of the JSON value at path 'nested.a' in the key stored at `doc`.
        +     *
        +     * res = Json.strlen(client, "doc", "$").get();
        +     * assert Arrays.equals((Object[]) res, new Object[] {null}); // Returns an array with null since the value at root path does in the JSON document stored at `doc` is not a string.
        +     *
        +     * res = Json.strlen(client, "non_existing_key", ".").get();
        +     * assert res == null; // `key` doesn't exist.
        +     * }
        + */ + public static CompletableFuture strlen( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, new ArgsBuilder().add(JSON_STRLEN).add(key).add(path).toArray()); + } + + /** + * Returns the length of the JSON string value stored at the specified path within + * the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
          + *
        • For JSONPath (path starts with $):
          + * Returns a list of integer replies for every possible path, indicating the length of + * the JSON string value, or null for JSON values matching the path that + * are not string. + *
        • For legacy path (path doesn't start with $):
          + * Returns the length of the JSON value at path or null if + * key doesn't exist.
          + * If multiple paths match, the length of the first matched string is returned.
          + * If the JSON value at path is not a string of if path + * doesn't exist, an error is raised. If key doesn't exist, null + * is returned. + *
        + * + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "{\"a\":\"foo\", \"nested\": {\"a\": \"hello\"}, \"nested2\": {\"a\": 31}}").get();
        +     * var res = Json.strlen(client, gs("doc"), gs("$..a")).get();
        +     * assert Arrays.equals((Object[]) res, new Object[] {3L, 5L, null}); // The length of the string values at path '$..a' in the key stored at `doc`.
        +     *
        +     * res = Json.strlen(client, gs("doc"), gs("nested.a")).get();
        +     * assert (Long) res == 5L; // The length of the JSON value at path 'nested.a' in the key stored at `doc`.
        +     *
        +     * res = Json.strlen(client, gs("doc"), gs("$")).get();
        +     * assert Arrays.equals((Object[]) res, new Object[] {null}); // Returns an array with null since the value at root path does in the JSON document stored at `doc` is not a string.
        +     *
        +     * res = Json.strlen(client, gs("non_existing_key"), gs(".")).get();
        +     * assert res == null; // `key` doesn't exist.
        +     * }
        + */ + public static CompletableFuture strlen( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand( + client, new ArgsBuilder().add(gs(JSON_STRLEN)).add(key).add(path).toArray()); + } + + /** + * Returns the length of the JSON string value stored at the root within the JSON document stored + * at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return Returns the length of the JSON value at the root.
        + * If the JSON value is not a string, an error is raised.
        + * If key doesn't exist, null is returned. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "\"Hello\"").get();
        +     * var res = Json.strlen(client, "doc").get();
        +     * assert res == 5L; // The length of the JSON value at the root in the key stored at `doc`.
        +     *
        +     * res = Json.strlen(client, "non_existing_key").get();
        +     * assert res == null; // `key` doesn't exist.
        +     * }
        + */ + public static CompletableFuture strlen(@NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new ArgsBuilder().add(JSON_STRLEN).add(key).toArray()); + } + + /** + * Returns the length of the JSON string value stored at the root within the JSON document stored + * at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return Returns the length of the JSON value at the root.
        + * If the JSON value is not a string, an error is raised.
        + * If key doesn't exist, null is returned. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "\"Hello\"").get();
        +     * var res = Json.strlen(client, gs("doc")).get();
        +     * assert res == 5L; // The length of the JSON value at the root in the key stored at `doc`.
        +     *
        +     * res = Json.strlen(client, gs("non_existing_key")).get();
        +     * assert res == null; // `key` doesn't exist.
        +     * }
        + */ + public static CompletableFuture strlen( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new ArgsBuilder().add(gs(JSON_STRLEN)).add(key).toArray()); + } + /** * Clears an array and an object at the root of the JSON document stored at key.
        * Equivalent to {@link #clear(BaseClient, String, String)} with path set to diff --git a/java/integTest/src/test/java/glide/modules/JsonTests.java b/java/integTest/src/test/java/glide/modules/JsonTests.java index 2b07f9d44e..cb0d202bfd 100644 --- a/java/integTest/src/test/java/glide/modules/JsonTests.java +++ b/java/integTest/src/test/java/glide/modules/JsonTests.java @@ -603,6 +603,90 @@ public void toggle() { ExecutionException.class, () -> Json.toggle(client, "non_existing_key", "$").get()); } + @Test + @SneakyThrows + public void strappend() { + String key = UUID.randomUUID().toString(); + String jsonValue = "{\"a\": \"foo\", \"nested\": {\"a\": \"hello\"}, \"nested2\": {\"a\": 31}}"; + assertEquals("OK", Json.set(client, key, "$", jsonValue).get()); + + assertArrayEquals( + new Object[] {6L, 8L, null}, + (Object[]) Json.strappend(client, key, "\"bar\"", "$..a").get()); + assertEquals(9L, (Long) Json.strappend(client, key, "\"foo\"", "a").get()); + + String jsonStr = Json.get(client, key, new String[] {"."}).get(); + assertEquals( + "{\"a\":\"foobarfoo\",\"nested\":{\"a\":\"hellobar\"},\"nested2\":{\"a\":31}}", jsonStr); + + assertArrayEquals( + new Object[] {null}, (Object[]) Json.strappend(client, key, "\"bar\"", "$.nested").get()); + + assertThrows( + ExecutionException.class, () -> Json.strappend(client, key, "\"bar\"", ".nested").get()); + + assertThrows(ExecutionException.class, () -> Json.strappend(client, key, "\"bar\"").get()); + + assertArrayEquals( + new Object[] {}, + (Object[]) Json.strappend(client, key, "\"try\"", "$.non_existing_path").get()); + + assertThrows( + ExecutionException.class, + () -> Json.strappend(client, key, "\"try\"", "non_existing_path").get()); + + assertThrows( + ExecutionException.class, + () -> Json.strappend(client, "non_existing_key", "\"try\"").get()); + + // Binary test + // Binary with path + assertEquals(12L, (Long) Json.strappend(client, gs(key), gs("\"foo\""), gs("a")).get()); + jsonStr = Json.get(client, key, new String[] {"."}).get(); + assertEquals( + "{\"a\":\"foobarfoofoo\",\"nested\":{\"a\":\"hellobar\"},\"nested2\":{\"a\":31}}", jsonStr); + + // Binary no path + assertEquals("OK", Json.set(client, key, "$", "\"hi\"").get()); + assertEquals(5L, Json.strappend(client, gs(key), gs("\"foo\"")).get()); + jsonStr = Json.get(client, key, new String[] {"."}).get(); + assertEquals("\"hifoo\"", jsonStr); + } + + @Test + @SneakyThrows + public void strlen() { + String key = UUID.randomUUID().toString(); + String jsonValue = "{\"a\": \"foo\", \"nested\": {\"a\": \"hello\"}, \"nested2\": {\"a\": 31}}"; + assertEquals("OK", Json.set(client, key, "$", jsonValue).get()); + + assertArrayEquals( + new Object[] {3L, 5L, null}, (Object[]) Json.strlen(client, key, "$..a").get()); + assertEquals(3L, (Long) Json.strlen(client, key, "a").get()); + + assertArrayEquals(new Object[] {null}, (Object[]) Json.strlen(client, key, "$.nested").get()); + + assertThrows(ExecutionException.class, () -> Json.strlen(client, key, "nested").get()); + + assertThrows(ExecutionException.class, () -> Json.strlen(client, key).get()); + + assertArrayEquals( + new Object[] {}, (Object[]) Json.strlen(client, key, "$.non_existing_path").get()); + assertThrows( + ExecutionException.class, () -> Json.strlen(client, key, ".non_existing_path").get()); + + assertNull(Json.strlen(client, "non_existing_key", ".").get()); + assertNull(Json.strlen(client, "non_existing_key", "$").get()); + + // Binary test + // Binary with path + assertEquals(3L, (Long) Json.strlen(client, gs(key), gs("a")).get()); + + // Binary no path + assertEquals("OK", Json.set(client, key, "$", "\"hi\"").get()); + assertEquals(2L, Json.strlen(client, gs(key)).get()); + } + @Test @SneakyThrows public void json_resp() { From 5e86b7a6fe34860d122e84ab5cc4f4930303bad1 Mon Sep 17 00:00:00 2001 From: jonathanl-bq <72158117+jonathanl-bq@users.noreply.github.com> Date: Mon, 28 Oct 2024 15:48:09 -0700 Subject: [PATCH 070/180] Java: add JSON.NUMINCRBY and JSON.NUMMULTBY (#2511) * Implement NUMINCRBY and NUMMULTBY JSON commands --------- Signed-off-by: Jonathan Louie Signed-off-by: jonathanl-bq <72158117+jonathanl-bq@users.noreply.github.com> --- CHANGELOG.md | 1 + .../api/commands/servermodules/Json.java | 173 ++++++++++ .../test/java/glide/modules/JsonTests.java | 295 ++++++++++++++++++ 3 files changed, 469 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 064e73e1f7..4ccae7464c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ * Java: Added `FT.ALIASADD`, `FT.ALIASDEL`, `FT.ALIASUPDATE` ([#2442](https://github.com/valkey-io/valkey-glide/pull/2442)) * Core: Update routing for commands from server modules ([#2461](https://github.com/valkey-io/valkey-glide/pull/2461)) * Node: Added `JSON.SET` and `JSON.GET` ([#2427](https://github.com/valkey-io/valkey-glide/pull/2427)) +* Java: Added `JSON.NUMINCRBY` and `JSON.NUMMULTBY` ([#2511](https://github.com/valkey-io/valkey-glide/pull/2511)) * Java: Added `JSON.ARRAPPEND` ([#2489](https://github.com/valkey-io/valkey-glide/pull/2489)) * Java: Added `JSON.ARRTRIM` ([#2518](https://github.com/valkey-io/valkey-glide/pull/2518)) * Node: Added `JSON.TOGGLE` ([#2491](https://github.com/valkey-io/valkey-glide/pull/2491)) diff --git a/java/client/src/main/java/glide/api/commands/servermodules/Json.java b/java/client/src/main/java/glide/api/commands/servermodules/Json.java index 3d8a299b26..5edf8b76d3 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/Json.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/Json.java @@ -22,6 +22,8 @@ public class Json { private static final String JSON_PREFIX = "JSON."; private static final String JSON_SET = JSON_PREFIX + "SET"; private static final String JSON_GET = JSON_PREFIX + "GET"; + private static final String JSON_NUMINCRBY = JSON_PREFIX + "NUMINCRBY"; + private static final String JSON_NUMMULTBY = JSON_PREFIX + "NUMMULTBY"; private static final String JSON_ARRAPPEND = JSON_PREFIX + "ARRAPPEND"; private static final String JSON_ARRINSERT = JSON_PREFIX + "ARRINSERT"; private static final String JSON_ARRLEN = JSON_PREFIX + "ARRLEN"; @@ -1252,6 +1254,177 @@ public static CompletableFuture arrtrim( .toArray()); } + /** + * Increments or decrements the JSON value(s) at the specified path by number + * within the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param number The number to increment or decrement by. + * @return + *
          + *
        • For JSONPath (path starts with $):
          + * Returns a string representation of an array of strings, indicating the new values + * after incrementing for each matched path.
          + * If a value is not a number, its corresponding return value will be null. + *
          + * If path doesn't exist, a byte string representation of an empty array + * will be returned. + *
        • For legacy path (path doesn't start with $):
          + * Returns a string representation of the resulting value after the increment or + * decrement.
          + * If multiple paths match, the result of the last updated value is returned.
          + * If the value at the path is not a number or path doesn't + * exist, an error is raised. + *
        + * If key does not exist, an error is raised.
        + * If the result is out of the range of 64-bit IEEE double, an error is raised. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "{\"c\": [1, 2], \"d\": [1, 2, 3]}").get();
        +     * var res = Json.numincrby(client, "doc", "$.d[*]", 10.0).get();
        +     * assert res.equals("[11,12,13]"); // Increment each element in `d` array by 10.
        +     *
        +     * res = Json.numincrby(client, "doc", ".c[1]", 10.0).get();
        +     * assert res.equals("12"); // Increment the second element in the `c` array by 10.
        +     * }
        + */ + public static CompletableFuture numincrby( + @NonNull BaseClient client, @NonNull String key, @NonNull String path, Number number) { + return executeCommand(client, new String[] {JSON_NUMINCRBY, key, path, number.toString()}); + } + + /** + * Increments or decrements the JSON value(s) at the specified path by number + * within the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param number The number to increment or decrement by. + * @return + *
          + *
        • For JSONPath (path starts with $):
          + * Returns a GlideString representation of an array of strings, indicating + * the new values after incrementing for each matched path.
          + * If a value is not a number, its corresponding return value will be null. + *
          + * If path doesn't exist, a byte string representation of an empty array + * will be returned. + *
        • For legacy path (path doesn't start with $):
          + * Returns a GlideString representation of the resulting value after the + * increment or decrement.
          + * If multiple paths match, the result of the last updated value is returned.
          + * If the value at the path is not a number or path doesn't + * exist, an error is raised. + *
        + * If key does not exist, an error is raised.
        + * If the result is out of the range of 64-bit IEEE double, an error is raised. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "{\"c\": [1, 2], \"d\": [1, 2, 3]}").get();
        +     * var res = Json.numincrby(client, gs("doc"), gs("$.d[*]"), 10.0).get();
        +     * assert res.equals(gs("[11,12,13]")); // Increment each element in `d` array by 10.
        +     *
        +     * res = Json.numincrby(client, gs("doc"), gs(".c[1]"), 10.0).get();
        +     * assert res.equals(gs("12")); // Increment the second element in the `c` array by 10.
        +     * }
        + */ + public static CompletableFuture numincrby( + @NonNull BaseClient client, + @NonNull GlideString key, + @NonNull GlideString path, + Number number) { + return executeCommand( + client, new GlideString[] {gs(JSON_NUMINCRBY), key, path, gs(number.toString())}); + } + + /** + * Multiplies the JSON value(s) at the specified path by number within + * the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param number The number to multiply by. + * @return + *
          + *
        • For JSONPath (path starts with $):
          + * Returns a string representation of an array of strings, indicating the new values + * after multiplication for each matched path.
          + * If a value is not a number, its corresponding return value will be null. + *
          + * If path doesn't exist, a byte string representation of an empty array + * will be returned. + *
        • For legacy path (path doesn't start with $):
          + * Returns a string representation of the resulting value after multiplication.
          + * If multiple paths match, the result of the last updated value is returned.
          + * If the value at the path is not a number or path doesn't + * exist, an error is raised. + *
        + * If key does not exist, an error is raised.
        + * If the result is out of the range of 64-bit IEEE double, an error is raised. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "{\"c\": [1, 2], \"d\": [1, 2, 3]}").get();
        +     * var res = Json.nummultby(client, "doc", "$.d[*]", 2.0).get();
        +     * assert res.equals("[2,4,6]"); // Multiplies each element in the `d` array by 2.
        +     *
        +     * res = Json.nummultby(client, "doc", ".c[1]", 2.0).get();
        +     * assert res.equals("12"); // Multiplies the second element in the `c` array by 2.
        +     * }
        + */ + public static CompletableFuture nummultby( + @NonNull BaseClient client, @NonNull String key, @NonNull String path, Number number) { + return executeCommand(client, new String[] {JSON_NUMMULTBY, key, path, number.toString()}); + } + + /** + * Multiplies the JSON value(s) at the specified path by number within + * the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param number The number to multiply by. + * @return + *
          + *
        • For JSONPath (path starts with $):
          + * Returns a GlideString representation of an array of strings, indicating + * the new values after multiplication for each matched path.
          + * If a value is not a number, its corresponding return value will be null. + *
          + * If path doesn't exist, a byte string representation of an empty array + * will be returned. + *
        • For legacy path (path doesn't start with $):
          + * Returns a GlideString representation of the resulting value after + * multiplication.
          + * If multiple paths match, the result of the last updated value is returned.
          + * If the value at the path is not a number or path doesn't + * exist, an error is raised. + *
        + * If key does not exist, an error is raised.
        + * If the result is out of the range of 64-bit IEEE double, an error is raised. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "{\"c\": [1, 2], \"d\": [1, 2, 3]}").get();
        +     * var res = Json.nummultby(client, gs("doc"), gs("$.d[*]"), 2.0).get();
        +     * assert res.equals(gs("[2,4,6]")); // Multiplies each element in the `d` array by 2.
        +     *
        +     * res = Json.nummultby(client, gs("doc"), gs(".c[1]"), 2.0).get();
        +     * assert res.equals(gs("12")); // Multiplies the second element in the `c` array by 2.
        +     * }
        + */ + public static CompletableFuture nummultby( + @NonNull BaseClient client, + @NonNull GlideString key, + @NonNull GlideString path, + Number number) { + return executeCommand( + client, new GlideString[] {gs(JSON_NUMMULTBY), key, path, gs(number.toString())}); + } + /** * Retrieves the number of key-value pairs in the object values at the specified path * within the JSON document stored at key.
        diff --git a/java/integTest/src/test/java/glide/modules/JsonTests.java b/java/integTest/src/test/java/glide/modules/JsonTests.java index cb0d202bfd..040b8f1a9c 100644 --- a/java/integTest/src/test/java/glide/modules/JsonTests.java +++ b/java/integTest/src/test/java/glide/modules/JsonTests.java @@ -403,6 +403,301 @@ public void clear() { ExecutionException.class, () -> Json.clear(client, UUID.randomUUID().toString()).get()); } + @Test + @SneakyThrows + void numincrby() { + String key = UUID.randomUUID().toString(); + + var jsonValue = + "{" + + " \"key1\": 1," + + " \"key2\": 3.5," + + " \"key3\": {\"nested_key\": {\"key1\": [4, 5]}}," + + " \"key4\": [1, 2, 3]," + + " \"key5\": 0," + + " \"key6\": \"hello\"," + + " \"key7\": null," + + " \"key8\": {\"nested_key\": {\"key1\": 69}}," + + " \"key9\": 1.7976931348623157e308" + + "}"; + + // Set the initial JSON document at the key + assertEquals("OK", Json.set(client, key, "$", jsonValue).get()); + + // Test JSONPath + // Increment integer value (key1) by 5 + String result = Json.numincrby(client, key, "$.key1", 5).get(); + assertEquals("[6]", result); // Expect 1 + 5 = 6 + + // Increment float value (key2) by 2.5 + result = Json.numincrby(client, key, "$.key2", 2.5).get(); + assertEquals("[6]", result); // Expect 3.5 + 2.5 = 6 + + // Increment nested object (key3.nested_key.key1[0]) by 7 + result = Json.numincrby(client, key, "$.key3.nested_key.key1[1]", 7).get(); + assertEquals("[12]", result); // Expect 4 + 7 = 12 + + // Increment array element (key4[1]) by 1 + result = Json.numincrby(client, key, "$.key4[1]", 1).get(); + assertEquals("[3]", result); // Expect 2 + 1 = 3 + + // Increment zero value (key5) by 10.23 (float number) + result = Json.numincrby(client, key, "$.key5", 10.23).get(); + assertEquals("[10.23]", result); // Expect 0 + 10.23 = 10.23 + + // Increment a string value (key6) by a number + result = Json.numincrby(client, key, "$.key6", 99).get(); + assertEquals("[null]", result); // Expect null + + // Increment a None value (key7) by a number + result = Json.numincrby(client, key, "$.key7", 51).get(); + assertEquals("[null]", result); // Expect null + + // Check increment for all numbers in the document using JSON Path (First Null: key3 as an + // entire object. Second Null: The path checks under key3, which is an object, for numeric + // values). + result = Json.numincrby(client, key, "$..*", 5).get(); + assertEquals( + "[11,11,null,null,15.23,null,null,null,1.7976931348623157e+308,null,null,9,17,6,8,8,null,74]", + result); + + // Check for multiple path match in enhanced + result = Json.numincrby(client, key, "$..key1", 1).get(); + assertEquals("[12,null,75]", result); // Expect null + + // Check for non existent path in JSONPath + result = Json.numincrby(client, key, "$.key10", 51).get(); + assertEquals("[]", result); // Expect Empty Array + + // Check for non existent key in JSONPath + assertThrows( + ExecutionException.class, + () -> Json.numincrby(client, "non_existent_key", "$.key10", 51).get()); + + // Check for Overflow in JSONPath + assertThrows( + ExecutionException.class, + () -> Json.numincrby(client, key, "$.key9", 1.7976931348623157e308).get()); + + // Decrement integer value (key1) by 12 + result = Json.numincrby(client, key, "$.key1", -12).get(); + assertEquals("[0]", result); // Expect 12 - 12 = 0 + + // Decrement integer value (key1) by 0.5 + result = Json.numincrby(client, key, "$.key1", -0.5).get(); + assertEquals("[-0.5]", result); // Expect 0 - 0.5 = -0.5 + + // Test Legacy Path + // Increment float value (key1) by 5 (integer) + result = Json.numincrby(client, key, "key1", 5).get(); + assertEquals("4.5", result); // Expect -0.5 + 5 = 4.5 + + // Decrement float value (key1) by 5.5 (integer) + result = Json.numincrby(client, key, "key1", -5.5).get(); + assertEquals("-1", result); // Expect 4.5 - 5.5 = -1 + + // Increment int value (key2) by 2.5 (a float number) + result = Json.numincrby(client, key, "key2", 2.5).get(); + assertEquals("13.5", result); // Expect 11 + 2.5 = 13.5 + + // Increment nested value (key3.nested_key.key1[0]) by 7 + result = Json.numincrby(client, key, "key3.nested_key.key1[0]", 7).get(); + assertEquals("16", result); // Expect 9 + 7 = 16 + + // Increment array element (key4[1]) by 1 + result = Json.numincrby(client, key, "key4[1]", 1).get(); + assertEquals("9", result); // Expect 8 + 1 = 9 + + // Increment a float value (key5) by 10.2 (a float number) + result = Json.numincrby(client, key, "key5", 10.2).get(); + assertEquals("25.43", result); // Expect 15.23 + 10.2 = 25.43 + + // Check for multiple path match in legacy and assure that the result of the last updated value + // is returned + result = Json.numincrby(client, key, "..key1", 1).get(); + assertEquals("76", result); + + // Check if the rest of the key1 path matches were updated and not only the last value + result = Json.get(client, key, new String[] {"$..key1"}).get(); + assertEquals( + "[0,[16,17],76]", + result); // First is 0 as 0 + 0 = 0, Second doesn't change as its an array type + // (non-numeric), third is 76 as 0 + 76 = 0 + + // Check for non existent path in legacy + assertThrows(ExecutionException.class, () -> Json.numincrby(client, key, ".key10", 51).get()); + + // Check for non existent key in legacy + assertThrows( + ExecutionException.class, + () -> Json.numincrby(client, "non_existent_key", ".key10", 51).get()); + + // Check for Overflow in legacy + assertThrows( + ExecutionException.class, + () -> Json.numincrby(client, key, ".key9", 1.7976931348623157e308).get()); + + // Binary tests + // Binary integer test + GlideString binaryResult = Json.numincrby(client, gs(key), gs("key4[1]"), 1).get(); + assertEquals(gs("10"), binaryResult); // Expect 9 + 1 = 10 + + // Binary float test + binaryResult = Json.numincrby(client, gs(key), gs("key5"), 1.0).get(); + assertEquals(gs("26.43"), binaryResult); // Expect 25.43 + 1.0 = 26.43 + } + + @Test + @SneakyThrows + void nummultby() { + String key = UUID.randomUUID().toString(); + var jsonValue = + "{" + + " \"key1\": 1," + + " \"key2\": 3.5," + + " \"key3\": {\"nested_key\": {\"key1\": [4, 5]}}," + + " \"key4\": [1, 2, 3]," + + " \"key5\": 0," + + " \"key6\": \"hello\"," + + " \"key7\": null," + + " \"key8\": {\"nested_key\": {\"key1\": 69}}," + + " \"key9\": 3.5953862697246314e307" + + "}"; + + // Set the initial JSON document at the key + assertEquals("OK", Json.set(client, key, "$", jsonValue).get()); + + // Test JSONPath + // Multiply integer value (key1) by 5 + String result = Json.nummultby(client, key, "$.key1", 5).get(); + assertEquals("[5]", result); // Expect 1 * 5 = 5 + + // Multiply float value (key2) by 2.5 + result = Json.nummultby(client, key, "$.key2", 2.5).get(); + assertEquals("[8.75]", result); // Expect 3.5 * 2.5 = 8.75 + + // Multiply nested object (key3.nested_key.key1[1]) by 7 + result = Json.nummultby(client, key, "$.key3.nested_key.key1[1]", 7).get(); + assertEquals("[35]", result); // Expect 5 * 7 = 35 + + // Multiply array element (key4[1]) by 1 + result = Json.nummultby(client, key, "$.key4[1]", 1).get(); + assertEquals("[2]", result); // Expect 2 * 1 = 2 + + // Multiply zero value (key5) by 10.23 (float number) + result = Json.nummultby(client, key, "$.key5", 10.23).get(); + assertEquals("[0]", result); // Expect 0 * 10.23 = 0 + + // Multiply a string value (key6) by a number + result = Json.nummultby(client, key, "$.key6", 99).get(); + assertEquals("[null]", result); // Expect null + + // Multiply a None value (key7) by a number + result = Json.nummultby(client, key, "$.key7", 51).get(); + assertEquals("[null]", result); // Expect null + + // Check multiplication for all numbers in the document using JSON Path + // key1: 5 * 5 = 25 + // key2: 8.75 * 5 = 43.75 + // key3.nested_key.key1[0]: 4 * 5 = 20 + // key3.nested_key.key1[1]: 35 * 5 = 175 + // key4[0]: 1 * 5 = 5 + // key4[1]: 2 * 5 = 10 + // key4[2]: 3 * 5 = 15 + // key5: 0 * 5 = 0 + // key8.nested_key.key1: 69 * 5 = 345 + // key9: 3.5953862697246314e307 * 5 = 1.7976931348623157e308 + result = Json.nummultby(client, key, "$..*", 5).get(); + assertEquals( + "[25,43.75,null,null,0,null,null,null,1.7976931348623157e+308,null,null,20,175,5,10,15,null,345]", + result); + + // Check for multiple path matches in JSONPath + // key1: 25 * 2 = 50 + // key8.nested_key.key1: 345 * 2 = 690 + result = Json.nummultby(client, key, "$..key1", 2).get(); + assertEquals("[50,null,690]", result); // After previous multiplications + + // Check for non-existent path in JSONPath + result = Json.nummultby(client, key, "$.key10", 51).get(); + assertEquals("[]", result); // Expect Empty Array + + // Check for non-existent key in JSONPath + assertThrows( + ExecutionException.class, + () -> Json.nummultby(client, "non_existent_key", "$.key10", 51).get()); + + // Check for Overflow in JSONPath + assertThrows( + ExecutionException.class, + () -> Json.nummultby(client, key, "$.key9", 1.7976931348623157e308).get()); + + // Multiply integer value (key1) by -12 + result = Json.nummultby(client, key, "$.key1", -12).get(); + assertEquals("[-600]", result); // Expect 50 * -12 = -600 + + // Multiply integer value (key1) by -0.5 + result = Json.nummultby(client, key, "$.key1", -0.5).get(); + assertEquals("[300]", result); // Expect -600 * -0.5 = 300 + + // Test Legacy Path + // Multiply int value (key1) by 5 (integer) + result = Json.nummultby(client, key, "key1", 5).get(); + assertEquals("1500", result); // Expect 300 * 5 = -1500 + + // Multiply int value (key1) by -5.5 (float number) + result = Json.nummultby(client, key, "key1", -5.5).get(); + assertEquals("-8250", result); // Expect -150 * -5.5 = -8250 + + // Multiply int float (key2) by 2.5 (a float number) + result = Json.nummultby(client, key, "key2", 2.5).get(); + assertEquals("109.375", result); // Expect 43.75 * 2.5 = 109.375 + + // Multiply nested value (key3.nested_key.key1[0]) by 7 + result = Json.nummultby(client, key, "key3.nested_key.key1[0]", 7).get(); + assertEquals("140", result); // Expect 20 * 7 = 140 + + // Multiply array element (key4[1]) by 1 + result = Json.nummultby(client, key, "key4[1]", 1).get(); + assertEquals("10", result); // Expect 10 * 1 = 10 + + // Multiply a float value (key5) by 10.2 (a float number) + result = Json.nummultby(client, key, "key5", 10.2).get(); + assertEquals("0", result); // Expect 0 * 10.2 = 0 + + // Check for multiple path matches in legacy and assure that the result of the last updated + // value is returned + // last updated value is key8.nested_key.key1: 690 * 2 = 1380 + result = Json.nummultby(client, key, "..key1", 2).get(); + assertEquals("1380", result); // Expect the last updated key1 value multiplied by 2 + + // Check if the rest of the key1 path matches were updated and not only the last value + result = Json.get(client, key, new String[] {"$..key1"}).get(); + assertEquals(result, "[-16500,[140,175],1380]"); + + // Check for non-existent path in legacy + assertThrows(ExecutionException.class, () -> Json.nummultby(client, key, ".key10", 51).get()); + + // Check for non-existent key in legacy + assertThrows( + ExecutionException.class, + () -> Json.nummultby(client, "non_existent_key", ".key10", 51).get()); + + // Check for Overflow in legacy + assertThrows( + ExecutionException.class, + () -> Json.nummultby(client, key, ".key9", 1.7976931348623157e308).get()); + + // Binary tests + // Binary integer test + GlideString binaryResult = Json.nummultby(client, gs(key), gs("key4[1]"), 1).get(); + assertEquals(gs("10"), binaryResult); // Expect 10 * 1 = 10 + + // Binary float test + binaryResult = Json.nummultby(client, gs(key), gs("key5"), 10.2).get(); + assertEquals(gs("0"), binaryResult); // Expect 0 * 10.2 = 0 + } + @Test @SneakyThrows public void arrtrim() { From ddca0676d448faf396c40c6323cc09c11e2a9a76 Mon Sep 17 00:00:00 2001 From: Chloe Yip <168601573+cyip10@users.noreply.github.com> Date: Mon, 28 Oct 2024 16:08:46 -0700 Subject: [PATCH 071/180] Java: FT.EXPLAIN and FT.EXPLAINCLI (#2515) Signed-off-by: Chloe Signed-off-by: Chloe Yip Signed-off-by: Andrew Carbonetto Co-authored-by: Andrew Carbonetto --- CHANGELOG.md | 1 + .../glide/api/commands/servermodules/FT.java | 100 +++++++++++++++++ .../java/glide/modules/VectorSearchTests.java | 106 ++++++++++++++++++ 3 files changed, 207 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ccae7464c..3109af9f35 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ * Java: Added `JSON.OBJLEN` and `JSON.OBJKEYS` ([#2492](https://github.com/valkey-io/valkey-glide/pull/2492)) * Java: Added `JSON.DEL` and `JSON.FORGET` ([#2490](https://github.com/valkey-io/valkey-glide/pull/2490)) * Java: Added `FT.ALIASADD`, `FT.ALIASDEL`, `FT.ALIASUPDATE` ([#2442](https://github.com/valkey-io/valkey-glide/pull/2442)) +* Java: Added `FT.EXPLAIN`, `FT.EXPLAINCLI` ([#2515](https://github.com/valkey-io/valkey-glide/pull/2515)) * Core: Update routing for commands from server modules ([#2461](https://github.com/valkey-io/valkey-glide/pull/2461)) * Node: Added `JSON.SET` and `JSON.GET` ([#2427](https://github.com/valkey-io/valkey-glide/pull/2427)) * Java: Added `JSON.NUMINCRBY` and `JSON.NUMMULTBY` ([#2511](https://github.com/valkey-io/valkey-glide/pull/2511)) diff --git a/java/client/src/main/java/glide/api/commands/servermodules/FT.java b/java/client/src/main/java/glide/api/commands/servermodules/FT.java index 7a5dbb8714..9d1a75e9ea 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/FT.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/FT.java @@ -751,6 +751,106 @@ public static CompletableFuture aliasupdate( return executeCommand(client, args, false); } + /** + * Parse a query and return information about how that query was parsed. + * + * @param client The client to execute the command. + * @param indexName The index name to search into. + * @param query The text query to search. It is the same as the query passed as an argument to + * {@link FT#search(BaseClient, String, String)} and {@link FT#aggregate(BaseClient, String, + * String)}. + * @return A String representing the execution plan. + * @example + *
        {@code
        +     * String result = FT.explain(client, "myIndex", "@price:[0 10]").get();
        +     * assert result.equals("Field {\n\tprice\n\t0\n\t10\n}");
        +     * }
        + */ + public static CompletableFuture explain( + @NonNull BaseClient client, @NonNull String indexName, @NonNull String query) { + GlideString[] args = {gs("FT.EXPLAIN"), gs(indexName), gs(query)}; + return FT.executeCommand(client, args, false).thenApply(GlideString::toString); + } + + /** + * Parse a query and return information about how that query was parsed. + * + * @param client The client to execute the command. + * @param indexName The index name to search into. + * @param query The text query to search. It is the same as the query passed as an argument to + * {@link FT#search(BaseClient, GlideString, GlideString)} and {@link FT#aggregate(BaseClient, + * GlideString, GlideString)}. + * @return A GlideString representing the execution plan. + * @example + *
        {@code
        +     * GlideString result = FT.explain(client, gs("myIndex"), gs("@price:[0 10]")).get();
        +     * assert result.equals("Field {\n\tprice\n\t0\n\t10\n}");
        +     * }
        + */ + public static CompletableFuture explain( + @NonNull BaseClient client, @NonNull GlideString indexName, @NonNull GlideString query) { + GlideString[] args = {gs("FT.EXPLAIN"), indexName, query}; + return executeCommand(client, args, false); + } + + /** + * Same as the {@link FT#explain(BaseClient, String, String)} except that the results are + * displayed in a different format. + * + * @param client The client to execute the command. + * @param indexName The index name to search into. + * @param query The text query to search. It is the same as the query passed as an argument to + * {@link FT#search(BaseClient, String, String)} and {@link FT#aggregate(BaseClient, String, + * String)}. + * @return A String[] representing the execution plan. + * @example + *
        {@code
        +     * String[] result = FT.explaincli(client, "myIndex",  "@price:[0 10]").get();
        +     * assert Arrays.equals(result, new String[]{
        +     *   "Field {",
        +     *   "  price",
        +     *   "  0",
        +     *   "  10",
        +     *   "}"
        +     * });
        +     * }
        + */ + public static CompletableFuture explaincli( + @NonNull BaseClient client, @NonNull String indexName, @NonNull String query) { + CompletableFuture result = explaincli(client, gs(indexName), gs(query)); + return result.thenApply( + ret -> Arrays.stream(ret).map(GlideString::toString).toArray(String[]::new)); + } + + /** + * Same as the {@link FT#explain(BaseClient, String, String)} except that the results are + * displayed in a different format. + * + * @param client The client to execute the command. + * @param indexName The index name to search into. + * @param query The text query to search. It is the same as the query passed as an argument to + * {@link FT#search(BaseClient, GlideString, GlideString)} and {@link FT#aggregate(BaseClient, + * GlideString, GlideString)}. + * @return A GlideString[] representing the execution plan. + * @example + *
        {@code
        +     * GlideString[] result = FT.explaincli(client, gs("myIndex"),  gs("@price:[0 10]")).get();
        +     * assert Arrays.equals(result, new GlideString[]{
        +     *   gs("Field {"),
        +     *   gs("  price"),
        +     *   gs("  0"),
        +     *   gs("  10"),
        +     *   gs("}")
        +     * });
        +     * }
        + */ + public static CompletableFuture explaincli( + @NonNull BaseClient client, @NonNull GlideString indexName, @NonNull GlideString query) { + GlideString[] args = new GlideString[] {gs("FT.EXPLAINCLI"), indexName, query}; + return FT.executeCommand(client, args, false) + .thenApply(ret -> castArray(ret, GlideString.class)); + } + /** * A wrapper for custom command API. * diff --git a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java index 10d09a5974..75151f103b 100644 --- a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java +++ b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java @@ -26,6 +26,7 @@ import glide.api.models.commands.FT.FTAggregateOptions.SortBy.SortOrder; import glide.api.models.commands.FT.FTAggregateOptions.SortBy.SortProperty; import glide.api.models.commands.FT.FTCreateOptions; +import glide.api.models.commands.FT.FTCreateOptions.DataType; import glide.api.models.commands.FT.FTCreateOptions.DistanceMetric; import glide.api.models.commands.FT.FTCreateOptions.FieldInfo; import glide.api.models.commands.FT.FTCreateOptions.NumericField; @@ -38,11 +39,14 @@ import glide.api.models.commands.FlushMode; import glide.api.models.commands.InfoOptions.Section; import glide.api.models.exceptions.RequestException; +import java.util.Arrays; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.UUID; import java.util.concurrent.ExecutionException; +import java.util.stream.Collectors; import lombok.SneakyThrows; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; @@ -825,4 +829,106 @@ public void ft_aliasadd_aliasdel_aliasupdate() { assertInstanceOf(RequestException.class, exception.getCause()); assertTrue(exception.getMessage().contains("Index does not exist")); } + + @SneakyThrows + @Test + public void ft_explain() { + + String indexName = UUID.randomUUID().toString(); + createIndexHelper(indexName); + + // search query containing numeric field. + String query = "@price:[0 10]"; + String result = FT.explain(client, indexName, query).get(); + assertTrue(result.contains("price")); + assertTrue(result.contains("0")); + assertTrue(result.contains("10")); + + GlideString resultGS = FT.explain(client, gs(indexName), gs(query)).get(); + assertTrue((resultGS).toString().contains("price")); + assertTrue((resultGS).toString().contains("0")); + assertTrue((resultGS).toString().contains("10")); + + // search query that returns all data. + GlideString resultGSAllData = FT.explain(client, gs(indexName), gs("*")).get(); + assertTrue(resultGSAllData.toString().contains("*")); + + assertEquals(OK, FT.dropindex(client, indexName).get()); + + // missing index throws an error. + var exception = + assertThrows( + ExecutionException.class, + () -> FT.explain(client, UUID.randomUUID().toString(), "*").get()); + assertInstanceOf(RequestException.class, exception.getCause()); + assertTrue(exception.getMessage().contains("Index not found")); + } + + @SneakyThrows + @Test + public void ft_explaincli() { + + String indexName = UUID.randomUUID().toString(); + createIndexHelper(indexName); + + // search query containing numeric field. + String query = "@price:[0 10]"; + String[] result = FT.explaincli(client, indexName, query).get(); + List resultList = Arrays.stream(result).map(String::trim).collect(Collectors.toList()); + + assertTrue(resultList.contains("price")); + assertTrue(resultList.contains("0")); + assertTrue(resultList.contains("10")); + + GlideString[] resultGS = FT.explaincli(client, gs(indexName), gs(query)).get(); + List resultListGS = + Arrays.stream(resultGS) + .map(GlideString::toString) + .map(String::trim) + .collect(Collectors.toList()); + + assertTrue((resultListGS).contains("price")); + assertTrue((resultListGS).contains("0")); + assertTrue((resultListGS).contains("10")); + + // search query that returns all data. + GlideString[] resultGSAllData = FT.explaincli(client, gs(indexName), gs("*")).get(); + List resultListGSAllData = + Arrays.stream(resultGSAllData) + .map(GlideString::toString) + .map(String::trim) + .collect(Collectors.toList()); + assertTrue((resultListGSAllData).contains("*")); + + assertEquals(OK, FT.dropindex(client, indexName).get()); + + // missing index throws an error. + var exception = + assertThrows( + ExecutionException.class, + () -> FT.explaincli(client, UUID.randomUUID().toString(), "*").get()); + assertInstanceOf(RequestException.class, exception.getCause()); + assertTrue(exception.getMessage().contains("Index not found")); + } + + private void createIndexHelper(String indexName) throws ExecutionException, InterruptedException { + FieldInfo numericField = new FieldInfo("price", new NumericField()); + FieldInfo textField = new FieldInfo("title", new TextField()); + + FieldInfo[] fields = new FieldInfo[] {numericField, textField}; + + String prefix = "{hash-search-" + UUID.randomUUID().toString() + "}:"; + + assertEquals( + OK, + FT.create( + client, + indexName, + fields, + FTCreateOptions.builder() + .dataType(DataType.HASH) + .prefixes(new String[] {prefix}) + .build()) + .get()); + } } From 7b5b72d85d31a1d1519e3e03a59b3a8acd3db466 Mon Sep 17 00:00:00 2001 From: prateek-kumar-improving Date: Mon, 28 Oct 2024 18:27:24 -0700 Subject: [PATCH 072/180] Python: FT.AGGREGATE command added (#2530) * Python: FT.AGGREGATE command added --------- Signed-off-by: Prateek Kumar --- CHANGELOG.md | 1 + python/python/glide/__init__.py | 20 + .../glide/async_commands/command_args.py | 1 + .../glide/async_commands/server_modules/ft.py | 32 +- .../ft_options/ft_aggregate_options.py | 293 ++++++++++++ .../server_modules/ft_options/ft_constants.py | 15 + .../tests/tests_server_modules/test_ft.py | 439 ++++++++++++++++++ 7 files changed, 800 insertions(+), 1 deletion(-) create mode 100644 python/python/glide/async_commands/server_modules/ft_options/ft_aggregate_options.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 3109af9f35..ba8bef87c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ #### Changes +* Python: Python: FT.AGGREGATE command added([#2530](https://github.com/valkey-io/valkey-glide/pull/2530)) * Python: Add JSON.OBJLEN command ([#2495](https://github.com/valkey-io/valkey-glide/pull/2495)) * Python: FT.EXPLAIN and FT.EXPLAINCLI commands added([#2508](https://github.com/valkey-io/valkey-glide/pull/2508)) * Python: Python FT.INFO command added([#2429](https://github.com/valkey-io/valkey-glide/pull/2494)) diff --git a/python/python/glide/__init__.py b/python/python/glide/__init__.py index a4fdc27d67..9490bfe5b6 100644 --- a/python/python/glide/__init__.py +++ b/python/python/glide/__init__.py @@ -33,6 +33,17 @@ UpdateOptions, ) from glide.async_commands.server_modules import ft, json +from glide.async_commands.server_modules.ft_options.ft_aggregate_options import ( + FtAggregateApply, + FtAggregateClause, + FtAggregateFilter, + FtAggregateGroupBy, + FtAggregateLimit, + FtAggregateOptions, + FtAggregateReducer, + FtAggregateSortBy, + FtAggregateSortProperty, +) from glide.async_commands.server_modules.ft_options.ft_create_options import ( DataType, DistanceMetricType, @@ -273,4 +284,13 @@ "FtSearchLimit", "ReturnField", "FtSeachOptions", + "FtAggregateApply", + "FtAggregateFilter", + "FtAggregateClause", + "FtAggregateLimit", + "FtAggregateOptions", + "FtAggregateGroupBy", + "FtAggregateReducer", + "FtAggregateSortBy", + "FtAggregateSortProperty", ] diff --git a/python/python/glide/async_commands/command_args.py b/python/python/glide/async_commands/command_args.py index 92e0100665..05efe925ab 100644 --- a/python/python/glide/async_commands/command_args.py +++ b/python/python/glide/async_commands/command_args.py @@ -36,6 +36,7 @@ class OrderBy(Enum): This enum is used for the following commands: - `SORT`: General sorting in ascending or descending order. - `GEOSEARCH`: Sorting items based on their proximity to a center point. + - `FT.AGGREGATE`: Used in the SortBy clause of the FT.AGGREGATE command. """ ASC = "ASC" diff --git a/python/python/glide/async_commands/server_modules/ft.py b/python/python/glide/async_commands/server_modules/ft.py index cce57bd727..6240d626f7 100644 --- a/python/python/glide/async_commands/server_modules/ft.py +++ b/python/python/glide/async_commands/server_modules/ft.py @@ -3,8 +3,11 @@ module for `vector search` commands. """ -from typing import List, Mapping, Optional, Union, cast +from typing import Any, List, Mapping, Optional, Union, cast +from glide.async_commands.server_modules.ft_options.ft_aggregate_options import ( + FtAggregateOptions, +) from glide.async_commands.server_modules.ft_options.ft_constants import ( CommandNames, FtCreateKeywords, @@ -276,3 +279,30 @@ async def explaincli( """ args: List[TEncodable] = [CommandNames.FT_EXPLAINCLI, indexName, query] return cast(List[TEncodable], await client.custom_command(args)) + + +async def aggregate( + client: TGlideClient, + indexName: TEncodable, + query: TEncodable, + options: Optional[FtAggregateOptions], +) -> List[Mapping[TEncodable, Any]]: + """ + A superset of the FT.SEARCH command, it allows substantial additional processing of the keys selected by the query expression. + Args: + client (TGlideClient): The client to execute the command. + indexName (TEncodable): The index name for which the query is written. + query (TEncodable): The search query, same as the query passed as an argument to FT.SEARCH. + options (Optional[FtAggregateOptions]): The optional arguments for the command. + Returns: + List[Mapping[TEncodable, Any]]: An array containing a mapping of field name and associated value as returned after the last stage of the command. + + Examples: + >>> from glide import ft + >>> result = await ft.aggregate(glide_client, myIndex"", "*", FtAggregateOptions(loadFields=["__key"], clauses=[GroupBy(["@condition"], [Reducer("COUNT", [], "bicycles")])])) + [{b'condition': b'refurbished', b'bicycles': b'1'}, {b'condition': b'new', b'bicycles': b'5'}, {b'condition': b'used', b'bicycles': b'4'}] + """ + args: List[TEncodable] = [CommandNames.FT_AGGREGATE, indexName, query] + if options: + args.extend(options.to_args()) + return cast(List[Mapping[TEncodable, Any]], await client.custom_command(args)) diff --git a/python/python/glide/async_commands/server_modules/ft_options/ft_aggregate_options.py b/python/python/glide/async_commands/server_modules/ft_options/ft_aggregate_options.py new file mode 100644 index 0000000000..c121c10985 --- /dev/null +++ b/python/python/glide/async_commands/server_modules/ft_options/ft_aggregate_options.py @@ -0,0 +1,293 @@ +# Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from enum import Enum +from typing import List, Mapping, Optional + +from glide.async_commands.command_args import OrderBy +from glide.async_commands.server_modules.ft_options.ft_constants import ( + FtAggregateKeywords, +) +from glide.constants import TEncodable + + +class FtAggregateClause(ABC): + """ + Abstract base class for the FT.AGGREGATE command clauses. + """ + + @abstractmethod + def to_args(self) -> List[TEncodable]: + """ + Get the arguments for the clause of the FT.AGGREGATE command. + + Returns: + List[TEncodable]: A list of arguments for the clause of the FT.AGGREGATE command. + """ + args: List[TEncodable] = [] + return args + + +class FtAggregateLimit(FtAggregateClause): + """ + A clause for limiting the number of retained records. + """ + + def __init__(self, offset: int, count: int): + """ + Initialize a new FtAggregateLimit instance. + + Args: + offset (int): Starting point from which the records have to be retained. + count (int): The total number of records to be retained. + """ + self.offset = offset + self.count = count + + def to_args(self) -> List[TEncodable]: + """ + Get the arguments for the Limit clause. + + Returns: + List[TEncodable]: A list of Limit clause arguments. + """ + return [FtAggregateKeywords.LIMIT, str(self.offset), str(self.count)] + + +class FtAggregateFilter(FtAggregateClause): + """ + A clause for filtering the results using predicate expression relating to values in each result. It is applied post query and relate to the current state of the pipeline. + """ + + def __init__(self, expression: TEncodable): + """ + Initialize a new FtAggregateFilter instance. + + Args: + expression (TEncodable): The expression to filter the results. + """ + self.expression = expression + + def to_args(self) -> List[TEncodable]: + """ + Get the arguments for the Filter clause. + + Returns: + List[TEncodable]: A list arguments for the filter clause. + """ + return [FtAggregateKeywords.FILTER, self.expression] + + +class FtAggregateReducer: + """ + A clause for reducing the matching results in each group using a reduction function. The matching results are reduced into a single record. + """ + + def __init__( + self, + function: TEncodable, + args: List[TEncodable], + name: Optional[TEncodable] = None, + ): + """ + Initialize a new FtAggregateReducer instance. + + Args: + function (TEncodable): The reduction function names for the respective group. + args (List[TEncodable]): The list of arguments for the reducer. + name (Optional[TEncodable]): User defined property name for the reducer. + """ + self.function = function + self.args = args + self.name = name + + def to_args(self) -> List[TEncodable]: + """ + Get the arguments for the Reducer. + + Returns: + List[TEncodable]: A list of arguments for the reducer. + """ + args: List[TEncodable] = [ + FtAggregateKeywords.REDUCE, + self.function, + str(len(self.args)), + ] + self.args + if self.name: + args.extend([FtAggregateKeywords.AS, self.name]) + return args + + +class FtAggregateGroupBy(FtAggregateClause): + """ + A clause for grouping the results in the pipeline based on one or more properties. + """ + + def __init__( + self, properties: List[TEncodable], reducers: List[FtAggregateReducer] + ): + """ + Initialize a new FtAggregateGroupBy instance. + + Args: + properties (List[TEncodable]): The list of properties to be used for grouping the results in the pipeline. + reducers (List[Reducer]): The list of functions that handles the group entries by performing multiple aggregate operations. + """ + self.properties = properties + self.reducers = reducers + + def to_args(self) -> List[TEncodable]: + args = [ + FtAggregateKeywords.GROUPBY, + str(len(self.properties)), + ] + self.properties + if self.reducers: + for reducer in self.reducers: + args.extend(reducer.to_args()) + return args + + +class FtAggregateSortProperty: + """ + This class represents the a single property for the SortBy clause. + """ + + def __init__(self, property: TEncodable, order: OrderBy): + """ + Initialize a new FtAggregateSortProperty instance. + + Args: + property (TEncodable): The sorting parameter. + order (OrderBy): The order for the sorting. This option can be added for each property. + """ + self.property = property + self.order = order + + def to_args(self) -> List[TEncodable]: + """ + Get the arguments for the SortBy clause property. + + Returns: + List[TEncodable]: A list of arguments for the SortBy clause property. + """ + return [self.property, self.order.value] + + +class FtAggregateSortBy(FtAggregateClause): + """ + A clause for sorting the pipeline up until the point of SORTBY, using a list of properties. + """ + + def __init__( + self, properties: List[FtAggregateSortProperty], max: Optional[int] = None + ): + """ + Initialize a new FtAggregateSortBy instance. + + Args: + properties (List[FtAggregateSortProperty]): A list of sorting parameters for the sort operation. + max: (Optional[int]): The MAX value for optimizing the sorting, by sorting only for the n-largest elements. + """ + self.properties = properties + self.max = max + + def to_args(self) -> List[TEncodable]: + """ + Get the arguments for the SortBy clause. + + Returns: + List[TEncodable]: A list of arguments for the SortBy clause. + """ + args: List[TEncodable] = [ + FtAggregateKeywords.SORTBY, + str(len(self.properties) * 2), + ] + for property in self.properties: + args.extend(property.to_args()) + if self.max: + args.extend([FtAggregateKeywords.MAX, str(self.max)]) + return args + + +class FtAggregateApply(FtAggregateClause): + """ + A clause for applying a 1-to-1 transformation on one or more properties and stores the result as a new property down the pipeline or replaces any property using this transformation. + """ + + def __init__(self, expression: TEncodable, name: TEncodable): + """ + Initialize a new FtAggregateApply instance. + + Args: + expression (TEncodable): The expression to be transformed. + name (TEncodable): The new property name to store the result of apply. This name can be referenced by further APPLY/SORTBY/GROUPBY/REDUCE operations down the pipeline. + """ + self.expression = expression + self.name = name + + def to_args(self) -> List[TEncodable]: + """ + Get the arguments for the Apply clause. + + Returns: + List[TEncodable]: A list of arguments for the Apply clause. + """ + return [ + FtAggregateKeywords.APPLY, + self.expression, + FtAggregateKeywords.AS, + self.name, + ] + + +class FtAggregateOptions: + """ + This class represents the optional arguments for the FT.AGGREGATE command. + """ + + def __init__( + self, + loadAll: Optional[bool] = False, + loadFields: Optional[List[TEncodable]] = [], + timeout: Optional[int] = None, + params: Optional[Mapping[TEncodable, TEncodable]] = {}, + clauses: Optional[List[FtAggregateClause]] = [], + ): + """ + Initialize a new FtAggregateOptions instance. + + Args: + loadAll (Optional[bool]): An option to load all fields declared in the index. + loadFields (Optional[List[TEncodable]]): An option to load only the fields passed in this list. + timeout (Optional[int]): Overrides the timeout parameter of the module. + params (Optional[Mapping[TEncodable, TEncodable]]): The key/value pairs can be referenced from within the query expression. + clauses (Optional[List[FtAggregateClause]]): FILTER, LIMIT, GROUPBY, SORTBY and APPLY clauses, that can be repeated multiple times in any order and be freely intermixed. They are applied in the order specified, with the output of one clause feeding the input of the next clause. + """ + self.loadAll = loadAll + self.loadFields = loadFields + self.timeout = timeout + self.params = params + self.clauses = clauses + + def to_args(self) -> List[TEncodable]: + """ + Get the optional arguments for the FT.AGGREGATE command. + + Returns: + List[TEncodable]: A list of optional arguments for the FT.AGGREGATE command. + """ + args: List[TEncodable] = [] + if self.loadAll: + args.extend([FtAggregateKeywords.LOAD, "*"]) + elif self.loadFields: + args.extend([FtAggregateKeywords.LOAD, str(len(self.loadFields))]) + args.extend(self.loadFields) + if self.timeout: + args.extend([FtAggregateKeywords.TIMEOUT, str(self.timeout)]) + if self.params: + args.extend([FtAggregateKeywords.PARAMS, str(len(self.params) * 2)]) + for [name, value] in self.params.items(): + args.extend([name, value]) + if self.clauses: + for clause in self.clauses: + args.extend(clause.to_args()) + return args diff --git a/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py b/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py index 1755c6136e..fd703ffcaf 100644 --- a/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py +++ b/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py @@ -15,6 +15,7 @@ class CommandNames: FT_ALIASUPDATE = "FT.ALIASUPDATE" FT_EXPLAIN = "FT.EXPLAIN" FT_EXPLAINCLI = "FT.EXPLAINCLI" + FT_AGGREGATE = "FT.AGGREGATE" class FtCreateKeywords: @@ -51,3 +52,17 @@ class FtSeachKeywords: LIMIT = "LIMIT" COUNT = "COUNT" AS = "AS" + + +class FtAggregateKeywords: + LIMIT = "LIMIT" + FILTER = "FILTER" + GROUPBY = "GROUPBY" + REDUCE = "REDUCE" + AS = "AS" + SORTBY = "SORTBY" + MAX = "MAX" + APPLY = "APPLY" + LOAD = "LOAD" + TIMEOUT = "TIMEOUT" + PARAMS = "PARAMS" diff --git a/python/python/tests/tests_server_modules/test_ft.py b/python/python/tests/tests_server_modules/test_ft.py index ea6ad8261b..77d0f4079d 100644 --- a/python/python/tests/tests_server_modules/test_ft.py +++ b/python/python/tests/tests_server_modules/test_ft.py @@ -1,15 +1,28 @@ # Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 +import time import uuid from typing import List, Mapping, Union, cast import pytest +from glide.async_commands.command_args import OrderBy from glide.async_commands.server_modules import ft +from glide.async_commands.server_modules import json as GlideJson +from glide.async_commands.server_modules.ft_options.ft_aggregate_options import ( + FtAggregateApply, + FtAggregateClause, + FtAggregateGroupBy, + FtAggregateOptions, + FtAggregateReducer, + FtAggregateSortBy, + FtAggregateSortProperty, +) from glide.async_commands.server_modules.ft_options.ft_create_options import ( DataType, DistanceMetricType, Field, FtCreateOptions, NumericField, + TagField, TextField, VectorAlgorithm, VectorField, @@ -35,6 +48,8 @@ class TestFt: ] ] + sleep_wait_time = 1 # This value is in seconds + @pytest.mark.parametrize("cluster_mode", [True]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) async def test_ft_aliasadd(self, glide_client: GlideClusterClient): @@ -303,3 +318,427 @@ async def _create_test_index_for_ft_explain_commands( ) == OK ) + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_aggregate_with_bicycles_data( + self, glide_client: GlideClusterClient, protocol + ): + prefixBicycles = "{bicycles}:" + indexBicycles = prefixBicycles + str(uuid.uuid4()) + await TestFt._create_index_for_ft_aggregate_with_bicycles_data( + self=self, + glide_client=glide_client, + index_name=indexBicycles, + prefix=prefixBicycles, + ) + await TestFt._create_json_keys_for_ft_aggregate_with_bicycles_data( + self=self, glide_client=glide_client, prefix=prefixBicycles + ) + time.sleep(self.sleep_wait_time) + + # Run FT.AGGREGATE command with the following arguments: ['FT.AGGREGATE', '{bicycles}:1e15faab-a870-488e-b6cd-f2b76c6916a3', '*', 'LOAD', '1', '__key', 'GROUPBY', '1', '@condition', 'REDUCE', 'COUNT', '0', 'AS', 'bicycles'] + result = await ft.aggregate( + glide_client, + indexName=indexBicycles, + query="*", + options=FtAggregateOptions( + loadFields=["__key"], + clauses=[ + FtAggregateGroupBy( + ["@condition"], [FtAggregateReducer("COUNT", [], "bicycles")] + ) + ], + ), + ) + assert await ft.dropindex(glide_client, indexName=indexBicycles) == OK + sortedResult = sorted(result, key=lambda x: (x[b"condition"], x[b"bicycles"])) + + expectedResult = sorted( + [ + { + b"condition": b"refurbished", + b"bicycles": b"1" if (protocol == ProtocolVersion.RESP2) else 1.0, + }, + { + b"condition": b"new", + b"bicycles": b"5" if (protocol == ProtocolVersion.RESP2) else 5.0, + }, + { + b"condition": b"used", + b"bicycles": b"4" if (protocol == ProtocolVersion.RESP2) else 4.0, + }, + ], + key=lambda x: (x[b"condition"], x[b"bicycles"]), + ) + assert sortedResult == expectedResult + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_aggregate_with_movies_data( + self, glide_client: GlideClusterClient, protocol + ): + prefixMovies = "{movies}:" + indexMovies = prefixMovies + str(uuid.uuid4()) + # Create index for movies data. + await TestFt._create_index_for_ft_aggregate_with_movies_data( + self=self, + glide_client=glide_client, + index_name=indexMovies, + prefix=prefixMovies, + ) + # Set JSON keys with movies data. + await TestFt._create_hash_keys_for_ft_aggregate_with_movies_data( + self=self, glide_client=glide_client, prefix=prefixMovies + ) + # Wait for index to be updated. + time.sleep(self.sleep_wait_time) + + # Run FT.AGGREGATE command with the following arguments: + # ['FT.AGGREGATE', '{movies}:5a0e6257-3488-4514-96f2-f4c80f6cb0a9', '*', 'LOAD', '*', 'APPLY', 'ceil(@rating)', 'AS', 'r_rating', 'GROUPBY', '1', '@genre', 'REDUCE', 'COUNT', '0', 'AS', 'nb_of_movies', 'REDUCE', 'SUM', '1', 'votes', 'AS', 'nb_of_votes', 'REDUCE', 'AVG', '1', 'r_rating', 'AS', 'avg_rating', 'SORTBY', '4', '@avg_rating', 'DESC', '@nb_of_votes', 'DESC'] + + result = await ft.aggregate( + glide_client, + indexName=indexMovies, + query="*", + options=FtAggregateOptions( + loadAll=True, + clauses=[ + FtAggregateApply(expression="ceil(@rating)", name="r_rating"), + FtAggregateGroupBy( + ["@genre"], + [ + FtAggregateReducer("COUNT", [], "nb_of_movies"), + FtAggregateReducer("SUM", ["votes"], "nb_of_votes"), + FtAggregateReducer("AVG", ["r_rating"], "avg_rating"), + ], + ), + FtAggregateSortBy( + properties=[ + FtAggregateSortProperty("@avg_rating", OrderBy.DESC), + FtAggregateSortProperty("@nb_of_votes", OrderBy.DESC), + ] + ), + ], + ), + ) + assert await ft.dropindex(glide_client, indexName=indexMovies) == OK + sortedResult = sorted( + result, + key=lambda x: ( + x[b"genre"], + x[b"nb_of_movies"], + x[b"nb_of_votes"], + x[b"avg_rating"], + ), + ) + expectedResultSet = sorted( + [ + { + b"genre": b"Drama", + b"nb_of_movies": ( + b"1" if (protocol == ProtocolVersion.RESP2) else 1.0 + ), + b"nb_of_votes": ( + b"1563839" if (protocol == ProtocolVersion.RESP2) else 1563839.0 + ), + b"avg_rating": ( + b"10" if (protocol == ProtocolVersion.RESP2) else 10.0 + ), + }, + { + b"genre": b"Action", + b"nb_of_movies": ( + b"2" if (protocol == ProtocolVersion.RESP2) else 2.0 + ), + b"nb_of_votes": ( + b"2033895" if (protocol == ProtocolVersion.RESP2) else 2033895.0 + ), + b"avg_rating": b"9" if (protocol == ProtocolVersion.RESP2) else 9.0, + }, + { + b"genre": b"Thriller", + b"nb_of_movies": ( + b"1" if (protocol == ProtocolVersion.RESP2) else 1.0 + ), + b"nb_of_votes": ( + b"559490" if (protocol == ProtocolVersion.RESP2) else 559490.0 + ), + b"avg_rating": b"9" if (protocol == ProtocolVersion.RESP2) else 9.0, + }, + ], + key=lambda x: ( + x[b"genre"], + x[b"nb_of_movies"], + x[b"nb_of_votes"], + x[b"avg_rating"], + ), + ) + assert expectedResultSet == sortedResult + + async def _create_index_for_ft_aggregate_with_bicycles_data( + self, glide_client: GlideClusterClient, index_name: TEncodable, prefix + ): + fields: List[Field] = [ + TextField("$.model", "model"), + TextField("$.description", "description"), + NumericField("$.price", "price"), + TagField("$.condition", "condition", ","), + ] + assert ( + await ft.create( + glide_client, + index_name, + fields, + FtCreateOptions(DataType.JSON, prefixes=[prefix]), + ) + == OK + ) + + async def _create_json_keys_for_ft_aggregate_with_bicycles_data( + self, glide_client: GlideClusterClient, prefix + ): + assert ( + await GlideJson.set( + glide_client, + prefix + "0", + ".", + '{"brand": "Velorim", "model": "Jigger", "price": 270, "description":' + + ' "Small and powerful, the Jigger is the best ride for the smallest of tikes!' + + " This is the tiniest kids\\u2019 pedal bike on the market available without a" + + " coaster brake, the Jigger is the vehicle of choice for the rare tenacious" + + ' little rider raring to go.", "condition": "new"}', + ) + == OK + ) + + assert ( + await GlideJson.set( + glide_client, + prefix + "1", + ".", + '{"brand": "Bicyk", "model": "Hillcraft", "price": 1200, "description":' + + ' "Kids want to ride with as little weight as possible. Especially on an' + + ' incline! They may be at the age when a 27.5\\" wheel bike is just too clumsy' + + ' coming off a 24\\" bike. The Hillcraft 26 is just the solution they need!",' + + ' "condition": "used"}', + ) + == OK + ) + + assert ( + await GlideJson.set( + glide_client, + prefix + "2", + ".", + '{"brand": "Nord", "model": "Chook air 5", "price": 815, "description":' + + ' "The Chook Air 5 gives kids aged six years and older a durable and' + + " uberlight mountain bike for their first experience on tracks and easy" + + " cruising through forests and fields. The lower top tube makes it easy to" + + " mount and dismount in any situation, giving your kids greater safety on the" + + ' trails.", "condition": "used"}', + ) + == OK + ) + + assert ( + await GlideJson.set( + glide_client, + prefix + "3", + ".", + '{"brand": "Eva", "model": "Eva 291", "price": 3400, "description": "The' + + " sister company to Nord, Eva launched in 2005 as the first and only" + + " women-dedicated bicycle brand. Designed by women for women, allEva bikes are" + + " optimized for the feminine physique using analytics from a body metrics" + + " database. If you like 29ers, try the Eva 291. It\\u2019s a brand new bike for" + + " 2022.. This full-suspension, cross-country ride has been designed for" + + " velocity. The 291 has 100mm of front and rear travel, a superlight aluminum" + + ' frame and fast-rolling 29-inch wheels. Yippee!", "condition": "used"}', + ) + == OK + ) + + assert ( + await GlideJson.set( + glide_client, + prefix + "4", + ".", + '{"brand": "Noka Bikes", "model": "Kahuna", "price": 3200, "description":' + + ' "Whether you want to try your hand at XC racing or are looking for a lively' + + " trail bike that's just as inspiring on the climbs as it is over rougher" + + " ground, the Wilder is one heck of a bike built specifically for short women." + + " Both the frames and components have been tweaked to include a women\\u2019s" + + ' saddle, different bars and unique colourway.", "condition": "used"}', + ) + == OK + ) + + assert ( + await GlideJson.set( + glide_client, + prefix + "5", + ".", + '{"brand": "Breakout", "model": "XBN 2.1 Alloy", "price": 810,' + + ' "description": "The XBN 2.1 Alloy is our entry-level road bike \\u2013 but' + + " that\\u2019s not to say that it\\u2019s a basic machine. With an internal" + + " weld aluminium frame, a full carbon fork, and the slick-shifting Claris gears" + + " from Shimano\\u2019s, this is a bike which doesn\\u2019t break the bank and" + + ' delivers craved performance.", "condition": "new"}', + ) + == OK + ) + + assert ( + await GlideJson.set( + glide_client, + prefix + "6", + ".", + '{"brand": "ScramBikes", "model": "WattBike", "price": 2300,' + + ' "description": "The WattBike is the best e-bike for people who still feel' + + " young at heart. It has a Bafang 1000W mid-drive system and a 48V 17.5AH" + + " Samsung Lithium-Ion battery, allowing you to ride for more than 60 miles on" + + " one charge. It\\u2019s great for tackling hilly terrain or if you just fancy" + + " a more leisurely ride. With three working modes, you can choose between" + + ' E-bike, assisted bicycle, and normal bike modes.", "condition": "new"}', + ) + == OK + ) + + assert ( + await GlideJson.set( + glide_client, + prefix + "7", + ".", + '{"brand": "Peaknetic", "model": "Secto", "price": 430, "description":' + + ' "If you struggle with stiff fingers or a kinked neck or back after a few' + " minutes on the road, this lightweight, aluminum bike alleviates those issues" + " and allows you to enjoy the ride. From the ergonomic grips to the" + " lumbar-supporting seat position, the Roll Low-Entry offers incredible" + " comfort. The rear-inclined seat tube facilitates stability by allowing you to" + " put a foot on the ground to balance at a stop, and the low step-over frame" + " makes it accessible for all ability and mobility levels. The saddle is very" + " soft, with a wide back to support your hip joints and a cutout in the center" + " to redistribute that pressure. Rim brakes deliver satisfactory braking" + " control, and the wide tires provide a smooth, stable ride on paved roads and" + " gravel. Rack and fender mounts facilitate setting up the Roll Low-Entry as" + " your preferred commuter, and the BMX-like handlebar offers space for mounting" + ' a flashlight, bell, or phone holder.", "condition": "new"}', + ) + == OK + ) + + assert ( + await GlideJson.set( + glide_client, + prefix + "8", + ".", + '{"brand": "nHill", "model": "Summit", "price": 1200, "description":' + + ' "This budget mountain bike from nHill performs well both on bike paths and' + + " on the trail. The fork with 100mm of travel absorbs rough terrain. Fat Kenda" + + " Booster tires give you grip in corners and on wet trails. The Shimano Tourney" + + " drivetrain offered enough gears for finding a comfortable pace to ride" + + " uphill, and the Tektro hydraulic disc brakes break smoothly. Whether you want" + + " an affordable bike that you can take to work, but also take trail in" + + " mountains on the weekends or you\\u2019re just after a stable, comfortable" + + ' ride for the bike path, the Summit gives a good value for money.",' + + ' "condition": "new"}', + ) + == OK + ) + + assert ( + await GlideJson.set( + glide_client, + prefix + "9", + ".", + '{"model": "ThrillCycle", "brand": "BikeShind", "price": 815,' + + ' "description": "An artsy, retro-inspired bicycle that\\u2019s as' + + " functional as it is pretty: The ThrillCycle steel frame offers a smooth ride." + + " A 9-speed drivetrain has enough gears for coasting in the city, but we" + + " wouldn\\u2019t suggest taking it to the mountains. Fenders protect you from" + + " mud, and a rear basket lets you transport groceries, flowers and books. The" + + " ThrillCycle comes with a limited lifetime warranty, so this little guy will" + + ' last you long past graduation.", "condition": "refurbished"}', + ) + == OK + ) + + async def _create_index_for_ft_aggregate_with_movies_data( + self, glide_client: GlideClusterClient, index_name: TEncodable, prefix + ): + fields: List[Field] = [ + TextField("title"), + NumericField("release_year"), + NumericField("rating"), + TagField("genre"), + NumericField("votes"), + ] + assert ( + await ft.create( + glide_client, + index_name, + fields, + FtCreateOptions(DataType.HASH, prefixes=[prefix]), + ) + == OK + ) + + async def _create_hash_keys_for_ft_aggregate_with_movies_data( + self, glide_client: GlideClusterClient, prefix + ): + await glide_client.hset( + prefix + "11002", + { + "title": "Star Wars: Episode V - The Empire Strikes Back", + "plot": "After the Rebels are brutally overpowered by the Empire on the ice planet Hoth," + + " Luke Skywalker begins Jedi training with Yoda, while his friends are" + + " pursued by Darth Vader and a bounty hunter named Boba Fett all over the" + + " galaxy.", + "release_year": "1980", + "genre": "Action", + "rating": "8.7", + "votes": "1127635", + "imdb_id": "tt0080684", + }, + ) + + await glide_client.hset( + prefix + "11003", + { + "title": "The Godfather", + "plot": "The aging patriarch of an organized crime dynasty transfers control of his" + + " clandestine empire to his reluctant son.", + "release_year": "1972", + "genre": "Drama", + "rating": "9.2", + "votes": "1563839", + "imdb_id": "tt0068646", + }, + ) + + await glide_client.hset( + prefix + "11004", + { + "title": "Heat", + "plot": "A group of professional bank robbers start to feel the heat from police when they" + + " unknowingly leave a clue at their latest heist.", + "release_year": "1995", + "genre": "Thriller", + "rating": "8.2", + "votes": "559490", + "imdb_id": "tt0113277", + }, + ) + + await glide_client.hset( + prefix + "11005", + { + "title": "Star Wars: Episode VI - Return of the Jedi", + "plot": "The Rebels dispatch to Endor to destroy the second Empire's Death Star.", + "release_year": "1983", + "genre": "Action", + "rating": "8.3", + "votes": "906260", + "imdb_id": "tt0086190", + }, + ) From 53743a9252a8493682eaf1258b5e69f6f55fca03 Mon Sep 17 00:00:00 2001 From: James Xin Date: Mon, 28 Oct 2024 18:40:47 -0700 Subject: [PATCH 073/180] Java: add JSON.TYPE (#2525) * Java: add JSON.TYPE --------- Signed-off-by: James Xin --- CHANGELOG.md | 1 + .../api/commands/servermodules/Json.java | 108 ++++++++++++++++++ .../api/commands/servermodules/JsonTest.java | 84 ++++++++++++++ .../test/java/glide/modules/JsonTests.java | 41 +++++++ 4 files changed, 234 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ba8bef87c4..3854d89b28 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,7 @@ * Java: Added `JSON.CLEAR` ([#2519](https://github.com/valkey-io/valkey-glide/pull/2519)) * Node: Added `JSON.TYPE` ([#2510](https://github.com/valkey-io/valkey-glide/pull/2510)) * Java: Added `JSON.RESP` ([#2513](https://github.com/valkey-io/valkey-glide/pull/2513)) +* Java: Added `JSON.TYPE` ([#2525](https://github.com/valkey-io/valkey-glide/pull/2525)) * Node: Added `FT.DROPINDEX` ([#2516](https://github.com/valkey-io/valkey-glide/pull/2516)) * Node: Added `JSON.RESP` ([#2517](https://github.com/valkey-io/valkey-glide/pull/2517)) * Python: Add `JSON.STRAPPEND` , `JSON.STRLEN` commands ([#2372](https://github.com/valkey-io/valkey-glide/pull/2372)) diff --git a/java/client/src/main/java/glide/api/commands/servermodules/Json.java b/java/client/src/main/java/glide/api/commands/servermodules/Json.java index 5edf8b76d3..939557307d 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/Json.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/Json.java @@ -2495,6 +2495,114 @@ public static CompletableFuture resp( return executeCommand(client, new GlideString[] {gs(JSON_RESP), key, path}); } + /** + * Retrieves the type of the JSON value at the root of the JSON document stored at key + * . + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return Returns the type of the JSON value at root. If key doesn't exist, + * null is returned. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "[1, 2, 3]");
        +     * assertEquals("array", Json.type(client, "doc").get());
        +     *
        +     * Json.set(client, "doc", "$", "{\"a\": 1}");
        +     * assertEquals("object", Json.type(client, "doc").get());
        +     *
        +     * assertNull(Json.type(client, "non_existing_key").get());
        +     * }
        + */ + public static CompletableFuture type(@NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new String[] {JSON_TYPE, key}); + } + + /** + * Retrieves the type of the JSON value at the root of the JSON document stored at key + * . + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return Returns the type of the JSON value at root. If key doesn't exist, + * null is returned. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "[1, 2, 3]");
        +     * assertEquals(gs("array"), Json.type(client, gs("doc")).get());
        +     *
        +     * Json.set(client, "doc", "$", "{\"a\": 1}");
        +     * assertEquals(gs("object"), Json.type(client, gs("doc")).get());
        +     *
        +     * assertNull(Json.type(client, gs("non_existing_key")).get());
        +     * }
        + */ + public static CompletableFuture type( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new GlideString[] {gs(JSON_TYPE), key}); + } + + /** + * Retrieves the type of the JSON value at the specified path within the JSON + * document stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the type will be retrieved. + * @return + *
          + *
        • For JSONPath (path starts with $): Returns a list of string + * replies for every possible path, indicating the type of the JSON value. If `path` + * doesn't exist, an empty array will be returned. + *
        • For legacy path (path doesn't starts with $): Returns the + * type of the JSON value at `path`. If multiple paths match, the type of the first JSON + * value match is returned. If `path` doesn't exist, null will be returned. + *
        + * If key doesn't exist, null is returned. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}}");
        +     * assertArrayEquals(new Object[]{"object"}, (Object[]) Json.type(client, key, "$.nested").get());
        +     * assertArrayEquals(new Object[]{"integer"}, (Object[]) Json.type(client, key, "$.nested.a").get());
        +     * assertArrayEquals(new Object[]{"integer", "object"}, (Object[]) Json.type(client, key, "$[*]").get());
        +     * }
        + */ + public static CompletableFuture type( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + + return executeCommand(client, new String[] {JSON_TYPE, key, path}); + } + + /** + * Retrieves the type of the JSON value at the specified path within the JSON + * document stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the type will be retrieved. + * @return + *
          + *
        • For JSONPath (path starts with $): Returns a list of string + * replies for every possible path, indicating the type of the JSON value. If `path` + * doesn't exist, an empty array will be returned. + *
        • For legacy path (path doesn't starts with $): Returns the + * type of the JSON value at `path`. If multiple paths match, the type of the first JSON + * value match is returned. If `path` doesn't exist, null will be returned. + *
        + * If key doesn't exist, null is returned. + * @example + *
        {@code
        +     * Json.set(client, "doc", "$", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}}");
        +     * assertArrayEquals(new Object[]{gs("object")}, (Object[]) Json.type(client, gs(key), gs("$.nested")).get());
        +     * assertArrayEquals(new Object[]{gs("integer")}, (Object[]) Json.type(client, gs(key), gs("$.nested.a")).get());
        +     * assertArrayEquals(new Object[]{gs("integer"), gs("object")}, (Object[]) Json.type(client, gs(key), gs("$[*]")).get());
        +     * }
        + */ + public static CompletableFuture type( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand(client, new GlideString[] {gs(JSON_TYPE), key, path}); + } + /** * A wrapper for custom command API. * diff --git a/java/client/src/test/java/glide/api/commands/servermodules/JsonTest.java b/java/client/src/test/java/glide/api/commands/servermodules/JsonTest.java index f05a9b2bb6..884f1bed27 100644 --- a/java/client/src/test/java/glide/api/commands/servermodules/JsonTest.java +++ b/java/client/src/test/java/glide/api/commands/servermodules/JsonTest.java @@ -598,4 +598,88 @@ void resp_binary_with_path_returns_success() { assertEquals(expectedResponse, actualResponse); assertEquals(expectedResponseValue, actualResponseValue); } + + @Test + @SneakyThrows + void type_without_path_returns_success() { + // setup + String key = "testKey"; + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "foo"; + expectedResponse.complete(expectedResponseValue); + when(glideClient.customCommand(eq(new String[] {"JSON.TYPE", key})).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.type(glideClient, key); + Object actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void type_binary_without_path_returns_success() { + // setup + GlideString key = gs("testKey"); + CompletableFuture expectedResponse = new CompletableFuture<>(); + GlideString expectedResponseValue = gs("foo"); + expectedResponse.complete(expectedResponseValue); + when(glideClient.customCommand(eq(new GlideString[] {gs("JSON.TYPE"), key})).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.type(glideClient, key); + Object actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void type_with_path_returns_success() { + // setup + String key = "testKey"; + String path = "$"; + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "foo"; + expectedResponse.complete(expectedResponseValue); + when(glideClient.customCommand(eq(new String[] {"JSON.TYPE", key, path})).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.type(glideClient, key, path); + Object actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void type_binary_with_path_returns_success() { + // setup + GlideString key = gs("testKey"); + GlideString path = gs("$"); + CompletableFuture expectedResponse = new CompletableFuture<>(); + GlideString expectedResponseValue = gs("foo"); + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new GlideString[] {gs("JSON.TYPE"), key, path})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.type(glideClient, key, path); + Object actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } } diff --git a/java/integTest/src/test/java/glide/modules/JsonTests.java b/java/integTest/src/test/java/glide/modules/JsonTests.java index 040b8f1a9c..d2e7097e40 100644 --- a/java/integTest/src/test/java/glide/modules/JsonTests.java +++ b/java/integTest/src/test/java/glide/modules/JsonTests.java @@ -1070,4 +1070,45 @@ public void json_resp() { assertNull(Json.resp(client, "nonexistent_key", ".").get()); assertNull(Json.resp(client, "nonexistent_key").get()); } + + @Test + @SneakyThrows + public void json_type() { + String key = UUID.randomUUID().toString(); + String jsonValue = + "{\"key1\": \"value1\", \"key2\": 2, \"key3\": [1, 2, 3], \"key4\": {\"nested_key\":" + + " {\"key1\": [4, 5]}}, \"key5\": null, \"key6\": true, \"dec_key\": 2.3}"; + assertEquals(OK, Json.set(client, key, "$", jsonValue).get()); + + assertArrayEquals(new Object[] {"object"}, (Object[]) Json.type(client, key, "$").get()); + assertArrayEquals( + new Object[] {gs("string"), gs("array")}, + (Object[]) Json.type(client, gs(key), gs("$..key1")).get()); + assertArrayEquals(new Object[] {"integer"}, (Object[]) Json.type(client, key, "$.key2").get()); + assertArrayEquals(new Object[] {"array"}, (Object[]) Json.type(client, key, "$.key3").get()); + assertArrayEquals(new Object[] {"object"}, (Object[]) Json.type(client, key, "$.key4").get()); + assertArrayEquals( + new Object[] {"object"}, (Object[]) Json.type(client, key, "$.key4.nested_key").get()); + assertArrayEquals(new Object[] {"null"}, (Object[]) Json.type(client, key, "$.key5").get()); + assertArrayEquals(new Object[] {"boolean"}, (Object[]) Json.type(client, key, "$.key6").get()); + // Check for non-existent path in enhanced mode $.key7 + assertArrayEquals(new Object[] {}, (Object[]) Json.type(client, key, "$.key7").get()); + // Check for non-existent path within an existing key (array bound) + assertArrayEquals(new Object[] {}, (Object[]) Json.type(client, key, "$.key3[3]").get()); + // Legacy path (without $) - will return None for non-existing path + assertNull(Json.type(client, key, "key7").get()); + // Check for multiple path match in legacy + assertEquals("string", Json.type(client, key, "..key1").get()); + // Check for non-existent key with enhanced path + assertNull(Json.type(client, "non_existing_key", "$.key1").get()); + // Check for non-existent key with legacy path + assertNull(Json.type(client, "non_existing_key", "key1").get()); + // Check for all types in the JSON document using JSON Path + Object[] actualResult = (Object[]) Json.type(client, key, "$[*]").get(); + Object[] expectedResult = + new Object[] {"string", "integer", "array", "object", "null", "boolean", "number"}; + assertArrayEquals(expectedResult, actualResult); + // Check for all types in the JSON document using legacy path + assertEquals("string", Json.type(client, key, "[*]").get()); + } } From 1e0476cfcbc84ce13ae4ec6e987075a55871c218 Mon Sep 17 00:00:00 2001 From: Shoham Elias <116083498+shohamazon@users.noreply.github.com> Date: Tue, 29 Oct 2024 12:23:31 +0200 Subject: [PATCH 074/180] Python: adds JSON.ARRAPPEND command (#2382) --------- Signed-off-by: Shoham Elias Signed-off-by: Shoham Elias <116083498+shohamazon@users.noreply.github.com> --- CHANGELOG.md | 1 + .../async_commands/server_modules/json.py | 49 +++++++++++++++++++ .../tests/tests_server_modules/test_json.py | 41 ++++++++++++++++ 3 files changed, 91 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3854d89b28..86242a5d7d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,7 @@ * Python: Add `JSON.OBJKEYS` command ([#2395](https://github.com/valkey-io/valkey-glide/pull/2395)) * Python: Add `JSON.ARRINSERT` command ([#2464](https://github.com/valkey-io/valkey-glide/pull/2464)) * Python: Add `JSON.ARRTRIM` command ([#2457](https://github.com/valkey-io/valkey-glide/pull/2457)) +* Python: Add `JSON.ARRAPPEND` command ([#2382](https://github.com/valkey-io/valkey-glide/pull/2382)) #### Breaking Changes diff --git a/python/python/glide/async_commands/server_modules/json.py b/python/python/glide/async_commands/server_modules/json.py index eb75b6bfb3..80caa7979e 100644 --- a/python/python/glide/async_commands/server_modules/json.py +++ b/python/python/glide/async_commands/server_modules/json.py @@ -147,6 +147,55 @@ async def get( return cast(TJsonResponse[Optional[bytes]], await client.custom_command(args)) +async def arrappend( + client: TGlideClient, + key: TEncodable, + values: List[TEncodable], + path: Optional[TEncodable] = None, +) -> TJsonResponse[int]: + """ + Appends one or more `values` to the JSON array at the specified `path` within the JSON document stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + values (TEncodable): The values to append to the JSON array at the specified path. + path (Optional[TEncodable]): Represents the path within the JSON document where the `values` will be appended. + Defaults to None. + **Beware**: For AWS ElastiCache/MemoryDB the `path` parameter is required and not optional. + + Returns: + TJsonResponse[int]: + For JSONPath (`path` starts with `$`): + Returns a list of integer replies for every possible path, indicating the new length of the new array after appending `values`, + or None for JSON values matching the path that are not an array. + If `path` doesn't exist, an empty array will be returned. + For legacy path (`path` doesn't start with `$`): + Returns the length of the new array after appending `values` to the array at `path`. + If multiple paths match, the length of the first updated array is returned. + If the JSON value at `path` is not a array or if `path` doesn't exist, an error is raised. + If `key` doesn't exist, an error is raised. + For more information about the returned type, see `TJsonResponse`. + + Examples: + >>> from glide import json as valkeyJson + >>> import json + >>> await valkeyJson.set(client, "doc", "$", '{"a": 1, "b": ["one", "two"]}') + 'OK' # Indicates successful setting of the value at path '$' in the key stored at `doc`. + >>> await valkeyJson.arrappend(client, "doc", ["three"], "$.b") + [3] # Returns the new length of the array at path '$.b' after appending the value. + >>> await valkeyJson.arrappend(client, "doc", ["four"], ".b") + 4 # Returns the new length of the array at path '.b' after appending the value. + >>> json.loads(await valkeyJson.get(client, "doc", ".")) + {"a": 1, "b": ["one", "two", "three", "four"]} # Returns the updated JSON document + """ + args = ["JSON.ARRAPPEND", key] + if path: + args.append(path) + args.extend(values) + return cast(TJsonResponse[int], await client.custom_command(args)) + + async def arrinsert( client: TGlideClient, key: TEncodable, diff --git a/python/python/tests/tests_server_modules/test_json.py b/python/python/tests/tests_server_modules/test_json.py index c1ad3dfbfc..dc8c050320 100644 --- a/python/python/tests/tests_server_modules/test_json.py +++ b/python/python/tests/tests_server_modules/test_json.py @@ -1344,3 +1344,44 @@ async def test_json_arrtrim(self, glide_client: TGlideClient): assert await json.arrtrim(glide_client, key, "$.empty", 0, 1) == [0] assert await json.arrtrim(glide_client, key, ".empty", 0, 1) == 0 assert OuterJson.loads(await json.get(glide_client, key, "$.empty")) == [[]] + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_arrappend(self, glide_client: TGlideClient): + key = get_random_string(10) + initial_json_value = '{"a": 1, "b": ["one", "two"]}' + assert await json.set(glide_client, key, "$", initial_json_value) == OK + + assert await json.arrappend(glide_client, key, ['"three"'], "$.b") == [3] + assert await json.arrappend(glide_client, key, ['"four"', '"five"'], ".b") == 5 + + result = await json.get(glide_client, key, "$") + assert isinstance(result, bytes) + assert OuterJson.loads(result) == [ + {"a": 1, "b": ["one", "two", "three", "four", "five"]} + ] + + assert await json.arrappend(glide_client, key, ['"value"'], "$.a") == [None] + + # JSONPath, path doesnt exist + assert await json.arrappend(glide_client, key, ['"value"'], "$.c") == [] + # Legacy path, `path` doesnt exist + with pytest.raises(RequestError): + await json.arrappend(glide_client, key, ['"value"'], ".c") + + # Legacy path, the JSON value at `path` is not a array + with pytest.raises(RequestError): + await json.arrappend(glide_client, key, ['"value"'], ".a") + + with pytest.raises(RequestError): + await json.arrappend(glide_client, "non_existing_key", ['"six"'], "$.b") + with pytest.raises(RequestError): + await json.arrappend(glide_client, "non_existing_key", ['"six"'], ".b") + + # multiple path match + json_value = '[[], ["a"], ["a", "b"]]' + assert await json.set(glide_client, key, "$", json_value) == OK + assert await json.arrappend(glide_client, key, ['"c"'], "[*]") == 1 + result = await json.get(glide_client, key, "$") + assert isinstance(result, bytes) + assert OuterJson.loads(result) == [[["c"], ["a", "c"], ["a", "b", "c"]]] From 2284c757f68a89771cb7416c5ea7ac25380c6bbf Mon Sep 17 00:00:00 2001 From: Shoham Elias <116083498+shohamazon@users.noreply.github.com> Date: Tue, 29 Oct 2024 15:13:02 +0200 Subject: [PATCH 075/180] Python: adds JSON.RESP command (#2451) --------- Signed-off-by: Shoham Elias --- CHANGELOG.md | 1 + python/python/glide/__init__.py | 4 + .../async_commands/server_modules/json.py | 89 +++++++++++-- python/python/glide/constants.py | 19 +++ .../tests/tests_server_modules/test_json.py | 126 ++++++++++++++++++ 5 files changed, 225 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 86242a5d7d..5118e65015 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,7 @@ * Python: Add `JSON.ARRINSERT` command ([#2464](https://github.com/valkey-io/valkey-glide/pull/2464)) * Python: Add `JSON.ARRTRIM` command ([#2457](https://github.com/valkey-io/valkey-glide/pull/2457)) * Python: Add `JSON.ARRAPPEND` command ([#2382](https://github.com/valkey-io/valkey-glide/pull/2382)) +* Python: Add `JSON.RESP` command ([#2451](https://github.com/valkey-io/valkey-glide/pull/2451)) #### Breaking Changes diff --git a/python/python/glide/__init__.py b/python/python/glide/__init__.py index 9490bfe5b6..189f832e5e 100644 --- a/python/python/glide/__init__.py +++ b/python/python/glide/__init__.py @@ -120,6 +120,8 @@ TFunctionListResponse, TFunctionStatsFullResponse, TFunctionStatsSingleNodeResponse, + TJsonResponse, + TJsonUniversalResponse, TResult, TSingleNodeRoute, TXInfoStreamFullResponse, @@ -177,6 +179,8 @@ "TFunctionListResponse", "TFunctionStatsFullResponse", "TFunctionStatsSingleNodeResponse", + "TJsonResponse", + "TJsonUniversalResponse", "TOK", "TResult", "TXInfoStreamFullResponse", diff --git a/python/python/glide/async_commands/server_modules/json.py b/python/python/glide/async_commands/server_modules/json.py index 80caa7979e..d115f62dad 100644 --- a/python/python/glide/async_commands/server_modules/json.py +++ b/python/python/glide/async_commands/server_modules/json.py @@ -18,7 +18,7 @@ from typing import List, Optional, Union, cast from glide.async_commands.core import ConditionalChange -from glide.constants import TOK, TEncodable, TJsonResponse +from glide.constants import TOK, TEncodable, TJsonResponse, TJsonUniversalResponse from glide.glide_client import TGlideClient from glide.protobuf.command_request_pb2 import RequestType @@ -309,7 +309,7 @@ async def arrtrim( end: int, ) -> TJsonResponse[int]: """ - Trims an array at the specified `path` within the JSON document stored at `key` so that it becomes a subarray [start, end], both inclusive.› + Trims an array at the specified `path` within the JSON document stored at `key` so that it becomes a subarray [start, end], both inclusive. If `start` < 0, it is treated as 0. If `end` >= size (size of the array), it is treated as size-1. If `start` >= size or `start` > `end`, the array is emptied and 0 is returned. @@ -412,7 +412,7 @@ async def debug_fields( client: TGlideClient, key: TEncodable, path: Optional[TEncodable] = None, -) -> Optional[Union[int, List[int]]]: +) -> Optional[TJsonUniversalResponse[int]]: """ Returns the number of fields of the JSON value at the specified `path` within the JSON document stored at `key`. - **Primitive Values**: Each non-container JSON value (e.g., strings, numbers, booleans, and null) counts as one field. @@ -429,7 +429,7 @@ async def debug_fields( path (Optional[TEncodable]): The path within the JSON document. Defaults to root if not provided. Returns: - Optional[Union[int, List[int]]]: + Optional[TJsonUniversalResponse[int]]: For JSONPath (`path` starts with `$`): Returns an array of integers, each indicating the number of fields for each matched `path`. If `path` doesn't exist, an empty array will be returned. @@ -460,14 +460,16 @@ async def debug_fields( if path: args.append(path) - return cast(Optional[Union[int, List[int]]], await client.custom_command(args)) + return cast( + Optional[TJsonUniversalResponse[int]], await client.custom_command(args) + ) async def debug_memory( client: TGlideClient, key: TEncodable, path: Optional[TEncodable] = None, -) -> Optional[Union[int, List[int]]]: +) -> Optional[TJsonUniversalResponse[int]]: """ Reports memory usage in bytes of a JSON value at the specified `path` within the JSON document stored at `key`. @@ -477,7 +479,7 @@ async def debug_memory( path (Optional[TEncodable]): The path within the JSON document. Defaults to None. Returns: - Optional[Union[int, List[int]]]: + Optional[TJsonUniversalResponse[int]]: For JSONPath (`path` starts with `$`): Returns an array of integers, indicating the memory usage in bytes of a JSON value for each matched `path`. If `path` doesn't exist, an empty array will be returned. @@ -506,7 +508,9 @@ async def debug_memory( if path: args.append(path) - return cast(Optional[Union[int, List[int]]], await client.custom_command(args)) + return cast( + Optional[TJsonUniversalResponse[int]], await client.custom_command(args) + ) async def delete( @@ -612,7 +616,7 @@ async def numincrby( >>> from glide import json >>> await json.set(client, "doc", "$", '{"a": [], "b": [1], "c": [1, 2], "d": [1, 2, 3]}') 'OK' - >>> await json.numincrby(client, "doc", "$.d[*]", 10)› + >>> await json.numincrby(client, "doc", "$.d[*]", 10) b'[11,12,13]' # Increment each element in `d` array by 10. >>> await json.numincrby(client, "doc", ".c[1]", 10) b'12' # Increment the second element in the `c` array by 10. @@ -720,7 +724,7 @@ async def objkeys( client: TGlideClient, key: TEncodable, path: Optional[TEncodable] = None, -) -> Optional[Union[List[bytes], List[List[bytes]]]]: +) -> Optional[TJsonUniversalResponse[List[bytes]]]: """ Retrieves key names in the object values at the specified `path` within the JSON document stored at `key`. @@ -731,7 +735,7 @@ async def objkeys( Defaults to None. Returns: - Optional[Union[List[bytes], List[List[bytes]]]]: + Optional[TJsonUniversalResponse[List[bytes]]]: For JSONPath (`path` starts with `$`): Returns a list of arrays containing key names for each matching object. If a value matching the path is not an object, an empty array is returned. @@ -765,6 +769,61 @@ async def objkeys( ) +async def resp( + client: TGlideClient, key: TEncodable, path: Optional[TEncodable] = None +) -> TJsonUniversalResponse[ + Optional[Union[bytes, int, List[Optional[Union[bytes, int]]]]] +]: + """ + Retrieve the JSON value at the specified `path` within the JSON document stored at `key`. + The returning result is in the Valkey or Redis OSS Serialization Protocol (RESP).\n + JSON null is mapped to the RESP Null Bulk String.\n + JSON Booleans are mapped to RESP Simple string.\n + JSON integers are mapped to RESP Integers.\n + JSON doubles are mapped to RESP Bulk Strings.\n + JSON strings are mapped to RESP Bulk Strings.\n + JSON arrays are represented as RESP arrays, where the first element is the simple string [, followed by the array's elements.\n + JSON objects are represented as RESP object, where the first element is the simple string {, followed by key-value pairs, each of which is a RESP bulk string.\n + + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (Optional[TEncodable]): The path within the JSON document. Default to None. + + Returns: + TJsonUniversalResponse[Optional[Union[bytes, int, List[Optional[Union[bytes, int]]]]]] + For JSONPath ('path' starts with '$'): + Returns a list of replies for every possible path, indicating the RESP form of the JSON value. + If `path` doesn't exist, returns an empty list. + For legacy path (`path` doesn't starts with `$`): + Returns a single reply for the JSON value at the specified path, in its RESP form. + This can be a bytes object, an integer, None, or a list representing complex structures. + If multiple paths match, the value of the first JSON value match is returned. + If `path` doesn't exist, an error is raised. + If `key` doesn't exist, an None is returned. + + Examples: + >>> from glide import json + >>> await json.set(client, "doc", "$", '{"a": [1, 2, 3], "b": {"a": [1, 2], "c": {"a": 42}}}') + 'OK' + >>> await json.resp(client, "doc", "$..a") + [[b"[", 1, 2, 3],[b"[", 1, 2],42] + >>> await json.resp(client, "doc", "..a") + [b"[", 1, 2, 3] + """ + args = ["JSON.RESP", key] + if path: + args.append(path) + + return cast( + TJsonUniversalResponse[ + Optional[Union[bytes, int, List[Optional[Union[bytes, int]]]]] + ], + await client.custom_command(args), + ) + + async def strappend( client: TGlideClient, key: TEncodable, @@ -910,7 +969,7 @@ async def type( client: TGlideClient, key: TEncodable, path: Optional[TEncodable] = None, -) -> Optional[Union[bytes, List[bytes]]]: +) -> Optional[TJsonUniversalResponse[bytes]]: """ Retrieves the type of the JSON value at the specified `path` within the JSON document stored at `key`. @@ -920,7 +979,7 @@ async def type( path (Optional[TEncodable]): The path within the JSON document. Default to None. Returns: - Optional[Union[bytes, List[bytes]]]: + Optional[TJsonUniversalResponse[bytes]]: For JSONPath ('path' starts with '$'): Returns a list of byte string replies for every possible path, indicating the type of the JSON value. If `path` doesn't exist, an empty array will be returned. @@ -945,4 +1004,6 @@ async def type( if path: args.append(path) - return cast(Optional[Union[bytes, List[bytes]]], await client.custom_command(args)) + return cast( + Optional[TJsonUniversalResponse[bytes]], await client.custom_command(args) + ) diff --git a/python/python/glide/constants.py b/python/python/glide/constants.py index 4ecd2003a3..7c28372053 100644 --- a/python/python/glide/constants.py +++ b/python/python/glide/constants.py @@ -33,8 +33,27 @@ TSingleNodeRoute = Union[RandomNode, SlotKeyRoute, SlotIdRoute, ByAddressRoute] # When specifying legacy path (path doesn't start with `$`), response will be T # Otherwise, (when specifying JSONPath), response will be List[Optional[T]]. +# +# TJsonResponse is designed to handle scenarios where some paths may not contain valid values, especially with JSONPath targeting multiple paths. +# In such cases, the response may include None values, represented as `Optional[T]` in the list. +# This type provides flexibility for commands where a subset of the paths may return None. +# # For more information, see: https://redis.io/docs/data-types/json/path/ . TJsonResponse = Union[T, List[Optional[T]]] + +# When specifying legacy path (path doesn't start with `$`), response will be T +# Otherwise, (when specifying JSONPath), response will be List[T]. +# This type represents the response format for commands that apply to every path and every type in a JSON document. +# It covers both singular and multiple paths, ensuring that the command returns valid results for each matched path without None values. +# +# TJsonUniversalResponse is considered "universal" because it applies to every matched path and +# guarantees valid, non-null results across all paths, covering both singular and multiple paths. +# This type is used for commands that return results from all matched paths, ensuring that each +# path contains meaningful values without None entries (unless it's part of the commands response). +# It is typically used in scenarios where each target is expected to yield a valid response. For commands that are valid for all target types. +# +# For more information, see: https://redis.io/docs/data-types/json/path/ . +TJsonUniversalResponse = Union[T, List[T]] TEncodable = Union[str, bytes] TFunctionListResponse = List[ Mapping[ diff --git a/python/python/tests/tests_server_modules/test_json.py b/python/python/tests/tests_server_modules/test_json.py index dc8c050320..4f4a4c09b7 100644 --- a/python/python/tests/tests_server_modules/test_json.py +++ b/python/python/tests/tests_server_modules/test_json.py @@ -1,6 +1,8 @@ # Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 +import copy import json as OuterJson +import random import typing import pytest @@ -14,6 +16,19 @@ from tests.test_async_client import get_random_string, parse_info_response +def get_random_value(value_type="str"): + if value_type == "int": + return random.randint(1, 100) + elif value_type == "float": + return round(random.uniform(1, 100), 2) + elif value_type == "str": + return "".join(random.choices("abcdefghijklmnopqrstuvwxyz", k=5)) + elif value_type == "bool": + return random.choice([True, False]) + elif value_type == "null": + return None + + @pytest.mark.asyncio class TestJson: @pytest.mark.parametrize("cluster_mode", [True, False]) @@ -1385,3 +1400,114 @@ async def test_json_arrappend(self, glide_client: TGlideClient): result = await json.get(glide_client, key, "$") assert isinstance(result, bytes) assert OuterJson.loads(result) == [[["c"], ["a", "c"], ["a", "b", "c"]]] + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_resp(self, glide_client: TGlideClient): + key = get_random_string(5) + + # Generate random JSON content with specified types + json_value = { + "obj": {"a": get_random_value("int"), "b": get_random_value("float")}, + "arr": [get_random_value("int") for _ in range(3)], + "str": get_random_value("str"), + "bool": get_random_value("bool"), + "int": get_random_value("int"), + "float": get_random_value("float"), + "nullVal": get_random_value("null"), + } + + json_value_expected = copy.deepcopy(json_value) + json_value_expected["obj"]["b"] = str(json_value["obj"]["b"]).encode() + json_value_expected["float"] = str(json_value["float"]).encode() + json_value_expected["str"] = str(json_value["str"]).encode() + json_value_expected["bool"] = str(json_value["bool"]).lower().encode() + assert ( + await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == "OK" + ) + + assert await json.resp(glide_client, key, "$.*") == [ + [ + b"{", + [b"a", json_value_expected["obj"]["a"]], + [b"b", json_value_expected["obj"]["b"]], + ], + [b"[", *json_value_expected["arr"]], + json_value_expected["str"], + json_value_expected["bool"], + json_value_expected["int"], + json_value_expected["float"], + json_value_expected["nullVal"], + ] + + # multiple path match, the first will be returned + assert await json.resp(glide_client, key, "*") == [ + b"{", + [b"a", json_value_expected["obj"]["a"]], + [b"b", json_value_expected["obj"]["b"]], + ] + + assert await json.resp(glide_client, key, "$") == [ + [ + b"{", + [ + b"obj", + [ + b"{", + [b"a", json_value_expected["obj"]["a"]], + [b"b", json_value_expected["obj"]["b"]], + ], + ], + [b"arr", [b"[", *json_value_expected["arr"]]], + [ + b"str", + json_value_expected["str"], + ], + [ + b"bool", + json_value_expected["bool"], + ], + [b"int", json_value["int"]], + [ + b"float", + json_value_expected["float"], + ], + [b"nullVal", json_value["nullVal"]], + ], + ] + + assert await json.resp(glide_client, key, "$.str") == [ + json_value_expected["str"] + ] + assert await json.resp(glide_client, key, ".str") == json_value_expected["str"] + + # Further tests with a new random JSON structure + json_value = { + "a": [random.randint(1, 10) for _ in range(3)], + "b": { + "a": [random.randint(1, 10) for _ in range(2)], + "c": {"a": random.randint(1, 10)}, + }, + } + assert ( + await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == "OK" + ) + + # Multiple path match + assert await json.resp(glide_client, key, "$..a") == [ + [b"[", *json_value["a"]], + [b"[", *json_value["b"]["a"]], + json_value["b"]["c"]["a"], + ] + + assert await json.resp(glide_client, key, "..a") == [b"[", *json_value["a"]] + + # Test for non-existent paths + assert await json.resp(glide_client, key, "$.nonexistent") == [] + with pytest.raises(RequestError): + await json.resp(glide_client, key, "nonexistent") + + # Test for non-existent key + assert await json.resp(glide_client, "nonexistent_key", "$") is None + assert await json.resp(glide_client, "nonexistent_key", ".") is None + assert await json.resp(glide_client, "nonexistent_key") is None From ff392382829b2674f596414950a1cb4008f4d7c1 Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Tue, 29 Oct 2024 10:21:06 -0700 Subject: [PATCH 076/180] Java: Fix script kill IT (#2523) * Fix script kill IT Signed-off-by: Yury-Fridlyand --- .../src/test/java/glide/TestUtilities.java | 15 +++++++++------ .../src/test/java/glide/cluster/CommandTests.java | 12 ++++++------ .../test/java/glide/standalone/CommandTests.java | 8 ++++---- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/java/integTest/src/test/java/glide/TestUtilities.java b/java/integTest/src/test/java/glide/TestUtilities.java index 1429f62c97..2535fcaa8e 100644 --- a/java/integTest/src/test/java/glide/TestUtilities.java +++ b/java/integTest/src/test/java/glide/TestUtilities.java @@ -27,6 +27,8 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; import java.util.stream.Collectors; import lombok.NonNull; import lombok.SneakyThrows; @@ -394,17 +396,18 @@ public static String createLongRunningLuaScript(int timeout, boolean readOnly) { return script.replace("$timeout", Integer.toString(timeout)); } - public static void waitForNotBusy(BaseClient client) { + /** + * Lock test until server completes a script/function execution. + * + * @param lambda Client api reference to use for terminating the script/function on the server. + */ + public static void waitForNotBusy(Supplier> lambda) { // If function wasn't killed, and it didn't time out - it blocks the server and cause rest // test to fail. boolean isBusy = true; do { try { - if (client instanceof GlideClusterClient) { - ((GlideClusterClient) client).functionKill().get(); - } else if (client instanceof GlideClient) { - ((GlideClient) client).functionKill().get(); - } + lambda.get().get(); } catch (Exception busy) { // should throw `notbusy` error, because the function should be killed before if (busy.getMessage().toLowerCase().contains("notbusy")) { diff --git a/java/integTest/src/test/java/glide/cluster/CommandTests.java b/java/integTest/src/test/java/glide/cluster/CommandTests.java index 8cbdabdb99..ee2682a259 100644 --- a/java/integTest/src/test/java/glide/cluster/CommandTests.java +++ b/java/integTest/src/test/java/glide/cluster/CommandTests.java @@ -1808,7 +1808,7 @@ public void functionKill_no_write_without_route() { assertTrue(functionKilled); } finally { - waitForNotBusy(clusterClient); + waitForNotBusy(clusterClient::functionKill); } } } @@ -1863,7 +1863,7 @@ public void functionKillBinary_no_write_without_route() { assertTrue(functionKilled); } finally { - waitForNotBusy(clusterClient); + waitForNotBusy(clusterClient::functionKill); } } } @@ -1915,7 +1915,7 @@ public void functionKill_no_write_with_route(boolean singleNodeRoute) { assertTrue(functionKilled); } finally { - waitForNotBusy(clusterClient); + waitForNotBusy(clusterClient::functionKill); } } } @@ -1969,7 +1969,7 @@ public void functionKillBinary_no_write_with_route(boolean singleNodeRoute) { assertTrue(functionKilled); } finally { - waitForNotBusy(clusterClient); + waitForNotBusy(clusterClient::functionKill); } } } @@ -3276,7 +3276,7 @@ public void scriptKill_with_route() { assertTrue(scriptKilled); } finally { - waitForNotBusy(clusterClient); + waitForNotBusy(clusterClient::scriptKill); } } @@ -3297,7 +3297,7 @@ public void scriptKill_unkillable() { String key = UUID.randomUUID().toString(); RequestRoutingConfiguration.Route route = new RequestRoutingConfiguration.SlotKeyRoute(key, PRIMARY); - String code = createLongRunningLuaScript(5, false); + String code = createLongRunningLuaScript(6, false); Script script = new Script(code, false); CompletableFuture promise = new CompletableFuture<>(); diff --git a/java/integTest/src/test/java/glide/standalone/CommandTests.java b/java/integTest/src/test/java/glide/standalone/CommandTests.java index 5e558a0273..f7e4943e3a 100644 --- a/java/integTest/src/test/java/glide/standalone/CommandTests.java +++ b/java/integTest/src/test/java/glide/standalone/CommandTests.java @@ -784,7 +784,7 @@ public void functionKill_no_write() { assertTrue(functionKilled); } finally { - waitForNotBusy(regularClient); + waitForNotBusy(regularClient::functionKill); } } } @@ -835,7 +835,7 @@ public void functionKillBinary_no_write() { assertTrue(functionKilled); } finally { - waitForNotBusy(regularClient); + waitForNotBusy(regularClient::functionKill); } } } @@ -1681,7 +1681,7 @@ public void scriptKill() { assertTrue(scriptKilled); } finally { - waitForNotBusy(regularClient); + waitForNotBusy(regularClient::scriptKill); } } @@ -1700,7 +1700,7 @@ public void scriptKill() { @Test public void scriptKill_unkillable() { String key = UUID.randomUUID().toString(); - String code = createLongRunningLuaScript(5, false); + String code = createLongRunningLuaScript(6, false); Script script = new Script(code, false); CompletableFuture promise = new CompletableFuture<>(); From 4af45f8bba988fa5289862af9c8f1a76e4d8f124 Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Tue, 29 Oct 2024 13:03:54 -0700 Subject: [PATCH 077/180] Add node REPL (#2355) * Add node REPL Signed-off-by: Yury-Fridlyand --- node/DEVELOPER.md | 29 ++++++++++++++++++++++++++++- node/package.json | 3 ++- node/tsconfig.json | 9 +++++++++ 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/node/DEVELOPER.md b/node/DEVELOPER.md index 4d06e391a1..15c8b53a14 100644 --- a/node/DEVELOPER.md +++ b/node/DEVELOPER.md @@ -149,6 +149,33 @@ run the following command: npm run test-modules ``` +### REPL (interactive shell) + +It is possible to run an interactive shell synced with the currect client code to test and debug it: + +```bash +npx ts-node --project tsconfig.json +``` + +This shell allows executing typescript and javascript code line by line: + +```typescript +import { GlideClient, GlideClusterClient } from "."; +let client = await GlideClient.createClient({ + addresses: [{ host: "localhost", port: 6379 }], +}); +let clusterClient = await GlideClusterClient.createClient({ + addresses: [{ host: "localhost", port: 7000 }], +}); +await client.ping(); +``` + +After applying changes in client code you need to restart the shell. + +It has command history and bash-like search (`Ctrl+R`). + +Shell hangs on exit (`Ctrl+D`) if you don't close the clients. Use `Ctrl+C` to kill it and/or close clients before exit. + ### Submodules After pulling new changes, ensure that you update the submodules by running the following command: @@ -181,7 +208,7 @@ Development on the Node wrapper may involve changes in either the TypeScript or # Run from the node folder npm run lint # To automatically apply ESLint and/or prettier recommendations - npx run lint:fix + npm run lint:fix ``` 2. Rust diff --git a/node/package.json b/node/package.json index 85d92b0476..06a2dc7743 100644 --- a/node/package.json +++ b/node/package.json @@ -60,7 +60,8 @@ "semver": "^7.6.3", "ts-jest": "^29.2.5", "typescript": "^5.5.4", - "uuid": "^10.0.0" + "uuid": "^10.0.0", + "ts-node": "^10.9.2" }, "author": "Valkey GLIDE Maintainers", "license": "Apache-2.0", diff --git a/node/tsconfig.json b/node/tsconfig.json index 4cd744701c..a1824416be 100644 --- a/node/tsconfig.json +++ b/node/tsconfig.json @@ -25,6 +25,15 @@ ] /* Specify a set of bundled library declaration files that describe the target runtime environment. */, "outDir": "./build-ts" /* Specify an output folder for all emitted files.*/ }, + "ts-node": { + "transpileOnly": true, + "compilerOptions": { + "module": "CommonJS", + "target": "ES2018", + "esModuleInterop": true + }, + "esm": true + }, "compileOnSave": false, "include": ["./*.ts", "src/*.ts", "src/*.js"], "exclude": ["node_modules", "build-ts"] From 5b90096f725d578b3be929883421a6921ba92f50 Mon Sep 17 00:00:00 2001 From: Muhammad Awawdi Date: Wed, 30 Oct 2024 18:46:46 +0200 Subject: [PATCH 078/180] Python: Add JSON.ARRINDEX Command (#2528) --------- Signed-off-by: Muhammad Awawdi Signed-off-by: Shoham Elias Co-authored-by: Shoham Elias --- CHANGELOG.md | 1 + python/python/glide/__init__.py | 6 +- .../async_commands/server_modules/json.py | 102 ++++++ .../tests/tests_server_modules/test_json.py | 331 +++++++++++++++++- 4 files changed, 438 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5118e65015..6791a5bccd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ * Python: Add JSON.TYPE command ([#2409](https://github.com/valkey-io/valkey-glide/pull/2409)) * Python: Add JSON.NUMINCRBY command ([#2448](https://github.com/valkey-io/valkey-glide/pull/2448)) * Python: Add JSON.NUMMULTBY command ([#2458](https://github.com/valkey-io/valkey-glide/pull/2458)) +* Python: Add JSON.ARRINDEX command ([#2528](https://github.com/valkey-io/valkey-glide/pull/2528)) * Python: Add `JSON.DEBUG_MEMORY` and `JSON.DEBUG_FIELDS` commands ([#2481](https://github.com/valkey-io/valkey-glide/pull/2481)) * Java: Added `FT.CREATE` ([#2414](https://github.com/valkey-io/valkey-glide/pull/2414)) * Java: Added `FT.INFO` ([#2405](https://github.com/valkey-io/valkey-glide/pull/2441)) diff --git a/python/python/glide/__init__.py b/python/python/glide/__init__.py index 189f832e5e..46289fdacc 100644 --- a/python/python/glide/__init__.py +++ b/python/python/glide/__init__.py @@ -64,6 +64,7 @@ FtSearchLimit, ReturnField, ) +from glide.async_commands.server_modules.json import JsonArrIndexOptions, JsonGetOptions from glide.async_commands.sorted_set import ( AggregationType, GeoSearchByBox, @@ -221,7 +222,6 @@ "InfBound", "InfoSection", "InsertPosition", - "json", "ft", "LexBoundary", "Limit", @@ -250,6 +250,10 @@ "ClusterScanCursor" # PubSub "PubSubMsg", + # Json + "json", + "JsonGetOptions", + "JsonArrIndexOptions", # Logger "Logger", "LogLevel", diff --git a/python/python/glide/async_commands/server_modules/json.py b/python/python/glide/async_commands/server_modules/json.py index d115f62dad..a8d0dfcfcc 100644 --- a/python/python/glide/async_commands/server_modules/json.py +++ b/python/python/glide/async_commands/server_modules/json.py @@ -54,6 +54,36 @@ def get_options(self) -> List[str]: return args +class JsonArrIndexOptions: + """ + Options for the `JSON.ARRINDEX` command. + + Args: + start (int): The inclusive start index from which the search begins. Defaults to None. + end (Optional[int]): The exclusive end index where the search stops. Defaults to None. + + Note: + - If `start` is greater than `end`, the command returns `-1` to indicate that the value was not found. + - Indices that exceed the array bounds are automatically adjusted to the nearest valid position. + """ + + def __init__(self, start: int, end: Optional[int] = None): + self.start = start + self.end = end + + def to_args(self) -> List[str]: + """ + Get the options as a list of arguments for the JSON.ARRINDEX command. + + Returns: + List[str]: A list containing the start and end indices if specified. + """ + args = [str(self.start)] + if self.end is not None: + args.append(str(self.end)) + return args + + async def set( client: TGlideClient, key: TEncodable, @@ -196,6 +226,78 @@ async def arrappend( return cast(TJsonResponse[int], await client.custom_command(args)) +async def arrindex( + client: TGlideClient, + key: TEncodable, + path: TEncodable, + value: TEncodable, + options: Optional[JsonArrIndexOptions] = None, +) -> TJsonResponse[int]: + """ + Searches for the first occurrence of a scalar JSON value (i.e., a value that is neither an object nor an array) within arrays at the specified `path` in the JSON document stored at `key`. + + If specified, `options.start` and `options.end` define an inclusive-to-exclusive search range within the array. + (Where `options.start` is inclusive and `options.end` is exclusive). + + Out-of-range indices adjust to the nearest valid position, and negative values count from the end (e.g., `-1` is the last element, `-2` the second last). + + Setting `options.end` to `0` behaves like `-1`, extending the range to the array's end (inclusive). + + If `options.start` exceeds `options.end`, `-1` is returned, indicating that the value was not found. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (TEncodable): The path within the JSON document. + value (TEncodable): The value to search for within the arrays. + options (Optional[JsonArrIndexOptions]): Options specifying an inclusive `start` index and an optional exclusive `end` index for a range-limited search. + Defaults to the full array if not provided. See `JsonArrIndexOptions`. + + Returns: + Optional[Union[int, List[int]]]: + For JSONPath (`path` starts with `$`): + Returns an array of integers for every possible path, indicating of the first occurrence of `value` within the array, + or None for JSON values matching the path that are not an array. + A returned value of `-1` indicates that the value was not found in that particular array. + If `path` does not exist, an empty array will be returned. + For legacy path (`path` doesn't start with `$`): + Returns an integer representing the index of the first occurrence of `value` within the array at the specified path. + A returned value of `-1` indicates that the value was not found in that particular array. + If multiple paths match, the index of the value from the first matching array is returned. + If the JSON value at the `path` is not an array or if `path` does not exist, an error is raised. + If `key` does not exist, an error is raised. + + Examples: + >>> from glide import json + >>> await json.set(client, "doc", "$", '[[], ["a"], ["a", "b"], ["a", "b", "c"]]') + 'OK' + >>> await json.arrindex(client, "doc", "$[*]", '"b"') + [-1, -1, 1, 1] + >>> await json.set(client, "doc", ".", '{"children": ["John", "Jack", "Tom", "Bob", "Mike"]}') + 'OK' + >>> await json.arrindex(client, "doc", ".children", '"Tom"') + 2 + >>> await json.set(client, "doc", "$", '{"fruits": ["apple", "banana", "cherry", "banana", "grape"]}') + 'OK' + >>> await json.arrindex(client, "doc", "$.fruits", '"banana"', JsonArrIndexOptions(start=2, end=4)) + 3 + >>> await json.set(client, "k", ".", '[1, 2, "a", 4, "a", 6, 7, "b"]') + 'OK' + >>> await json.arrindex(client, "k", ".", '"b"', JsonArrIndexOptions(start=4, end=0)) + 7 # "b" found at index 7 within the specified range, treating end=0 as the entire array's end. + >>> await json.arrindex(client, "k", ".", '"b"', JsonArrIndexOptions(start=4, end=-1)) + 7 # "b" found at index 7, with end=-1 covering the full array to its last element. + >>> await json.arrindex(client, "k", ".", '"b"', JsonArrIndexOptions(start=4, end=7)) + -1 # "b" not found within the range from index 4 to exclusive end at index 7. + """ + args = ["JSON.ARRINDEX", key, path, value] + + if options: + args.extend(options.to_args()) + + return cast(TJsonResponse[int], await client.custom_command(args)) + + async def arrinsert( client: TGlideClient, key: TEncodable, diff --git a/python/python/tests/tests_server_modules/test_json.py b/python/python/tests/tests_server_modules/test_json.py index 4f4a4c09b7..ad958ff88d 100644 --- a/python/python/tests/tests_server_modules/test_json.py +++ b/python/python/tests/tests_server_modules/test_json.py @@ -8,7 +8,7 @@ import pytest from glide.async_commands.core import ConditionalChange, InfoSection from glide.async_commands.server_modules import json -from glide.async_commands.server_modules.json import JsonGetOptions +from glide.async_commands.server_modules.json import JsonArrIndexOptions, JsonGetOptions from glide.config import ProtocolVersion from glide.constants import OK from glide.exceptions import RequestError @@ -1360,6 +1360,335 @@ async def test_json_arrtrim(self, glide_client: TGlideClient): assert await json.arrtrim(glide_client, key, ".empty", 0, 1) == 0 assert OuterJson.loads(await json.get(glide_client, key, "$.empty")) == [[]] + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_arrindex(self, glide_client: TGlideClient): + key = get_random_string(10) + + json_value = { + "empty_array": [], + "single_element": ["apple"], + "multiple_elements": ["banana", "cherry", "date"], + "nested_arrays": [ + ["alpha"], + ["beta", "gamma"], + ["delta", "epsilon", "zeta"], + ], + "mixed_types": [1, "two", True, None, 5.5], + "not_array": 5, + "nested_arrays2": [ + ["a"], + ["ab", "abc"], + ["abcd", "abcde", "abcdef", "abcdefg", "abcdefgh", 1, 2, None, "gamma"], + ], + "nested_structure": { + "level1": { + "level2": { + "level3": [ + ["gamma", "theta"], + ["iota", "kappa", "gamma"], + ] + } + } + }, + } + + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + + # JSONPath Syntax Tests + # Search for "beta" in all arrays at the root level, Non-array values return null + result = await json.arrindex(glide_client, key, "$[*]", '"beta"') + assert result == [-1, -1, -1, -1, -1, None, -1, None] + + # Search for a boolean + result = await json.arrindex(glide_client, key, "$.mixed_types", "true") + assert result == [2] # True found at index 2 in the "mixed_types" array + + # Search for a float + result = await json.arrindex(glide_client, key, "$.mixed_types", "5.5") + assert result == [4] # 5.5 found at index 4 in the "mixed_types" array + + # Search for "gamma" at nested level + result = await json.arrindex(glide_client, key, "$.nested_arrays[*]", '"gamma"') + assert result == [-1, 1, -1] # "gamma" found at index 1 in the second array + + # Search for "gamma" at nested level with a specified range + result = await json.arrindex( + glide_client, + key, + "$.nested_arrays2[*]", + '"gamma"', + JsonArrIndexOptions(start=0, end=5), + ) + assert result == [-1, -1, -1] + + # Search for "gamma" at nested level with start > end + result = await json.arrindex( + glide_client, + key, + "$.nested_arrays[*]", + '"gamma"', + JsonArrIndexOptions(start=2, end=1), + ) + assert result == [-1, -1, -1] # Invalid range, returns -1 for all + + # Search for "omega" which does not exist + result = await json.arrindex(glide_client, key, "$[*]", '"omega"') + assert result == [-1, -1, -1, -1, -1, None, -1, None] # "omega" not found + + # Search for null values, null found at at third index in the fifth array + result = await json.arrindex(glide_client, key, "$[*]", "null") + assert result == [-1, -1, -1, -1, 3, None, -1, None] + + # Search in mixed types, "two" found at first index in the fifth array + result = await json.arrindex(glide_client, key, "$[*]", '"two"') + assert result == [-1, -1, -1, -1, 1, None, -1, None] + + # Out of range check for "start" value + result = await json.arrindex( + glide_client, key, "$[*]", '"apple"', JsonArrIndexOptions(start=-200) + ) + assert result == [ + -1, + 0, + -1, + -1, + -1, + None, + -1, + None, + ] # Rounded to the array's start + + # Check for end = -1, tests if the function includes the last element, found "gamma" at index 8 at the third array + result = await json.arrindex( + glide_client, + key, + "$.nested_arrays2[*]", + '"gamma"', + JsonArrIndexOptions(start=0, end=-1), + ) + assert result == [-1, -1, 8] + + # Check for non-existent key + with pytest.raises(RequestError): + await json.arrindex( + glide_client, + "Non_existent", + "$.nested_arrays2[*]", + '"abcdefg"', + JsonArrIndexOptions(start=0, end=-1), + ) + + # Check for non-existent path + result = await json.arrindex( + glide_client, + key, + "$.nested_arrays3[*]", + '"abcdefg"', + JsonArrIndexOptions(start=0, end=-1), + ) + assert result == [] + + # Using JSONPath syntax to search for "gamma" in nested_structure.level1.level2.level3 + result = await json.arrindex( + glide_client, key, "$.nested_structure.level1.level2.level3[*]", '"gamma"' + ) + assert result == [ + 0, + 2, + ] # "gamma" at index 0 in first array, index 2 in second array + + # Check for inclusive behavior of start in JSONPath syntax + result = await json.arrindex( + glide_client, + key, + "$.nested_structure.level1.level2.level3[*]", + '"gamma"', + JsonArrIndexOptions(start=0), + ) + assert result == [ + 0, + 2, + ] # "gamma" at index 0 of level3[0] and index 2 of level3[1]. + + # Check for exclusive behavior of end in JSONPath syntax + result = await json.arrindex( + glide_client, + key, + "$.nested_structure.level1.level2.level3[*]", + '"gamma"', + JsonArrIndexOptions(start=0, end=2), + ) + assert result == [ + 0, + -1, + ] # Only "gamma" at index 0 of level3[0] is found; gamma at index 2 of level3[1] is excluded as its not within the search range. + + # Check for passing start = 0, end = 0 in JSONPath syntax + result = await json.arrindex( + glide_client, + key, + "$.nested_arrays[2]", + '"zeta"', + JsonArrIndexOptions(start=0, end=0), + ) + assert result == [2] # "zeta" found at index 2 as the whole range was searched + + # Check for passing start = 1, end = 0 (start>end) but end is a "special value" in JSONPath syntax + result = await json.arrindex( + glide_client, + key, + "$.nested_arrays[2]", + '"zeta"', + JsonArrIndexOptions(start=1, end=0), + ) + assert result == [2] # "zeta" found at index 2 as the whole range was searched + + # Check for passing start = 1, end = -1 (start>end) but end is a "special value" in JSONPath syntax + result = await json.arrindex( + glide_client, + key, + "$.nested_arrays[2]", + '"zeta"', + JsonArrIndexOptions(start=1, end=-1), + ) + assert result == [2] # "zeta" found at index 2 as the whole range was searched + + # Restricted Path Syntax Tests + # Search for "abcd" in the "nested_arrays2" array + result = await json.arrindex(glide_client, key, ".nested_arrays2[2]", '"abcd"') + assert result == 0 # "abcd" found at index 0 + + # Search for "abcd" in the "nested_arrays2" array with specified range + result = await json.arrindex( + glide_client, + key, + ".nested_arrays2[2]", + '"abcd"', + JsonArrIndexOptions(start=1, end=4), + ) + assert result == -1 # "abcd" not found at the specified range + + # Search for "abcdefg" in the "nested_arrays2" with start > end + result = await json.arrindex( + glide_client, + key, + ".nested_arrays2[2]", + '"abcdefg"', + JsonArrIndexOptions(start=4, end=3), + ) + assert result == -1 + + # Search for "theta" which does not exist + result = await json.arrindex(glide_client, key, ".multiple_elements", '"theta"') + assert result == -1 # "theta" not found + + # Check for non_existent path + with pytest.raises(RequestError): + await json.arrindex(glide_client, key, ".non_existent", '"value"') + + # Search in an empty array + result = await json.arrindex(glide_client, key, ".empty_array", '"anything"') + assert result == -1 # Nothing to find in empty array + + # Search for a boolean + result = await json.arrindex(glide_client, key, ".mixed_types", "true") + assert result == 2 # True found at index 2 + + # Search for a float + result = await json.arrindex(glide_client, key, ".mixed_types", "5.5") + assert result == 4 # 5.5 found at index 4 + + # Search for null value + result = await json.arrindex(glide_client, key, ".mixed_types", "null") + assert result == 3 # null found at index 3 + + # Out of range check for "start" value + result = await json.arrindex( + glide_client, + key, + ".single_element", + '"apple"', + JsonArrIndexOptions(start=-200), + ) + assert result == 0 # Rounded to the array's start + + # Check for end = -1, tests if the function includes the last element + result = await json.arrindex( + glide_client, + key, + ".nested_arrays2[2]", + '"gamma"', + JsonArrIndexOptions(start=0, end=-1), + ) + assert result == 8 + + # Check for non-existent key + with pytest.raises(RequestError): + await json.arrindex( + glide_client, "Non_existent", ".nested_arrays2[1]", '"abcdefg"' + ) + + # Check for value at path is not an array + with pytest.raises(RequestError): + await json.arrindex(glide_client, key, ".not_array", "val") + + # Using legacy syntax to search for "gamma" in nested_structure + result = await json.arrindex( + glide_client, key, ".nested_structure.level1.level2.level3[*]", '"gamma"' + ) + assert result == 0 # Legacy syntax returns index from first matching array + + # Check for inclusive behavior of start in legacy syntax + result = await json.arrindex( + glide_client, + key, + ".nested_arrays[2]", + '"epsilon"', + JsonArrIndexOptions(start=1), + ) + assert result == 1 # "epsilon" found at index 1 in nested_arrays[2]. + + # Check for exclusive behavior of end in legacy syntax + result = await json.arrindex( + glide_client, + key, + ".nested_arrays[2]", + '"zeta"', + JsonArrIndexOptions(start=1, end=2), + ) + assert result == -1 # "zeta" at index 2 is excluded due to exclusive end. + + # Check for passing start = 0, end = 0 + result = await json.arrindex( + glide_client, + key, + ".nested_arrays[2]", + '"zeta"', + JsonArrIndexOptions(start=0, end=0), + ) + assert result == 2 # "zeta" found at index 2 as the whole range was searched + + # Check for passing start = 1, end = 0 (start>end) but end is a "special value" + result = await json.arrindex( + glide_client, + key, + ".nested_arrays[2]", + '"zeta"', + JsonArrIndexOptions(start=1, end=0), + ) + assert result == 2 # "zeta" found at index 2 as the whole range was searched + + # Check for passing start = 1, end = -1 (start>end) but end is a "special value" + result = await json.arrindex( + glide_client, + key, + ".nested_arrays[2]", + '"zeta"', + JsonArrIndexOptions(start=1, end=-1), + ) + assert result == 2 # "zeta" found at index 2 as the whole range was searched + @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) async def test_json_arrappend(self, glide_client: TGlideClient): From d4ef3c3699ca2cfc21b5ff016fdc0c4a3e82b5ab Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Wed, 30 Oct 2024 10:44:37 -0700 Subject: [PATCH 079/180] Node: `FT.INFO` (#2540) * `FT.INFO`. Signed-off-by: Yury-Fridlyand --- CHANGELOG.md | 1 + .../glide/api/commands/servermodules/FT.java | 8 +- .../java/glide/modules/VectorSearchTests.java | 8 - node/npm/glide/index.ts | 2 + node/src/BaseClient.ts | 3 - node/src/server-modules/GlideFt.ts | 92 +++++++++- node/tests/ServerModules.test.ts | 159 ++++++++++++------ 7 files changed, 197 insertions(+), 76 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6791a5bccd..a1510be7f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ * Java: Added `FT.PROFILE` ([#2473](https://github.com/valkey-io/valkey-glide/pull/2473)) * Java: Added `JSON.SET` and `JSON.GET` ([#2462](https://github.com/valkey-io/valkey-glide/pull/2462)) * Node: Added `FT.CREATE` ([#2501](https://github.com/valkey-io/valkey-glide/pull/2501)) +* Node: Added `FT.INFO` ([#2540](https://github.com/valkey-io/valkey-glide/pull/2540)) * Java: Added `JSON.DEBUG` ([#2520](https://github.com/valkey-io/valkey-glide/pull/2520)) * Java: Added `JSON.ARRINSERT` and `JSON.ARRLEN` ([#2476](https://github.com/valkey-io/valkey-glide/pull/2476)) * Java: Added `JSON.ARRPOP` ([#2486](https://github.com/valkey-io/valkey-glide/pull/2486)) diff --git a/java/client/src/main/java/glide/api/commands/servermodules/FT.java b/java/client/src/main/java/glide/api/commands/servermodules/FT.java index 9d1a75e9ea..ff0d80f0c7 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/FT.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/FT.java @@ -540,7 +540,7 @@ public static CompletableFuture profile( * Map response = client.ftinfo("myIndex").get(); * // the response contains data in the following format: * Map data = Map.of( - * "index_name", gs("bcd97d68-4180-4bc5-98fe-5125d0abbcb8"), + * "index_name", gs("myIndex"), * "index_status", gs("AVAILABLE"), * "key_type", gs("JSON"), * "creation_timestamp", 1728348101728771L, @@ -566,7 +566,7 @@ public static CompletableFuture profile( * gs("dimension", 6L, * gs("block_size", 1024L, * gs("algorithm", gs("FLAT") - * ) + * ) * ), * Map.of( * gs("identifier"), gs("name"), @@ -599,7 +599,7 @@ public static CompletableFuture> info( * Map response = client.ftinfo(gs("myIndex")).get(); * // the response contains data in the following format: * Map data = Map.of( - * "index_name", gs("bcd97d68-4180-4bc5-98fe-5125d0abbcb8"), + * "index_name", gs("myIndex"), * "index_status", gs("AVAILABLE"), * "key_type", gs("JSON"), * "creation_timestamp", 1728348101728771L, @@ -625,7 +625,7 @@ public static CompletableFuture> info( * gs("dimension", 6L, * gs("block_size", 1024L, * gs("algorithm", gs("FLAT") - * ) + * ) * ), * Map.of( * gs("identifier"), gs("name"), diff --git a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java index 75151f103b..9f02df3680 100644 --- a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java +++ b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java @@ -723,14 +723,6 @@ public void ft_aggregate() { @Test @SneakyThrows public void ft_info() { - // TODO use FT.LIST when it is done - var indices = (Object[]) client.customCommand(new String[] {"FT._LIST"}).get().getSingleValue(); - - // check that we can get a response for all indices (no crashes on value conversion or so) - for (var idx : indices) { - FT.info(client, (String) idx).get(); - } - var index = UUID.randomUUID().toString(); assertEquals( OK, diff --git a/node/npm/glide/index.ts b/node/npm/glide/index.ts index 7539524e32..4c370588df 100644 --- a/node/npm/glide/index.ts +++ b/node/npm/glide/index.ts @@ -126,6 +126,7 @@ function initialize() { VectorFieldAttributesFlat, VectorFieldAttributesHnsw, FtCreateOptions, + FtInfoReturnType, GlideRecord, GlideString, JsonGetOptions, @@ -244,6 +245,7 @@ function initialize() { VectorFieldAttributesFlat, VectorFieldAttributesHnsw, FtCreateOptions, + FtInfoReturnType, GlideRecord, GlideJson, GlideString, diff --git a/node/src/BaseClient.ts b/node/src/BaseClient.ts index 7923e71d2e..de31ee3c5e 100644 --- a/node/src/BaseClient.ts +++ b/node/src/BaseClient.ts @@ -5510,7 +5510,6 @@ export class BaseClient { * attributes of a consumer group for the stream at `key`. * @example * ```typescript - *
        {@code
              * const result = await client.xinfoGroups("my_stream");
              * console.log(result); // Output:
              * // [
        @@ -5963,13 +5962,11 @@ export class BaseClient {
              *
              * @example
              * ```typescript
        -     *  
        {@code
              * const entryId = await client.xadd("mystream", ["myfield", "mydata"]);
              * // read messages from streamId
              * const readResult = await client.xreadgroup(["myfield", "mydata"], "mygroup", "my0consumer");
              * // acknowledge messages on stream
              * console.log(await client.xack("mystream", "mygroup", [entryId])); // Output: 1
        -     * 
        * ``` */ public async xack( diff --git a/node/src/server-modules/GlideFt.ts b/node/src/server-modules/GlideFt.ts index e78f961c8c..0d58cbfeb0 100644 --- a/node/src/server-modules/GlideFt.ts +++ b/node/src/server-modules/GlideFt.ts @@ -2,12 +2,28 @@ * Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ -import { Decoder, DecoderOption, GlideString } from "../BaseClient"; +import { + convertGlideRecordToRecord, + Decoder, + DecoderOption, + GlideRecord, + GlideReturnType, + GlideString, +} from "../BaseClient"; import { GlideClient } from "../GlideClient"; import { GlideClusterClient } from "../GlideClusterClient"; import { Field, FtCreateOptions } from "./GlideFtOptions"; -/** Module for Vector Search commands */ +/** Data type of {@link GlideFt.info | info} command response. */ +type FtInfoReturnType = Record< + string, + | GlideString + | number + | GlideString[] + | Record[]> +>; + +/** Module for Vector Search commands. */ export class GlideFt { /** * Creates an index and initiates a backfill of that index. @@ -187,16 +203,82 @@ export class GlideFt { decoder: Decoder.String, }) as Promise<"OK">; } + + /** + * Returns information about a given index. + * + * @param client - The client to execute the command. + * @param indexName - The index name. + * @param options - (Optional) See {@link DecoderOption}. + * + * @returns Nested maps with info about the index. See example for more details. + * + * @example + * ```typescript + * const info = await GlideFt.info(client, "myIndex"); + * console.log(info); // Output: + * // { + * // index_name: 'myIndex', + * // index_status: 'AVAILABLE', + * // key_type: 'JSON', + * // creation_timestamp: 1728348101728771, + * // key_prefixes: [ 'json:' ], + * // num_indexed_vectors: 0, + * // space_usage: 653471, + * // num_docs: 0, + * // vector_space_usage: 653471, + * // index_degradation_percentage: 0, + * // fulltext_space_usage: 0, + * // current_lag: 0, + * // fields: [ + * // { + * // identifier: '$.vec', + * // type: 'VECTOR', + * // field_name: 'VEC', + * // option: '', + * // vector_params: { + * // data_type: 'FLOAT32', + * // initial_capacity: 1000, + * // current_capacity: 1000, + * // distance_metric: 'L2', + * // dimension: 6, + * // block_size: 1024, + * // algorithm: 'FLAT' + * // } + * // }, + * // { + * // identifier: 'name', + * // type: 'TEXT', + * // field_name: 'name', + * // option: '' + * // }, + * // ] + * // } + * ``` + */ + static async info( + client: GlideClient | GlideClusterClient, + indexName: GlideString, + options?: DecoderOption, + ): Promise { + const args: GlideString[] = ["FT.INFO", indexName]; + + return ( + _handleCustomCommand(client, args, options) as Promise< + GlideRecord + > + ).then(convertGlideRecordToRecord); + } } /** * @internal */ -function _handleCustomCommand( +async function _handleCustomCommand( client: GlideClient | GlideClusterClient, args: GlideString[], - decoderOption: DecoderOption, -) { + decoderOption?: DecoderOption, +): Promise { return client instanceof GlideClient ? (client as GlideClient).customCommand(args, decoderOption) : (client as GlideClusterClient).customCommand(args, decoderOption); diff --git a/node/tests/ServerModules.test.ts b/node/tests/ServerModules.test.ts index 4ec5757fd7..1016e2378c 100644 --- a/node/tests/ServerModules.test.ts +++ b/node/tests/ServerModules.test.ts @@ -12,6 +12,7 @@ import { import { v4 as uuidv4 } from "uuid"; import { ConditionalChange, + Decoder, GlideClusterClient, GlideFt, GlideJson, @@ -48,16 +49,16 @@ describe("Server Module Tests", () => { await cluster.close(); }, TIMEOUT); - describe("GlideJson", () => { - let client: GlideClusterClient; + describe.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "GlideJson", + (protocol) => { + let client: GlideClusterClient; - afterEach(async () => { - await flushAndCloseClient(true, cluster.getAddresses(), client); - }); + afterEach(async () => { + await flushAndCloseClient(true, cluster.getAddresses(), client); + }); - it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( - "check modules loaded", - async (protocol) => { + it("check modules loaded", async () => { client = await GlideClusterClient.createClient( getClientConfigurationOption( cluster.getAddresses(), @@ -70,12 +71,9 @@ describe("Server Module Tests", () => { }); expect(info).toContain("# json_core_metrics"); expect(info).toContain("# search_index_stats"); - }, - ); + }); - it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( - "json.set and json.get tests", - async (protocol) => { + it("json.set and json.get tests", async () => { client = await GlideClusterClient.createClient( getClientConfigurationOption( cluster.getAddresses(), @@ -118,12 +116,9 @@ describe("Server Module Tests", () => { // JSON.get with non-existing path result = await GlideJson.get(client, key, { paths: ["$.d"] }); expect(result).toEqual("[]"); - }, - ); + }); - it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( - "json.set and json.get tests with multiple value", - async (protocol) => { + it("json.set and json.get tests with multiple value", async () => { client = await GlideClusterClient.createClient( getClientConfigurationOption( cluster.getAddresses(), @@ -164,12 +159,9 @@ describe("Server Module Tests", () => { "new_value", "new_value", ]); - }, - ); + }); - it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( - "json.set conditional set", - async (protocol) => { + it("json.set conditional set", async () => { client = await GlideClusterClient.createClient( getClientConfigurationOption( cluster.getAddresses(), @@ -210,12 +202,9 @@ describe("Server Module Tests", () => { ).toBe("OK"); result = await GlideJson.get(client, key, { paths: [".a"] }); expect(result).toEqual("4.5"); - }, - ); + }); - it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( - "json.get formatting", - async (protocol) => { + it("json.get formatting", async () => { client = await GlideClusterClient.createClient( getClientConfigurationOption( cluster.getAddresses(), @@ -254,12 +243,9 @@ describe("Server Module Tests", () => { const expectedResult2 = '[\n~{\n~~"a":*1,\n~~"b":*2,\n~~"c":*{\n~~~"d":*3,\n~~~"e":*4\n~~}\n~}\n]'; expect(result).toEqual(expectedResult2); - }, - ); + }); - it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( - "json.toggle tests", - async (protocol) => { + it("json.toggle tests", async () => { client = await GlideClusterClient.createClient( getClientConfigurationOption( cluster.getAddresses(), @@ -312,12 +298,9 @@ describe("Server Module Tests", () => { await expect( GlideJson.toggle(client, "non_existing_key", { path: "$" }), ).rejects.toThrow(RequestError); - }, - ); + }); - it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( - "json.del tests", - async (protocol) => { + it("json.del tests", async () => { client = await GlideClusterClient.createClient( getClientConfigurationOption( cluster.getAddresses(), @@ -414,12 +397,9 @@ describe("Server Module Tests", () => { path: ".", }), ).toBe(0); - }, - ); + }); - it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( - "json.forget tests", - async (protocol) => { + it("json.forget tests", async () => { client = await GlideClusterClient.createClient( getClientConfigurationOption( cluster.getAddresses(), @@ -520,12 +500,9 @@ describe("Server Module Tests", () => { path: ".", }), ).toBe(0); - }, - ); + }); - it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( - "json.type tests", - async (protocol) => { + it("json.type tests", async () => { client = await GlideClusterClient.createClient( getClientConfigurationOption( cluster.getAddresses(), @@ -588,12 +565,9 @@ describe("Server Module Tests", () => { expect( await GlideJson.type(client, "non_existing", { path: "." }), ).toBeNull(); - }, - ); + }); - it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( - "json.resp tests", - async (protocol) => { + it("json.resp tests", async () => { client = await GlideClusterClient.createClient( getClientConfigurationOption( cluster.getAddresses(), @@ -717,9 +691,9 @@ describe("Server Module Tests", () => { expect( await GlideJson.resp(client, "nonexistent_key"), ).toBeNull(); - }, - ); - }); + }); + }, + ); describe("GlideFt", () => { let client: GlideClusterClient; @@ -940,5 +914,78 @@ describe("Server Module Tests", () => { expect((e as Error).message).toContain("Index does not exist"); } }); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "FT.INFO ft.info", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + + const index = uuidv4(); + expect( + await GlideFt.create( + client, + Buffer.from(index), + [ + { + type: "VECTOR", + name: "$.vec", + alias: "VEC", + attributes: { + algorithm: "HNSW", + distanceMetric: "COSINE", + dimensions: 42, + }, + }, + { type: "TEXT", name: "$.name" }, + ], + { dataType: "JSON", prefixes: ["123"] }, + ), + ).toEqual("OK"); + + let response = await GlideFt.info(client, Buffer.from(index)); + + expect(response).toMatchObject({ + index_name: index, + key_type: "JSON", + key_prefixes: ["123"], + fields: [ + { + identifier: "$.name", + type: "TEXT", + field_name: "$.name", + option: "", + }, + { + identifier: "$.vec", + type: "VECTOR", + field_name: "VEC", + option: "", + vector_params: { + distance_metric: "COSINE", + dimension: 42, + }, + }, + ], + }); + + response = await GlideFt.info(client, index, { + decoder: Decoder.Bytes, + }); + expect(response).toMatchObject({ + index_name: Buffer.from(index), + }); + + expect(await GlideFt.dropindex(client, index)).toEqual("OK"); + // querying a missing index + await expect(GlideFt.info(client, index)).rejects.toThrow( + "Index not found", + ); + }, + ); }); }); From 8ef77b6284466c1a01cebfb55b7a13d8f1502e39 Mon Sep 17 00:00:00 2001 From: tjzhang-BQ <111323543+tjzhang-BQ@users.noreply.github.com> Date: Wed, 30 Oct 2024 13:34:46 -0700 Subject: [PATCH 080/180] Node: Add command JSON.STRLEN and JSON.STRAPPEND (#2537) * Node: Add command JSON.STRLEN Signed-off-by: TJ Zhang --- CHANGELOG.md | 1 + node/npm/glide/index.ts | 4 + node/src/server-modules/GlideJson.ts | 171 +++++++++++++++++---- node/tests/ServerModules.test.ts | 220 ++++++++++++++++++++++++--- 4 files changed, 343 insertions(+), 53 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a1510be7f1..83a1ebc0c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,7 @@ * Python: Add `JSON.ARRTRIM` command ([#2457](https://github.com/valkey-io/valkey-glide/pull/2457)) * Python: Add `JSON.ARRAPPEND` command ([#2382](https://github.com/valkey-io/valkey-glide/pull/2382)) * Python: Add `JSON.RESP` command ([#2451](https://github.com/valkey-io/valkey-glide/pull/2451)) +* Node: Add `JSON.STRLEN` and `JSON.STRAPPEND` command ([#2537](https://github.com/valkey-io/valkey-glide/pull/2537)) #### Breaking Changes diff --git a/node/npm/glide/index.ts b/node/npm/glide/index.ts index 4c370588df..90cf70f2b7 100644 --- a/node/npm/glide/index.ts +++ b/node/npm/glide/index.ts @@ -213,6 +213,8 @@ function initialize() { ReturnTypeMap, ClusterResponse, ReturnTypeAttribute, + ReturnTypeJson, + UniversalReturnTypeJson, } = nativeBinding; module.exports = { @@ -347,6 +349,8 @@ function initialize() { ReturnTypeMap, ClusterResponse, ReturnTypeAttribute, + ReturnTypeJson, + UniversalReturnTypeJson, }; globalObject = Object.assign(global, nativeBinding); diff --git a/node/src/server-modules/GlideJson.ts b/node/src/server-modules/GlideJson.ts index 92d3ae35a8..1d9837c297 100644 --- a/node/src/server-modules/GlideJson.ts +++ b/node/src/server-modules/GlideJson.ts @@ -8,13 +8,14 @@ import { GlideClient } from "../GlideClient"; import { GlideClusterClient, RouteOption } from "../GlideClusterClient"; export type ReturnTypeJson = T | (T | null)[]; +export type UniversalReturnTypeJson = T | T[]; /** * Represents options for formatting JSON data, to be used in the [JSON.GET](https://valkey.io/commands/json.get/) command. */ export interface JsonGetOptions { /** The path or list of paths within the JSON document. Default is root `$`. */ - paths?: GlideString[]; + path?: GlideString | GlideString[]; /** Sets an indentation string for nested levels. */ indent?: GlideString; /** Sets a string that's printed at the end of each line. */ @@ -31,23 +32,27 @@ export interface JsonGetOptions { function _jsonGetOptionsToArgs(options: JsonGetOptions): GlideString[] { const result: GlideString[] = []; - if (options.paths !== undefined) { - result.push(...options.paths); + if (options.path) { + if (Array.isArray(options.path)) { + result.push(...options.path); + } else { + result.push(options.path); + } } - if (options.indent !== undefined) { + if (options.indent) { result.push("INDENT", options.indent); } - if (options.newline !== undefined) { + if (options.newline) { result.push("NEWLINE", options.newline); } - if (options.space !== undefined) { + if (options.space) { result.push("SPACE", options.space); } - if (options.noescape !== undefined) { + if (options.noescape) { result.push("NOESCAPE"); } @@ -100,7 +105,7 @@ export class GlideJson { * const result = await GlideJson.set("doc", "$", jsonStr); * console.log(result); // 'OK' - Indicates successful setting of the value at path '$' in the key stored at `doc`. * - * const jsonGetStr = await GlideJson.get(client, "doc", "$"); // Returns the value at path '$' in the JSON document stored at `doc` as JSON string. + * const jsonGetStr = await GlideJson.get(client, "doc", {path: "$"}); // Returns the value at path '$' in the JSON document stored at `doc` as JSON string. * console.log(jsonGetStr); // '[{"a":1.0,"b":2}]' * console.log(JSON.stringify(jsonGetStr)); // [{"a": 1.0, "b": 2}] # JSON object retrieved from the key `doc` * ``` @@ -129,7 +134,7 @@ export class GlideJson { * @param options - (Optional) Additional parameters: * - (Optional) Options for formatting the byte representation of the JSON data. See {@link JsonGetOptions}. * - (Optional) `decoder`: see {@link DecoderOption}. - * @returns ReturnTypeJson: + * @returns * - If one path is given: * - For JSONPath (path starts with `$`): * - Returns a stringified JSON list of bytes replies for every possible path, @@ -145,15 +150,15 @@ export class GlideJson { * * @example * ```typescript - * const jsonStr = await client.jsonGet('doc', '$'); + * const jsonStr = await GlideJson.get('doc', {path: '$'}); * console.log(JSON.parse(jsonStr as string)); * // Output: [{"a": 1.0, "b" :2}] - JSON object retrieved from the key `doc`. * - * const jsonData = await client.jsonGet('doc', '$'); + * const jsonData = await GlideJson.get(('doc', {path: '$'}); * console.log(jsonData); * // Output: '[{"a":1.0,"b":2}]' - Returns the value at path '$' in the JSON document stored at `doc`. * - * const formattedJson = await client.jsonGet('doc', { + * const formattedJson = await GlideJson.get(('doc', { * ['$.a', '$.b'] * indent: " ", * newline: "\n", @@ -162,7 +167,7 @@ export class GlideJson { * console.log(formattedJson); * // Output: "{\n \"$.a\": [\n 1.0\n ],\n \"$.b\": [\n 2\n ]\n}" - Returns values at paths '$.a' and '$.b' with custom format. * - * const nonExistingPath = await client.jsonGet('doc', '$.non_existing_path'); + * const nonExistingPath = await GlideJson.get(('doc', {path: '$.non_existing_path'}); * console.log(nonExistingPath); * // Output: "[]" - Empty array since the path does not exist in the JSON document. * ``` @@ -192,7 +197,7 @@ export class GlideJson { * @param client - The client to execute the command. * @param key - The key of the JSON document. * @param options - (Optional) Additional parameters: - * - (Optional) path - The JSONPath to specify. Defaults to the root if not specified. + * - (Optional) `path`: The JSONPath to specify. Defaults to root (`"."`) if not provided. * @returns - For JSONPath (`path` starts with `$`), returns a list of boolean replies for every possible path, with the toggled boolean value, * or `null` for JSON values matching the path that are not boolean. * - For legacy path (`path` doesn't starts with `$`), returns the value of the toggled boolean in `path`. @@ -220,7 +225,7 @@ export class GlideJson { * * // Without specifying a path, the path defaults to root. * console.log(await GlideJson.set(client, "doc2", ".", true)); // Output: "OK" - * console.log(await GlideJson.toggle(client,"doc2")); // Output: "false" + * console.log(await GlideJson.toggle(client, "doc2")); // Output: "false" * console.log(await GlideJson.toggle(client, "doc2")); // Output: "true" * ``` */ @@ -244,16 +249,16 @@ export class GlideJson { * @param client - The client to execute the command. * @param key - The key of the JSON document. * @param options - (Optional) Additional parameters: - * - (Optional) path - If `null`, deletes the entire JSON document at `key`. + * - (Optional) `path`: If `null`, deletes the entire JSON document at `key`. * @returns - The number of elements removed. If `key` or `path` doesn't exist, returns 0. * * @example * ```typescript * console.log(await GlideJson.set(client, "doc", "$", '{a: 1, nested: {a:2, b:3}}')); * // Output: "OK" - Indicates successful setting of the value at path '$' in the key stored at `doc`. - * console.log(await GlideJson.del(client, "doc", "$..a")); + * console.log(await GlideJson.del(client, "doc", {path: "$..a"})); * // Output: 2 - Indicates successful deletion of the specific values in the key stored at `doc`. - * console.log(await GlideJson.get(client, "doc", "$")); + * console.log(await GlideJson.get(client, "doc", {path: "$"})); * // Output: "[{nested: {b: 3}}]" - Returns the value at path '$' in the JSON document stored at `doc`. * console.log(await GlideJson.del(client, "doc")); * // Output: 1 - Deletes the entire JSON document stored at `doc`. @@ -280,7 +285,7 @@ export class GlideJson { * @param client - The client to execute the command. * @param key - The key of the JSON document. * @param options - (Optional) Additional parameters: - * - (Optional) path - If `null`, deletes the entire JSON document at `key`. + * - (Optional) `path`: If `null`, deletes the entire JSON document at `key`. * @returns - The number of elements removed. If `key` or `path` doesn't exist, returns 0. * * @example @@ -315,8 +320,8 @@ export class GlideJson { * @param client - The client to execute the command. * @param key - The key of the JSON document. * @param options - (Optional) Additional parameters: - * - (Optional) path - defaults to root if not provided. - * @returns ReturnTypeJson: + * - (Optional) `path`: Defaults to root (`"."`) if not provided. + * @returns * - For JSONPath (path starts with `$`): * - Returns an array of strings that represents the type of value at each path. * The type is one of "null", "boolean", "string", "number", "integer", "object" and "array". @@ -336,7 +341,7 @@ export class GlideJson { * // Output: ["integer", "number", "string", "boolean", null, "object", "array"]; * console.log(await GlideJson.set(client, "doc2", ".", "{Name: 'John', Age: 27}")); * console.log(await GlideJson.type(client, "doc2")); // Output: "object" - * console.log(await GlideJson.type(client, "doc2", ".Age")); // Output: "integer" + * console.log(await GlideJson.type(client, "doc2", {path: ".Age"})); // Output: "integer" * ``` */ static async type( @@ -367,8 +372,9 @@ export class GlideJson { * @param client - The client to execute the command. * @param key - The key of the JSON document. * @param options - (Optional) Additional parameters: - * - (Optional) path - The path within the JSON document, Defaults to root if not provided. - * @returns ReturnTypeJson: + * - (Optional) `path`: The path within the JSON document, defaults to root (`"."`) if not provided. + * - (Optional) `decoder`: see {@link DecoderOption}. + * @returns * - For JSONPath (path starts with `$`): * - Returns an array of replies for every possible path, indicating the RESP form of the JSON value. * If `path` doesn't exist, returns an empty array. @@ -381,23 +387,128 @@ export class GlideJson { * ```typescript * console.log(await GlideJson.set(client, "doc", ".", "{a: [1, 2, 3], b: {a: [1, 2], c: {a: 42}}}")); * // Output: 'OK' - Indicates successful setting of the value at path '.' in the key stored at `doc`. - * const result = await GlideJson.resp(client, "doc", "$..a"); + * const result = await GlideJson.resp(client, "doc", {path: "$..a"}); * console.log(result); - * // Output: [ ["[", 1L, 2L, 3L], ["[", 1L, 2L], [42L]]; - * console.log(await GlideJson.type(client, "doc", "..a")); // Output: ["[", 1L, 2L, 3L] + * // Output: [ ["[", 1, 2, 3], ["[", 1, 2], [42]]; + * console.log(await GlideJson.type(client, "doc", {path: "..a"})); // Output: ["[", 1, 2, 3] * ``` */ static async resp( client: BaseClient, key: GlideString, - options?: { path: GlideString }, - ): Promise> { + options?: { path: GlideString } & DecoderOption, + ): Promise< + UniversalReturnTypeJson< + (number | GlideString) | (number | GlideString | null) | null + > + > { const args = ["JSON.RESP", key]; if (options) { args.push(options.path); } - return _executeCommand>(client, args); + return _executeCommand(client, args, options); + } + + /** + * Returns the length of the JSON string value stored at the specified `path` within + * the JSON document stored at `key`. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param options - (Optional) Additional parameters: + * - (Optional) `path`: The path within the JSON document, Defaults to root (`"."`) if not provided. + * @returns + * - For JSONPath (path starts with `$`): + * - Returns a list of integer replies for every possible path, indicating the length of + * the JSON string value, or null for JSON values matching the path that + * are not string. + * - For legacy path (path doesn't start with `$`): + * - Returns the length of the JSON value at `path` or `null` if `key` doesn't exist. + * - If multiple paths match, the length of the first matched string is returned. + * - If the JSON value at`path` is not a string or if `path` doesn't exist, an error is raised. + * - If `key` doesn't exist, `null` is returned. + * + * @example + * ```typescript + * console.log(await GlideJson.set(client, "doc", "$", '{a:"foo", nested: {a: "hello"}, nested2: {a: 31}}")); + * // Output: 'OK' - Indicates successful setting of the value at path '$' in the key stored at `doc`. + * console.log(await GlideJson.strlen(client, "doc", {path: "$..a"})); + * // Output: [3, 5, null] - The length of the string values at path '$..a' in the key stored at `doc`. + * + * console.log(await GlideJson.strlen(client, "doc", {path: "nested.a"})); + * // Output: 5 - The length of the JSON value at path 'nested.a' in the key stored at `doc`. + * + * console.log(await GlideJson.strlen(client, "doc", {path: "$"})); + * // Output: [null] - Returns an array with null since the value at root path does in the JSON document stored at `doc` is not a string. + * + * console.log(await GlideJson.strlen(client, "non_existent_key", {path: "."})); + * // Output: null - return null if key does not exist. + * ``` + */ + static async strlen( + client: BaseClient, + key: GlideString, + options?: { path: GlideString }, + ): Promise> { + const args = ["JSON.STRLEN", key]; + + if (options) { + args.push(options.path); + } + + return _executeCommand(client, args); + } + + /** + * Appends the specified `value` to the string stored at the specified `path` within the JSON document stored at `key`. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param value - The value to append to the string. Must be wrapped with single quotes. For example, to append "foo", pass '"foo"'. + * @param options - (Optional) Additional parameters: + * - (Optional) `path`: The path within the JSON document, defaults to root (`"."`) if not provided. + * @returns + * - For JSONPath (path starts with `$`): + * - Returns a list of integer replies for every possible path, indicating the length of the resulting string after appending `value`, + * or None for JSON values matching the path that are not string. + * - If `key` doesn't exist, an error is raised. + * - For legacy path (path doesn't start with `$`): + * - Returns the length of the resulting string after appending `value` to the string at `path`. + * - If multiple paths match, the length of the last updated string is returned. + * - If the JSON value at `path` is not a string of if `path` doesn't exist, an error is raised. + * - If `key` doesn't exist, an error is raised. + * + * @example + * ```typescript + * console.log(await GlideJson.set(client, "doc", "$", '{a:"foo", nested: {a: "hello"}, nested2: {a: 31}}")); + * // Output: 'OK' - Indicates successful setting of the value at path '$' in the key stored at `doc`. + * console.log(await GlideJson.strappend(client, "doc", jsonpy.dumps("baz"), {path: "$..a"})) + * // Output: [6, 8, null] - The new length of the string values at path '$..a' in the key stored at `doc` after the append operation. + * + * console.log(await GlideJson.strappend(client, "doc", '"foo"', {path: "nested.a"})); + * // Output: 11 - The length of the string value after appending "foo" to the string at path 'nested.array' in the key stored at `doc`. + * + * const result = JSON.parse(await GlideJson.get(client, "doc", {path: "$"})); + * console.log(result); + * // Output: [{"a":"foobaz", "nested": {"a": "hellobazfoo"}, "nested2": {"a": 31}}] - The updated JSON value in the key stored at `doc`. + * ``` + */ + static async strappend( + client: BaseClient, + key: GlideString, + value: GlideString, + options?: { path: GlideString }, + ): Promise> { + const args = ["JSON.STRAPPEND", key]; + + if (options) { + args.push(options.path); + } + + args.push(value); + + return _executeCommand>(client, args); } } diff --git a/node/tests/ServerModules.test.ts b/node/tests/ServerModules.test.ts index 1016e2378c..ea1f831e39 100644 --- a/node/tests/ServerModules.test.ts +++ b/node/tests/ServerModules.test.ts @@ -94,12 +94,12 @@ describe("Server Module Tests", () => { ).toBe("OK"); // JSON.get - let result = await GlideJson.get(client, key, { paths: ["."] }); + let result = await GlideJson.get(client, key, { path: "." }); expect(JSON.parse(result.toString())).toEqual(jsonValue); // JSON.get with array of paths result = await GlideJson.get(client, key, { - paths: ["$.a", "$.b"], + path: ["$.a", "$.b"], }); expect(JSON.parse(result.toString())).toEqual({ "$.a": [1.0], @@ -109,12 +109,12 @@ describe("Server Module Tests", () => { // JSON.get with non-existing key expect( await GlideJson.get(client, "non_existing_key", { - paths: ["$"], + path: ["$"], }), ); // JSON.get with non-existing path - result = await GlideJson.get(client, key, { paths: ["$.d"] }); + result = await GlideJson.get(client, key, { path: "$.d" }); expect(result).toEqual("[]"); }); @@ -143,7 +143,7 @@ describe("Server Module Tests", () => { // JSON.get with deep path let result = await GlideJson.get(client, key, { - paths: ["$..c"], + path: "$..c", }); expect(JSON.parse(result.toString())).toEqual([true, 1, 2]); @@ -153,7 +153,7 @@ describe("Server Module Tests", () => { ).toBe("OK"); // verify JSON.set result - result = await GlideJson.get(client, key, { paths: ["$..c"] }); + result = await GlideJson.get(client, key, { path: "$..c" }); expect(JSON.parse(result.toString())).toEqual([ "new_value", "new_value", @@ -191,7 +191,7 @@ describe("Server Module Tests", () => { }), ).toBeNull(); let result = await GlideJson.get(client, key, { - paths: [".a"], + path: ".a", }); expect(result).toEqual("1"); @@ -200,7 +200,7 @@ describe("Server Module Tests", () => { conditionalChange: ConditionalChange.ONLY_IF_EXISTS, }), ).toBe("OK"); - result = await GlideJson.get(client, key, { paths: [".a"] }); + result = await GlideJson.get(client, key, { path: ".a" }); expect(result).toEqual("4.5"); }); @@ -223,7 +223,7 @@ describe("Server Module Tests", () => { ).toBe("OK"); // JSON.get with formatting options let result = await GlideJson.get(client, key, { - paths: ["$"], + path: "$", indent: " ", newline: "\n", space: " ", @@ -234,7 +234,7 @@ describe("Server Module Tests", () => { expect(result).toEqual(expectedResult1); // JSON.get with different formatting options result = await GlideJson.get(client, key, { - paths: ["$"], + path: "$", indent: "~", newline: "\n", space: "*", @@ -327,13 +327,13 @@ describe("Server Module Tests", () => { await GlideJson.del(client, key, { path: "..path" }), ).toBe(0); - // deleting existing paths + // deleting existing path expect(await GlideJson.del(client, key, { path: "$..a" })).toBe( 2, ); - expect( - await GlideJson.get(client, key, { paths: ["$..a"] }), - ).toBe("[]"); + expect(await GlideJson.get(client, key, { path: "$..a" })).toBe( + "[]", + ); expect( await GlideJson.set( client, @@ -346,12 +346,12 @@ describe("Server Module Tests", () => { 2, ); await expect( - GlideJson.get(client, key, { paths: ["..a"] }), + GlideJson.get(client, key, { path: "..a" }), ).rejects.toThrow(RequestError); // verify result const result = await GlideJson.get(client, key, { - paths: ["$"], + path: "$", }); expect(JSON.parse(result as string)).toEqual([ { b: { b: 2.5, c: true } }, @@ -383,7 +383,7 @@ describe("Server Module Tests", () => { expect(await GlideJson.del(client, key)).toBe(1); expect(await GlideJson.del(client, key)).toBe(0); expect( - await GlideJson.get(client, key, { paths: ["$"] }), + await GlideJson.get(client, key, { path: "$" }), ).toBeNull(); // non-existing keys @@ -430,9 +430,9 @@ describe("Server Module Tests", () => { expect( await GlideJson.forget(client, key, { path: "$..a" }), ).toBe(2); - expect( - await GlideJson.get(client, key, { paths: ["$..a"] }), - ).toBe("[]"); + expect(await GlideJson.get(client, key, { path: "$..a" })).toBe( + "[]", + ); expect( await GlideJson.set( client, @@ -445,12 +445,12 @@ describe("Server Module Tests", () => { await GlideJson.forget(client, key, { path: "..a" }), ).toBe(2); await expect( - GlideJson.get(client, key, { paths: ["..a"] }), + GlideJson.get(client, key, { path: "..a" }), ).rejects.toThrow(RequestError); // verify result const result = await GlideJson.get(client, key, { - paths: ["$"], + path: "$", }); expect(JSON.parse(result as string)).toEqual([ { b: { b: 2.5, c: true } }, @@ -486,7 +486,7 @@ describe("Server Module Tests", () => { expect(await GlideJson.forget(client, key)).toBe(1); expect(await GlideJson.forget(client, key)).toBe(0); expect( - await GlideJson.get(client, key, { paths: ["$"] }), + await GlideJson.get(client, key, { path: "$" }), ).toBeNull(); // non-existing keys @@ -692,6 +692,180 @@ describe("Server Module Tests", () => { await GlideJson.resp(client, "nonexistent_key"), ).toBeNull(); }); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "json.strlen tests", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const jsonValue = { + a: "foo", + nested: { a: "hello" }, + nested2: { a: 31 }, + }; + // setup + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + + expect( + await GlideJson.strlen(client, key, { path: "$..a" }), + ).toEqual([3, 5, null]); + expect( + await GlideJson.strlen(client, key, { path: "a" }), + ).toBe(3); + + expect( + await GlideJson.strlen(client, key, { + path: "$.nested", + }), + ).toEqual([null]); + expect( + await GlideJson.strlen(client, key, { path: "$..a" }), + ).toEqual([3, 5, null]); + + expect( + await GlideJson.strlen(client, "non_existing_key", { + path: ".", + }), + ).toBeNull(); + expect( + await GlideJson.strlen(client, "non_existing_key", { + path: "$", + }), + ).toBeNull(); + expect( + await GlideJson.strlen(client, key, { + path: "$.non_existing_path", + }), + ).toEqual([]); + + // error case + await expect( + GlideJson.strlen(client, key, { path: "nested" }), + ).rejects.toThrow(RequestError); + await expect(GlideJson.strlen(client, key)).rejects.toThrow( + RequestError, + ); + }, + ); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "json.strappend tests", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const jsonValue = { + a: "foo", + nested: { a: "hello" }, + nested2: { a: 31 }, + }; + // setup + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + + expect( + await GlideJson.strappend(client, key, '"bar"', { + path: "$..a", + }), + ).toEqual([6, 8, null]); + expect( + await GlideJson.strappend( + client, + key, + JSON.stringify("foo"), + { + path: "a", + }, + ), + ).toBe(9); + + expect( + await GlideJson.get(client, key, { path: "." }), + ).toEqual( + JSON.stringify({ + a: "foobarfoo", + nested: { a: "hellobar" }, + nested2: { a: 31 }, + }), + ); + + expect( + await GlideJson.strappend( + client, + key, + JSON.stringify("bar"), + { + path: "$.nested", + }, + ), + ).toEqual([null]); + + await expect( + GlideJson.strappend( + client, + key, + JSON.stringify("bar"), + { + path: ".nested", + }, + ), + ).rejects.toThrow(RequestError); + await expect( + GlideJson.strappend(client, key, JSON.stringify("bar")), + ).rejects.toThrow(RequestError); + + expect( + await GlideJson.strappend( + client, + key, + JSON.stringify("try"), + { + path: "$.non_existing_path", + }, + ), + ).toEqual([]); + + await expect( + GlideJson.strappend( + client, + key, + JSON.stringify("try"), + { + path: ".non_existing_path", + }, + ), + ).rejects.toThrow(RequestError); + await expect( + GlideJson.strappend( + client, + "non_existing_key", + JSON.stringify("try"), + ), + ).rejects.toThrow(RequestError); + }, + ); }, ); From 6b061383e61c98a35ab55594696f8075bf93a732 Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Wed, 30 Oct 2024 19:34:25 -0700 Subject: [PATCH 081/180] Node: `JSON.ARRINSERT`, `JSON.ARRPOP` and `JSON.ARRLEN`. (#2542) * `JSON.ARRINSERT`, `JSON.ARRPOP` and `JSON.ARRLEN`. Signed-off-by: Yury-Fridlyand --- CHANGELOG.md | 1 + .../api/commands/servermodules/Json.java | 12 +- node/npm/glide/index.ts | 2 + node/src/BaseClient.ts | 2 +- node/src/server-modules/GlideJson.ts | 168 ++++++++++++++- node/tests/ServerModules.test.ts | 194 ++++++++++++++++++ .../async_commands/server_modules/json.py | 4 +- 7 files changed, 363 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 83a1ebc0c7..bf0e71a3d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ * Java: Added `JSON.ARRAPPEND` ([#2489](https://github.com/valkey-io/valkey-glide/pull/2489)) * Java: Added `JSON.ARRTRIM` ([#2518](https://github.com/valkey-io/valkey-glide/pull/2518)) * Node: Added `JSON.TOGGLE` ([#2491](https://github.com/valkey-io/valkey-glide/pull/2491)) +* Node: Added `JSON.ARRINSERT`, `JSON.ARRPOP` and `JSON.ARRLEN` ([#2542](https://github.com/valkey-io/valkey-glide/pull/2542)) * Node: Added `JSON.DEL` and `JSON.FORGET` ([#2505](https://github.com/valkey-io/valkey-glide/pull/2505)) * Java: Added `JSON.TOGGLE` ([#2504](https://github.com/valkey-io/valkey-glide/pull/2504)) * Java: Added `JSON.STRAPPEND` and `JSON.STRLEN` ([#2522](https://github.com/valkey-io/valkey-glide/pull/2522)) diff --git a/java/client/src/main/java/glide/api/commands/servermodules/Json.java b/java/client/src/main/java/glide/api/commands/servermodules/Json.java index 939557307d..eef7a8d4c8 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/Json.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/Json.java @@ -1009,8 +1009,7 @@ public static CompletableFuture arrpop( *
      • For JSONPath (path starts with $):
        * Returns an array with a strings for every possible path, representing the popped JSON * values, or null for JSON values matching the path that are not an array - * or an empty array. If a value is not an array, its corresponding return value is - * "null". + * or an empty array. *
      • For legacy path (path doesn't start with $):
        * Returns a string representing the popped JSON value, or null if the * array at path is empty. If multiple paths are matched, the value from @@ -1044,8 +1043,7 @@ public static CompletableFuture arrpop( *
      • For JSONPath (path starts with $):
        * Returns an array with a strings for every possible path, representing the popped JSON * values, or null for JSON values matching the path that are not an array - * or an empty array. If a value is not an array, its corresponding return value is - * "null". + * or an empty array. *
      • For legacy path (path doesn't start with $):
        * Returns a string representing the popped JSON value, or null if the * array at path is empty. If multiple paths are matched, the value from @@ -1081,8 +1079,7 @@ public static CompletableFuture arrpop( *
      • For JSONPath (path starts with $):
        * Returns an array with a strings for every possible path, representing the popped JSON * values, or null for JSON values matching the path that are not an array - * or an empty array. If a value is not an array, its corresponding return value is - * "null". + * or an empty array. *
      • For legacy path (path doesn't start with $):
        * Returns a string representing the popped JSON value, or null if the * array at path is empty. If multiple paths are matched, the value from @@ -1121,8 +1118,7 @@ public static CompletableFuture arrpop( *
      • For JSONPath (path starts with $):
        * Returns an array with a strings for every possible path, representing the popped JSON * values, or null for JSON values matching the path that are not an array - * or an empty array. If a value is not an array, its corresponding return value is - * "null". + * or an empty array. *
      • For legacy path (path doesn't start with $):
        * Returns a string representing the popped JSON value, or null if the * array at path is empty. If multiple paths are matched, the value from diff --git a/node/npm/glide/index.ts b/node/npm/glide/index.ts index 90cf70f2b7..0dc40fd055 100644 --- a/node/npm/glide/index.ts +++ b/node/npm/glide/index.ts @@ -130,6 +130,7 @@ function initialize() { GlideRecord, GlideString, JsonGetOptions, + JsonArrPopOptions, SortedSetDataType, StreamEntryDataType, HashDataType, @@ -252,6 +253,7 @@ function initialize() { GlideJson, GlideString, JsonGetOptions, + JsonArrPopOptions, SortedSetDataType, StreamEntryDataType, HashDataType, diff --git a/node/src/BaseClient.ts b/node/src/BaseClient.ts index de31ee3c5e..768f995119 100644 --- a/node/src/BaseClient.ts +++ b/node/src/BaseClient.ts @@ -2286,7 +2286,7 @@ export class BaseClient { * * @param key - The key of the set. * @param cursor - The cursor that points to the next iteration of results. A value of `"0"` indicates the start of the search. - * @param options - (Optional) The {@link HScanOptions}. + * @param options - (Optional) See {@link HScanOptions} and {@link DecoderOption}. * @returns An array of the `cursor` and the subset of the hash held by `key`. * The first element is always the `cursor` for the next iteration of results. `"0"` will be the `cursor` * returned on the last iteration of the hash. The second element is always an array of the subset of the diff --git a/node/src/server-modules/GlideJson.ts b/node/src/server-modules/GlideJson.ts index 1d9837c297..3bf942a4bb 100644 --- a/node/src/server-modules/GlideJson.ts +++ b/node/src/server-modules/GlideJson.ts @@ -11,7 +11,7 @@ export type ReturnTypeJson = T | (T | null)[]; export type UniversalReturnTypeJson = T | T[]; /** - * Represents options for formatting JSON data, to be used in the [JSON.GET](https://valkey.io/commands/json.get/) command. + * Represents options for formatting JSON data, to be used in the {@link GlideJson.get | JSON.GET} command. */ export interface JsonGetOptions { /** The path or list of paths within the JSON document. Default is root `$`. */ @@ -26,6 +26,14 @@ export interface JsonGetOptions { noescape?: boolean; } +/** Additional options for {@link GlideJson.arrpop | JSON.ARRPOP} command. */ +export interface JsonArrPopOptions { + /** The path within the JSON document. */ + path: GlideString; + /** The index of the element to pop. Out of boundary indexes are rounded to their respective array boundaries. */ + index?: number; +} + /** * @internal */ @@ -127,7 +135,7 @@ export class GlideJson { } /** - * Retrieves the JSON value at the specified `paths` stored at `key`. + * Retrieves the JSON value at the specified `paths` stored at `key`. * * @param client The client to execute the command. * @param key - The key of the JSON document. @@ -184,11 +192,153 @@ export class GlideJson { args.push(...optionArgs); } - return _executeCommand>( - client, - args, - options, - ); + return _executeCommand(client, args, options); + } + + /** + * Inserts one or more values into the array at the specified `path` within the JSON + * document stored at `key`, before the given `index`. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param path - The path within the JSON document. + * @param index - The array index before which values are inserted. + * @param values - The JSON values to be inserted into the array, in JSON formatted bytes or str. + * JSON string values must be wrapped with quotes. For example, to append `"foo"`, pass `"\"foo\""`. + * @returns + * - For JSONPath (path starts with `$`): + * Returns an array with a list of integers for every possible path, + * indicating the new length of the array, or `null` for JSON values matching + * the path that are not an array. If `path` does not exist, an empty array + * will be returned. + * - For legacy path (path doesn't start with `$`): + * Returns an integer representing the new length of the array. If multiple paths are + * matched, returns the length of the first modified array. If `path` doesn't + * exist or the value at `path` is not an array, an error is raised. + * - If the index is out of bounds or `key` doesn't exist, an error is raised. + * + * @example + * ```typescript + * await GlideJson.set(client, "doc", "$", '[[], ["a"], ["a", "b"]]'); + * const result = await GlideJson.arrinsert(client, "doc", "$[*]", 0, ['"c"', '{"key": "value"}', "true", "null", '["bar"]']); + * console.log(result); // Output: [5, 6, 7] + * const doc = await json.get(client, "doc"); + * console.log(doc); // Output: '[["c",{"key":"value"},true,null,["bar"]],["c",{"key":"value"},true,null,["bar"],"a"],["c",{"key":"value"},true,null,["bar"],"a","b"]]' + * ``` + * @example + * ```typescript + * await GlideJson.set(client, "doc", "$", '[[], ["a"], ["a", "b"]]'); + * const result = await GlideJson.arrinsert(client, "doc", ".", 0, ['"c"']) + * console.log(result); // Output: 4 + * const doc = await json.get(client, "doc"); + * console.log(doc); // Output: '[\"c\",[],[\"a\"],[\"a\",\"b\"]]' + * ``` + */ + static async arrinsert( + client: BaseClient, + key: GlideString, + path: GlideString, + index: number, + values: GlideString[], + ): Promise> { + const args = ["JSON.ARRINSERT", key, path, index.toString(), ...values]; + + return _executeCommand(client, args); + } + + /** + * Pops an element from the array located at `path` in the JSON document stored at `key`. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param options - (Optional) See {@link JsonArrPopOptions} and {@link DecoderOption}. + * @returns + * - For JSONPath (path starts with `$`): + * Returns an array with a strings for every possible path, representing the popped JSON + * values, or `null` for JSON values matching the path that are not an array + * or an empty array. + * - For legacy path (path doesn't start with `$`): + * Returns a string representing the popped JSON value, or `null` if the + * array at `path` is empty. If multiple paths are matched, the value from + * the first matching array that is not empty is returned. If `path` doesn't + * exist or the value at `path` is not an array, an error is raised. + * - If the index is out of bounds or `key` doesn't exist, an error is raised. + * + * @example + * ```typescript + * await GlideJson.set(client, "doc", "$", '{"a": [1, 2, true], "b": {"a": [3, 4, ["value", 3, false], 5], "c": {"a": 42}}}'); + * let result = await GlideJson.arrpop(client, "doc", { path: "$.a", index: 1 }); + * console.log(result); // Output: ['2'] - Popped second element from array at path `$.a` + * result = await GlideJson.arrpop(client, "doc", { path: "$..a" }); + * console.log(result); // Output: ['true', '5', null] - Popped last elements from all arrays matching path `$..a` + * + * result = await GlideJson.arrpop(client, "doc", { path: "..a" }); + * console.log(result); // Output: "1" - First match popped (from array at path ..a) + * // Even though only one value is returned from `..a`, subsequent arrays are also affected + * console.log(await GlideJson.get(client, "doc", "$..a")); // Output: "[[], [3, 4], 42]" + * ``` + * @example + * ```typescript + * await GlideJson.set(client, "doc", "$", '[[], ["a"], ["a", "b", "c"]]'); + * let result = await GlideJson.arrpop(client, "doc", { path: ".", index: -1 }); + * console.log(result); // Output: '["a","b","c"]' - Popped last elements at path `.` + * ``` + */ + static async arrpop( + client: BaseClient, + key: GlideString, + options?: JsonArrPopOptions & DecoderOption, + ): Promise> { + const args = ["JSON.ARRPOP", key]; + if (options?.path) args.push(options?.path); + if (options && "index" in options && options.index) + args.push(options?.index.toString()); + + return _executeCommand(client, args, options); + } + + /** + * Retrieves the length of the array at the specified `path` within the JSON document stored at `key`. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param options - (Optional) Additional parameters: + * - (Optional) `path`: The path within the JSON document. Defaults to the root (`"."`) if not specified. + * @returns + * - For JSONPath (path starts with `$`): + * Returns an array with a list of integers for every possible path, + * indicating the length of the array, or `null` for JSON values matching + * the path that are not an array. If `path` does not exist, an empty array + * will be returned. + * - For legacy path (path doesn't start with `$`): + * Returns an integer representing the length of the array. If multiple paths are + * matched, returns the length of the first matching array. If `path` doesn't + * exist or the value at `path` is not an array, an error is raised. + * - If the index is out of bounds or `key` doesn't exist, an error is raised. + * + * @example + * ```typescript + * await GlideJson.set(client, "doc", "$", '{"a": [1, 2, 3], "b": {"a": [1, 2], "c": {"a": 42}}}'); + * console.log(await GlideJson.arrlen(client, "doc", { path: "$" })); // Output: [null] - No array at the root path. + * console.log(await GlideJson.arrlen(client, "doc", { path: "$.a" })); // Output: [3] - Retrieves the length of the array at path $.a. + * console.log(await GlideJson.arrlen(client, "doc", { path: "$..a" })); // Output: [3, 2, null] - Retrieves lengths of arrays found at all levels of the path `$..a`. + * console.log(await GlideJson.arrlen(client, "doc", { path: "..a" })); // Output: 3 - Legacy path retrieves the first array match at path `..a`. + * ``` + * @example + * ```typescript + * await GlideJson.set(client, "doc", "$", '[1, 2, 3, 4]'); + * console.log(await GlideJson.arrlen(client, "doc")); // Output: 4 - the length of array at root. + * ``` + */ + static async arrlen( + client: BaseClient, + key: GlideString, + options?: { path: GlideString }, + ): Promise> { + const args = ["JSON.ARRLEN", key]; + if (options?.path) args.push(options?.path); + + return _executeCommand(client, args); } /** @@ -197,7 +347,7 @@ export class GlideJson { * @param client - The client to execute the command. * @param key - The key of the JSON document. * @param options - (Optional) Additional parameters: - * - (Optional) `path`: The JSONPath to specify. Defaults to root (`"."`) if not provided. + * - (Optional) `path`: The path within the JSON document. Defaults to the root (`"."`) if not specified. * @returns - For JSONPath (`path` starts with `$`), returns a list of boolean replies for every possible path, with the toggled boolean value, * or `null` for JSON values matching the path that are not boolean. * - For legacy path (`path` doesn't starts with `$`), returns the value of the toggled boolean in `path`. @@ -240,7 +390,7 @@ export class GlideJson { args.push(options.path); } - return _executeCommand>(client, args); + return _executeCommand(client, args); } /** diff --git a/node/tests/ServerModules.test.ts b/node/tests/ServerModules.test.ts index ea1f831e39..4baf854d54 100644 --- a/node/tests/ServerModules.test.ts +++ b/node/tests/ServerModules.test.ts @@ -245,6 +245,200 @@ describe("Server Module Tests", () => { expect(result).toEqual(expectedResult2); }); + it("json.arrinsert", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + + const key = uuidv4(); + const doc = { + a: [], + b: { a: [1, 2, 3, 4] }, + c: { a: "not an array" }, + d: [{ a: ["x", "y"] }, { a: [["foo"]] }], + e: [{ a: 42 }, { a: {} }], + f: { a: [true, false, null] }, + }; + expect( + await GlideJson.set(client, key, "$", JSON.stringify(doc)), + ).toBe("OK"); + + const result = await GlideJson.arrinsert( + client, + key, + "$..a", + 0, + [ + '"string_value"', + "123", + '{"key": "value"}', + "true", + "null", + '["bar"]', + ], + ); + expect(result).toEqual([6, 10, null, 8, 7, null, null, 9]); + + const expected = { + a: [ + "string_value", + 123, + { key: "value" }, + true, + null, + ["bar"], + ], + b: { + a: [ + "string_value", + 123, + { key: "value" }, + true, + null, + ["bar"], + 1, + 2, + 3, + 4, + ], + }, + c: { a: "not an array" }, + d: [ + { + a: [ + "string_value", + 123, + { key: "value" }, + true, + null, + ["bar"], + "x", + "y", + ], + }, + { + a: [ + "string_value", + 123, + { key: "value" }, + true, + null, + ["bar"], + ["foo"], + ], + }, + ], + e: [{ a: 42 }, { a: {} }], + f: { + a: [ + "string_value", + 123, + { key: "value" }, + true, + null, + ["bar"], + true, + false, + null, + ], + }, + }; + expect( + JSON.parse((await GlideJson.get(client, key)) as string), + ).toEqual(expected); + }); + + it("json.arrpop", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + + const key = uuidv4(); + let doc = + '{"a": [1, 2, true], "b": {"a": [3, 4, ["value", 3, false], 5], "c": {"a": 42}}}'; + expect(await GlideJson.set(client, key, "$", doc)).toBe("OK"); + + let res = await GlideJson.arrpop(client, key, { + path: "$.a", + index: 1, + }); + expect(res).toEqual(["2"]); + + res = await GlideJson.arrpop(client, Buffer.from(key), { + path: "$..a", + }); + expect(res).toEqual(["true", "5", null]); + + res = await GlideJson.arrpop(client, key, { + path: "..a", + decoder: Decoder.Bytes, + }); + expect(res).toEqual(Buffer.from("1")); + + // Even if only one array element was returned, ensure second array at `..a` was popped + doc = (await GlideJson.get(client, key, { + path: ["$..a"], + })) as string; + expect(doc).toEqual("[[],[3,4],42]"); + + // Out of index + res = await GlideJson.arrpop(client, key, { + path: Buffer.from("$..a"), + index: 10, + }); + expect(res).toEqual([null, "4", null]); + + // pop without options + expect(await GlideJson.set(client, key, "$", doc)).toEqual( + "OK", + ); + expect(await GlideJson.arrpop(client, key)).toEqual("42"); + }); + + it("json.arrlen", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + + const key = uuidv4(); + const doc = + '{"a": [1, 2, 3], "b": {"a": [1, 2], "c": {"a": 42}}}'; + expect(await GlideJson.set(client, key, "$", doc)).toBe("OK"); + + expect( + await GlideJson.arrlen(client, key, { path: "$.a" }), + ).toEqual([3]); + expect( + await GlideJson.arrlen(client, key, { path: "$..a" }), + ).toEqual([3, 2, null]); + // Legacy path retrieves the first array match at ..a + expect( + await GlideJson.arrlen(client, key, { path: "..a" }), + ).toEqual(3); + // Value at path is not an array + expect( + await GlideJson.arrlen(client, key, { path: "$" }), + ).toEqual([null]); + + await expect( + GlideJson.arrlen(client, key, { path: "." }), + ).rejects.toThrow(); + + expect( + await GlideJson.set(client, key, "$", "[1, 2, 3, 4]"), + ).toBe("OK"); + expect(await GlideJson.arrlen(client, key)).toEqual(4); + }); + it("json.toggle tests", async () => { client = await GlideClusterClient.createClient( getClientConfigurationOption( diff --git a/python/python/glide/async_commands/server_modules/json.py b/python/python/glide/async_commands/server_modules/json.py index a8d0dfcfcc..2b57239a41 100644 --- a/python/python/glide/async_commands/server_modules/json.py +++ b/python/python/glide/async_commands/server_modules/json.py @@ -383,7 +383,7 @@ async def arrlen( >>> await json.arrlen(client, "doc", "$.a") [3] # Retrieves the length of the array at path $.a. >>> await json.arrlen(client, "doc", "$..a") - [3, 2, None] # Retrieves lengths of arrays found at all levels of the path `..a`. + [3, 2, None] # Retrieves lengths of arrays found at all levels of the path `$..a`. >>> await json.arrlen(client, "doc", "..a") 3 # Legacy path retrieves the first array match at path `..a`. >>> await json.arrlen(client, "non_existing_key", "$.a") @@ -392,7 +392,7 @@ async def arrlen( >>> await json.set(client, "doc", "$", '[1, 2, 3, 4]') 'OK' # JSON is successfully set for doc >>> await json.arrlen(client, "doc") - 4 # Retrieves lengths of arrays in root. + 4 # Retrieves lengths of array in root. """ args = ["JSON.ARRLEN", key] if path: From d598e244bcc5efc851af5fb314becce6dfc72222 Mon Sep 17 00:00:00 2001 From: ikolomi Date: Thu, 31 Oct 2024 13:30:08 +0200 Subject: [PATCH 082/180] Fixes https://github.com/valkey-io/valkey-glide/issues/2556 : Mutable default parameter in python example. Signed-off-by: ikolomi --- CHANGELOG.md | 1 + examples/python/cluster_example.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bf0e71a3d7..4fa3bf7597 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,6 +53,7 @@ * Python: Add `JSON.ARRAPPEND` command ([#2382](https://github.com/valkey-io/valkey-glide/pull/2382)) * Python: Add `JSON.RESP` command ([#2451](https://github.com/valkey-io/valkey-glide/pull/2451)) * Node: Add `JSON.STRLEN` and `JSON.STRAPPEND` command ([#2537](https://github.com/valkey-io/valkey-glide/pull/2537)) +* Python: Fix example ([#2556](https://github.com/valkey-io/valkey-glide/issues/2556)) #### Breaking Changes diff --git a/examples/python/cluster_example.py b/examples/python/cluster_example.py index c3cefbd14f..01f916963b 100644 --- a/examples/python/cluster_example.py +++ b/examples/python/cluster_example.py @@ -1,5 +1,5 @@ import asyncio -from typing import List, Tuple +from typing import List, Tuple, Optional from glide import ( AllNodes, @@ -17,7 +17,7 @@ async def create_client( - nodes_list: List[Tuple[str, int]] = [("localhost", 6379)] + nodes_list: Optional[List[Tuple[str, int]]] = None ) -> GlideClusterClient: """ Creates and returns a GlideClusterClient instance. @@ -33,6 +33,8 @@ async def create_client( Returns: GlideClusterClient: An instance of GlideClusterClient connected to the discovered nodes. """ + if nodes_list is None: + nodes_list = [("localhost", 6379)] addresses = [NodeAddress(host, port) for host, port in nodes_list] # Check `GlideClusterClientConfiguration` for additional options. config = GlideClusterClientConfiguration( From 8a20b139a64563dd8e1e2bbb86697e28972c957d Mon Sep 17 00:00:00 2001 From: prateek-kumar-improving Date: Thu, 31 Oct 2024 14:54:50 -0700 Subject: [PATCH 083/180] Python: FT.PROFILE command added (#2543) * Python: FT.PROFILE command added --------- Signed-off-by: Prateek Kumar --- CHANGELOG.md | 1 + python/python/glide/__init__.py | 10 ++ .../glide/async_commands/server_modules/ft.py | 75 ++++++++++-- .../server_modules/ft_options/ft_constants.py | 14 +++ .../ft_options/ft_profile_options.py | 108 ++++++++++++++++++ python/python/glide/constants.py | 12 +- .../search/test_ft_search.py | 90 +++++++++++---- .../tests/tests_server_modules/test_ft.py | 105 ++++++++++++----- 8 files changed, 349 insertions(+), 66 deletions(-) create mode 100644 python/python/glide/async_commands/server_modules/ft_options/ft_profile_options.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 4fa3bf7597..bc832f60f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ #### Changes +* Python: Python: FT.PROFILE command added ([#2543](https://github.com/valkey-io/valkey-glide/pull/2543)) * Python: Python: FT.AGGREGATE command added([#2530](https://github.com/valkey-io/valkey-glide/pull/2530)) * Python: Add JSON.OBJLEN command ([#2495](https://github.com/valkey-io/valkey-glide/pull/2495)) * Python: FT.EXPLAIN and FT.EXPLAINCLI commands added([#2508](https://github.com/valkey-io/valkey-glide/pull/2508)) diff --git a/python/python/glide/__init__.py b/python/python/glide/__init__.py index 46289fdacc..bde4f33401 100644 --- a/python/python/glide/__init__.py +++ b/python/python/glide/__init__.py @@ -59,6 +59,9 @@ VectorFieldAttributesHnsw, VectorType, ) +from glide.async_commands.server_modules.ft_options.ft_profile_options import ( + FtProfileOptions, +) from glide.async_commands.server_modules.ft_options.ft_search_options import ( FtSeachOptions, FtSearchLimit, @@ -115,7 +118,10 @@ from glide.constants import ( OK, TOK, + FtAggregateResponse, FtInfoResponse, + FtProfileResponse, + FtSearchResponse, TClusterResponse, TEncodable, TFunctionListResponse, @@ -186,7 +192,10 @@ "TResult", "TXInfoStreamFullResponse", "TXInfoStreamResponse", + "FtAggregateResponse", "FtInfoResponse", + "FtProfileResponse", + "FtSearchResponse", # Commands "BitEncoding", "BitFieldGet", @@ -301,4 +310,5 @@ "FtAggregateReducer", "FtAggregateSortBy", "FtAggregateSortProperty", + "FtProfileOptions", ] diff --git a/python/python/glide/async_commands/server_modules/ft.py b/python/python/glide/async_commands/server_modules/ft.py index 6240d626f7..c8a757a979 100644 --- a/python/python/glide/async_commands/server_modules/ft.py +++ b/python/python/glide/async_commands/server_modules/ft.py @@ -16,10 +16,20 @@ Field, FtCreateOptions, ) +from glide.async_commands.server_modules.ft_options.ft_profile_options import ( + FtProfileOptions, +) from glide.async_commands.server_modules.ft_options.ft_search_options import ( FtSeachOptions, ) -from glide.constants import TOK, FtInfoResponse, TEncodable +from glide.constants import ( + TOK, + FtAggregateResponse, + FtInfoResponse, + FtProfileResponse, + FtSearchResponse, + TEncodable, +) from glide.glide_client import TGlideClient @@ -85,7 +95,7 @@ async def search( indexName: TEncodable, query: TEncodable, options: Optional[FtSeachOptions], -) -> List[Union[int, Mapping[TEncodable, Mapping[TEncodable, TEncodable]]]]: +) -> FtSearchResponse: """ Uses the provided query expression to locate keys within an index. Once located, the count and/or the content of indexed fields within those keys can be returned. @@ -96,7 +106,7 @@ async def search( options (Optional[FtSeachOptions]): The search options. See `FtSearchOptions`. Returns: - List[Union[int, Mapping[TEncodable, Mapping[TEncodable, TEncodable]]]]: A two element array, where first element is count of documents in result set, and the second element, which has the format Mapping[TEncodable, Mapping[TEncodable, TEncodable]] is a mapping between document names and map of their attributes. + FtSearchResponse: A two element array, where first element is count of documents in result set, and the second element, which has the format Mapping[TEncodable, Mapping[TEncodable, TEncodable]] is a mapping between document names and map of their attributes. If count(option in `FtSearchOptions`) is set to true or limit(option in `FtSearchOptions`) is set to FtSearchLimit(0, 0), the command returns array with only one element - the count of the documents. Examples: @@ -111,10 +121,7 @@ async def search( args: List[TEncodable] = [CommandNames.FT_SEARCH, indexName, query] if options: args.extend(options.toArgs()) - return cast( - List[Union[int, Mapping[TEncodable, Mapping[TEncodable, TEncodable]]]], - await client.custom_command(args), - ) + return cast(FtSearchResponse, await client.custom_command(args)) async def aliasadd( @@ -286,23 +293,69 @@ async def aggregate( indexName: TEncodable, query: TEncodable, options: Optional[FtAggregateOptions], -) -> List[Mapping[TEncodable, Any]]: +) -> FtAggregateResponse: """ A superset of the FT.SEARCH command, it allows substantial additional processing of the keys selected by the query expression. + Args: client (TGlideClient): The client to execute the command. indexName (TEncodable): The index name for which the query is written. query (TEncodable): The search query, same as the query passed as an argument to FT.SEARCH. options (Optional[FtAggregateOptions]): The optional arguments for the command. + Returns: - List[Mapping[TEncodable, Any]]: An array containing a mapping of field name and associated value as returned after the last stage of the command. + FtAggregateResponse: An array containing a mapping of field name and associated value as returned after the last stage of the command. Examples: >>> from glide import ft - >>> result = await ft.aggregate(glide_client, myIndex"", "*", FtAggregateOptions(loadFields=["__key"], clauses=[GroupBy(["@condition"], [Reducer("COUNT", [], "bicycles")])])) + >>> result = await ft.aggregate(glide_client, "myIndex", "*", FtAggregateOptions(loadFields=["__key"], clauses=[GroupBy(["@condition"], [Reducer("COUNT", [], "bicycles")])])) [{b'condition': b'refurbished', b'bicycles': b'1'}, {b'condition': b'new', b'bicycles': b'5'}, {b'condition': b'used', b'bicycles': b'4'}] """ args: List[TEncodable] = [CommandNames.FT_AGGREGATE, indexName, query] if options: args.extend(options.to_args()) - return cast(List[Mapping[TEncodable, Any]], await client.custom_command(args)) + return cast(FtAggregateResponse, await client.custom_command(args)) + + +async def profile( + client: TGlideClient, indexName: TEncodable, options: FtProfileOptions +) -> FtProfileResponse: + """ + Runs a search or aggregation query and collects performance profiling information. + + Args: + client (TGlideClient): The client to execute the command. + indexName (TEncodable): The index name + options (FtProfileOptions): Options for the command. + + Returns: + FtProfileResponse: A two-element array. The first element contains results of query being profiled, the second element stores profiling information. + + Examples: + >>> ftSearchOptions = FtSeachOptions(return_fields=[ReturnField(field_identifier="a", alias="a_new"), ReturnField(field_identifier="b", alias="b_new")]) + >>> ftProfileResult = await ft.profile(glide_client, "myIndex", FtProfileOptions.from_query_options(query="*", queryOptions=ftSearchOptions)) + [ + [ + 2, + { + b'key1': { + b'a': b'11111', + b'b': b'2' + }, + b'key2': { + b'a': b'22222', + b'b': b'2' + } + } + ], + { + b'all.count': 2, + b'sync.time': 1, + b'query.time': 7, + b'result.count': 2, + b'result.time': 0 + } + ] + """ + args: List[TEncodable] = [CommandNames.FT_PROFILE, indexName] + options.to_args() + return cast(FtProfileResponse, await client.custom_command(args)) diff --git a/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py b/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py index fd703ffcaf..15a978eac8 100644 --- a/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py +++ b/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py @@ -16,6 +16,7 @@ class CommandNames: FT_EXPLAIN = "FT.EXPLAIN" FT_EXPLAINCLI = "FT.EXPLAINCLI" FT_AGGREGATE = "FT.AGGREGATE" + FT_PROFILE = "FT.PROFILE" class FtCreateKeywords: @@ -55,6 +56,10 @@ class FtSeachKeywords: class FtAggregateKeywords: + """ + Keywords used in the FT.AGGREGATE command. + """ + LIMIT = "LIMIT" FILTER = "FILTER" GROUPBY = "GROUPBY" @@ -66,3 +71,12 @@ class FtAggregateKeywords: LOAD = "LOAD" TIMEOUT = "TIMEOUT" PARAMS = "PARAMS" + + +class FtProfileKeywords: + """ + Keywords used in the FT.PROFILE command. + """ + + QUERY = "QUERY" + LIMITED = "LIMITED" diff --git a/python/python/glide/async_commands/server_modules/ft_options/ft_profile_options.py b/python/python/glide/async_commands/server_modules/ft_options/ft_profile_options.py new file mode 100644 index 0000000000..d6ab8ceb7b --- /dev/null +++ b/python/python/glide/async_commands/server_modules/ft_options/ft_profile_options.py @@ -0,0 +1,108 @@ +# Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 +from enum import Enum +from typing import List, Optional, Union, cast + +from glide.async_commands.server_modules.ft_options.ft_aggregate_options import ( + FtAggregateOptions, +) +from glide.async_commands.server_modules.ft_options.ft_constants import ( + FtProfileKeywords, +) +from glide.async_commands.server_modules.ft_options.ft_search_options import ( + FtSeachOptions, +) +from glide.constants import TEncodable + + +class QueryType(Enum): + """ + This class represents the query type being profiled. + """ + + AGGREGATE = "AGGREGATE" + """ + If the query being profiled is for the FT.AGGREGATE command. + """ + SEARCH = "SEARCH" + """ + If the query being profiled is for the FT.SEARCH command. + """ + + +class FtProfileOptions: + """ + This class represents the arguments/options for the FT.PROFILE command. + """ + + def __init__( + self, + query: TEncodable, + queryType: QueryType, + queryOptions: Optional[Union[FtSeachOptions, FtAggregateOptions]] = None, + limited: Optional[bool] = False, + ): + """ + Initialize a new FtProfileOptions instance. + + Args: + query (TEncodable): The query that is being profiled. This is the query argument from the FT.AGGREGATE/FT.SEARCH command. + queryType (Optional[QueryType]): The type of query to be profiled. + queryOptions (Optional[Union[FtSeachOptions, FtAggregateOptions]]): The arguments/options for the FT.AGGREGATE/FT.SEARCH command being profiled. + limited (Optional[bool]): To provide some brief version of the output, otherwise a full verbose output is provided. + """ + self.query = query + self.queryType = queryType + self.queryOptions = queryOptions + self.limited = limited + + @classmethod + def from_query_options( + cls, + query: TEncodable, + queryOptions: Union[FtSeachOptions, FtAggregateOptions], + limited: Optional[bool] = False, + ): + """ + A class method to create FtProfileOptions with FT.SEARCH/FT.AGGREGATE options. + + Args: + query (TEncodable): The query that is being profiled. This is the query argument from the FT.AGGREGATE/FT.SEARCH command. + queryOptions (Optional[Union[FtSeachOptions, FtAggregateOptions]]): The arguments/options for the FT.AGGREGATE/FT.SEARCH command being profiled. + limited (Optional[bool]): To provide some brief version of the output, otherwise a full verbose output is provided. + """ + queryType: QueryType = QueryType.SEARCH + if type(queryOptions) == FtAggregateOptions: + queryType = QueryType.AGGREGATE + return cls(query, queryType, queryOptions, limited) + + @classmethod + def from_query_type( + cls, query: TEncodable, queryType: QueryType, limited: Optional[bool] = False + ): + """ + A class method to create FtProfileOptions with QueryType. + + Args: + query (TEncodable): The query that is being profiled. This is the query argument from the FT.AGGREGATE/FT.SEARCH command. + queryType (QueryType): The type of query to be profiled. + limited (Optional[bool]): To provide some brief version of the output, otherwise a full verbose output is provided. + """ + return cls(query, queryType, None, limited) + + def to_args(self) -> List[TEncodable]: + """ + Get the remaining arguments for the FT.PROFILE command. + + Returns: + List[TEncodable]: A list of remaining arguments for the FT.PROFILE command. + """ + args: List[TEncodable] = [self.queryType.value] + if self.limited: + args.append(FtProfileKeywords.LIMITED) + args.extend([FtProfileKeywords.QUERY, self.query]) + if self.queryOptions: + if type(self.queryOptions) == FtAggregateOptions: + args.extend(cast(FtAggregateOptions, self.queryOptions).to_args()) + else: + args.extend(cast(FtSeachOptions, self.queryOptions).toArgs()) + return args diff --git a/python/python/glide/constants.py b/python/python/glide/constants.py index 7c28372053..9740ac8cf6 100644 --- a/python/python/glide/constants.py +++ b/python/python/glide/constants.py @@ -1,6 +1,6 @@ # Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 -from typing import Dict, List, Literal, Mapping, Optional, Set, TypeVar, Union +from typing import Any, Dict, List, Literal, Mapping, Optional, Set, TypeVar, Union from glide.protobuf.command_request_pb2 import CommandRequest from glide.protobuf.connection_request_pb2 import ConnectionRequest @@ -108,3 +108,13 @@ ], ], ] + +FtSearchResponse = List[ + Union[int, Mapping[TEncodable, Mapping[TEncodable, TEncodable]]] +] + +FtAggregateResponse = List[Mapping[TEncodable, Any]] + +FtProfileResponse = List[ + Union[FtSearchResponse, FtAggregateResponse, Mapping[str, int]] +] diff --git a/python/python/tests/tests_server_modules/search/test_ft_search.py b/python/python/tests/tests_server_modules/search/test_ft_search.py index 80d8319676..bece8e1434 100644 --- a/python/python/tests/tests_server_modules/search/test_ft_search.py +++ b/python/python/tests/tests_server_modules/search/test_ft_search.py @@ -13,12 +13,15 @@ FtCreateOptions, NumericField, ) +from glide.async_commands.server_modules.ft_options.ft_profile_options import ( + FtProfileOptions, +) from glide.async_commands.server_modules.ft_options.ft_search_options import ( FtSeachOptions, ReturnField, ) from glide.config import ProtocolVersion -from glide.constants import OK, TEncodable +from glide.constants import OK, FtSearchResponse, TEncodable from glide.glide_client import GlideClusterClient @@ -38,7 +41,7 @@ async def test_ft_search(self, glide_client: GlideClusterClient): prefixes.append(prefix) index = prefix + str(uuid.uuid4()) - # Create an index + # Create an index. assert ( await ft.create( glide_client, @@ -52,7 +55,7 @@ async def test_ft_search(self, glide_client: GlideClusterClient): == OK ) - # Create a json key + # Create a json key. assert ( await GlideJson.set(glide_client, json_key1, "$", json.dumps(json_value1)) == OK @@ -65,22 +68,43 @@ async def test_ft_search(self, glide_client: GlideClusterClient): # Wait for index to be updated to avoid this error - ResponseError: The index is under construction. time.sleep(self.sleep_wait_time) - # Search the index for string inputs - result1 = await ft.search( + ftSearchOptions = FtSeachOptions( + return_fields=[ + ReturnField(field_identifier="a", alias="a_new"), + ReturnField(field_identifier="b", alias="b_new"), + ] + ) + + # Search the index for string inputs. + result1 = await ft.search(glide_client, index, "*", options=ftSearchOptions) + # Check if we get the expected result from ft.search for string inputs. + TestFtSearch._ft_search_deep_compare_result( + self, + result=result1, + json_key1=json_key1, + json_key2=json_key2, + json_value1=json_value1, + json_value2=json_value2, + fieldName1="a", + fieldName2="b", + ) + + # Test FT.PROFILE for the above mentioned FT.SEARCH query and search options. + + ftProfileResult = await ft.profile( glide_client, index, - "*", - options=FtSeachOptions( - return_fields=[ - ReturnField(field_identifier="a", alias="a_new"), - ReturnField(field_identifier="b", alias="b_new"), - ] + FtProfileOptions.from_query_options( + query="*", queryOptions=ftSearchOptions ), ) - # Check if we get the expected result from ft.search for string inputs + print(ftProfileResult) + assert len(ftProfileResult) > 0 + + # Check if we get the expected result from FT.PROFILE for string inputs. TestFtSearch._ft_search_deep_compare_result( self, - result=result1, + result=cast(FtSearchResponse, ftProfileResult[0]), json_key1=json_key1, json_key2=json_key2, json_value1=json_value1, @@ -88,24 +112,44 @@ async def test_ft_search(self, glide_client: GlideClusterClient): fieldName1="a", fieldName2="b", ) + ftSearchOptionsByteInput = FtSeachOptions( + return_fields=[ + ReturnField(field_identifier=b"a", alias=b"a_new"), + ReturnField(field_identifier=b"b", alias=b"b_new"), + ] + ) - # Search the index for byte inputs + # Search the index for byte type inputs. result2 = await ft.search( + glide_client, bytes(index, "utf-8"), b"*", options=ftSearchOptionsByteInput + ) + + # Check if we get the expected result from ft.search for byte type inputs. + TestFtSearch._ft_search_deep_compare_result( + self, + result=result2, + json_key1=json_key1, + json_key2=json_key2, + json_value1=json_value1, + json_value2=json_value2, + fieldName1="a", + fieldName2="b", + ) + + # Test FT.PROFILE for the above mentioned FT.SEARCH query and search options for byte type inputs. + ftProfileResult = await ft.profile( glide_client, - bytes(index, "utf-8"), - b"*", - options=FtSeachOptions( - return_fields=[ - ReturnField(field_identifier=b"a", alias=b"a_new"), - ReturnField(field_identifier=b"b", alias=b"b_new"), - ] + index, + FtProfileOptions.from_query_options( + query=b"*", queryOptions=ftSearchOptionsByteInput ), ) + assert len(ftProfileResult) > 0 - # Check if we get the expected result from ft.search from byte inputs + # Check if we get the expected result from FT.PROFILE for byte type inputs. TestFtSearch._ft_search_deep_compare_result( self, - result=result2, + result=cast(FtSearchResponse, ftProfileResult[0]), json_key1=json_key1, json_key2=json_key2, json_value1=json_value1, diff --git a/python/python/tests/tests_server_modules/test_ft.py b/python/python/tests/tests_server_modules/test_ft.py index 77d0f4079d..5c49e5e7c2 100644 --- a/python/python/tests/tests_server_modules/test_ft.py +++ b/python/python/tests/tests_server_modules/test_ft.py @@ -29,6 +29,9 @@ VectorFieldAttributesHnsw, VectorType, ) +from glide.async_commands.server_modules.ft_options.ft_profile_options import ( + FtProfileOptions, +) from glide.config import ProtocolVersion from glide.constants import OK, TEncodable from glide.exceptions import RequestError @@ -337,23 +340,23 @@ async def test_ft_aggregate_with_bicycles_data( ) time.sleep(self.sleep_wait_time) + ftAggregateOptions: FtAggregateOptions = FtAggregateOptions( + loadFields=["__key"], + clauses=[ + FtAggregateGroupBy( + ["@condition"], [FtAggregateReducer("COUNT", [], "bicycles")] + ) + ], + ) + # Run FT.AGGREGATE command with the following arguments: ['FT.AGGREGATE', '{bicycles}:1e15faab-a870-488e-b6cd-f2b76c6916a3', '*', 'LOAD', '1', '__key', 'GROUPBY', '1', '@condition', 'REDUCE', 'COUNT', '0', 'AS', 'bicycles'] result = await ft.aggregate( glide_client, indexName=indexBicycles, query="*", - options=FtAggregateOptions( - loadFields=["__key"], - clauses=[ - FtAggregateGroupBy( - ["@condition"], [FtAggregateReducer("COUNT", [], "bicycles")] - ) - ], - ), + options=ftAggregateOptions, ) - assert await ft.dropindex(glide_client, indexName=indexBicycles) == OK sortedResult = sorted(result, key=lambda x: (x[b"condition"], x[b"bicycles"])) - expectedResult = sorted( [ { @@ -373,6 +376,22 @@ async def test_ft_aggregate_with_bicycles_data( ) assert sortedResult == expectedResult + # Test FT.PROFILE for the above mentioned FT.AGGREGATE query + ftProfileResult = await ft.profile( + glide_client, + indexBicycles, + FtProfileOptions.from_query_options( + query="*", queryOptions=ftAggregateOptions + ), + ) + assert len(ftProfileResult) > 0 + assert ( + sorted(ftProfileResult[0], key=lambda x: (x[b"condition"], x[b"bicycles"])) + == expectedResult + ) + + assert await ft.dropindex(glide_client, indexName=indexBicycles) == OK + @pytest.mark.parametrize("cluster_mode", [True]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) async def test_ft_aggregate_with_movies_data( @@ -397,32 +416,32 @@ async def test_ft_aggregate_with_movies_data( # Run FT.AGGREGATE command with the following arguments: # ['FT.AGGREGATE', '{movies}:5a0e6257-3488-4514-96f2-f4c80f6cb0a9', '*', 'LOAD', '*', 'APPLY', 'ceil(@rating)', 'AS', 'r_rating', 'GROUPBY', '1', '@genre', 'REDUCE', 'COUNT', '0', 'AS', 'nb_of_movies', 'REDUCE', 'SUM', '1', 'votes', 'AS', 'nb_of_votes', 'REDUCE', 'AVG', '1', 'r_rating', 'AS', 'avg_rating', 'SORTBY', '4', '@avg_rating', 'DESC', '@nb_of_votes', 'DESC'] + ftAggregateOptions: FtAggregateOptions = FtAggregateOptions( + loadAll=True, + clauses=[ + FtAggregateApply(expression="ceil(@rating)", name="r_rating"), + FtAggregateGroupBy( + ["@genre"], + [ + FtAggregateReducer("COUNT", [], "nb_of_movies"), + FtAggregateReducer("SUM", ["votes"], "nb_of_votes"), + FtAggregateReducer("AVG", ["r_rating"], "avg_rating"), + ], + ), + FtAggregateSortBy( + properties=[ + FtAggregateSortProperty("@avg_rating", OrderBy.DESC), + FtAggregateSortProperty("@nb_of_votes", OrderBy.DESC), + ] + ), + ], + ) result = await ft.aggregate( glide_client, indexName=indexMovies, query="*", - options=FtAggregateOptions( - loadAll=True, - clauses=[ - FtAggregateApply(expression="ceil(@rating)", name="r_rating"), - FtAggregateGroupBy( - ["@genre"], - [ - FtAggregateReducer("COUNT", [], "nb_of_movies"), - FtAggregateReducer("SUM", ["votes"], "nb_of_votes"), - FtAggregateReducer("AVG", ["r_rating"], "avg_rating"), - ], - ), - FtAggregateSortBy( - properties=[ - FtAggregateSortProperty("@avg_rating", OrderBy.DESC), - FtAggregateSortProperty("@nb_of_votes", OrderBy.DESC), - ] - ), - ], - ), + options=ftAggregateOptions, ) - assert await ft.dropindex(glide_client, indexName=indexMovies) == OK sortedResult = sorted( result, key=lambda x: ( @@ -476,6 +495,30 @@ async def test_ft_aggregate_with_movies_data( ) assert expectedResultSet == sortedResult + # Test FT.PROFILE for the above mentioned FT.AGGREGATE query + ftProfileResult = await ft.profile( + glide_client, + indexMovies, + FtProfileOptions.from_query_options( + query="*", queryOptions=ftAggregateOptions + ), + ) + assert len(ftProfileResult) > 0 + assert ( + sorted( + ftProfileResult[0], + key=lambda x: ( + x[b"genre"], + x[b"nb_of_movies"], + x[b"nb_of_votes"], + x[b"avg_rating"], + ), + ) + == expectedResultSet + ) + + assert await ft.dropindex(glide_client, indexName=indexMovies) == OK + async def _create_index_for_ft_aggregate_with_bicycles_data( self, glide_client: GlideClusterClient, index_name: TEncodable, prefix ): From 23689db8aa3bd0aa2795fca6b829289b53cf47e9 Mon Sep 17 00:00:00 2001 From: Yi-Pin Chen Date: Thu, 31 Oct 2024 16:07:39 -0700 Subject: [PATCH 084/180] Node: add ARRTRIM command (#2550) * Node: add ARRTRIM command Signed-off-by: Yi-Pin Chen --- CHANGELOG.md | 1 + .../api/commands/servermodules/Json.java | 4 +- node/src/server-modules/GlideJson.ts | 71 +++++++++- node/tests/ServerModules.test.ts | 133 ++++++++++++++++++ 4 files changed, 203 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bc832f60f4..478aa96339 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,7 @@ * Java: Added `JSON.TYPE` ([#2525](https://github.com/valkey-io/valkey-glide/pull/2525)) * Node: Added `FT.DROPINDEX` ([#2516](https://github.com/valkey-io/valkey-glide/pull/2516)) * Node: Added `JSON.RESP` ([#2517](https://github.com/valkey-io/valkey-glide/pull/2517)) +* Node: Added `JSON.ARRTRIM` ([#2550](https://github.com/valkey-io/valkey-glide/pull/2550)) * Python: Add `JSON.STRAPPEND` , `JSON.STRLEN` commands ([#2372](https://github.com/valkey-io/valkey-glide/pull/2372)) * Python: Add `JSON.OBJKEYS` command ([#2395](https://github.com/valkey-io/valkey-glide/pull/2395)) * Python: Add `JSON.ARRINSERT` command ([#2464](https://github.com/valkey-io/valkey-glide/pull/2464)) diff --git a/java/client/src/main/java/glide/api/commands/servermodules/Json.java b/java/client/src/main/java/glide/api/commands/servermodules/Json.java index eef7a8d4c8..02ea0ff07b 100644 --- a/java/client/src/main/java/glide/api/commands/servermodules/Json.java +++ b/java/client/src/main/java/glide/api/commands/servermodules/Json.java @@ -1145,7 +1145,7 @@ public static CompletableFuture arrpop( } /** - * Trims an array at the specified path within the JSON document started at key + * Trims an array at the specified path within the JSON document stored at key * so that it becomes a subarray [start, end], both inclusive. *
        * If start < 0, it is treated as 0.
        @@ -1193,7 +1193,7 @@ public static CompletableFuture arrtrim( } /** - * Trims an array at the specified path within the JSON document started at key + * Trims an array at the specified path within the JSON document stored at key * so that it becomes a subarray [start, end], both inclusive. *
        * If start < 0, it is treated as 0.
        diff --git a/node/src/server-modules/GlideJson.ts b/node/src/server-modules/GlideJson.ts index 3bf942a4bb..2db3e453ae 100644 --- a/node/src/server-modules/GlideJson.ts +++ b/node/src/server-modules/GlideJson.ts @@ -341,6 +341,69 @@ export class GlideJson { return _executeCommand(client, args); } + /** + * Trims an array at the specified `path` within the JSON document stored at `key` so that it becomes a subarray [start, end], both inclusive. + * If `start` < 0, it is treated as 0. + * If `end` >= size (size of the array), it is treated as size-1. + * If `start` >= size or `start` > `end`, the array is emptied and 0 is returned. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param path - The path within the JSON document. + * @param start - The start index, inclusive. + * @param end - The end index, inclusive. + * @returns + * - For JSONPath (`path` starts with `$`): + * - Returns a list of integer replies for every possible path, indicating the new length of the array, + * or `null` for JSON values matching the path that are not an array. + * - If the array is empty, its corresponding return value is 0. + * - If `path` doesn't exist, an empty array will be returned. + * - If an index argument is out of bounds, an error is raised. + * - For legacy path (`path` doesn't start with `$`): + * - Returns an integer representing the new length of the array. + * - If the array is empty, its corresponding return value is 0. + * - If multiple paths match, the length of the first trimmed array match is returned. + * - If `path` doesn't exist, or the value at `path` is not an array, an error is raised. + * - If an index argument is out of bounds, an error is raised. + * + * @example + * ```typescript + * console.log(await GlideJson.set(client, "doc", "$", '[[], ["a"], ["a", "b"], ["a", "b", "c"]]'); + * // Output: 'OK' - Indicates successful setting of the value at path '$' in the key stored at `doc`. + * const result = await GlideJson.arrtrim(client, "doc", "$[*]", 0, 1); + * console.log(result); + * // Output: [0, 1, 2, 2] + * console.log(await GlideJson.get(client, "doc", "$")); + * // Output: '[[],["a"],["a","b"],["a","b"]]' - Returns the value at path '$' in the JSON document stored at `doc`. + * ``` + * @example + * ```typescript + * console.log(await GlideJson.set(client, "doc", "$", '{"children": ["John", "Jack", "Tom", "Bob", "Mike"]}'); + * // Output: 'OK' - Indicates successful setting of the value at path '$' in the key stored at `doc`. + * result = await GlideJson.arrtrim(client, "doc", ".children", 0, 1); + * console.log(result); + * // Output: 2 + * console.log(await GlideJson.get(client, "doc", ".children")); + * // Output: '["John", "Jack"]' - Returns the value at path '$' in the JSON document stored at `doc`. + * ``` + */ + static async arrtrim( + client: BaseClient, + key: GlideString, + path: GlideString, + start: number, + end: number, + ): Promise> { + const args: GlideString[] = [ + "JSON.ARRTRIM", + key, + path, + start.toString(), + end.toString(), + ]; + return _executeCommand>(client, args); + } + /** * Toggles a Boolean value stored at the specified `path` within the JSON document stored at `key`. * @@ -484,7 +547,7 @@ export class GlideJson { * * @example * ```typescript - * console.log(await GlideJson.set(client, "doc", "$", "[1, 2.3, "foo", true, null, {}, []]")); + * console.log(await GlideJson.set(client, "doc", "$", '[1, 2.3, "foo", true, null, {}, []]')); * // Output: 'OK' - Indicates successful setting of the value at path '$' in the key stored at `doc`. * const result = await GlideJson.type(client, "doc", "$[*]"); * console.log(result); @@ -535,7 +598,7 @@ export class GlideJson { * * @example * ```typescript - * console.log(await GlideJson.set(client, "doc", ".", "{a: [1, 2, 3], b: {a: [1, 2], c: {a: 42}}}")); + * console.log(await GlideJson.set(client, "doc", ".", '{a: [1, 2, 3], b: {a: [1, 2], c: {a: 42}}}')); * // Output: 'OK' - Indicates successful setting of the value at path '.' in the key stored at `doc`. * const result = await GlideJson.resp(client, "doc", {path: "$..a"}); * console.log(result); @@ -582,7 +645,7 @@ export class GlideJson { * * @example * ```typescript - * console.log(await GlideJson.set(client, "doc", "$", '{a:"foo", nested: {a: "hello"}, nested2: {a: 31}}")); + * console.log(await GlideJson.set(client, "doc", "$", '{a:"foo", nested: {a: "hello"}, nested2: {a: 31}}')); * // Output: 'OK' - Indicates successful setting of the value at path '$' in the key stored at `doc`. * console.log(await GlideJson.strlen(client, "doc", {path: "$..a"})); * // Output: [3, 5, null] - The length of the string values at path '$..a' in the key stored at `doc`. @@ -632,7 +695,7 @@ export class GlideJson { * * @example * ```typescript - * console.log(await GlideJson.set(client, "doc", "$", '{a:"foo", nested: {a: "hello"}, nested2: {a: 31}}")); + * console.log(await GlideJson.set(client, "doc", "$", '{a:"foo", nested: {a: "hello"}, nested2: {a: 31}}')); * // Output: 'OK' - Indicates successful setting of the value at path '$' in the key stored at `doc`. * console.log(await GlideJson.strappend(client, "doc", jsonpy.dumps("baz"), {path: "$..a"})) * // Output: [6, 8, null] - The new length of the string values at path '$..a' in the key stored at `doc` after the append operation. diff --git a/node/tests/ServerModules.test.ts b/node/tests/ServerModules.test.ts index 4baf854d54..314158426f 100644 --- a/node/tests/ServerModules.test.ts +++ b/node/tests/ServerModules.test.ts @@ -887,6 +887,139 @@ describe("Server Module Tests", () => { ).toBeNull(); }); + it("json.arrtrim tests", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + + const key = uuidv4(); + const jsonValue = { + a: [0, 1, 2, 3, 4, 5, 6, 7, 8], + b: { a: [0, 9, 10, 11, 12, 13], c: { a: 42 } }, + }; + + // setup + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + + // Basic trim + expect( + await GlideJson.arrtrim(client, key, "$..a", 1, 7), + ).toEqual([7, 5, null]); + + // Test end >= size (should be treated as size-1) + expect( + await GlideJson.arrtrim(client, key, "$.a", 0, 10), + ).toEqual([7]); + expect( + await GlideJson.arrtrim(client, key, ".a", 0, 10), + ).toEqual(7); + + // Test negative start (should be treated as 0) + expect( + await GlideJson.arrtrim(client, key, "$.a", -1, 5), + ).toEqual([6]); + expect( + await GlideJson.arrtrim(client, key, ".a", -1, 5), + ).toEqual(6); + + // Test start >= size (should empty the array) + expect( + await GlideJson.arrtrim(client, key, "$.a", 7, 10), + ).toEqual([0]); + const jsonValue2 = ["a", "b", "c"]; + expect( + await GlideJson.set( + client, + key, + ".a", + JSON.stringify(jsonValue2), + ), + ).toBe("OK"); + expect( + await GlideJson.arrtrim(client, key, ".a", 7, 10), + ).toEqual(0); + + // Test start > end (should empty the array) + expect( + await GlideJson.arrtrim(client, key, "$..a", 2, 1), + ).toEqual([0, 0, null]); + const jsonValue3 = ["a", "b", "c", "d"]; + expect( + await GlideJson.set( + client, + key, + "..a", + JSON.stringify(jsonValue3), + ), + ).toBe("OK"); + expect( + await GlideJson.arrtrim(client, key, "..a", 2, 1), + ).toEqual(0); + + // Multiple path match + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + expect( + await GlideJson.arrtrim(client, key, "..a", 1, 10), + ).toEqual(8); + + // Test with non-existent path + await expect( + GlideJson.arrtrim(client, key, "nonexistent", 0, 1), + ).rejects.toThrow(RequestError); + expect( + await GlideJson.arrtrim(client, key, "$.nonexistent", 0, 1), + ).toEqual([]); + + // Test with non-array path + expect(await GlideJson.arrtrim(client, key, "$", 0, 1)).toEqual( + [null], + ); + await expect( + GlideJson.arrtrim(client, key, ".", 0, 1), + ).rejects.toThrow(RequestError); + + // Test with non-existent key + await expect( + GlideJson.arrtrim(client, "non_existing_key", "$", 0, 1), + ).rejects.toThrow(RequestError); + await expect( + GlideJson.arrtrim(client, "non_existing_key", ".", 0, 1), + ).rejects.toThrow(RequestError); + + // Test empty array + expect( + await GlideJson.set( + client, + key, + "$.empty", + JSON.stringify([]), + ), + ).toBe("OK"); + expect( + await GlideJson.arrtrim(client, key, "$.empty", 0, 1), + ).toEqual([0]); + expect( + await GlideJson.arrtrim(client, key, ".empty", 0, 1), + ).toEqual(0); + }); + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( "json.strlen tests", async (protocol) => { From 788b6c8c9f7d4cf470e1bf6e48970f21467d73bb Mon Sep 17 00:00:00 2001 From: Andrew Carbonetto Date: Fri, 1 Nov 2024 16:05:48 -0700 Subject: [PATCH 085/180] Node: Add `FT.SEARCH` command (#2551) * Node: add FT.SEARCH Signed-off-by: Andrew Carbonetto --------- Signed-off-by: Andrew Carbonetto --- CHANGELOG.md | 1 + node/npm/glide/index.ts | 4 + node/src/server-modules/GlideFt.ts | 144 +++++++++++++-- node/src/server-modules/GlideFtOptions.ts | 45 ++++- node/tests/ServerModules.test.ts | 210 +++++++++++++++++++++- node/tests/TestUtilities.ts | 1 - 6 files changed, 391 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 478aa96339..a1740d9e75 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,6 +55,7 @@ * Python: Add `JSON.ARRAPPEND` command ([#2382](https://github.com/valkey-io/valkey-glide/pull/2382)) * Python: Add `JSON.RESP` command ([#2451](https://github.com/valkey-io/valkey-glide/pull/2451)) * Node: Add `JSON.STRLEN` and `JSON.STRAPPEND` command ([#2537](https://github.com/valkey-io/valkey-glide/pull/2537)) +* Node: Add `FT.SEARCH` ([#2551](https://github.com/valkey-io/valkey-glide/pull/2551)) * Python: Fix example ([#2556](https://github.com/valkey-io/valkey-glide/issues/2556)) #### Breaking Changes diff --git a/node/npm/glide/index.ts b/node/npm/glide/index.ts index 0dc40fd055..781fd26594 100644 --- a/node/npm/glide/index.ts +++ b/node/npm/glide/index.ts @@ -126,7 +126,9 @@ function initialize() { VectorFieldAttributesFlat, VectorFieldAttributesHnsw, FtCreateOptions, + FtSearchOptions, FtInfoReturnType, + FtSearchReturnType, GlideRecord, GlideString, JsonGetOptions, @@ -248,7 +250,9 @@ function initialize() { VectorFieldAttributesFlat, VectorFieldAttributesHnsw, FtCreateOptions, + FtSearchOptions, FtInfoReturnType, + FtSearchReturnType, GlideRecord, GlideJson, GlideString, diff --git a/node/src/server-modules/GlideFt.ts b/node/src/server-modules/GlideFt.ts index 0d58cbfeb0..60c7f72459 100644 --- a/node/src/server-modules/GlideFt.ts +++ b/node/src/server-modules/GlideFt.ts @@ -12,10 +12,10 @@ import { } from "../BaseClient"; import { GlideClient } from "../GlideClient"; import { GlideClusterClient } from "../GlideClusterClient"; -import { Field, FtCreateOptions } from "./GlideFtOptions"; +import { Field, FtCreateOptions, FtSearchOptions } from "./GlideFtOptions"; -/** Data type of {@link GlideFt.info | info} command response. */ -type FtInfoReturnType = Record< +/** Response type of {@link GlideFt.info | ft.info} command. */ +export type FtInfoReturnType = Record< string, | GlideString | number @@ -23,15 +23,23 @@ type FtInfoReturnType = Record< | Record[]> >; +/** + * Response type for the {@link GlideFt.search | ft.search} command. + */ +export type FtSearchReturnType = [ + number, + GlideRecord>, +]; + /** Module for Vector Search commands. */ export class GlideFt { /** * Creates an index and initiates a backfill of that index. * - * @param client The client to execute the command. - * @param indexName The index name for the index to be created. - * @param schema The fields of the index schema, specifying the fields and their types. - * @param options Optional arguments for the `FT.CREATE` command. See {@link FtCreateOptions}. + * @param client - The client to execute the command. + * @param indexName - The index name for the index to be created. + * @param schema - The fields of the index schema, specifying the fields and their types. + * @param options - (Optional) Options for the `FT.CREATE` command. See {@link FtCreateOptions}. * * @returns If the index is successfully created, returns "OK". * @@ -182,8 +190,8 @@ export class GlideFt { /** * Deletes an index and associated content. Indexed document keys are unaffected. * - * @param client The client to execute the command. - * @param indexName The index name. + * @param client - The client to execute the command. + * @param indexName - The index name. * * @returns "OK" * @@ -269,6 +277,122 @@ export class GlideFt { > ).then(convertGlideRecordToRecord); } + + /** + * Uses the provided query expression to locate keys within an index. Once located, the count + * and/or content of indexed fields within those keys can be returned. + * + * @param client - The client to execute the command. + * @param indexName - The index name to search into. + * @param query - The text query to search. + * @param options - (Optional) See {@link FtSearchOptions} and {@link DecoderOption}. + * + * @returns A two-element array, where the first element is the number of documents in the result set, and the + * second element has the format: `GlideRecord>`: + * a mapping between document names and a map of their attributes. + * + * If `count` or `limit` with values `{offset: 0, count: 0}` is + * set, the command returns array with only one element: the number of documents. + * + * @example + * ```typescript + * // + * const vector = Buffer.alloc(24); + * const result = await GlideFt.search(client, "json_idx1", "*=>[KNN 2 @VEC $query_vec]", {params: [{key: "query_vec", value: vector}]}); + * console.log(result); // Output: + * // [ + * // 2, + * // [ + * // { + * // key: "json:2", + * // value: [ + * // { + * // key: "$", + * // value: '{"vec":[1.1,1.2,1.3,1.4,1.5,1.6]}', + * // }, + * // { + * // key: "__VEC_score", + * // value: "11.1100006104", + * // }, + * // ], + * // }, + * // { + * // key: "json:0", + * // value: [ + * // { + * // key: "$", + * // value: '{"vec":[1,2,3,4,5,6]}', + * // }, + * // { + * // key: "__VEC_score", + * // value: "91", + * // }, + * // ], + * // }, + * // ], + * // ] + * ``` + */ + static async search( + client: GlideClient | GlideClusterClient, + indexName: GlideString, + query: GlideString, + options?: FtSearchOptions & DecoderOption, + ): Promise { + const args: GlideString[] = ["FT.SEARCH", indexName, query]; + + if (options) { + // RETURN + if (options.returnFields) { + const returnFields: GlideString[] = []; + options.returnFields.forEach((returnField) => + returnField.alias + ? returnFields.push( + returnField.fieldIdentifier, + "AS", + returnField.alias, + ) + : returnFields.push(returnField.fieldIdentifier), + ); + args.push( + "RETURN", + returnFields.length.toString(), + ...returnFields, + ); + } + + // TIMEOUT + if (options.timeout) { + args.push("TIMEOUT", options.timeout.toString()); + } + + // PARAMS + if (options.params) { + args.push("PARAMS", (options.params.length * 2).toString()); + options.params.forEach((param) => + args.push(param.key, param.value), + ); + } + + // LIMIT + if (options.limit) { + args.push( + "LIMIT", + options.limit.offset.toString(), + options.limit.count.toString(), + ); + } + + // COUNT + if (options.count) { + args.push("COUNT"); + } + } + + return _handleCustomCommand(client, args, options) as Promise< + [number, GlideRecord>] + >; + } } /** @@ -277,7 +401,7 @@ export class GlideFt { async function _handleCustomCommand( client: GlideClient | GlideClusterClient, args: GlideString[], - decoderOption?: DecoderOption, + decoderOption: DecoderOption = {}, ): Promise { return client instanceof GlideClient ? (client as GlideClient).customCommand(args, decoderOption) diff --git a/node/src/server-modules/GlideFtOptions.ts b/node/src/server-modules/GlideFtOptions.ts index 24846da6d2..fffccd11c1 100644 --- a/node/src/server-modules/GlideFtOptions.ts +++ b/node/src/server-modules/GlideFtOptions.ts @@ -2,7 +2,7 @@ * Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ -import { GlideString } from "../BaseClient"; +import { GlideRecord, GlideString } from "../BaseClient"; interface BaseField { /** The name of the field. */ @@ -118,3 +118,46 @@ export interface FtCreateOptions { /** The prefix of the key to be indexed. */ prefixes?: GlideString[]; } + +/** + * Represents the input options to be used in the FT.SEARCH command. + * All fields in this class are optional inputs for FT.SEARCH. + */ +export type FtSearchOptions = { + /** + * Add a field to be returned. + * @param fieldIdentifier field name to return. + * @param alias optional alias for the field name to return. + */ + returnFields?: { fieldIdentifier: GlideString; alias?: GlideString }[]; + + /** Query timeout in milliseconds. */ + timeout?: number; + + /** + * Query parameters, which could be referenced in the query by `$` sign, followed by + * the parameter name. + */ + params?: GlideRecord; +} & ( + | { + /** + * Configure query pagination. By default only first 10 documents are returned. + * + * @param offset Zero-based offset. + * @param count Number of elements to return. + */ + limit?: { offset: number; count: number }; + /** `limit` and `count` are mutually exclusive. */ + count?: never; + } + | { + /** + * Once set, the query will return only the number of documents in the result set without actually + * returning them. + */ + count?: boolean; + /** `limit` and `count` are mutually exclusive. */ + limit?: never; + } +); diff --git a/node/tests/ServerModules.test.ts b/node/tests/ServerModules.test.ts index 314158426f..24ecfd9435 100644 --- a/node/tests/ServerModules.test.ts +++ b/node/tests/ServerModules.test.ts @@ -13,6 +13,7 @@ import { v4 as uuidv4 } from "uuid"; import { ConditionalChange, Decoder, + FtSearchReturnType, GlideClusterClient, GlideFt, GlideJson, @@ -32,6 +33,7 @@ import { } from "./TestUtilities"; const TIMEOUT = 50000; +const DATA_PROCESSING_TIMEOUT = 1000; describe("Server Module Tests", () => { let cluster: ValkeyCluster; @@ -1217,7 +1219,7 @@ describe("Server Module Tests", () => { expect(info).toContain("# search_index_stats"); }); - it("Ft.Create test", async () => { + it("FT.CREATE test", async () => { client = await GlideClusterClient.createClient( getClientConfigurationOption( cluster.getAddresses(), @@ -1371,7 +1373,7 @@ describe("Server Module Tests", () => { } }); - it("Ft.DROPINDEX test", async () => { + it("FT.DROPINDEX test", async () => { client = await GlideClusterClient.createClient( getClientConfigurationOption( cluster.getAddresses(), @@ -1488,5 +1490,209 @@ describe("Server Module Tests", () => { ); }, ); + + it("FT.SEARCH binary test", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + ProtocolVersion.RESP3, + ), + ); + + const prefix = "{" + uuidv4() + "}:"; + const index = prefix + "index"; + + // setup a hash index: + expect( + await GlideFt.create( + client, + index, + [ + { + type: "VECTOR", + name: "vec", + alias: "VEC", + attributes: { + algorithm: "HNSW", + distanceMetric: "L2", + dimensions: 2, + }, + }, + ], + { + dataType: "HASH", + prefixes: [prefix], + }, + ), + ).toEqual("OK"); + + const binaryValue1 = Buffer.alloc(8); + expect( + await client.hset(Buffer.from(prefix + "0"), [ + // value of + { field: "vec", value: binaryValue1 }, + ]), + ).toEqual(1); + + const binaryValue2: Buffer = Buffer.alloc(8); + binaryValue2[6] = 0x80; + binaryValue2[7] = 0xbf; + expect( + await client.hset(Buffer.from(prefix + "1"), [ + // value of + { field: "vec", value: binaryValue2 }, + ]), + ).toEqual(1); + + // let server digest the data and update index + const sleep = new Promise((resolve) => + setTimeout(resolve, DATA_PROCESSING_TIMEOUT), + ); + await sleep; + + // With the `COUNT` parameters - returns only the count + const binaryResultCount: FtSearchReturnType = await GlideFt.search( + client, + index, + "*=>[KNN 2 @VEC $query_vec]", + { + params: [{ key: "query_vec", value: binaryValue1 }], + timeout: 10000, + count: true, + decoder: Decoder.Bytes, + }, + ); + expect(binaryResultCount).toEqual([2]); + + const binaryResult: FtSearchReturnType = await GlideFt.search( + client, + index, + "*=>[KNN 2 @VEC $query_vec]", + { + params: [{ key: "query_vec", value: binaryValue1 }], + timeout: 10000, + decoder: Decoder.Bytes, + }, + ); + + const expectedBinaryResult: FtSearchReturnType = [ + 2, + [ + { + key: Buffer.from(prefix + "1"), + value: [ + { + key: Buffer.from("vec"), + value: binaryValue2, + }, + { + key: Buffer.from("__VEC_score"), + value: Buffer.from("1"), + }, + ], + }, + { + key: Buffer.from(prefix + "0"), + value: [ + { + key: Buffer.from("vec"), + value: binaryValue1, + }, + { + key: Buffer.from("__VEC_score"), + value: Buffer.from("0"), + }, + ], + }, + ], + ]; + expect(binaryResult).toEqual(expectedBinaryResult); + }); + + it("FT.SEARCH string test", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + ProtocolVersion.RESP3, + ), + ); + + const prefix = "{" + uuidv4() + "}:"; + const index = prefix + "index"; + + // set string values + expect( + await GlideJson.set( + client, + prefix + "1", + "$", + '[{"arr": 42}, {"val": "hello"}, {"val": "world"}]', + ), + ).toEqual("OK"); + + // setup a json index: + expect( + await GlideFt.create( + client, + index, + [ + { + type: "NUMERIC", + name: "$..arr", + alias: "arr", + }, + { + type: "TEXT", + name: "$..val", + alias: "val", + }, + ], + { + dataType: "JSON", + prefixes: [prefix], + }, + ), + ).toEqual("OK"); + + // let server digest the data and update index + const sleep = new Promise((resolve) => + setTimeout(resolve, DATA_PROCESSING_TIMEOUT), + ); + await sleep; + + const stringResult: FtSearchReturnType = await GlideFt.search( + client, + index, + "*", + { + returnFields: [ + { fieldIdentifier: "$..arr", alias: "myarr" }, + { fieldIdentifier: "$..val", alias: "myval" }, + ], + timeout: 10000, + decoder: Decoder.String, + limit: { offset: 0, count: 2 }, + }, + ); + const expectedStringResult: FtSearchReturnType = [ + 1, + [ + { + key: prefix + "1", + value: [ + { + key: "myarr", + value: "42", + }, + { + key: "myval", + value: "hello", + }, + ], + }, + ], + ]; + expect(stringResult).toEqual(expectedStringResult); + }); }); }); diff --git a/node/tests/TestUtilities.ts b/node/tests/TestUtilities.ts index c3fac91e09..0b64b31a04 100644 --- a/node/tests/TestUtilities.ts +++ b/node/tests/TestUtilities.ts @@ -177,7 +177,6 @@ export function flushallOnPort(port: number): Promise { */ export const parseEndpoints = (endpointsStr: string): [string, number][] => { try { - console.log(endpointsStr); const endpoints: string[][] = endpointsStr .split(",") .map((endpoint) => endpoint.split(":")); From 58454c7362d7359ff823d35bacf595dec7873a0d Mon Sep 17 00:00:00 2001 From: Muhammad Awawdi Date: Sun, 3 Nov 2024 11:31:50 +0200 Subject: [PATCH 086/180] PYTHON: Json.numincrby update (#2577) --------- Signed-off-by: Muhammad Awawdi --- python/python/glide/async_commands/server_modules/json.py | 6 +++--- python/python/tests/tests_server_modules/test_json.py | 8 ++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/python/glide/async_commands/server_modules/json.py b/python/python/glide/async_commands/server_modules/json.py index 2b57239a41..9bd8d870c0 100644 --- a/python/python/glide/async_commands/server_modules/json.py +++ b/python/python/glide/async_commands/server_modules/json.py @@ -691,7 +691,7 @@ async def numincrby( key: TEncodable, path: TEncodable, number: Union[int, float], -) -> Optional[bytes]: +) -> bytes: """ Increments or decrements the JSON value(s) at the specified `path` by `number` within the JSON document stored at `key`. @@ -702,7 +702,7 @@ async def numincrby( number (Union[int, float]): The number to increment or decrement by. Returns: - Optional[bytes]: + bytes: For JSONPath (`path` starts with `$`): Returns a bytes string representation of an array of bulk strings, indicating the new values after incrementing for each matched `path`. If a value is not a number, its corresponding return value will be `null`. @@ -725,7 +725,7 @@ async def numincrby( """ args = ["JSON.NUMINCRBY", key, path, str(number)] - return cast(Optional[bytes], await client.custom_command(args)) + return cast(bytes, await client.custom_command(args)) async def nummultby( diff --git a/python/python/tests/tests_server_modules/test_json.py b/python/python/tests/tests_server_modules/test_json.py index ad958ff88d..cce7a5f3f5 100644 --- a/python/python/tests/tests_server_modules/test_json.py +++ b/python/python/tests/tests_server_modules/test_json.py @@ -585,6 +585,10 @@ async def test_json_numincrby(self, glide_client: TGlideClient): result = await json.numincrby(glide_client, key, "$.key1", -0.5) assert result == b"[-0.5]" # Expect 0 - 0.5 = -0.5 + # Check 'null' value + result = await json.numincrby(glide_client, key, "$.key7", 5) + assert result == b"[null]" # Expect 'null' + # Test Legacy Path # Increment float value (key1) by 5 (integer) result = await json.numincrby(glide_client, key, "key1", 5) @@ -632,6 +636,10 @@ async def test_json_numincrby(self, glide_client: TGlideClient): with pytest.raises(RequestError): await json.numincrby(glide_client, key, ".key9", 1.7976931348623157e308) + # Check 'null' value + with pytest.raises(RequestError): + await json.numincrby(glide_client, key, ".key7", 5) + @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) async def test_json_nummultby(self, glide_client: TGlideClient): From c2d5e4a34e947111c69a563f3c320c0ef0d57174 Mon Sep 17 00:00:00 2001 From: Muhammad Awawdi Date: Sun, 3 Nov 2024 11:35:56 +0200 Subject: [PATCH 087/180] PYTHON: Json.nummultby update (#2578) Signed-off-by: Muhammad Awawdi --- python/python/glide/async_commands/server_modules/json.py | 6 +++--- python/python/tests/tests_server_modules/test_json.py | 4 ++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/python/glide/async_commands/server_modules/json.py b/python/python/glide/async_commands/server_modules/json.py index 9bd8d870c0..0b2cfbe4e0 100644 --- a/python/python/glide/async_commands/server_modules/json.py +++ b/python/python/glide/async_commands/server_modules/json.py @@ -733,7 +733,7 @@ async def nummultby( key: TEncodable, path: TEncodable, number: Union[int, float], -) -> Optional[bytes]: +) -> bytes: """ Multiplies the JSON value(s) at the specified `path` by `number` within the JSON document stored at `key`. @@ -744,7 +744,7 @@ async def nummultby( number (Union[int, float]): The number to multiply by. Returns: - Optional[bytes]: + bytes: For JSONPath (`path` starts with `$`): Returns a bytes string representation of an array of bulk strings, indicating the new values after multiplication for each matched `path`. If a value is not a number, its corresponding return value will be `null`. @@ -767,7 +767,7 @@ async def nummultby( """ args = ["JSON.NUMMULTBY", key, path, str(number)] - return cast(Optional[bytes], await client.custom_command(args)) + return cast(bytes, await client.custom_command(args)) async def objlen( diff --git a/python/python/tests/tests_server_modules/test_json.py b/python/python/tests/tests_server_modules/test_json.py index cce7a5f3f5..783b33cb2c 100644 --- a/python/python/tests/tests_server_modules/test_json.py +++ b/python/python/tests/tests_server_modules/test_json.py @@ -766,6 +766,10 @@ async def test_json_nummultby(self, glide_client: TGlideClient): result = await json.get(glide_client, key, "$..key1") # type: ignore assert result == b"[-16500,[140,175],1380]" + # Check 'null' in legacy + with pytest.raises(RequestError): + await json.nummultby(glide_client, key, ".key7", 5) + # Check for non-existent path in legacy with pytest.raises(RequestError): await json.nummultby(glide_client, key, ".key10", 51) From b0918041cc56f25aa9058f66f4fe8289934d511c Mon Sep 17 00:00:00 2001 From: Shoham Elias <116083498+shohamazon@users.noreply.github.com> Date: Sun, 3 Nov 2024 12:40:13 +0200 Subject: [PATCH 088/180] Python: adds JSON.ARRPOP command (#2407) --------- Signed-off-by: Shoham Elias Signed-off-by: Shoham Elias <116083498+shohamazon@users.noreply.github.com> --- CHANGELOG.md | 1 + python/python/glide/__init__.py | 7 +- .../async_commands/server_modules/json.py | 89 +++++++++++++++ .../tests/tests_server_modules/test_json.py | 104 +++++++++++++++++- 4 files changed, 199 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a1740d9e75..4d809b3801 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,7 @@ * Python: Add `JSON.ARRTRIM` command ([#2457](https://github.com/valkey-io/valkey-glide/pull/2457)) * Python: Add `JSON.ARRAPPEND` command ([#2382](https://github.com/valkey-io/valkey-glide/pull/2382)) * Python: Add `JSON.RESP` command ([#2451](https://github.com/valkey-io/valkey-glide/pull/2451)) +* Python: Add `JSON.ARRPOP` command ([#2407](https://github.com/valkey-io/valkey-glide/pull/2407)) * Node: Add `JSON.STRLEN` and `JSON.STRAPPEND` command ([#2537](https://github.com/valkey-io/valkey-glide/pull/2537)) * Node: Add `FT.SEARCH` ([#2551](https://github.com/valkey-io/valkey-glide/pull/2551)) * Python: Fix example ([#2556](https://github.com/valkey-io/valkey-glide/issues/2556)) diff --git a/python/python/glide/__init__.py b/python/python/glide/__init__.py index bde4f33401..4db35d0e30 100644 --- a/python/python/glide/__init__.py +++ b/python/python/glide/__init__.py @@ -67,7 +67,11 @@ FtSearchLimit, ReturnField, ) -from glide.async_commands.server_modules.json import JsonArrIndexOptions, JsonGetOptions +from glide.async_commands.server_modules.json import ( + JsonArrIndexOptions, + JsonArrPopOptions, + JsonGetOptions, +) from glide.async_commands.sorted_set import ( AggregationType, GeoSearchByBox, @@ -263,6 +267,7 @@ "json", "JsonGetOptions", "JsonArrIndexOptions", + "JsonArrPopOptions", # Logger "Logger", "LogLevel", diff --git a/python/python/glide/async_commands/server_modules/json.py b/python/python/glide/async_commands/server_modules/json.py index 0b2cfbe4e0..3c98672d9f 100644 --- a/python/python/glide/async_commands/server_modules/json.py +++ b/python/python/glide/async_commands/server_modules/json.py @@ -84,6 +84,33 @@ def to_args(self) -> List[str]: return args +class JsonArrPopOptions: + """ + Options for the JSON.ARRPOP command. + + Args: + path (TEncodable): The path within the JSON document. + index (Optional[int]): The index of the element to pop. If not specified, will pop the last element. + Out of boundary indexes are rounded to their respective array boundaries. Defaults to None. + """ + + def __init__(self, path: TEncodable, index: Optional[int] = None): + self.path = path + self.index = index + + def to_args(self) -> List[TEncodable]: + """ + Get the options as a list of arguments for the JSON.ARRPOP command. + + Returns: + List[TEncodable]: A list containing the path and, if specified, the index. + """ + args = [self.path] + if self.index is not None: + args.append(str(self.index)) + return args + + async def set( client: TGlideClient, key: TEncodable, @@ -403,6 +430,68 @@ async def arrlen( ) +async def arrpop( + client: TGlideClient, + key: TEncodable, + options: Optional[JsonArrPopOptions] = None, +) -> Optional[TJsonResponse[bytes]]: + """ + Pops an element from the array located at the specified path within the JSON document stored at `key`. + If `options.index` is provided, it pops the element at that index instead of the last element. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + options (Optional[JsonArrPopOptions]): Options including the path and optional index. See `JsonArrPopOptions`. Default to None. + If not specified, attempts to pop the last element from the root value if it's an array. + If the root value is not an array, an error will be raised. + + Returns: + Optional[TJsonResponse[bytes]]: + For JSONPath (`options.path` starts with `$`): + Returns a list of bytes string replies for every possible path, representing the popped JSON values, + or None for JSON values matching the path that are not an array or are an empty array. + If `options.path` doesn't exist, an empty list will be returned. + For legacy path (`options.path` doesn't starts with `$`): + Returns a bytes string representing the popped JSON value, or None if the array at `options.path` is empty. + If multiple paths match, the value from the first matching array that is not empty is returned. + If the JSON value at `options.path` is not a array or if `options.path` doesn't exist, an error is raised. + If `key` doesn't exist, an error is raised. + + Examples: + >>> from glide import json + >>> await json.set(client, "doc", "$", '{"a": [1, 2, true], "b": {"a": [3, 4, ["value", 3, false], 5], "c": {"a": 42}}}') + b'OK' + >>> await json.arrpop(client, "doc", JsonArrPopOptions(path="$.a", index=1)) + [b'2'] # Pop second element from array at path $.a + >>> await json.arrpop(client, "doc", JsonArrPopOptions(path="$..a")) + [b'true', b'5', None] # Pop last elements from all arrays matching path `$..a` + + #### Using a legacy path (..) to pop the first matching array + >>> await json.arrpop(client, "doc", JsonArrPopOptions(path="..a")) + b"1" # First match popped (from array at path ..a) + + #### Even though only one value is returned from `..a`, subsequent arrays are also affected + >>> await json.get(client, "doc", "$..a") + b"[[], [3, 4], 42]" # Remaining elements after pop show the changes + + >>> await json.set(client, "doc", "$", '[[], ["a"], ["a", "b", "c"]]') + b'OK' # JSON is successfully set + >>> await json.arrpop(client, "doc", JsonArrPopOptions(path=".", index=-1)) + b'["a","b","c"]' # Pop last elements at path `.` + >>> await json.arrpop(client, "doc") + b'["a"]' # Pop last elements at path `.` + """ + args = ["JSON.ARRPOP", key] + if options: + args.extend(options.to_args()) + + return cast( + Optional[TJsonResponse[bytes]], + await client.custom_command(args), + ) + + async def arrtrim( client: TGlideClient, key: TEncodable, diff --git a/python/python/tests/tests_server_modules/test_json.py b/python/python/tests/tests_server_modules/test_json.py index 783b33cb2c..a77797af93 100644 --- a/python/python/tests/tests_server_modules/test_json.py +++ b/python/python/tests/tests_server_modules/test_json.py @@ -8,7 +8,11 @@ import pytest from glide.async_commands.core import ConditionalChange, InfoSection from glide.async_commands.server_modules import json -from glide.async_commands.server_modules.json import JsonArrIndexOptions, JsonGetOptions +from glide.async_commands.server_modules.json import ( + JsonArrIndexOptions, + JsonArrPopOptions, + JsonGetOptions, +) from glide.config import ProtocolVersion from glide.constants import OK from glide.exceptions import RequestError @@ -1852,3 +1856,101 @@ async def test_json_resp(self, glide_client: TGlideClient): assert await json.resp(glide_client, "nonexistent_key", "$") is None assert await json.resp(glide_client, "nonexistent_key", ".") is None assert await json.resp(glide_client, "nonexistent_key") is None + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_arrpop(self, glide_client: TGlideClient): + key = get_random_string(5) + key2 = get_random_string(5) + + json_value = '{"a": [1, 2, true], "b": {"a": [3, 4, ["value", 3, false] ,5], "c": {"a": 42}}}' + assert await json.set(glide_client, key, "$", json_value) == OK + + assert await json.arrpop( + glide_client, key, JsonArrPopOptions(path="$.a", index=1) + ) == [b"2"] + assert ( + await json.arrpop(glide_client, key, JsonArrPopOptions(path="$..a")) + ) == [b"true", b"5", None] + + assert ( + await json.arrpop(glide_client, key, JsonArrPopOptions(path="..a")) == b"1" + ) + # Even if only one array element was returned, ensure second array at `..a` was popped + assert await json.get(glide_client, key, "$..a") == b"[[],[3,4],42]" + + # Out of index + assert await json.arrpop( + glide_client, key, JsonArrPopOptions(path="$..a", index=10) + ) == [None, b"4", None] + + assert ( + await json.arrpop( + glide_client, key, JsonArrPopOptions(path="..a", index=-10) + ) + == b"3" + ) + + # Path is not an array + assert await json.arrpop(glide_client, key, JsonArrPopOptions(path="$")) == [ + None + ] + with pytest.raises(RequestError): + assert await json.arrpop(glide_client, key, JsonArrPopOptions(path=".")) + with pytest.raises(RequestError): + assert await json.arrpop(glide_client, key) + + # Non existing path + assert ( + await json.arrpop( + glide_client, key, JsonArrPopOptions(path="$.non_existing_path") + ) + == [] + ) + with pytest.raises(RequestError): + assert await json.arrpop( + glide_client, key, JsonArrPopOptions(path="non_existing_path") + ) + + # Non existing key + with pytest.raises(RequestError): + await json.arrpop( + glide_client, "non_existing_key", JsonArrPopOptions(path="$.a") + ) + with pytest.raises(RequestError): + await json.arrpop( + glide_client, "non_existing_key", JsonArrPopOptions(path=".a") + ) + + assert ( + await json.set(glide_client, key2, "$", '[[], ["a"], ["a", "b", "c"]]') + == OK + ) + assert ( + await json.arrpop(glide_client, key2, JsonArrPopOptions(path=".", index=-1)) + == b'["a","b","c"]' + ) + assert await json.arrpop(glide_client, key2) == b'["a"]' + + # pop from an empty array + assert await json.arrpop(glide_client, key2, JsonArrPopOptions("$[0]")) == [ + None + ] + assert await json.arrpop(glide_client, key2, JsonArrPopOptions("$[0]", 10)) == [ + None + ] + assert await json.arrpop(glide_client, key2, JsonArrPopOptions("[0]")) == None + assert ( + await json.arrpop(glide_client, key2, JsonArrPopOptions("[0]", 10)) == None + ) + + # non jsonpath pops from all matching paths, even if one result is being returned + assert ( + await json.set( + glide_client, key2, "$", '[[], ["a"], ["a", "b"], ["a", "b", "c"]]' + ) + == OK + ) + + assert await json.arrpop(glide_client, key2, JsonArrPopOptions("[*]")) == b'"a"' + assert await json.get(glide_client, key2, ".") == b'[[],[],["a"],["a","b"]]' From 9340ea46f6ad85b5627ffc71880ef80b01d7c334 Mon Sep 17 00:00:00 2001 From: barshaul Date: Mon, 4 Nov 2024 13:46:42 +0000 Subject: [PATCH 089/180] Fix benchmark CI error MISCONF Signed-off-by: barshaul --- .github/workflows/test-benchmark/action.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test-benchmark/action.yml b/.github/workflows/test-benchmark/action.yml index 91cc36697f..3bd50dc0f2 100644 --- a/.github/workflows/test-benchmark/action.yml +++ b/.github/workflows/test-benchmark/action.yml @@ -11,7 +11,8 @@ runs: steps: - shell: bash - run: redis-server & + # Disable RDB snapshots to avoid configuration errors + run: redis-server --save "" --daemonize "yes" - shell: bash working-directory: ./benchmarks From db7b1bb8ed01e43cc91b5825b5bc0a0fedb4ffdb Mon Sep 17 00:00:00 2001 From: Shoham Elias <116083498+shohamazon@users.noreply.github.com> Date: Mon, 4 Nov 2024 20:57:25 +0200 Subject: [PATCH 090/180] Python: Round map values in tests for floating-point comparison (#2592) --------- Signed-off-by: Ubuntu Signed-off-by: Shoham Elias Co-authored-by: Ubuntu --- python/python/tests/test_async_client.py | 10 ++++++++++ python/python/tests/utils/utils.py | 5 +++++ 2 files changed, 15 insertions(+) diff --git a/python/python/tests/test_async_client.py b/python/python/tests/test_async_client.py index c9744157d6..24768423b3 100644 --- a/python/python/tests/test_async_client.py +++ b/python/python/tests/test_async_client.py @@ -105,6 +105,7 @@ get_random_string, is_single_response, parse_info_response, + round_values, ) @@ -2687,6 +2688,7 @@ async def test_geosearchstore_by_box(self, glide_client: TGlideClient): ) expected_map = {member: value[1] for member, value in result.items()} sorted_expected_map = dict(sorted(expected_map.items(), key=lambda x: x[1])) + zrange_map = round_values(zrange_map, 10) assert compare_maps(zrange_map, sorted_expected_map) is True # Test storing results of a box search, unit: kilometes, from a geospatial data, with distance @@ -2706,6 +2708,8 @@ async def test_geosearchstore_by_box(self, glide_client: TGlideClient): ) expected_map = {member: value[0] for member, value in result.items()} sorted_expected_map = dict(sorted(expected_map.items(), key=lambda x: x[1])) + zrange_map = round_values(zrange_map, 10) + sorted_expected_map = round_values(sorted_expected_map, 10) assert compare_maps(zrange_map, sorted_expected_map) is True # Test storing results of a box search, unit: kilometes, from a geospatial data, with count @@ -2746,6 +2750,8 @@ async def test_geosearchstore_by_box(self, glide_client: TGlideClient): b"Palermo": 166274.15156960033, b"edge2": 236529.17986494553, } + zrange_map = round_values(zrange_map, 9) + expected_distances = round_values(expected_distances, 9) assert compare_maps(zrange_map, expected_distances) is True # Test search by box, unit: feet, from a member, with limited ANY count to 2, with hash @@ -2827,6 +2833,8 @@ async def test_geosearchstore_by_radius(self, glide_client: TGlideClient): b"Catania": 0.0, b"Palermo": 166274.15156960033, } + zrange_map = round_values(zrange_map, 9) + expected_distances = round_values(expected_distances, 9) assert compare_maps(zrange_map, expected_distances) is True # Test search by radius, unit: miles, from a geospatial data @@ -2860,6 +2868,8 @@ async def test_geosearchstore_by_radius(self, glide_client: TGlideClient): ) expected_map = {member: value[0] for member, value in result.items()} sorted_expected_map = dict(sorted(expected_map.items(), key=lambda x: x[1])) + zrange_map = round_values(zrange_map, 10) + sorted_expected_map = round_values(sorted_expected_map, 10) assert compare_maps(zrange_map, sorted_expected_map) is True # Test storing results of a radius search, unit: kilometers, from a geospatial data, with limited ANY count to 1 diff --git a/python/python/tests/utils/utils.py b/python/python/tests/utils/utils.py index 497342b5c7..f912d5f6bd 100644 --- a/python/python/tests/utils/utils.py +++ b/python/python/tests/utils/utils.py @@ -137,6 +137,11 @@ def compare_maps( ) +def round_values(map_data: dict, decimal_places: int) -> dict: + """Round the values in a map to the specified number of decimal places.""" + return {key: round(value, decimal_places) for key, value in map_data.items()} + + def convert_bytes_to_string_object( # TODO: remove the str options byte_string_dict: Optional[ From 2b6578a8ddf244c6f1e6da9f4290ae1dfdc3ad9e Mon Sep 17 00:00:00 2001 From: Avi Fenesh <55848801+avifenesh@users.noreply.github.com> Date: Tue, 5 Nov 2024 11:29:53 +0200 Subject: [PATCH 091/180] Merge ci from main (#2584) * Update README.md nodejs platform support (#2397) * Update README.md Signed-off-by: Avi Fenesh <55848801+avifenesh@users.noreply.github.com> * Update README.md - lint fix Signed-off-by: Avi Fenesh <55848801+avifenesh@users.noreply.github.com> * Update README.md ling Signed-off-by: Avi Fenesh <55848801+avifenesh@users.noreply.github.com> * Update node/README.md Co-authored-by: Yury-Fridlyand Signed-off-by: Avi Fenesh <55848801+avifenesh@users.noreply.github.com> --------- Signed-off-by: Avi Fenesh <55848801+avifenesh@users.noreply.github.com> Co-authored-by: Yury-Fridlyand * CI - Minimal and full CI matrix impl (#2051) * CI - Minimal and full CI matrix impl Signed-off-by: avifenesh * Fix mypy failing (#2453) --------- Signed-off-by: Shoham Elias Signed-off-by: avifenesh * Python: adds JSON.ARRLEN command (#2403) --------- Signed-off-by: Shoham Elias Signed-off-by: Shoham Elias <116083498+shohamazon@users.noreply.github.com> Signed-off-by: avifenesh --------- Signed-off-by: avifenesh Signed-off-by: Shoham Elias Signed-off-by: Shoham Elias <116083498+shohamazon@users.noreply.github.com> Co-authored-by: Shoham Elias <116083498+shohamazon@users.noreply.github.com> * CI - Minimal and full CI matrix impl (#2051) * CI - Minimal and full CI matrix impl Signed-off-by: avifenesh * Fix mypy failing (#2453) --------- Signed-off-by: Shoham Elias Signed-off-by: avifenesh * Python: adds JSON.ARRLEN command (#2403) --------- Signed-off-by: Shoham Elias Signed-off-by: Shoham Elias <116083498+shohamazon@users.noreply.github.com> Signed-off-by: avifenesh --------- Signed-off-by: avifenesh Signed-off-by: Shoham Elias Signed-off-by: Shoham Elias <116083498+shohamazon@users.noreply.github.com> Co-authored-by: Shoham Elias <116083498+shohamazon@users.noreply.github.com> * CI - Minimal and full CI matrix impl (#2051) * CI - Minimal and full CI matrix impl Signed-off-by: avifenesh * Fix mypy failing (#2453) --------- Signed-off-by: Shoham Elias Signed-off-by: avifenesh * Python: adds JSON.ARRLEN command (#2403) --------- Signed-off-by: Shoham Elias Signed-off-by: Shoham Elias <116083498+shohamazon@users.noreply.github.com> Signed-off-by: avifenesh --------- Signed-off-by: avifenesh Signed-off-by: Shoham Elias Signed-off-by: Shoham Elias <116083498+shohamazon@users.noreply.github.com> Co-authored-by: Shoham Elias <116083498+shohamazon@users.noreply.github.com> * Fix for - Minimal and full CI matrix impl #2051 (#2500) * Refactor tests to use async cleanup and improve error handling Signed-off-by: avifenesh * Enhance Jest configuration and add test setup file; update build scripts and dependencies Signed-off-by: avifenesh * Update devDependencies in package.json for hybrid-node-tests to latest versions Signed-off-by: avifenesh * Enhance test utilities and command tests with improved wait logic and version checks Signed-off-by: avifenesh * Refactor tests to assert expected replica reads are less than or equal to actual reads; update connection handling in utilities and allow unused imports in types Signed-off-by: avifenesh * Update dependencies and enhance PyO3 bindings; add new features and improve type handling Signed-off-by: avifenesh * Update GitHub workflows: enhance linting configurations, adjust engine version requirements, and remove obsolete Redis installation workflow Signed-off-by: avifenesh --------- Signed-off-by: avifenesh * fixes for CI (#2552) Signed-off-by: avifenesh * Refactor CI configuration for consistency and clarity Signed-off-by: avifenesh * Refactor CI configuration for consistency and clarity Signed-off-by: avifenesh --------- Signed-off-by: Avi Fenesh <55848801+avifenesh@users.noreply.github.com> Signed-off-by: avifenesh Signed-off-by: Shoham Elias Signed-off-by: Shoham Elias <116083498+shohamazon@users.noreply.github.com> Co-authored-by: Yury-Fridlyand Co-authored-by: Shoham Elias <116083498+shohamazon@users.noreply.github.com> --- .github/DEVELOPER.md | 144 ++++++++++ .github/ISSUE_TEMPLATE/bug-report.yml | 236 ++++++++-------- .github/ISSUE_TEMPLATE/feature-request.yml | 100 +++---- .github/json_matrices/build-matrix.json | 67 ++--- .github/json_matrices/engine-matrix.json | 26 +- .../supported-languages-versions.json | 27 ++ .github/pull_request_template.md | 14 +- .../workflows/build-node-wrapper/action.yml | 18 +- .../workflows/build-python-wrapper/action.yml | 2 +- .github/workflows/codeql.yml | 4 +- .../workflows/create-test-matrices/action.yml | 80 ++++++ .github/workflows/csharp.yml | 174 ++++++++---- .github/workflows/full-matrix-tests.yml | 126 +++++++++ .github/workflows/go.yml | 230 ++++++++-------- .github/workflows/install-engine/action.yml | 102 +++++++ .github/workflows/install-redis/action.yml | 46 ---- .../install-rust-and-protoc/action.yml | 1 - .../install-shared-dependencies/action.yml | 13 +- .github/workflows/install-valkey/action.yml | 79 ------ .github/workflows/java-cd.yml | 24 +- .github/workflows/java.yml | 146 +++++----- .github/workflows/lint-rust/action.yml | 4 +- .github/workflows/lint-ts/action.yml | 2 +- .../node-create-package-file/action.yml | 2 +- .github/workflows/node.yml | 245 ++++++++--------- .github/workflows/npm-cd.yml | 123 +++++---- .github/workflows/ort.yml | 187 +++++++------ .github/workflows/pypi-cd.yml | 31 +-- .github/workflows/python.yml | 228 ++++++++-------- .github/workflows/redis-rs.yml | 7 +- .github/workflows/run-ort-tools/action.yml | 16 +- .github/workflows/rust.yml | 67 +++-- .github/workflows/semgrep.yml | 2 +- .../workflows/setup-musl-on-linux/action.yml | 18 +- .../start-self-hosted-runner/action.yml | 6 +- .prettierignore | 1 + .vscode/settings.json | 5 +- .../redis-rs/scripts/get_command_info.py | 1 + glide-core/src/client/types.rs | 2 + glide-core/tests/test_standalone_client.rs | 2 +- glide-core/tests/utilities/mod.rs | 2 +- .../test/java/glide/TestConfiguration.java | 3 + node/README.md | 8 +- .../commonjs-test/package.json | 6 +- .../ecmascript-test/package.json | 6 +- node/jest.config.js | 5 +- node/package.json | 6 +- node/tests/AsyncClient.test.ts | 4 +- node/tests/GlideClient.test.ts | 75 ++++-- node/tests/GlideClientInternals.test.ts | 7 +- node/tests/GlideClusterClient.test.ts | 35 ++- node/tests/PubSub.test.ts | 30 ++- node/tests/ScanTest.test.ts | 6 +- node/tests/SharedTests.ts | 251 ++++++++---------- node/tests/TestUtilities.ts | 129 +++------ node/tests/setup.js | 7 + python/Cargo.toml | 17 +- python/python/tests/conftest.py | 2 +- python/python/tests/test_pubsub.py | 14 +- python/src/lib.rs | 79 +++--- utils/TestUtils.ts | 58 ++-- 61 files changed, 1921 insertions(+), 1437 deletions(-) create mode 100644 .github/DEVELOPER.md create mode 100644 .github/json_matrices/supported-languages-versions.json create mode 100644 .github/workflows/create-test-matrices/action.yml create mode 100644 .github/workflows/full-matrix-tests.yml create mode 100644 .github/workflows/install-engine/action.yml delete mode 100644 .github/workflows/install-redis/action.yml delete mode 100644 .github/workflows/install-valkey/action.yml create mode 100644 .prettierignore create mode 100644 node/tests/setup.js diff --git a/.github/DEVELOPER.md b/.github/DEVELOPER.md new file mode 100644 index 0000000000..2acc4ccb68 --- /dev/null +++ b/.github/DEVELOPER.md @@ -0,0 +1,144 @@ +# CI/CD Workflow Guide + +### Overview + +Our CI/CD pipeline tests and builds our project across multiple languages, versions, and environments. This guide outlines the key components and processes of our workflow. + +### Workflow Triggers + +- Pull requests +- Pushes to `main` or release branches (PR merges) +- Scheduled runs (daily) - starts CI pipelines for all clients +- Manual trigger (`workflow_dispatch`) - a developer can start a client's pipeline or the scheduled one to run all pipelines on demand + +Job triggers + +### Test coverage + +There are two levels of testing: the basic one and full (_aka_ `full-matrix`). +Basic amount of testing is executed on every open and merged PR. The full set of tests is executed by the scheduled job. +A developer can select the level when starting a job, either scheduled or client's pipeline. + +Matrices + +### Language-Specific Workflows + +Each language has its own workflow file with similar structure but language-specific steps, for example python.yml for Python, or java.yml for Java. + +### Shared Components + +#### Matrix Files + +While workflows are language-specific, the matrix files are shared across all workflows. +Workflows are starting by loading the matrix files from the `.github/json_matrices` directory. + +- `engine-matrix.json`: Defines the versions of the engine to test against. +- `build-matrix.json`: Defines the host environments for testing. +- `supported-languages-versions.json`: Defines the supported versions of languages. + +All matrices have a `run` like field which specifies if the configuration should be tested on every workflow run. +This allows for flexible control over which configurations are tested in different scenarios, optimizing CI/CD performance and resource usage. + +#### Engine Matrix (engine-matrix.json) + +Defines the versions of Valkey engine to test against: + +```json +[ + { "type": "valkey", "version": "7.2.5", "run": "always" } + // ... other configurations +] +``` + +- `type`: The type of engine (e.g., Valkey, Redis). +- `version`: Specifies the engine version that the workflow should checkout. + For example, "7.2.5" represents a release tag, while "7.0" denotes a branch name. The workflow should use this parameter to checkout the specific release version or branch to build the engine with the appropriate version. +- `run`: Specifies if the engine version should be tested on every workflow. + +#### Build Matrix (build-matrix.json) + +Defines the host environments for testing: + +```json +[ + { + "OS": "ubuntu", + "RUNNER": "ubuntu-latest", + "TARGET": "x86_64-unknown-linux-gnu", + "run": ["always", "python", "node", "java"] + } + // ... other configurations +] +``` + +- `OS`: The operating system of the host. +- `RUNNER`: The GitHub runner to use. +- `TARGET`: The target environment as defined in Rust. To see a list of available targets, run `rustup target list`. +- `run`: Specifies which language workflows should use this host configuration. The value `always` indicates that the configuration should be used for every workflow trigger. + +#### Supported Languages Version (supported-languages-version.json) + +Defines the supported versions of languages: + +```json +[ + { + "language": "java", + "versions": ["11", "17"], + "always-run-versions": ["17"] + } + // ... other configurations +] +``` + +- `language`: The language for which the version is supported. +- `versions`: The full versions supported of the language which will test against scheduled. +- `always-run-versions`: The versions that will always be tested, regardless of the workflow trigger. + +#### Triggering Workflows + +- Push to `main` by merging a PR or create a new pull request to run workflows automatically. +- Use `workflow_dispatch` for manual triggers, accepting a boolean configuration parameter to run all configurations. +- Scheduled runs are triggered daily to ensure regular testing of all configurations. + +### Mutual vs. Language-Specific Components + +#### Mutual + +- Matrix files - `.github/json_matrices/` +- Shared dependencies installation - `.github/workflows/install-shared-dependencies/action.yml` +- Rust linters - `.github/workflows/lint-rust/action.yml` + +#### Language-Specific + +- Package manager commands +- Testing frameworks +- Build processes + +### Customizing Workflows + +Modify `.yml` files to adjust language-specific steps. +Update matrix files to change tested versions or environments. +Adjust cron schedules in workflow files for different timing of scheduled runs. + +### Workflow Matrices + +We use dynamic matrices for our CI/CD workflows, which are created using the `create-test-matrices` action. This action is defined in `.github/workflows/create-test-matrices/action.yml`. + +#### How it works + +1. The action is called with a `language-name` input and `dispatch-run-full-matrix` input. +2. It reads the `engine-matrix.json`, `build-matrix.json`, and `supported-languages-version.json` files. +3. It filters the matrices based on the inputs and the event type. +4. It generates three matrices: + - Engine matrix: Defines the types and versions of the engine to test against, for example Valkey 7.2.5. + - Host matrix: Defines the host platforms to run the tests on, for example Ubuntu on ARM64. + - Language-version matrix: Defines the supported versions of languages, for example python 3.8. + +#### Outputs + +- `engine-matrix-output`: The generated engine matrix. +- `host-matrix-output`: The generated host matrix. +- `language-version-matrix-output`: The generated language version matrix. + +This dynamic matrix generation allows for flexible and efficient CI/CD workflows, adapting the test configurations based on the type of change and the specific language being tested. diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 5f63253d51..98a018219c 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -5,134 +5,134 @@ title: "(topic): (short issue description)" labels: [bug, needs-triage] assignees: [] body: - - type: textarea - id: description - attributes: - label: Describe the bug - description: What is the problem? A clear and concise description of the bug. - validations: - required: true + - type: textarea + id: description + attributes: + label: Describe the bug + description: What is the problem? A clear and concise description of the bug. + validations: + required: true - - type: textarea - id: expected - attributes: - label: Expected Behavior - description: | - What did you expect to happen? - validations: - required: true + - type: textarea + id: expected + attributes: + label: Expected Behavior + description: | + What did you expect to happen? + validations: + required: true - - type: textarea - id: current - attributes: - label: Current Behavior - description: | - What actually happened? - - Please include full errors, uncaught exceptions, stack traces, and relevant logs. - If service responses are relevant, please include wire logs. - validations: - required: true + - type: textarea + id: current + attributes: + label: Current Behavior + description: | + What actually happened? - - type: textarea - id: reproduction - attributes: - label: Reproduction Steps - description: | - Provide a self-contained, concise snippet of code that can be used to reproduce the issue. - For more complex issues provide a repo with the smallest sample that reproduces the bug. - - Avoid including business logic or unrelated code, it makes diagnosis more difficult. - The code sample should be an SSCCE. See http://sscce.org/ for details. In short, please provide a code sample that we can copy/paste, run and reproduce. - validations: - required: true + Please include full errors, uncaught exceptions, stack traces, and relevant logs. + If service responses are relevant, please include wire logs. + validations: + required: true - - type: textarea - id: solution - attributes: - label: Possible Solution - description: | - Suggest a fix/reason for the bug - validations: - required: false + - type: textarea + id: reproduction + attributes: + label: Reproduction Steps + description: | + Provide a self-contained, concise snippet of code that can be used to reproduce the issue. + For more complex issues provide a repo with the smallest sample that reproduces the bug. - - type: textarea - id: context - attributes: - label: Additional Information/Context - description: | - Anything else that might be relevant for troubleshooting this bug. Providing context helps us come up with a solution that is most useful in the real world. - validations: - required: false + Avoid including business logic or unrelated code, it makes diagnosis more difficult. + The code sample should be an SSCCE. See http://sscce.org/ for details. In short, please provide a code sample that we can copy/paste, run and reproduce. + validations: + required: true - - type: input - id: client-version - attributes: - label: Client version used - validations: - required: true + - type: textarea + id: solution + attributes: + label: Possible Solution + description: | + Suggest a fix/reason for the bug + validations: + required: false - - type: input - id: engine-version - attributes: - label: Engine type and version - description: E.g. Valkey 7.0 - validations: - required: true + - type: textarea + id: context + attributes: + label: Additional Information/Context + description: | + Anything else that might be relevant for troubleshooting this bug. Providing context helps us come up with a solution that is most useful in the real world. + validations: + required: false - - type: input - id: operating-system - attributes: - label: OS - validations: - required: true + - type: input + id: client-version + attributes: + label: Client version used + validations: + required: true - - type: dropdown - id: language - attributes: - label: Language - multiple: true - options: - - TypeScript - - Python - - Java - - Rust - - Go - - .Net - validations: - required: true + - type: input + id: engine-version + attributes: + label: Engine type and version + description: E.g. Valkey 7.0 + validations: + required: true - - type: input - id: language-version - attributes: - label: Language Version - description: E.g. TypeScript (5.2.2) | Python (3.9) - validations: - required: true + - type: input + id: operating-system + attributes: + label: OS + validations: + required: true - - type: textarea - id: cluster-info - attributes: - label: Cluster information - description: | - Cluster information, cluster topology, number of shards, number of replicas, used data types. - validations: - required: false + - type: dropdown + id: language + attributes: + label: Language + multiple: true + options: + - TypeScript + - Python + - Java + - Rust + - Go + - .Net + validations: + required: true - - type: textarea - id: logs - attributes: - label: Logs - description: | - Client and/or server logs. - validations: - required: false + - type: input + id: language-version + attributes: + label: Language Version + description: E.g. TypeScript (5.2.2) | Python (3.9) + validations: + required: true - - type: textarea - id: other - attributes: - label: Other information - description: | - e.g. detailed explanation, stacktraces, related issues, suggestions how to fix, links for us to have context, eg. associated pull-request, stackoverflow, etc - validations: - required: false + - type: textarea + id: cluster-info + attributes: + label: Cluster information + description: | + Cluster information, cluster topology, number of shards, number of replicas, used data types. + validations: + required: false + + - type: textarea + id: logs + attributes: + label: Logs + description: | + Client and/or server logs. + validations: + required: false + + - type: textarea + id: other + attributes: + label: Other information + description: | + e.g. detailed explanation, stacktraces, related issues, suggestions how to fix, links for us to have context, eg. associated pull-request, stackoverflow, etc + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/feature-request.yml b/.github/ISSUE_TEMPLATE/feature-request.yml index 64684f1d1c..a7607565aa 100644 --- a/.github/ISSUE_TEMPLATE/feature-request.yml +++ b/.github/ISSUE_TEMPLATE/feature-request.yml @@ -5,55 +5,55 @@ title: "(topic): (short issue description)" labels: [feature-request, needs-triage] assignees: [] body: - - type: textarea - id: description - attributes: - label: Describe the feature - description: A clear and concise description of the feature you are proposing. - validations: - required: true - - type: textarea - id: use-case - attributes: - label: Use Case - description: | - Why do you need this feature? - validations: - required: true - - type: textarea - id: solution - attributes: - label: Proposed Solution - description: | - Suggest how to implement the addition or change. Please include prototype/workaround/sketch/reference implementation. - validations: - required: false - - type: textarea - id: other - attributes: - label: Other Information - description: | - Any alternative solutions or features you considered, a more detailed explanation, stack traces, related issues, links for context, etc. - validations: - required: false - - type: checkboxes - id: ack - attributes: - label: Acknowledgements - options: - - label: I may be able to implement this feature request + - type: textarea + id: description + attributes: + label: Describe the feature + description: A clear and concise description of the feature you are proposing. + validations: + required: true + - type: textarea + id: use-case + attributes: + label: Use Case + description: | + Why do you need this feature? + validations: + required: true + - type: textarea + id: solution + attributes: + label: Proposed Solution + description: | + Suggest how to implement the addition or change. Please include prototype/workaround/sketch/reference implementation. + validations: required: false - - label: This feature might incur a breaking change + - type: textarea + id: other + attributes: + label: Other Information + description: | + Any alternative solutions or features you considered, a more detailed explanation, stack traces, related issues, links for context, etc. + validations: required: false - - type: input - id: client-version - attributes: - label: Client version used - validations: - required: true - - type: input - id: environment - attributes: - label: Environment details (OS name and version, etc.) - validations: - required: true + - type: checkboxes + id: ack + attributes: + label: Acknowledgements + options: + - label: I may be able to implement this feature request + required: false + - label: This feature might incur a breaking change + required: false + - type: input + id: client-version + attributes: + label: Client version used + validations: + required: true + - type: input + id: environment + attributes: + label: Environment details (OS name and version, etc.) + validations: + required: true diff --git a/.github/json_matrices/build-matrix.json b/.github/json_matrices/build-matrix.json index fc02093b9f..59c50617f9 100644 --- a/.github/json_matrices/build-matrix.json +++ b/.github/json_matrices/build-matrix.json @@ -5,40 +5,18 @@ "RUNNER": "ubuntu-latest", "ARCH": "x64", "TARGET": "x86_64-unknown-linux-gnu", - "PACKAGE_MANAGERS": [ - "pypi", - "npm", - "maven" - ] + "PACKAGE_MANAGERS": ["pypi", "npm", "maven"], + "run": "always", + "languages": ["python", "node", "java", "go", "dotnet"] }, { "OS": "ubuntu", "NAMED_OS": "linux", - "RUNNER": [ - "self-hosted", - "Linux", - "ARM64" - ], + "RUNNER": ["self-hosted", "Linux", "ARM64"], "ARCH": "arm64", "TARGET": "aarch64-unknown-linux-gnu", - "PACKAGE_MANAGERS": [ - "pypi", - "npm", - "maven" - ], - "CONTAINER": "2_28" - }, - { - "OS": "macos", - "NAMED_OS": "darwin", - "RUNNER": "macos-12", - "ARCH": "x64", - "TARGET": "x86_64-apple-darwin", - "PACKAGE_MANAGERS": [ - "pypi", - "npm", - "maven" - ] + "PACKAGE_MANAGERS": ["pypi", "npm", "maven"], + "languages": ["python", "node", "java", "go", "dotnet"] }, { "OS": "macos", @@ -46,27 +24,19 @@ "RUNNER": "macos-latest", "ARCH": "arm64", "TARGET": "aarch64-apple-darwin", - "PACKAGE_MANAGERS": [ - "pypi", - "npm", - "maven" - ] + "PACKAGE_MANAGERS": ["pypi", "npm", "maven"], + "languages": ["python", "node", "java", "go", "dotnet"] }, { "OS": "ubuntu", "NAMED_OS": "linux", "ARCH": "arm64", "TARGET": "aarch64-unknown-linux-musl", - "RUNNER": [ - "self-hosted", - "Linux", - "ARM64" - ], + "RUNNER": ["self-hosted", "Linux", "ARM64"], "IMAGE": "node:lts-alpine3.19", "CONTAINER_OPTIONS": "--user root --privileged --rm", - "PACKAGE_MANAGERS": [ - "npm" - ] + "PACKAGE_MANAGERS": ["npm"], + "languages": ["node"] }, { "OS": "ubuntu", @@ -76,8 +46,17 @@ "RUNNER": "ubuntu-latest", "IMAGE": "node:lts-alpine3.19", "CONTAINER_OPTIONS": "--user root --privileged", - "PACKAGE_MANAGERS": [ - "npm" - ] + "PACKAGE_MANAGERS": ["npm"], + "languages": ["node"] + }, + { + "OS": "amazon-linux", + "NAMED_OS": "linux", + "RUNNER": "ubuntu-latest", + "ARCH": "x64", + "TARGET": "x86_64-unknown-linux-gnu", + "IMAGE": "amazonlinux:latest", + "PACKAGE_MANAGERS": [], + "languages": ["python", "node", "java", "go", "dotnet"] } ] diff --git a/.github/json_matrices/engine-matrix.json b/.github/json_matrices/engine-matrix.json index 464aedf31a..06a8a27fd9 100644 --- a/.github/json_matrices/engine-matrix.json +++ b/.github/json_matrices/engine-matrix.json @@ -1,10 +1,20 @@ [ - { - "type": "valkey", - "version": "7.2.5" - }, - { - "type": "valkey", - "version": "8.0.0" - } + { + "type": "valkey", + "version": "8.0", + "run": "always" + }, + { + "type": "valkey", + "version": "7.2" + }, + { + "type": "redis", + "version": "7.0" + }, + { + "type": "redis", + "version": "6.2", + "run": "always" + } ] diff --git a/.github/json_matrices/supported-languages-versions.json b/.github/json_matrices/supported-languages-versions.json new file mode 100644 index 0000000000..aadc99328a --- /dev/null +++ b/.github/json_matrices/supported-languages-versions.json @@ -0,0 +1,27 @@ +[ + { + "language": "java", + "versions": ["11", "17"], + "always-run-versions": ["17"] + }, + { + "language": "python", + "versions": ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"], + "always-run-versions": ["3.8", "3.13"] + }, + { + "language": "node", + "versions": ["16.x", "17.x", "18.x", "19.x", "20.x", "21.x", "22.x"], + "always-run-versions": ["16.x", "22.x"] + }, + { + "language": "dotnet", + "versions": ["8.0", "6.0"], + "always-run-versions": ["8.0"] + }, + { + "language": "go", + "versions": ["1.22.0", "1.18.10"], + "always-run-versions": ["1.22.0", "1.18.10"] + } +] diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 8e6d8dd2b3..ff120235e3 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -7,16 +7,16 @@ here](https://github.com/valkey-io/valkey-glide/blob/main/CONTRIBUTING.md) --> ### Issue link + This Pull Request is linked to issue (URL): [REPLACE ME] ### Checklist Before submitting the PR make sure the following are checked: -* [ ] This Pull Request is related to one issue. -* [ ] Commit message has a detailed description of what changed and why. -* [ ] Tests are added or updated. -* [ ] CHANGELOG.md and documentation files are updated. -* [ ] Destination branch is correct - main or release -* [ ] Commits will be squashed upon merging. - +- [ ] This Pull Request is related to one issue. +- [ ] Commit message has a detailed description of what changed and why. +- [ ] Tests are added or updated. +- [ ] CHANGELOG.md and documentation files are updated. +- [ ] Destination branch is correct - main or release +- [ ] Commits will be squashed upon merging. diff --git a/.github/workflows/build-node-wrapper/action.yml b/.github/workflows/build-node-wrapper/action.yml index aa1200fbd5..9d2f14d59f 100644 --- a/.github/workflows/build-node-wrapper/action.yml +++ b/.github/workflows/build-node-wrapper/action.yml @@ -65,20 +65,20 @@ runs: - name: Create package.json file uses: ./.github/workflows/node-create-package-file with: - release_version: ${{ env.RELEASE_VERSION }} - os: ${{ inputs.os }} - named_os: ${{ inputs.named_os }} - arch: ${{ inputs.arch }} - npm_scope: ${{ inputs.npm_scope }} - target: ${{ inputs.target }} - + release_version: ${{ env.RELEASE_VERSION }} + os: ${{ inputs.os }} + named_os: ${{ inputs.named_os }} + arch: ${{ inputs.arch }} + npm_scope: ${{ inputs.npm_scope }} + target: ${{ inputs.target }} + - name: npm install shell: bash working-directory: ./node run: | - rm -rf node_modules && npm install --frozen-lockfile + rm -rf node_modules && rm -rf package-lock.json && npm install cd rust-client - npm install --frozen-lockfile + rm -rf node_modules && rm -rf package-lock.json && npm install - name: Build shell: bash diff --git a/.github/workflows/build-python-wrapper/action.yml b/.github/workflows/build-python-wrapper/action.yml index 72863c6a43..25c7e20b7d 100644 --- a/.github/workflows/build-python-wrapper/action.yml +++ b/.github/workflows/build-python-wrapper/action.yml @@ -15,7 +15,7 @@ inputs: required: true engine-version: description: "Engine version to install" - required: true + required: false type: string publish: description: "Enable building the wrapper in release mode" diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 36ac59f664..fde89563bf 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -2,7 +2,7 @@ name: "CodeQL" on: push: - branches: + branches: - "main" - "v.?[0-9]+.[0-9]+.[0-9]+" - "v.?[0-9]+.[0-9]+" @@ -10,7 +10,7 @@ on: - "v?[0-9]+.[0-9]+" - release-* pull_request: - branches: + branches: - "main" - "v.?[0-9]+.[0-9]+.[0-9]+" - "v.?[0-9]+.[0-9]+" diff --git a/.github/workflows/create-test-matrices/action.yml b/.github/workflows/create-test-matrices/action.yml new file mode 100644 index 0000000000..5bd777b5a1 --- /dev/null +++ b/.github/workflows/create-test-matrices/action.yml @@ -0,0 +1,80 @@ +inputs: + language-name: + description: "Language name" + required: true + type: choice + options: + - java + - node + - python + - go + - C# + run-full-matrix: + description: "Run the full matrix" + required: true + type: boolean + containers: + description: "Run in containers" + required: true + default: false + type: boolean + +outputs: + engine-matrix-output: + description: "Engine matrix" + value: ${{ steps.load-engine-matrix.outputs.engine-matrix }} + host-matrix-output: + description: "Host matrix" + value: ${{ steps.load-host-matrix.outputs.host-matrix }} + version-matrix-output: + description: "Version matrix" + value: ${{ steps.create-lang-version-matrix.outputs.version-matrix }} + +runs: + using: "composite" + steps: + - name: Load engine matrix + id: load-engine-matrix + shell: bash + run: | + set -o pipefail + echo 'Select server engines to run tests against' + if [[ "${{ github.event_name }}" == "pull_request" || "${{ github.event_name }}" == "push" || "${{ inputs.run-full-matrix }}" == "false" ]]; then + echo 'Pick engines marked as `"run": "always"` only - on PR, push or manually triggered job which does not require full matrix' + jq -c '[.[] | select(.run == "always")]' < .github/json_matrices/engine-matrix.json | awk '{ printf "engine-matrix=%s\n", $1 }' | tee -a $GITHUB_OUTPUT + else + echo 'Pick all engines - on cron (schedule) or if manually triggered job requires a full matrix' + jq -c . < .github/json_matrices/engine-matrix.json | awk '{ printf "engine-matrix=%s\n", $1 }' | tee -a $GITHUB_OUTPUT + fi + cat $GITHUB_OUTPUT + + - name: Load host matrix + id: load-host-matrix + shell: bash + run: | + set -o pipefail + [[ "${{ inputs.containers }}" == "true" ]] && CONDITION=".IMAGE?" || CONDITION=".IMAGE == null" + echo 'Select runners (VMs) to run tests on' + if [[ "${{ github.event_name }}" == "pull_request" || "${{ github.event_name }}" == "push" || "${{ inputs.run-full-matrix }}" == "false" ]]; then + echo 'Pick runners marked as '"run": "always"' only - on PR, push or manually triggered job which does not require full matrix' + jq -c '[.[] | select(.run == "always")]' < .github/json_matrices/build-matrix.json | awk '{ printf "host-matrix=%s\n", $1 }' | tee -a $GITHUB_OUTPUT + else + echo 'Pick all runners assigned for the chosen client (language) - on cron (schedule) or if manually triggered job requires a full matrix' + jq -c "[.[] | select(.languages? and any(.languages[] == \"${{ inputs.language-name }}\"; .) and $CONDITION)]" < .github/json_matrices/build-matrix.json | awk '{ printf "host-matrix=%s\n", $1 }' | tee -a $GITHUB_OUTPUT + fi + cat $GITHUB_OUTPUT + + - name: Create language version matrix + id: create-lang-version-matrix + shell: bash + run: | + set -o pipefail + echo 'Select language (framework/SDK) versions to run tests on' + if [[ "${{ github.event_name }}" == "pull_request" || "${{ github.event_name }}" == "push" || "${{ inputs.run-full-matrix }}" == "false" ]]; then + echo 'Pick language versions listed in 'always-run-versions' only - on PR, push or manually triggered job which does not require full matrix' + jq -c '[.[] | select(.language == "${{ inputs.language-name }}") | .["always-run-versions"]][0] // []' < .github/json_matrices/supported-languages-versions.json | awk '{ printf "version-matrix=%s\n", $1 }' | tee -a $GITHUB_OUTPUT + else + echo 'Pick language versions listed in 'versions' - on cron (schedule) or if manually triggered job requires a full matrix' + jq -c '[.[] | select(.language == "${{ inputs.language-name }}") | .versions][0]' < .github/json_matrices/supported-languages-versions.json | awk '{ printf "version-matrix=%s\n", $1 }' | tee -a $GITHUB_OUTPUT + fi + cat $GITHUB_OUTPUT diff --git a/.github/workflows/csharp.yml b/.github/workflows/csharp.yml index eab61c6dc1..1cd5778a5c 100644 --- a/.github/workflows/csharp.yml +++ b/.github/workflows/csharp.yml @@ -10,79 +10,89 @@ on: - csharp/** - glide-core/src/** - glide-core/redis-rs/redis/src/** + - utils/cluster_manager.py - .github/workflows/csharp.yml - .github/workflows/install-shared-dependencies/action.yml - .github/workflows/test-benchmark/action.yml - .github/workflows/lint-rust/action.yml - - .github/workflows/install-valkey/action.yml - - .github/json_matrices/build-matrix.json + - .github/workflows/install-engine/action.yml + - .github/workflows/create-test-matrices/action.yml + - .github/json_matrices/** pull_request: paths: - csharp/** - glide-core/src/** - glide-core/redis-rs/redis/src/** + - utils/cluster_manager.py - .github/workflows/csharp.yml - .github/workflows/install-shared-dependencies/action.yml - .github/workflows/test-benchmark/action.yml - .github/workflows/lint-rust/action.yml - - .github/workflows/install-valkey/action.yml - - .github/json_matrices/build-matrix.json + - .github/workflows/install-engine/action.yml + - .github/workflows/create-test-matrices/action.yml + - .github/json_matrices/** workflow_dispatch: + inputs: + full-matrix: + description: "Run the full engine, host, and language version matrix" + type: boolean + default: false + name: + required: false + type: string + description: "(Optional) Test run name" + + workflow_call: permissions: contents: read concurrency: - group: C#-${{ github.head_ref || github.ref }} + group: C#-${{ github.head_ref || github.ref }}-${{ toJson(inputs) }} cancel-in-progress: true +run-name: + # Set custom name if job is started manually and name is given + ${{ github.event_name == 'workflow_dispatch' && (inputs.name == '' && format('{0} @ {1} {2}', github.ref_name, github.sha, toJson(inputs)) || inputs.name) || '' }} + +env: + CARGO_TERM_COLOR: always + jobs: - load-engine-matrix: + get-matrices: runs-on: ubuntu-latest outputs: - matrix: ${{ steps.load-engine-matrix.outputs.matrix }} + engine-matrix-output: ${{ steps.get-matrices.outputs.engine-matrix-output }} + host-matrix-output: ${{ steps.get-matrices.outputs.host-matrix-output }} + version-matrix-output: ${{ steps.get-matrices.outputs.version-matrix-output }} steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Load the engine matrix - id: load-engine-matrix - shell: bash - run: echo "matrix=$(jq -c . < .github/json_matrices/engine-matrix.json)" >> $GITHUB_OUTPUT - - run-tests: - needs: load-engine-matrix - timeout-minutes: 25 + - uses: actions/checkout@v4 + - id: get-matrices + uses: ./.github/workflows/create-test-matrices + with: + language-name: dotnet + # Run full test matrix if job started by cron or it was explictly specified by a person who triggered the workflow + run-full-matrix: ${{ github.event.inputs.full-matrix == 'true' || github.event_name == 'schedule' }} + + test-csharp: + needs: get-matrices + timeout-minutes: 35 strategy: fail-fast: false matrix: - engine: ${{ fromJson(needs.load-engine-matrix.outputs.matrix) }} - dotnet: - # - '6.0' - - '8.0' - host: - - { - OS: ubuntu, - RUNNER: ubuntu-latest, - TARGET: x86_64-unknown-linux-gnu - } - # - { - # OS: macos, - # RUNNER: macos-latest, - # TARGET: aarch64-apple-darwin - # } - + dotnet: ${{ fromJson(needs.get-matrices.outputs.version-matrix-output) }} + engine: ${{ fromJson(needs.get-matrices.outputs.engine-matrix-output) }} + host: ${{ fromJson(needs.get-matrices.outputs.host-matrix-output) }} runs-on: ${{ matrix.host.RUNNER }} steps: - uses: actions/checkout@v4 - - name: Set up dotnet ${{ matrix.dotnet }} uses: actions/setup-dotnet@v4 with: dotnet-version: ${{ matrix.dotnet }} - + - name: Install shared software dependencies uses: ./.github/workflows/install-shared-dependencies with: @@ -91,12 +101,8 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} engine-version: ${{ matrix.engine.version }} - - name: Format - working-directory: ./csharp - run: dotnet format --verify-no-changes --verbosity diagnostic - - name: Test dotnet ${{ matrix.dotnet }} - working-directory: ./csharp + working-directory: csharp run: dotnet test --framework net${{ matrix.dotnet }} "-l:html;LogFileName=TestReport.html" --results-directory . -warnaserror - uses: ./.github/workflows/test-benchmark @@ -108,21 +114,95 @@ jobs: continue-on-error: true uses: actions/upload-artifact@v4 with: - name: test-reports-dotnet-${{ matrix.dotnet }}-redis-${{ matrix.redis }}-${{ matrix.host.RUNNER }} + name: test-reports-dotnet-${{ matrix.dotnet }}-${{ matrix.engine.type }}-${{ matrix.engine.version }}-${{ matrix.host.RUNNER }} path: | csharp/TestReport.html benchmarks/results/* utils/clusters/** -# TODO Add amazonlinux + get-containers: + runs-on: ubuntu-latest + if: ${{ github.event.inputs.full-matrix == 'true' || github.event_name == 'schedule' }} + outputs: + engine-matrix-output: ${{ steps.get-matrices.outputs.engine-matrix-output }} + host-matrix-output: ${{ steps.get-matrices.outputs.host-matrix-output }} + version-matrix-output: ${{ steps.get-matrices.outputs.version-matrix-output }} + + steps: + - uses: actions/checkout@v4 + - id: get-matrices + uses: ./.github/workflows/create-test-matrices + with: + language-name: dotnet + run-full-matrix: true + containers: true + + test-csharp-container: + runs-on: ${{ matrix.host.RUNNER }} + needs: [get-containers] + timeout-minutes: 25 + strategy: + fail-fast: false + matrix: + # Don't use generated matrix for dotnet until net6.0 compatibility issues resolved on amazon linux + # dotnet: ${{ fromJson(needs.get-containers.outputs.version-matrix-output) }} + dotnet: ["8.0"] + engine: ${{ fromJson(needs.get-containers.outputs.engine-matrix-output) }} + host: ${{ fromJson(needs.get-containers.outputs.host-matrix-output) }} + container: + image: ${{ matrix.host.IMAGE }} + options: ${{ join(' -q ', matrix.host.CONTAINER_OPTIONS) }} # adding `-q` to bypass empty options + steps: + - name: Install git + run: | + yum update + yum install -y git tar findutils libicu + echo IMAGE=amazonlinux:latest | sed -r 's/:/-/g' >> $GITHUB_ENV + # Replace `:` in the variable otherwise it can't be used in `upload-artifact` + - uses: actions/checkout@v4 + with: + submodules: recursive + + - name: Set up dotnet ${{ matrix.dotnet }} + uses: actions/setup-dotnet@v4 + with: + dotnet-version: ${{ matrix.dotnet }} + + - name: Install shared software dependencies + uses: ./.github/workflows/install-shared-dependencies + with: + os: ${{ matrix.host.OS }} + target: ${{ matrix.host.TARGET }} + github-token: ${{ secrets.GITHUB_TOKEN }} + engine-version: ${{ matrix.engine.version }} + + - name: Test dotnet ${{ matrix.dotnet }} + working-directory: csharp + run: dotnet test --framework net${{ matrix.dotnet }} "-l:html;LogFileName=TestReport.html" --results-directory . -warnaserror + + - name: Upload test reports + if: always() + continue-on-error: true + uses: actions/upload-artifact@v4 + with: + name: test-reports-dotnet-${{ matrix.dotnet }}-${{ matrix.engine.type }}-${{ matrix.engine.version }}-${{ env.IMAGE }}-${{ matrix.host.ARCH }} + path: | + csharp/TestReport.html + benchmarks/results/* + utils/clusters/** - lint-rust: + lint: timeout-minutes: 10 runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - - uses: ./.github/workflows/lint-rust + - name: lint rust + uses: ./.github/workflows/lint-rust with: - cargo-toml-folder: ./csharp/lib + cargo-toml-folder: csharp/lib + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Format + working-directory: csharp + run: dotnet format --verify-no-changes --verbosity diagnostic diff --git a/.github/workflows/full-matrix-tests.yml b/.github/workflows/full-matrix-tests.yml new file mode 100644 index 0000000000..85fa3dbe08 --- /dev/null +++ b/.github/workflows/full-matrix-tests.yml @@ -0,0 +1,126 @@ +name: Full Matrix tests + +on: + workflow_dispatch: # note: if started manually, it won't run all matrix + inputs: + full-matrix: + description: "Run the full engine and host matrix" + type: boolean + default: false + # GHA supports up to 10 inputs, there is no option for multi-choice + core: + description: "Test GLIDE core" + type: boolean + default: true + redis-rs: + description: "Test Redis-RS client" + type: boolean + default: true + node: + description: "Test Node client" + type: boolean + default: true + python: + description: "Test Python client" + type: boolean + default: true + java: + description: "Test Java client" + type: boolean + default: true + csharp: + description: "Test C# client" + type: boolean + default: false + go: + description: "Test Golang client" + type: boolean + default: false + + schedule: + - cron: "0 3 * * *" + +concurrency: + group: full-matrix-tests + cancel-in-progress: false + +# TODO matrix by workflow (`uses`) - not supported yet by GH +jobs: + check-running-workflow: + runs-on: ubuntu-latest + outputs: + is_running: ${{ steps.check.outputs.is_running }} + steps: + - name: Check if the same workflow is running + id: check + uses: actions/github-script@v6 + with: + script: | + const { data } = await github.rest.actions.listWorkflowRuns({ + owner: context.repo.owner, + repo: context.repo.repo, + workflow_id: context.workflow, + status: 'in_progress', + }); + + const isRunning = data.workflow_runs.some(run => run.id !== context.runId); + + core.setOutput('is_running', isRunning.toString()); + + check-input: + runs-on: ubuntu-latest + needs: check-running-workflow + if: needs.check-running-workflow.outputs.is_running == 'false' + steps: + - name: No tests selected + run: echo "No tests selected." + if: github.event_name == 'workflow_dispatch' && inputs.core == 'false' && inputs.java == 'false' && inputs.python == 'false' && inputs.node == 'false' && inputs.csharp == 'false' && inputs.go == 'false' + + run-full-tests-for-core: + needs: check-running-workflow + if: needs.check-running-workflow.outputs.is_running == 'false' && inputs.core == 'true' + uses: ./.github/workflows/rust.yml + name: Run CI for GLIDE core lib + secrets: inherit + + run-full-tests-for-redis-rs: + needs: check-running-workflow + if: needs.check-running-workflow.outputs.is_running == 'false' && inputs.redis-rs == 'true' + uses: ./.github/workflows/redis-rs.yml + name: Run CI for Redis-RS client + secrets: inherit + + run-full-tests-for-java: + needs: check-running-workflow + if: needs.check-running-workflow.outputs.is_running == 'false' && inputs.java == 'true' + uses: ./.github/workflows/java.yml + name: Run CI for Java client + secrets: inherit + + run-full-tests-for-python: + needs: check-running-workflow + if: needs.check-running-workflow.outputs.is_running == 'false' && inputs.python == 'true' + uses: ./.github/workflows/python.yml + name: Run CI for Python client + secrets: inherit + + run-full-tests-for-node: + needs: check-running-workflow + if: needs.check-running-workflow.outputs.is_running == 'false' && inputs.node == 'true' + uses: ./.github/workflows/node.yml + name: Run CI for Node client + secrets: inherit + + run-full-tests-for-csharp: + needs: check-running-workflow + if: needs.check-running-workflow.outputs.is_running == 'false' && inputs.csharp == 'true' + uses: ./.github/workflows/csharp.yml + name: Run CI for C# client + secrets: inherit + + run-full-tests-for-go: + needs: check-running-workflow + if: needs.check-running-workflow.outputs.is_running == 'false' && inputs.go == 'true' + uses: ./.github/workflows/go.yml + name: Run CI for Go client + secrets: inherit diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 3290839a6a..1d17640188 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -9,71 +9,81 @@ on: paths: - glide-core/src/** - glide-core/redis-rs/redis/src/** + - utils/cluster_manager.py - go/** - .github/workflows/go.yml - .github/workflows/install-shared-dependencies/action.yml - .github/workflows/test-benchmark/action.yml - .github/workflows/lint-rust/action.yml - - .github/workflows/install-valkey/action.yml - - .github/json_matrices/build-matrix.json + - .github/workflows/install-engine/action.yml + - .github/workflows/create-test-matrices/action.yml + - .github/json_matrices/** pull_request: paths: - glide-core/src/** - glide-core/redis-rs/redis/src/** + - utils/cluster_manager.py - go/** - .github/workflows/go.yml - .github/workflows/install-shared-dependencies/action.yml - .github/workflows/test-benchmark/action.yml - .github/workflows/lint-rust/action.yml - - .github/workflows/install-valkey/action.yml - - .github/json_matrices/build-matrix.json + - .github/workflows/install-engine/action.yml + - .github/workflows/create-test-matrices/action.yml + - .github/json_matrices/** workflow_dispatch: + inputs: + full-matrix: + description: "Run the full engine, host, and language version matrix" + type: boolean + default: false + name: + required: false + type: string + description: "(Optional) Test run name" + + workflow_call: concurrency: - group: go-${{ github.head_ref || github.ref }} + group: go-${{ github.head_ref || github.ref }}-${{ toJson(inputs) }} cancel-in-progress: true +run-name: + # Set custom name if job is started manually and name is given + ${{ github.event_name == 'workflow_dispatch' && (inputs.name == '' && format('{0} @ {1} {2}', github.ref_name, github.sha, toJson(inputs)) || inputs.name) || '' }} + +env: + CARGO_TERM_COLOR: always + jobs: - load-engine-matrix: - runs-on: ubuntu-latest - outputs: - matrix: ${{ steps.load-engine-matrix.outputs.matrix }} - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Load the engine matrix - id: load-engine-matrix - shell: bash - run: echo "matrix=$(jq -c . < .github/json_matrices/engine-matrix.json)" >> $GITHUB_OUTPUT - - build-and-test-go-client: - needs: load-engine-matrix + get-matrices: + runs-on: ubuntu-latest + outputs: + engine-matrix-output: ${{ steps.get-matrices.outputs.engine-matrix-output }} + host-matrix-output: ${{ steps.get-matrices.outputs.host-matrix-output }} + version-matrix-output: ${{ steps.get-matrices.outputs.version-matrix-output }} + steps: + - uses: actions/checkout@v4 + - id: get-matrices + uses: ./.github/workflows/create-test-matrices + with: + language-name: go + # Run full test matrix if job started by cron or it was explictly specified by a person who triggered the workflow + run-full-matrix: ${{ github.event.inputs.full-matrix == 'true' || github.event_name == 'schedule' }} + + test-go: + needs: get-matrices timeout-minutes: 35 strategy: - # Run all jobs fail-fast: false matrix: - go: - - '1.22.0' - engine: ${{ fromJson(needs.load-engine-matrix.outputs.matrix) }} - host: - - { - OS: ubuntu, - RUNNER: ubuntu-latest, - TARGET: x86_64-unknown-linux-gnu - } - # - { - # OS: macos, - # RUNNER: macos-latest, - # TARGET: aarch64-apple-darwin - # } - + go: ${{ fromJson(needs.get-matrices.outputs.version-matrix-output) }} + engine: ${{ fromJson(needs.get-matrices.outputs.engine-matrix-output) }} + host: ${{ fromJson(needs.get-matrices.outputs.host-matrix-output) }} runs-on: ${{ matrix.host.RUNNER }} steps: - uses: actions/checkout@v4 - - name: Set up Go ${{ matrix.go }} uses: actions/setup-go@v5 @@ -101,14 +111,9 @@ jobs: working-directory: ./go run: make build - - name: Run linters - working-directory: ./go - run: make lint-ci - - name: Run tests working-directory: ./go - run: | - make test + run: make test - uses: ./.github/workflows/test-benchmark with: @@ -119,93 +124,106 @@ jobs: continue-on-error: true uses: actions/upload-artifact@v4 with: - name: reports-go-${{ matrix.go }}-redis-${{ matrix.redis }}-${{ matrix.os }} + name: test-report-go-${{ matrix.go }}-${{ matrix.engine.type }}-${{ matrix.engine.version }}-${{ matrix.host.RUNNER }} path: | utils/clusters/** benchmarks/results/** - build-amazonlinux-latest: - if: github.repository_owner == 'valkey-io' + lint: + timeout-minutes: 10 + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: ./.github/workflows/lint-rust + with: + cargo-toml-folder: go + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Set up Go ${{ matrix.go }} + uses: actions/setup-go@v5 + with: + go-version: "1.22.0" + cache-dependency-path: go/go.sum + + - name: Install protoc + uses: ./.github/workflows/install-rust-and-protoc + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Install and run linters + working-directory: go + run: | + make install-dev-tools install-build-tools build lint-ci + + get-containers: + runs-on: ubuntu-latest + if: ${{ github.event.inputs.full-matrix == 'true' || github.event_name == 'schedule' }} + outputs: + engine-matrix-output: ${{ steps.get-matrices.outputs.engine-matrix-output }} + host-matrix-output: ${{ steps.get-matrices.outputs.host-matrix-output }} + version-matrix-output: ${{ steps.get-matrices.outputs.version-matrix-output }} + + steps: + - uses: actions/checkout@v4 + - id: get-matrices + uses: ./.github/workflows/create-test-matrices + with: + language-name: go + run-full-matrix: true + containers: true + + test-go-container: + runs-on: ${{ matrix.host.RUNNER }} + needs: [get-containers] + timeout-minutes: 25 strategy: - # Run all jobs fail-fast: false matrix: - go: - - 1.22.0 - runs-on: ubuntu-latest - container: amazonlinux:latest - timeout-minutes: 15 + go: ${{ fromJson(needs.get-containers.outputs.version-matrix-output) }} + engine: ${{ fromJson(needs.get-containers.outputs.engine-matrix-output) }} + host: ${{ fromJson(needs.get-containers.outputs.host-matrix-output) }} + container: + image: ${{ matrix.host.IMAGE }} + options: ${{ join(' -q ', matrix.host.CONTAINER_OPTIONS) }} # adding `-q` to bypass empty options steps: - name: Install git run: | - yum -y remove git - yum -y remove git-* - yum -y install https://packages.endpointdev.com/rhel/7/os/x86_64/endpoint-repo.x86_64.rpm yum update - yum install -y git - git --version - + yum install -y git tar + git config --global --add safe.directory "$GITHUB_WORKSPACE" + echo IMAGE=amazonlinux:latest | sed -r 's/:/-/g' >> $GITHUB_ENV + # Replace `:` in the variable otherwise it can't be used in `upload-artifact` - uses: actions/checkout@v4 + with: + submodules: recursive - - name: Checkout submodules - run: | - git config --global --add safe.directory "$GITHUB_WORKSPACE" - git submodule update --init --recursive + - name: Set up Go ${{ matrix.go }} + uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go }} + cache-dependency-path: go/go.sum - name: Install shared software dependencies uses: ./.github/workflows/install-shared-dependencies with: - os: "amazon-linux" - target: "x86_64-unknown-linux-gnu" + os: ${{ matrix.host.OS }} + target: ${{ matrix.host.TARGET }} github-token: ${{ secrets.GITHUB_TOKEN }} - engine-version: "7.2.5" - - - name: Install Go - run: | - yum -y install wget - yum -y install tar - wget https://go.dev/dl/go${{ matrix.go }}.linux-amd64.tar.gz - tar -C /usr/local -xzf go${{ matrix.go }}.linux-amd64.tar.gz - echo "/usr/local/go/bin" >> $GITHUB_PATH - echo "$HOME/go/bin" >> $GITHUB_PATH - - - name: Install tools for Go ${{ matrix.go }} - working-directory: ./go - run: make install-tools-go${{ matrix.go }} - - - name: Set LD_LIBRARY_PATH - run: echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$GITHUB_WORKSPACE/go/target/release/deps/" >> $GITHUB_ENV - - - name: Build client - working-directory: ./go - run: make build - - - name: Run linters - working-directory: ./go - run: make lint-ci + engine-version: ${{ matrix.engine.version }} - - name: Run tests - working-directory: ./go + - name: Install & build & test + working-directory: go run: | - make test + LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$GITHUB_WORKSPACE/go/target/release/deps/ + make install-tools-go${{ matrix.go }} build test - - name: Upload cluster manager logs + - name: Upload test reports if: always() continue-on-error: true uses: actions/upload-artifact@v4 with: - name: cluster-manager-logs-${{ matrix.go }}-redis-6-amazonlinux + name: test-reports-go-${{ matrix.go }}-${{ matrix.engine.type }}-${{ matrix.engine.version }}-${{ env.IMAGE }}-${{ matrix.host.ARCH }} path: | utils/clusters/** - - lint-rust: - timeout-minutes: 15 - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - - uses: ./.github/workflows/lint-rust - with: - cargo-toml-folder: ./go - name: lint go rust + benchmarks/results/** diff --git a/.github/workflows/install-engine/action.yml b/.github/workflows/install-engine/action.yml new file mode 100644 index 0000000000..6f28e02d11 --- /dev/null +++ b/.github/workflows/install-engine/action.yml @@ -0,0 +1,102 @@ +name: Install Engine + +inputs: + engine-version: + description: "Engine version to install" + required: true + type: string + target: + description: "Specified target toolchain, ex. x86_64-unknown-linux-gnu" + type: string + required: true + options: + - x86_64-unknown-linux-gnu + - aarch64-unknown-linux-gnu + - x86_64-apple-darwin + - aarch64-apple-darwin + - aarch64-unknown-linux-musl + - x86_64-unknown-linux-musl + +env: + CARGO_TERM_COLOR: always + +runs: + using: "composite" + + # TODO: self-hosted runners are actually cloning the repo, using the cache from the previous run + # will not work as expected. We need to find a way to cache the valkey repo on the runner itself. + steps: + - name: Cache Valkey + if: ${{!contains(inputs.target, 'aarch64-unknown') }} + uses: actions/cache@v4 + id: cache-valkey + with: + path: | + ~/valkey + key: valkey-${{ inputs.engine-version }}-${{ inputs.target }} + + - name: Build Valkey for ARM + if: ${{ contains(inputs.target, 'aarch64-unknown') }} + shell: bash + working-directory: ~ + run: | + cd ~ + echo "Building valkey ${{ inputs.engine-version }}" + # check if the valkey repo is already cloned + if [[ ! -d valkey ]]; then + git clone https://github.com/valkey-io/valkey.git + else + # check if the branch=version is already checked out + if [[ $(git branch --show-current) != ${{ inputs.engine-version }} ]]; then + cd valkey + make clean + make distclean + sudo rm -rf /usr/local/bin/redis-* /usr/local/bin/valkey-* ./valkey-* ./redis-* ./dump.rdb + git fetch --all + git checkout ${{ inputs.engine-version }} + git pull + fi + fi + # if no cache hit, build the engine + - name: Build Valkey + if: ${{ steps.cache-valkey.outputs.cache-hit != 'true' && !contains(inputs.target, 'aarch64-unknown') }} + shell: bash + run: | + echo "Building valkey ${{ inputs.engine-version }}" + cd ~ + git clone https://github.com/valkey-io/valkey.git + cd valkey + git checkout ${{ inputs.engine-version }} + + - name: Install engine + shell: bash + run: | + cd ~/valkey + make BUILD_TLS=yes + if command -v sudo &> /dev/null + then + echo "sudo command exists" + sudo make install + else + echo "sudo command does not exist" + make install + fi + echo 'export PATH=/usr/local/bin:$PATH' >>~/.bash_profile + + # TODO: This seems redundant to me. Is it necessary? Do we check that the Python we install is the correct version? + # Why here and not elsewhere? All Git git repos were created equal + - name: Verify Valkey installation and symlinks + if: ${{ !contains(inputs.engine-version, '-rc') }} + shell: bash + run: | + # In Valkey releases, the engine is built with symlinks from valkey-server and valkey-cli + # to redis-server and redis-cli. This step ensures that the engine is properly installed + # with the expected version and that Valkey symlinks are correctly created. + EXPECTED_VERSION=`echo ${{ inputs.engine-version }} | sed -e "s/^redis-//"` + INSTALLED_VER=$(redis-server -v) + if [[ $INSTALLED_VER != *"${EXPECTED_VERSION}"* ]]; then + echo "Wrong version has been installed. Expected: $EXPECTED_VERSION, Installed: $INSTALLED_VER" + exit 1 + else + echo "Successfully installed the server: $INSTALLED_VER" + fi diff --git a/.github/workflows/install-redis/action.yml b/.github/workflows/install-redis/action.yml deleted file mode 100644 index b60f0687b5..0000000000 --- a/.github/workflows/install-redis/action.yml +++ /dev/null @@ -1,46 +0,0 @@ -name: Install Redis - -inputs: - redis-version: - description: "redis version to install" - required: true - type: string - -env: - CARGO_TERM_COLOR: always - -runs: - using: "composite" - - steps: - - run: mkdir -p ~/redis-binaries/${{ inputs.redis-version }} - shell: bash - - - uses: actions/checkout@v4 - - - - uses: actions/cache@v3 - id: cache-redis - with: - path: | - ~/redis-binaries/${{ inputs.redis-version }}/redis-cli - ~/redis-binaries/${{ inputs.redis-version }}/redis-server - key: ${{ runner.os }}-${{ inputs.redis-version }}-install-redis - - - name: Install redis - shell: bash - if: steps.cache-redis.outputs.cache-hit != 'true' - run: | - sudo apt-get update - wget https://github.com/redis/redis/archive/${{ inputs.redis-version }}.tar.gz; - tar -xzvf ${{ inputs.redis-version }}.tar.gz; - pushd redis-${{ inputs.redis-version }} && BUILD_TLS=yes make && sudo mv src/redis-server src/redis-cli ~/redis-binaries/${{ inputs.redis-version }} && popd; - - - name: Remove the source package - shell: bash - if: steps.cache-redis.outputs.cache-hit != 'true' - run: sudo rm -r redis-${{ inputs.redis-version }} - - - name: Copy executable to place - shell: bash - run: sudo cp ~/redis-binaries/${{ inputs.redis-version }}/redis-server ~/redis-binaries/${{ inputs.redis-version }}/redis-cli /usr/bin/ diff --git a/.github/workflows/install-rust-and-protoc/action.yml b/.github/workflows/install-rust-and-protoc/action.yml index e1222ffd9d..31987ba04a 100644 --- a/.github/workflows/install-rust-and-protoc/action.yml +++ b/.github/workflows/install-rust-and-protoc/action.yml @@ -16,7 +16,6 @@ inputs: type: string required: true - runs: using: "composite" steps: diff --git a/.github/workflows/install-shared-dependencies/action.yml b/.github/workflows/install-shared-dependencies/action.yml index ed065e9840..1e64d939a3 100644 --- a/.github/workflows/install-shared-dependencies/action.yml +++ b/.github/workflows/install-shared-dependencies/action.yml @@ -30,7 +30,6 @@ inputs: required: true type: string - runs: using: "composite" steps: @@ -47,7 +46,7 @@ runs: run: | sudo apt update -y sudo apt install -y git gcc pkg-config openssl libssl-dev - + - name: Install software dependencies for Ubuntu MUSL shell: bash if: "${{ contains(inputs.target, 'musl') }}" @@ -67,12 +66,12 @@ runs: if: "${{ !contains(inputs.target, 'musl') }}" uses: ./.github/workflows/install-rust-and-protoc with: - target: ${{ inputs.target }} - github-token: ${{ inputs.github-token }} + target: ${{ inputs.target }} + github-token: ${{ inputs.github-token }} - - name: Install Valkey - if: ${{ inputs.engine-version != '' }} - uses: ./.github/workflows/install-valkey + - name: Install engine + if: "${{ inputs.engine-version }}" + uses: ./.github/workflows/install-engine with: engine-version: ${{ inputs.engine-version }} target: ${{ inputs.target }} diff --git a/.github/workflows/install-valkey/action.yml b/.github/workflows/install-valkey/action.yml deleted file mode 100644 index 74c75572a4..0000000000 --- a/.github/workflows/install-valkey/action.yml +++ /dev/null @@ -1,79 +0,0 @@ -name: Install Valkey - -inputs: - engine-version: - description: "Engine version to install" - required: true - type: string - target: - description: "Specified target toolchain, ex. x86_64-unknown-linux-gnu" - type: string - required: true - options: - - x86_64-unknown-linux-gnu - - aarch64-unknown-linux-gnu - - x86_64-apple-darwin - - aarch64-apple-darwin - - aarch64-unknown-linux-musl - - x86_64-unknown-linux-musl - -env: - CARGO_TERM_COLOR: always - VALKEY_MIN_VERSION: "7.2.5" - -runs: - using: "composite" - - steps: - - name: Cache Valkey - # TODO: remove the musl ARM64 limitation when https://github.com/actions/runner/issues/801 is resolved - if: ${{ inputs.target != 'aarch64-unknown-linux-musl' }} - uses: actions/cache@v4 - id: cache-valkey - with: - path: | - ~/valkey - key: valkey-${{ inputs.engine-version }}-${{ inputs.target }} - - - name: Build Valkey - if: ${{ steps.cache-valkey.outputs.cache-hit != 'true' }} - shell: bash - run: | - echo "Building valkey ${{ inputs.engine-version }}" - cd ~ - rm -rf valkey - git clone https://github.com/valkey-io/valkey.git - cd valkey - git checkout ${{ inputs.engine-version }} - make BUILD_TLS=yes - - - name: Install Valkey - shell: bash - run: | - cd ~/valkey - if command -v sudo &> /dev/null - then - echo "sudo command exists" - sudo make install - else - echo "sudo command does not exist" - make install - fi - echo 'export PATH=/usr/local/bin:$PATH' >>~/.bash_profile - - - name: Verify Valkey installation and symlinks - if: ${{ !contains(inputs.engine-version, '-rc') }} - shell: bash - run: | - # In Valkey releases, the engine is built with symlinks from valkey-server and valkey-cli - # to redis-server and redis-cli. This step ensures that the engine is properly installed - # with the expected version and that Valkey symlinks are correctly created. - EXPECTED_VERSION=`echo ${{ inputs.engine-version }} | sed -e "s/^redis-//"` - INSTALLED_VER=$(redis-server -v) - if [[ $INSTALLED_VER != *"${EXPECTED_VERSION}"* ]]; then - echo "Wrong version has been installed. Expected: $EXPECTED_VERSION, Installed: $INSTALLED_VER" - exit 1 - else - echo "Successfully installed the server: $INSTALLED_VER" - fi - diff --git a/.github/workflows/java-cd.yml b/.github/workflows/java-cd.yml index d3f2038313..f4c0146342 100644 --- a/.github/workflows/java-cd.yml +++ b/.github/workflows/java-cd.yml @@ -86,8 +86,6 @@ jobs: echo "No cleaning needed" fi - uses: actions/checkout@v4 - - - name: Set up JDK uses: actions/setup-java@v4 with: @@ -211,7 +209,12 @@ jobs: exit 1 test-deployment-on-all-architectures: - needs: [set-release-version, load-platform-matrix, publish-to-maven-central-deployment] + needs: + [ + set-release-version, + load-platform-matrix, + publish-to-maven-central-deployment, + ] env: JAVA_VERSION: "11" RELEASE_VERSION: ${{ needs.set-release-version.outputs.RELEASE_VERSION }} @@ -224,7 +227,6 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - - name: Set up JDK uses: actions/setup-java@v4 @@ -236,7 +238,7 @@ jobs: uses: ./.github/workflows/install-shared-dependencies with: os: ${{ matrix.host.OS }} - engine-version: "7.2.5" + engine-version: "7.2" target: ${{ matrix.host.TARGET }} github-token: ${{ secrets.GITHUB_TOKEN }} @@ -265,7 +267,11 @@ jobs: publish-release-to-maven: if: ${{ inputs.maven_publish == true || github.event_name == 'push' }} - needs: [publish-to-maven-central-deployment, test-deployment-on-all-architectures] + needs: + [ + publish-to-maven-central-deployment, + test-deployment-on-all-architectures, + ] runs-on: ubuntu-latest environment: AWS_ACTIONS env: @@ -279,7 +285,11 @@ jobs: drop-deployment-if-validation-fails: if: ${{ failure() }} - needs: [publish-to-maven-central-deployment, test-deployment-on-all-architectures] + needs: + [ + publish-to-maven-central-deployment, + test-deployment-on-all-architectures, + ] runs-on: ubuntu-latest env: DEPLOYMENT_ID: ${{ needs.publish-to-maven-central-deployment.outputs.DEPLOYMENT_ID }} diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index 2c31562f78..66c99cca3e 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -10,71 +10,79 @@ on: - glide-core/src/** - glide-core/redis-rs/redis/src/** - java/** + - utils/cluster_manager.py - .github/workflows/java.yml - .github/workflows/install-shared-dependencies/action.yml - .github/workflows/test-benchmark/action.yml - .github/workflows/lint-rust/action.yml - - .github/workflows/install-valkey/action.yml - - .github/json_matrices/build-matrix.json + - .github/workflows/install-engine/action.yml + - .github/workflows/create-test-matrices/action.yml + - .github/json_matrices/** + pull_request: paths: - glide-core/src/** - glide-core/redis-rs/redis/src/** - java/** + - utils/cluster_manager.py - .github/workflows/java.yml - .github/workflows/install-shared-dependencies/action.yml - .github/workflows/test-benchmark/action.yml - .github/workflows/lint-rust/action.yml - - .github/workflows/install-valkey/action.yml - - .github/json_matrices/build-matrix.json + - .github/workflows/install-engine/action.yml + - .github/workflows/create-test-matrices/action.yml + - .github/json_matrices/** + workflow_dispatch: + inputs: + full-matrix: + description: "Run the full engine, host, and language version matrix" + type: boolean + default: false + name: + required: false + type: string + description: "(Optional) Test run name" + + workflow_call: concurrency: - group: java-${{ github.head_ref || github.ref }} + group: java-${{ github.head_ref || github.ref }}-${{ toJson(inputs) }} cancel-in-progress: true +run-name: + # Set custom name if job is started manually and name is given + ${{ github.event_name == 'workflow_dispatch' && (inputs.name == '' && format('{0} @ {1} {2}', github.ref_name, github.sha, toJson(inputs)) || inputs.name) || '' }} + jobs: - load-engine-matrix: + get-matrices: runs-on: ubuntu-latest outputs: - matrix: ${{ steps.load-engine-matrix.outputs.matrix }} + engine-matrix-output: ${{ steps.get-matrices.outputs.engine-matrix-output }} + host-matrix-output: ${{ steps.get-matrices.outputs.host-matrix-output }} + version-matrix-output: ${{ steps.get-matrices.outputs.version-matrix-output }} steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Load the engine matrix - id: load-engine-matrix - shell: bash - run: echo "matrix=$(jq -c . < .github/json_matrices/engine-matrix.json)" >> $GITHUB_OUTPUT + - uses: actions/checkout@v4 + - id: get-matrices + uses: ./.github/workflows/create-test-matrices + with: + language-name: java + # Run full test matrix if job started by cron or it was explictly specified by a person who triggered the workflow + run-full-matrix: ${{ github.event.inputs.full-matrix == 'true' || github.event_name == 'schedule' }} - build-and-test-java-client: - needs: load-engine-matrix + test-java: + needs: get-matrices timeout-minutes: 35 strategy: - # Run all jobs fail-fast: false matrix: - java: - # - 11 - - 17 - engine: ${{ fromJson(needs.load-engine-matrix.outputs.matrix) }} - host: - - { - OS: ubuntu, - RUNNER: ubuntu-latest, - TARGET: x86_64-unknown-linux-gnu, - } - # - { - # OS: macos, - # RUNNER: macos-latest, - # TARGET: aarch64-apple-darwin - # } - + java: ${{ fromJson(needs.get-matrices.outputs.version-matrix-output) }} + engine: ${{ fromJson(needs.get-matrices.outputs.engine-matrix-output) }} + host: ${{ fromJson(needs.get-matrices.outputs.host-matrix-output) }} runs-on: ${{ matrix.host.RUNNER }} steps: - uses: actions/checkout@v4 - - uses: gradle/actions/wrapper-validation@v3 @@ -107,6 +115,7 @@ jobs: run: ./gradlew spotlessDiagnose | grep 'All formatters are well behaved for all files' - uses: ./.github/workflows/test-benchmark + if: ${{ matrix.engine.version == '8.0' && matrix.host.RUNNER == 'ubuntu-latest' && matrix.java == '17' }} with: language-flag: -java @@ -123,42 +132,54 @@ jobs: benchmarks/results/** java/client/build/reports/spotbugs/** - build-amazonlinux-latest: - if: github.repository_owner == 'valkey-io' + get-containers: + runs-on: ubuntu-latest + if: ${{ github.event.inputs.full-matrix == 'true' || github.event_name == 'schedule' }} + outputs: + engine-matrix-output: ${{ steps.get-matrices.outputs.engine-matrix-output }} + host-matrix-output: ${{ steps.get-matrices.outputs.host-matrix-output }} + version-matrix-output: ${{ steps.get-matrices.outputs.version-matrix-output }} + + steps: + - uses: actions/checkout@v4 + - id: get-matrices + uses: ./.github/workflows/create-test-matrices + with: + language-name: java + run-full-matrix: true + containers: true + + test-java-container: + runs-on: ${{ matrix.host.RUNNER }} + needs: [get-containers] + timeout-minutes: 25 strategy: - # Run all jobs fail-fast: false matrix: - java: - # - 11 - - 17 - runs-on: ubuntu-latest - container: amazonlinux:latest - timeout-minutes: 35 + java: ${{ fromJson(needs.get-containers.outputs.version-matrix-output) }} + engine: ${{ fromJson(needs.get-containers.outputs.engine-matrix-output) }} + host: ${{ fromJson(needs.get-containers.outputs.host-matrix-output) }} + container: + image: ${{ matrix.host.IMAGE }} + options: ${{ join(' -q ', matrix.host.CONTAINER_OPTIONS) }} # adding `-q` to bypass empty options steps: - name: Install git run: | - yum -y remove git - yum -y remove git-* - yum -y install https://packages.endpointdev.com/rhel/7/os/x86_64/endpoint-repo.x86_64.rpm yum update - yum install -y git - git --version - + yum install -y git tar java-${{ matrix.java }}-amazon-corretto-devel.x86_64 + echo IMAGE=amazonlinux:latest | sed -r 's/:/-/g' >> $GITHUB_ENV + # Replace `:` in the variable otherwise it can't be used in `upload-artifact` - uses: actions/checkout@v4 - - - name: Checkout submodules - run: | - git config --global --add safe.directory "$GITHUB_WORKSPACE" - git submodule update --init --recursive + with: + submodules: recursive - name: Install shared software dependencies uses: ./.github/workflows/install-shared-dependencies with: - os: "amazon-linux" - target: "x86_64-unknown-linux-gnu" + os: ${{ matrix.host.OS }} + target: ${{ matrix.host.TARGET }} github-token: ${{ secrets.GITHUB_TOKEN }} - engine-version: "7.2.5" + engine-version: ${{ matrix.engine.version }} - name: Install protoc (protobuf) uses: arduino/setup-protoc@v3 @@ -166,10 +187,6 @@ jobs: version: "26.1" repo-token: ${{ secrets.GITHUB_TOKEN }} - - name: Install Java - run: | - yum install -y java-${{ matrix.java }}-amazon-corretto-devel.x86_64 - - name: Build java wrapper working-directory: java run: ./gradlew --continue build -x javadoc @@ -179,7 +196,7 @@ jobs: continue-on-error: true uses: actions/upload-artifact@v4 with: - name: test-reports-${{ matrix.java }}-amazon-linux + name: test-reports-java-${{ matrix.java }}-${{ matrix.engine.type }}-${{ matrix.engine.version }}-${{ env.IMAGE }}-${{ matrix.host.ARCH }} path: | java/client/build/reports/** java/integTest/build/reports/** @@ -190,11 +207,11 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: ./.github/workflows/lint-rust with: - cargo-toml-folder: ./java + cargo-toml-folder: java + github-token: ${{ secrets.GITHUB_TOKEN }} name: lint java rust test-modules: @@ -208,7 +225,6 @@ jobs: run: sudo chown -R $USER:$USER /home/ubuntu/actions-runner/_work/valkey-glide - uses: actions/checkout@v4 - - name: Set up JDK uses: actions/setup-java@v4 diff --git a/.github/workflows/lint-rust/action.yml b/.github/workflows/lint-rust/action.yml index 06b0b7a75a..35c5e313c5 100644 --- a/.github/workflows/lint-rust/action.yml +++ b/.github/workflows/lint-rust/action.yml @@ -12,10 +12,8 @@ inputs: runs: using: "composite" - steps: - uses: actions/checkout@v4 - - name: Install Rust toolchain and protoc uses: ./.github/workflows/install-rust-and-protoc @@ -35,7 +33,7 @@ runs: # We run clippy without features - run: cargo clippy --all-targets -- -D warnings working-directory: ${{ inputs.cargo-toml-folder }} - shell: bash + shell: bash - run: | cargo update diff --git a/.github/workflows/lint-ts/action.yml b/.github/workflows/lint-ts/action.yml index 834e3d7ec9..2f4b47e358 100644 --- a/.github/workflows/lint-ts/action.yml +++ b/.github/workflows/lint-ts/action.yml @@ -12,7 +12,7 @@ runs: steps: - uses: actions/checkout@v4 - - run: cp eslint.config.mjs ${{ inputs.package-folder }} + - run: cp eslint.config.mjs ${{ inputs.package-folder }} shell: bash - run: | diff --git a/.github/workflows/node-create-package-file/action.yml b/.github/workflows/node-create-package-file/action.yml index 1da7510aab..8c9cc9d2f2 100644 --- a/.github/workflows/node-create-package-file/action.yml +++ b/.github/workflows/node-create-package-file/action.yml @@ -84,4 +84,4 @@ runs: mv package.json package.json.tmpl envsubst < package.json.tmpl > "package.json" cat package.json - echo $(ls *json*) + echo $(ls *json*) diff --git a/.github/workflows/node.yml b/.github/workflows/node.yml index 1927649247..1ed9488e01 100644 --- a/.github/workflows/node.yml +++ b/.github/workflows/node.yml @@ -16,9 +16,9 @@ on: - .github/workflows/install-shared-dependencies/action.yml - .github/workflows/test-benchmark/action.yml - .github/workflows/lint-rust/action.yml - - .github/workflows/install-valkey/action.yml - - .github/json_matrices/build-matrix.json - - .github/workflows/start-self-hosted-runner/action.yml + - .github/workflows/install-engine/action.yml + - .github/json_matrices/** + - .github/workflows/create-test-matrices/action.yml pull_request: paths: - glide-core/src/** @@ -30,55 +30,76 @@ on: - .github/workflows/install-shared-dependencies/action.yml - .github/workflows/test-benchmark/action.yml - .github/workflows/lint-rust/action.yml - - .github/workflows/install-valkey/action.yml - - .github/json_matrices/build-matrix.json - - .github/workflows/start-self-hosted-runner/action.yml + - .github/workflows/install-engine/action.yml + - .github/json_matrices/** + - .github/workflows/create-test-matrices/action.yml workflow_dispatch: + inputs: + full-matrix: + description: "Run the full engine, host, and language version matrix" + type: boolean + default: false + name: + required: false + type: string + description: "(Optional) Test run name" + + workflow_call: concurrency: - group: node-${{ github.head_ref || github.ref }} + group: node-${{ github.head_ref || github.ref }}-${{ toJson(inputs) }} cancel-in-progress: true env: CARGO_TERM_COLOR: always +run-name: + # Set custom name if job is started manually and name is given + ${{ github.event_name == 'workflow_dispatch' && (inputs.name == '' && format('{0} @ {1} {2}', github.ref_name, github.sha, toJson(inputs)) || inputs.name) || '' }} + jobs: - load-engine-matrix: + get-matrices: runs-on: ubuntu-latest outputs: - matrix: ${{ steps.load-engine-matrix.outputs.matrix }} - steps: - - name: Checkout - uses: actions/checkout@v4 + engine-matrix-output: ${{ steps.get-matrices.outputs.engine-matrix-output }} + host-matrix-output: ${{ steps.get-matrices.outputs.host-matrix-output }} + version-matrix-output: ${{ steps.get-matrices.outputs.version-matrix-output }} - - name: Load the engine matrix - id: load-engine-matrix - shell: bash - run: echo "matrix=$(jq -c . < .github/json_matrices/engine-matrix.json)" >> $GITHUB_OUTPUT + steps: + - uses: actions/checkout@v4 + - id: get-matrices + uses: ./.github/workflows/create-test-matrices + with: + language-name: node + run-full-matrix: ${{ github.event.inputs.full-matrix == 'true' || github.event_name == 'schedule' }} - test-ubuntu-latest: - runs-on: ubuntu-latest - needs: load-engine-matrix + test-node: + runs-on: ${{ matrix.host.RUNNER }} + needs: [get-matrices] timeout-minutes: 25 strategy: fail-fast: false matrix: - engine: ${{ fromJson(needs.load-engine-matrix.outputs.matrix) }} - + engine: ${{ fromJson(needs.get-matrices.outputs.engine-matrix-output) }} + host: ${{ fromJson(needs.get-matrices.outputs.host-matrix-output) }} + node: ${{ fromJson(needs.get-matrices.outputs.version-matrix-output) }} steps: - uses: actions/checkout@v4 - - - name: Use Node.js 16.x + - name: Setup Node uses: actions/setup-node@v4 + env: + ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true with: - node-version: 16.x + node-version: ${{ matrix.node }} - name: Build Node wrapper uses: ./.github/workflows/build-node-wrapper with: - os: "ubuntu" - target: "x86_64-unknown-linux-gnu" + os: ${{ matrix.host.OS }} + named_os: ${{ matrix.host.NAMED_OS }} + arch: ${{ matrix.host.ARCH }} + target: ${{ matrix.host.TARGET }} github-token: ${{ secrets.GITHUB_TOKEN }} engine-version: ${{ matrix.engine.version }} @@ -87,24 +108,25 @@ jobs: working-directory: ./node - name: test hybrid node modules - commonjs + if: ${{ matrix.engine.version == '8.0' && matrix.host.OS == 'ubuntu' && matrix.host.RUNNER == 'ubuntu-latest' && matrix.node == '20.x' }} run: | - npm install --package-lock-only - npm ci - npm run build-and-test + npm install + npm run test working-directory: ./node/hybrid-node-tests/commonjs-test env: JEST_HTML_REPORTER_OUTPUT_PATH: test-report-commonjs.html - name: test hybrid node modules - ecma + if: ${{ matrix.engine.version == '8.0' && matrix.host.OS == 'ubuntu' && matrix.host.RUNNER == 'ubuntu-latest' && matrix.node == '20.x' }} run: | - npm install --package-lock-only - npm ci - npm run build-and-test + npm install + npm run test working-directory: ./node/hybrid-node-tests/ecmascript-test env: JEST_HTML_REPORTER_OUTPUT_PATH: test-report-ecma.html - uses: ./.github/workflows/test-benchmark + if: ${{ matrix.engine.version == '8.0' && matrix.host.OS == 'ubuntu' && matrix.host.RUNNER == 'ubuntu-latest' && matrix.node == '20.x' }} with: language-flag: -node @@ -113,7 +135,7 @@ jobs: continue-on-error: true uses: actions/upload-artifact@v4 with: - name: test-report-node-${{ matrix.engine.type }}-${{ matrix.engine.version }}-ubuntu + name: test-report-node-${{ matrix.node }}-${{ matrix.engine.type }}-${{ matrix.engine.version }}-${{ matrix.host.OS }}-${{ matrix.host.ARCH }} path: | node/test-report*.html utils/clusters/** @@ -124,7 +146,6 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - name: lint node rust uses: ./.github/workflows/lint-rust @@ -132,134 +153,77 @@ jobs: cargo-toml-folder: ./node/rust-client github-token: ${{ secrets.GITHUB_TOKEN }} - # build-macos-latest: - # runs-on: macos-latest - # timeout-minutes: 25 - # steps: - # - uses: actions/checkout@v4 - # with: - # submodules: recursive - # - name: Set up Homebrew - # uses: Homebrew/actions/setup-homebrew@master - - # - name: Install NodeJS - # run: | - # brew update - # brew upgrade || true - # brew install node - - # - name: Downgrade npm major version to 8 - # run: | - # npm i -g npm@8 - - # - name: Build Node wrapper - # uses: ./.github/workflows/build-node-wrapper - # with: - # os: "macos" - # named_os: "darwin" - # arch: "arm64" - # target: "aarch64-apple-darwin" - # github-token: ${{ secrets.GITHUB_TOKEN }} - # engine-version: "7.2.5" - - # - name: Test compatibility - # run: npm test -- -t "set and get flow works" - # working-directory: ./node - - # - name: Upload test reports - # if: always() - # continue-on-error: true - # uses: actions/upload-artifact@v4 - # with: - # name: test-report-node-${{ matrix.engine.type }}-${{ matrix.engine.version }}-macos - # path: | - # node/test-report*.html - # utils/clusters/** - # benchmarks/results/** - - build-amazonlinux-latest: + get-containers: runs-on: ubuntu-latest - container: amazonlinux:latest - timeout-minutes: 15 - steps: - - name: Install git - run: | - yum -y remove git - yum -y remove git-* - yum -y install https://packages.endpointdev.com/rhel/7/os/x86_64/endpoint-repo.x86_64.rpm - yum install -y git - git --version + if: ${{ github.event.inputs.full-matrix == 'true' || github.event_name == 'schedule' }} + outputs: + engine-matrix-output: ${{ steps.get-matrices.outputs.engine-matrix-output }} + host-matrix-output: ${{ steps.get-matrices.outputs.host-matrix-output }} + version-matrix-output: ${{ steps.get-matrices.outputs.version-matrix-output }} + steps: - uses: actions/checkout@v4 - - - name: Checkout submodules - run: | - git config --global --add safe.directory "$GITHUB_WORKSPACE" - git submodule update --init --recursive - - - name: Install NodeJS - run: | - yum install -y nodejs - - - name: Build Node wrapper - uses: ./.github/workflows/build-node-wrapper - with: - os: "amazon-linux" - target: "x86_64-unknown-linux-gnu" - github-token: ${{ secrets.GITHUB_TOKEN }} - engine-version: "7.2.5" - - - name: Test compatibility - run: npm test -- -t "set and get flow works" - working-directory: ./node - - - name: Upload test reports - if: always() - continue-on-error: true - uses: actions/upload-artifact@v4 + - id: get-matrices + uses: ./.github/workflows/create-test-matrices with: - name: test-report-node-amazonlinux - path: | - node/test-report*.html - utils/clusters/** - benchmarks/results/** + language-name: node + run-full-matrix: true + containers: true - build-and-test-linux-musl-on-x86: - name: Build and test Node wrapper on Linux musl - runs-on: ubuntu-latest + test-node-container: + runs-on: ${{ matrix.host.RUNNER }} + needs: [get-containers] + timeout-minutes: 25 + strategy: + fail-fast: false + matrix: + node: ${{ fromJson(needs.get-containers.outputs.version-matrix-output) }} + engine: ${{ fromJson(needs.get-containers.outputs.engine-matrix-output) }} + host: ${{ fromJson(needs.get-containers.outputs.host-matrix-output) }} container: - image: node:alpine - options: --user root --privileged - + image: ${{ matrix.host.IMAGE }} + options: ${{ join(' -q ', matrix.host.CONTAINER_OPTIONS) }} # adding `-q` to bypass empty options steps: - name: Install git run: | - apk update - apk add git - + if [[ ${{ contains(matrix.host.TARGET, 'musl') }} == true ]]; then + apk update + apk add --no-cache git tar + elif [[ ${{ contains(matrix.host.IMAGE, 'amazonlinux') }} == true ]]; then + yum update + yum install -y git tar + fi + echo IMAGE=amazonlinux:latest | sed -r 's/:/-/g' >> $GITHUB_ENV + # Replace `:` in the variable otherwise it can't be used in `upload-artifact` - uses: actions/checkout@v4 - - name: Setup musl on Linux + if: ${{ contains(matrix.host.TARGET, 'musl') }} uses: ./.github/workflows/setup-musl-on-linux with: workspace: $GITHUB_WORKSPACE npm-scope: ${{ secrets.NPM_SCOPE }} npm-auth-token: ${{ secrets.NPM_AUTH_TOKEN }} + - name: Setup Node + uses: actions/setup-node@v4 + env: + ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true + with: + node-version: ${{ matrix.node }} + - name: Build Node wrapper uses: ./.github/workflows/build-node-wrapper with: - os: ubuntu - named_os: linux - arch: x64 - target: x86_64-unknown-linux-musl + os: ${{ matrix.host.OS }} + named_os: ${{ matrix.host.NAMED_OS }} + target: ${{ matrix.host.TARGET }} github-token: ${{ secrets.GITHUB_TOKEN }} - engine-version: "7.2.5" + engine-version: ${{ matrix.engine.version }} + arch: ${{ matrix.host.ARCH }} - - name: Test compatibility - shell: bash - run: npm test -- -t "set and get flow works" + - name: test + run: npm test working-directory: ./node - name: Upload test reports @@ -267,7 +231,7 @@ jobs: continue-on-error: true uses: actions/upload-artifact@v4 with: - name: test-report-node-linux-musl + name: test-report-node-${{ matrix.node }}-${{ matrix.engine.type }}-${{ matrix.engine.version }}-${{ env.IMAGE }}-${{ matrix.host.ARCH }} path: | node/test-report*.html utils/clusters/** @@ -285,7 +249,6 @@ jobs: run: sudo chown -R $USER:$USER /home/ubuntu/actions-runner/_work/valkey-glide - uses: actions/checkout@v4 - - name: Use Node.js 16.x uses: actions/setup-node@v4 diff --git a/.github/workflows/npm-cd.yml b/.github/workflows/npm-cd.yml index 362117affb..8ff2936fbc 100644 --- a/.github/workflows/npm-cd.yml +++ b/.github/workflows/npm-cd.yml @@ -4,29 +4,30 @@ name: NPM - Continuous Deployment on: pull_request: - paths: - - .github/workflows/npm-cd.yml - - .github/workflows/build-node-wrapper/action.yml - - .github/workflows/start-self-hosted-runner/action.yml - - .github/workflows/install-rust-and-protoc/action.yml - - .github/workflows/install-shared-dependencies/action.yml - - .github/workflows/install-valkey/action.yml - - .github/json_matrices/build-matrix.json + paths: + - .github/workflows/npm-cd.yml + - .github/workflows/build-node-wrapper/action.yml + - .github/workflows/start-self-hosted-runner/action.yml + - .github/workflows/install-rust-and-protoc/action.yml + - .github/workflows/install-shared-dependencies/action.yml + - .github/workflows/install-engine/action.yml + - .github/json_matrices/** + - .github/workflows/create-test-matrices/action.yml push: tags: - "v*.*" workflow_dispatch: - inputs: - version: - description: 'The release version of GLIDE, formatted as *.*.* or *.*.*-rc*' - required: true + inputs: + version: + description: "The release version of GLIDE, formatted as *.*.* or *.*.*-rc*" + required: true concurrency: - group: npm-${{ github.head_ref || github.ref }} + group: node-cd-${{ github.head_ref || github.ref }}-${{ toJson(inputs) }} cancel-in-progress: true permissions: - id-token: write + id-token: write jobs: start-self-hosted-runner: @@ -34,17 +35,17 @@ jobs: runs-on: ubuntu-latest environment: AWS_ACTIONS steps: - - name: Checkout - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Start self hosted EC2 runner - uses: ./.github/workflows/start-self-hosted-runner - with: - role-to-assume: ${{ secrets.ROLE_TO_ASSUME }} - aws-region: ${{ secrets.AWS_REGION }} - ec2-instance-id: ${{ secrets.AWS_EC2_INSTANCE_ID }} + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Start self hosted EC2 runner + uses: ./.github/workflows/start-self-hosted-runner + with: + role-to-assume: ${{ secrets.ROLE_TO_ASSUME }} + aws-region: ${{ secrets.AWS_REGION }} + ec2-instance-id: ${{ secrets.AWS_EC2_INSTANCE_ID }} load-platform-matrix: runs-on: ubuntu-latest @@ -58,9 +59,9 @@ jobs: id: load-platform-matrix shell: bash run: | - # Get the matrix from the matrix.json file, without the object that has the IMAGE key - export "PLATFORM_MATRIX=$(jq 'map(select(.PACKAGE_MANAGERS | contains(["npm"])))' < .github/json_matrices/build-matrix.json | jq -c .)" - echo "PLATFORM_MATRIX=${PLATFORM_MATRIX}" >> $GITHUB_OUTPUT + # Get the matrix from the matrix.json file, without the object that has the IMAGE key + export "PLATFORM_MATRIX=$(jq 'map(select(.PACKAGE_MANAGERS != null and (.PACKAGE_MANAGERS | contains(["npm"]))))' < .github/json_matrices/build-matrix.json | jq -c .)" + echo "PLATFORM_MATRIX=${PLATFORM_MATRIX}" >> $GITHUB_OUTPUT publish-binaries: needs: [start-self-hosted-runner, load-platform-matrix] @@ -73,7 +74,7 @@ jobs: strategy: fail-fast: false matrix: - build: ${{fromJson(needs.load-platform-matrix.outputs.PLATFORM_MATRIX)}} + build: ${{fromJson(needs.load-platform-matrix.outputs.PLATFORM_MATRIX)}} steps: - name: Setup self-hosted runner access if: ${{ contains(matrix.build.RUNNER, 'self-hosted') && matrix.build.TARGET != 'aarch64-unknown-linux-musl' }} @@ -85,7 +86,7 @@ jobs: run: | apk update apk add git - + - name: Checkout if: ${{ matrix.build.TARGET != 'aarch64-unknown-linux-musl' }} uses: actions/checkout@v4 @@ -100,7 +101,7 @@ jobs: workspace: $GITHUB_WORKSPACE npm-scope: ${{ vars.NPM_SCOPE }} npm-auth-token: ${{ secrets.NPM_AUTH_TOKEN }} - arch: ${{ matrix.build.ARCH }} + arch: ${{ matrix.build.ARCH }} - name: Set the release version shell: bash @@ -115,8 +116,8 @@ jobs: fi echo "RELEASE_VERSION=${R_VERSION}" >> $GITHUB_ENV env: - EVENT_NAME: ${{ github.event_name }} - INPUT_VERSION: ${{ github.event.inputs.version }} + EVENT_NAME: ${{ github.event_name }} + INPUT_VERSION: ${{ github.event.inputs.version }} - name: Setup node if: ${{ !contains(matrix.build.TARGET, 'musl') }} @@ -128,15 +129,15 @@ jobs: scope: "${{ vars.NPM_SCOPE }}" always-auth: true token: ${{ secrets.NPM_AUTH_TOKEN }} - + - name: Setup node for publishing if: ${{ !contains(matrix.build.TARGET, 'musl') }} working-directory: ./node run: | - npm config set registry https://registry.npmjs.org/ - npm config set '//registry.npmjs.org/:_authToken' ${{ secrets.NPM_AUTH_TOKEN }} - npm config set scope ${{ vars.NPM_SCOPE }} - + npm config set registry https://registry.npmjs.org/ + npm config set '//registry.npmjs.org/:_authToken' ${{ secrets.NPM_AUTH_TOKEN }} + npm config set scope ${{ vars.NPM_SCOPE }} + - name: Update package version in config.toml uses: ./.github/workflows/update-glide-version with: @@ -153,8 +154,7 @@ jobs: npm_scope: ${{ vars.NPM_SCOPE }} publish: "true" github-token: ${{ secrets.GITHUB_TOKEN }} - engine-version: "7.2.5" - + - name: Check if RC and set a distribution tag for the package shell: bash run: | @@ -171,9 +171,9 @@ jobs: - name: Check that the release version dont have typo init if: ${{ github.event_name != 'pull_request' && contains(env.RELEASE_VERSION, '-') && !contains(env.RELEASE_VERSION, 'rc') }} run: | - echo "The release version "${GITHUB_REF:11}" contains a typo, please fix it" - echo "The release version should be in the format v{major-version}.{minor-version}.{patch-version}-rc{release-candidate-number} when it a release candidate or v{major-version}.{minor-version}.{patch-version} in a stable release." - exit 1 + echo "The release version "${GITHUB_REF:11}" contains a typo, please fix it" + echo "The release version should be in the format v{major-version}.{minor-version}.{patch-version}-rc{release-candidate-number} when it a release candidate or v{major-version}.{minor-version}.{patch-version} in a stable release." + exit 1 - name: Publish to NPM if: github.event_name != 'pull_request' @@ -203,11 +203,11 @@ jobs: if: ${{ matrix.build.ARCH == 'arm64' }} shell: bash run: | - echo "Resetting repository" - git clean -xdf - git reset --hard - git fetch - git checkout ${{ github.sha }} + echo "Resetting repository" + git clean -xdf + git reset --hard + git fetch + git checkout ${{ github.sha }} publish-base-to-npm: if: github.event_name != 'pull_request' @@ -260,8 +260,7 @@ jobs: os: ubuntu target: "x86_64-unknown-linux-gnu" github-token: ${{ secrets.GITHUB_TOKEN }} - engine-version: "7.2.5" - + - name: Check if RC and set a distribution tag for the package shell: bash run: | @@ -274,7 +273,7 @@ jobs: export npm_tag="latest" fi echo "NPM_TAG=${npm_tag}" >> $GITHUB_ENV - + - name: Publish the base package if: github.event_name != 'pull_request' shell: bash @@ -299,7 +298,7 @@ jobs: strategy: fail-fast: false matrix: - build: ${{fromJson(needs.load-platform-matrix.outputs.PLATFORM_MATRIX)}} + build: ${{fromJson(needs.load-platform-matrix.outputs.PLATFORM_MATRIX)}} steps: - name: Setup self-hosted runner access if: ${{ matrix.build.TARGET == 'aarch64-unknown-linux-gnu' }} @@ -308,26 +307,24 @@ jobs: - name: install Redis and git for alpine if: ${{ contains(matrix.build.TARGET, 'musl') }} run: | - apk update - apk add redis git - node -v - + apk update + apk add redis git + node -v + - name: install Redis and Python for ubuntu if: ${{ contains(matrix.build.TARGET, 'linux-gnu') }} run: | - sudo apt-get update - sudo apt-get install redis-server python3 + sudo apt-get update + sudo apt-get install redis-server python3 - name: install Redis, Python for macos if: ${{ contains(matrix.build.RUNNER, 'mac') }} run: | - brew install redis python3 + brew install redis python3 - name: Checkout if: ${{ matrix.build.TARGET != 'aarch64-unknown-linux-musl'}} uses: actions/checkout@v4 - with: - submodules: "true" - name: Setup for musl if: ${{ contains(matrix.build.TARGET, 'musl') }} @@ -384,5 +381,5 @@ jobs: if: ${{ contains(matrix.build.RUNNER, 'self-hosted') }} shell: bash run: | - git reset --hard - git clean -xdf + git reset --hard + git clean -xdf diff --git a/.github/workflows/ort.yml b/.github/workflows/ort.yml index fcea61ee6b..2eff2a3f1a 100644 --- a/.github/workflows/ort.yml +++ b/.github/workflows/ort.yml @@ -1,34 +1,33 @@ - name: The OSS Review Toolkit (ORT) on: schedule: - - cron: "0 0 * * *" + - cron: "0 0 * * *" pull_request: - paths: - - .github/workflows/ort.yml - - .github/workflows/run-ort-tools/action.yml - - utils/get_licenses_from_ort.py + paths: + - .github/workflows/ort.yml + - .github/workflows/run-ort-tools/action.yml + - utils/get_licenses_from_ort.py workflow_dispatch: - inputs: - branch: - description: 'The branch to run against the ORT tool' - required: true - version: - description: 'The release version of GLIDE' - required: true + inputs: + branch: + description: "The branch to run against the ORT tool" + required: true + version: + description: "The release version of GLIDE" + required: true jobs: run-ort: if: github.repository_owner == 'valkey-io' name: Create attribution files runs-on: ubuntu-latest strategy: - fail-fast: false - env: - PYTHON_ATTRIBUTIONS: "python/THIRD_PARTY_LICENSES_PYTHON" - NODE_ATTRIBUTIONS: "node/THIRD_PARTY_LICENSES_NODE" - RUST_ATTRIBUTIONS: "glide-core/THIRD_PARTY_LICENSES_RUST" - JAVA_ATTRIBUTIONS: "java/THIRD_PARTY_LICENSES_JAVA" + fail-fast: false + env: + PYTHON_ATTRIBUTIONS: "python/THIRD_PARTY_LICENSES_PYTHON" + NODE_ATTRIBUTIONS: "node/THIRD_PARTY_LICENSES_NODE" + RUST_ATTRIBUTIONS: "glide-core/THIRD_PARTY_LICENSES_RUST" + JAVA_ATTRIBUTIONS: "java/THIRD_PARTY_LICENSES_JAVA" steps: - name: Set the release version shell: bash @@ -36,17 +35,17 @@ jobs: export version=`if [ "$EVENT_NAME" == 'schedule' ] || [ "$EVENT_NAME" == 'pull_request' ]; then echo '255.255.255'; else echo "$INPUT_VERSION"; fi` echo "RELEASE_VERSION=${version}" >> $GITHUB_ENV env: - EVENT_NAME: ${{ github.event_name }} - INPUT_VERSION: ${{ github.event.inputs.version }} - + EVENT_NAME: ${{ github.event_name }} + INPUT_VERSION: ${{ github.event.inputs.version }} + - name: Set the base branch run: | - export BASE_BRANCH=`if [ "$EVENT_NAME" == 'schedule' ]; then echo 'main'; elif [ "$EVENT_NAME" == 'workflow_dispatch' ]; then echo "$INPUT_BRANCH"; else echo ""; fi` - echo "Base branch is: ${BASE_BRANCH}" - echo "BASE_BRANCH=${BASE_BRANCH}" >> $GITHUB_ENV + export BASE_BRANCH=`if [ "$EVENT_NAME" == 'schedule' ]; then echo 'main'; elif [ "$EVENT_NAME" == 'workflow_dispatch' ]; then echo "$INPUT_BRANCH"; else echo ""; fi` + echo "Base branch is: ${BASE_BRANCH}" + echo "BASE_BRANCH=${BASE_BRANCH}" >> $GITHUB_ENV env: - EVENT_NAME: ${{ github.event_name }} - INPUT_BRANCH: ${{ github.event.inputs.branch }} + EVENT_NAME: ${{ github.event_name }} + INPUT_BRANCH: ${{ github.event.inputs.branch }} - name: Checkout uses: actions/checkout@v4 @@ -64,73 +63,73 @@ jobs: uses: actions/cache@v4 id: cache-ort with: - path: | - ./ort - ~/.gradle/caches - ~/.gradle/wrapper - key: ${{ runner.os }}-ort + path: | + ./ort + ~/.gradle/caches + ~/.gradle/wrapper + key: ${{ runner.os }}-ort - name: Checkout ORT Repository if: steps.cache-ort.outputs.cache-hit != 'true' uses: actions/checkout@v4 - with: + with: repository: "oss-review-toolkit/ort" path: "./ort" ref: "26.0.0" submodules: recursive - name: Install Rust toolchain - uses: dtolnay/rust-toolchain@1.78 + uses: dtolnay/rust-toolchain@1.78 - name: Install ORT if: steps.cache-ort.outputs.cache-hit != 'true' working-directory: ./ort/ run: | - export JAVA_OPTS="$JAVA_OPTS -Xmx8g" - ./gradlew installDist + export JAVA_OPTS="$JAVA_OPTS -Xmx8g" + ./gradlew installDist - name: Create ORT config file run: | - mkdir -p ~/.ort/config - cat << EOF > ~/.ort/config/config.yml - ort: - analyzer: - allowDynamicVersions: true - enabledPackageManagers: [Cargo, NPM, PIP, GradleInspector] - EOF - cat ~/.ort/config/config.yml + mkdir -p ~/.ort/config + cat << EOF > ~/.ort/config/config.yml + ort: + analyzer: + allowDynamicVersions: true + enabledPackageManagers: [Cargo, NPM, PIP, GradleInspector] + EOF + cat ~/.ort/config/config.yml - ### NodeJS ### + ### NodeJS ### - name: Set up Node.js 16.x uses: actions/setup-node@v4 with: node-version: 16.x - - name: Create package.json file for the Node wrapper + - name: Create package.json file for the Node wrapper uses: ./.github/workflows/node-create-package-file with: - release_version: ${{ env.RELEASE_VERSION }} - os: "ubuntu-latest" + release_version: ${{ env.RELEASE_VERSION }} + os: "ubuntu-latest" - name: Fix Node base NPM package.json file for ORT working-directory: ./node/npm/glide run: | - # Remove the glide-rs dependency to avoid duplication - sed -i '/ "glide-rs":/d' ../../package.json - export pkg_name=valkey-glide-base - export package_version="${{ env.RELEASE_VERSION }}" - export scope=`if [ "$NPM_SCOPE" != '' ]; then echo "$NPM_SCOPE/"; fi` - mv package.json package.json.tmpl - envsubst < package.json.tmpl > "package.json" - cat package.json - + # Remove the glide-rs dependency to avoid duplication + sed -i '/ "glide-rs":/d' ../../package.json + export pkg_name=valkey-glide-base + export package_version="${{ env.RELEASE_VERSION }}" + export scope=`if [ "$NPM_SCOPE" != '' ]; then echo "$NPM_SCOPE/"; fi` + mv package.json package.json.tmpl + envsubst < package.json.tmpl > "package.json" + cat package.json + - name: Run ORT tools for Node uses: ./.github/workflows/run-ort-tools with: - folder_path: "${{ github.workspace }}/node" - - ### Python ### + folder_path: "${{ github.workspace }}/node" + + ### Python ### - name: Set up Python 3.10 uses: actions/setup-python@v5 @@ -146,14 +145,14 @@ jobs: - name: Run ORT tools for Python uses: ./.github/workflows/run-ort-tools with: - folder_path: "${{ github.workspace }}/python" + folder_path: "${{ github.workspace }}/python" ### Rust ### - name: Run ORT tools for Rust uses: ./.github/workflows/run-ort-tools with: - folder_path: "${{ github.workspace }}/glide-core" + folder_path: "${{ github.workspace }}/glide-core" ### Java ### @@ -163,60 +162,60 @@ jobs: distribution: "temurin" java-version: 11 - - name: Run ORT tools for Java + - name: Run ORT tools for Java uses: ./.github/workflows/run-ort-tools with: - folder_path: "${{ github.workspace }}/java" + folder_path: "${{ github.workspace }}/java" ### Process results ### - name: Check for diff run: | - cp python/ort_results/NOTICE_DEFAULT $PYTHON_ATTRIBUTIONS - cp node/ort_results/NOTICE_DEFAULT $NODE_ATTRIBUTIONS - cp glide-core/ort_results/NOTICE_DEFAULT $RUST_ATTRIBUTIONS - cp java/ort_results/NOTICE_DEFAULT $JAVA_ATTRIBUTIONS - GIT_DIFF=`git diff $PYTHON_ATTRIBUTIONS $NODE_ATTRIBUTIONS $RUST_ATTRIBUTIONS $JAVA_ATTRIBUTIONS` - if [ -n "$GIT_DIFF" ]; then - echo "FOUND_DIFF=true" >> $GITHUB_ENV - else - echo "FOUND_DIFF=false" >> $GITHUB_ENV - fi + cp python/ort_results/NOTICE_DEFAULT $PYTHON_ATTRIBUTIONS + cp node/ort_results/NOTICE_DEFAULT $NODE_ATTRIBUTIONS + cp glide-core/ort_results/NOTICE_DEFAULT $RUST_ATTRIBUTIONS + cp java/ort_results/NOTICE_DEFAULT $JAVA_ATTRIBUTIONS + GIT_DIFF=`git diff $PYTHON_ATTRIBUTIONS $NODE_ATTRIBUTIONS $RUST_ATTRIBUTIONS $JAVA_ATTRIBUTIONS` + if [ -n "$GIT_DIFF" ]; then + echo "FOUND_DIFF=true" >> $GITHUB_ENV + else + echo "FOUND_DIFF=false" >> $GITHUB_ENV + fi - name: Retrieve licenses list working-directory: ./utils run: | - { - echo 'LICENSES_LIST<> "$GITHUB_ENV" + { + echo 'LICENSES_LIST<> "$GITHUB_ENV" ### Create PR ### - name: Create pull request if: ${{ env.FOUND_DIFF == 'true' && github.event_name != 'pull_request' }} run: | - export BRANCH_NAME=`if [ "$EVENT_NAME" == 'schedule' ] || [ "$EVENT_NAME" == 'pull_request' ]; then echo 'scheduled-ort'; else echo "ort-v$INPUT_VERSION"; fi` - echo "Creating pull request from branch ${BRANCH_NAME} to branch ${{ env.BASE_BRANCH }}" - git config --global user.email "valkey-glide@lists.valkey.io" - git config --global user.name "ort-bot" - git checkout -b ${BRANCH_NAME} - git add $PYTHON_ATTRIBUTIONS $NODE_ATTRIBUTIONS $RUST_ATTRIBUTIONS $JAVA_ATTRIBUTIONS - git commit -m "Updated attribution files" -s - git push --set-upstream origin ${BRANCH_NAME} -f - title="Updated attribution files for ${BRANCH_NAME}" - gh pr create -B ${{ env.BASE_BRANCH }} -H ${BRANCH_NAME} --title "${title}" --body 'Created by Github action.\n${{ env.LICENSES_LIST }}' + export BRANCH_NAME=`if [ "$EVENT_NAME" == 'schedule' ] || [ "$EVENT_NAME" == 'pull_request' ]; then echo 'scheduled-ort'; else echo "ort-v$INPUT_VERSION"; fi` + echo "Creating pull request from branch ${BRANCH_NAME} to branch ${{ env.BASE_BRANCH }}" + git config --global user.email "valkey-glide@lists.valkey.io" + git config --global user.name "ort-bot" + git checkout -b ${BRANCH_NAME} + git add $PYTHON_ATTRIBUTIONS $NODE_ATTRIBUTIONS $RUST_ATTRIBUTIONS $JAVA_ATTRIBUTIONS + git commit -m "Updated attribution files" -s + git push --set-upstream origin ${BRANCH_NAME} -f + title="Updated attribution files for ${BRANCH_NAME}" + gh pr create -B ${{ env.BASE_BRANCH }} -H ${BRANCH_NAME} --title "${title}" --body 'Created by Github action.\n${{ env.LICENSES_LIST }}' env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - EVENT_NAME: ${{ github.event_name }} - INPUT_VERSION: ${{ github.event.inputs.version }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + EVENT_NAME: ${{ github.event_name }} + INPUT_VERSION: ${{ github.event.inputs.version }} - name: Get current date id: date run: | - CURR_DATE=$(date +'%Y-%m-%d-%H') - echo "date=${CURR_DATE}" >> $GITHUB_OUTPUT + CURR_DATE=$(date +'%Y-%m-%d-%H') + echo "date=${CURR_DATE}" >> $GITHUB_OUTPUT - name: Upload the final package list continue-on-error: true diff --git a/.github/workflows/pypi-cd.yml b/.github/workflows/pypi-cd.yml index 5d547ac7f5..b72da2a211 100644 --- a/.github/workflows/pypi-cd.yml +++ b/.github/workflows/pypi-cd.yml @@ -9,8 +9,9 @@ on: - .github/workflows/build-python-wrapper/action.yml - .github/workflows/start-self-hosted-runner/action.yml - .github/workflows/install-shared-dependencies/action.yml - - .github/workflows/install-valkey/action.yml - - .github/json_matrices/build-matrix.json + - .github/workflows/install-engine/action.yml + - .github/json_matrices/** + - .github/workflows/create-test-matrices/action.yml push: tags: - "v*.*" @@ -41,7 +42,7 @@ jobs: shell: bash run: | # Get the matrix from the matrix.json file, without the object that has the IMAGE key - export "PLATFORM_MATRIX=$(jq 'map(select(.PACKAGE_MANAGERS | contains(["pypi"])))' < .github/json_matrices/build-matrix.json | jq -c .)" + export "PLATFORM_MATRIX=$(jq 'map(select(.PACKAGE_MANAGERS != null and (.PACKAGE_MANAGERS | contains(["pypi"]))))' < .github/json_matrices/build-matrix.json | jq -c .)" echo "PLATFORM_MATRIX=${PLATFORM_MATRIX}" >> $GITHUB_OUTPUT start-self-hosted-runner: @@ -51,6 +52,9 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Start self hosted EC2 runner uses: ./.github/workflows/start-self-hosted-runner with: @@ -114,13 +118,6 @@ jobs: with: python-version: "3.12" - - name: Setup Python for self-hosted Ubuntu runners - if: contains(matrix.build.RUNNER, 'self-hosted') - run: | - sudo apt update -y - sudo apt upgrade -y - sudo apt install python3 python3-venv python3-pip -y - - name: Update package version in config.toml uses: ./.github/workflows/update-glide-version with: @@ -134,7 +131,6 @@ jobs: target: ${{ matrix.build.TARGET }} publish: "true" github-token: ${{ secrets.GITHUB_TOKEN }} - engine-version: "7.2" - name: Include protobuf files in the package working-directory: ./python @@ -150,7 +146,7 @@ jobs: with: working-directory: ./python target: ${{ matrix.build.TARGET }} - args: --release --strip --out wheels -i ${{ github.event_name != 'pull_request' && 'python3.8 python3.9 python3.10 python3.11 python3.12' || 'python3.10' }} + args: --release --strip --out wheels -i ${{ github.event_name != 'pull_request' && 'python3.8 python3.9 python3.10 python3.11 python3.12 python3.13' || 'python3.12' }} manylinux: auto container: ${{ matrix.build.CONTAINER != '' && matrix.build.CONTAINER || '2014' }} before-script-linux: | @@ -178,9 +174,10 @@ jobs: if: startsWith(matrix.build.NAMED_OS, 'darwin') uses: PyO3/maturin-action@v1 with: + maturin-version: latest working-directory: ./python target: ${{ matrix.build.TARGET }} - args: --release --strip --out wheels -i ${{ github.event_name != 'pull_request' && 'python3.8 python3.9 python3.10 python3.11 python3.12' || 'python3.12' }} + args: --release --strip --out wheels -i ${{ github.event_name != 'pull_request' && 'python3.8 python3.9 python3.10 python3.11 python3.12 python3.13' || 'python3.12' }} - name: Upload Python wheels if: github.event_name != 'pull_request' @@ -220,10 +217,6 @@ jobs: matrix: build: ${{ fromJson(needs.load-platform-matrix.outputs.PLATFORM_MATRIX) }} steps: - - name: Setup self-hosted runner access - if: ${{ matrix.build.TARGET == 'aarch64-unknown-linux-gnu' }} - run: sudo chown -R $USER:$USER /home/ubuntu/actions-runner/_work/valkey-glide - - name: checkout uses: actions/checkout@v4 @@ -232,8 +225,8 @@ jobs: with: python-version: 3.12 - - name: Install ValKey - uses: ./.github/workflows/install-valkey + - name: Install engine + uses: ./.github/workflows/install-engine with: version: "8.0" diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 6c45d7707c..46c3d8e252 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -16,10 +16,10 @@ on: - .github/workflows/install-shared-dependencies/action.yml - .github/workflows/test-benchmark/action.yml - .github/workflows/lint-rust/action.yml - - .github/workflows/install-valkey/action.yml - - .github/json_matrices/build-matrix.json - - .github/json_matrices/engine-matrix.json + - .github/workflows/install-engine/action.yml - .github/workflows/start-self-hosted-runner/action.yml + - .github/workflows/create-test-matrices/action.yml + - .github/json_matrices/** pull_request: paths: @@ -32,62 +32,68 @@ on: - .github/workflows/install-shared-dependencies/action.yml - .github/workflows/test-benchmark/action.yml - .github/workflows/lint-rust/action.yml - - .github/workflows/install-valkey/action.yml - - .github/json_matrices/build-matrix.json - - .github/json_matrices/engine-matrix.json + - .github/workflows/install-engine/action.yml - .github/workflows/start-self-hosted-runner/action.yml + - .github/workflows/create-test-matrices/action.yml + - .github/json_matrices/** + workflow_dispatch: + inputs: + full-matrix: + description: "Run the full engine, host, and language version matrix" + type: boolean + default: false + name: + required: false + type: string + description: "(Optional) Test run name" + + workflow_call: concurrency: - group: python-${{ github.head_ref || github.ref }} + group: python-${{ github.head_ref || github.ref }}-${{ toJson(inputs) }} cancel-in-progress: true +permissions: + contents: read + # Allows the GITHUB_TOKEN to make an API call to generate an OIDC token. + id-token: write + +run-name: + # Set custom name if job is started manually and name is given + ${{ github.event_name == 'workflow_dispatch' && (inputs.name == '' && format('{0} @ {1} {2}', github.ref_name, github.sha, toJson(inputs)) || inputs.name) || '' }} + jobs: - load-engine-matrix: + get-matrices: runs-on: ubuntu-latest outputs: - matrix: ${{ steps.load-engine-matrix.outputs.matrix }} + engine-matrix-output: ${{ steps.get-matrices.outputs.engine-matrix-output }} + host-matrix-output: ${{ steps.get-matrices.outputs.host-matrix-output }} + version-matrix-output: ${{ steps.get-matrices.outputs.version-matrix-output }} steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Load the engine matrix - id: load-engine-matrix - shell: bash - run: echo "matrix=$(jq -c . < .github/json_matrices/engine-matrix.json)" >> $GITHUB_OUTPUT - - test: + - uses: actions/checkout@v4 + - id: get-matrices + uses: ./.github/workflows/create-test-matrices + with: + language-name: python + # Run full test matrix if job started by cron or it was explictly specified by a person who triggered the workflow + run-full-matrix: ${{ github.event.inputs.full-matrix == 'true' || github.event_name == 'schedule' }} + + test-python: runs-on: ${{ matrix.host.RUNNER }} - needs: load-engine-matrix + needs: get-matrices timeout-minutes: 35 strategy: fail-fast: false matrix: - engine: ${{ fromJson(needs.load-engine-matrix.outputs.matrix) }} - python: - # - "3.8" - # - "3.9" - # - "3.10" - # - "3.11" - - "3.12" - host: - - { - OS: ubuntu, - RUNNER: ubuntu-latest, - TARGET: x86_64-unknown-linux-gnu - } - # - { - # OS: macos, - # RUNNER: macos-latest, - # TARGET: aarch64-apple-darwin - # } - + python: ${{ fromJson(needs.get-matrices.outputs.version-matrix-output) }} + engine: ${{ fromJson(needs.get-matrices.outputs.engine-matrix-output) }} + host: ${{ fromJson(needs.get-matrices.outputs.host-matrix-output) }} steps: - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} @@ -105,17 +111,6 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} engine-version: ${{ matrix.engine.version }} - - name: Type check with mypy - working-directory: ./python - run: | - # The type check should run inside the virtual env to get - # all installed dependencies and build files - source .env/bin/activate - pip install mypy types-protobuf - # Install the benchmark requirements - pip install -r ../benchmarks/python/requirements.txt - python -m mypy .. - - name: Test with pytest working-directory: ./python run: | @@ -124,9 +119,22 @@ jobs: pytest --asyncio-mode=auto --html=pytest_report.html --self-contained-html - uses: ./.github/workflows/test-benchmark + if: ${{ matrix.engine.version == '8.0' && matrix.host.OS == 'ubuntu' && matrix.host.RUNNER == 'ubuntu-latest' && matrix.python == '3.12' }} with: language-flag: -python + - name: Type check with mypy + if: ${{ matrix.engine.version == '8.0' && matrix.host.OS == 'ubuntu' && matrix.host.RUNNER == 'ubuntu-latest' && matrix.python == '3.12' }} + working-directory: ./python + run: | + # The type check should run inside the virtual env to get + # all installed dependencies and build files + source .env/bin/activate + pip install mypy types-protobuf + # Install the benchmark requirements + pip install -r ../benchmarks/python/requirements.txt + python -m mypy .. + - name: Upload test reports if: always() continue-on-error: true @@ -138,41 +146,25 @@ jobs: utils/clusters/** benchmarks/results/** - test-pubsub: + # run pubsub tests in another job - they take too much time + test-pubsub-python: runs-on: ${{ matrix.host.RUNNER }} - needs: load-engine-matrix + needs: get-matrices timeout-minutes: 35 strategy: fail-fast: false matrix: - engine: ${{ fromJson(needs.load-engine-matrix.outputs.matrix) }} - python: - # - "3.8" - # - "3.9" - # - "3.10" - # - "3.11" - - "3.12" - host: - - { - OS: ubuntu, - RUNNER: ubuntu-latest, - TARGET: x86_64-unknown-linux-gnu - } - # - { - # OS: macos, - # RUNNER: macos-latest, - # TARGET: aarch64-apple-darwin - # } - + python: ${{ fromJson(needs.get-matrices.outputs.version-matrix-output) }} + engine: ${{ fromJson(needs.get-matrices.outputs.engine-matrix-output) }} + host: ${{ fromJson(needs.get-matrices.outputs.host-matrix-output) }} steps: - uses: actions/checkout@v4 - - + - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} - + - name: Build Python wrapper uses: ./.github/workflows/build-python-wrapper with: @@ -180,14 +172,14 @@ jobs: target: ${{ matrix.host.TARGET }} github-token: ${{ secrets.GITHUB_TOKEN }} engine-version: ${{ matrix.engine.version }} - + - name: Test pubsub with pytest working-directory: ./python run: | source .env/bin/activate cd python/tests/ pytest --asyncio-mode=auto -k test_pubsub --html=pytest_report.html --self-contained-html - + - name: Upload test reports if: always() continue-on-error: true @@ -202,12 +194,12 @@ jobs: timeout-minutes: 15 steps: - uses: actions/checkout@v4 - - name: lint rust uses: ./.github/workflows/lint-rust with: - cargo-toml-folder: ./python + cargo-toml-folder: python + github-token: ${{ secrets.GITHUB_TOKEN }} - name: Install dependencies if: always() @@ -236,66 +228,89 @@ jobs: run: | black --check --diff . - build-amazonlinux-latest: + get-containers: runs-on: ubuntu-latest - container: amazonlinux:latest - timeout-minutes: 15 - steps: - - name: Install git - run: | - yum -y remove git - yum -y remove git-* - yum -y install https://packages.endpointdev.com/rhel/7/os/x86_64/endpoint-repo.x86_64.rpm - yum install -y git - git --version + if: ${{ github.event.inputs.full-matrix == 'true' || github.event_name == 'schedule' }} + outputs: + engine-matrix-output: ${{ steps.get-matrices.outputs.engine-matrix-output }} + host-matrix-output: ${{ steps.get-matrices.outputs.host-matrix-output }} + version-matrix-output: ${{ steps.get-matrices.outputs.version-matrix-output }} + steps: - uses: actions/checkout@v4 + - id: get-matrices + uses: ./.github/workflows/create-test-matrices + with: + language-name: python + run-full-matrix: true + containers: true - - name: Checkout submodules - run: | - git config --global --add safe.directory "$GITHUB_WORKSPACE" - git submodule update --init --recursive - - - name: Install python + test-python-container: + runs-on: ${{ matrix.host.RUNNER }} + needs: [get-containers] + timeout-minutes: 25 + strategy: + fail-fast: false + matrix: + # Don't use generated matrix for python until compatibility issues resolved on amazon linux + # python: ${{ fromJson(needs.get-containers.outputs.version-matrix-output) }} + engine: ${{ fromJson(needs.get-containers.outputs.engine-matrix-output) }} + host: ${{ fromJson(needs.get-containers.outputs.host-matrix-output) }} + container: + image: ${{ matrix.host.IMAGE }} + options: ${{ join(' -q ', matrix.host.CONTAINER_OPTIONS) }} # adding `-q` to bypass empty options + steps: + - name: Install git and python run: | - yum install -y python3 + yum update + yum install -y git tar python3 + python3 -m ensurepip --upgrade + python3 -m pip install --upgrade pip + python3 -m pip install mypy-protobuf virtualenv + echo IMAGE=amazonlinux:latest | sed -r 's/:/-/g' >> $GITHUB_ENV + # Replace `:` in the variable otherwise it can't be used in `upload-artifact` + - uses: actions/checkout@v4 - name: Build Python wrapper uses: ./.github/workflows/build-python-wrapper with: - os: "amazon-linux" - target: "x86_64-unknown-linux-gnu" + os: ubuntu + target: aarch64-unknown-linux-gnu github-token: ${{ secrets.GITHUB_TOKEN }} - engine-version: "7.2.5" - - name: Test compatibility with pytest + - name: Test with pytest working-directory: ./python run: | source .env/bin/activate - pytest --asyncio-mode=auto -m smoke_test --html=pytest_report.html --self-contained-html + cd python/tests/ + pytest --asyncio-mode=auto --html=pytest_report.html --self-contained-html - name: Upload test reports if: always() continue-on-error: true uses: actions/upload-artifact@v4 with: - name: smoke-test-report-amazon-linux + name: test-report-python-${{ matrix.python }}-${{ matrix.engine.type }}-${{ matrix.engine.version }}-${{ env.IMAGE }}-${{ matrix.host.ARCH }} path: | python/python/tests/pytest_report.html - + utils/clusters/** + benchmarks/results/** + test-modules: if: (github.repository_owner == 'valkey-io' && github.event_name == 'workflow_dispatch') || github.event.pull_request.head.repo.owner.login == 'valkey-io' environment: AWS_ACTIONS name: Running Module Tests runs-on: [self-hosted, linux, ARM64] timeout-minutes: 15 - + steps: - name: Setup self-hosted runner access if: ${{ contains(matrix.host.RUNNER, 'self-hosted') }} run: sudo chown -R $USER:$USER /home/ubuntu/actions-runner/_work/valkey-glide - uses: actions/checkout@v4 + with: + submodules: recursive - name: Build Python wrapper uses: ./.github/workflows/build-python-wrapper @@ -303,7 +318,6 @@ jobs: os: ubuntu target: aarch64-unknown-linux-gnu github-token: ${{ secrets.GITHUB_TOKEN }} - - name: Test with pytest working-directory: ./python run: | diff --git a/.github/workflows/redis-rs.yml b/.github/workflows/redis-rs.yml index cdc4967759..d0c1b00830 100644 --- a/.github/workflows/redis-rs.yml +++ b/.github/workflows/redis-rs.yml @@ -45,12 +45,11 @@ jobs: engine-version: "7.2.5" github-token: ${{ secrets.GITHUB_TOKEN }} - - name: Cache dependencies + - name: Cache dependencies uses: Swatinem/rust-cache@v2 with: - cache-on-failure: true - workspaces: ./glide-core/redis-rs/redis - + cache-on-failure: true + workspaces: ./glide-core/redis-rs/redis - name: Build project run: cargo build --release diff --git a/.github/workflows/run-ort-tools/action.yml b/.github/workflows/run-ort-tools/action.yml index 5686f21b58..2b55517700 100644 --- a/.github/workflows/run-ort-tools/action.yml +++ b/.github/workflows/run-ort-tools/action.yml @@ -13,11 +13,11 @@ runs: working-directory: ./ort/ shell: bash run: | - echo "Running ORT tools for ${{ inputs.folder_path }}" - FOLDER=${{ inputs.folder_path }} - mkdir $FOLDER/ort_results - # Analyzer (analyzer-result.json) - ./gradlew cli:run --args="analyze -i $FOLDER -o $FOLDER/ort_results -f JSON" - - # NOTICE DEFAULT - ./gradlew cli:run --args="report -i $FOLDER/ort_results/analyzer-result.json -o $FOLDER/ort_results/ -f PlainTextTemplate" + echo "Running ORT tools for ${{ inputs.folder_path }}" + FOLDER=${{ inputs.folder_path }} + mkdir $FOLDER/ort_results + # Analyzer (analyzer-result.json) + ./gradlew cli:run --args="analyze -i $FOLDER -o $FOLDER/ort_results -f JSON" + + # NOTICE DEFAULT + ./gradlew cli:run --args="report -i $FOLDER/ort_results/analyzer-result.json -o $FOLDER/ort_results/ -f PlainTextTemplate" diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 2fdfa77c1f..0c71fa2f86 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -13,9 +13,10 @@ on: - utils/cluster_manager.py - .github/workflows/rust.yml - .github/workflows/install-shared-dependencies/action.yml - - .github/workflows/install-valkey/action.yml + - .github/workflows/install-engine/action.yml - .github/workflows/lint-rust/action.yml - - .github/json_matrices/build-matrix.json + - .github/workflows/create-test-matrices/action.yml + - .github/json_matrices/** - deny.toml pull_request: paths: @@ -25,45 +26,66 @@ on: - utils/cluster_manager.py - .github/workflows/rust.yml - .github/workflows/install-shared-dependencies/action.yml - - .github/workflows/install-valkey/action.yml + - .github/workflows/install-engine/action.yml - .github/workflows/lint-rust/action.yml - - .github/json_matrices/build-matrix.json + - .github/workflows/create-test-matrices/action.yml + - .github/json_matrices/** - deny.toml workflow_dispatch: + inputs: + full-matrix: + description: "Run the full engine and host matrix" + type: boolean + default: false + name: + required: false + type: string + description: "(Optional) Test run name" + + workflow_call: concurrency: - group: rust-${{ github.head_ref || github.ref }} + group: rust-${{ github.head_ref || github.ref }}-${{ toJson(inputs) }} cancel-in-progress: true env: CARGO_TERM_COLOR: always +run-name: + # Set custom name if job is started manually and name is given + ${{ github.event_name == 'workflow_dispatch' && (inputs.name == '' && format('{0} @ {1} {2}', github.ref_name, github.sha, toJson(inputs)) || inputs.name) || '' }} + jobs: - load-engine-matrix: - runs-on: ubuntu-latest - outputs: - matrix: ${{ steps.load-engine-matrix.outputs.matrix }} - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Load the engine matrix - id: load-engine-matrix - shell: bash - run: echo "matrix=$(jq -c . < .github/json_matrices/engine-matrix.json)" >> $GITHUB_OUTPUT - - build: + get-matrices: runs-on: ubuntu-latest - needs: load-engine-matrix + # Avoid running on schedule for forks + if: (github.repository_owner == 'valkey-io' || github.event_name != 'schedule') || github.event_name == 'push' || github.event_name == 'pull_request' || github.event_name == 'workflow_dispatch' + outputs: + engine-matrix-output: ${{ steps.get-matrices.outputs.engine-matrix-output }} + host-matrix-output: ${{ steps.get-matrices.outputs.host-matrix-output }} + # language version matrix is omitted + + steps: + - uses: actions/checkout@v4 + - id: get-matrices + uses: ./.github/workflows/create-test-matrices + with: + language-name: rust + # Run full test matrix if job started by cron or it was explictly specified by a person who triggered the workflow + run-full-matrix: ${{ github.event.inputs.full-matrix == 'true' || github.event_name == 'schedule' }} + + tests: + runs-on: ${{ matrix.host.RUNNER }} + needs: get-matrices timeout-minutes: 15 strategy: fail-fast: false matrix: - engine: ${{ fromJson(needs.load-engine-matrix.outputs.matrix) }} + engine: ${{ fromJson(needs.get-matrices.outputs.engine-matrix-output) }} + host: ${{ fromJson(needs.get-matrices.outputs.host-matrix-output) }} steps: - uses: actions/checkout@v4 - - name: Install shared software dependencies uses: ./.github/workflows/install-shared-dependencies @@ -98,7 +120,6 @@ jobs: timeout-minutes: 30 steps: - uses: actions/checkout@v4 - - uses: ./.github/workflows/lint-rust with: diff --git a/.github/workflows/semgrep.yml b/.github/workflows/semgrep.yml index a9cf3db6df..10523666fa 100644 --- a/.github/workflows/semgrep.yml +++ b/.github/workflows/semgrep.yml @@ -2,7 +2,7 @@ name: Semgrep on: # Scan changed files in PRs (diff-aware scanning): - pull_request: + pull_request: {} # Scan on-demand through GitHub Actions interface: workflow_dispatch: inputs: diff --git a/.github/workflows/setup-musl-on-linux/action.yml b/.github/workflows/setup-musl-on-linux/action.yml index f270c27507..5d69f4063d 100644 --- a/.github/workflows/setup-musl-on-linux/action.yml +++ b/.github/workflows/setup-musl-on-linux/action.yml @@ -30,28 +30,28 @@ runs: run: | apk update apk add bash git sed python3 - + - name: Skip all steps if not on ARM64 shell: bash if: ${{ inputs.arch != 'arm64' }} run: exit 0 - # Currently "Checkout" action is not supported for musl on ARM64, so the checkout is happening on the runner and + # Currently "Checkout" action is not supported for musl on ARM64, so the checkout is happening on the runner and # here we just making sure we getting the clean repo - name: Clean repository for musl on ARM64 shell: bash run: | - git config --global --add safe.directory "${{ inputs.workspace }}" - git fetch origin ${{ github.sha }} - git checkout ${{ github.sha }} - git clean -xdf - git reset --hard - + git config --global --add safe.directory "${{ inputs.workspace }}" + git fetch origin ${{ github.sha }} + git checkout ${{ github.sha }} + git clean -xdf + git reset --hard + - name: Set up access for musl on ARM shell: bash run: | chown -R $(whoami):$(whoami) $GITHUB_WORKSPACE - + - name: Setup node shell: bash working-directory: ./node diff --git a/.github/workflows/start-self-hosted-runner/action.yml b/.github/workflows/start-self-hosted-runner/action.yml index 45038b2d1d..e929e1d2d4 100644 --- a/.github/workflows/start-self-hosted-runner/action.yml +++ b/.github/workflows/start-self-hosted-runner/action.yml @@ -19,14 +19,14 @@ runs: with: role-to-assume: ${{ inputs.role-to-assume }} aws-region: ${{ inputs.aws-region }} - + - name: Start EC2 self hosted runner shell: bash run: | sudo snap refresh sudo snap install aws-cli --classic command_id=$(aws ssm send-command --instance-ids ${{ inputs.ec2-instance-id }} --document-name StartGithubSelfHostedRunner --query Command.CommandId --output text) - + while [[ "$invoke_status" != "Success" && "$invoke_status" != "Failed" ]]; do invoke_status=$(aws ssm list-command-invocations --command-id $command_id --query 'CommandInvocations[0].Status' --output text) echo "Current Status: $invoke_status" @@ -41,4 +41,4 @@ runs: fi done - echo "Final Command Status: $invoke_status" + echo "Final Command Status: $invoke_status" diff --git a/.prettierignore b/.prettierignore new file mode 100644 index 0000000000..dd449725e1 --- /dev/null +++ b/.prettierignore @@ -0,0 +1 @@ +*.md diff --git a/.vscode/settings.json b/.vscode/settings.json index 72bcb0d6f7..229045495f 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -27,7 +27,7 @@ "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, "[github-actions-workflow]": { - "editor.defaultFormatter": "redhat.vscode-yaml" + "editor.defaultFormatter": "esbenp.prettier-vscode" }, "yaml.schemas": { "https://json.schemastore.org/github-workflow.json": [ @@ -58,5 +58,6 @@ "black" ], "rust-analyzer.cargo.features": "all", - "dotnet.defaultSolution": "csharp/csharp.sln" + "dotnet.defaultSolution": "csharp/csharp.sln", + "java.compile.nullAnalysis.mode": "automatic" } diff --git a/glide-core/redis-rs/scripts/get_command_info.py b/glide-core/redis-rs/scripts/get_command_info.py index dcba666bff..4c719dd4d4 100644 --- a/glide-core/redis-rs/scripts/get_command_info.py +++ b/glide-core/redis-rs/scripts/get_command_info.py @@ -1,3 +1,4 @@ +# type: ignore import argparse import json import os diff --git a/glide-core/src/client/types.rs b/glide-core/src/client/types.rs index 2fa46037da..ef4be661e6 100644 --- a/glide-core/src/client/types.rs +++ b/glide-core/src/client/types.rs @@ -2,7 +2,9 @@ * Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ +#[allow(unused_imports)] use logger_core::log_warn; +#[allow(unused_imports)] use std::collections::HashSet; use std::time::Duration; diff --git a/glide-core/tests/test_standalone_client.rs b/glide-core/tests/test_standalone_client.rs index 5b269dd42c..8001ccab0c 100644 --- a/glide-core/tests/test_standalone_client.rs +++ b/glide-core/tests/test_standalone_client.rs @@ -251,7 +251,7 @@ mod standalone_client_tests { .map(|mock| mock.get_number_of_received_commands()) .collect(); replica_reads.sort(); - assert_eq!(config.expected_replica_reads, replica_reads); + assert!(config.expected_replica_reads <= replica_reads); } #[rstest] diff --git a/glide-core/tests/utilities/mod.rs b/glide-core/tests/utilities/mod.rs index 4e73c0ea93..765c1cffb1 100644 --- a/glide-core/tests/utilities/mod.rs +++ b/glide-core/tests/utilities/mod.rs @@ -734,7 +734,7 @@ pub async fn kill_connection(client: &mut impl glide_core::client::GlideClientFo &client_kill_cmd, Some(RoutingInfo::MultiNode(( MultipleNodeRoutingInfo::AllNodes, - None, + Some(redis::cluster_routing::ResponsePolicy::AllSucceeded), ))), ) .await diff --git a/java/integTest/src/test/java/glide/TestConfiguration.java b/java/integTest/src/test/java/glide/TestConfiguration.java index e2f77e6547..864e384e1d 100644 --- a/java/integTest/src/test/java/glide/TestConfiguration.java +++ b/java/integTest/src/test/java/glide/TestConfiguration.java @@ -8,6 +8,7 @@ import glide.api.BaseClient; import glide.api.GlideClient; import glide.api.GlideClusterClient; +import glide.api.logging.Logger; public final class TestConfiguration { // All servers are hosted on localhost @@ -19,6 +20,8 @@ public final class TestConfiguration { public static final boolean TLS = Boolean.parseBoolean(System.getProperty("test.server.tls", "")); static { + Logger.init(Logger.Level.OFF); + Logger.setLoggerConfig(Logger.Level.OFF); try { BaseClient client = !STANDALONE_HOSTS[0].isEmpty() diff --git a/node/README.md b/node/README.md index d215bc102f..661e742b96 100644 --- a/node/README.md +++ b/node/README.md @@ -14,12 +14,16 @@ The release of Valkey GLIDE was tested on the following platforms: Linux: -- Ubuntu 22.04.1 (x86_64) +- Ubuntu 22.04.1 (x86_64 and aarch64) - Amazon Linux 2023 (AL2023) (x86_64) macOS: -- macOS 12.7 (Apple silicon/aarch_64 and Intel/x86_64) +- macOS (12.7 and latest) (Apple silicon/aarch_64 and Intel/x86_64) + +Alpine: + +- node:alpine (default on aarch64 and x86_64) ## NodeJS supported version diff --git a/node/hybrid-node-tests/commonjs-test/package.json b/node/hybrid-node-tests/commonjs-test/package.json index f45541c135..6e1291a680 100644 --- a/node/hybrid-node-tests/commonjs-test/package.json +++ b/node/hybrid-node-tests/commonjs-test/package.json @@ -19,8 +19,8 @@ "valkey-glide": "file:../../../node/build-ts/cjs" }, "devDependencies": { - "@types/node": "^18.7.9", - "prettier": "^2.8.8", - "typescript": "^4.8.4" + "@types/node": "^22.8", + "prettier": "^3.3", + "typescript": "^5.6" } } diff --git a/node/hybrid-node-tests/ecmascript-test/package.json b/node/hybrid-node-tests/ecmascript-test/package.json index df578128a7..d1d91d5a97 100644 --- a/node/hybrid-node-tests/ecmascript-test/package.json +++ b/node/hybrid-node-tests/ecmascript-test/package.json @@ -19,8 +19,8 @@ "valkey-glide": "file:../../../node/build-ts/mjs" }, "devDependencies": { - "@types/node": "^18.7.9", - "prettier": "^2.8.8", - "typescript": "^4.8.4" + "@types/node": "^22.8", + "prettier": "^3.3", + "typescript": "^5.6" } } diff --git a/node/jest.config.js b/node/jest.config.js index e73e7ce439..6952aecfca 100644 --- a/node/jest.config.js +++ b/node/jest.config.js @@ -1,7 +1,9 @@ /* eslint no-undef: off */ module.exports = { preset: "ts-jest", - transform: { "^.+\\.ts?$": "ts-jest" }, + transform: { + "^.+\\.(t|j)s$": ["ts-jest", { isolatedModules: true }], + }, testEnvironment: "node", testRegex: "/tests/.*\\.(test|spec)?\\.(ts|tsx)$", moduleFileExtensions: [ @@ -27,4 +29,5 @@ module.exports = { }, ], ], + setupFilesAfterEnv: ["./tests/setup.js"], }; diff --git a/node/package.json b/node/package.json index 06a2dc7743..685d1338fd 100644 --- a/node/package.json +++ b/node/package.json @@ -32,15 +32,15 @@ "compile-protobuf-files": "cd src && pbjs -t static-module -o ProtobufMessage.js ../../glide-core/src/protobuf/*.proto && pbts -o ProtobufMessage.d.ts ProtobufMessage.js", "clean": "rm -rf build-ts rust-client/target docs glide-logs rust-client/glide-rs.*.node rust-client/index.* src/ProtobufMessage.*", "fix-protobuf-file": "replace 'this\\.encode\\(message, writer\\)\\.ldelim' 'this.encode(message, writer && writer.len ? writer.fork() : writer).ldelim' src/ProtobufMessage.js", - "test": "npm run build-test-utils && jest --verbose --runInBand --testPathIgnorePatterns='ServerModules'", + "test": "npm run build-test-utils && jest --verbose --testPathIgnorePatterns='ServerModules'", "test-minimum": "npm run build-test-utils && jest --verbose --runInBand --testNamePattern='^(.(?!(GlideJson|GlideFt|pubsub|kill)))*$'", - "test-modules": "npm run build-test-utils && jest --verbose --runInBand --testNamePattern='(GlideJson|GlideFt)'", + "test-modules": "npm run build-test-utils && jest --verbose --testNamePattern='(GlideJson|GlideFt)'", "build-test-utils": "cd ../utils && npm i && npm run build", "lint:fix": "npm run install-linting && npx eslint -c ../eslint.config.mjs --fix && npm run prettier:format", "lint": "npm run install-linting && npx eslint -c ../eslint.config.mjs && npm run prettier:check:ci", "install-linting": "cd ../ & npm install", "prepack": "npmignore --auto", - "prereq": "git submodule update --init --recursive && npm install", + "prereq": "npm install", "prettier:check:ci": "npx prettier --check . --ignore-unknown '!**/*.{js,d.ts}'", "prettier:format": "npx prettier --write . --ignore-unknown '!**/*.{js,d.ts}'" }, diff --git a/node/tests/AsyncClient.test.ts b/node/tests/AsyncClient.test.ts index 6d0d3b4246..c810b07f91 100644 --- a/node/tests/AsyncClient.test.ts +++ b/node/tests/AsyncClient.test.ts @@ -35,8 +35,8 @@ describe("AsyncClient", () => { await flushallOnPort(port); }); - afterAll(() => { - server.close(); + afterAll(async () => { + await server.close(); }); runCommonTests({ diff --git a/node/tests/GlideClient.test.ts b/node/tests/GlideClient.test.ts index e0ac65e169..e1043c0657 100644 --- a/node/tests/GlideClient.test.ts +++ b/node/tests/GlideClient.test.ts @@ -20,6 +20,7 @@ import { GlideClient, GlideRecord, GlideString, + ListDirection, ProtocolVersion, RequestError, Script, @@ -34,9 +35,8 @@ import { convertStringArrayToBuffer, createLongRunningLuaScript, createLuaLibWithLongRunningFunction, - DumpAndRestureTest, + DumpAndRestoreTest, encodableTransactionTest, - encodedTransactionTest, flushAndCloseClient, generateLuaLibCode, getClientConfigurationOption, @@ -74,6 +74,8 @@ describe("GlideClient", () => { afterAll(async () => { if (testsFailed === 0) { await cluster.close(); + } else { + await cluster.close(true); } }, TIMEOUT); @@ -135,6 +137,45 @@ describe("GlideClient", () => { }, ); + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "check that blocking commands returns never timeout_%p", + async (protocol) => { + client = await GlideClient.createClient( + getClientConfigurationOption(cluster.getAddresses(), protocol, { + requestTimeout: 300, + }), + ); + + const promiseList = [ + client.blmove( + "source", + "destination", + ListDirection.LEFT, + ListDirection.LEFT, + 0.1, + ), + client.bzpopmax(["key1", "key2"], 0), + client.bzpopmin(["key1", "key2"], 0), + ]; + + try { + for (const promise of promiseList) { + const timeoutPromise = new Promise((resolve) => { + setTimeout(resolve, 500); + }); + await Promise.race([promise, timeoutPromise]); + } + } finally { + for (const promise of promiseList) { + await Promise.resolve([promise]); + } + + client.close(); + } + }, + 5000, + ); + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( "select dbsize flushdb test %p", async (protocol) => { @@ -239,25 +280,6 @@ describe("GlideClient", () => { }, ); - it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( - `can get Bytes decoded transactions_%p`, - async (protocol) => { - client = await GlideClient.createClient( - getClientConfigurationOption(cluster.getAddresses(), protocol), - ); - const transaction = new Transaction(); - const expectedRes = await encodedTransactionTest(transaction); - transaction.select(0); - const result = await client.exec(transaction, { - decoder: Decoder.Bytes, - }); - expectedRes.push(["select(0)", "OK"]); - - validateTransactionResponse(result, expectedRes); - client.close(); - }, - ); - it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( `dump and restore transactions_%p`, async (protocol) => { @@ -265,9 +287,12 @@ describe("GlideClient", () => { getClientConfigurationOption(cluster.getAddresses(), protocol), ); const bytesTransaction = new Transaction(); - const expectedBytesRes = await DumpAndRestureTest( + await client.set("key", "value"); + const dumpValue: Buffer = (await client.dump("key")) as Buffer; + await client.del(["key"]); + const expectedBytesRes = await DumpAndRestoreTest( bytesTransaction, - Buffer.from("value"), + dumpValue, ); bytesTransaction.select(0); const result = await client.exec(bytesTransaction, { @@ -278,14 +303,14 @@ describe("GlideClient", () => { validateTransactionResponse(result, expectedBytesRes); const stringTransaction = new Transaction(); - await DumpAndRestureTest(stringTransaction, "value"); + await DumpAndRestoreTest(stringTransaction, dumpValue); stringTransaction.select(0); // Since DUMP gets binary results, we cannot use the string decoder here, so we expected to get an error. await expect( client.exec(stringTransaction, { decoder: Decoder.String }), ).rejects.toThrowError( - "invalid utf-8 sequence of 1 bytes from index 9", + /invalid utf-8 sequence of 1 bytes from index/, ); client.close(); diff --git a/node/tests/GlideClientInternals.test.ts b/node/tests/GlideClientInternals.test.ts index 9d8f093d0a..b12fd007d4 100644 --- a/node/tests/GlideClientInternals.test.ts +++ b/node/tests/GlideClientInternals.test.ts @@ -2,7 +2,7 @@ * Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ -import { beforeAll, describe, expect, it } from "@jest/globals"; +import { describe, expect, it } from "@jest/globals"; import fs from "fs"; import { createLeakedArray, @@ -32,7 +32,6 @@ import { GlideReturnType, InfoOptions, isGlideRecord, - Logger, RequestError, SlotKeyTypes, TimeUnit, @@ -45,10 +44,6 @@ import { import { convertStringArrayToBuffer } from "./TestUtilities"; const { RequestType, CommandRequest } = command_request; -beforeAll(() => { - Logger.init("info"); -}); - enum ResponseType { /** Type of a response that returns a null. */ Null, diff --git a/node/tests/GlideClusterClient.test.ts b/node/tests/GlideClusterClient.test.ts index 6f29f99884..3707b46f74 100644 --- a/node/tests/GlideClusterClient.test.ts +++ b/node/tests/GlideClusterClient.test.ts @@ -84,6 +84,8 @@ describe("GlideClusterClient", () => { afterAll(async () => { if (testsFailed === 0) { await cluster.close(); + } else { + await cluster.close(true); } }); @@ -246,7 +248,7 @@ describe("GlideClusterClient", () => { expect(await client.set(key, value)).toEqual("OK"); // Since DUMP gets binary results, we cannot use the default decoder (string) here, so we expected to get an error. await expect(client.customCommand(["DUMP", key])).rejects.toThrow( - "invalid utf-8 sequence of 1 bytes from index 9", + "invalid utf-8 sequence of 1 bytes from index", ); const dumpResult = await client.customCommand(["DUMP", key], { @@ -370,6 +372,20 @@ describe("GlideClusterClient", () => { const client = await GlideClusterClient.createClient( getClientConfigurationOption(cluster.getAddresses(), protocol), ); + const lmpopArr = []; + + if (!cluster.checkIfServerVersionLessThan("7.0.0")) { + lmpopArr.push( + client.lmpop(["abc", "def"], ListDirection.LEFT, { + count: 1, + }), + ); + lmpopArr.push( + client.blmpop(["abc", "def"], ListDirection.RIGHT, 0.1, { + count: 1, + }), + ); + } const promises: Promise[] = [ client.blpop(["abc", "zxy", "lkn"], 0.1), @@ -391,10 +407,7 @@ describe("GlideClusterClient", () => { client.sdiffstore("abc", ["zxy", "lkn"]), client.sortStore("abc", "zyx"), client.sortStore("abc", "zyx", { isAlpha: true }), - client.lmpop(["abc", "def"], ListDirection.LEFT, { count: 1 }), - client.blmpop(["abc", "def"], ListDirection.RIGHT, 0.1, { - count: 1, - }), + ...lmpopArr, client.bzpopmax(["abc", "def"], 0.5), client.bzpopmin(["abc", "def"], 0.5), client.xread({ abc: "0-0", zxy: "0-0", lkn: "0-0" }), @@ -440,9 +453,15 @@ describe("GlideClusterClient", () => { ); } - for (const promise of promises) { - await expect(promise).rejects.toThrowError(/crossslot/i); - } + await Promise.allSettled(promises).then((results) => { + results.forEach((result) => { + expect(result.status).toBe("rejected"); + + if (result.status === "rejected") { + expect(result.reason.message).toContain("CrossSlot"); + } + }); + }); client.close(); }, diff --git a/node/tests/PubSub.test.ts b/node/tests/PubSub.test.ts index 5e8d40207a..8e87b7ee1c 100644 --- a/node/tests/PubSub.test.ts +++ b/node/tests/PubSub.test.ts @@ -80,12 +80,22 @@ describe("PubSub", () => { : await ValkeyCluster.createCluster(true, 3, 1, getServerVersion); }, 40000); afterEach(async () => { - await flushAndCloseClient(false, cmdCluster.getAddresses()); - await flushAndCloseClient(true, cmeCluster.getAddresses()); + if (cmdCluster) { + await flushAndCloseClient(false, cmdCluster.getAddresses()); + } + + if (cmeCluster) { + await flushAndCloseClient(true, cmeCluster.getAddresses()); + } }); afterAll(async () => { - await cmeCluster.close(); - await cmdCluster.close(); + if (cmdCluster) { + await cmdCluster.close(); + } + + if (cmeCluster) { + await cmeCluster.close(); + } }); async function createClients( @@ -3956,6 +3966,12 @@ describe("PubSub", () => { let client2: TGlideClient | null = null; try { + const minVersion = "7.0.0"; + + if (cmeCluster.checkIfServerVersionLessThan(minVersion)) { + return; // Skip test if server version is less than required + } + const regularChannel = "regular_channel"; const shardChannel = "shard_channel"; @@ -3977,12 +3993,6 @@ describe("PubSub", () => { pubSub, ); - const minVersion = "7.0.0"; - - if (cmeCluster.checkIfServerVersionLessThan(minVersion)) { - return; // Skip test if server version is less than required - } - // Test pubsubChannels const regularChannels = await client2.pubsubChannels(); expect(regularChannels).toEqual([regularChannel]); diff --git a/node/tests/ScanTest.test.ts b/node/tests/ScanTest.test.ts index e1ccb9bb06..5c975cacdc 100644 --- a/node/tests/ScanTest.test.ts +++ b/node/tests/ScanTest.test.ts @@ -40,7 +40,7 @@ describe("Scan GlideClusterClient", () => { ) : // setting replicaCount to 1 to facilitate tests routed to replicas await ValkeyCluster.createCluster(true, 3, 1, getServerVersion); - }, 20000); + }, 40000); afterEach(async () => { await flushAndCloseClient(true, cluster.getAddresses(), client); @@ -49,6 +49,8 @@ describe("Scan GlideClusterClient", () => { afterAll(async () => { if (testsFailed === 0) { await cluster.close(); + } else { + await cluster.close(true); } }); @@ -401,6 +403,8 @@ describe("Scan GlideClient", () => { afterAll(async () => { if (testsFailed === 0) { await cluster.close(); + } else { + await cluster.close(true); } }); diff --git a/node/tests/SharedTests.ts b/node/tests/SharedTests.ts index 7b1af14ed1..92e07c30ce 100644 --- a/node/tests/SharedTests.ts +++ b/node/tests/SharedTests.ts @@ -1656,7 +1656,6 @@ export function runBaseTests(config: { await expect(client.hscan(key1, "-1")).rejects.toThrow( RequestError, ); - await expect(client.sscan(key1, "-1")).rejects.toThrow( RequestError, ); @@ -6435,7 +6434,7 @@ export function runBaseTests(config: { [key3], cluster.checkIfServerVersionLessThan("6.0.0") ? 1.0 - : 0.001, + : 0.01, ), ).toBeNull(); @@ -6485,7 +6484,7 @@ export function runBaseTests(config: { [key3], cluster.checkIfServerVersionLessThan("6.0.0") ? 1.0 - : 0.001, + : 0.01, ), ).toBeNull(); @@ -7526,148 +7525,126 @@ export function runBaseTests(config: { it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( `xinfo stream xinfosream test_%p`, async (protocol) => { - await runTest(async (client: BaseClient) => { - const key = uuidv4(); - const groupName = `group-${uuidv4()}`; - const consumerName = `consumer-${uuidv4()}`; - const streamId0_0 = "0-0"; - const streamId1_0 = "1-0"; - const streamId1_1 = "1-1"; + await runTest( + async (client: BaseClient, cluster: ValkeyCluster) => { + const key = uuidv4(); + const groupName = `group-${uuidv4()}`; + const consumerName = `consumer-${uuidv4()}`; + const streamId0_0 = "0-0"; + const streamId1_0 = "1-0"; + const streamId1_1 = "1-1"; - // Setup: add stream entry, create consumer group and consumer, read from stream with consumer - expect( - await client.xadd( - key, - [ - ["a", "b"], - ["c", "d"], - ], - { id: streamId1_0 }, - ), - ).toEqual(streamId1_0); + expect( + await client.xadd( + key, + [ + ["a", "b"], + ["c", "d"], + ], + { id: streamId1_0 }, + ), + ).toEqual(streamId1_0); - expect( - await client.xgroupCreate(key, groupName, streamId0_0), - ).toEqual("OK"); + expect( + await client.xgroupCreate(key, groupName, streamId0_0), + ).toEqual("OK"); - await client.xreadgroup(groupName, consumerName, { - [key]: ">", - }); + await client.xreadgroup(groupName, consumerName, { + [key]: ">", + }); - // test xinfoStream base (non-full) case: - const result = (await client.xinfoStream(key)) as { - length: number; - "radix-tree-keys": number; - "radix-tree-nodes": number; - "last-generated-id": string; - "max-deleted-entry-id": string; - "entries-added": number; - "recorded-first-entry-id": string; - "first-entry": (string | number | string[])[]; - "last-entry": (string | number | string[])[]; - groups: number; - }; + const result = (await client.xinfoStream(key)) as { + length: number; + "radix-tree-keys": number; + "radix-tree-nodes": number; + "last-generated-id": string; + "max-deleted-entry-id": string; + "entries-added": number; + "recorded-first-entry-id": string; + "first-entry": (string | number | string[])[]; + "last-entry": (string | number | string[])[]; + groups: number; + }; - // verify result: - expect(result.length).toEqual(1); - const expectedFirstEntry = ["1-0", ["a", "b", "c", "d"]]; - expect(result["first-entry"]).toEqual(expectedFirstEntry); - expect(result["last-entry"]).toEqual(expectedFirstEntry); - expect(result.groups).toEqual(1); + expect(result.length).toEqual(1); + const expectedFirstEntry = ["1-0", ["a", "b", "c", "d"]]; + expect(result["first-entry"]).toEqual(expectedFirstEntry); + expect(result["last-entry"]).toEqual(expectedFirstEntry); + expect(result.groups).toEqual(1); - // Add one more entry - expect( - await client.xadd(key, [["foo", "bar"]], { - id: streamId1_1, - }), - ).toEqual(streamId1_1); - const fullResult = (await client.xinfoStream(Buffer.from(key), { - fullOptions: 1, - })) as { - length: number; - "radix-tree-keys": number; - "radix-tree-nodes": number; - "last-generated-id": string; - "max-deleted-entry-id": string; - "entries-added": number; - "recorded-first-entry-id": string; - entries: (string | number | string[])[][]; - groups: [ + expect( + await client.xadd(key, [["foo", "bar"]], { + id: streamId1_1, + }), + ).toEqual(streamId1_1); + const fullResult = (await client.xinfoStream( + Buffer.from(key), { - name: string; - "last-delivered-id": string; - "entries-read": number; - lag: number; - "pel-count": number; - pending: (string | number)[][]; - consumers: [ - { - name: string; - "seen-time": number; - "active-time": number; - "pel-count": number; - pending: (string | number)[][]; - }, - ]; + fullOptions: 1, }, - ]; - }; + )) as { + length: number; + "radix-tree-keys": number; + "radix-tree-nodes": number; + "last-generated-id": string; + "max-deleted-entry-id": string; + "entries-added": number; + "recorded-first-entry-id": string; + entries: (string | number | string[])[][]; + groups: [ + { + name: string; + "last-delivered-id": string; + "entries-read": number; + lag: number; + "pel-count": number; + pending: (string | number)[][]; + consumers: [ + { + name: string; + "seen-time": number; + "active-time": number; + "pel-count": number; + pending: (string | number)[][]; + }, + ]; + }, + ]; + }; - // verify full result like: - // { - // length: 2, - // 'radix-tree-keys': 1, - // 'radix-tree-nodes': 2, - // 'last-generated-id': '1-1', - // 'max-deleted-entry-id': '0-0', - // 'entries-added': 2, - // 'recorded-first-entry-id': '1-0', - // entries: [ [ '1-0', ['a', 'b', ...] ] ], - // groups: [ { - // name: 'group', - // 'last-delivered-id': '1-0', - // 'entries-read': 1, - // lag: 1, - // 'pel-count': 1, - // pending: [ [ '1-0', 'consumer', 1722624726802, 1 ] ], - // consumers: [ { - // name: 'consumer', - // 'seen-time': 1722624726802, - // 'active-time': 1722624726802, - // 'pel-count': 1, - // pending: [ [ '1-0', 'consumer', 1722624726802, 1 ] ], - // } - // ] - // } - // ] - // } - expect(fullResult.length).toEqual(2); - expect(fullResult["recorded-first-entry-id"]).toEqual( - streamId1_0, - ); - - // Only the first entry will be returned since we passed count: 1 - expect(fullResult.entries).toEqual([expectedFirstEntry]); - - // compare groupName, consumerName, and pending messages from the full info result: - const fullResultGroups = fullResult.groups; - expect(fullResultGroups.length).toEqual(1); - expect(fullResultGroups[0]["name"]).toEqual(groupName); - - const pendingResult = fullResultGroups[0]["pending"]; - expect(pendingResult.length).toEqual(1); - expect(pendingResult[0][0]).toEqual(streamId1_0); - expect(pendingResult[0][1]).toEqual(consumerName); - - const consumersResult = fullResultGroups[0]["consumers"]; - expect(consumersResult.length).toEqual(1); - expect(consumersResult[0]["name"]).toEqual(consumerName); - - const consumerPendingResult = fullResultGroups[0]["pending"]; - expect(consumerPendingResult.length).toEqual(1); - expect(consumerPendingResult[0][0]).toEqual(streamId1_0); - expect(consumerPendingResult[0][1]).toEqual(consumerName); - }, protocol); + expect(fullResult.length).toEqual(2); + + if (cluster.checkIfServerVersionLessThan("7.0.0")) { + expect( + fullResult["max-deleted-entry-id"], + ).toBeUndefined(); + expect(fullResult["entries-added"]).toBeUndefined(); + expect( + fullResult.groups[0]["entries-read"], + ).toBeUndefined(); + expect(fullResult.groups[0]["lag"]).toBeUndefined(); + } else if (cluster.checkIfServerVersionLessThan("7.2.0")) { + expect(fullResult["recorded-first-entry-id"]).toEqual( + streamId1_0, + ); + + expect( + fullResult.groups[0].consumers[0]["active-time"], + ).toBeUndefined(); + expect( + fullResult.groups[0].consumers[0]["seen-time"], + ).toBeDefined(); + } else { + expect( + fullResult.groups[0].consumers[0]["active-time"], + ).toBeDefined(); + expect( + fullResult.groups[0].consumers[0]["seen-time"], + ).toBeDefined(); + } + }, + protocol, + ); }, config.timeout, ); @@ -10660,7 +10637,7 @@ export function runBaseTests(config: { expect(result[0].pending).toEqual(1); expect(result[0].idle).toBeGreaterThan(0); - if (cluster.checkIfServerVersionLessThan("7.2.0")) { + if (!cluster.checkIfServerVersionLessThan("7.2.0")) { expect(result[0].inactive).toBeGreaterThan(0); } diff --git a/node/tests/TestUtilities.ts b/node/tests/TestUtilities.ts index 0b64b31a04..3ac7743cec 100644 --- a/node/tests/TestUtilities.ts +++ b/node/tests/TestUtilities.ts @@ -600,86 +600,42 @@ export async function encodableTransactionTest( return responseData; } -/** - * Populates a transaction with commands to test the decoded response. - * @param baseTransaction - A transaction. - * @returns Array of tuples, where first element is a test name/description, second - expected return value. - */ -export async function encodedTransactionTest( - baseTransaction: Transaction | ClusterTransaction, -): Promise<[string, GlideReturnType][]> { - const key1 = "{key}" + uuidv4(); // string - const key2 = "{key}" + uuidv4(); // string - const key = "dumpKey"; - const dumpResult = Buffer.from([ - 0, 5, 118, 97, 108, 117, 101, 11, 0, 232, 41, 124, 75, 60, 53, 114, 231, - ]); - const value = "value"; - const valueEncoded = Buffer.from(value); - // array of tuples - first element is test name/description, second - expected return value - const responseData: [string, GlideReturnType][] = []; - - baseTransaction.set(key1, value); - responseData.push(["set(key1, value)", "OK"]); - baseTransaction.set(key2, value); - responseData.push(["set(key2, value)", "OK"]); - baseTransaction.get(key1); - responseData.push(["get(key1)", valueEncoded]); - baseTransaction.get(key2); - responseData.push(["get(key2)", valueEncoded]); - - baseTransaction.set(key, value); - responseData.push(["set(key, value)", "OK"]); - baseTransaction.customCommand(["DUMP", key]); - responseData.push(['customCommand(["DUMP", key])', dumpResult]); - baseTransaction.del([key]); - responseData.push(["del(key)", 1]); - baseTransaction.get(key); - responseData.push(["get(key)", null]); - baseTransaction.customCommand(["RESTORE", key, "0", dumpResult]); - responseData.push([ - 'customCommand(["RESTORE", key, "0", dumpResult])', - "OK", - ]); - baseTransaction.get(key); - responseData.push(["get(key)", valueEncoded]); - - return responseData; -} - /** Populates a transaction with dump and restore commands * * @param baseTransaction - A transaction * @param valueResponse - Represents the encoded response of "value" to compare * @returns Array of tuples, where first element is a test name/description, second - expected return value. */ -export async function DumpAndRestureTest( +export async function DumpAndRestoreTest( baseTransaction: Transaction, - valueResponse: GlideString, + dumpValue: Buffer | null, ): Promise<[string, GlideReturnType][]> { - const key = "dumpKey"; - const dumpResult = Buffer.from([ - 0, 5, 118, 97, 108, 117, 101, 11, 0, 232, 41, 124, 75, 60, 53, 114, 231, - ]); - const value = "value"; + if (dumpValue == null) { + throw new Error("dumpValue is null"); + } + + const key = "{key}-" + uuidv4(); // string + const buffValue = Buffer.from("value"); // array of tuples - first element is test name/description, second - expected return value const responseData: [string, GlideReturnType][] = []; - baseTransaction.set(key, value); - responseData.push(["set(key, value)", "OK"]); + baseTransaction.set(key, "value"); + responseData.push(["set(key, stringValue)", "OK"]); baseTransaction.customCommand(["DUMP", key]); - responseData.push(['customCommand(["DUMP", key])', dumpResult]); + responseData.push(['customCommand(["DUMP", key])', dumpValue]); + baseTransaction.get(key); + responseData.push(["get(key)", buffValue]); baseTransaction.del([key]); responseData.push(["del(key)", 1]); baseTransaction.get(key); responseData.push(["get(key)", null]); - baseTransaction.customCommand(["RESTORE", key, "0", dumpResult]); + baseTransaction.customCommand(["RESTORE", key, "0", dumpValue]); responseData.push([ - 'customCommand(["RESTORE", key, "0", dumpResult])', + 'customCommand(["RESTORE", buffKey, "0", stringValue])', "OK", ]); baseTransaction.get(key); - responseData.push(["get(key)", valueResponse]); + responseData.push(["get(key)", buffValue]); return responseData; } @@ -874,20 +830,20 @@ export async function transactionTest( ]); responseData.push(["lpush(key5, [1, 2, 3, 4])", 4]); - if (gte("7.0.0", version)) { + if (gte(version, "7.0.0")) { baseTransaction.lpush(key24, [field + "1", field + "2"]); responseData.push(["lpush(key22, [1, 2])", 2]); baseTransaction.lmpop([key24], ListDirection.LEFT); responseData.push([ "lmpop([key22], ListDirection.LEFT)", - { [key24]: [field + "2"] }, + [{ key: key24, value: [field + "2"] }], ]); baseTransaction.lpush(key24, [field + "2"]); responseData.push(["lpush(key22, [2])", 2]); baseTransaction.blmpop([key24], ListDirection.LEFT, 0.1, 1); responseData.push([ "blmpop([key22], ListDirection.LEFT, 0.1, 1)", - { [key24]: [field + "2"] }, + [{ key: key24, value: [field + "2"] }], ]); } @@ -1393,6 +1349,9 @@ export async function transactionTest( baseTransaction.xgroupDestroy(key9, groupName2); responseData.push(["xgroupDestroy(key9, groupName2)", true]); + baseTransaction.wait(1, 200); + responseData.push(["wait(1, 200)", 1]); + baseTransaction.rename(key9, key10); responseData.push(["rename(key9, key10)", "OK"]); baseTransaction.exists([key10]); @@ -1490,31 +1449,31 @@ export async function transactionTest( baseTransaction.geoadd( key18, new Map([ - ["Palermo", { longitude: 13.361389, latitude: 38.115556 }], - ["Catania", { longitude: 15.087269, latitude: 37.502669 }], + ["palermo", { longitude: 13.361389, latitude: 38.115556 }], + ["catania", { longitude: 15.087269, latitude: 37.502669 }], ]), ); - responseData.push(["geoadd(key18, { Palermo: ..., Catania: ... })", 2]); - baseTransaction.geopos(key18, ["Palermo", "Catania"]); + responseData.push(["geoadd(key18, { palermo: ..., catania: ... })", 2]); + baseTransaction.geopos(key18, ["palermo", "catania"]); responseData.push([ - 'geopos(key18, ["Palermo", "Catania"])', + 'geopos(key18, ["palermo", "catania"])', [ [13.36138933897018433, 38.11555639549629859], [15.08726745843887329, 37.50266842333162032], ], ]); - baseTransaction.geodist(key18, "Palermo", "Catania"); - responseData.push(['geodist(key18, "Palermo", "Catania")', 166274.1516]); - baseTransaction.geodist(key18, "Palermo", "Catania", { + baseTransaction.geodist(key18, "palermo", "catania"); + responseData.push(['geodist(key18, "palermo", "catania")', 166274.1516]); + baseTransaction.geodist(key18, "palermo", "catania", { unit: GeoUnit.KILOMETERS, }); responseData.push([ - 'geodist(key18, "Palermo", "Catania", { unit: GeoUnit.KILOMETERS })', + 'geodist(key18, "palermo", "catania", { unit: GeoUnit.KILOMETERS })', 166.2742, ]); - baseTransaction.geohash(key18, ["Palermo", "Catania", "NonExisting"]); + baseTransaction.geohash(key18, ["palermo", "catania", "NonExisting"]); responseData.push([ - 'geohash(key18, ["Palermo", "Catania", "NonExisting"])', + 'geohash(key18, ["palermo", "catania", "NonExisting"])', ["sqc8b49rny0", "sqdtr74hyu0", null], ]); baseTransaction.zadd(key23, { one: 1.0 }); @@ -1533,7 +1492,7 @@ export async function transactionTest( baseTransaction .geosearch( key18, - { member: "Palermo" }, + { member: "palermo" }, { radius: 200, unit: GeoUnit.KILOMETERS }, { sortOrder: SortOrder.ASC }, ) @@ -1544,7 +1503,7 @@ export async function transactionTest( ) .geosearch( key18, - { member: "Palermo" }, + { member: "palermo" }, { radius: 200, unit: GeoUnit.KILOMETERS }, { sortOrder: SortOrder.ASC, @@ -1567,18 +1526,18 @@ export async function transactionTest( }, ); responseData.push([ - 'geosearch(key18, "Palermo", R200 KM, ASC)', - ["Palermo", "Catania"], + 'geosearch(key18, "palermo", R200 KM, ASC)', + ["palermo", "catania"], ]); responseData.push([ "geosearch(key18, (15, 37), 400x400 KM, ASC)", - ["Palermo", "Catania"], + ["palermo", "catania"], ]); responseData.push([ - 'geosearch(key18, "Palermo", R200 KM, ASC 2 3x true)', + 'geosearch(key18, "palermo", R200 KM, ASC 2 3x true)', [ [ - "Palermo", + "palermo", [ 0.0, 3479099956230698, @@ -1586,7 +1545,7 @@ export async function transactionTest( ], ], [ - "Catania", + "catania", [ 166.2742, 3479447370796909, @@ -1599,7 +1558,7 @@ export async function transactionTest( "geosearch(key18, (15, 37), 400x400 KM, ASC 2 3x true)", [ [ - "Catania", + "catania", [ 56.4413, 3479447370796909, @@ -1607,7 +1566,7 @@ export async function transactionTest( ], ], [ - "Palermo", + "palermo", [ 190.4424, 3479099956230698, @@ -1757,8 +1716,6 @@ export async function transactionTest( responseData.push(["sortReadOnly(key21)", ["1", "2", "3"]]); } - baseTransaction.wait(1, 200); - responseData.push(["wait(1, 200)", 1]); return responseData; } diff --git a/node/tests/setup.js b/node/tests/setup.js new file mode 100644 index 0000000000..d9b8d0b74b --- /dev/null +++ b/node/tests/setup.js @@ -0,0 +1,7 @@ +import { beforeAll } from "@jest/globals"; +import { Logger } from ".."; + +beforeAll(() => { + Logger.init("off"); + Logger.setLoggerConfig("off"); +}); diff --git a/python/Cargo.toml b/python/Cargo.toml index 3945322cd2..aaf49762df 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -11,11 +11,20 @@ name = "glide" crate-type = ["cdylib"] [dependencies] -pyo3 = { version = "^0.20", features = ["extension-module", "num-bigint"] } -bytes = { version = "1.6.0" } -redis = { path = "../glide-core/redis-rs/redis", features = ["aio", "tokio-comp", "connection-manager","tokio-rustls-comp"] } +pyo3 = { version = "^0.22", features = [ + "extension-module", + "num-bigint", + "gil-refs", +] } +bytes = { version = "^1.8" } +redis = { path = "../glide-core/redis-rs/redis", features = [ + "aio", + "tokio-comp", + "connection-manager", + "tokio-rustls-comp", +] } glide-core = { path = "../glide-core", features = ["socket-layer"] } -logger_core = {path = "../logger_core"} +logger_core = { path = "../logger_core" } [package.metadata.maturin] python-source = "python" diff --git a/python/python/tests/conftest.py b/python/python/tests/conftest.py index 65ef1eb540..9b1db487da 100644 --- a/python/python/tests/conftest.py +++ b/python/python/tests/conftest.py @@ -18,7 +18,7 @@ DEFAULT_HOST = "localhost" DEFAULT_PORT = 6379 -DEFAULT_TEST_LOG_LEVEL = logLevel.WARN +DEFAULT_TEST_LOG_LEVEL = logLevel.OFF Logger.set_logger_config(DEFAULT_TEST_LOG_LEVEL) diff --git a/python/python/tests/test_pubsub.py b/python/python/tests/test_pubsub.py index 4d1d344c0e..c6335ae242 100644 --- a/python/python/tests/test_pubsub.py +++ b/python/python/tests/test_pubsub.py @@ -13,9 +13,9 @@ GlideClusterClientConfiguration, ProtocolVersion, ) -from glide.constants import OK, TEncodable +from glide.constants import OK from glide.exceptions import ConfigurationError -from glide.glide_client import BaseClient, GlideClient, GlideClusterClient, TGlideClient +from glide.glide_client import GlideClient, GlideClusterClient, TGlideClient from tests.conftest import create_client from tests.utils.utils import check_if_server_version_lt, get_random_string @@ -2471,7 +2471,7 @@ async def test_pubsub_shardchannels(self, request, cluster_mode: bool): This test verifies that the pubsub_shardchannels command correctly returns the active sharded channels matching a specified pattern. """ - client1, client2, client = None, None, None + pub_sub, client1, client2, client = None, None, None, None try: channel1 = "test_shardchannel1" channel2 = "test_shardchannel2" @@ -2506,10 +2506,6 @@ async def test_pubsub_shardchannels(self, request, cluster_mode: bool): request, cluster_mode, pub_sub ) - min_version = "7.0.0" - if await check_if_server_version_lt(client1, min_version): - pytest.skip(reason=f"Valkey version required >= {min_version}") - assert type(client2) == GlideClusterClient # Test pubsub_shardchannels without pattern @@ -2596,10 +2592,6 @@ async def test_pubsub_shardnumsub(self, request, cluster_mode: bool): request, cluster_mode, pub_sub1, pub_sub2 ) - min_version = "7.0.0" - if await check_if_server_version_lt(client1, min_version): - pytest.skip(reason=f"Valkey version required >= {min_version}") - client3, client4 = await create_two_clients_with_pubsub( request, cluster_mode, pub_sub3 ) diff --git a/python/src/lib.rs b/python/src/lib.rs index 994c7f7b4e..8a1a0d3444 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -10,12 +10,13 @@ use pyo3::prelude::*; use pyo3::types::{PyAny, PyBool, PyBytes, PyDict, PyFloat, PyList, PySet}; use pyo3::Python; use redis::Value; +use std::sync::Arc; pub const DEFAULT_TIMEOUT_IN_MILLISECONDS: u32 = glide_core::client::DEFAULT_RESPONSE_TIMEOUT.as_millis() as u32; pub const MAX_REQUEST_ARGS_LEN: u32 = MAX_REQUEST_ARGS_LENGTH as u32; -#[pyclass] +#[pyclass(eq, eq_int)] #[derive(PartialEq, Eq, PartialOrd, Clone)] pub enum Level { Error = 0, @@ -48,6 +49,7 @@ pub struct ClusterScanCursor { #[pymethods] impl ClusterScanCursor { #[new] + #[pyo3(signature = (new_cursor=None))] fn new(new_cursor: Option) -> Self { match new_cursor { Some(cursor) => ClusterScanCursor { cursor }, @@ -78,7 +80,7 @@ pub struct Script { #[pymethods] impl Script { #[new] - fn new(code: &PyAny) -> PyResult { + fn new(code: &Bound) -> PyResult { let hash = if let Ok(code_str) = code.extract::() { glide_core::scripts_container::add_script(code_str.as_bytes()) } else if let Ok(code_bytes) = code.extract::<&PyBytes>() { @@ -103,7 +105,7 @@ impl Script { /// A Python module implemented in Rust. #[pymodule] -fn glide(_py: Python, m: &PyModule) -> PyResult<()> { +fn glide(_py: Python, m: &Bound) -> PyResult<()> { m.add_class::()?; m.add_class::