Skip to content

Commit

Permalink
Python: add ZMSCORE command (#1357)
Browse files Browse the repository at this point in the history
* Python: add ZMSCORE command (#250)

* Update CHANGELOG with PR link

* PR suggestions

* Create convert_array_elements function

* Fix rust formatting

* PR suggestions

* PR suggestions

* Fix rust formatting

---------

Co-authored-by: Andrew Carbonetto <[email protected]>
  • Loading branch information
aaron-congo and acarbonetto authored May 1, 2024
1 parent d8bdfa7 commit 47364d3
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* Python: Added GEOPOS command ([#1301](https://github.com/aws/glide-for-redis/pull/1301))
* Node: Added PFADD command ([#1317](https://github.com/aws/glide-for-redis/pull/1317))
* Python: Added PFADD command ([#1315](https://github.com/aws/glide-for-redis/pull/1315))
* Python: Added ZMSCORE command ([#1357](https://github.com/aws/glide-for-redis/pull/1357))
* Python: Added HRANDFIELD command ([#1334](https://github.com/aws/glide-for-redis/pull/1334))

#### Fixes
Expand Down
66 changes: 56 additions & 10 deletions glide-core/src/client/value_conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub(crate) enum ExpectedReturnType {
ZrankReturnType,
JsonToggleReturnType,
ArrayOfBools,
ArrayOfDoubleOrNull,
Lolwut,
ArrayOfArraysOfDoubleOrNull,
ArrayOfKeyValuePairs,
Expand Down Expand Up @@ -164,23 +165,23 @@ pub(crate) fn convert_to_expected_type(
.into()),
},
ExpectedReturnType::ArrayOfBools => match value {
Value::Array(array) => {
let array_of_bools = array
.iter()
.map(|v| {
convert_to_expected_type(v.clone(), Some(ExpectedReturnType::Boolean))
.unwrap()
})
.collect();
Ok(Value::Array(array_of_bools))
}
Value::Array(array) => convert_array_elements(array, ExpectedReturnType::Boolean),
_ => Err((
ErrorKind::TypeError,
"Response couldn't be converted to an array of boolean",
format!("(response was {:?})", value),
)
.into()),
},
ExpectedReturnType::ArrayOfDoubleOrNull => match value {
Value::Array(array) => convert_array_elements(array, ExpectedReturnType::DoubleOrNull),
_ => Err((
ErrorKind::TypeError,
"Response couldn't be converted to an array of doubles",
format!("(response was {:?})", value),
)
.into()),
},
ExpectedReturnType::ArrayOfArraysOfDoubleOrNull => match value {
// This is used for GEOPOS command.
Value::Array(array) => {
Expand Down Expand Up @@ -303,6 +304,21 @@ fn convert_lolwut_string(data: &str) -> String {
}
}

/// Converts elements in an array to the specified type.
///
/// `array` is an array of values.
/// `element_type` is the type that the array elements should be converted to.
fn convert_array_elements(
array: Vec<Value>,
element_type: ExpectedReturnType,
) -> RedisResult<Value> {
let converted_array = array
.iter()
.map(|v| convert_to_expected_type(v.clone(), Some(element_type)).unwrap())
.collect();
Ok(Value::Array(converted_array))
}

fn convert_array_to_map(
array: Vec<Value>,
key_expected_return_type: Option<ExpectedReturnType>,
Expand Down Expand Up @@ -385,6 +401,7 @@ pub(crate) fn expected_type_for_cmd(cmd: &Cmd) -> Option<ExpectedReturnType> {
b"SMISMEMBER" => Some(ExpectedReturnType::ArrayOfBools),
b"SMEMBERS" | b"SINTER" => Some(ExpectedReturnType::Set),
b"ZSCORE" | b"GEODIST" => Some(ExpectedReturnType::DoubleOrNull),
b"ZMSCORE" => Some(ExpectedReturnType::ArrayOfDoubleOrNull),
b"ZPOPMIN" | b"ZPOPMAX" => Some(ExpectedReturnType::MapOfStringToDouble),
b"JSON.TOGGLE" => Some(ExpectedReturnType::JsonToggleReturnType),
b"GEOPOS" => Some(ExpectedReturnType::ArrayOfArraysOfDoubleOrNull),
Expand Down Expand Up @@ -637,6 +654,35 @@ mod tests {
assert!(expected_type_for_cmd(redis::cmd("ZREVRANK").arg("key").arg("member")).is_none());
}

#[test]
fn convert_zmscore() {
assert!(matches!(
expected_type_for_cmd(redis::cmd("ZMSCORE").arg("key").arg("member")),
Some(ExpectedReturnType::ArrayOfDoubleOrNull)
));

let array_response = Value::Array(vec![
Value::Nil,
Value::Double(1.5),
Value::BulkString(b"2.5".to_vec()),
]);
let converted_response = convert_to_expected_type(
array_response,
Some(ExpectedReturnType::ArrayOfDoubleOrNull),
)
.unwrap();
let expected_response =
Value::Array(vec![Value::Nil, Value::Double(1.5), Value::Double(2.5)]);
assert_eq!(expected_response, converted_response);

let unexpected_response_type = Value::Double(0.5);
assert!(convert_to_expected_type(
unexpected_response_type,
Some(ExpectedReturnType::ArrayOfDoubleOrNull)
)
.is_err());
}

#[test]
fn convert_smove_to_bool() {
assert!(matches!(
Expand Down
27 changes: 27 additions & 0 deletions python/python/glide/async_commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2402,6 +2402,33 @@ async def zscore(self, key: str, member: str) -> Optional[float]:
await self._execute_command(RequestType.ZScore, [key, member]),
)

async def zmscore(
self,
key: str,
members: List[str],
) -> List[Optional[float]]:
"""
Returns the scores associated with the specified `members` in the sorted set stored at `key`.
See https://valkey.io/commands/zmscore for more details.
Args:
key (str): The key of the sorted set.
members (List[str]): A list of members in the sorted set.
Returns:
List[Optional[float]]: A list of scores corresponding to `members`.
If a member does not exist in the sorted set, the corresponding value in the list will be None.
Examples:
>>> await client.zmscore("my_sorted_set", ["one", "non_existent_member", "three"])
[1.0, None, 3.0]
"""
return cast(
List[Optional[float]],
await self._execute_command(RequestType.ZMScore, [key] + members),
)

async def invoke_script(
self,
script: Script,
Expand Down
16 changes: 16 additions & 0 deletions python/python/glide/async_commands/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -1807,6 +1807,22 @@ def zscore(self: TTransaction, key: str, member: str) -> TTransaction:
"""
return self.append_command(RequestType.ZScore, [key, member])

def zmscore(self: TTransaction, key: str, members: List[str]) -> TTransaction:
"""
Returns the scores associated with the specified `members` in the sorted set stored at `key`.
See https://valkey.io/commands/zmscore for more details.
Args:
key (str): The key of the sorted set.
members (List[str]): A list of members in the sorted set.
Command response:
List[Optional[float]]: A list of scores corresponding to `members`.
If a member does not exist in the sorted set, the corresponding value in the list will be None.
"""
return self.append_command(RequestType.ZMScore, [key] + members)

def dbsize(self: TTransaction) -> TTransaction:
"""
Returns the number of keys in the currently selected database.
Expand Down
22 changes: 22 additions & 0 deletions python/python/tests/test_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1778,6 +1778,28 @@ async def test_zscore(self, redis_client: TRedisClient):
await redis_client.zscore("non_existing_key", "non_existing_member") is None
)

@pytest.mark.parametrize("cluster_mode", [True, False])
@pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3])
async def test_zmscore(self, redis_client: TRedisClient):
key1 = get_random_string(10)
key2 = get_random_string(10)
members_scores = {"one": 1, "two": 2, "three": 3}

assert await redis_client.zadd(key1, members_scores=members_scores) == 3
assert await redis_client.zmscore(key1, ["one", "two", "three"]) == [
1.0,
2.0,
3.0,
]
assert await redis_client.zmscore(
key1, ["one", "non_existing_member", "non_existing_member", "three"]
) == [1.0, None, None, 3.0]
assert await redis_client.zmscore("non_existing_key", ["one"]) == [None]

assert await redis_client.set(key2, "value") == OK
with pytest.raises(RequestError):
await redis_client.zmscore(key2, ["one"])

@pytest.mark.parametrize("cluster_mode", [True, False])
@pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3])
async def test_zpopmin(self, redis_client: TRedisClient):
Expand Down
2 changes: 2 additions & 0 deletions python/python/tests/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ async def transaction_test(
args.append(["two", "three", "four"])
transaction.zrange_withscores(key8, RangeByIndex(start=0, stop=-1))
args.append({"two": 2, "three": 3, "four": 4})
transaction.zmscore(key8, ["two", "three"])
args.append([2.0, 3.0])
transaction.zpopmin(key8)
args.append({"two": 2.0})
transaction.zpopmax(key8)
Expand Down

0 comments on commit 47364d3

Please sign in to comment.