diff --git a/CHANGELOG.md b/CHANGELOG.md index 784ab9ec49..365fa9f677 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ * Python: Added GEOPOS command ([#1301](https://github.com/aws/glide-for-redis/pull/1301)) * Node: Added PFADD command ([#1317](https://github.com/aws/glide-for-redis/pull/1317)) * Python: Added PFADD command ([#1315](https://github.com/aws/glide-for-redis/pull/1315)) +* Python: Added ZMSCORE command ([#1357](https://github.com/aws/glide-for-redis/pull/1357)) * Python: Added HRANDFIELD command ([#1334](https://github.com/aws/glide-for-redis/pull/1334)) #### Fixes diff --git a/glide-core/src/client/value_conversion.rs b/glide-core/src/client/value_conversion.rs index 994a0551f0..154846d9fc 100644 --- a/glide-core/src/client/value_conversion.rs +++ b/glide-core/src/client/value_conversion.rs @@ -17,6 +17,7 @@ pub(crate) enum ExpectedReturnType { ZrankReturnType, JsonToggleReturnType, ArrayOfBools, + ArrayOfDoubleOrNull, Lolwut, ArrayOfArraysOfDoubleOrNull, ArrayOfKeyValuePairs, @@ -164,16 +165,7 @@ pub(crate) fn convert_to_expected_type( .into()), }, ExpectedReturnType::ArrayOfBools => match value { - Value::Array(array) => { - let array_of_bools = array - .iter() - .map(|v| { - convert_to_expected_type(v.clone(), Some(ExpectedReturnType::Boolean)) - .unwrap() - }) - .collect(); - Ok(Value::Array(array_of_bools)) - } + Value::Array(array) => convert_array_elements(array, ExpectedReturnType::Boolean), _ => Err(( ErrorKind::TypeError, "Response couldn't be converted to an array of boolean", @@ -181,6 +173,15 @@ pub(crate) fn convert_to_expected_type( ) .into()), }, + ExpectedReturnType::ArrayOfDoubleOrNull => match value { + Value::Array(array) => convert_array_elements(array, ExpectedReturnType::DoubleOrNull), + _ => Err(( + ErrorKind::TypeError, + "Response couldn't be converted to an array of doubles", + format!("(response was {:?})", value), + ) + .into()), + }, ExpectedReturnType::ArrayOfArraysOfDoubleOrNull => match value { // This is used for GEOPOS command. Value::Array(array) => { @@ -303,6 +304,21 @@ fn convert_lolwut_string(data: &str) -> String { } } +/// Converts elements in an array to the specified type. +/// +/// `array` is an array of values. +/// `element_type` is the type that the array elements should be converted to. +fn convert_array_elements( + array: Vec, + element_type: ExpectedReturnType, +) -> RedisResult { + let converted_array = array + .iter() + .map(|v| convert_to_expected_type(v.clone(), Some(element_type)).unwrap()) + .collect(); + Ok(Value::Array(converted_array)) +} + fn convert_array_to_map( array: Vec, key_expected_return_type: Option, @@ -385,6 +401,7 @@ pub(crate) fn expected_type_for_cmd(cmd: &Cmd) -> Option { b"SMISMEMBER" => Some(ExpectedReturnType::ArrayOfBools), b"SMEMBERS" | b"SINTER" => Some(ExpectedReturnType::Set), b"ZSCORE" | b"GEODIST" => Some(ExpectedReturnType::DoubleOrNull), + b"ZMSCORE" => Some(ExpectedReturnType::ArrayOfDoubleOrNull), b"ZPOPMIN" | b"ZPOPMAX" => Some(ExpectedReturnType::MapOfStringToDouble), b"JSON.TOGGLE" => Some(ExpectedReturnType::JsonToggleReturnType), b"GEOPOS" => Some(ExpectedReturnType::ArrayOfArraysOfDoubleOrNull), @@ -637,6 +654,35 @@ mod tests { assert!(expected_type_for_cmd(redis::cmd("ZREVRANK").arg("key").arg("member")).is_none()); } + #[test] + fn convert_zmscore() { + assert!(matches!( + expected_type_for_cmd(redis::cmd("ZMSCORE").arg("key").arg("member")), + Some(ExpectedReturnType::ArrayOfDoubleOrNull) + )); + + let array_response = Value::Array(vec![ + Value::Nil, + Value::Double(1.5), + Value::BulkString(b"2.5".to_vec()), + ]); + let converted_response = convert_to_expected_type( + array_response, + Some(ExpectedReturnType::ArrayOfDoubleOrNull), + ) + .unwrap(); + let expected_response = + Value::Array(vec![Value::Nil, Value::Double(1.5), Value::Double(2.5)]); + assert_eq!(expected_response, converted_response); + + let unexpected_response_type = Value::Double(0.5); + assert!(convert_to_expected_type( + unexpected_response_type, + Some(ExpectedReturnType::ArrayOfDoubleOrNull) + ) + .is_err()); + } + #[test] fn convert_smove_to_bool() { assert!(matches!( diff --git a/python/python/glide/async_commands/core.py b/python/python/glide/async_commands/core.py index 3af69406e0..f6dc886fc2 100644 --- a/python/python/glide/async_commands/core.py +++ b/python/python/glide/async_commands/core.py @@ -2402,6 +2402,33 @@ async def zscore(self, key: str, member: str) -> Optional[float]: await self._execute_command(RequestType.ZScore, [key, member]), ) + async def zmscore( + self, + key: str, + members: List[str], + ) -> List[Optional[float]]: + """ + Returns the scores associated with the specified `members` in the sorted set stored at `key`. + + See https://valkey.io/commands/zmscore for more details. + + Args: + key (str): The key of the sorted set. + members (List[str]): A list of members in the sorted set. + + Returns: + List[Optional[float]]: A list of scores corresponding to `members`. + If a member does not exist in the sorted set, the corresponding value in the list will be None. + + Examples: + >>> await client.zmscore("my_sorted_set", ["one", "non_existent_member", "three"]) + [1.0, None, 3.0] + """ + return cast( + List[Optional[float]], + await self._execute_command(RequestType.ZMScore, [key] + members), + ) + async def invoke_script( self, script: Script, diff --git a/python/python/glide/async_commands/transaction.py b/python/python/glide/async_commands/transaction.py index df78a529f4..a90527e586 100644 --- a/python/python/glide/async_commands/transaction.py +++ b/python/python/glide/async_commands/transaction.py @@ -1807,6 +1807,22 @@ def zscore(self: TTransaction, key: str, member: str) -> TTransaction: """ return self.append_command(RequestType.ZScore, [key, member]) + def zmscore(self: TTransaction, key: str, members: List[str]) -> TTransaction: + """ + Returns the scores associated with the specified `members` in the sorted set stored at `key`. + + See https://valkey.io/commands/zmscore for more details. + + Args: + key (str): The key of the sorted set. + members (List[str]): A list of members in the sorted set. + + Command response: + List[Optional[float]]: A list of scores corresponding to `members`. + If a member does not exist in the sorted set, the corresponding value in the list will be None. + """ + return self.append_command(RequestType.ZMScore, [key] + members) + def dbsize(self: TTransaction) -> TTransaction: """ Returns the number of keys in the currently selected database. diff --git a/python/python/tests/test_async_client.py b/python/python/tests/test_async_client.py index 92cd21bb19..eee211a238 100644 --- a/python/python/tests/test_async_client.py +++ b/python/python/tests/test_async_client.py @@ -1778,6 +1778,28 @@ async def test_zscore(self, redis_client: TRedisClient): await redis_client.zscore("non_existing_key", "non_existing_member") is None ) + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_zmscore(self, redis_client: TRedisClient): + key1 = get_random_string(10) + key2 = get_random_string(10) + members_scores = {"one": 1, "two": 2, "three": 3} + + assert await redis_client.zadd(key1, members_scores=members_scores) == 3 + assert await redis_client.zmscore(key1, ["one", "two", "three"]) == [ + 1.0, + 2.0, + 3.0, + ] + assert await redis_client.zmscore( + key1, ["one", "non_existing_member", "non_existing_member", "three"] + ) == [1.0, None, None, 3.0] + assert await redis_client.zmscore("non_existing_key", ["one"]) == [None] + + assert await redis_client.set(key2, "value") == OK + with pytest.raises(RequestError): + await redis_client.zmscore(key2, ["one"]) + @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) async def test_zpopmin(self, redis_client: TRedisClient): diff --git a/python/python/tests/test_transaction.py b/python/python/tests/test_transaction.py index dd1003212b..24aba819ff 100644 --- a/python/python/tests/test_transaction.py +++ b/python/python/tests/test_transaction.py @@ -209,6 +209,8 @@ async def transaction_test( args.append(["two", "three", "four"]) transaction.zrange_withscores(key8, RangeByIndex(start=0, stop=-1)) args.append({"two": 2, "three": 3, "four": 4}) + transaction.zmscore(key8, ["two", "three"]) + args.append([2.0, 3.0]) transaction.zpopmin(key8) args.append({"two": 2.0}) transaction.zpopmax(key8)