diff --git a/.github/DEVELOPER.md b/.github/DEVELOPER.md index 2acc4ccb68..9b45af17e0 100644 --- a/.github/DEVELOPER.md +++ b/.github/DEVELOPER.md @@ -133,7 +133,7 @@ We use dynamic matrices for our CI/CD workflows, which are created using the `cr 4. It generates three matrices: - Engine matrix: Defines the types and versions of the engine to test against, for example Valkey 7.2.5. - Host matrix: Defines the host platforms to run the tests on, for example Ubuntu on ARM64. - - Language-version matrix: Defines the supported versions of languages, for example python 3.8. + - Language-version matrix: Defines the supported versions of languages, for example python 3.9. #### Outputs diff --git a/.github/json_matrices/build-matrix.json b/.github/json_matrices/build-matrix.json index dff057084b..0d0e7c10bb 100644 --- a/.github/json_matrices/build-matrix.json +++ b/.github/json_matrices/build-matrix.json @@ -16,22 +16,12 @@ "ARCH": "arm64", "TARGET": "aarch64-unknown-linux-gnu", "PACKAGE_MANAGERS": ["pypi", "npm", "maven"], - "CONTAINER": "2_28", "languages": ["python", "node", "java", "go", "dotnet"] }, { "OS": "macos", "NAMED_OS": "darwin", - "RUNNER": "macos-12", - "ARCH": "x64", - "TARGET": "x86_64-apple-darwin", - "PACKAGE_MANAGERS": ["pypi", "npm", "maven"], - "languages": ["python", "node", "java", "go", "dotnet"] - }, - { - "OS": "macos", - "NAMED_OS": "darwin", - "RUNNER": "macos-latest", + "RUNNER": "macos-14", "ARCH": "arm64", "TARGET": "aarch64-apple-darwin", "PACKAGE_MANAGERS": ["pypi", "npm", "maven"], @@ -43,7 +33,7 @@ "ARCH": "arm64", "TARGET": "aarch64-unknown-linux-musl", "RUNNER": ["self-hosted", "Linux", "ARM64"], - "IMAGE": "node:lts-alpine3.19", + "IMAGE": "node:lts-alpine", "CONTAINER_OPTIONS": "--user root --privileged --rm", "PACKAGE_MANAGERS": ["npm"], "languages": ["node"] @@ -54,7 +44,7 @@ "ARCH": "x64", "TARGET": "x86_64-unknown-linux-musl", "RUNNER": "ubuntu-latest", - "IMAGE": "node:lts-alpine3.19", + "IMAGE": "node:lts-alpine", "CONTAINER_OPTIONS": "--user root --privileged", "PACKAGE_MANAGERS": ["npm"], "languages": ["node"] diff --git a/.github/json_matrices/supported-languages-versions.json b/.github/json_matrices/supported-languages-versions.json index 0685ecf0ea..b3712a51ea 100644 --- a/.github/json_matrices/supported-languages-versions.json +++ b/.github/json_matrices/supported-languages-versions.json @@ -6,13 +6,13 @@ }, { "language": "python", - "versions": ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"], - "always-run-versions": ["3.8", "3.13"] + "versions": ["3.9", "3.10", "3.11", "3.12", "3.13"], + "always-run-versions": ["3.9", "3.13"] }, { "language": "node", - "versions": ["16.x", "17.x", "18.x", "19.x", "20.x"], - "always-run-versions": ["16.x", "20.x"] + "versions": ["16.x", "17.x", "18.x", "19.x", "20.x", "21.x", "22.x"], + "always-run-versions": ["16.x", "22.x"] }, { "language": "dotnet", diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index a24c04af12..fde89563bf 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -8,6 +8,7 @@ on: - "v.?[0-9]+.[0-9]+" - "v?[0-9]+.[0-9]+.[0-9]+" - "v?[0-9]+.[0-9]+" + - release-* pull_request: branches: - "main" @@ -15,6 +16,7 @@ on: - "v.?[0-9]+.[0-9]+" - "v?[0-9]+.[0-9]+.[0-9]+" - "v?[0-9]+.[0-9]+" + - release-* schedule: - cron: "37 18 * * 6" diff --git a/.github/workflows/csharp.yml b/.github/workflows/csharp.yml index 100efee036..1cd5778a5c 100644 --- a/.github/workflows/csharp.yml +++ b/.github/workflows/csharp.yml @@ -9,8 +9,8 @@ on: paths: - csharp/** - glide-core/src/** + - glide-core/redis-rs/redis/src/** - utils/cluster_manager.py - - submodules/** - .github/workflows/csharp.yml - .github/workflows/install-shared-dependencies/action.yml - .github/workflows/test-benchmark/action.yml @@ -22,7 +22,7 @@ on: paths: - csharp/** - glide-core/src/** - - submodules/** + - glide-core/redis-rs/redis/src/** - utils/cluster_manager.py - .github/workflows/csharp.yml - .github/workflows/install-shared-dependencies/action.yml @@ -87,8 +87,6 @@ jobs: steps: - uses: actions/checkout@v4 - with: - submodules: recursive - name: Set up dotnet ${{ matrix.dotnet }} uses: actions/setup-dotnet@v4 @@ -198,8 +196,6 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - with: - submodules: recursive - name: lint rust uses: ./.github/workflows/lint-rust diff --git a/.github/workflows/full-matrix-tests.yml b/.github/workflows/full-matrix-tests.yml index 8a26e6a75e..ce812dc0fd 100644 --- a/.github/workflows/full-matrix-tests.yml +++ b/.github/workflows/full-matrix-tests.yml @@ -59,11 +59,11 @@ jobs: name: Run CI for GLIDE core lib secrets: inherit - # run-full-tests-for-redis-rs: - # if: (github.repository_owner == 'valkey-io' && github.event_name == 'schedule') || (github.event_name == 'workflow_dispatch' && inputs.redis-rs == true) - # uses: ./.github/workflows/redis-rs.yml - # name: Run CI for Redis-RS client - # secrets: inherit + run-full-tests-for-redis-rs: + if: (github.repository_owner == 'valkey-io' && github.event_name == 'schedule') || (github.event_name == 'workflow_dispatch' && inputs.redis-rs == true) + uses: ./.github/workflows/redis-rs.yml + name: Run CI for Redis-RS client + secrets: inherit run-full-tests-for-java: if: (github.repository_owner == 'valkey-io' && github.event_name == 'schedule') || (github.event_name == 'workflow_dispatch' && inputs.java == true) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index f7bf6c06da..1d17640188 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -8,7 +8,7 @@ on: - v* paths: - glide-core/src/** - - submodules/** + - glide-core/redis-rs/redis/src/** - utils/cluster_manager.py - go/** - .github/workflows/go.yml @@ -21,7 +21,7 @@ on: pull_request: paths: - glide-core/src/** - - submodules/** + - glide-core/redis-rs/redis/src/** - utils/cluster_manager.py - go/** - .github/workflows/go.yml @@ -84,8 +84,6 @@ jobs: steps: - uses: actions/checkout@v4 - with: - submodules: recursive - name: Set up Go ${{ matrix.go }} uses: actions/setup-go@v5 @@ -136,8 +134,6 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - with: - submodules: recursive - uses: ./.github/workflows/lint-rust with: diff --git a/.github/workflows/install-shared-dependencies/action.yml b/.github/workflows/install-shared-dependencies/action.yml index 57647f3211..57e750ccee 100644 --- a/.github/workflows/install-shared-dependencies/action.yml +++ b/.github/workflows/install-shared-dependencies/action.yml @@ -25,7 +25,6 @@ inputs: description: "Engine version to install" required: false type: string - github-token: description: "GITHUB_TOKEN, GitHub App installation access token" required: true @@ -39,7 +38,7 @@ runs: if: "${{ inputs.os == 'macos' }}" run: | brew update - brew install git openssl coreutils + brew install openssl coreutils - name: Install software dependencies for Ubuntu GNU shell: bash diff --git a/.github/workflows/java-cd.yml b/.github/workflows/java-cd.yml index 6f21a2a517..f4c0146342 100644 --- a/.github/workflows/java-cd.yml +++ b/.github/workflows/java-cd.yml @@ -86,9 +86,6 @@ jobs: echo "No cleaning needed" fi - uses: actions/checkout@v4 - with: - submodules: recursive - - name: Set up JDK uses: actions/setup-java@v4 with: @@ -230,8 +227,6 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - with: - submodules: recursive - name: Set up JDK uses: actions/setup-java@v4 diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index fbb88e1310..66c99cca3e 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -8,7 +8,7 @@ on: - v* paths: - glide-core/src/** - - submodules/** + - glide-core/redis-rs/redis/src/** - java/** - utils/cluster_manager.py - .github/workflows/java.yml @@ -22,7 +22,7 @@ on: pull_request: paths: - glide-core/src/** - - submodules/** + - glide-core/redis-rs/redis/src/** - java/** - utils/cluster_manager.py - .github/workflows/java.yml @@ -83,8 +83,6 @@ jobs: steps: - uses: actions/checkout@v4 - with: - submodules: recursive - uses: gradle/actions/wrapper-validation@v3 @@ -117,6 +115,7 @@ jobs: run: ./gradlew spotlessDiagnose | grep 'All formatters are well behaved for all files' - uses: ./.github/workflows/test-benchmark + if: ${{ matrix.engine.version == '8.0' && matrix.host.RUNNER == 'ubuntu-latest' && matrix.java == '17' }} with: language-flag: -java @@ -208,8 +207,6 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - with: - submodules: recursive - uses: ./.github/workflows/lint-rust with: @@ -218,9 +215,9 @@ jobs: name: lint java rust test-modules: - if: github.event.pull_request.head.repo.owner.login == 'valkey-io' + if: (github.repository_owner == 'valkey-io' && github.event_name == 'workflow_dispatch') || github.event.pull_request.head.repo.owner.login == 'valkey-io' environment: AWS_ACTIONS - name: Running Module Tests + name: Modules Tests runs-on: [self-hosted, linux, ARM64] timeout-minutes: 15 steps: @@ -228,8 +225,6 @@ jobs: run: sudo chown -R $USER:$USER /home/ubuntu/actions-runner/_work/valkey-glide - uses: actions/checkout@v4 - with: - submodules: recursive - name: Set up JDK uses: actions/setup-java@v4 diff --git a/.github/workflows/lint-rust/action.yml b/.github/workflows/lint-rust/action.yml index aa8c433660..35c5e313c5 100644 --- a/.github/workflows/lint-rust/action.yml +++ b/.github/workflows/lint-rust/action.yml @@ -14,8 +14,6 @@ runs: using: "composite" steps: - uses: actions/checkout@v4 - with: - submodules: recursive - name: Install Rust toolchain and protoc uses: ./.github/workflows/install-rust-and-protoc @@ -39,7 +37,7 @@ runs: - run: | cargo update - cargo install --locked --version 0.15.1 cargo-deny + cargo install --locked cargo-deny cargo deny check --config ${GITHUB_WORKSPACE}/deny.toml working-directory: ${{ inputs.cargo-toml-folder }} shell: bash diff --git a/.github/workflows/node.yml b/.github/workflows/node.yml index 19474cbd8e..a91a853534 100644 --- a/.github/workflows/node.yml +++ b/.github/workflows/node.yml @@ -8,7 +8,7 @@ on: - v* paths: - glide-core/src/** - - submodules/** + - glide-core/redis-rs/redis/src/** - node/** - utils/cluster_manager.py - .github/workflows/node.yml @@ -22,7 +22,7 @@ on: pull_request: paths: - glide-core/src/** - - submodules/** + - glide-core/redis-rs/redis/src/** - node/** - utils/cluster_manager.py - .github/workflows/node.yml @@ -85,8 +85,6 @@ jobs: node: ${{ fromJson(needs.get-matrices.outputs.version-matrix-output) }} steps: - uses: actions/checkout@v4 - with: - submodules: recursive - name: Setup Node uses: actions/setup-node@v4 @@ -110,6 +108,7 @@ jobs: working-directory: ./node - name: test hybrid node modules - commonjs + if: ${{ matrix.engine.version == '8.0' && matrix.host.OS == 'ubuntu' && matrix.host.RUNNER == 'ubuntu-latest' && matrix.node == '20.x' }} run: | npm install npm run test @@ -118,6 +117,7 @@ jobs: JEST_HTML_REPORTER_OUTPUT_PATH: test-report-commonjs.html - name: test hybrid node modules - ecma + if: ${{ matrix.engine.version == '8.0' && matrix.host.OS == 'ubuntu' && matrix.host.RUNNER == 'ubuntu-latest' && matrix.node == '20.x' }} run: | npm install npm run test @@ -126,6 +126,7 @@ jobs: JEST_HTML_REPORTER_OUTPUT_PATH: test-report-ecma.html - uses: ./.github/workflows/test-benchmark + if: ${{ matrix.engine.version == '8.0' && matrix.host.OS == 'ubuntu' && matrix.host.RUNNER == 'ubuntu-latest' && matrix.node == '20.x' }} with: language-flag: -node @@ -145,14 +146,12 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - with: - submodules: recursive - - uses: ./.github/workflows/lint-rust + - name: lint node rust + uses: ./.github/workflows/lint-rust with: cargo-toml-folder: ./node/rust-client github-token: ${{ secrets.GITHUB_TOKEN }} - name: lint node rust get-containers: runs-on: ubuntu-latest @@ -197,8 +196,6 @@ jobs: echo IMAGE=amazonlinux:latest | sed -r 's/:/-/g' >> $GITHUB_ENV # Replace `:` in the variable otherwise it can't be used in `upload-artifact` - uses: actions/checkout@v4 - with: - submodules: recursive - name: Setup musl on Linux if: ${{ contains(matrix.host.TARGET, 'musl') }} @@ -239,3 +236,43 @@ jobs: node/test-report*.html utils/clusters/** benchmarks/results/** + + test-modules: + if: (github.repository_owner == 'valkey-io' && github.event_name == 'workflow_dispatch') || github.event.pull_request.head.repo.owner.login == 'valkey-io' + environment: AWS_ACTIONS + name: Running Module Tests + runs-on: [self-hosted, linux, ARM64] + timeout-minutes: 15 + + steps: + - name: Setup self-hosted runner access + run: sudo chown -R $USER:$USER /home/ubuntu/actions-runner/_work/valkey-glide + + - uses: actions/checkout@v4 + + - name: Install Node.js + uses: actions/setup-node@v4 + with: + node-version: 20.x + + - name: Build Node wrapper + uses: ./.github/workflows/build-node-wrapper + with: + os: ubuntu + named_os: linux + arch: arm64 + target: aarch64-unknown-linux-gnu + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: test + run: npm run test-modules -- --cluster-endpoints=${{ secrets.MEMDB_MODULES_ENDPOINT }} --tls=true + working-directory: ./node + + - name: Upload test reports + if: always() + continue-on-error: true + uses: actions/upload-artifact@v4 + with: + name: test-report-node-modules-ubuntu + path: | + node/test-report*.html diff --git a/.github/workflows/npm-cd.yml b/.github/workflows/npm-cd.yml index 4e58dc4c81..24497e83cc 100644 --- a/.github/workflows/npm-cd.yml +++ b/.github/workflows/npm-cd.yml @@ -120,10 +120,10 @@ jobs: INPUT_VERSION: ${{ github.event.inputs.version }} - name: Setup node - if: ${{ matrix.build.TARGET != 'aarch64-unknown-linux-musl' }} - uses: actions/setup-node@v3 + if: ${{ !contains(matrix.build.TARGET, 'musl') }} + uses: actions/setup-node@v4 with: - node-version: "20" + node-version: "latest" registry-url: "https://registry.npmjs.org" architecture: ${{ matrix.build.ARCH }} scope: "${{ vars.NPM_SCOPE }}" @@ -131,7 +131,7 @@ jobs: token: ${{ secrets.NPM_AUTH_TOKEN }} - name: Setup node for publishing - if: ${{ matrix.build.TARGET == 'aarch64-unknown-linux-musl' }} + if: ${{ !contains(matrix.build.TARGET, 'musl') }} working-directory: ./node run: | npm config set registry https://registry.npmjs.org/ @@ -181,18 +181,22 @@ jobs: working-directory: ./node run: | npm pkg fix - set +e - # 2>&1 1>&3- redirects stderr to stdout and then redirects the original stdout to another file descriptor, - # effectively separating stderr and stdout. The 3>&1 at the end redirects the original stdout back to the console. - # https://github.com/npm/npm/issues/118#issuecomment-325440 - ignoring notice messages since currentlly they are directed to stderr - { npm_publish_err=$(npm publish --tag ${{ env.NPM_TAG }} --access public 2>&1 1>&3- | grep -Ev "notice|ExperimentalWarning") ;} 3>&1 - if [[ "$npm_publish_err" == *"You cannot publish over the previously published versions"* ]] - then - echo "Skipping publishing, package already published" - elif [[ ! -z "$npm_publish_err" ]] - then - echo "Failed to publish with error: ${npm_publish_err}" - exit 1 + set +e # Disable immediate exit on non-zero exit codes + + # Redirect stderr to stdout, filter out notices and warnings + { npm_publish_err=$(npm publish --tag "${NPM_TAG}" --access public --loglevel=error 2>&1 1>&3- | grep -Ev "notice|ExperimentalWarning|WARN") ;} 3>&1 + publish_exit_code=$? + + # Re-enable immediate exit + set -e + + if [[ $publish_exit_code -eq 0 ]]; then + echo "Package published successfully." + elif echo "$npm_publish_err" | grep -q "You cannot publish over the previously published versions"; then + echo "Skipping publishing, package already published." + elif [[ ! -z "$npm_publish_err" ]]; then + echo "Failed to publish with error: $npm_publish_err" + exit 1 fi env: NODE_AUTH_TOKEN: ${{ secrets.NPM_AUTH_TOKEN }} @@ -325,8 +329,6 @@ jobs: - name: Checkout if: ${{ matrix.build.TARGET != 'aarch64-unknown-linux-musl'}} uses: actions/checkout@v4 - with: - submodules: "true" - name: Setup for musl if: ${{ contains(matrix.build.TARGET, 'musl') }} @@ -377,10 +379,56 @@ jobs: npm install --no-save @valkey/valkey-glide@${{ env.NPM_TAG }} npm run test + - name: Deprecating packages on failure + if: ${{ failure() }} + shell: bash + env: + GH_EVENT_NAME: ${{ github.event_name }} + GH_EVENT_INPUT_VERSION: ${{ github.event.inputs.version }} + GH_REF: ${{ github.ref }} + NODE_AUTH_TOKEN: ${{ secrets.NPM_AUTH_TOKEN }} + PLATFORM_MATRIX: ${{ needs.load-platform-matrix.outputs.PLATFORM_MATRIX }} + run: | + # Detect OS and install jq + if [[ "${OSTYPE}" == "darwin"* ]]; then + brew install jq || true + elif command -v apk > /dev/null; then + apk add --no-cache jq + else + sudo apt-get update && sudo apt-get install -y jq + fi + + # Set RELEASE_VERSION correctly using environment variables + if [[ "${GH_EVENT_NAME}" == "workflow_dispatch" ]]; then + RELEASE_VERSION="${GH_EVENT_INPUT_VERSION}" + else + RELEASE_VERSION="${GH_REF#refs/tags/v}" + fi + + # Validate RELEASE_VERSION + if [[ ! "${RELEASE_VERSION}" =~ ^[0-9]+\.[0-9]+\.[0-9]+(-rc[0-9]+)?$ ]]; then + echo "Invalid release version format: ${RELEASE_VERSION}" + exit 1 + fi + + echo "Release version for Deprecating: ${RELEASE_VERSION}" + + # Deprecating base package + npm deprecate "@valkey/valkey-glide@${RELEASE_VERSION}" "This version has been deprecated" --force || true + + # Process platform matrix + echo "${PLATFORM_MATRIX}" > platform_matrix.json + + while read -r pkg; do + package_name="@valkey/valkey-glide-${pkg}" + echo "Deprecating ${package_name}@${RELEASE_VERSION}" + npm deprecate "${package_name}@${RELEASE_VERSION}" "This version has been deprecated" --force || true + done < <(jq -r '.[] | "\(.NAMED_OS)\(.TARGET | test("musl") | if . then "-musl" else "" end)-\(.ARCH)"' platform_matrix.json) + # Reset the repository to make sure we get the clean checkout of the action later in other actions. # It is not required since in other actions we are cleaning before the action, but it is a good practice to do it here as well. - name: Reset repository - if: ${{ contains(matrix.build.RUNNER, 'self-hosted') }} + if: ${{ always() }} && ${{ contains(matrix.build.RUNNER, 'self-hosted') }} shell: bash run: | git reset --hard diff --git a/.github/workflows/ort.yml b/.github/workflows/ort.yml index 8b405e121d..2134f1f7a4 100644 --- a/.github/workflows/ort.yml +++ b/.github/workflows/ort.yml @@ -43,7 +43,6 @@ jobs: - name: Checkout target branch uses: actions/checkout@v4 with: - submodules: "true" ref: ${{ env.TARGET_BRANCH }} repository: ${{ github.event.pull_request.head.repo.full_name }} token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/pypi-cd.yml b/.github/workflows/pypi-cd.yml index 647dbd1c8e..4ea517a818 100644 --- a/.github/workflows/pypi-cd.yml +++ b/.github/workflows/pypi-cd.yml @@ -52,6 +52,9 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Start self hosted EC2 runner uses: ./.github/workflows/start-self-hosted-runner with: @@ -113,7 +116,7 @@ jobs: if: ${{ !contains(matrix.build.RUNNER, 'self-hosted') }} uses: actions/setup-python@v5 with: - python-version: "3.12" + python-version: "3.13" - name: Update package version in config.toml uses: ./.github/workflows/update-glide-version @@ -143,7 +146,7 @@ jobs: with: working-directory: ./python target: ${{ matrix.build.TARGET }} - args: --release --strip --out wheels -i ${{ github.event_name != 'pull_request' && 'python3.8 python3.9 python3.10 python3.11 python3.12 python3.13' || 'python3.12' }} + args: --release --strip --out wheels -i ${{ github.event_name != 'pull_request' && 'python3.9 python3.10 python3.11 python3.12 python3.13' || 'python3.12' }} manylinux: auto container: ${{ matrix.build.CONTAINER != '' && matrix.build.CONTAINER || '2014' }} before-script-linux: | @@ -151,7 +154,8 @@ jobs: if [[ $(`which apt`) != '' ]] then echo "installing unzip and curl" - apt install unzip curl -y + apt-get update + apt install unzip curl python3.13 -y fi PB_REL="https://github.com/protocolbuffers/protobuf/releases" ARCH=`uname -p` @@ -159,6 +163,8 @@ jobs: PROTOC_ARCH="x86_64" elif [[ $ARCH == 'aarch64' ]]; then PROTOC_ARCH="aarch_64" + export CC_aarch64_unknown_linux_gnu=aarch64-linux-gnu-gcc + export CFLAGS_aarch64_unknown_linux_gnu="-march=armv8-a" else echo "Running on unsupported architecture: $ARCH. Expected one of: ['x86_64', 'aarch64']" exit 1 @@ -171,10 +177,10 @@ jobs: if: startsWith(matrix.build.NAMED_OS, 'darwin') uses: PyO3/maturin-action@v1 with: - maturin-version: latest + maturin-version: 0.14.17 working-directory: ./python target: ${{ matrix.build.TARGET }} - args: --release --strip --out wheels -i ${{ github.event_name != 'pull_request' && 'python3.8 python3.9 python3.10 python3.11 python3.12 python3.13' || 'python3.12' }} + args: --release --strip --out wheels -i ${{ github.event_name != 'pull_request' && 'python3.9 python3.10 python3.11 python3.12 python3.13' || 'python3.12' }} - name: Upload Python wheels if: github.event_name != 'pull_request' @@ -214,6 +220,10 @@ jobs: matrix: build: ${{ fromJson(needs.load-platform-matrix.outputs.PLATFORM_MATRIX) }} steps: + - name: Setup self-hosted runner access + if: ${{ matrix.build.TARGET == 'aarch64-unknown-linux-gnu' }} + run: sudo chown -R $USER:$USER /home/ubuntu/actions-runner/_work/valkey-glide + - name: checkout uses: actions/checkout@v4 @@ -222,10 +232,31 @@ jobs: with: python-version: 3.12 - - name: Install engine - uses: ./.github/workflows/install-engine - with: - version: "8.0" + - name: Install engine Ubuntu ARM + if: ${{ matrix.build.TARGET == 'aarch64-unknown-linux-gnu' }} + shell: bash + # in self hosted runner we first want to check that engine is not already installed + run: | + if [[ $(`which redis-server`) == '' ]] + then + sudo apt-get update + sudo apt-get install -y redis-server + else + echo "Redis is already installed" + fi + + - name: Install engine Ubuntu x86 + if: ${{ matrix.build.TARGET == 'x86_64-unknown-linux-gnu' }} + shell: bash + run: | + sudo apt-get update + sudo apt-get install -y redis-server + + - name: Install engine MacOS + if: ${{ matrix.build.OS == 'macos' }} + shell: bash + run: | + brew install redis - name: Check if RC and set a distribution tag for the package shell: bash @@ -233,12 +264,11 @@ jobs: if [[ "${GITHUB_REF:11}" == *"rc"* ]] then echo "This is a release candidate" - export pip_pre="--pre" + echo "PIP_PRE=true" >> $GITHUB_ENV else echo "This is a stable release" - export pip_pre="" + echo "PIP_PRE=false" >> $GITHUB_ENV fi - echo "PIP_PRE=${pip_pre}" >> $GITHUB_ENV - name: Run the tests shell: bash @@ -247,7 +277,11 @@ jobs: python -m venv venv source venv/bin/activate pip install -U pip - pip install ${PIP_PRE} valkey-glide + if [[ "${{ env.PIP_PRE }}" == "true" ]]; then + pip install --pre valkey-glide + else + pip install valkey-glide + fi python rc_test.py # Reset the repository to make sure we get the clean checkout of the action later in other actions. diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 3fcd171498..699033cf1a 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -9,7 +9,7 @@ on: paths: - python/** - glide-core/src/** - - submodules/** + - glide-core/redis-rs/redis/src/**/** - utils/cluster_manager.py - .github/workflows/python.yml - .github/workflows/build-python-wrapper/action.yml @@ -25,7 +25,7 @@ on: paths: - python/** - glide-core/src/** - - submodules/** + - glide-core/redis-rs/redis/src/** - utils/cluster_manager.py - .github/workflows/python.yml - .github/workflows/build-python-wrapper/action.yml @@ -91,8 +91,6 @@ jobs: host: ${{ fromJson(needs.get-matrices.outputs.host-matrix-output) }} steps: - uses: actions/checkout@v4 - with: - submodules: recursive - name: Set up Python uses: actions/setup-python@v5 @@ -113,28 +111,31 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} engine-version: ${{ matrix.engine.version }} - - name: Type check with mypy - working-directory: ./python - run: | - # The type check should run inside the virtual env to get - # all installed dependencies and build files - source .env/bin/activate - pip install mypy types-protobuf - # Install the benchmark requirements - pip install -r ../benchmarks/python/requirements.txt - python -m mypy .. - - name: Test with pytest working-directory: ./python run: | source .env/bin/activate + pip install -r dev_requirements.txt cd python/tests/ pytest --asyncio-mode=auto --html=pytest_report.html --self-contained-html - uses: ./.github/workflows/test-benchmark + if: ${{ matrix.engine.version == '8.0' && matrix.host.OS == 'ubuntu' && matrix.host.RUNNER == 'ubuntu-latest' && matrix.python == '3.12' }} with: language-flag: -python + - name: Type check with mypy + if: ${{ matrix.engine.version == '8.0' && matrix.host.OS == 'ubuntu' && matrix.host.RUNNER == 'ubuntu-latest' && matrix.python == '3.12' }} + working-directory: ./python + run: | + # The type check should run inside the virtual env to get + # all installed dependencies and build files + source .env/bin/activate + pip install mypy types-protobuf + # Install the benchmark requirements + pip install -r ../benchmarks/python/requirements.txt + python -m mypy .. + - name: Upload test reports if: always() continue-on-error: true @@ -159,8 +160,6 @@ jobs: host: ${{ fromJson(needs.get-matrices.outputs.host-matrix-output) }} steps: - uses: actions/checkout@v4 - with: - submodules: recursive - name: Set up Python uses: actions/setup-python@v5 @@ -179,6 +178,7 @@ jobs: working-directory: ./python run: | source .env/bin/activate + pip install -r dev_requirements.txt cd python/tests/ pytest --asyncio-mode=auto -k test_pubsub --html=pytest_report.html --self-contained-html @@ -196,8 +196,6 @@ jobs: timeout-minutes: 15 steps: - uses: actions/checkout@v4 - with: - submodules: recursive - name: lint rust uses: ./.github/workflows/lint-rust @@ -274,8 +272,6 @@ jobs: echo IMAGE=amazonlinux:latest | sed -r 's/:/-/g' >> $GITHUB_ENV # Replace `:` in the variable otherwise it can't be used in `upload-artifact` - uses: actions/checkout@v4 - with: - submodules: recursive - name: Build Python wrapper uses: ./.github/workflows/build-python-wrapper diff --git a/.github/workflows/redis-rs.yml b/.github/workflows/redis-rs.yml new file mode 100644 index 0000000000..5d3d82855a --- /dev/null +++ b/.github/workflows/redis-rs.yml @@ -0,0 +1,142 @@ +name: Redis-rs CI + +on: + push: + branches: + - main + - release-* + - v* + paths: + - glide-core/redis-rs/redis/** + - utils/cluster_manager.py + - deny.toml + - .github/workflows/install-shared-dependencies/action.yml + - .github/workflows/redis-rs.yml + pull_request: + paths: + - glide-core/redis-rs/redis/** + - utils/cluster_manager.py + - deny.toml + - .github/workflows/install-shared-dependencies/action.yml + - .github/workflows/redis-rs.yml + workflow_dispatch: + workflow_call: + +concurrency: + group: redis-rs-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +env: + CARGO_TERM_COLOR: always + +jobs: + redis-rs-CI: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install shared software dependencies + uses: ./.github/workflows/install-shared-dependencies + with: + os: "ubuntu" + target: "x86_64-unknown-linux-gnu" + engine-version: "7.2.5" + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Cache dependencies + uses: Swatinem/rust-cache@v2 + with: + cache-on-failure: true + workspaces: ./glide-core/redis-rs/redis + + - name: Build project + run: cargo build --release + working-directory: ./glide-core/redis-rs/redis/src + + - name: Lint redis-rs + shell: bash + run: | + cargo fmt --all -- --check + cargo clippy -- -D warnings + cargo install --locked cargo-deny + cargo deny check all --config ${GITHUB_WORKSPACE}/deny.toml --exclude-dev all + working-directory: ./glide-core/redis-rs/redis + + - name: Test + # TODO remove the concurrency limit after we fix test flakyness. + run: | + cargo test --release -- --test-threads=1 | tee ../test-results.xml + echo "### Tests passed :v:" >> $GITHUB_STEP_SUMMARY + working-directory: ./glide-core/redis-rs/redis/src + + - name: Upload test reports + if: always() + continue-on-error: true + uses: actions/upload-artifact@v4 + with: + name: test-reports-redis-rs-${{ github.sha }} + path: ./glide-core/redis-rs/redis/test-results.xml + + - name: Run benchmarks + run: | + cargo bench | tee bench-results.xml + working-directory: ./glide-core/redis-rs/redis + + - name: Upload benchmark results + if: always() + continue-on-error: true + uses: actions/upload-artifact@v4 + with: + name: benchmark-results-redis-rs-${{ github.sha }} + path: ./glide-core/redis-rs/redis/bench-results.xml + + - name: Test docs + run: | + cargo test --doc + working-directory: ./glide-core/redis-rs/redis/src + + - name: Security audit + run: | + cargo audit | tee audit-results.txt + if grep -q "Crate: " audit-results.txt; then + echo "## Security audit results summary: Security vulnerabilities found :exclamation: :exclamation:" >> $GITHUB_STEP_SUMMARY + echo "Security audit results summary: Security vulnerabilities found" + exit 1 + else + echo "### Security audit results summary: All good, no security vulnerabilities found :closed_lock_with_key:" >> $GITHUB_STEP_SUMMARY + echo "Security audit results summary: All good, no security vulnerabilities found" + fi + working-directory: ./glide-core/redis-rs/redis + + - name: Upload audit results + if: always() + continue-on-error: true + uses: actions/upload-artifact@v4 + with: + name: audit-results-redis-rs--${{ github.sha }} + path: ./glide-core/redis-rs/redis/audit-results.txt + + - name: Run cargo machete + run: | + cargo install cargo-machete + cargo machete | tee machete-results.txt + if grep -A1 "cargo-machete found the following unused dependencies in this directory:" machete-results.txt | sed -n '2p' | grep -v "^if" > /dev/null; then + echo "Machete results summary: Unused dependencies found" >> $GITHUB_STEP_SUMMARY + echo "Machete results summary: Unused dependencies found" + cat machete-results.txt | grep -A1 "cargo-machete found the following unused dependencies in this directory:" | sed -n '2p' | grep -v "^if" >> $GITHUB_STEP_SUMMARY + exit 1 + else + echo "### Machete results summary: All good, no unused dependencies found :rocket:" >> $GITHUB_STEP_SUMMARY + echo "Machete results summary: All good, no unused dependencies found" + fi + working-directory: ./glide-core/redis-rs/redis + + - name: Upload machete results + if: always() + continue-on-error: true + uses: actions/upload-artifact@v4 + with: + name: machete-results-redis-rs-${{ github.sha }} + path: ./glide-core/redis-rs/redis/machete-results.txt diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 8f8f25a180..0c71fa2f86 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -9,7 +9,7 @@ on: paths: - logger_core/** - glide-core/** - - submodules/** + - glide-core/redis-rs/redis/src/** - utils/cluster_manager.py - .github/workflows/rust.yml - .github/workflows/install-shared-dependencies/action.yml @@ -22,7 +22,7 @@ on: paths: - logger_core/** - glide-core/** - - submodules/** + - glide-core/redis-rs/redis/src/** - utils/cluster_manager.py - .github/workflows/rust.yml - .github/workflows/install-shared-dependencies/action.yml @@ -86,8 +86,6 @@ jobs: steps: - uses: actions/checkout@v4 - with: - submodules: recursive - name: Install shared software dependencies uses: ./.github/workflows/install-shared-dependencies @@ -122,8 +120,6 @@ jobs: timeout-minutes: 30 steps: - uses: actions/checkout@v4 - with: - submodules: recursive - uses: ./.github/workflows/lint-rust with: diff --git a/.github/workflows/semgrep.yml b/.github/workflows/semgrep.yml index 3e4219bd7f..10523666fa 100644 --- a/.github/workflows/semgrep.yml +++ b/.github/workflows/semgrep.yml @@ -10,7 +10,10 @@ on: description: "The branch to run against the semgrep tool" required: true push: - branches: ["main"] + branches: + - main + - release-* + - v* # Schedule the CI job (this method uses cron syntax): schedule: - cron: "0 8 * * *" # Sets Semgrep to scan every day at 08:00 UTC. @@ -33,4 +36,4 @@ jobs: # Fetch project source with GitHub Actions Checkout. - uses: actions/checkout@v4 # Run the "semgrep ci" command on the command line of the docker image. - - run: semgrep ci --config auto --no-suppress-errors + - run: semgrep ci --config auto --no-suppress-errors --exclude-rule generic.secrets.security.detected-private-key.detected-private-key diff --git a/.github/workflows/test-benchmark/action.yml b/.github/workflows/test-benchmark/action.yml index 91cc36697f..3bd50dc0f2 100644 --- a/.github/workflows/test-benchmark/action.yml +++ b/.github/workflows/test-benchmark/action.yml @@ -11,7 +11,8 @@ runs: steps: - shell: bash - run: redis-server & + # Disable RDB snapshots to avoid configuration errors + run: redis-server --save "" --daemonize "yes" - shell: bash working-directory: ./benchmarks diff --git a/.gitignore b/.gitignore index 573bfc218d..6799f31ea6 100644 --- a/.gitignore +++ b/.gitignore @@ -43,6 +43,8 @@ logger-rs.linux-x64-gnu.node utils/clusters/ utils/tls_crts/ utils/TestUtils.js +.build/ +.project # OSS Review Toolkit (ORT) files **/ort*/** diff --git a/.gitmodules b/.gitmodules index 87a3d9b855..e69de29bb2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +0,0 @@ -[submodule "submodules/redis-rs"] - path = submodules/redis-rs - url = https://github.com/amazon-contributing/redis-rs diff --git a/.vscode/settings.json b/.vscode/settings.json index e599c02a5b..229045495f 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -14,7 +14,7 @@ "node/rust-client/Cargo.toml", "logger_core/Cargo.toml", "csharp/lib/Cargo.toml", - "submodules/redis-rs/Cargo.toml", + "glide-core/redis-rs/Cargo.toml", "benchmarks/rust/Cargo.toml", "java/Cargo.toml" ], diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ee5dda24f..6a98e3bf76 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,12 +1,108 @@ #### Changes * Python: Add JSON.ARRLEN command ([#2403](https://github.com/valkey-io/valkey-glide/pull/2403)) +* Node: Client API for retrieving internal statistics ([#2727](https://github.com/valkey-io/valkey-glide/pull/2727)) +* Python: Client API for retrieving internal statistics ([#2707](https://github.com/valkey-io/valkey-glide/pull/2707)) +* Node, Python, Java: Adding support for replacing connection configured password ([#2651](https://github.com/valkey-io/valkey-glide/pull/2651), [#2659](https://github.com/valkey-io/valkey-glide/pull/2659), [#2677](https://github.com/valkey-io/valkey-glide/pull/2677)) +* Node, Python, Java: AZ Affinity - Python Wrapper Support ([#2686](https://github.com/valkey-io/valkey-glide/pull/2686), [#2676](https://github.com/valkey-io/valkey-glide/pull/2676), [#2678](https://github.com/valkey-io/valkey-glide/pull/2678)) +* Node: Add `FT._ALIASLIST` command ([#2652](https://github.com/valkey-io/valkey-glide/pull/2652)) +* Python: Add `FT._ALIASLIST` command ([#2638](https://github.com/valkey-io/valkey-glide/pull/2638)) +* Node: Add `FT.ALIASADD`, `FT.ALIADDEL`, `FT.ALIASUPDATE` ([#2596](https://github.com/valkey-io/valkey-glide/pull/2596)) +* Python code cleanup ([#2573](https://github.com/valkey-io/valkey-glide/pull/2573)) +* Python: Add `FT.PROFILE` command ([#2543](https://github.com/valkey-io/valkey-glide/pull/2543)) +* Python: Add `FT.AGGREGATE` command ([#2530](https://github.com/valkey-io/valkey-glide/pull/2530)) +* Python: Add `JSON.OBJLEN` command ([#2495](https://github.com/valkey-io/valkey-glide/pull/2495)) +* Python: Add `FT.EXPLAIN` and `FT.EXPLAINCLI` commands ([#2508](https://github.com/valkey-io/valkey-glide/pull/2508)) +* Python: Add `FT.INFO` command ([#2429](https://github.com/valkey-io/valkey-glide/pull/2494)) +* Python: Add `FT.SEARCH` command ([#2470](https://github.com/valkey-io/valkey-glide/pull/2470)) +* Python: Add commands `FT.ALIASADD`, `FT.ALIASDEL`, `FT.ALIASUPDATE` ([#2471](https://github.com/valkey-io/valkey-glide/pull/2471)) +* Python: Add `FT.DROPINDEX` command ([#2437](https://github.com/valkey-io/valkey-glide/pull/2437)) +* Python: Add `FT.CREATE` command ([#2413](https://github.com/valkey-io/valkey-glide/pull/2413)) +* Python: Add `JSON.MGET` command ([#2507](https://github.com/valkey-io/valkey-glide/pull/2507)) +* Python: Add `JSON.ARRLEN` command ([#2403](https://github.com/valkey-io/valkey-glide/pull/2403)) +* Python: Add `JSON.CLEAR` command ([#2418](https://github.com/valkey-io/valkey-glide/pull/2418)) +* Python: Add `JSON.TYPE` command ([#2409](https://github.com/valkey-io/valkey-glide/pull/2409)) +* Python: Add `JSON.NUMINCRBY` command ([#2448](https://github.com/valkey-io/valkey-glide/pull/2448)) +* Python: Add `JSON.NUMMULTBY` command ([#2458](https://github.com/valkey-io/valkey-glide/pull/2458)) +* Python: Add `JSON.ARRINDEX` command ([#2528](https://github.com/valkey-io/valkey-glide/pull/2528)) +* Python: Add `FT._LIST` command ([#2571](https://github.com/valkey-io/valkey-glide/pull/2571)) +* Python: Add `JSON.DEBUG_MEMORY` and `JSON.DEBUG_FIELDS` commands ([#2481](https://github.com/valkey-io/valkey-glide/pull/2481)) +* Java: Added `FT.CREATE` ([#2414](https://github.com/valkey-io/valkey-glide/pull/2414)) +* Java: Added `FT.INFO` ([#2405](https://github.com/valkey-io/valkey-glide/pull/2441)) +* Java: Added `FT.DROPINDEX` ([#2440](https://github.com/valkey-io/valkey-glide/pull/2440)) +* Java: Added `FT.SEARCH` ([#2439](https://github.com/valkey-io/valkey-glide/pull/2439)) +* Java: Added `FT.AGGREGATE` ([#2466](https://github.com/valkey-io/valkey-glide/pull/2466)) +* Java: Added `FT.PROFILE` ([#2473](https://github.com/valkey-io/valkey-glide/pull/2473)) +* Java: Added `JSON.SET` and `JSON.GET` ([#2462](https://github.com/valkey-io/valkey-glide/pull/2462)) +* Java: Added `JSON.MGET` ([#2514](https://github.com/valkey-io/valkey-glide/pull/2514)) +* Node: Added `FT.CREATE` ([#2501](https://github.com/valkey-io/valkey-glide/pull/2501)) +* Node: Added `FT.INFO` ([#2540](https://github.com/valkey-io/valkey-glide/pull/2540)) +* Node: Added `FT.AGGREGATE` ([#2554](https://github.com/valkey-io/valkey-glide/pull/2554)) +* Node: Added `FT.PROFILE` ([#2633](https://github.com/valkey-io/valkey-glide/pull/2633)) +* Java: Added `JSON.DEBUG` ([#2520](https://github.com/valkey-io/valkey-glide/pull/2520)) +* Java: Added `JSON.ARRINSERT` and `JSON.ARRLEN` ([#2476](https://github.com/valkey-io/valkey-glide/pull/2476)) +* Java: Added `JSON.ARRINDEX` ([#2546](https://github.com/valkey-io/valkey-glide/pull/2546)) +* Java: Added `JSON.ARRPOP` ([#2486](https://github.com/valkey-io/valkey-glide/pull/2486)) +* Java: Added `JSON.OBJLEN` and `JSON.OBJKEYS` ([#2492](https://github.com/valkey-io/valkey-glide/pull/2492)) +* Java: Added `JSON.DEL` and `JSON.FORGET` ([#2490](https://github.com/valkey-io/valkey-glide/pull/2490)) +* Java: Added `FT.ALIASADD`, `FT.ALIASDEL`, `FT.ALIASUPDATE` ([#2442](https://github.com/valkey-io/valkey-glide/pull/2442)) +* Java: Added `FT._ALIASLIST` ([#2569](https://github.com/valkey-io/valkey-glide/pull/2569)) +* Java: Added `FT.EXPLAIN`, `FT.EXPLAINCLI` ([#2515](https://github.com/valkey-io/valkey-glide/pull/2515)) +* Core: Update routing for commands from server modules ([#2461](https://github.com/valkey-io/valkey-glide/pull/2461)) +* Node: Added `JSON.SET` and `JSON.GET` ([#2427](https://github.com/valkey-io/valkey-glide/pull/2427)) +* Node: Added `JSON.MGET` ([#2567](https://github.com/valkey-io/valkey-glide/pull/2567)) +* Java: Added `JSON.NUMINCRBY` and `JSON.NUMMULTBY` ([#2511](https://github.com/valkey-io/valkey-glide/pull/2511)) +* Java: Added `JSON.ARRAPPEND` ([#2489](https://github.com/valkey-io/valkey-glide/pull/2489)) +* Java: Added `JSON.ARRTRIM` ([#2518](https://github.com/valkey-io/valkey-glide/pull/2518)) +* Node: Added `JSON.TOGGLE` ([#2491](https://github.com/valkey-io/valkey-glide/pull/2491)) +* Node: Added `JSON.ARRINSERT`, `JSON.ARRPOP` and `JSON.ARRLEN` ([#2542](https://github.com/valkey-io/valkey-glide/pull/2542)) +* Node: Added `JSON.DEL` and `JSON.FORGET` ([#2505](https://github.com/valkey-io/valkey-glide/pull/2505)) +* Java: Added `JSON.TOGGLE` ([#2504](https://github.com/valkey-io/valkey-glide/pull/2504)) +* Java: Added `JSON.STRAPPEND` and `JSON.STRLEN` ([#2522](https://github.com/valkey-io/valkey-glide/pull/2522)) +* Java: Added `JSON.CLEAR` ([#2519](https://github.com/valkey-io/valkey-glide/pull/2519)) +* Node: Added `JSON.TYPE` ([#2510](https://github.com/valkey-io/valkey-glide/pull/2510)) +* Node: Added `JSON.ARRAPPEND` ([#2562](https://github.com/valkey-io/valkey-glide/pull/2562)) +* Java: Added `JSON.RESP` ([#2513](https://github.com/valkey-io/valkey-glide/pull/2513)) +* Java: Added `JSON.TYPE` ([#2525](https://github.com/valkey-io/valkey-glide/pull/2525)) +* Java: Added `FT._LIST` ([#2568](https://github.com/valkey-io/valkey-glide/pull/2568)) +* Node: Added `FT.DROPINDEX` ([#2516](https://github.com/valkey-io/valkey-glide/pull/2516)) +* Node: Added `FT._LIST` ([#2570](https://github.com/valkey-io/valkey-glide/pull/2570)) +* Node: Added `JSON.RESP` ([#2517](https://github.com/valkey-io/valkey-glide/pull/2517)) +* Node: Added `FT.EXPLAIN` and `FT.EXPLAINCLI` ([#2560](https://github.com/valkey-io/valkey-glide/pull/2560)) +* Node: Added `JSON.CLEAR` ([#2566](https://github.com/valkey-io/valkey-glide/pull/2566)) +* Node: Added `JSON.ARRTRIM` ([#2550](https://github.com/valkey-io/valkey-glide/pull/2550)) +* Python: Add `JSON.STRAPPEND`, `JSON.STRLEN` commands ([#2372](https://github.com/valkey-io/valkey-glide/pull/2372)) +* Node: Added `JSON.ARRINDEX` ([#2559](https://github.com/valkey-io/valkey-glide/pull/2559)) +* Node: Added `JSON.OBJLEN` and `JSON.OBJKEYS` ([#2563](https://github.com/valkey-io/valkey-glide/pull/2563)) +* Python: Add `JSON.STRAPPEND`, `JSON.STRLEN` commands ([#2372](https://github.com/valkey-io/valkey-glide/pull/2372)) +* Python: Add `JSON.OBJKEYS` command ([#2395](https://github.com/valkey-io/valkey-glide/pull/2395)) +* Python: Add `JSON.ARRINSERT` command ([#2464](https://github.com/valkey-io/valkey-glide/pull/2464)) +* Python: Add `JSON.ARRTRIM` command ([#2457](https://github.com/valkey-io/valkey-glide/pull/2457)) +* Python: Add `JSON.ARRAPPEND` command ([#2382](https://github.com/valkey-io/valkey-glide/pull/2382)) +* Python: Add `JSON.RESP` command ([#2451](https://github.com/valkey-io/valkey-glide/pull/2451)) +* Python: Add `JSON.ARRPOP` command ([#2407](https://github.com/valkey-io/valkey-glide/pull/2407)) +* Node: Add `JSON.STRLEN` and `JSON.STRAPPEND` command ([#2537](https://github.com/valkey-io/valkey-glide/pull/2537)) +* Node: Add `FT.SEARCH` ([#2551](https://github.com/valkey-io/valkey-glide/pull/2551)) +* Python: Fix example ([#2556](https://github.com/valkey-io/valkey-glide/issues/2556)) +* Core: Add support for sending multi-slot JSON.MSET and JSON.MGET commands ([#2587](https://github.com/valkey-io/valkey-glide/pull/2587)) +* Node: Add `JSON.DEBUG` command ([#2572](https://github.com/valkey-io/valkey-glide/pull/2572)) +* Node: Add `JSON.NUMINCRBY` and `JSON.NUMMULTBY` command ([#2555](https://github.com/valkey-io/valkey-glide/pull/2555)) +* Core: Add support to Availability Zone Affinity read strategy ([#2539](https://github.com/valkey-io/valkey-glide/pull/2539)) +* Core: Fix list of readonly commands ([#2634](https://github.com/valkey-io/valkey-glide/pull/2634), [#2649](https://github.com/valkey-io/valkey-glide/pull/2649)) +* Core: Improve retry logic and update unmaintained dependencies for Rust lint CI ([#2673](https://github.com/valkey-io/valkey-glide/pull/2643)) +* Core: Release the read lock while creating connections in `refresh_connections` ([#2630](https://github.com/valkey-io/valkey-glide/issues/2630)) +* Core: SlotMap refactor - Added NodesMap, Update the slot map upon MOVED errors ([#2682](https://github.com/valkey-io/valkey-glide/issues/2682)) #### Breaking Changes #### Fixes +* Core: UDS Socket Handling Rework ([#2482](https://github.com/valkey-io/valkey-glide/pull/2482)) #### Operational Enhancements +* Java: Add modules CI ([#2388](https://github.com/valkey-io/valkey-glide/pull/2388), [#2404](https://github.com/valkey-io/valkey-glide/pull/2404), [#2416](https://github.com/valkey-io/valkey-glide/pull/2416)) +* Node: Add modules CI ([#2472](https://github.com/valkey-io/valkey-glide/pull/2472)) +* Python: Fix modules CI ([#2487](https://github.com/valkey-io/valkey-glide/pull/2487)) + ## 1.1.0 (2024-09-24) #### Changes @@ -401,6 +497,7 @@ * Node: Added LINDEX command ([#999](https://github.com/valkey-io/valkey-glide/pull/999)) * Python, Node: Added ZPOPMAX command ([#996](https://github.com/valkey-io/valkey-glide/pull/996), [#1009](https://github.com/valkey-io/valkey-glide/pull/1009)) * Python: Added DBSIZE command ([#1040](https://github.com/valkey-io/valkey-glide/pull/1040)) +* Core: Log directory can now be modified by setting the environment variable `GLIDE_LOG_DIR` ([#2704](https://github.com/valkey-io/valkey-glide/issues/2704)) #### Features diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000..877d78f6e9 --- /dev/null +++ b/Makefile @@ -0,0 +1,116 @@ +.PHONY: all java java-test python python-test node node-test check-redis-server go go-test + +BLUE=\033[34m +YELLOW=\033[33m +GREEN=\033[32m +RESET=\033[0m +ROOT_DIR=$(shell pwd) +PYENV_DIR=$(shell pwd)/python/.env +PY_PATH=$(shell find python/.env -name "site-packages"|xargs readlink -f) +PY_GLIDE_PATH=$(shell pwd)/python/python/ + +all: java java-test python python-test node node-test go go-test python-lint java-lint + +## +## Java targets +## +java: + @echo "$(GREEN)Building for Java (release)$(RESET)" + @cd java && ./gradlew :client:buildAllRelease + +java-lint: + @echo "$(GREEN)Running spotlessApply$(RESET)" + @cd java && ./gradlew :spotlessApply + +java-test: check-redis-server + @echo "$(GREEN)Running integration tests$(RESET)" + @cd java && ./gradlew :integTest:test + +## +## Python targets +## +python: .build/python_deps + @echo "$(GREEN)Building for Python (release)$(RESET)" + @cd python && VIRTUAL_ENV=$(PYENV_DIR) .env/bin/maturin develop --release --strip + +python-lint: .build/python_deps + @echo "$(GREEN)Building Linters for python$(RESET)" + cd python && \ + export VIRTUAL_ENV=$(PYENV_DIR); \ + export PYTHONPATH=$(PY_PATH):$(PY_GLIDE_PATH); \ + export PATH=$(PYENV_DIR)/bin:$(PATH); \ + isort . --profile black --skip-glob python/glide/protobuf --skip-glob .env && \ + black . --exclude python/glide/protobuf --exclude .env && \ + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics \ + --exclude=python/glide/protobuf,.env/* --extend-ignore=E230 && \ + flake8 . --count --exit-zero --max-complexity=12 --max-line-length=127 \ + --statistics --exclude=python/glide/protobuf,.env/* \ + --extend-ignore=E230 + +python-test: .build/python_deps check-redis-server + cd python && PYTHONPATH=$(PY_PATH):$(PY_GLIDE_PATH) .env/bin/pytest --asyncio-mode=auto + +.build/python_deps: + @echo "$(GREEN)Generating protobuf files...$(RESET)" + @protoc -Iprotobuf=$(ROOT_DIR)/glide-core/src/protobuf/ \ + --python_out=$(ROOT_DIR)/python/python/glide $(ROOT_DIR)/glide-core/src/protobuf/*.proto + @echo "$(GREEN)Building environment...$(RESET)" + @cd python && python3 -m venv .env + @echo "$(GREEN)Installing requirements...$(RESET)" + @cd python && .env/bin/pip install -r requirements.txt + @cd python && .env/bin/pip install -r dev_requirements.txt + @mkdir -p .build/ && touch .build/python_deps + +## +## NodeJS targets +## +node: .build/node_deps + @echo "$(GREEN)Building for NodeJS (release)...$(RESET)" + @cd node && npm run build:release + +.build/node_deps: + @echo "$(GREEN)Installing NodeJS dependencies...$(RESET)" + @cd node && npm i + @cd node/rust-client && npm i + @mkdir -p .build/ && touch .build/node_deps + +node-test: .build/node_deps check-redis-server + @echo "$(GREEN)Running tests for NodeJS$(RESET)" + @cd node && npm run build + cd node && npm test + +node-lint: .build/node_deps + @echo "$(GREEN)Running linters for NodeJS$(RESET)" + @cd node && npx run lint:fix + +## +## Go targets +## + + +go: .build/go_deps + $(MAKE) -C go build + +go-test: .build/go_deps + $(MAKE) -C go test + +go-lint: .build/go_deps + $(MAKE) -C go lint + +.build/go_deps: + @echo "$(GREEN)Installing GO dependencies...$(RESET)" + $(MAKE) -C go install-build-tools install-dev-tools + @mkdir -p .build/ && touch .build/go_deps + +## +## Common targets +## +check-redis-server: + which redis-server + +clean: + rm -fr .build/ + +help: + @echo "$(GREEN)Listing Makefile targets:$(RESET)" + @echo $(shell grep '^[^#[:space:]].*:' Makefile|cut -d":" -f1|grep -v PHONY|grep -v "^.build"|sort) diff --git a/benchmarks/rust/Cargo.toml b/benchmarks/rust/Cargo.toml index 6f0849d505..1c7baf0b70 100644 --- a/benchmarks/rust/Cargo.toml +++ b/benchmarks/rust/Cargo.toml @@ -11,11 +11,10 @@ authors = ["Valkey GLIDE Maintainers"] tokio = { version = "1", features = ["macros", "time", "rt-multi-thread"] } glide-core = { path = "../../glide-core" } logger_core = {path = "../../logger_core"} -redis = { path = "../../submodules/redis-rs/redis", features = ["aio"] } +redis = { path = "../../glide-core/redis-rs/redis", features = ["aio"] } futures = "0.3.28" rand = "0.8.5" itoa = "1.0.6" -futures-time = "^3.0.0" clap = { version = "4.3.8", features = ["derive"] } chrono = "0.4.26" serde_json = "1.0.99" diff --git a/csharp/lib/Cargo.toml b/csharp/lib/Cargo.toml index 95981480b2..b49e098bf7 100644 --- a/csharp/lib/Cargo.toml +++ b/csharp/lib/Cargo.toml @@ -12,7 +12,7 @@ name = "glide_rs" crate-type = ["cdylib"] [dependencies] -redis = { path = "../../submodules/redis-rs/redis", features = ["aio", "tokio-comp","tokio-native-tls-comp"] } +redis = { path = "../../glide-core/redis-rs/redis", features = ["aio", "tokio-comp","tokio-native-tls-comp"] } glide-core = { path = "../../glide-core" } tokio = { version = "^1", features = ["rt", "macros", "rt-multi-thread", "time"] } logger_core = {path = "../../logger_core"} diff --git a/deny.toml b/deny.toml index f97a82f0c8..526ce9bd1e 100644 --- a/deny.toml +++ b/deny.toml @@ -9,24 +9,6 @@ # The values provided in this template are the default values that will be used # when any section or field is not specified in your own configuration -# If 1 or more target triples (and optionally, target_features) are specified, -# only the specified targets will be checked when running `cargo deny check`. -# This means, if a particular package is only ever used as a target specific -# dependency, such as, for example, the `nix` crate only being used via the -# `target_family = "unix"` configuration, that only having windows targets in -# this list would mean the nix crate, as well as any of its exclusive -# dependencies not shared by any other crates, would be ignored, as the target -# list here is effectively saying which targets you are building for. -targets = [ - # The triple can be any string, but only the target triples built in to - # rustc (as of 1.40) can be checked against actual config expressions - #{ triple = "x86_64-unknown-linux-musl" }, - # You can also specify which target_features you promise are enabled for a - # particular target. target_features are currently not validated against - # the actual valid features supported by the target architecture. - #{ triple = "wasm32-unknown-unknown", features = ["atomics"] }, -] - # This section is considered when running `cargo deny check advisories` # More documentation for the advisories section can be found here: # https://embarkstudios.github.io/cargo-deny/checks/advisories/cfg.html @@ -35,22 +17,13 @@ targets = [ db-path = "~/.cargo/advisory-db" # The url(s) of the advisory databases to use db-urls = ["https://github.com/rustsec/advisory-db"] -# The lint level for security vulnerabilities -vulnerability = "deny" -# The lint level for unmaintained crates -unmaintained = "deny" # The lint level for crates that have been yanked from their source registry yanked = "deny" -# The lint level for crates with security notices. Note that as of -# 2019-12-17 there are no security notice advisories in -# https://github.com/rustsec/advisory-db -notice = "deny" -unsound = "deny" # A list of advisory IDs to ignore. Note that ignored advisories will still # output a note when they are encountered. ignore = [ # Unmaintained dependency error that needs more attention due to nested dependencies - "RUSTSEC-2024-0370" + "RUSTSEC-2024-0370", ] # Threshold for security vulnerabilities, any vulnerability with a CVSS score # lower than the range specified will be ignored. Note that ignored advisories @@ -72,8 +45,6 @@ ignore = [ # More documentation for the licenses section can be found here: # https://embarkstudios.github.io/cargo-deny/checks/licenses/cfg.html [licenses] -# The lint level for crates which do not have a detectable license -unlicensed = "deny" # List of explicitly allowed licenses # See https://spdx.org/licenses/ for list of possible licenses # [possible values: any SPDX 3.11 short identifier (+ optional exception)]. @@ -85,28 +56,10 @@ allow = [ "BSD-3-Clause", "Unicode-DFS-2016", "ISC", - "OpenSSL" -] -# List of explicitly disallowed licenses -# See https://spdx.org/licenses/ for list of possible licenses -# [possible values: any SPDX 3.11 short identifier (+ optional exception)]. -deny = [ - #"Nokia", + "OpenSSL", + "MPL-2.0", + "Unicode-3.0" ] -# Lint level for licenses considered copyleft -copyleft = "deny" -# Blanket approval or denial for OSI-approved or FSF Free/Libre licenses -# * both - The license will be approved if it is both OSI-approved *AND* FSF -# * either - The license will be approved if it is either OSI-approved *OR* FSF -# * osi-only - The license will be approved if is OSI-approved *AND NOT* FSF -# * fsf-only - The license will be approved if is FSF *AND NOT* OSI-approved -# * neither - This predicate is ignored and the default lint level is used -allow-osi-fsf-free = "neither" -# Lint level used when no other predicates are matched -# 1. License isn't in the allow or deny lists -# 2. License isn't copyleft -# 3. License isn't OSI/FSF, or allow-osi-fsf-free = "neither" -default = "deny" # The confidence threshold for detecting a license from license text. # The higher the value, the more closely the license text must be to the # canonical license text of a valid SPDX license file. @@ -137,7 +90,7 @@ expression = "MIT AND ISC AND OpenSSL" # depending on the rest of your configuration license-files = [ # Each entry is a crate relative path, and the (opaque) hash of its contents - { path = "LICENSE", hash = 0xbd0eed23 } + { path = "LICENSE", hash = 0xbd0eed23 }, ] [licenses.private] diff --git a/examples/python/cluster_example.py b/examples/python/cluster_example.py index c3cefbd14f..01f916963b 100644 --- a/examples/python/cluster_example.py +++ b/examples/python/cluster_example.py @@ -1,5 +1,5 @@ import asyncio -from typing import List, Tuple +from typing import List, Tuple, Optional from glide import ( AllNodes, @@ -17,7 +17,7 @@ async def create_client( - nodes_list: List[Tuple[str, int]] = [("localhost", 6379)] + nodes_list: Optional[List[Tuple[str, int]]] = None ) -> GlideClusterClient: """ Creates and returns a GlideClusterClient instance. @@ -33,6 +33,8 @@ async def create_client( Returns: GlideClusterClient: An instance of GlideClusterClient connected to the discovered nodes. """ + if nodes_list is None: + nodes_list = [("localhost", 6379)] addresses = [NodeAddress(host, port) for host, port in nodes_list] # Check `GlideClusterClientConfiguration` for additional options. config = GlideClusterClientConfiguration( diff --git a/glide-core/Cargo.toml b/glide-core/Cargo.toml index e0a1b05368..bd12bb09c9 100644 --- a/glide-core/Cargo.toml +++ b/glide-core/Cargo.toml @@ -10,42 +10,63 @@ authors = ["Valkey GLIDE Maintainers"] [dependencies] bytes = "1" futures = "^0.3" -redis = { path = "../submodules/redis-rs/redis", features = ["aio", "tokio-comp", "tokio-rustls-comp", "connection-manager","cluster", "cluster-async"] } +redis = { path = "./redis-rs/redis", features = [ + "aio", + "tokio-comp", + "tokio-rustls-comp", + "connection-manager", + "cluster", + "cluster-async", +] } +telemetrylib = { path = "./telemetry" } tokio = { version = "1", features = ["macros", "time"] } -logger_core = {path = "../logger_core"} +logger_core = { path = "../logger_core" } dispose = "0.5.0" -tokio-util = {version = "^0.7", features = ["rt"], optional = true} +tokio-util = { version = "^0.7", features = ["rt"], optional = true } num_cpus = { version = "^1.15", optional = true } -tokio-retry = "0.3.0" -protobuf = { version= "3", features = ["bytes", "with-bytes"], optional = true } +tokio-retry2 = {version = "0.5", features = ["jitter"]} + +protobuf = { version = "3", features = [ + "bytes", + "with-bytes", +], optional = true } integer-encoding = { version = "4.0.0", optional = true } thiserror = "1" rand = { version = "0.8.5" } futures-intrusive = "0.5.0" -directories = { version = "4.0", optional = true } +directories = { version = "5.0", optional = true } once_cell = "1.18.0" -arcstr = "1.1.5" sha1_smol = "1.0.0" nanoid = "0.4.0" async-trait = { version = "0.1.24" } +serde_json = "1" +serde = { version = "1", features = ["derive"] } [features] -socket-layer = ["directories", "integer-encoding", "num_cpus", "protobuf", "tokio-util"] +socket-layer = [ + "directories", + "integer-encoding", + "num_cpus", + "protobuf", + "tokio-util", +] standalone_heartbeat = [] [dev-dependencies] rsevents = "0.3.1" socket2 = "^0.5" tempfile = "3.3.0" -rstest = "^0.18" +rstest = "^0.23" serial_test = "3" criterion = { version = "^0.5", features = ["html_reports", "async_tokio"] } -which = "5" +which = "6" ctor = "0.2.2" -redis = { path = "../submodules/redis-rs/redis", features = ["tls-rustls-insecure"] } -iai-callgrind = "0.9" +redis = { path = "./redis-rs/redis", features = ["tls-rustls-insecure"] } +iai-callgrind = "0.14" tokio = { version = "1", features = ["rt-multi-thread"] } -glide-core = { path = ".", features = ["socket-layer"] } # always enable this feature in tests. +glide-core = { path = ".", features = [ + "socket-layer", +] } # always enable this feature in tests. [lints.rust] unexpected_cfgs = { level = "warn", check-cfg = ['cfg(standalone_heartbeat)'] } diff --git a/glide-core/redis-rs/Cargo.toml b/glide-core/redis-rs/Cargo.toml new file mode 100644 index 0000000000..fa989b93cd --- /dev/null +++ b/glide-core/redis-rs/Cargo.toml @@ -0,0 +1,9 @@ +# Dummy package so ORT tool will not fail on virtual workspace +[package] +name = "dummy-for-ort" +version = "0.1.0" +edition = "2021" + +[workspace] +members = ["redis", "redis-test", "afl/parser"] +resolver = "2" diff --git a/glide-core/redis-rs/LICENSE b/glide-core/redis-rs/LICENSE new file mode 100644 index 0000000000..533ac4e5a2 --- /dev/null +++ b/glide-core/redis-rs/LICENSE @@ -0,0 +1,33 @@ +Copyright (c) 2022 by redis-rs contributors + +Redis cluster code in parts copyright (c) 2018 by Atsushi Koge. + +Some rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + * The names of the contributors may not be used to endorse or + promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/glide-core/redis-rs/README.md b/glide-core/redis-rs/README.md new file mode 100644 index 0000000000..34cdfe4778 --- /dev/null +++ b/glide-core/redis-rs/README.md @@ -0,0 +1,233 @@ +# redis-rs + +[![Rust](https://github.com/redis-rs/redis-rs/actions/workflows/rust.yml/badge.svg)](https://github.com/redis-rs/redis-rs/actions/workflows/rust.yml) +[![crates.io](https://img.shields.io/crates/v/redis.svg)](https://crates.io/crates/redis) +[![Chat](https://img.shields.io/discord/976380008299917365?logo=discord)](https://discord.gg/WHKcJK9AKP) + +Redis-rs is a high level redis library for Rust. It provides convenient access +to all Redis functionality through a very flexible but low-level API. It +uses a customizable type conversion trait so that any operation can return +results in just the type you are expecting. This makes for a very pleasant +development experience. + +The crate is called `redis` and you can depend on it via cargo: + +```ini +[dependencies] +redis = "0.25.2" +``` + +Documentation on the library can be found at +[docs.rs/redis](https://docs.rs/redis). + +**Note: redis-rs requires at least Rust 1.60.** + +## Basic Operation + +To open a connection you need to create a client and then to fetch a +connection from it. In the future there will be a connection pool for +those, currently each connection is separate and not pooled. + +Many commands are implemented through the `Commands` trait but manual +command creation is also possible. + +```rust +use redis::Commands; + +fn fetch_an_integer() -> redis::RedisResult { + // connect to redis + let client = redis::Client::open("redis://127.0.0.1/")?; + let mut con = client.get_connection(None)?; + // throw away the result, just make sure it does not fail + let _ : () = con.set("my_key", 42)?; + // read back the key and return it. Because the return value + // from the function is a result for integer this will automatically + // convert into one. + con.get("my_key") +} +``` + +Variables are converted to and from the Redis format for a wide variety of types +(`String`, num types, tuples, `Vec`). If you want to use it with your own types, +you can implement the `FromRedisValue` and `ToRedisArgs` traits, or derive it with the +[redis-macros](https://github.com/daniel7grant/redis-macros/#json-wrapper-with-redisjson) crate. + +## Async support + +To enable asynchronous clients, enable the relevant feature in your Cargo.toml, +`tokio-comp` for tokio users or `async-std-comp` for async-std users. + +``` +# if you use tokio +redis = { version = "0.25.2", features = ["tokio-comp"] } + +# if you use async-std +redis = { version = "0.25.2", features = ["async-std-comp"] } +``` + +## TLS Support + +To enable TLS support, you need to use the relevant feature entry in your Cargo.toml. +Currently, `native-tls` and `rustls` are supported. + +To use `native-tls`: + +``` +redis = { version = "0.25.2", features = ["tls-native-tls"] } + +# if you use tokio +redis = { version = "0.25.2", features = ["tokio-native-tls-comp"] } + +# if you use async-std +redis = { version = "0.25.2", features = ["async-std-native-tls-comp"] } +``` + +To use `rustls`: + +``` +redis = { version = "0.25.2", features = ["tls-rustls"] } + +# if you use tokio +redis = { version = "0.25.2", features = ["tokio-rustls-comp"] } + +# if you use async-std +redis = { version = "0.25.2", features = ["async-std-rustls-comp"] } +``` + +With `rustls`, you can add the following feature flags on top of other feature flags to enable additional features: + +- `tls-rustls-insecure`: Allow insecure TLS connections +- `tls-rustls-webpki-roots`: Use `webpki-roots` (Mozilla's root certificates) instead of native root certificates + +then you should be able to connect to a redis instance using the `rediss://` URL scheme: + +```rust +let client = redis::Client::open("rediss://127.0.0.1/")?; +``` + +To enable insecure mode, append `#insecure` at the end of the URL: + +```rust +let client = redis::Client::open("rediss://127.0.0.1/#insecure")?; +``` + +**Deprecation Notice:** If you were using the `tls` or `async-std-tls-comp` features, please use the `tls-native-tls` or `async-std-native-tls-comp` features respectively. + +## Cluster Support + +Support for Redis Cluster can be enabled by enabling the `cluster` feature in your Cargo.toml: + +`redis = { version = "0.25.2", features = [ "cluster"] }` + +Then you can simply use the `ClusterClient`, which accepts a list of available nodes. Note +that only one node in the cluster needs to be specified when instantiating the client, though +you can specify multiple. + +```rust +use redis::cluster::ClusterClient; +use redis::Commands; + +fn fetch_an_integer() -> String { + let nodes = vec!["redis://127.0.0.1/"]; + let client = ClusterClient::new(nodes).unwrap(); + let mut connection = client.get_connection(None).unwrap(); + let _: () = connection.set("test", "test_data").unwrap(); + let rv: String = connection.get("test").unwrap(); + return rv; +} +``` + +Async Redis Cluster support can be enabled by enabling the `cluster-async` feature, along +with your preferred async runtime, e.g.: + +`redis = { version = "0.25.2", features = [ "cluster-async", "tokio-std-comp" ] }` + +```rust +use redis::cluster::ClusterClient; +use redis::AsyncCommands; + +async fn fetch_an_integer() -> String { + let nodes = vec!["redis://127.0.0.1/"]; + let client = ClusterClient::new(nodes).unwrap(); + let mut connection = client.get_async_connection().await.unwrap(); + let _: () = connection.set("test", "test_data").await.unwrap(); + let rv: String = connection.get("test").await.unwrap(); + return rv; +} +``` + +## JSON Support + +Support for the RedisJSON Module can be enabled by specifying "json" as a feature in your Cargo.toml. + +`redis = { version = "0.25.2", features = ["json"] }` + +Then you can simply import the `JsonCommands` trait which will add the `json` commands to all Redis Connections (not to be confused with just `Commands` which only adds the default commands) + +```rust +use redis::Client; +use redis::JsonCommands; +use redis::RedisResult; +use redis::ToRedisArgs; + +// Result returns Ok(true) if the value was set +// Result returns Err(e) if there was an error with the server itself OR serde_json was unable to serialize the boolean +fn set_json_bool(key: P, path: P, b: bool) -> RedisResult { + let client = Client::open("redis://127.0.0.1").unwrap(); + let connection = client.get_connection(None).unwrap(); + + // runs `JSON.SET {key} {path} {b}` + connection.json_set(key, path, b)? +} + +``` + +To parse the results, you'll need to use `serde_json` (or some other json lib) to deserialize +the results from the bytes. It will always be a `Vec`, if no results were found at the path it'll +be an empty `Vec`. If you want to handle deserialization and `Vec` unwrapping automatically, +you can use the `Json` wrapper from the +[redis-macros](https://github.com/daniel7grant/redis-macros/#json-wrapper-with-redisjson) crate. + +## Development + +To test `redis` you're going to need to be able to test with the Redis Modules, to do this +you must set the following environment variable before running the test script + +- `REDIS_RS_REDIS_JSON_PATH` = The absolute path to the RedisJSON module (Either `librejson.so` for Linux or `librejson.dylib` for MacOS). + +- Please refer to this [link](https://github.com/RedisJSON/RedisJSON) to access the RedisJSON module: + + + +If you want to develop on the library there are a few commands provided +by the makefile: + +To build: + + $ make + +To test: + + $ make test + +To run benchmarks: + + $ make bench + +To build the docs (require nightly compiler, see [rust-lang/rust#43781](https://github.com/rust-lang/rust/issues/43781)): + + $ make docs + +We encourage you to run `clippy` prior to seeking a merge for your work. The lints can be quite strict. Running this on your own workstation can save you time, since Travis CI will fail any build that doesn't satisfy `clippy`: + + $ cargo clippy --all-features --all --tests --examples -- -D clippy::all -D warnings + +To run fuzz tests with afl, first install cargo-afl (`cargo install -f afl`), +then run: + + $ make fuzz + +If the fuzzer finds a crash, in order to reproduce it, run: + + $ cd afl// + $ cargo run --bin reproduce -- out/crashes/ diff --git a/glide-core/redis-rs/afl/.gitignore b/glide-core/redis-rs/afl/.gitignore new file mode 100644 index 0000000000..1776e13233 --- /dev/null +++ b/glide-core/redis-rs/afl/.gitignore @@ -0,0 +1,2 @@ +out/ +core.* diff --git a/glide-core/redis-rs/afl/parser/Cargo.toml b/glide-core/redis-rs/afl/parser/Cargo.toml new file mode 100644 index 0000000000..ef356faaf9 --- /dev/null +++ b/glide-core/redis-rs/afl/parser/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "fuzz-target-parser" +version = "0.1.0" +authors = ["redis-rs developers"] +edition = "2018" + +[[bin]] +name = "fuzz-target" +path = "src/main.rs" + +[[bin]] +name = "reproduce" +path = "src/reproduce.rs" + +[dependencies] +afl = "0.15" +redis = { path = "../../redis" } diff --git a/glide-core/redis-rs/afl/parser/in/array b/glide-core/redis-rs/afl/parser/in/array new file mode 100644 index 0000000000..c92e405790 --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/array @@ -0,0 +1,5 @@ +*3 +:1 +$-1 +$2 +hi diff --git a/glide-core/redis-rs/afl/parser/in/array-null b/glide-core/redis-rs/afl/parser/in/array-null new file mode 100644 index 0000000000..e0f619c1b3 --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/array-null @@ -0,0 +1 @@ +*-1 diff --git a/glide-core/redis-rs/afl/parser/in/bulkstring b/glide-core/redis-rs/afl/parser/in/bulkstring new file mode 100644 index 0000000000..930878abea --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/bulkstring @@ -0,0 +1,2 @@ +$6 +foobar diff --git a/glide-core/redis-rs/afl/parser/in/bulkstring-null b/glide-core/redis-rs/afl/parser/in/bulkstring-null new file mode 100644 index 0000000000..f4280bede5 --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/bulkstring-null @@ -0,0 +1 @@ +$-1 diff --git a/glide-core/redis-rs/afl/parser/in/error b/glide-core/redis-rs/afl/parser/in/error new file mode 100644 index 0000000000..7cfd9a521a --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/error @@ -0,0 +1 @@ +-ERR unknown command diff --git a/glide-core/redis-rs/afl/parser/in/integer b/glide-core/redis-rs/afl/parser/in/integer new file mode 100644 index 0000000000..49525f0d45 --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/integer @@ -0,0 +1 @@ +:1337 diff --git a/glide-core/redis-rs/afl/parser/in/invalid-string b/glide-core/redis-rs/afl/parser/in/invalid-string new file mode 100644 index 0000000000..604dd3e85a --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/invalid-string @@ -0,0 +1,2 @@ +$6 +foo diff --git a/glide-core/redis-rs/afl/parser/in/string b/glide-core/redis-rs/afl/parser/in/string new file mode 100644 index 0000000000..054430c700 --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/string @@ -0,0 +1 @@ ++OK diff --git a/glide-core/redis-rs/afl/parser/src/main.rs b/glide-core/redis-rs/afl/parser/src/main.rs new file mode 100644 index 0000000000..6dc674edff --- /dev/null +++ b/glide-core/redis-rs/afl/parser/src/main.rs @@ -0,0 +1,9 @@ +use afl::fuzz; + +use redis::parse_redis_value; + +fn main() { + fuzz!(|data: &[u8]| { + let _ = parse_redis_value(data); + }); +} diff --git a/glide-core/redis-rs/afl/parser/src/reproduce.rs b/glide-core/redis-rs/afl/parser/src/reproduce.rs new file mode 100644 index 0000000000..086dfffb50 --- /dev/null +++ b/glide-core/redis-rs/afl/parser/src/reproduce.rs @@ -0,0 +1,13 @@ +use redis::parse_redis_value; + +fn main() { + let args: Vec = std::env::args().collect(); + if args.len() != 2 { + println!("Usage: {} ", args[0]); + std::process::exit(1); + } + + let data = std::fs::read(&args[1]).expect(&format!("Could not open file {}", args[1])); + let v = parse_redis_value(&data); + println!("Result: {:?}", v); +} diff --git a/glide-core/redis-rs/appveyor.yml b/glide-core/redis-rs/appveyor.yml new file mode 100644 index 0000000000..8310b8def5 --- /dev/null +++ b/glide-core/redis-rs/appveyor.yml @@ -0,0 +1,23 @@ +os: Visual Studio 2015 + +environment: + REDISRS_SERVER_TYPE: tcp + RUST_BACKTRACE: 1 + matrix: + - channel: stable + target: x86_64-pc-windows-msvc + - channel: stable + target: x86_64-pc-windows-gnu +install: + - appveyor DownloadFile https://win.rustup.rs/ -FileName rustup-init.exe + - rustup-init -yv --default-toolchain %channel% --default-host %target% + - set PATH=%PATH%;%USERPROFILE%\.cargo\bin + - rustc -vV + - cargo -vV + - cmd: nuget install redis-64 -excludeversion + - set PATH=%PATH%;%APPVEYOR_BUILD_FOLDER%\redis-64\tools\ + +build: false + +test_script: + - cargo test --verbose --no-default-features --features tokio-comp %cargoflags% diff --git a/glide-core/redis-rs/get_command_info.py b/glide-core/redis-rs/get_command_info.py new file mode 100644 index 0000000000..4c719dd4d4 --- /dev/null +++ b/glide-core/redis-rs/get_command_info.py @@ -0,0 +1,228 @@ +# type: ignore +import argparse +import json +import os +from os.path import join + +"""Valkey command categorizer + +This script analyzes command info json files and categorizes the commands based on their routing. The output can be used +to map commands in the cluster_routing.rs#base_routing function to their RouteBy category. Commands that cannot be +categorized by the script will be listed under the "Uncategorized" section. These commands will need to be manually +categorized. + +To use the script: +1. Clone https://github.com/valkey-io/valkey +2. cd into the cloned valkey repository and checkout the desired version of the code, eg 7.2.5 +3. cd into the directory containing this script +4. run: + python get_command_info.py --commands-dir=/valkey/src/commands +""" + + +class CommandCategory: + def __init__(self, name, description): + self.name = name + self.description = description + self.commands = [] + + def add_command(self, command_name): + self.commands.append(command_name) + + +def main(): + parser = argparse.ArgumentParser( + description="Analyzes command info json and categorizes commands into their RouteBy categories") + parser.add_argument( + "--commands-dir", + type=str, + help="Path to the directory containing the command info json files (example: ../../valkey/src/commands)", + required=True, + ) + + args = parser.parse_args() + commands_dir = args.commands_dir + if not os.path.exists(commands_dir): + raise parser.error("The command info directory passed to the '--commands-dir' argument does not exist") + + all_nodes = CommandCategory("AllNodes", "Commands with an ALL_NODES request policy") + all_primaries = CommandCategory("AllPrimaries", "Commands with an ALL_SHARDS request policy") + multi_shard = CommandCategory("MultiShardNoValues or MultiShardWithValues", + "Commands with a MULTI_SHARD request policy") + first_arg = CommandCategory("FirstKey", "Commands with their first key argument at position 1") + second_arg = CommandCategory("SecondArg", "Commands with their first key argument at position 2") + second_arg_numkeys = ( + CommandCategory("SecondArgAfterKeyCount", + "Commands with their first key argument at position 2, after a numkeys argument")) + # all commands with their first key argument at position 3 have a numkeys argument at position 2, + # so there is a ThirdArgAfterKeyCount category but no ThirdArg category + third_arg_numkeys = ( + CommandCategory("ThirdArgAfterKeyCount", + "Commands with their first key argument at position 3, after a numkeys argument")) + streams_index = CommandCategory("StreamsIndex", "Commands that include a STREAMS token") + second_arg_slot = CommandCategory("SecondArgSlot", "Commands with a slot argument at position 2") + uncategorized = ( + CommandCategory( + "Uncategorized", + "Commands that don't fall into the other categories. These commands will have to be manually categorized.")) + + categories = [all_nodes, all_primaries, multi_shard, first_arg, second_arg, second_arg_numkeys, third_arg_numkeys, + streams_index, second_arg_slot, uncategorized] + + print("Gathering command info...\n") + + for filename in os.listdir(commands_dir): + file_path = join(commands_dir, filename) + _, file_extension = os.path.splitext(file_path) + if file_extension != ".json": + print(f"Note: {filename} is not a json file and will thus be ignored") + continue + + file = open(file_path) + command_json = json.load(file) + if len(command_json) == 0: + raise Exception( + f"The json for {filename} was empty. A json object with information about the command was expected.") + + command_name = next(iter(command_json)) + command_info = command_json[command_name] + if "container" in command_info: + # for two-word commands like 'XINFO GROUPS', the `next(iter(command_json))` statement above returns 'GROUPS' + # and `command_info['container']` returns 'XINFO' + command_name = f"{command_info['container']} {command_name}" + + if "command_tips" in command_info: + request_policy = get_request_policy(command_info["command_tips"]) + if request_policy == "ALL_NODES": + all_nodes.add_command(command_name) + continue + elif request_policy == "ALL_SHARDS": + all_primaries.add_command(command_name) + continue + elif request_policy == "MULTI_SHARD": + multi_shard.add_command(command_name) + continue + + if "arguments" not in command_info: + uncategorized.add_command(command_name) + continue + + command_args = command_info["arguments"] + split_name = command_name.split() + if len(split_name) == 0: + raise Exception(f"Encountered json with an empty command name in file '{filename}'") + + json_key_index, is_key_optional = get_first_key_info(command_args) + # cluster_routing.rs can handle optional keys if a keycount of 0 is provided, otherwise the command should + # fall under the "Uncategorized" section to indicate it will need to be manually inspected + if is_key_optional and not is_after_numkeys(command_args, json_key_index): + uncategorized.add_command(command_name) + continue + + if json_key_index == -1: + # the command does not have a key argument, check for a slot argument + json_slot_index, is_slot_optional = get_first_slot_info(command_args) + if is_slot_optional: + uncategorized.add_command(command_name) + continue + + # cluster_routing.rs considers each word in the command name to be an argument, but the json does not + cluster_routing_slot_index = -1 if json_slot_index == -1 else len(split_name) + json_slot_index + if cluster_routing_slot_index == 2: + second_arg_slot.add_command(command_name) + continue + + # the command does not have a slot argument, check for a "STREAMS" token + if has_streams_token(command_args): + streams_index.add_command(command_name) + continue + + uncategorized.add_command(command_name) + continue + + # cluster_routing.rs considers each word in the command name to be an argument, but the json does not + cluster_routing_key_index = -1 if json_key_index == -1 else len(split_name) + json_key_index + if cluster_routing_key_index == 1: + first_arg.add_command(command_name) + continue + elif cluster_routing_key_index == 2: + if is_after_numkeys(command_args, json_key_index): + second_arg_numkeys.add_command(command_name) + continue + else: + second_arg.add_command(command_name) + continue + # there aren't any commands that fall into a ThirdArg category, + # but there are commands that fall under ThirdArgAfterKeyCount category + elif cluster_routing_key_index == 3 and is_after_numkeys(command_args, json_key_index): + third_arg_numkeys.add_command(command_name) + continue + + uncategorized.add_command(command_name) + + print("\nNote: the following information considers each word in the command name to be an argument") + print("For example, for 'XGROUP DESTROY key group':") + print("'XGROUP' is arg0, 'DESTROY' is arg1, 'key' is arg2, and 'group' is arg3.\n") + + for category in categories: + print_category(category) + + +def get_request_policy(command_tips): + for command_tip in command_tips: + if command_tip.startswith("REQUEST_POLICY:"): + return command_tip[len("REQUEST_POLICY:"):] + + return None + + +def get_first_key_info(args_info_json) -> tuple[int, bool]: + for i in range(len(args_info_json)): + info = args_info_json[i] + if info["type"].lower() == "key": + is_optional = "optional" in info and info["optional"] + return i, is_optional + + return -1, False + + +def get_first_slot_info(args_info_json) -> tuple[int, bool]: + for i in range(len(args_info_json)): + info = args_info_json[i] + if info["name"].lower() == "slot": + is_optional = "optional" in info and info["optional"] + return i, is_optional + + return -1, False + + +def is_after_numkeys(args_info_json, json_index): + return json_index > 0 and args_info_json[json_index - 1]["name"].lower() == "numkeys" + + +def has_streams_token(args_info_json): + for arg_info in args_info_json: + if "token" in arg_info and arg_info["token"].upper() == "STREAMS": + return True + + return False + + +def print_category(category): + print("============================") + print(f"Category: {category.name} commands") + print(f"Description: {category.description}") + print("List of commands in this category:\n") + + if len(category.commands) == 0: + print("(No commands found for this category)") + else: + category.commands.sort() + for command_name in category.commands: + print(f"{command_name}") + + print("\n") + + +if __name__ == "__main__": + main() diff --git a/glide-core/redis-rs/redis-test/CHANGELOG.md b/glide-core/redis-rs/redis-test/CHANGELOG.md new file mode 100644 index 0000000000..83d3ab3dc4 --- /dev/null +++ b/glide-core/redis-rs/redis-test/CHANGELOG.md @@ -0,0 +1,44 @@ +### 0.4.0 (2023-03-08) +* Track redis 0.25.0 release + +### 0.3.0 (2023-12-05) +* Track redis 0.24.0 release + +### 0.2.3 (2023-09-01) + +* Track redis 0.23.3 release + +### 0.2.2 (2023-08-10) + +* Track redis 0.23.2 release + +### 0.2.1 (2023-07-28) + +* Track redis 0.23.1 release + + +### 0.2.0 (2023-04-05) + +* Track redis 0.23.0 release + + +### 0.2.0-beta.1 (2023-03-28) + +* Track redis 0.23.0-beta.1 release + + +### 0.1.1 (2022-10-18) + +#### Changes +* Add README +* Update LICENSE file / symlink from parent directory + + + +### 0.1.0 (2022-10-05) + +This is the initial release of the redis-test crate, which aims to provide mocking +for connections and commands. Thanks @tdyas! + +#### Features +* Testing module with support for mocking redis connections and commands ([#465](https://github.com/redis-rs/redis-rs/pull/465) @tdyas) \ No newline at end of file diff --git a/glide-core/redis-rs/redis-test/Cargo.toml b/glide-core/redis-rs/redis-test/Cargo.toml new file mode 100644 index 0000000000..6e0bcc3a9f --- /dev/null +++ b/glide-core/redis-rs/redis-test/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "redis-test" +version = "0.4.0" +edition = "2021" +description = "Testing helpers for the `redis` crate" +homepage = "https://github.com/redis-rs/redis-rs" +repository = "https://github.com/redis-rs/redis-rs" +documentation = "https://docs.rs/redis-test" +license = "BSD-3-Clause" +rust-version = "1.65" + +[lib] +bench = false + +[dependencies] +redis = { version = "0.25.0", path = "../redis" } + +bytes = { version = "1", optional = true } +futures = { version = "0.3", optional = true } + +[features] +aio = ["futures", "redis/aio"] + +[dev-dependencies] +redis = { version = "0.25.0", path = "../redis", features = ["aio", "tokio-comp"] } +tokio = { version = "1", features = ["rt", "macros", "rt-multi-thread", "time"] } diff --git a/glide-core/redis-rs/redis-test/LICENSE b/glide-core/redis-rs/redis-test/LICENSE new file mode 100644 index 0000000000..533ac4e5a2 --- /dev/null +++ b/glide-core/redis-rs/redis-test/LICENSE @@ -0,0 +1,33 @@ +Copyright (c) 2022 by redis-rs contributors + +Redis cluster code in parts copyright (c) 2018 by Atsushi Koge. + +Some rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + * The names of the contributors may not be used to endorse or + promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/glide-core/redis-rs/redis-test/README.md b/glide-core/redis-rs/redis-test/README.md new file mode 100644 index 0000000000..b89bfc4edb --- /dev/null +++ b/glide-core/redis-rs/redis-test/README.md @@ -0,0 +1,4 @@ +# redis-test + +Testing utilities for the redis-rs crate. + diff --git a/glide-core/redis-rs/redis-test/release.toml b/glide-core/redis-rs/redis-test/release.toml new file mode 100644 index 0000000000..7dc5b7a0a6 --- /dev/null +++ b/glide-core/redis-rs/redis-test/release.toml @@ -0,0 +1 @@ +tag-name = "redis-test-{{version}}" diff --git a/glide-core/redis-rs/redis-test/src/lib.rs b/glide-core/redis-rs/redis-test/src/lib.rs new file mode 100644 index 0000000000..cafe8a347b --- /dev/null +++ b/glide-core/redis-rs/redis-test/src/lib.rs @@ -0,0 +1,426 @@ +//! Testing support +//! +//! This module provides `MockRedisConnection` which implements ConnectionLike and can be +//! used in the same place as any other type that behaves like a Redis connection. This is useful +//! for writing unit tests without needing a Redis server. +//! +//! # Example +//! +//! ```rust +//! use redis::{ConnectionLike, RedisError}; +//! use redis_test::{MockCmd, MockRedisConnection}; +//! +//! fn my_exists(conn: &mut C, key: &str) -> Result { +//! let exists: bool = redis::cmd("EXISTS").arg(key).query(conn)?; +//! Ok(exists) +//! } +//! +//! let mut mock_connection = MockRedisConnection::new(vec![ +//! MockCmd::new(redis::cmd("EXISTS").arg("foo"), Ok("1")), +//! ]); +//! +//! let result = my_exists(&mut mock_connection, "foo").unwrap(); +//! assert_eq!(result, true); +//! ``` + +use std::collections::VecDeque; +use std::sync::{Arc, Mutex}; + +use redis::{Cmd, ConnectionLike, ErrorKind, Pipeline, RedisError, RedisResult, Value}; + +#[cfg(feature = "aio")] +use futures::{future, FutureExt}; + +#[cfg(feature = "aio")] +use redis::{aio::ConnectionLike as AioConnectionLike, RedisFuture}; + +/// Helper trait for converting test values into a `redis::Value` returned from a +/// `MockRedisConnection`. This is necessary because neither `redis::types::ToRedisArgs` +/// nor `redis::types::FromRedisValue` performs the precise conversion needed. +pub trait IntoRedisValue { + /// Convert a value into `redis::Value`. + fn into_redis_value(self) -> Value; +} + +impl IntoRedisValue for String { + fn into_redis_value(self) -> Value { + Value::BulkString(self.as_bytes().to_vec()) + } +} + +impl IntoRedisValue for &str { + fn into_redis_value(self) -> Value { + Value::BulkString(self.as_bytes().to_vec()) + } +} + +#[cfg(feature = "bytes")] +impl IntoRedisValue for bytes::Bytes { + fn into_redis_value(self) -> Value { + Value::BulkString(self.to_vec()) + } +} + +impl IntoRedisValue for Vec { + fn into_redis_value(self) -> Value { + Value::BulkString(self) + } +} + +impl IntoRedisValue for Value { + fn into_redis_value(self) -> Value { + self + } +} + +impl IntoRedisValue for i64 { + fn into_redis_value(self) -> Value { + Value::Int(self) + } +} + +/// Helper trait for converting `redis::Cmd` and `redis::Pipeline` instances into +/// encoded byte vectors. +pub trait IntoRedisCmdBytes { + /// Convert a command into an encoded byte vector. + fn into_redis_cmd_bytes(self) -> Vec; +} + +impl IntoRedisCmdBytes for Cmd { + fn into_redis_cmd_bytes(self) -> Vec { + self.get_packed_command() + } +} + +impl IntoRedisCmdBytes for &Cmd { + fn into_redis_cmd_bytes(self) -> Vec { + self.get_packed_command() + } +} + +impl IntoRedisCmdBytes for &mut Cmd { + fn into_redis_cmd_bytes(self) -> Vec { + self.get_packed_command() + } +} + +impl IntoRedisCmdBytes for Pipeline { + fn into_redis_cmd_bytes(self) -> Vec { + self.get_packed_pipeline() + } +} + +impl IntoRedisCmdBytes for &Pipeline { + fn into_redis_cmd_bytes(self) -> Vec { + self.get_packed_pipeline() + } +} + +impl IntoRedisCmdBytes for &mut Pipeline { + fn into_redis_cmd_bytes(self) -> Vec { + self.get_packed_pipeline() + } +} + +/// Represents a command to be executed against a `MockConnection`. +pub struct MockCmd { + cmd_bytes: Vec, + responses: Result, RedisError>, +} + +impl MockCmd { + /// Create a new `MockCmd` given a Redis command and either a value convertible to + /// a `redis::Value` or a `RedisError`. + pub fn new(cmd: C, response: Result) -> Self + where + C: IntoRedisCmdBytes, + V: IntoRedisValue, + { + MockCmd { + cmd_bytes: cmd.into_redis_cmd_bytes(), + responses: response.map(|r| vec![r.into_redis_value()]), + } + } + + /// Create a new `MockCommand` given a Redis command/pipeline and a vector of value convertible + /// to a `redis::Value` or a `RedisError`. + pub fn with_values(cmd: C, responses: Result, RedisError>) -> Self + where + C: IntoRedisCmdBytes, + V: IntoRedisValue, + { + MockCmd { + cmd_bytes: cmd.into_redis_cmd_bytes(), + responses: responses.map(|xs| xs.into_iter().map(|x| x.into_redis_value()).collect()), + } + } +} + +/// A mock Redis client for testing without a server. `MockRedisConnection` checks whether the +/// client submits a specific sequence of commands and generates an error if it does not. +#[derive(Clone)] +pub struct MockRedisConnection { + commands: Arc>>, +} + +impl MockRedisConnection { + /// Construct a new from the given sequence of commands. + pub fn new(commands: I) -> Self + where + I: IntoIterator, + { + MockRedisConnection { + commands: Arc::new(Mutex::new(VecDeque::from_iter(commands))), + } + } +} + +impl ConnectionLike for MockRedisConnection { + fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult { + let mut commands = self.commands.lock().unwrap(); + let next_cmd = commands.pop_front().ok_or_else(|| { + RedisError::from(( + ErrorKind::ClientError, + "TEST", + "unexpected command".to_owned(), + )) + })?; + + if cmd != next_cmd.cmd_bytes { + return Err(RedisError::from(( + ErrorKind::ClientError, + "TEST", + format!( + "unexpected command: expected={}, actual={}", + String::from_utf8(next_cmd.cmd_bytes) + .unwrap_or_else(|_| "decode error".to_owned()), + String::from_utf8(Vec::from(cmd)).unwrap_or_else(|_| "decode error".to_owned()), + ), + ))); + } + + next_cmd + .responses + .and_then(|values| match values.as_slice() { + [value] => Ok(value.clone()), + [] => Err(RedisError::from(( + ErrorKind::ClientError, + "no value configured as response", + ))), + _ => Err(RedisError::from(( + ErrorKind::ClientError, + "multiple values configured as response for command expecting a single value", + ))), + }) + } + + fn req_packed_commands( + &mut self, + cmd: &[u8], + _offset: usize, + _count: usize, + ) -> RedisResult> { + let mut commands = self.commands.lock().unwrap(); + let next_cmd = commands.pop_front().ok_or_else(|| { + RedisError::from(( + ErrorKind::ClientError, + "TEST", + "unexpected command".to_owned(), + )) + })?; + + if cmd != next_cmd.cmd_bytes { + return Err(RedisError::from(( + ErrorKind::ClientError, + "TEST", + format!( + "unexpected command: expected={}, actual={}", + String::from_utf8(next_cmd.cmd_bytes) + .unwrap_or_else(|_| "decode error".to_owned()), + String::from_utf8(Vec::from(cmd)).unwrap_or_else(|_| "decode error".to_owned()), + ), + ))); + } + + next_cmd.responses + } + + fn get_db(&self) -> i64 { + 0 + } + + fn check_connection(&mut self) -> bool { + true + } + + fn is_open(&self) -> bool { + true + } +} + +#[cfg(feature = "aio")] +impl AioConnectionLike for MockRedisConnection { + fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> { + let packed_cmd = cmd.get_packed_command(); + let response = ::req_packed_command( + self, + packed_cmd.as_slice(), + ); + future::ready(response).boxed() + } + + fn req_packed_commands<'a>( + &'a mut self, + cmd: &'a Pipeline, + offset: usize, + count: usize, + ) -> RedisFuture<'a, Vec> { + let packed_cmd = cmd.get_packed_pipeline(); + let response = ::req_packed_commands( + self, + packed_cmd.as_slice(), + offset, + count, + ); + future::ready(response).boxed() + } + + fn get_db(&self) -> i64 { + 0 + } + + fn is_closed(&self) -> bool { + false + } +} + +#[cfg(test)] +mod tests { + use super::{MockCmd, MockRedisConnection}; + use redis::{cmd, pipe, ErrorKind, Value}; + + #[test] + fn sync_basic_test() { + let mut conn = MockRedisConnection::new(vec![ + MockCmd::new(cmd("SET").arg("foo").arg(42), Ok("")), + MockCmd::new(cmd("GET").arg("foo"), Ok(42)), + MockCmd::new(cmd("SET").arg("bar").arg("foo"), Ok("")), + MockCmd::new(cmd("GET").arg("bar"), Ok("foo")), + ]); + + cmd("SET").arg("foo").arg(42).execute(&mut conn); + assert_eq!(cmd("GET").arg("foo").query(&mut conn), Ok(42)); + + cmd("SET").arg("bar").arg("foo").execute(&mut conn); + assert_eq!( + cmd("GET").arg("bar").query(&mut conn), + Ok(Value::BulkString(b"foo".as_ref().into())) + ); + } + + #[cfg(feature = "aio")] + #[tokio::test] + async fn async_basic_test() { + let mut conn = MockRedisConnection::new(vec![ + MockCmd::new(cmd("SET").arg("foo").arg(42), Ok("")), + MockCmd::new(cmd("GET").arg("foo"), Ok(42)), + MockCmd::new(cmd("SET").arg("bar").arg("foo"), Ok("")), + MockCmd::new(cmd("GET").arg("bar"), Ok("foo")), + ]); + + cmd("SET") + .arg("foo") + .arg("42") + .query_async::<_, ()>(&mut conn) + .await + .unwrap(); + let result: Result = cmd("GET").arg("foo").query_async(&mut conn).await; + assert_eq!(result, Ok(42)); + + cmd("SET") + .arg("bar") + .arg("foo") + .query_async::<_, ()>(&mut conn) + .await + .unwrap(); + let result: Result, _> = cmd("GET").arg("bar").query_async(&mut conn).await; + assert_eq!(result.as_deref(), Ok(&b"foo"[..])); + } + + #[test] + fn errors_for_unexpected_commands() { + let mut conn = MockRedisConnection::new(vec![ + MockCmd::new(cmd("SET").arg("foo").arg(42), Ok("")), + MockCmd::new(cmd("GET").arg("foo"), Ok(42)), + ]); + + cmd("SET").arg("foo").arg(42).execute(&mut conn); + assert_eq!(cmd("GET").arg("foo").query(&mut conn), Ok(42)); + + let err = cmd("SET") + .arg("bar") + .arg("foo") + .query::<()>(&mut conn) + .unwrap_err(); + assert_eq!(err.kind(), ErrorKind::ClientError); + assert_eq!(err.detail(), Some("unexpected command")); + } + + #[test] + fn errors_for_mismatched_commands() { + let mut conn = MockRedisConnection::new(vec![ + MockCmd::new(cmd("SET").arg("foo").arg(42), Ok("")), + MockCmd::new(cmd("GET").arg("foo"), Ok(42)), + MockCmd::new(cmd("SET").arg("bar").arg("foo"), Ok("")), + ]); + + cmd("SET").arg("foo").arg(42).execute(&mut conn); + let err = cmd("SET") + .arg("bar") + .arg("foo") + .query::<()>(&mut conn) + .unwrap_err(); + assert_eq!(err.kind(), ErrorKind::ClientError); + assert!(err.detail().unwrap().contains("unexpected command")); + } + + #[test] + fn pipeline_basic_test() { + let mut conn = MockRedisConnection::new(vec![MockCmd::with_values( + pipe().cmd("GET").arg("foo").cmd("GET").arg("bar"), + Ok(vec!["hello", "world"]), + )]); + + let results: Vec = pipe() + .cmd("GET") + .arg("foo") + .cmd("GET") + .arg("bar") + .query(&mut conn) + .expect("success"); + assert_eq!(results, vec!["hello", "world"]); + } + + #[test] + fn pipeline_atomic_test() { + let mut conn = MockRedisConnection::new(vec![MockCmd::with_values( + pipe().atomic().cmd("GET").arg("foo").cmd("GET").arg("bar"), + Ok(vec![Value::Array( + vec!["hello", "world"] + .into_iter() + .map(|x| Value::BulkString(x.as_bytes().into())) + .collect(), + )]), + )]); + + let results: Vec = pipe() + .atomic() + .cmd("GET") + .arg("foo") + .cmd("GET") + .arg("bar") + .query(&mut conn) + .expect("success"); + assert_eq!(results, vec!["hello", "world"]); + } +} diff --git a/glide-core/redis-rs/redis/CHANGELOG.md b/glide-core/redis-rs/redis/CHANGELOG.md new file mode 100644 index 0000000000..9c3dd18524 --- /dev/null +++ b/glide-core/redis-rs/redis/CHANGELOG.md @@ -0,0 +1,828 @@ +### 0.25.2 (2024-03-15) + +* MultiplexedConnection: Separate response handling for pipeline. ([#1078](https://github.com/redis-rs/redis-rs/pull/1078)) + +### 0.25.1 (2024-03-12) + +* Fix small disambiguity in examples ([#1072](https://github.com/redis-rs/redis-rs/pull/1072) @sunhuachuang) +* Upgrade to socket2 0.5 ([#1073](https://github.com/redis-rs/redis-rs/pull/1073) @djc) +* Avoid library dependency on futures-time ([#1074](https://github.com/redis-rs/redis-rs/pull/1074) @djc) + + +### 0.25.0 (2024-03-08) + +#### Features + +* **Breaking change**: Add connection timeout to the cluster client ([#834](https://github.com/redis-rs/redis-rs/pull/834)) +* **Breaking change**: Deprecate aio::Connection ([#889](https://github.com/redis-rs/redis-rs/pull/889)) +* Cluster: fix read from replica & missing slots ([#965](https://github.com/redis-rs/redis-rs/pull/965)) +* Async cluster connection: Improve handling of missing connections ([#968](https://github.com/redis-rs/redis-rs/pull/968)) +* Add support for parsing to/from any sized arrays ([#981](https://github.com/redis-rs/redis-rs/pull/981)) +* Upgrade to rustls 0.22 ([#1000](https://github.com/redis-rs/redis-rs/pull/1000) @djc) +* add SMISMEMBER command ([#1002](https://github.com/redis-rs/redis-rs/pull/1002) @Zacaria) +* Add support for some big number types ([#1014](https://github.com/redis-rs/redis-rs/pull/1014) @AkiraMiyakoda) +* Add Support for UUIDs ([#1029](https://github.com/redis-rs/redis-rs/pull/1029) @Rabbitminers) +* Add FromRedisValue::from_owned_redis_value to reduce copies while parsing response ([#1030](https://github.com/redis-rs/redis-rs/pull/1030) @Nathan-Fenner) +* Save reconnected connections during retries ([#1033](https://github.com/redis-rs/redis-rs/pull/1033)) +* Avoid panic on connection failure ([#1035](https://github.com/redis-rs/redis-rs/pull/1035)) +* add disable client setinfo feature and its default mode is off ([#1036](https://github.com/redis-rs/redis-rs/pull/1036) @Ggiggle) +* Reconnect on parsing errors ([#1051](https://github.com/redis-rs/redis-rs/pull/1051)) +* preallocate buffer for evalsha in Script ([#1044](https://github.com/redis-rs/redis-rs/pull/1044) @framlog) + +#### Changes + +* Align more commands routings ([#938](https://github.com/redis-rs/redis-rs/pull/938)) +* Fix HashMap conversion ([#977](https://github.com/redis-rs/redis-rs/pull/977) @mxbrt) +* MultiplexedConnection: Remove unnecessary allocation in send ([#990](https://github.com/redis-rs/redis-rs/pull/990)) +* Tests: Reduce cluster setup flakiness ([#999](https://github.com/redis-rs/redis-rs/pull/999)) +* Remove the unwrap_or! macro ([#1010](https://github.com/redis-rs/redis-rs/pull/1010)) +* Remove allocation from command function ([#1008](https://github.com/redis-rs/redis-rs/pull/1008)) +* Catch panics from task::spawn in tests ([#1015](https://github.com/redis-rs/redis-rs/pull/1015)) +* Fix lint errors from new Rust version ([#1016](https://github.com/redis-rs/redis-rs/pull/1016)) +* Fix warnings that appear only with native-TLS ([#1018](https://github.com/redis-rs/redis-rs/pull/1018)) +* Hide the req_packed_commands from docs ([#1020](https://github.com/redis-rs/redis-rs/pull/1020)) +* Fix documentaion error ([#1022](https://github.com/redis-rs/redis-rs/pull/1022) @rcl-viveksharma) +* Fixes minor grammar mistake in json.rs file ([#1026](https://github.com/redis-rs/redis-rs/pull/1026) @RScrusoe) +* Enable ignored pipe test ([#1027](https://github.com/redis-rs/redis-rs/pull/1027)) +* Fix names of existing async cluster tests ([#1028](https://github.com/redis-rs/redis-rs/pull/1028)) +* Add lock file to keep MSRV constant ([#1039](https://github.com/redis-rs/redis-rs/pull/1039)) +* Fail CI if lock file isn't updated ([#1042](https://github.com/redis-rs/redis-rs/pull/1042)) +* impl Clone/Copy for SetOptions ([#1046](https://github.com/redis-rs/redis-rs/pull/1046) @ahmadbky) +* docs: add "connection-manager" cfg attr ([#1048](https://github.com/redis-rs/redis-rs/pull/1048) @DCNick3) +* Remove the usage of aio::Connection in tests ([#1049](https://github.com/redis-rs/redis-rs/pull/1049)) +* Fix new clippy lints ([#1052](https://github.com/redis-rs/redis-rs/pull/1052)) +* Handle server errors in array response ([#1056](https://github.com/redis-rs/redis-rs/pull/1056)) +* Appease Clippy ([#1061](https://github.com/redis-rs/redis-rs/pull/1061)) +* make Pipeline handle returned bulks correctly ([#1063](https://github.com/redis-rs/redis-rs/pull/1063) @framlog) +* Update mio dependency due to vulnerability ([#1064](https://github.com/redis-rs/redis-rs/pull/1064)) +* Simplify Sink polling logic ([#1065](https://github.com/redis-rs/redis-rs/pull/1065)) +* Separate parsing errors from general response errors ([#1069](https://github.com/redis-rs/redis-rs/pull/1069)) + +### 0.24.0 (2023-12-05) + +#### Features +* **Breaking change**: Support Mutual TLS ([#858](https://github.com/redis-rs/redis-rs/pull/858) @sp-angel) +* Implement `FromRedisValue` for `Box<[T]>` and `Arc<[T]>` ([#799](https://github.com/redis-rs/redis-rs/pull/799) @JOT85) +* Sync Cluster: support multi-slot operations. ([#967](https://github.com/redis-rs/redis-rs/pull/967)) +* Execute multi-node requests using try_request. ([#919](https://github.com/redis-rs/redis-rs/pull/919)) +* Sorted set blocking commands ([#962](https://github.com/redis-rs/redis-rs/pull/962) @gheorghitamutu) +* Allow passing routing information to cluster. ([#899](https://github.com/redis-rs/redis-rs/pull/899)) +* Add `tcp_nodelay` feature ([#941](https://github.com/redis-rs/redis-rs/pull/941) @PureWhiteWu) +* Add support for multi-shard commands. ([#900](https://github.com/redis-rs/redis-rs/pull/900)) + +#### Changes +* Order in usage of ClusterParams. ([#997](https://github.com/redis-rs/redis-rs/pull/997)) +* **Breaking change**: Fix StreamId::contains_key signature ([#783](https://github.com/redis-rs/redis-rs/pull/783) @Ayush1325) +* **Breaking change**: Update Command expiration values to be an appropriate type ([#589](https://github.com/redis-rs/redis-rs/pull/589) @joshleeb) +* **Breaking change**: Bump aHash to v0.8.6 ([#966](https://github.com/redis-rs/redis-rs/pull/966) @aumetra) +* Fix features for `load_native_certs`. ([#996](https://github.com/redis-rs/redis-rs/pull/996)) +* Revert redis-test versioning changes ([#993](https://github.com/redis-rs/redis-rs/pull/993)) +* Tests: Add retries to test cluster creation ([#994](https://github.com/redis-rs/redis-rs/pull/994)) +* Fix sync cluster behavior with transactions. ([#983](https://github.com/redis-rs/redis-rs/pull/983)) +* Sync Pub/Sub - cache received pub/sub messages. ([#910](https://github.com/redis-rs/redis-rs/pull/910)) +* Prefer routing to primary in a transaction. ([#986](https://github.com/redis-rs/redis-rs/pull/986)) +* Accept iterator at `ClusterClient` initialization ([#987](https://github.com/redis-rs/redis-rs/pull/987) @ruanpetterson) +* **Breaking change**: Change timeouts from usize and isize to f64 ([#988](https://github.com/redis-rs/redis-rs/pull/988) @eythorhel19) +* Update minimal rust version to 1.6.5 ([#982](https://github.com/redis-rs/redis-rs/pull/982)) +* Disable JSON module tests for redis 6.2.4. ([#980](https://github.com/redis-rs/redis-rs/pull/980)) +* Add connection string examples ([#976](https://github.com/redis-rs/redis-rs/pull/976) @NuclearOreo) +* Move response policy into multi-node routing. ([#952](https://github.com/redis-rs/redis-rs/pull/952)) +* Added functions that allow tests to check version. ([#963](https://github.com/redis-rs/redis-rs/pull/963)) +* Fix XREADGROUP command ordering as per Redis Docs, and compatibility with Upstash Redis ([#960](https://github.com/redis-rs/redis-rs/pull/960) @prabhpreet) +* Optimize make_pipeline_results by pre-allocate memory ([#957](https://github.com/redis-rs/redis-rs/pull/957) @PureWhiteWu) +* Run module tests sequentially. ([#956](https://github.com/redis-rs/redis-rs/pull/956)) +* Log cluster creation output in tests. ([#955](https://github.com/redis-rs/redis-rs/pull/955)) +* CI: Update and use better maintained github actions. ([#954](https://github.com/redis-rs/redis-rs/pull/954)) +* Call CLIENT SETINFO on new connections. ([#945](https://github.com/redis-rs/redis-rs/pull/945)) +* Deprecate functions that erroneously use `tokio` in their name. ([#913](https://github.com/redis-rs/redis-rs/pull/913)) +* CI: Increase timeouts and use newer redis. ([#949](https://github.com/redis-rs/redis-rs/pull/949)) +* Remove redis version from redis-test. ([#943](https://github.com/redis-rs/redis-rs/pull/943)) + +### 0.23.4 (2023-11-26) +**Yanked** -- Inadvertently introduced breaking changes (sorry!). The changes in this tag +have been pushed to 0.24.0. + +### 0.23.3 (2023-09-01) + +Note that this release fixes a small regression in async Redis Cluster handling of the `PING` command. +Based on updated response aggregation logic in [#888](https://github.com/redis-rs/redis-rs/pull/888), it +will again return a single response instead of an array. + +#### Features +* Add `key_type` command ([#933](https://github.com/redis-rs/redis-rs/pull/933) @bruaba) +* Async cluster: Group responses by response_policy. ([#888](https://github.com/redis-rs/redis-rs/pull/888)) + + +#### Fixes +* Remove unnecessary heap allocation ([#939](https://github.com/redis-rs/redis-rs/pull/939) @thechampagne) +* Sentinel tests: Ensure no ports are used twice ([#915](https://github.com/redis-rs/redis-rs/pull/915)) +* Fix lint issues ([#937](https://github.com/redis-rs/redis-rs/pull/937)) +* Fix JSON serialization error test ([#928](https://github.com/redis-rs/redis-rs/pull/928)) +* Remove unused dependencies ([#916](https://github.com/redis-rs/redis-rs/pull/916)) + + +### 0.23.2 (2023-08-10) + +#### Fixes +* Fix sentinel tests flakiness ([#912](https://github.com/redis-rs/redis-rs/pull/912)) +* Rustls: Remove usage of deprecated method ([#921](https://github.com/redis-rs/redis-rs/pull/921)) +* Fix compiling with sentinel feature, without aio feature ([#922](https://github.com/redis-rs/redis-rs/pull/923) @brocaar) +* Add timeouts to tests github action ([#911](https://github.com/redis-rs/redis-rs/pull/911)) + +### 0.23.1 (2023-07-28) + +#### Features +* Add basic Sentinel functionality ([#836](https://github.com/redis-rs/redis-rs/pull/836) @felipou) +* Enable keep alive on tcp connections via feature ([#886](https://github.com/redis-rs/redis-rs/pull/886) @DoumanAsh) +* Support fan-out commands in cluster-async ([#843](https://github.com/redis-rs/redis-rs/pull/843) @nihohit) +* connection_manager: retry and backoff on reconnect ([#804](https://github.com/redis-rs/redis-rs/pull/804) @nihohit) + +#### Changes +* Tests: Wait for all servers ([#901](https://github.com/redis-rs/redis-rs/pull/901) @barshaul) +* Pin `tempfile` dependency ([#902](https://github.com/redis-rs/redis-rs/pull/902)) +* Update routing data for commands. ([#887](https://github.com/redis-rs/redis-rs/pull/887) @nihohit) +* Add basic benchmark reporting to CI ([#880](https://github.com/redis-rs/redis-rs/pull/880)) +* Add `set_options` cmd ([#879](https://github.com/redis-rs/redis-rs/pull/879) @RokasVaitkevicius) +* Move random connection creation to when needed. ([#882](https://github.com/redis-rs/redis-rs/pull/882) @nihohit) +* Clean up existing benchmarks ([#881](https://github.com/redis-rs/redis-rs/pull/881)) +* Improve async cluster client performance. ([#877](https://github.com/redis-rs/redis-rs/pull/877) @nihohit) +* Allow configuration of cluster retry wait duration ([#859](https://github.com/redis-rs/redis-rs/pull/859) @nihohit) +* Fix async connect when ns resolved to multi ip ([#872](https://github.com/redis-rs/redis-rs/pull/872) @hugefiver) +* Reduce the number of unnecessary clones. ([#874](https://github.com/redis-rs/redis-rs/pull/874) @nihohit) +* Remove connection checking on every request. ([#873](https://github.com/redis-rs/redis-rs/pull/873) @nihohit) +* cluster_async: Wrap internal state with Arc. ([#864](https://github.com/redis-rs/redis-rs/pull/864) @nihohit) +* Fix redirect routing on request with no route. ([#870](https://github.com/redis-rs/redis-rs/pull/870) @nihohit) +* Amend README for macOS users ([#869](https://github.com/redis-rs/redis-rs/pull/869) @sarisssa) +* Improved redirection error handling ([#857](https://github.com/redis-rs/redis-rs/pull/857)) +* Fix minor async client bug. ([#862](https://github.com/redis-rs/redis-rs/pull/862) @nihohit) +* Split aio.rs to separate files. ([#821](https://github.com/redis-rs/redis-rs/pull/821) @nihohit) +* Add time feature to tokio dependency ([#855](https://github.com/redis-rs/redis-rs/pull/855) @robjtede) +* Refactor cluster error handling ([#844](https://github.com/redis-rs/redis-rs/pull/844)) +* Fix unnecessarily mutable variable ([#849](https://github.com/redis-rs/redis-rs/pull/849) @kamulos) +* Newtype SlotMap ([#845](https://github.com/redis-rs/redis-rs/pull/845)) +* Bump MSRV to 1.60 ([#846](https://github.com/redis-rs/redis-rs/pull/846)) +* Improve error logging. ([#838](https://github.com/redis-rs/redis-rs/pull/838) @nihohit) +* Improve documentation, add references to `redis-macros` ([#769](https://github.com/redis-rs/redis-rs/pull/769) @daniel7grant) +* Allow creating Cmd with capacity. ([#817](https://github.com/redis-rs/redis-rs/pull/817) @nihohit) + + +### 0.23.0 (2023-04-05) +In addition to *everything mentioned in 0.23.0-beta.1 notes*, this release adds support for Rustls, a long- +sought feature. Thanks to @rharish101 and @LeoRowan for getting this in! + +#### Changes +* Update Rustls to v0.21.0 ([#820](https://github.com/redis-rs/redis-rs/pull/820) @rharish101) +* Implement support for Rustls ([#725](https://github.com/redis-rs/redis-rs/pull/725) @rharish101, @LeoRowan) + + +### 0.23.0-beta.1 (2023-03-28) + +This release adds the `cluster_async` module, which introduces async Redis Cluster support. The code therein +is largely taken from @Marwes's [redis-cluster-async crate](https://github.com/redis-rs/redis-cluster-async), which itself +appears to have started from a sync Redis Cluster implementation started by @atuk721. In any case, thanks to @Marwes and @atuk721 +for the great work, and we hope to keep development moving forward in `redis-rs`. + +Though async Redis Cluster functionality for the time being has been kept as close to the originating crate as possible, previous users of +`redis-cluster-async` should note the following changes: +* Retries, while still configurable, can no longer be set to `None`/infinite retries +* Routing and slot parsing logic has been removed and merged with existing `redis-rs` functionality +* The client has been removed and superceded by common `ClusterClient` +* Renamed `Connection` to `ClusterConnection` +* Added support for reading from replicas +* Added support for insecure TLS +* Added support for setting both username and password + +#### Breaking Changes +* Fix long-standing bug related to `AsyncIter`'s stream implementation in which polling the server + for additional data yielded broken data in most cases. Type bounds for `AsyncIter` have changed slightly, + making this a potentially breaking change. ([#597](https://github.com/redis-rs/redis-rs/pull/597) @roger) + +#### Changes +* Commands: Add additional generic args for key arguments ([#795](https://github.com/redis-rs/redis-rs/pull/795) @MaxOhn) +* Add `mset` / deprecate `set_multiple` ([#766](https://github.com/redis-rs/redis-rs/pull/766) @randomairborne) +* More efficient interfaces for `MultiplexedConnection` and `ConnectionManager` ([#811](https://github.com/redis-rs/redis-rs/pull/811) @nihohit) +* Refactor / remove flaky test ([#810](https://github.com/redis-rs/redis-rs/pull/810)) +* `cluster_async`: rename `Connection` to `ClusterConnection`, `Pipeline` to `ClusterConnInner` ([#808](https://github.com/redis-rs/redis-rs/pull/808)) +* Support parsing IPV6 cluster nodes ([#796](https://github.com/redis-rs/redis-rs/pull/796) @socs) +* Common client for sync/async cluster connections ([#798](https://github.com/redis-rs/redis-rs/pull/798)) + * `cluster::ClusterConnection` underlying connection type is now generic (with existing type as default) + * Support `read_from_replicas` in cluster_async + * Set retries in `ClusterClientBuilder` + * Add mock tests for `cluster` +* cluster-async common slot parsing([#793](https://github.com/redis-rs/redis-rs/pull/793)) +* Support async-std in cluster_async module ([#790](https://github.com/redis-rs/redis-rs/pull/790)) +* Async-Cluster use same routing as Sync-Cluster ([#789](https://github.com/redis-rs/redis-rs/pull/789)) +* Add Async Cluster Support ([#696](https://github.com/redis-rs/redis-rs/pull/696)) +* Fix broken json-module tests ([#786](https://github.com/redis-rs/redis-rs/pull/786)) +* `cluster`: Tls Builder support / simplify cluster connection map ([#718](https://github.com/redis-rs/redis-rs/pull/718) @0xWOF, @utkarshgupta137) + + +### 0.22.3 (2023-01-23) + +#### Changes +* Restore inherent `ClusterConnection::check_connection()` method ([#758](https://github.com/redis-rs/redis-rs/pull/758) @robjtede) + + + +### 0.22.2 (2023-01-07) + +This release adds various incremental improvements and fixes a few long-standing bugs. Thanks to all our +contributors for making this release happen. + +#### Features +* Implement ToRedisArgs for HashMap ([#722](https://github.com/redis-rs/redis-rs/pull/722) @gibranamparan) +* Add explicit `MGET` command ([#729](https://github.com/redis-rs/redis-rs/pull/729) @vamshiaruru-virgodesigns) + +#### Bug fixes +* Enable single-item-vector `get` responses ([#507](https://github.com/redis-rs/redis-rs/pull/507) @hank121314) +* Fix empty result from xread_options with deleted entries ([#712](https://github.com/redis-rs/redis-rs/pull/712) @Quiwin) +* Limit Parser Recursion ([#724](https://github.com/redis-rs/redis-rs/pull/724)) +* Improve MultiplexedConnection Error Handling ([#699](https://github.com/redis-rs/redis-rs/pull/699)) + +#### Changes +* Add test case for atomic pipeline ([#702](https://github.com/redis-rs/redis-rs/pull/702) @CNLHC) +* Capture subscribe result error in PubSub doc example ([#739](https://github.com/redis-rs/redis-rs/pull/739) @baoyachi) +* Use async-std name resolution when necessary ([#701](https://github.com/redis-rs/redis-rs/pull/701) @UgnilJoZ) +* Add Script::invoke_async method ([#711](https://github.com/redis-rs/redis-rs/pull/711) @r-bk) +* Cluster Refactorings ([#717](https://github.com/redis-rs/redis-rs/pull/717), [#716](https://github.com/redis-rs/redis-rs/pull/716), [#709](https://github.com/redis-rs/redis-rs/pull/709), [#707](https://github.com/redis-rs/redis-rs/pull/707), [#706](https://github.com/redis-rs/redis-rs/pull/706) @0xWOF, @utkarshgupta137) +* Fix intermitent test failure ([#714](https://github.com/redis-rs/redis-rs/pull/714) @0xWOF, @utkarshgupta137) +* Doc changes ([#705](https://github.com/redis-rs/redis-rs/pull/705) @0xWOF, @utkarshgupta137) +* Lint fixes ([#704](https://github.com/redis-rs/redis-rs/pull/704) @0xWOF) + + + +### 0.22.1 (2022-10-18) + +#### Changes +* Add README attribute to Cargo.toml +* Update LICENSE file / symlink from parent directory + + +### 0.22.0 (2022-10-05) + +This release adds various incremental improvements, including +additional convenience commands, improved Cluster APIs, and various other bug +fixes/library improvements. + +Although the changes here are incremental, this is a major release due to the +breaking changes listed below. + +This release would not be possible without our many wonderful +contributors -- thank you! + +#### Breaking changes +* Box all large enum variants; changes enum signature ([#667](https://github.com/redis-rs/redis-rs/pull/667) @nihohit) +* Support ACL commands by adding Rule::Other to cover newly defined flags; adds new enum variant ([#685](https://github.com/redis-rs/redis-rs/pull/685) @garyhai) +* Switch from sha1 to sha1_smol; renames `sha1` feature ([#576](https://github.com/redis-rs/redis-rs/pull/576)) + +#### Features +* Add support for RedisJSON ([#657](https://github.com/redis-rs/redis-rs/pull/657) @d3rpp) +* Add support for weights in zunionstore and zinterstore ([#641](https://github.com/redis-rs/redis-rs/pull/641) @ndd7xv) +* Cluster: Create read_from_replicas option ([#635](https://github.com/redis-rs/redis-rs/pull/635) @utkarshgupta137) +* Make Direction a public enum to use with Commands like BLMOVE ([#646](https://github.com/redis-rs/redis-rs/pull/646) @thorbadour) +* Add `ahash` feature for using ahash internally & for redis values ([#636](https://github.com/redis-rs/redis-rs/pull/636) @utkarshgupta137) +* Add Script::load function ([#603](https://github.com/redis-rs/redis-rs/pull/603) @zhiburt) +* Add support for OBJECT ([[#610]](https://github.com/redis-rs/redis-rs/pull/610) @roger) +* Add GETEX and GETDEL support ([#582](https://github.com/redis-rs/redis-rs/pull/582) @arpandaze) +* Add support for ZMPOP ([#605](https://github.com/redis-rs/redis-rs/pull/605) @gkorland) + +#### Changes +* Rust 2021 Edition / MSRV 1.59.0 +* Fix: Support IPV6 link-local address parsing ([#679](https://github.com/redis-rs/redis-rs/pull/679) @buaazp) +* Derive Clone and add Deref trait to InfoDict ([#661](https://github.com/redis-rs/redis-rs/pull/661) @danni-m) +* ClusterClient: add handling for empty initial_nodes, use ClusterParams to store cluster parameters, improve builder pattern ([#669](https://github.com/redis-rs/redis-rs/pull/669) @utkarshgupta137) +* Implement Debug for MultiplexedConnection & Pipeline ([#664](https://github.com/redis-rs/redis-rs/pull/664) @elpiel) +* Add support for casting RedisResult to CString ([#660](https://github.com/redis-rs/redis-rs/pull/660) @nihohit) +* Move redis crate to subdirectory to support multiple crates in project ([#465](https://github.com/redis-rs/redis-rs/pull/465) @tdyas) +* Stop versioning Cargo.lock ([#620](https://github.com/redis-rs/redis-rs/pull/620)) +* Auto-implement ConnectionLike for DerefMut ([#567](https://github.com/redis-rs/redis-rs/pull/567) @holmesmr) +* Return errors from parsing inner items ([#608](https://github.com/redis-rs/redis-rs/pull/608)) +* Make dns resolution async, in async runtime ([#606](https://github.com/redis-rs/redis-rs/pull/606) @roger) +* Make async_trait dependency optional ([#572](https://github.com/redis-rs/redis-rs/pull/572) @kamulos) +* Add username to ClusterClient and ClusterConnection ([#596](https://github.com/redis-rs/redis-rs/pull/596) @gildaf) + + + +### 0.21.6 (2022-08-24) + +* Update dependencies ([#588](https://github.com/mitsuhiko/redis-rs/pull/588)) + + +### 0.21.5 (2022-01-10) + +#### Features + +* Add new list commands ([#570](https://github.com/mitsuhiko/redis-rs/pull/570)) + + + +### 0.21.4 (2021-11-04) + +#### Features + +* Add convenience command: zrandbember ([#556](https://github.com/mitsuhiko/redis-rs/pull/556)) + + + + +### 0.21.3 (2021-10-15) + +#### Features + +* Add support for TLS with cluster mode ([#548](https://github.com/mitsuhiko/redis-rs/pull/548)) + +#### Changes + +* Remove stunnel as a dep, use redis native tls ([#542](https://github.com/mitsuhiko/redis-rs/pull/542)) + + + + + +### 0.21.2 (2021-09-02) + + +#### Bug Fixes + +* Compile with tokio-comp and up-to-date dependencies ([282f997e](https://github.com/mitsuhiko/redis-rs/commit/282f997e41cc0de2a604c0f6a96d82818dacc373), closes [#531](https://github.com/mitsuhiko/redis-rs/issues/531), breaks [#](https://github.com/mitsuhiko/redis-rs/issues/)) + +#### Breaking Changes + +* Compile with tokio-comp and up-to-date dependencies ([282f997e](https://github.com/mitsuhiko/redis-rs/commit/282f997e41cc0de2a604c0f6a96d82818dacc373), closes [#531](https://github.com/mitsuhiko/redis-rs/issues/531), breaks [#](https://github.com/mitsuhiko/redis-rs/issues/)) + + + + +### 0.21.1 (2021-08-25) + + +#### Bug Fixes + +* pin futures dependency to required version ([9be392bc](https://github.com/mitsuhiko/redis-rs/commit/9be392bc5b22326a8a0eafc7aa18cc04c5d79e0e)) + + + + +### 0.21.0 (2021-07-16) + + +#### Performance + +* Don't enqueue multiplexed commands if the receiver is dropped ([ca5019db](https://github.com/mitsuhiko/redis-rs/commit/ca5019dbe76cc56c93eaecb5721de8fcf74d1641)) + +#### Features + +* Refactor ConnectionAddr to remove boxing and clarify fields + + +### 0.20.2 (2021-06-17) + +#### Features + +* Provide a new_async_std function ([c3716d15](https://github.com/mitsuhiko/redis-rs/commit/c3716d154f067b71acdd5bd927e118305cd0830b)) + +#### Bug Fixes + +* Return Ready(Ok(())) when we have flushed all messages ([ca319c06](https://github.com/mitsuhiko/redis-rs/commit/ca319c06ad80fc37f1f701aecebbd5dabb0dceb0)) +* Don't loop forever on shutdown of the multiplexed connection ([ddecce9e](https://github.com/mitsuhiko/redis-rs/commit/ddecce9e10b8ab626f41409aae289d62b4fb74be)) + + + + +### 0.20.1 (2021-05-18) + + +#### Bug Fixes + +* Error properly if eof is reached in the decoder ([306797c3](https://github.com/mitsuhiko/redis-rs/commit/306797c3c55ab24e0a29b6517356af794731d326)) + + + + +## 0.20.0 (2021-02-17) + + +#### Features + +* Make ErrorKind non_exhaustive for forwards compatibility ([ac5e1a60](https://github.com/mitsuhiko/redis-rs/commit/ac5e1a60d398814b18ed1b579fe3f6b337e545e9)) +* **aio:** Allow the underlying IO stream to be customized ([6d2fc8fa](https://github.com/mitsuhiko/redis-rs/commit/6d2fc8faa707fbbbaae9fe092bbc90ce01224523)) + + + + +## 0.19.0 (2020-12-26) + + +#### Features + +* Update to tokio 1.0 ([41960194](https://github.com/mitsuhiko/redis-rs/commit/4196019494aafc2bab718bafd1fdfd5e8c195ffa)) +* use the node specified in the MOVED error ([8a53e269](https://github.com/mitsuhiko/redis-rs/commit/8a53e2699d7d7bd63f222de778ed6820b0a65665)) + + + + +## 0.18.0 (2020-12-03) + + +#### Bug Fixes + +* Don't require tokio for the connection manager ([46be86f3](https://github.com/mitsuhiko/redis-rs/commit/46be86f3f07df4900559bf9a4dfd0b5138c3ac52)) + +* Make ToRedisArgs and FromRedisValue consistent for booleans + +BREAKING CHANGE + +bool are now written as 0 and 1 instead of true and false. Parsing a bool still accept true and false so this should not break anything for most users however if you are reading something out that was stored as a bool you may see different results. + +#### Features + +* Update tokio dependency to 0.3 ([bf5e0af3](https://github.com/mitsuhiko/redis-rs/commit/bf5e0af31c08be1785656031ffda96c355ee83c4), closes [#396](https://github.com/mitsuhiko/redis-rs/issues/396)) +* add doc_cfg for Makefile and docs.rs config ([1bf79517](https://github.com/mitsuhiko/redis-rs/commit/1bf795174521160934f3695326897458246e4978)) +* Impl FromRedisValue for i128 and u128 + + +# Changelog + +## [0.18.0](https://github.com/mitsuhiko/redis-rs/compare/0.17.0...0.18.0) - 2020-12-03 + +## [0.17.0](https://github.com/mitsuhiko/redis-rs/compare/0.16.0...0.17.0) - 2020-07-29 + +**Fixes and improvements** + +* Added Redis Streams commands ([#162](https://github.com/mitsuhiko/redis-rs/pull/319)) +* Added support for zpopmin and zpopmax ([#351](https://github.com/mitsuhiko/redis-rs/pull/351)) +* Added TLS support, gated by a feature flag ([#305](https://github.com/mitsuhiko/redis-rs/pull/305)) +* Added Debug and Clone implementations to redis::Script ([#365](https://github.com/mitsuhiko/redis-rs/pull/365)) +* Added FromStr for ConnectionInfo ([#368](https://github.com/mitsuhiko/redis-rs/pull/368)) +* Support SCAN methods on async connections ([#326](https://github.com/mitsuhiko/redis-rs/pull/326)) +* Removed unnecessary overhead around `Value` conversions ([#327](https://github.com/mitsuhiko/redis-rs/pull/327)) +* Support for Redis 6 auth ([#341](https://github.com/mitsuhiko/redis-rs/pull/341)) +* BUGFIX: Make aio::Connection Sync again ([#321](https://github.com/mitsuhiko/redis-rs/pull/321)) +* BUGFIX: Return UnexpectedEof if we try to decode at eof ([#322](https://github.com/mitsuhiko/redis-rs/pull/322)) +* Added support to create a connection from a (host, port) tuple ([#370](https://github.com/mitsuhiko/redis-rs/pull/370)) + +## [0.16.0](https://github.com/mitsuhiko/redis-rs/compare/0.15.1...0.16.0) - 2020-05-10 + +**Fixes and improvements** + +* Reduce dependencies without async IO ([#266](https://github.com/mitsuhiko/redis-rs/pull/266)) +* Add an afl fuzz target ([#274](https://github.com/mitsuhiko/redis-rs/pull/274)) +* Updated to combine 4 and avoid async dependencies for sync-only ([#272](https://github.com/mitsuhiko/redis-rs/pull/272)) + * **BREAKING CHANGE**: The parser type now only persists the buffer and takes the Read instance in `parse_value` +* Implement a connection manager for automatic reconnection ([#278](https://github.com/mitsuhiko/redis-rs/pull/278)) +* Add async-std support ([#281](https://github.com/mitsuhiko/redis-rs/pull/281)) +* Fix key extraction for some stream commands ([#283](https://github.com/mitsuhiko/redis-rs/pull/283)) +* Add asynchronous PubSub support ([#287](https://github.com/mitsuhiko/redis-rs/pull/287)) + +### Breaking changes + +#### Changes to the `Parser` type ([#272](https://github.com/mitsuhiko/redis-rs/pull/272)) + +The parser type now only persists the buffer and takes the Read instance in `parse_value`. +`redis::parse_redis_value` is unchanged and continues to work. + + +Old: + +```rust +let mut parser = Parser::new(bytes); +let result = parser.parse_value(); +``` + +New: + +```rust +let mut parser = Parser::new(); +let result = parser.pase_value(bytes); +``` + +## [0.15.1](https://github.com/mitsuhiko/redis-rs/compare/0.15.0...0.15.1) - 2020-01-15 + +**Fixes and improvements** + +* Fixed the `r2d2` feature (re-added it) ([#265](https://github.com/mitsuhiko/redis-rs/pull/265)) + +## [0.15.0](https://github.com/mitsuhiko/redis-rs/compare/0.14.0...0.15.0) - 2020-01-15 + +**Fixes and improvements** + +* Added support for redis cluster ([#239](https://github.com/mitsuhiko/redis-rs/pull/239)) + +## [0.14.0](https://github.com/mitsuhiko/redis-rs/compare/0.13.0...0.14.0) - 2020-01-08 + +**Fixes and improvements** + +* Fix the command verb being sent to redis for `zremrangebyrank` ([#240](https://github.com/mitsuhiko/redis-rs/pull/240)) +* Add `get_connection_with_timeout` to Client ([#243](https://github.com/mitsuhiko/redis-rs/pull/243)) +* **Breaking change:** Add Cmd::get, Cmd::set and remove PipelineCommands ([#253](https://github.com/mitsuhiko/redis-rs/pull/253)) +* Async-ify the API ([#232](https://github.com/mitsuhiko/redis-rs/pull/232)) +* Bump minimal required Rust version to 1.39 (required for the async/await API) +* Add async/await examples ([#261](https://github.com/mitsuhiko/redis-rs/pull/261), [#263](https://github.com/mitsuhiko/redis-rs/pull/263)) +* Added support for PSETEX and PTTL commands. ([#259](https://github.com/mitsuhiko/redis-rs/pull/259)) + +### Breaking changes + +#### Add Cmd::get, Cmd::set and remove PipelineCommands ([#253](https://github.com/mitsuhiko/redis-rs/pull/253)) + +If you are using pipelines and were importing the `PipelineCommands` trait you can now remove that import +and only use the `Commands` trait. + +Old: + +```rust +use redis::{Commands, PipelineCommands}; +``` + +New: + +```rust +use redis::Commands; +``` + +## [0.13.0](https://github.com/mitsuhiko/redis-rs/compare/0.12.0...0.13.0) - 2019-10-14 + +**Fixes and improvements** + +* **Breaking change:** rename `parse_async` to `parse_redis_value_async` for consistency ([ce59cecb](https://github.com/mitsuhiko/redis-rs/commit/ce59cecb830d4217115a4e74e38891e76cf01474)). +* Run clippy over the entire codebase ([#238](https://github.com/mitsuhiko/redis-rs/pull/238)) +* **Breaking change:** Make `Script#invoke_async` generic over `aio::ConnectionLike` ([#242](https://github.com/mitsuhiko/redis-rs/pull/242)) + +### BREAKING CHANGES + +#### Rename `parse_async` to `parse_redis_value_async` for consistency ([ce59cecb](https://github.com/mitsuhiko/redis-rs/commit/ce59cecb830d4217115a4e74e38891e76cf01474)). + +If you used `redis::parse_async` before, you now need to change this to `redis::parse_redis_value_async` +or import the method under the new name: `use redis::parse_redis_value_async`. + +#### Make `Script#invoke_async` generic over `aio::ConnectionLike` ([#242](https://github.com/mitsuhiko/redis-rs/pull/242)) + +`Script#invoke_async` was changed to be generic over `aio::ConnectionLike` in order to support wrapping a `SharedConnection` in user code. +This required adding a new generic parameter to the method, causing an error when the return type is defined using the turbofish syntax. + +Old: + +```rust +redis::Script::new("return ...") + .key("key1") + .arg("an argument") + .invoke_async::() +``` + +New: + +```rust +redis::Script::new("return ...") + .key("key1") + .arg("an argument") + .invoke_async::<_, String>() +``` + +## [0.12.0](https://github.com/mitsuhiko/redis-rs/compare/0.11.0...0.12.0) - 2019-08-26 + +**Fixes and improvements** + +* **Breaking change:** Use `dyn` keyword to avoid deprecation warning ([#223](https://github.com/mitsuhiko/redis-rs/pull/223)) +* **Breaking change:** Update url dependency to v2 ([#234](https://github.com/mitsuhiko/redis-rs/pull/234)) +* **Breaking change:** (async) Fix `Script::invoke_async` return type error ([#233](https://github.com/mitsuhiko/redis-rs/pull/233)) +* Add `GETRANGE` and `SETRANGE` commands ([#202](https://github.com/mitsuhiko/redis-rs/pull/202)) +* Fix `SINTERSTORE` wrapper name, it's now correctly `sinterstore` ([#225](https://github.com/mitsuhiko/redis-rs/pull/225)) +* Allow running `SharedConnection` with any other runtime ([#229](https://github.com/mitsuhiko/redis-rs/pull/229)) +* Reformatted as Edition 2018 code ([#235](https://github.com/mitsuhiko/redis-rs/pull/235)) + +### BREAKING CHANGES + +#### Use `dyn` keyword to avoid deprecation warning ([#223](https://github.com/mitsuhiko/redis-rs/pull/223)) + +Rust nightly deprecated bare trait objects. +This PR adds the `dyn` keyword to all trait objects in order to get rid of the warning. +This bumps the minimal supported Rust version to [Rust 1.27](https://blog.rust-lang.org/2018/06/21/Rust-1.27.html). + +#### Update url dependency to v2 ([#234](https://github.com/mitsuhiko/redis-rs/pull/234)) + +We updated the `url` dependency to v2. We do expose this on our public API on the `redis::parse_redis_url` function. If you depend on that, make sure to also upgrade your direct dependency. + +#### (async) Fix Script::invoke_async return type error ([#233](https://github.com/mitsuhiko/redis-rs/pull/233)) + +Previously, invoking a script with a complex return type would cause the following error: + +``` +Response was of incompatible type: "Not a bulk response" (response was string data('"4b98bef92b171357ddc437b395c7c1a5145ca2bd"')) +``` + +This was because the Future returned when loading the script into the database returns the hash of the script, and thus the return type of `String` would not match the intended return type. + +This commit adds an enum to account for the different Future return types. + + +## [0.11.0](https://github.com/mitsuhiko/redis-rs/compare/0.11.0-beta.2...0.11.0) - 2019-07-19 + +This release includes all fixes & improvements from the two beta releases listed below. +This release contains breaking changes. + +**Fixes and improvements** + +* (async) Fix performance problem for SharedConnection ([#222](https://github.com/mitsuhiko/redis-rs/pull/222)) + +## [0.11.0-beta.2](https://github.com/mitsuhiko/redis-rs/compare/0.11.0-beta.1...0.11.0-beta.2) - 2019-07-14 + +**Fixes and improvements** + +* (async) Don't block the executor from shutting down ([#217](https://github.com/mitsuhiko/redis-rs/pull/217)) + +## [0.11.0-beta.1](https://github.com/mitsuhiko/redis-rs/compare/0.10.0...0.11.0-beta.1) - 2019-05-30 + +**Fixes and improvements** + +* (async) Simplify implicit pipeline handling ([#182](https://github.com/mitsuhiko/redis-rs/pull/182)) +* (async) Use `tokio_sync`'s channels instead of futures ([#195](https://github.com/mitsuhiko/redis-rs/pull/195)) +* (async) Only allocate one oneshot per request ([#194](https://github.com/mitsuhiko/redis-rs/pull/194)) +* Remove redundant BufReader when parsing ([#197](https://github.com/mitsuhiko/redis-rs/pull/197)) +* Hide actual type returned from async parser ([#193](https://github.com/mitsuhiko/redis-rs/pull/193)) +* Use more performant operations for line parsing ([#198](https://github.com/mitsuhiko/redis-rs/pull/198)) +* Optimize the command encoding, see below for **breaking changes** ([#165](https://github.com/mitsuhiko/redis-rs/pull/165)) +* Add support for geospatial commands ([#130](https://github.com/mitsuhiko/redis-rs/pull/130)) +* (async) Add support for async Script invocation ([#206](https://github.com/mitsuhiko/redis-rs/pull/206)) + +### BREAKING CHANGES + +#### Renamed the async module to aio ([#189](https://github.com/mitsuhiko/redis-rs/pull/189)) + +`async` is a reserved keyword in Rust 2018, so this avoids the need to write `r#async` in it. + +Old code: + +```rust +use redis::async::SharedConnection; +``` + +New code: + +```rust +use redis::aio::SharedConnection; +``` + +#### The trait `ToRedisArgs` was changed ([#165](https://github.com/mitsuhiko/redis-rs/pull/165)) + +`ToRedisArgs` has been changed to take take an instance of `RedisWrite` instead of `Vec>`. Use the `write_arg` method instead of `Vec::push`. + +#### Minimum Rust version is now 1.26 ([#165](https://github.com/mitsuhiko/redis-rs/pull/165)) + +Upgrade your compiler. +`impl Iterator` is used, requiring a more recent version of the Rust compiler. + +#### `iter` now takes `self` by value ([#165](https://github.com/mitsuhiko/redis-rs/pull/165)) + +`iter` now takes `self` by value instead of cloning `self` inside the method. + +Old code: + +```rust +let mut iter : redis::Iter = cmd.arg("my_set").cursor_arg(0).iter(&con).unwrap(); +``` + +New code: + +```rust +let mut iter : redis::Iter = cmd.arg("my_set").cursor_arg(0).clone().iter(&con).unwrap(); +``` + +(The above line calls `clone()`.) + +#### A mutable connection object is now required ([#148](https://github.com/mitsuhiko/redis-rs/pull/148)) + +We removed the internal usage of `RefCell` and `Cell` and instead require a mutable reference, `&mut ConnectionLike`, +on all command calls. + +Old code: + +```rust +let client = redis::Client::open("redis://127.0.0.1/")?; +let con = client.get_connection(None)?; +redis::cmd("SET").arg("my_key").arg(42).execute(&con); +``` + +New code: + +```rust +let client = redis::Client::open("redis://127.0.0.1/")?; +let mut con = client.get_connection(None)?; +redis::cmd("SET").arg("my_key").arg(42).execute(&mut con); +``` + +Due to this, `transaction` has changed. The callback now also receives a mutable reference to the used connection. + +Old code: + +```rust +let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +let con = client.get_connection(None).unwrap(); +let key = "the_key"; +let (new_val,) : (isize,) = redis::transaction(&con, &[key], |pipe| { + let old_val : isize = con.get(key)?; + pipe + .set(key, old_val + 1).ignore() + .get(key).query(&con) +})?; +``` + +New code: + +```rust +let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +let mut con = client.get_connection(None).unwrap(); +let key = "the_key"; +let (new_val,) : (isize,) = redis::transaction(&mut con, &[key], |con, pipe| { + let old_val : isize = con.get(key)?; + pipe + .set(key, old_val + 1).ignore() + .get(key).query(&con) +})?; +``` + +#### Remove `rustc-serialize` feature ([#200](https://github.com/mitsuhiko/redis-rs/pull/200)) + +We removed serialization to/from JSON. The underlying library is deprecated for a long time. + +Old code in `Cargo.toml`: + +``` +[dependencies.redis] +version = "0.9.1" +features = ["with-rustc-json"] +``` + +There's no replacement for the feature. +Use [serde](https://serde.rs/) and handle the serialization/deserialization in your own code. + +#### Remove `with-unix-sockets` feature ([#201](https://github.com/mitsuhiko/redis-rs/pull/201)) + +We removed the Unix socket feature. It is now always enabled. +We also removed auto-detection. + +Old code in `Cargo.toml`: + +``` +[dependencies.redis] +version = "0.9.1" +features = ["with-unix-sockets"] +``` + +There's no replacement for the feature. Unix sockets will continue to work by default. + +## [0.10.0](https://github.com/mitsuhiko/redis-rs/compare/0.9.1...0.10.0) - 2019-02-19 + +* Fix handling of passwords with special characters (#163) +* Better performance for async code due to less boxing (#167) + * CAUTION: redis-rs will now require Rust 1.26 +* Add `clear` method to the pipeline (#176) +* Better benchmarking (#179) +* Fully formatted source code (#181) + +## [0.9.1](https://github.com/mitsuhiko/redis-rs/compare/0.9.0...0.9.1) (2018-09-10) + +* Add ttl command + +## [0.9.0](https://github.com/mitsuhiko/redis-rs/compare/0.8.0...0.9.0) (2018-08-08) + +Some time has passed since the last release. +This new release will bring less bugs, more commands, experimental async support and better performance. + +Highlights: + +* Implement flexible PubSub API (#136) +* Avoid allocating some redundant Vec's during encoding (#140) +* Add an async interface using futures-rs (#141) +* Allow the async connection to have multiple in flight requests (#143) + +The async support is currently experimental. + +## [0.8.0](https://github.com/mitsuhiko/redis-rs/compare/0.7.1...0.8.0) (2016-12-26) + +* Add publish command + +## [0.7.1](https://github.com/mitsuhiko/redis-rs/compare/0.7.0...0.7.1) (2016-12-17) + +* Fix unix socket builds +* Relax lifetimes for scripts + +## [0.7.0](https://github.com/mitsuhiko/redis-rs/compare/0.6.0...0.7.0) (2016-07-23) + +* Add support for built-in unix sockets + +## [0.6.0](https://github.com/mitsuhiko/redis-rs/compare/0.5.4...0.6.0) (2016-07-14) + +* feat: Make rustc-serialize an optional feature (#96) + +## [0.5.4](https://github.com/mitsuhiko/redis-rs/compare/0.5.3...0.5.4) (2016-06-25) + +* fix: Improved single arg handling (#95) +* feat: Implement ToRedisArgs for &String (#89) +* feat: Faster command encoding (#94) + +## [0.5.3](https://github.com/mitsuhiko/redis-rs/compare/0.5.2...0.5.3) (2016-05-03) + +* fix: Use explicit versions for dependencies +* fix: Send `AUTH` command before other commands +* fix: Shutdown connection upon protocol error +* feat: Add `keys` method +* feat: Possibility to set read and write timeouts for the connection diff --git a/glide-core/redis-rs/redis/Cargo.toml b/glide-core/redis-rs/redis/Cargo.toml new file mode 100644 index 0000000000..25b06f64c2 --- /dev/null +++ b/glide-core/redis-rs/redis/Cargo.toml @@ -0,0 +1,251 @@ +[package] +name = "redis" +version = "0.25.2" +keywords = ["redis", "database"] +description = "Redis driver for Rust." +homepage = "https://github.com/redis-rs/redis-rs" +repository = "https://github.com/redis-rs/redis-rs" +documentation = "https://docs.rs/redis" +license = "BSD-3-Clause" +edition = "2021" +rust-version = "1.67" +readme = "../README.md" + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] + +[lib] +bench = false + +[dependencies] +# These two are generally really common simple dependencies so it does not seem +# much of a point to optimize these, but these could in theory be removed for +# an indirection through std::Formatter. +ryu = "1.0" +itoa = "1.0" + +# Strum is a set of macros and traits for working with enums and strings easier in Rust. +strum = "0.26" +strum_macros = "0.26" + +# This is a dependency that already exists in url +percent-encoding = "2.1" + +# We need this for redis url parsing +url = "= 2.5.0" + +# We need this for script support +sha1_smol = { version = "1.0", optional = true } + +combine = { version = "4.6", default-features = false, features = ["std"] } + +# Only needed for AIO +bytes = { version = "1", optional = true } +futures-util = { version = "0.3.15", default-features = false, optional = true } +pin-project-lite = { version = "0.2", optional = true } +tokio-util = { version = "0.7", optional = true } +tokio = { version = "1", features = ["rt", "net", "time", "sync"] } +socket2 = { version = "0.5", features = ["all"], optional = true } +dispose = { version = "0.5.0", optional = true } + +# Only needed for the connection manager +arc-swap = { version = "1.7.1" } +futures = { version = "0.3.3", optional = true } + +# Only needed for the r2d2 feature +r2d2 = { version = "0.8.8", optional = true } + +# Only needed for cluster +crc16 = { version = "0.4", optional = true } +rand = { version = "0.8", optional = true } + +# Only needed for async cluster +dashmap = { version = "6.0", optional = true } + +async-trait = { version = "0.1.24", optional = true } + +# Only needed for tokio support +tokio-retry2 = { version = "0.5", features = ["jitter"], optional = true } + +# Only needed for native tls +native-tls = { version = "0.2", optional = true } +tokio-native-tls = { version = "0.3", optional = true } + +# Only needed for rustls +rustls = { version = "0.22", optional = true } +webpki-roots = { version = "0.26", optional = true } +rustls-native-certs = { version = "0.7", optional = true } +tokio-rustls = { version = "0.25", optional = true } +rustls-pemfile = { version = "2", optional = true } +rustls-pki-types = { version = "1", optional = true } + +# Only needed for RedisJSON Support +serde = { version = "1.0.82", optional = true } +serde_json = { version = "1.0.82", optional = true } + +# Only needed for bignum Support +rust_decimal = { version = "1.33.1", optional = true } +bigdecimal = { version = "0.4.2", optional = true } +num-bigint = "0.4.4" + +# Optional aHash support +ahash = { version = "0.8.11", optional = true } + +tracing = "0.1" + +# Optional uuid support +uuid = { version = "1.6.1", optional = true } + +telemetrylib = { path = "../../telemetry" } + +lazy_static = "1" + +[features] +default = [ + "acl", + "streams", + "geospatial", + "script", + "keep-alive", + "tokio-comp", + "tokio-rustls-comp", + "connection-manager", + "cluster", + "cluster-async", +] +acl = [] +aio = [ + "bytes", + "pin-project-lite", + "futures-util", + "futures-util/alloc", + "futures-util/sink", + "tokio/io-util", + "tokio-util", + "tokio-util/codec", + "combine/tokio", + "async-trait", + "dispose", +] +geospatial = [] +json = ["serde", "serde/derive", "serde_json"] +cluster = ["crc16", "rand"] +script = ["sha1_smol"] +tls-native-tls = ["native-tls"] +tls-rustls = [ + "rustls", + "rustls-native-certs", + "rustls-pemfile", + "rustls-pki-types", +] +tls-rustls-insecure = ["tls-rustls"] +tls-rustls-webpki-roots = ["tls-rustls", "webpki-roots"] +tokio-comp = ["aio", "tokio/net", "tokio-retry2"] +tokio-native-tls-comp = ["tokio-comp", "tls-native-tls", "tokio-native-tls"] +tokio-rustls-comp = ["tokio-comp", "tls-rustls", "tokio-rustls"] +connection-manager = ["futures", "aio", "tokio-retry2"] +streams = [] +cluster-async = ["cluster", "futures", "futures-util", "dashmap"] +keep-alive = ["socket2"] +sentinel = ["rand"] +tcp_nodelay = [] +rust_decimal = ["dep:rust_decimal"] +bigdecimal = ["dep:bigdecimal"] +num-bigint = [] +uuid = ["dep:uuid"] +disable-client-setinfo = [] + +# Deprecated features +tls = ["tls-native-tls"] # use "tls-native-tls" instead + +[dev-dependencies] +rand = "0.8" +socket2 = "0.5" +assert_approx_eq = "1.0" +fnv = "1.0.5" +futures = "0.3" +futures-time = "3" +criterion = "0.4" +partial-io = { version = "0.5", features = ["tokio", "quickcheck1"] } +quickcheck = "1.0.3" +tokio = { version = "1", features = [ + "rt", + "macros", + "rt-multi-thread", + "time", +] } +tempfile = "=3.6.0" +once_cell = "1" +anyhow = "1" +sscanf = "0.4.1" +serial_test = "2" +versions = "6.3" + +[[test]] +name = "test_async" +required-features = ["tokio-comp"] + +[[test]] +name = "parser" +required-features = ["aio"] + +[[test]] +name = "test_acl" + +[[test]] +name = "test_module_json" +required-features = ["json", "serde/derive"] + +[[test]] +name = "test_cluster_async" +required-features = ["cluster-async", "tokio-comp"] + +[[test]] +name = "test_async_cluster_connections_logic" +required-features = ["cluster-async"] + +[[test]] +name = "test_bignum" + +[[bench]] +name = "bench_basic" +harness = false +required-features = ["tokio-comp"] + +[[bench]] +name = "bench_cluster" +harness = false +required-features = ["cluster"] + +[[bench]] +name = "bench_cluster_async" +harness = false +required-features = ["cluster-async", "tokio-comp"] + +[[example]] +name = "async-multiplexed" +required-features = ["tokio-comp"] + +[[example]] +name = "async-await" +required-features = ["aio"] + +[[example]] +name = "async-pub-sub" +required-features = ["aio"] + +[[example]] +name = "async-scan" +required-features = ["aio"] + +[[example]] +name = "async-connection-loss" +required-features = ["connection-manager"] + +[[example]] +name = "streams" +required-features = ["streams"] + +[package.metadata.cargo-machete] +ignored = ["strum"] diff --git a/glide-core/redis-rs/redis/LICENSE b/glide-core/redis-rs/redis/LICENSE new file mode 100644 index 0000000000..533ac4e5a2 --- /dev/null +++ b/glide-core/redis-rs/redis/LICENSE @@ -0,0 +1,33 @@ +Copyright (c) 2022 by redis-rs contributors + +Redis cluster code in parts copyright (c) 2018 by Atsushi Koge. + +Some rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + * The names of the contributors may not be used to endorse or + promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/glide-core/redis-rs/redis/benches/bench_basic.rs b/glide-core/redis-rs/redis/benches/bench_basic.rs new file mode 100644 index 0000000000..356f74217e --- /dev/null +++ b/glide-core/redis-rs/redis/benches/bench_basic.rs @@ -0,0 +1,277 @@ +use criterion::{criterion_group, criterion_main, Bencher, Criterion, Throughput}; +use futures::prelude::*; +use redis::{RedisError, Value}; + +use support::*; + +#[path = "../tests/support/mod.rs"] +mod support; + +fn bench_simple_getsetdel(b: &mut Bencher) { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + b.iter(|| { + let key = "test_key"; + redis::cmd("SET").arg(key).arg(42).execute(&mut con); + let _: isize = redis::cmd("GET").arg(key).query(&mut con).unwrap(); + redis::cmd("DEL").arg(key).execute(&mut con); + }); +} + +fn bench_simple_getsetdel_async(b: &mut Bencher) { + let ctx = TestContext::new(); + let runtime = current_thread_runtime(); + let mut con = runtime.block_on(ctx.async_connection()).unwrap(); + + b.iter(|| { + runtime + .block_on(async { + let key = "test_key"; + () = redis::cmd("SET") + .arg(key) + .arg(42) + .query_async(&mut con) + .await?; + let _: isize = redis::cmd("GET").arg(key).query_async(&mut con).await?; + () = redis::cmd("DEL").arg(key).query_async(&mut con).await?; + Ok::<_, RedisError>(()) + }) + .unwrap() + }); +} + +fn bench_simple_getsetdel_pipeline(b: &mut Bencher) { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + b.iter(|| { + let key = "test_key"; + let _: (usize,) = redis::pipe() + .cmd("SET") + .arg(key) + .arg(42) + .ignore() + .cmd("GET") + .arg(key) + .cmd("DEL") + .arg(key) + .ignore() + .query(&mut con) + .unwrap(); + }); +} + +fn bench_simple_getsetdel_pipeline_precreated(b: &mut Bencher) { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let key = "test_key"; + let mut pipe = redis::pipe(); + pipe.cmd("SET") + .arg(key) + .arg(42) + .ignore() + .cmd("GET") + .arg(key) + .cmd("DEL") + .arg(key) + .ignore(); + + b.iter(|| { + let _: (usize,) = pipe.query(&mut con).unwrap(); + }); +} + +const PIPELINE_QUERIES: usize = 1_000; + +fn long_pipeline() -> redis::Pipeline { + let mut pipe = redis::pipe(); + + for i in 0..PIPELINE_QUERIES { + pipe.set(format!("foo{i}"), "bar").ignore(); + } + pipe +} + +fn bench_long_pipeline(b: &mut Bencher) { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let pipe = long_pipeline(); + + b.iter(|| { + pipe.query::<()>(&mut con).unwrap(); + }); +} + +fn bench_async_long_pipeline(b: &mut Bencher) { + let ctx = TestContext::new(); + let runtime = current_thread_runtime(); + let mut con = runtime.block_on(ctx.async_connection()).unwrap(); + + let pipe = long_pipeline(); + + b.iter(|| { + runtime + .block_on(async { pipe.query_async::<_, ()>(&mut con).await }) + .unwrap(); + }); +} + +fn bench_multiplexed_async_long_pipeline(b: &mut Bencher) { + let ctx = TestContext::new(); + let runtime = current_thread_runtime(); + let mut con = runtime + .block_on(ctx.multiplexed_async_connection_tokio()) + .unwrap(); + + let pipe = long_pipeline(); + + b.iter(|| { + runtime + .block_on(async { pipe.query_async::<_, ()>(&mut con).await }) + .unwrap(); + }); +} + +fn bench_multiplexed_async_implicit_pipeline(b: &mut Bencher) { + let ctx = TestContext::new(); + let runtime = current_thread_runtime(); + let con = runtime + .block_on(ctx.multiplexed_async_connection_tokio()) + .unwrap(); + + let cmds: Vec<_> = (0..PIPELINE_QUERIES) + .map(|i| redis::cmd("SET").arg(format!("foo{i}")).arg(i).clone()) + .collect(); + + let mut connections = (0..PIPELINE_QUERIES) + .map(|_| con.clone()) + .collect::>(); + + b.iter(|| { + runtime + .block_on(async { + cmds.iter() + .zip(&mut connections) + .map(|(cmd, con)| cmd.query_async::<_, ()>(con)) + .collect::>() + .try_for_each(|()| async { Ok(()) }) + .await + }) + .unwrap(); + }); +} + +fn bench_query(c: &mut Criterion) { + let mut group = c.benchmark_group("query"); + group + .bench_function("simple_getsetdel", bench_simple_getsetdel) + .bench_function("simple_getsetdel_async", bench_simple_getsetdel_async) + .bench_function("simple_getsetdel_pipeline", bench_simple_getsetdel_pipeline) + .bench_function( + "simple_getsetdel_pipeline_precreated", + bench_simple_getsetdel_pipeline_precreated, + ); + group.finish(); + + let mut group = c.benchmark_group("query_pipeline"); + group + .bench_function( + "multiplexed_async_implicit_pipeline", + bench_multiplexed_async_implicit_pipeline, + ) + .bench_function( + "multiplexed_async_long_pipeline", + bench_multiplexed_async_long_pipeline, + ) + .bench_function("async_long_pipeline", bench_async_long_pipeline) + .bench_function("long_pipeline", bench_long_pipeline) + .throughput(Throughput::Elements(PIPELINE_QUERIES as u64)); + group.finish(); +} + +fn bench_encode_small(b: &mut Bencher) { + b.iter(|| { + let mut cmd = redis::cmd("HSETX"); + + cmd.arg("ABC:1237897325302:878241asdyuxpioaswehqwu") + .arg("some hash key") + .arg(124757920); + + cmd.get_packed_command() + }); +} + +fn bench_encode_integer(b: &mut Bencher) { + b.iter(|| { + let mut pipe = redis::pipe(); + + for _ in 0..1_000 { + pipe.set(123, 45679123).ignore(); + } + pipe.get_packed_pipeline() + }); +} + +fn bench_encode_pipeline(b: &mut Bencher) { + b.iter(|| { + let mut pipe = redis::pipe(); + + for _ in 0..1_000 { + pipe.set("foo", "bar").ignore(); + } + pipe.get_packed_pipeline() + }); +} + +fn bench_encode_pipeline_nested(b: &mut Bencher) { + b.iter(|| { + let mut pipe = redis::pipe(); + + for _ in 0..200 { + pipe.set( + "foo", + ("bar", 123, b"1231279712", &["test", "test", "test"][..]), + ) + .ignore(); + } + pipe.get_packed_pipeline() + }); +} + +fn bench_encode(c: &mut Criterion) { + let mut group = c.benchmark_group("encode"); + group + .bench_function("pipeline", bench_encode_pipeline) + .bench_function("pipeline_nested", bench_encode_pipeline_nested) + .bench_function("integer", bench_encode_integer) + .bench_function("small", bench_encode_small); + group.finish(); +} + +fn bench_decode_simple(b: &mut Bencher, input: &[u8]) { + b.iter(|| redis::parse_redis_value(input).unwrap()); +} +fn bench_decode(c: &mut Criterion) { + let value = Value::Array(vec![ + Value::Okay, + Value::SimpleString("testing".to_string()), + Value::Array(vec![]), + Value::Nil, + Value::BulkString(vec![b'a'; 10]), + Value::Int(7512182390), + ]); + + let mut group = c.benchmark_group("decode"); + { + let mut input = Vec::new(); + support::encode_value(&value, &mut input).unwrap(); + assert_eq!(redis::parse_redis_value(&input).unwrap(), value); + group.bench_function("decode", move |b| bench_decode_simple(b, &input)); + } + group.finish(); +} + +criterion_group!(bench, bench_query, bench_encode, bench_decode); +criterion_main!(bench); diff --git a/glide-core/redis-rs/redis/benches/bench_cluster.rs b/glide-core/redis-rs/redis/benches/bench_cluster.rs new file mode 100644 index 0000000000..da854474ae --- /dev/null +++ b/glide-core/redis-rs/redis/benches/bench_cluster.rs @@ -0,0 +1,108 @@ +#![allow(clippy::unit_arg)] // want to allow this for `black_box()` +#![cfg(feature = "cluster")] +use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use redis::cluster::cluster_pipe; + +use support::*; + +#[path = "../tests/support/mod.rs"] +mod support; + +const PIPELINE_QUERIES: usize = 100; + +fn bench_set_get_and_del(c: &mut Criterion, con: &mut redis::cluster::ClusterConnection) { + let key = "test_key"; + + let mut group = c.benchmark_group("cluster_basic"); + + group.bench_function("set", |b| { + b.iter(|| { + redis::cmd("SET").arg(key).arg(42).execute(con); + black_box(()) + }) + }); + + group.bench_function("get", |b| { + b.iter(|| black_box(redis::cmd("GET").arg(key).query::(con).unwrap())) + }); + + let mut set_and_del = || { + redis::cmd("SET").arg(key).arg(42).execute(con); + redis::cmd("DEL").arg(key).execute(con); + }; + group.bench_function("set_and_del", |b| { + b.iter(|| { + set_and_del(); + black_box(()) + }) + }); + + group.finish(); +} + +fn bench_pipeline(c: &mut Criterion, con: &mut redis::cluster::ClusterConnection) { + let mut group = c.benchmark_group("cluster_pipeline"); + group.throughput(Throughput::Elements(PIPELINE_QUERIES as u64)); + + let mut queries = Vec::new(); + for i in 0..PIPELINE_QUERIES { + queries.push(format!("foo{i}")); + } + + let build_pipeline = || { + let mut pipe = cluster_pipe(); + for q in &queries { + pipe.set(q, "bar").ignore(); + } + }; + group.bench_function("build_pipeline", |b| { + b.iter(|| { + build_pipeline(); + black_box(()) + }) + }); + + let mut pipe = cluster_pipe(); + for q in &queries { + pipe.set(q, "bar").ignore(); + } + group.bench_function("query_pipeline", |b| { + b.iter(|| { + pipe.query::<()>(con).unwrap(); + black_box(()) + }) + }); + + group.finish(); +} + +fn bench_cluster_setup(c: &mut Criterion) { + let cluster = TestClusterContext::new(6, 1); + cluster.wait_for_cluster_up(); + + let mut con = cluster.connection(); + bench_set_get_and_del(c, &mut con); + bench_pipeline(c, &mut con); +} + +#[allow(dead_code)] +fn bench_cluster_read_from_replicas_setup(c: &mut Criterion) { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 6, + 1, + |builder| builder.read_from_replicas(), + false, + ); + cluster.wait_for_cluster_up(); + + let mut con = cluster.connection(); + bench_set_get_and_del(c, &mut con); + bench_pipeline(c, &mut con); +} + +criterion_group!( + cluster_bench, + bench_cluster_setup, + // bench_cluster_read_from_replicas_setup +); +criterion_main!(cluster_bench); diff --git a/glide-core/redis-rs/redis/benches/bench_cluster_async.rs b/glide-core/redis-rs/redis/benches/bench_cluster_async.rs new file mode 100644 index 0000000000..28c3b83c87 --- /dev/null +++ b/glide-core/redis-rs/redis/benches/bench_cluster_async.rs @@ -0,0 +1,88 @@ +#![allow(clippy::unit_arg)] // want to allow this for `black_box()` +#![cfg(feature = "cluster")] +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use futures_util::{stream, TryStreamExt}; +use redis::RedisError; + +use support::*; +use tokio::runtime::Runtime; + +#[path = "../tests/support/mod.rs"] +mod support; + +fn bench_cluster_async( + c: &mut Criterion, + con: &mut redis::cluster_async::ClusterConnection, + runtime: &Runtime, +) { + let mut group = c.benchmark_group("cluster_async"); + group.bench_function("set_get_and_del", |b| { + b.iter(|| { + runtime + .block_on(async { + let key = "test_key"; + () = redis::cmd("SET").arg(key).arg(42).query_async(con).await?; + let _: isize = redis::cmd("GET").arg(key).query_async(con).await?; + () = redis::cmd("DEL").arg(key).query_async(con).await?; + + Ok::<_, RedisError>(()) + }) + .unwrap(); + black_box(()) + }) + }); + + group.bench_function("parallel_requests", |b| { + let num_parallel = 100; + let cmds: Vec<_> = (0..num_parallel) + .map(|i| redis::cmd("SET").arg(format!("foo{i}")).arg(i).clone()) + .collect(); + + let mut connections = (0..num_parallel).map(|_| con.clone()).collect::>(); + + b.iter(|| { + runtime + .block_on(async { + cmds.iter() + .zip(&mut connections) + .map(|(cmd, con)| cmd.query_async::<_, ()>(con)) + .collect::>() + .try_for_each(|()| async { Ok(()) }) + .await + }) + .unwrap(); + black_box(()) + }); + }); + + group.bench_function("pipeline", |b| { + let num_queries = 100; + + let mut pipe = redis::pipe(); + + for _ in 0..num_queries { + pipe.set("foo".to_string(), "bar").ignore(); + } + + b.iter(|| { + runtime + .block_on(async { pipe.query_async::<_, ()>(con).await }) + .unwrap(); + black_box(()) + }); + }); + + group.finish(); +} + +fn bench_cluster_setup(c: &mut Criterion) { + let cluster = TestClusterContext::new(6, 1); + cluster.wait_for_cluster_up(); + let runtime = current_thread_runtime(); + let mut con = runtime.block_on(cluster.async_connection(None)); + + bench_cluster_async(c, &mut con, &runtime); +} + +criterion_group!(cluster_async_bench, bench_cluster_setup,); +criterion_main!(cluster_async_bench); diff --git a/glide-core/redis-rs/redis/examples/async-await.rs b/glide-core/redis-rs/redis/examples/async-await.rs new file mode 100644 index 0000000000..2d829c7d60 --- /dev/null +++ b/glide-core/redis-rs/redis/examples/async-await.rs @@ -0,0 +1,24 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +use redis::{AsyncCommands, GlideConnectionOptions}; + +#[tokio::main] +async fn main() -> redis::RedisResult<()> { + let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + let mut con = client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + + con.set("key1", b"foo").await?; + + redis::cmd("SET") + .arg(&["key2", "bar"]) + .query_async(&mut con) + .await?; + + let result = redis::cmd("MGET") + .arg(&["key1", "key2"]) + .query_async(&mut con) + .await; + assert_eq!(result, Ok(("foo".to_string(), b"bar".to_vec()))); + Ok(()) +} diff --git a/glide-core/redis-rs/redis/examples/async-connection-loss.rs b/glide-core/redis-rs/redis/examples/async-connection-loss.rs new file mode 100644 index 0000000000..a7dba3ab89 --- /dev/null +++ b/glide-core/redis-rs/redis/examples/async-connection-loss.rs @@ -0,0 +1,97 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +//! This example will connect to Redis in one of three modes: +//! +//! - Regular async connection +//! - Async multiplexed connection +//! - Async connection manager +//! +//! It will then send a PING every 100 ms and print the result. + +use std::env; +use std::process; +use std::time::Duration; + +use futures::future; +use redis::aio::ConnectionLike; +use redis::GlideConnectionOptions; +use redis::RedisResult; +use tokio::time::interval; + +enum Mode { + Deprecated, + Default, + Reconnect, +} + +async fn run_single(mut con: C) -> RedisResult<()> { + let mut interval = interval(Duration::from_millis(100)); + loop { + interval.tick().await; + println!(); + println!("> PING"); + let result: RedisResult = redis::cmd("PING").query_async(&mut con).await; + println!("< {result:?}"); + } +} + +async fn run_multi(mut con: C) -> RedisResult<()> { + let mut interval = interval(Duration::from_millis(100)); + loop { + interval.tick().await; + println!(); + println!("> PING"); + println!("> PING"); + println!("> PING"); + let results: ( + RedisResult, + RedisResult, + RedisResult, + ) = future::join3( + redis::cmd("PING").query_async(&mut con.clone()), + redis::cmd("PING").query_async(&mut con.clone()), + redis::cmd("PING").query_async(&mut con), + ) + .await; + println!("< {:?}", results.0); + println!("< {:?}", results.1); + println!("< {:?}", results.2); + } +} + +#[tokio::main] +async fn main() -> RedisResult<()> { + let mode = match env::args().nth(1).as_deref() { + Some("default") => { + println!("Using default connection mode\n"); + Mode::Default + } + Some("reconnect") => { + println!("Using reconnect manager mode\n"); + Mode::Reconnect + } + Some("deprecated") => { + println!("Using deprecated connection mode\n"); + Mode::Deprecated + } + Some(_) | None => { + println!("Usage: reconnect-manager (default|multiplexed|reconnect)"); + process::exit(1); + } + }; + + let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + match mode { + Mode::Default => { + run_multi( + client + .get_multiplexed_tokio_connection(GlideConnectionOptions::default()) + .await?, + ) + .await? + } + Mode::Reconnect => run_multi(client.get_connection_manager().await?).await?, + #[allow(deprecated)] + Mode::Deprecated => run_single(client.get_async_connection(None).await?).await?, + }; + Ok(()) +} diff --git a/glide-core/redis-rs/redis/examples/async-multiplexed.rs b/glide-core/redis-rs/redis/examples/async-multiplexed.rs new file mode 100644 index 0000000000..2e5332359b --- /dev/null +++ b/glide-core/redis-rs/redis/examples/async-multiplexed.rs @@ -0,0 +1,46 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +use futures::prelude::*; +use redis::{aio::MultiplexedConnection, GlideConnectionOptions, RedisResult}; + +async fn test_cmd(con: &MultiplexedConnection, i: i32) -> RedisResult<()> { + let mut con = con.clone(); + + let key = format!("key{i}"); + let key2 = format!("key{i}_2"); + let value = format!("foo{i}"); + + redis::cmd("SET") + .arg(&key) + .arg(&value) + .query_async(&mut con) + .await?; + + redis::cmd("SET") + .arg(&[&key2, "bar"]) + .query_async(&mut con) + .await?; + + redis::cmd("MGET") + .arg(&[&key, &key2]) + .query_async(&mut con) + .map(|result| { + assert_eq!(Ok((value, b"bar".to_vec())), result); + Ok(()) + }) + .await +} + +#[tokio::main] +async fn main() { + let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + + let con = client + .get_multiplexed_tokio_connection(GlideConnectionOptions::default()) + .await + .unwrap(); + + let cmds = (0..100).map(|i| test_cmd(&con, i)); + let result = future::try_join_all(cmds).await.unwrap(); + + assert_eq!(100, result.len()); +} diff --git a/glide-core/redis-rs/redis/examples/async-pub-sub.rs b/glide-core/redis-rs/redis/examples/async-pub-sub.rs new file mode 100644 index 0000000000..fe84b44fb9 --- /dev/null +++ b/glide-core/redis-rs/redis/examples/async-pub-sub.rs @@ -0,0 +1,22 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +use futures_util::StreamExt as _; +use redis::{AsyncCommands, GlideConnectionOptions}; + +#[tokio::main] +async fn main() -> redis::RedisResult<()> { + let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + let mut publish_conn = client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + let mut pubsub_conn = client.get_async_pubsub().await?; + + pubsub_conn.subscribe("wavephone").await?; + let mut pubsub_stream = pubsub_conn.on_message(); + + publish_conn.publish("wavephone", "banana").await?; + + let pubsub_msg: String = pubsub_stream.next().await.unwrap().get_payload()?; + assert_eq!(&pubsub_msg, "banana"); + + Ok(()) +} diff --git a/glide-core/redis-rs/redis/examples/async-scan.rs b/glide-core/redis-rs/redis/examples/async-scan.rs new file mode 100644 index 0000000000..06a66fe83e --- /dev/null +++ b/glide-core/redis-rs/redis/examples/async-scan.rs @@ -0,0 +1,25 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +use futures::stream::StreamExt; +use redis::{AsyncCommands, AsyncIter, GlideConnectionOptions}; + +#[tokio::main] +async fn main() -> redis::RedisResult<()> { + let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + let mut con = client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + + con.set("async-key1", b"foo").await?; + con.set("async-key2", b"foo").await?; + + let iter: AsyncIter = con.scan().await?; + let mut keys: Vec<_> = iter.collect().await; + + keys.sort(); + + assert_eq!( + keys, + vec!["async-key1".to_string(), "async-key2".to_string()] + ); + Ok(()) +} diff --git a/glide-core/redis-rs/redis/examples/basic.rs b/glide-core/redis-rs/redis/examples/basic.rs new file mode 100644 index 0000000000..622dc36e59 --- /dev/null +++ b/glide-core/redis-rs/redis/examples/basic.rs @@ -0,0 +1,169 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +use redis::{transaction, Commands}; + +use std::collections::HashMap; +use std::env; + +/// This function demonstrates how a return value can be coerced into a +/// hashmap of tuples. This is particularly useful for responses like +/// CONFIG GET or all most H functions which will return responses in +/// such list of implied tuples. +fn do_print_max_entry_limits(con: &mut redis::Connection) -> redis::RedisResult<()> { + // since rust cannot know what format we actually want we need to be + // explicit here and define the type of our response. In this case + // String -> int fits all the items we query for. + let config: HashMap = redis::cmd("CONFIG") + .arg("GET") + .arg("*-max-*-entries") + .query(con)?; + + println!("Max entry limits:"); + + println!( + " max-intset: {}", + config.get("set-max-intset-entries").unwrap_or(&0) + ); + println!( + " hash-max-ziplist: {}", + config.get("hash-max-ziplist-entries").unwrap_or(&0) + ); + println!( + " list-max-ziplist: {}", + config.get("list-max-ziplist-entries").unwrap_or(&0) + ); + println!( + " zset-max-ziplist: {}", + config.get("zset-max-ziplist-entries").unwrap_or(&0) + ); + + Ok(()) +} + +/// This is a pretty stupid example that demonstrates how to create a large +/// set through a pipeline and how to iterate over it through implied +/// cursors. +fn do_show_scanning(con: &mut redis::Connection) -> redis::RedisResult<()> { + // This makes a large pipeline of commands. Because the pipeline is + // modified in place we can just ignore the return value upon the end + // of each iteration. + let mut pipe = redis::pipe(); + for num in 0..1000 { + pipe.cmd("SADD").arg("my_set").arg(num).ignore(); + } + + // since we don't care about the return value of the pipeline we can + // just cast it into the unit type. + pipe.query(con)?; + + // since rust currently does not track temporaries for us, we need to + // store it in a local variable. + let mut cmd = redis::cmd("SSCAN"); + cmd.arg("my_set").cursor_arg(0); + + // as a simple exercise we just sum up the iterator. Since the fold + // method carries an initial value we do not need to define the + // type of the iterator, rust will figure "int" out for us. + let sum: i32 = cmd.iter::(con)?.sum(); + + println!("The sum of all numbers in the set 0-1000: {sum}"); + + Ok(()) +} + +/// Demonstrates how to do an atomic increment in a very low level way. +fn do_atomic_increment_lowlevel(con: &mut redis::Connection) -> redis::RedisResult<()> { + let key = "the_key"; + println!("Run low-level atomic increment:"); + + // set the initial value so we have something to test with. + redis::cmd("SET").arg(key).arg(42).query(con)?; + + loop { + // we need to start watching the key we care about, so that our + // exec fails if the key changes. + redis::cmd("WATCH").arg(key).query(con)?; + + // load the old value, so we know what to increment. + let val: isize = redis::cmd("GET").arg(key).query(con)?; + + // at this point we can go into an atomic pipe (a multi block) + // and set up the keys. + let response: Option<(isize,)> = redis::pipe() + .atomic() + .cmd("SET") + .arg(key) + .arg(val + 1) + .ignore() + .cmd("GET") + .arg(key) + .query(con)?; + + match response { + None => { + continue; + } + Some(response) => { + let (new_val,) = response; + println!(" New value: {new_val}"); + break; + } + } + } + + Ok(()) +} + +/// Demonstrates how to do an atomic increment with transaction support. +fn do_atomic_increment(con: &mut redis::Connection) -> redis::RedisResult<()> { + let key = "the_key"; + println!("Run high-level atomic increment:"); + + // set the initial value so we have something to test with. + con.set(key, 42)?; + + // run the transaction block. + let (new_val,): (isize,) = transaction(con, &[key], |con, pipe| { + // load the old value, so we know what to increment. + let val: isize = con.get(key)?; + // increment + pipe.set(key, val + 1).ignore().get(key).query(con) + })?; + + // and print the result + println!("New value: {new_val}"); + + Ok(()) +} + +/// Runs all the examples and propagates errors up. +fn do_redis_code(url: &str) -> redis::RedisResult<()> { + // general connection handling + let client = redis::Client::open(url)?; + let mut con = client.get_connection(None)?; + + // read some config and print it. + do_print_max_entry_limits(&mut con)?; + + // demonstrate how scanning works. + do_show_scanning(&mut con)?; + + // shows an atomic increment. + do_atomic_increment_lowlevel(&mut con)?; + do_atomic_increment(&mut con)?; + + Ok(()) +} + +fn main() { + // at this point the errors are fatal, let's just fail hard. + let url = if env::args().nth(1) == Some("--tls".into()) { + "rediss://127.0.0.1:6380/#insecure" + } else { + "redis://127.0.0.1:6379/" + }; + + if let Err(err) = do_redis_code(url) { + println!("Could not execute example:"); + println!(" {}: {}", err.category(), err); + } +} diff --git a/glide-core/redis-rs/redis/examples/geospatial.rs b/glide-core/redis-rs/redis/examples/geospatial.rs new file mode 100644 index 0000000000..5033b6c775 --- /dev/null +++ b/glide-core/redis-rs/redis/examples/geospatial.rs @@ -0,0 +1,68 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +use std::process::exit; + +use redis::RedisResult; + +#[cfg(feature = "geospatial")] +fn run() -> RedisResult<()> { + use redis::{geo, Commands}; + use std::env; + use std::f64; + + let redis_url = match env::var("REDIS_URL") { + Ok(url) => url, + Err(..) => "redis://127.0.0.1/".to_string(), + }; + + let client = redis::Client::open(redis_url.as_str())?; + let mut con = client.get_connection(None)?; + + // Add some members to the geospatial index. + + let added: isize = con.geo_add( + "gis", + &[ + (geo::Coord::lon_lat("13.361389", "38.115556"), "Palermo"), + (geo::Coord::lon_lat("15.087269", "37.502669"), "Catania"), + (geo::Coord::lon_lat("13.5833332", "37.316667"), "Agrigento"), + ], + )?; + + println!("[geo_add] Added {added} members."); + + // Get the position of one of them. + + let position: Vec> = con.geo_pos("gis", "Palermo")?; + println!("[geo_pos] Position for Palermo: {position:?}"); + + // Search members near (13.5, 37.75) + + let options = geo::RadiusOptions::default() + .order(geo::RadiusOrder::Asc) + .with_dist() + .limit(2); + let items: Vec = + con.geo_radius("gis", 13.5, 37.75, 150.0, geo::Unit::Kilometers, options)?; + + for item in items { + println!( + "[geo_radius] {}, dist = {} Km", + item.name, + item.dist.unwrap_or(f64::NAN) + ); + } + + Ok(()) +} + +#[cfg(not(feature = "geospatial"))] +fn run() -> RedisResult<()> { + Ok(()) +} + +fn main() { + if let Err(e) = run() { + println!("{e:?}"); + exit(1); + } +} diff --git a/glide-core/redis-rs/redis/examples/streams.rs b/glide-core/redis-rs/redis/examples/streams.rs new file mode 100644 index 0000000000..8c40ea487d --- /dev/null +++ b/glide-core/redis-rs/redis/examples/streams.rs @@ -0,0 +1,270 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +#![cfg(feature = "streams")] + +use redis::streams::{StreamId, StreamKey, StreamMaxlen, StreamReadOptions, StreamReadReply}; + +use redis::{Commands, RedisResult, Value}; + +use std::thread; +use std::time::Duration; +use std::time::{SystemTime, UNIX_EPOCH}; + +const DOG_STREAM: &str = "example-dog"; +const CAT_STREAM: &str = "example-cat"; +const DUCK_STREAM: &str = "example-duck"; + +const STREAMS: &[&str] = &[DOG_STREAM, CAT_STREAM, DUCK_STREAM]; + +const SLOWNESSES: &[u8] = &[2, 3, 4]; + +/// This program generates an arbitrary set of records across three +/// different streams. It then reads the data back in such a way +/// that demonstrates basic usage of both the XREAD and XREADGROUP +/// commands. +fn main() { + let client = redis::Client::open("redis://127.0.0.1/").expect("client"); + + println!("Demonstrating XADD followed by XREAD, single threaded\n"); + + add_records(&client).expect("contrived record generation"); + + read_records(&client).expect("simple read"); + + demo_group_reads(&client); + + clean_up(&client) +} + +fn demo_group_reads(client: &redis::Client) { + println!("\n\nDemonstrating a longer stream of data flowing\nin over time, consumed by multiple threads using XREADGROUP\n"); + + let mut handles = vec![]; + + let cc = client.clone(); + // Launch a producer thread which repeatedly adds records, + // with only a small delay between writes. + handles.push(thread::spawn(move || { + let repeat = 30; + let slowness = 1; + for _ in 0..repeat { + add_records(&cc).expect("add"); + thread::sleep(Duration::from_millis(random_wait_millis(slowness))) + } + })); + + // Launch consumer threads which repeatedly read from the + // streams at various speeds. They'll effectively compete + // to consume the stream. + // + // Consumer groups are only appropriate for cases where you + // do NOT want each consumer to read ALL of the data. This + // example is a contrived scenario so that each consumer + // receives its own, specific chunk of data. + // + // Once the data is read, the redis-rs lib will automatically + // acknowledge its receipt via XACK. + // + // Read more about reading with consumer groups here: + // https://redis.io/commands/xreadgroup + for slowness in SLOWNESSES { + let repeat = 5; + let ca = client.clone(); + handles.push(thread::spawn(move || { + let mut con = ca.get_connection(None).expect("con"); + + // We must create each group and each consumer + // See https://redis.io/commands/xreadgroup#differences-between-xread-and-xreadgroup + + for key in STREAMS { + let created: Result<(), _> = con.xgroup_create_mkstream(*key, GROUP_NAME, "$"); + if let Err(e) = created { + println!("Group already exists: {e:?}") + } + } + + for _ in 0..repeat { + let read_reply = read_group_records(&ca, *slowness).expect("group read"); + + // fake some expensive work + for StreamKey { key, ids } in read_reply.keys { + for StreamId { id, map: _ } in &ids { + thread::sleep(Duration::from_millis(random_wait_millis(*slowness))); + println!( + "Stream {} ID {} Consumer slowness {} SysTime {}", + key, + id, + slowness, + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time") + .as_millis() + ); + } + + // acknowledge each stream and message ID once all messages are + // correctly processed + let id_strs: Vec<&String> = + ids.iter().map(|StreamId { id, map: _ }| id).collect(); + con.xack(key, GROUP_NAME, &id_strs).expect("ack") + } + } + })) + } + + for h in handles { + h.join().expect("Join") + } +} + +/// Generate some contrived records and add them to various +/// streams. +fn add_records(client: &redis::Client) -> RedisResult<()> { + let mut con = client.get_connection(None).expect("conn"); + + let maxlen = StreamMaxlen::Approx(1000); + + // a stream whose records have two fields + for _ in 0..thrifty_rand() { + con.xadd_maxlen( + DOG_STREAM, + maxlen, + "*", + &[("bark", arbitrary_value()), ("groom", arbitrary_value())], + )?; + } + + // a streams whose records have three fields + for _ in 0..thrifty_rand() { + con.xadd_maxlen( + CAT_STREAM, + maxlen, + "*", + &[ + ("meow", arbitrary_value()), + ("groom", arbitrary_value()), + ("hunt", arbitrary_value()), + ], + )?; + } + + // a streams whose records have four fields + for _ in 0..thrifty_rand() { + con.xadd_maxlen( + DUCK_STREAM, + maxlen, + "*", + &[ + ("quack", arbitrary_value()), + ("waddle", arbitrary_value()), + ("splash", arbitrary_value()), + ("flap", arbitrary_value()), + ], + )?; + } + + Ok(()) +} + +/// An approximation of randomness, without leaving the stdlib. +fn thrifty_rand() -> u8 { + let penultimate_num = 2; + (SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time travel") + .as_nanos() + % penultimate_num) as u8 + + 1 +} + +const MAGIC: u64 = 11; +fn random_wait_millis(slowness: u8) -> u64 { + thrifty_rand() as u64 * thrifty_rand() as u64 * MAGIC * slowness as u64 +} + +/// Generate a potentially unique value. +fn arbitrary_value() -> String { + format!( + "{}", + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time travel") + .as_nanos() + ) +} + +/// Block the thread for this many milliseconds while +/// waiting for data to arrive on the stream. +const BLOCK_MILLIS: usize = 5000; + +/// Read back records from all three streams, if they're available. +/// Doesn't bother with consumer groups. Generally the user +/// would be responsible for keeping track of the most recent +/// ID from which they need to read, but in this example, we +/// just go back to the beginning of time and ask for all the +/// records in the stream. +fn read_records(client: &redis::Client) -> RedisResult<()> { + let mut con = client.get_connection(None).expect("conn"); + + let opts = StreamReadOptions::default().block(BLOCK_MILLIS); + + // Oldest known time index + let starting_id = "0-0"; + // Same as above + let another_form = "0"; + + let srr: StreamReadReply = con + .xread_options(STREAMS, &[starting_id, another_form, starting_id], &opts) + .expect("read"); + + for StreamKey { key, ids } in srr.keys { + println!("Stream {key}"); + for StreamId { id, map } in ids { + println!("\tID {id}"); + for (n, s) in map { + if let Value::BulkString(bytes) = s { + println!("\t\t{}: {}", n, String::from_utf8(bytes).expect("utf8")) + } else { + panic!("Weird data") + } + } + } + } + + Ok(()) +} + +fn consumer_name(slowness: u8) -> String { + format!("example-consumer-{slowness}") +} + +const GROUP_NAME: &str = "example-group-aaa"; + +fn read_group_records(client: &redis::Client, slowness: u8) -> RedisResult { + let mut con = client.get_connection(None).expect("conn"); + + let opts = StreamReadOptions::default() + .block(BLOCK_MILLIS) + .count(3) + .group(GROUP_NAME, consumer_name(slowness)); + + let srr: StreamReadReply = con + .xread_options( + &[DOG_STREAM, CAT_STREAM, DUCK_STREAM], + &[">", ">", ">"], + &opts, + ) + .expect("records"); + + Ok(srr) +} + +fn clean_up(client: &redis::Client) { + let mut con = client.get_connection(None).expect("con"); + for k in STREAMS { + let trimmed: RedisResult<()> = con.xtrim(*k, StreamMaxlen::Equals(0)); + trimmed.expect("trim"); + + let destroyed: RedisResult<()> = con.xgroup_destroy(*k, GROUP_NAME); + destroyed.expect("xgroup destroy"); + } +} diff --git a/glide-core/redis-rs/redis/release.toml b/glide-core/redis-rs/redis/release.toml new file mode 100644 index 0000000000..942730e0b6 --- /dev/null +++ b/glide-core/redis-rs/redis/release.toml @@ -0,0 +1,2 @@ +pre-release-hook = "../scripts/update-versions.sh" +tag-name = "{{version}}" diff --git a/glide-core/redis-rs/redis/src/acl.rs b/glide-core/redis-rs/redis/src/acl.rs new file mode 100644 index 0000000000..ef85877ba6 --- /dev/null +++ b/glide-core/redis-rs/redis/src/acl.rs @@ -0,0 +1,312 @@ +//! Defines types to use with the ACL commands. + +use crate::types::{ + ErrorKind, FromRedisValue, RedisError, RedisResult, RedisWrite, ToRedisArgs, Value, +}; + +macro_rules! not_convertible_error { + ($v:expr, $det:expr) => { + RedisError::from(( + ErrorKind::TypeError, + "Response type not convertible", + format!("{:?} (response was {:?})", $det, $v), + )) + }; +} + +/// ACL rules are used in order to activate or remove a flag, or to perform a +/// given change to the user ACL, which under the hood are just single words. +#[derive(Debug, Eq, PartialEq)] +pub enum Rule { + /// Enable the user: it is possible to authenticate as this user. + On, + /// Disable the user: it's no longer possible to authenticate with this + /// user, however the already authenticated connections will still work. + Off, + + /// Add the command to the list of commands the user can call. + AddCommand(String), + /// Remove the command to the list of commands the user can call. + RemoveCommand(String), + /// Add all the commands in such category to be called by the user. + AddCategory(String), + /// Remove the commands from such category the client can call. + RemoveCategory(String), + /// Alias for `+@all`. Note that it implies the ability to execute all the + /// future commands loaded via the modules system. + AllCommands, + /// Alias for `-@all`. + NoCommands, + + /// Add this password to the list of valid password for the user. + AddPass(String), + /// Remove this password from the list of valid passwords. + RemovePass(String), + /// Add this SHA-256 hash value to the list of valid passwords for the user. + AddHashedPass(String), + /// Remove this hash value from from the list of valid passwords + RemoveHashedPass(String), + /// All the set passwords of the user are removed, and the user is flagged + /// as requiring no password: it means that every password will work + /// against this user. + NoPass, + /// Flush the list of allowed passwords. Moreover removes the _nopass_ status. + ResetPass, + + /// Add a pattern of keys that can be mentioned as part of commands. + Pattern(String), + /// Alias for `~*`. + AllKeys, + /// Flush the list of allowed keys patterns. + ResetKeys, + + /// Performs the following actions: `resetpass`, `resetkeys`, `off`, `-@all`. + /// The user returns to the same state it has immediately after its creation. + Reset, + + /// Raw text of [`ACL rule`][1] that not enumerated above. + /// + /// [1]: https://redis.io/docs/manual/security/acl + Other(String), +} + +impl ToRedisArgs for Rule { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + use self::Rule::*; + + match self { + On => out.write_arg(b"on"), + Off => out.write_arg(b"off"), + + AddCommand(cmd) => out.write_arg_fmt(format_args!("+{cmd}")), + RemoveCommand(cmd) => out.write_arg_fmt(format_args!("-{cmd}")), + AddCategory(cat) => out.write_arg_fmt(format_args!("+@{cat}")), + RemoveCategory(cat) => out.write_arg_fmt(format_args!("-@{cat}")), + AllCommands => out.write_arg(b"allcommands"), + NoCommands => out.write_arg(b"nocommands"), + + AddPass(pass) => out.write_arg_fmt(format_args!(">{pass}")), + RemovePass(pass) => out.write_arg_fmt(format_args!("<{pass}")), + AddHashedPass(pass) => out.write_arg_fmt(format_args!("#{pass}")), + RemoveHashedPass(pass) => out.write_arg_fmt(format_args!("!{pass}")), + NoPass => out.write_arg(b"nopass"), + ResetPass => out.write_arg(b"resetpass"), + + Pattern(pat) => out.write_arg_fmt(format_args!("~{pat}")), + AllKeys => out.write_arg(b"allkeys"), + ResetKeys => out.write_arg(b"resetkeys"), + + Reset => out.write_arg(b"reset"), + + Other(rule) => out.write_arg(rule.as_bytes()), + }; + } +} + +/// An info dictionary type storing Redis ACL information as multiple `Rule`. +/// This type collects key/value data returned by the [`ACL GETUSER`][1] command. +/// +/// [1]: https://redis.io/commands/acl-getuser +#[derive(Debug, Eq, PartialEq)] +pub struct AclInfo { + /// Describes flag rules for the user. Represented by [`Rule::On`][1], + /// [`Rule::Off`][2], [`Rule::AllKeys`][3], [`Rule::AllCommands`][4] and + /// [`Rule::NoPass`][5]. + /// + /// [1]: ./enum.Rule.html#variant.On + /// [2]: ./enum.Rule.html#variant.Off + /// [3]: ./enum.Rule.html#variant.AllKeys + /// [4]: ./enum.Rule.html#variant.AllCommands + /// [5]: ./enum.Rule.html#variant.NoPass + pub flags: Vec, + /// Describes the user's passwords. Represented by [`Rule::AddHashedPass`][1]. + /// + /// [1]: ./enum.Rule.html#variant.AddHashedPass + pub passwords: Vec, + /// Describes capabilities of which commands the user can call. + /// Represented by [`Rule::AddCommand`][1], [`Rule::AddCategory`][2], + /// [`Rule::RemoveCommand`][3] and [`Rule::RemoveCategory`][4]. + /// + /// [1]: ./enum.Rule.html#variant.AddCommand + /// [2]: ./enum.Rule.html#variant.AddCategory + /// [3]: ./enum.Rule.html#variant.RemoveCommand + /// [4]: ./enum.Rule.html#variant.RemoveCategory + pub commands: Vec, + /// Describes patterns of keys which the user can access. Represented by + /// [`Rule::Pattern`][1]. + /// + /// [1]: ./enum.Rule.html#variant.Pattern + pub keys: Vec, +} + +impl FromRedisValue for AclInfo { + fn from_redis_value(v: &Value) -> RedisResult { + let mut it = v + .as_sequence() + .ok_or_else(|| not_convertible_error!(v, ""))? + .iter() + .skip(1) + .step_by(2); + + let (flags, passwords, commands, keys) = match (it.next(), it.next(), it.next(), it.next()) + { + (Some(flags), Some(passwords), Some(commands), Some(keys)) => { + // Parse flags + // Ref: https://github.com/redis/redis/blob/0cabe0cfa7290d9b14596ec38e0d0a22df65d1df/src/acl.c#L83-L90 + let flags = flags + .as_sequence() + .ok_or_else(|| { + not_convertible_error!(flags, "Expect an array response of ACL flags") + })? + .iter() + .map(|flag| match flag { + Value::BulkString(flag) => match flag.as_slice() { + b"on" => Ok(Rule::On), + b"off" => Ok(Rule::Off), + b"allkeys" => Ok(Rule::AllKeys), + b"allcommands" => Ok(Rule::AllCommands), + b"nopass" => Ok(Rule::NoPass), + other => Ok(Rule::Other(String::from_utf8_lossy(other).into_owned())), + }, + _ => Err(not_convertible_error!( + flag, + "Expect an arbitrary binary data" + )), + }) + .collect::>()?; + + let passwords = passwords + .as_sequence() + .ok_or_else(|| { + not_convertible_error!(flags, "Expect an array response of ACL flags") + })? + .iter() + .map(|pass| Ok(Rule::AddHashedPass(String::from_redis_value(pass)?))) + .collect::>()?; + + let commands = match commands { + Value::BulkString(cmd) => std::str::from_utf8(cmd)?, + _ => { + return Err(not_convertible_error!( + commands, + "Expect a valid UTF8 string" + )) + } + } + .split_terminator(' ') + .map(|cmd| match cmd { + x if x.starts_with("+@") => Ok(Rule::AddCategory(x[2..].to_owned())), + x if x.starts_with("-@") => Ok(Rule::RemoveCategory(x[2..].to_owned())), + x if x.starts_with('+') => Ok(Rule::AddCommand(x[1..].to_owned())), + x if x.starts_with('-') => Ok(Rule::RemoveCommand(x[1..].to_owned())), + _ => Err(not_convertible_error!( + cmd, + "Expect a command addition/removal" + )), + }) + .collect::>()?; + + let keys = keys + .as_sequence() + .ok_or_else(|| not_convertible_error!(keys, ""))? + .iter() + .map(|pat| Ok(Rule::Pattern(String::from_redis_value(pat)?))) + .collect::>()?; + + (flags, passwords, commands, keys) + } + _ => { + return Err(not_convertible_error!( + v, + "Expect a resposne from `ACL GETUSER`" + )) + } + }; + + Ok(Self { + flags, + passwords, + commands, + keys, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! assert_args { + ($rule:expr, $arg:expr) => { + assert_eq!($rule.to_redis_args(), vec![$arg.to_vec()]); + }; + } + + #[test] + fn test_rule_to_arg() { + use self::Rule::*; + + assert_args!(On, b"on"); + assert_args!(Off, b"off"); + assert_args!(AddCommand("set".to_owned()), b"+set"); + assert_args!(RemoveCommand("set".to_owned()), b"-set"); + assert_args!(AddCategory("hyperloglog".to_owned()), b"+@hyperloglog"); + assert_args!(RemoveCategory("hyperloglog".to_owned()), b"-@hyperloglog"); + assert_args!(AllCommands, b"allcommands"); + assert_args!(NoCommands, b"nocommands"); + assert_args!(AddPass("mypass".to_owned()), b">mypass"); + assert_args!(RemovePass("mypass".to_owned()), b">> { + con: C, + buf: Vec, + decoder: combine::stream::Decoder>, + db: i64, + + // Flag indicating whether the connection was left in the PubSub state after dropping `PubSub`. + // + // This flag is checked when attempting to send a command, and if it's raised, we attempt to + // exit the pubsub state before executing the new request. + pubsub: bool, + + // Field indicating which protocol to use for server communications. + protocol: ProtocolVersion, +} + +fn assert_sync() {} + +#[allow(unused)] +fn test() { + assert_sync::(); +} + +impl Connection +where + C: Unpin + AsyncRead + AsyncWrite + Send, +{ + /// Constructs a new `Connection` out of a `AsyncRead + AsyncWrite` object + /// and a `RedisConnectionInfo` + pub async fn new(connection_info: &RedisConnectionInfo, con: C) -> RedisResult { + let mut rv = Connection { + con, + buf: Vec::new(), + decoder: combine::stream::Decoder::new(), + db: connection_info.db, + pubsub: false, + protocol: connection_info.protocol, + }; + setup_connection(connection_info, &mut rv, false).await?; + Ok(rv) + } + + /// Converts this [`Connection`] into [`PubSub`]. + pub fn into_pubsub(self) -> PubSub { + PubSub::new(self) + } + + /// Converts this [`Connection`] into [`Monitor`] + pub fn into_monitor(self) -> Monitor { + Monitor::new(self) + } + + /// Fetches a single response from the connection. + async fn read_response(&mut self) -> RedisResult { + crate::parser::parse_redis_value_async(&mut self.decoder, &mut self.con).await + } + + /// Brings [`Connection`] out of `PubSub` mode. + /// + /// This will unsubscribe this [`Connection`] from all subscriptions. + /// + /// If this function returns error then on all command send tries will be performed attempt + /// to exit from `PubSub` mode until it will be successful. + async fn exit_pubsub(&mut self) -> RedisResult<()> { + let res = self.clear_active_subscriptions().await; + if res.is_ok() { + self.pubsub = false; + } else { + // Raise the pubsub flag to indicate the connection is "stuck" in that state. + self.pubsub = true; + } + + res + } + + /// Get the inner connection out of a PubSub + /// + /// Any active subscriptions are unsubscribed. In the event of an error, the connection is + /// dropped. + async fn clear_active_subscriptions(&mut self) -> RedisResult<()> { + // Responses to unsubscribe commands return in a 3-tuple with values + // ("unsubscribe" or "punsubscribe", name of subscription removed, count of remaining subs). + // The "count of remaining subs" includes both pattern subscriptions and non pattern + // subscriptions. Thus, to accurately drain all unsubscribe messages received from the + // server, both commands need to be executed at once. + { + // Prepare both unsubscribe commands + let unsubscribe = crate::Pipeline::new() + .add_command(cmd("UNSUBSCRIBE")) + .add_command(cmd("PUNSUBSCRIBE")) + .get_packed_pipeline(); + + // Execute commands + self.con.write_all(&unsubscribe).await?; + } + + // Receive responses + // + // There will be at minimum two responses - 1 for each of punsubscribe and unsubscribe + // commands. There may be more responses if there are active subscriptions. In this case, + // messages are received until the _subscription count_ in the responses reach zero. + let mut received_unsub = false; + let mut received_punsub = false; + if self.protocol != ProtocolVersion::RESP2 { + while let Value::Push { kind, data } = + from_owned_redis_value(self.read_response().await?)? + { + if data.len() >= 2 { + if let Value::Int(num) = data[1] { + if resp3_is_pub_sub_state_cleared( + &mut received_unsub, + &mut received_punsub, + &kind, + num as isize, + ) { + break; + } + } + } + } + } else { + loop { + let res: (Vec, (), isize) = + from_owned_redis_value(self.read_response().await?)?; + if resp2_is_pub_sub_state_cleared( + &mut received_unsub, + &mut received_punsub, + &res.0, + res.2, + ) { + break; + } + } + } + + // Finally, the connection is back in its normal state since all subscriptions were + // cancelled *and* all unsubscribe messages were received. + Ok(()) + } +} + +impl ConnectionLike for Connection +where + C: Unpin + AsyncRead + AsyncWrite + Send, +{ + fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> { + (async move { + if self.pubsub { + self.exit_pubsub().await?; + } + self.buf.clear(); + cmd.write_packed_command(&mut self.buf); + self.con.write_all(&self.buf).await?; + if cmd.is_no_response() { + return Ok(Value::Nil); + } + loop { + match self.read_response().await? { + Value::Push { .. } => continue, + val => return Ok(val), + } + } + }) + .boxed() + } + + fn req_packed_commands<'a>( + &'a mut self, + cmd: &'a crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisFuture<'a, Vec> { + (async move { + if self.pubsub { + self.exit_pubsub().await?; + } + + self.buf.clear(); + cmd.write_packed_pipeline(&mut self.buf); + self.con.write_all(&self.buf).await?; + + let mut first_err = None; + + for _ in 0..offset { + let response = self.read_response().await; + if let Err(err) = response { + if first_err.is_none() { + first_err = Some(err); + } + } + } + + let mut rv = Vec::with_capacity(count); + let mut count = count; + let mut idx = 0; + while idx < count { + let response = self.read_response().await; + match response { + Ok(item) => { + // RESP3 can insert push data between command replies + if let Value::Push { .. } = item { + // if that is the case we have to extend the loop and handle push data + count += 1; + } else { + rv.push(item); + } + } + Err(err) => { + if first_err.is_none() { + first_err = Some(err); + } + } + } + idx += 1; + } + + if let Some(err) = first_err { + Err(err) + } else { + Ok(rv) + } + }) + .boxed() + } + + fn get_db(&self) -> i64 { + self.db + } + + fn is_closed(&self) -> bool { + // always false for AsyncRead + AsyncWrite (cant do better) + false + } +} + +/// Represents a `PubSub` connection. +pub struct PubSub>>(Connection); + +/// Represents a `Monitor` connection. +pub struct Monitor>>(Connection); + +impl PubSub +where + C: Unpin + AsyncRead + AsyncWrite + Send, +{ + fn new(con: Connection) -> Self { + Self(con) + } + + /// Subscribes to a new channel. + pub async fn subscribe(&mut self, channel: T) -> RedisResult<()> { + let mut cmd = cmd("SUBSCRIBE"); + cmd.arg(channel); + if self.0.protocol != ProtocolVersion::RESP2 { + cmd.set_no_response(true); + } + cmd.query_async(&mut self.0).await + } + + /// Subscribes to a new channel with a pattern. + pub async fn psubscribe(&mut self, pchannel: T) -> RedisResult<()> { + let mut cmd = cmd("PSUBSCRIBE"); + cmd.arg(pchannel); + if self.0.protocol != ProtocolVersion::RESP2 { + cmd.set_no_response(true); + } + cmd.query_async(&mut self.0).await + } + + /// Unsubscribes from a channel. + pub async fn unsubscribe(&mut self, channel: T) -> RedisResult<()> { + let mut cmd = cmd("UNSUBSCRIBE"); + cmd.arg(channel); + if self.0.protocol != ProtocolVersion::RESP2 { + cmd.set_no_response(true); + } + cmd.query_async(&mut self.0).await + } + + /// Unsubscribes from a channel with a pattern. + pub async fn punsubscribe(&mut self, pchannel: T) -> RedisResult<()> { + let mut cmd = cmd("PUNSUBSCRIBE"); + cmd.arg(pchannel); + if self.0.protocol != ProtocolVersion::RESP2 { + cmd.set_no_response(true); + } + cmd.query_async(&mut self.0).await + } + + /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions. + /// + /// The message itself is still generic and can be converted into an appropriate type through + /// the helper methods on it. + pub fn on_message(&mut self) -> impl Stream + '_ { + ValueCodec::default() + .framed(&mut self.0.con) + .filter_map(|msg| Box::pin(async move { Msg::from_value(&msg.ok()?.ok()?) })) + } + + /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions consuming it. + /// + /// The message itself is still generic and can be converted into an appropriate type through + /// the helper methods on it. + /// This can be useful in cases where the stream needs to be returned or held by something other + /// than the [`PubSub`]. + pub fn into_on_message(self) -> impl Stream { + ValueCodec::default() + .framed(self.0.con) + .filter_map(|msg| Box::pin(async move { Msg::from_value(&msg.ok()?.ok()?) })) + } + + /// Exits from `PubSub` mode and converts [`PubSub`] into [`Connection`]. + #[deprecated(note = "aio::Connection is deprecated")] + pub async fn into_connection(mut self) -> Connection { + self.0.exit_pubsub().await.ok(); + + self.0 + } +} + +impl Monitor +where + C: Unpin + AsyncRead + AsyncWrite + Send, +{ + /// Create a [`Monitor`] from a [`Connection`] + pub fn new(con: Connection) -> Self { + Self(con) + } + + /// Deliver the MONITOR command to this [`Monitor`]ing wrapper. + pub async fn monitor(&mut self) -> RedisResult<()> { + cmd("MONITOR").query_async(&mut self.0).await + } + + /// Returns [`Stream`] of [`FromRedisValue`] values from this [`Monitor`]ing connection + pub fn on_message(&mut self) -> impl Stream + '_ { + ValueCodec::default() + .framed(&mut self.0.con) + .filter_map(|value| { + Box::pin(async move { T::from_owned_redis_value(value.ok()?.ok()?).ok() }) + }) + } + + /// Returns [`Stream`] of [`FromRedisValue`] values from this [`Monitor`]ing connection + pub fn into_on_message(self) -> impl Stream { + ValueCodec::default() + .framed(self.0.con) + .filter_map(|value| { + Box::pin(async move { T::from_owned_redis_value(value.ok()?.ok()?).ok() }) + }) + } +} + +pub(crate) async fn get_socket_addrs( + host: &str, + port: u16, +) -> RedisResult + Send + '_> { + #[cfg(feature = "tokio-comp")] + let socket_addrs = lookup_host((host, port)).await?; + + let mut socket_addrs = socket_addrs.peekable(); + match socket_addrs.peek() { + Some(_) => Ok(socket_addrs), + None => Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "No address found for host", + ))), + } +} + +/// Logs the creation of a connection, including its type, the node, and optionally its IP address. +fn log_conn_creation(conn_type: &str, node: T, ip: Option) +where + T: std::fmt::Debug, +{ + tracing::debug!( + "Creating {conn_type} connection for node: {node:?}{}", + ip.map(|ip| format!(", IP: {:?}", ip)).unwrap_or_default() + ); +} + +pub(crate) async fn connect_simple( + connection_info: &ConnectionInfo, + _socket_addr: Option, +) -> RedisResult<(T, Option)> { + Ok(match connection_info.addr { + ConnectionAddr::Tcp(ref host, port) => { + if let Some(socket_addr) = _socket_addr { + return Ok::<_, RedisError>(( + ::connect_tcp(socket_addr).await?, + Some(socket_addr.ip()), + )); + } + let socket_addrs = get_socket_addrs(host, port).await?; + select_ok(socket_addrs.map(|socket_addr| { + log_conn_creation("TCP", format!("{host}:{port}"), Some(socket_addr.ip())); + Box::pin(async move { + Ok::<_, RedisError>(( + ::connect_tcp(socket_addr).await?, + Some(socket_addr.ip()), + )) + }) + })) + .await? + .0 + } + + #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))] + ConnectionAddr::TcpTls { + ref host, + port, + insecure, + ref tls_params, + } => { + if let Some(socket_addr) = _socket_addr { + return Ok::<_, RedisError>(( + ::connect_tcp_tls(host, socket_addr, insecure, tls_params).await?, + Some(socket_addr.ip()), + )); + } + let socket_addrs = get_socket_addrs(host, port).await?; + select_ok(socket_addrs.map(|socket_addr| { + log_conn_creation( + "TCP with TLS", + format!("{host}:{port}"), + Some(socket_addr.ip()), + ); + Box::pin(async move { + Ok::<_, RedisError>(( + ::connect_tcp_tls(host, socket_addr, insecure, tls_params).await?, + Some(socket_addr.ip()), + )) + }) + })) + .await? + .0 + } + + #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))] + ConnectionAddr::TcpTls { .. } => { + fail!(( + ErrorKind::InvalidClientConfig, + "Cannot connect to TCP with TLS without the tls feature" + )); + } + + #[cfg(unix)] + ConnectionAddr::Unix(ref path) => { + log_conn_creation("UDS", path, None); + (::connect_unix(path).await?, None) + } + + #[cfg(not(unix))] + ConnectionAddr::Unix(_) => { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Cannot connect to unix sockets \ + on this platform", + ))) + } + }) +} diff --git a/glide-core/redis-rs/redis/src/aio/connection_manager.rs b/glide-core/redis-rs/redis/src/aio/connection_manager.rs new file mode 100644 index 0000000000..02e6976d15 --- /dev/null +++ b/glide-core/redis-rs/redis/src/aio/connection_manager.rs @@ -0,0 +1,312 @@ +use super::RedisFuture; +use crate::client::GlideConnectionOptions; +use crate::cmd::Cmd; +use crate::push_manager::PushManager; +use crate::types::{RedisError, RedisResult, Value}; +use crate::{ + aio::{ConnectionLike, MultiplexedConnection, Runtime}, + Client, +}; + +use arc_swap::ArcSwap; +use futures::{ + future::{self, Shared}, + FutureExt, +}; +use futures_util::future::BoxFuture; +use std::sync::Arc; +use tokio_retry2::strategy::{jitter, ExponentialBackoff}; +use tokio_retry2::{Retry, RetryError}; + +/// A `ConnectionManager` is a proxy that wraps a [multiplexed +/// connection][multiplexed-connection] and automatically reconnects to the +/// server when necessary. +/// +/// Like the [`MultiplexedConnection`][multiplexed-connection], this +/// manager can be cloned, allowing requests to be be sent concurrently on +/// the same underlying connection (tcp/unix socket). +/// +/// ## Behavior +/// +/// - When creating an instance of the `ConnectionManager`, an initial +/// connection will be established and awaited. Connection errors will be +/// returned directly. +/// - When a command sent to the server fails with an error that represents +/// a "connection dropped" condition, that error will be passed on to the +/// user, but it will trigger a reconnection in the background. +/// - The reconnect code will atomically swap the current (dead) connection +/// with a future that will eventually resolve to a `MultiplexedConnection` +/// or to a `RedisError` +/// - All commands that are issued after the reconnect process has been +/// initiated, will have to await the connection future. +/// - If reconnecting fails, all pending commands will be failed as well. A +/// new reconnection attempt will be triggered if the error is an I/O error. +/// +/// [multiplexed-connection]: struct.MultiplexedConnection.html +#[derive(Clone)] +pub struct ConnectionManager { + /// Information used for the connection. This is needed to be able to reconnect. + client: Client, + /// The connection future. + /// + /// The `ArcSwap` is required to be able to replace the connection + /// without making the `ConnectionManager` mutable. + connection: Arc>>, + + runtime: Runtime, + retry_strategy: ExponentialBackoff, + number_of_retries: usize, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + push_manager: PushManager, +} + +/// A `RedisResult` that can be cloned because `RedisError` is behind an `Arc`. +type CloneableRedisResult = Result>; + +/// Type alias for a shared boxed future that will resolve to a `CloneableRedisResult`. +type SharedRedisFuture = Shared>>; + +/// Handle a command result. If the connection was dropped, reconnect. +macro_rules! reconnect_if_dropped { + ($self:expr, $result:expr, $current:expr) => { + if let Err(ref e) = $result { + if e.is_unrecoverable_error() { + $self.reconnect($current); + } + } + }; +} + +/// Handle a connection result. If the connection has dropped, reconnect. +/// Propagate any error. +macro_rules! reconnect_if_conn_dropped { + ($self:expr, $result:expr, $current:expr) => { + if let Err(e) = $result { + if e.is_connection_dropped() { + $self.reconnect($current); + } + return Err(e); + } + }; +} + +impl ConnectionManager { + const DEFAULT_CONNECTION_RETRY_EXPONENT_BASE: u64 = 2; + const DEFAULT_CONNECTION_RETRY_FACTOR: u64 = 100; + const DEFAULT_NUMBER_OF_CONNECTION_RETRIESE: usize = 6; + + /// Connect to the server and store the connection inside the returned `ConnectionManager`. + /// + /// This requires the `connection-manager` feature, which will also pull in + /// the Tokio executor. + pub async fn new(client: Client) -> RedisResult { + Self::new_with_backoff( + client, + Self::DEFAULT_CONNECTION_RETRY_EXPONENT_BASE, + Self::DEFAULT_CONNECTION_RETRY_FACTOR, + Self::DEFAULT_NUMBER_OF_CONNECTION_RETRIESE, + ) + .await + } + + /// Connect to the server and store the connection inside the returned `ConnectionManager`. + /// + /// This requires the `connection-manager` feature, which will also pull in + /// the Tokio executor. + /// + /// In case of reconnection issues, the manager will retry reconnection + /// number_of_retries times, with an exponentially increasing delay, calculated as + /// rand(0 .. factor * (exponent_base ^ current-try)). + pub async fn new_with_backoff( + client: Client, + exponent_base: u64, + factor: u64, + number_of_retries: usize, + ) -> RedisResult { + Self::new_with_backoff_and_timeouts( + client, + exponent_base, + factor, + number_of_retries, + std::time::Duration::MAX, + std::time::Duration::MAX, + ) + .await + } + + /// Connect to the server and store the connection inside the returned `ConnectionManager`. + /// + /// This requires the `connection-manager` feature, which will also pull in + /// the Tokio executor. + /// + /// In case of reconnection issues, the manager will retry reconnection + /// number_of_retries times, with an exponentially increasing delay, calculated as + /// rand(0 .. factor * (exponent_base ^ current-try)). + /// + /// The new connection will timeout operations after `response_timeout` has passed. + /// Each connection attempt to the server will timeout after `connection_timeout`. + pub async fn new_with_backoff_and_timeouts( + client: Client, + exponent_base: u64, + factor: u64, + number_of_retries: usize, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + ) -> RedisResult { + // Create a MultiplexedConnection and wait for it to be established + let push_manager = PushManager::default(); + let runtime = Runtime::locate(); + let retry_strategy = ExponentialBackoff::from_millis(exponent_base).factor(factor); + let mut connection = Self::new_connection( + client.clone(), + retry_strategy.clone(), + number_of_retries, + response_timeout, + connection_timeout, + ) + .await?; + + // Wrap the connection in an `ArcSwap` instance for fast atomic access + connection.set_push_manager(push_manager.clone()).await; + Ok(Self { + client, + connection: Arc::new(ArcSwap::from_pointee( + future::ok(connection).boxed().shared(), + )), + runtime, + number_of_retries, + retry_strategy, + response_timeout, + connection_timeout, + push_manager, + }) + } + + async fn new_connection( + client: Client, + exponential_backoff: ExponentialBackoff, + number_of_retries: usize, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + ) -> RedisResult { + let retry_strategy = exponential_backoff.map(jitter).take(number_of_retries); + Retry::spawn(retry_strategy, || async { + client + .get_multiplexed_async_connection_with_timeouts( + response_timeout, + connection_timeout, + GlideConnectionOptions::default(), + ) + .await + .map_err(RetryError::transient) + }) + .await + } + + /// Reconnect and overwrite the old connection. + /// + /// The `current` guard points to the shared future that was active + /// when the connection loss was detected. + fn reconnect(&self, current: arc_swap::Guard>>) { + let client = self.client.clone(); + let retry_strategy = self.retry_strategy.clone(); + let number_of_retries = self.number_of_retries; + let response_timeout = self.response_timeout; + let connection_timeout = self.connection_timeout; + let pmc = self.push_manager.clone(); + let new_connection: SharedRedisFuture = async move { + let mut con = Self::new_connection( + client, + retry_strategy, + number_of_retries, + response_timeout, + connection_timeout, + ) + .await?; + con.set_push_manager(pmc).await; + Ok(con) + } + .boxed() + .shared(); + + // Update the connection in the connection manager + let new_connection_arc = Arc::new(new_connection.clone()); + let prev = self + .connection + .compare_and_swap(¤t, new_connection_arc); + + // If the swap happened... + if Arc::ptr_eq(&prev, ¤t) { + // ...start the connection attempt immediately but do not wait on it. + self.runtime.spawn(new_connection.map(|_| ())); + } + } + + /// Sends an already encoded (packed) command into the TCP socket and + /// reads the single response from it. + pub async fn send_packed_command(&mut self, cmd: &Cmd) -> RedisResult { + // Clone connection to avoid having to lock the ArcSwap in write mode + let guard = self.connection.load(); + let connection_result = (**guard) + .clone() + .await + .map_err(|e| e.clone_mostly("Reconnecting failed")); + reconnect_if_conn_dropped!(self, connection_result, guard); + let result = connection_result?.send_packed_command(cmd).await; + reconnect_if_dropped!(self, &result, guard); + result + } + + /// Sends multiple already encoded (packed) command into the TCP socket + /// and reads `count` responses from it. This is used to implement + /// pipelining. + pub async fn send_packed_commands( + &mut self, + cmd: &crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisResult> { + // Clone shared connection future to avoid having to lock the ArcSwap in write mode + let guard = self.connection.load(); + let connection_result = (**guard) + .clone() + .await + .map_err(|e| e.clone_mostly("Reconnecting failed")); + reconnect_if_conn_dropped!(self, connection_result, guard); + let result = connection_result? + .send_packed_commands(cmd, offset, count) + .await; + reconnect_if_dropped!(self, &result, guard); + result + } + + /// Returns `PushManager` of Connection, this method is used to subscribe/unsubscribe from Push types + pub fn get_push_manager(&self) -> PushManager { + self.push_manager.clone() + } +} + +impl ConnectionLike for ConnectionManager { + fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> { + (async move { self.send_packed_command(cmd).await }).boxed() + } + + fn req_packed_commands<'a>( + &'a mut self, + cmd: &'a crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisFuture<'a, Vec> { + (async move { self.send_packed_commands(cmd, offset, count).await }).boxed() + } + + fn get_db(&self) -> i64 { + self.client.connection_info().redis.db + } + + fn is_closed(&self) -> bool { + // always return false due to automatic reconnect + false + } +} diff --git a/glide-core/redis-rs/redis/src/aio/mod.rs b/glide-core/redis-rs/redis/src/aio/mod.rs new file mode 100644 index 0000000000..077046feba --- /dev/null +++ b/glide-core/redis-rs/redis/src/aio/mod.rs @@ -0,0 +1,328 @@ +//! Adds async IO support to redis. +use crate::cmd::{cmd, Cmd}; +use crate::connection::{ + get_resp3_hello_command_error, PubSubSubscriptionKind, RedisConnectionInfo, +}; +use crate::types::{ + ErrorKind, FromRedisValue, InfoDict, ProtocolVersion, RedisError, RedisFuture, RedisResult, + Value, +}; +use crate::PushKind; +use ::tokio::io::{AsyncRead, AsyncWrite}; +use async_trait::async_trait; +use futures_util::Future; +use std::net::SocketAddr; +#[cfg(unix)] +use std::path::Path; +use std::pin::Pin; +use std::time::Duration; + +#[cfg(feature = "tls-rustls")] +use crate::tls::TlsConnParams; + +#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] +use crate::connection::TlsConnParams; + +/// Enables the tokio compatibility +#[cfg(feature = "tokio-comp")] +#[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] +pub mod tokio; + +/// Represents the ability of connecting via TCP or via Unix socket +#[async_trait] +pub(crate) trait RedisRuntime: AsyncStream + Send + Sync + Sized + 'static { + /// Performs a TCP connection + async fn connect_tcp(socket_addr: SocketAddr) -> RedisResult; + + // Performs a TCP TLS connection + #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))] + async fn connect_tcp_tls( + hostname: &str, + socket_addr: SocketAddr, + insecure: bool, + tls_params: &Option, + ) -> RedisResult; + + /// Performs a UNIX connection + #[cfg(unix)] + async fn connect_unix(path: &Path) -> RedisResult; + + fn spawn(f: impl Future + Send + 'static); + + fn boxed(self) -> Pin> { + Box::pin(self) + } +} + +/// Trait for objects that implements `AsyncRead` and `AsyncWrite` +pub trait AsyncStream: AsyncRead + AsyncWrite {} +impl AsyncStream for S where S: AsyncRead + AsyncWrite {} + +/// An async abstraction over connections. +pub trait ConnectionLike { + /// Sends an already encoded (packed) command into the TCP socket and + /// reads the single response from it. + fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value>; + + /// Sends multiple already encoded (packed) command into the TCP socket + /// and reads `count` responses from it. This is used to implement + /// pipelining. + /// Important - this function is meant for internal usage, since it's + /// easy to pass incorrect `offset` & `count` parameters, which might + /// cause the connection to enter an erroneous state. Users shouldn't + /// call it, instead using the Pipeline::query_async function. + #[doc(hidden)] + fn req_packed_commands<'a>( + &'a mut self, + cmd: &'a crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisFuture<'a, Vec>; + + /// Returns the database this connection is bound to. Note that this + /// information might be unreliable because it's initially cached and + /// also might be incorrect if the connection like object is not + /// actually connected. + fn get_db(&self) -> i64; + + /// Returns the state of the connection + fn is_closed(&self) -> bool; + + /// Get the connection availibility zone + fn get_az(&self) -> Option { + None + } + + /// Set the connection availibility zone + fn set_az(&mut self, _az: Option) {} +} + +/// Implements ability to notify about disconnection events +#[async_trait] +pub trait DisconnectNotifier: Send + Sync { + /// Notify about disconnect event + fn notify_disconnect(&mut self); + + /// Wait for disconnect event with timeout + async fn wait_for_disconnect_with_timeout(&self, max_wait: &Duration); + + /// Intended to be used with Box + fn clone_box(&self) -> Box; +} + +impl Clone for Box { + fn clone(&self) -> Box { + self.clone_box() + } +} + +// Helper function to extract and update availability zone from INFO command +async fn update_az_from_info(con: &mut C) -> RedisResult<()> +where + C: ConnectionLike, +{ + let info_res = con.req_packed_command(&cmd("INFO")).await; + + match info_res { + Ok(value) => { + let info_dict: InfoDict = FromRedisValue::from_redis_value(&value)?; + if let Some(node_az) = info_dict.get::("availability_zone") { + con.set_az(Some(node_az)); + } + Ok(()) + } + Err(e) => { + // Handle the error case for the INFO command + Err(RedisError::from(( + ErrorKind::ResponseError, + "Failed to execute INFO command. ", + format!("{:?}", e), + ))) + } + } +} + +// Initial setup for every connection. +async fn setup_connection( + connection_info: &RedisConnectionInfo, + con: &mut C, + // This parameter is set to 'true' if ReadFromReplica strategy is set to AZAffinity. + // An INFO command will be triggered in the connection's setup to update the 'availability_zone' property. + discover_az: bool, +) -> RedisResult<()> +where + C: ConnectionLike, +{ + if connection_info.protocol != ProtocolVersion::RESP2 { + let hello_cmd = resp3_hello(connection_info); + let val: RedisResult = hello_cmd.query_async(con).await; + if let Err(err) = val { + return Err(get_resp3_hello_command_error(err)); + } + } else if let Some(password) = &connection_info.password { + let mut command = cmd("AUTH"); + if let Some(username) = &connection_info.username { + command.arg(username); + } + match command.arg(password).query_async(con).await { + Ok(Value::Okay) => (), + Err(e) => { + let err_msg = e.detail().ok_or(( + ErrorKind::AuthenticationFailed, + "Password authentication failed", + ))?; + + if !err_msg.contains("wrong number of arguments for 'auth' command") { + fail!(( + ErrorKind::AuthenticationFailed, + "Password authentication failed", + )); + } + + let mut command = cmd("AUTH"); + match command.arg(password).query_async(con).await { + Ok(Value::Okay) => (), + _ => { + fail!(( + ErrorKind::AuthenticationFailed, + "Password authentication failed" + )); + } + } + } + _ => { + fail!(( + ErrorKind::AuthenticationFailed, + "Password authentication failed" + )); + } + } + } + + if connection_info.db != 0 { + match cmd("SELECT").arg(connection_info.db).query_async(con).await { + Ok(Value::Okay) => (), + _ => fail!(( + ErrorKind::ResponseError, + "Redis server refused to switch database" + )), + } + } + + if let Some(client_name) = &connection_info.client_name { + match cmd("CLIENT") + .arg("SETNAME") + .arg(client_name) + .query_async(con) + .await + { + Ok(Value::Okay) => {} + _ => fail!(( + ErrorKind::ResponseError, + "Redis server refused to set client name" + )), + } + } + + if discover_az { + update_az_from_info(con).await?; + } + + // result is ignored, as per the command's instructions. + // https://redis.io/commands/client-setinfo/ + #[cfg(not(feature = "disable-client-setinfo"))] + let _: RedisResult<()> = crate::connection::client_set_info_pipeline() + .query_async(con) + .await; + + // resubscribe + if connection_info.protocol != ProtocolVersion::RESP3 { + return Ok(()); + } + static KIND_TO_COMMAND: [(PubSubSubscriptionKind, &str); 3] = [ + (PubSubSubscriptionKind::Exact, "SUBSCRIBE"), + (PubSubSubscriptionKind::Pattern, "PSUBSCRIBE"), + (PubSubSubscriptionKind::Sharded, "SSUBSCRIBE"), + ]; + + if connection_info.pubsub_subscriptions.is_none() { + return Ok(()); + } + + for (subscription_kind, channels_patterns) in + connection_info.pubsub_subscriptions.as_ref().unwrap() + { + for channel_pattern in channels_patterns.iter() { + let mut subscribe_command = + cmd(KIND_TO_COMMAND[Into::::into(*subscription_kind)].1); + subscribe_command.arg(channel_pattern); + + // This is a quite intricate code - Per RESP3, subscriptions commands do not return anything. + // Instead, push messages will be pushed for each channel. Thus, this is not a typycal request-response pattern. + // The act of pushing is asyncronous with the regard to the subscription command, and might be delayed for some time after the server state was already updated. + // (i.e. the behaviour is implementation defined). + // We will assume the configured time out is enough for the server to push the notifications. + match subscribe_command.query_async(con).await { + Ok(Value::Push { kind, data }) => { + match *subscription_kind { + PubSubSubscriptionKind::Exact => { + if kind != PushKind::Subscribe + || Value::BulkString(channel_pattern.clone()) != data[0] + { + fail!(( + ErrorKind::ResponseError, + // TODO: Consider printing the exact command + "Failed to restore Exact subscription channels" + )); + } + } + PubSubSubscriptionKind::Pattern => { + if kind != PushKind::PSubscribe + || Value::BulkString(channel_pattern.clone()) != data[0] + { + fail!(( + ErrorKind::ResponseError, + // TODO: Consider printing the exact command + "Failed to restore Pattern subscription channels" + )); + } + } + PubSubSubscriptionKind::Sharded => { + if kind != PushKind::SSubscribe + || Value::BulkString(channel_pattern.clone()) != data[0] + { + fail!(( + ErrorKind::ResponseError, + // TODO: Consider printing the exact command + "Failed to restore Sharded subscription channels" + )); + } + } + } + } + _ => { + fail!(( + ErrorKind::ResponseError, + // TODO: Consider printing the exact command + "Failed to receive subscription notification while restoring subscription channels" + )); + } + } + } + } + + Ok(()) +} + +mod connection; +pub use connection::*; +mod multiplexed_connection; +pub use multiplexed_connection::*; +#[cfg(feature = "connection-manager")] +mod connection_manager; +#[cfg(feature = "connection-manager")] +#[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))] +pub use connection_manager::*; +mod runtime; +use crate::commands::resp3_hello; +pub(super) use runtime::*; diff --git a/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs b/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs new file mode 100644 index 0000000000..98b3667cc9 --- /dev/null +++ b/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs @@ -0,0 +1,784 @@ +use super::{ConnectionLike, Runtime}; +use crate::aio::setup_connection; +use crate::aio::DisconnectNotifier; +use crate::client::GlideConnectionOptions; +use crate::cmd::Cmd; +#[cfg(feature = "tokio-comp")] +use crate::parser::ValueCodec; +use crate::push_manager::PushManager; +use crate::types::{RedisError, RedisFuture, RedisResult, Value}; +use crate::{cmd, ConnectionInfo, ProtocolVersion, PushKind}; +use ::tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{mpsc, oneshot}, +}; +use arc_swap::ArcSwap; +use futures_util::{ + future::{Future, FutureExt}, + ready, + sink::Sink, + stream::{self, Stream, StreamExt, TryStreamExt as _}, +}; +use pin_project_lite::pin_project; +use std::collections::VecDeque; +use std::fmt; +use std::fmt::Debug; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::task::{self, Poll}; +use std::time::Duration; +#[cfg(feature = "tokio-comp")] +use tokio_util::codec::Decoder; + +// Default connection timeout in ms +const DEFAULT_CONNECTION_ATTEMPT_TIMEOUT: Duration = Duration::from_millis(250); + +// Senders which the result of a single request are sent through +type PipelineOutput = oneshot::Sender>; + +enum ResponseAggregate { + SingleCommand, + Pipeline { + expected_response_count: usize, + current_response_count: usize, + buffer: Vec, + first_err: Option, + }, +} + +impl ResponseAggregate { + fn new(pipeline_response_count: Option) -> Self { + match pipeline_response_count { + Some(response_count) => ResponseAggregate::Pipeline { + expected_response_count: response_count, + current_response_count: 0, + buffer: Vec::new(), + first_err: None, + }, + None => ResponseAggregate::SingleCommand, + } + } +} + +struct InFlight { + output: PipelineOutput, + response_aggregate: ResponseAggregate, +} + +// A single message sent through the pipeline +struct PipelineMessage { + input: S, + output: PipelineOutput, + // If `None`, this is a single request, not a pipeline of multiple requests. + pipeline_response_count: Option, +} + +/// Wrapper around a `Stream + Sink` where each item sent through the `Sink` results in one or more +/// items being output by the `Stream` (the number is specified at time of sending). With the +/// interface provided by `Pipeline` an easy interface of request to response, hiding the `Stream` +/// and `Sink`. +#[derive(Clone)] +pub(crate) struct Pipeline { + sender: mpsc::Sender>, + push_manager: Arc>, + is_stream_closed: Arc, +} + +impl Debug for Pipeline +where + SinkItem: Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Pipeline").field(&self.sender).finish() + } +} + +pin_project! { + struct PipelineSink { + #[pin] + sink_stream: T, + in_flight: VecDeque, + error: Option, + push_manager: Arc>, + disconnect_notifier: Option>, + is_stream_closed: Arc, + } +} + +impl PipelineSink +where + T: Stream> + 'static, +{ + fn new( + sink_stream: T, + push_manager: Arc>, + disconnect_notifier: Option>, + is_stream_closed: Arc, + ) -> Self + where + T: Sink + Stream> + 'static, + { + PipelineSink { + sink_stream, + in_flight: VecDeque::new(), + error: None, + push_manager, + disconnect_notifier, + is_stream_closed, + } + } + + // Read messages from the stream and send them back to the caller + fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { + loop { + let item = match ready!(self.as_mut().project().sink_stream.poll_next(cx)) { + Some(result) => result, + // The redis response stream is not going to produce any more items so we `Err` + // to break out of the `forward` combinator and stop handling requests + None => { + // this is the right place to notify about the passive TCP disconnect + // In other places we cannot distinguish between the active destruction of MultiplexedConnection and passive disconnect + if let Some(disconnect_notifier) = self.as_mut().project().disconnect_notifier { + disconnect_notifier.notify_disconnect(); + } + self.is_stream_closed.store(true, Ordering::Relaxed); + return Poll::Ready(Err(())); + } + }; + self.as_mut().send_result(item); + } + } + + fn send_result(self: Pin<&mut Self>, result: RedisResult) { + let self_ = self.project(); + let mut skip_value = false; + if let Ok(res) = &result { + if let Value::Push { kind, data: _data } = res { + self_.push_manager.load().try_send_raw(res); + if !kind.has_reply() { + // If it's not true then push kind is converted to reply of a command + skip_value = true; + } + } + } + + let mut entry = match self_.in_flight.pop_front() { + Some(entry) => entry, + None => return, + }; + + if skip_value { + self_.in_flight.push_front(entry); + return; + } + + match &mut entry.response_aggregate { + ResponseAggregate::SingleCommand => { + entry.output.send(result).ok(); + } + ResponseAggregate::Pipeline { + expected_response_count, + current_response_count, + buffer, + first_err, + } => { + match result { + Ok(item) => { + buffer.push(item); + } + Err(err) => { + if first_err.is_none() { + *first_err = Some(err); + } + } + } + + *current_response_count += 1; + if current_response_count < expected_response_count { + // Need to gather more response values + self_.in_flight.push_front(entry); + return; + } + + let response = match first_err.take() { + Some(err) => Err(err), + None => Ok(Value::Array(std::mem::take(buffer))), + }; + + // `Err` means that the receiver was dropped in which case it does not + // care about the output and we can continue by just dropping the value + // and sender + entry.output.send(response).ok(); + } + } + } +} + +impl Sink> for PipelineSink +where + T: Sink + Stream> + 'static, +{ + type Error = (); + + // Retrieve incoming messages and write them to the sink + fn poll_ready( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + ) -> Poll> { + match ready!(self.as_mut().project().sink_stream.poll_ready(cx)) { + Ok(()) => Ok(()).into(), + Err(err) => { + *self.project().error = Some(err); + Ok(()).into() + } + } + } + + fn start_send( + mut self: Pin<&mut Self>, + PipelineMessage { + input, + output, + pipeline_response_count, + }: PipelineMessage, + ) -> Result<(), Self::Error> { + // If there is nothing to receive our output we do not need to send the message as it is + // ambiguous whether the message will be sent anyway. Helps shed some load on the + // connection. + if output.is_closed() { + return Ok(()); + } + + let self_ = self.as_mut().project(); + + if let Some(err) = self_.error.take() { + let _ = output.send(Err(err)); + return Err(()); + } + + match self_.sink_stream.start_send(input) { + Ok(()) => { + let response_aggregate = ResponseAggregate::new(pipeline_response_count); + let entry = InFlight { + output, + response_aggregate, + }; + + self_.in_flight.push_back(entry); + Ok(()) + } + Err(err) => { + let _ = output.send(Err(err)); + Err(()) + } + } + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + ) -> Poll> { + ready!(self + .as_mut() + .project() + .sink_stream + .poll_flush(cx) + .map_err(|err| { + self.as_mut().send_result(Err(err)); + }))?; + self.poll_read(cx) + } + + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + ) -> Poll> { + // No new requests will come in after the first call to `close` but we need to complete any + // in progress requests before closing + if !self.in_flight.is_empty() { + ready!(self.as_mut().poll_flush(cx))?; + } + let this = self.as_mut().project(); + this.sink_stream.poll_close(cx).map_err(|err| { + self.send_result(Err(err)); + }) + } +} + +impl Pipeline +where + SinkItem: Send + 'static, +{ + fn new( + sink_stream: T, + disconnect_notifier: Option>, + ) -> (Self, impl Future) + where + T: Sink + Stream> + 'static, + T: Send + 'static, + T::Item: Send, + T::Error: Send, + T::Error: ::std::fmt::Debug, + { + const BUFFER_SIZE: usize = 50; + let (sender, mut receiver) = mpsc::channel(BUFFER_SIZE); + let push_manager: Arc> = + Arc::new(ArcSwap::new(Arc::new(PushManager::default()))); + let is_stream_closed = Arc::new(AtomicBool::new(false)); + let sink = PipelineSink::new::( + sink_stream, + push_manager.clone(), + disconnect_notifier, + is_stream_closed.clone(), + ); + let f = stream::poll_fn(move |cx| receiver.poll_recv(cx)) + .map(Ok) + .forward(sink) + .map(|_| ()); + ( + Pipeline { + sender, + push_manager, + is_stream_closed, + }, + f, + ) + } + + // `None` means that the stream was out of items causing that poll loop to shut down. + async fn send_single( + &mut self, + item: SinkItem, + timeout: Duration, + ) -> Result { + self.send_recv(item, None, timeout).await + } + + async fn send_recv( + &mut self, + input: SinkItem, + // If `None`, this is a single request, not a pipeline of multiple requests. + pipeline_response_count: Option, + timeout: Duration, + ) -> Result { + let (sender, receiver) = oneshot::channel(); + + self.sender + .send(PipelineMessage { + input, + pipeline_response_count, + output: sender, + }) + .await + .map_err(|err| { + // If an error occurs here, it means the request never reached the server, as guaranteed + // by the 'send' function. Since the server did not receive the data, it is safe to retry + // the request. + RedisError::from(( + crate::ErrorKind::FatalSendError, + "Failed to send the request to the server", + err.to_string(), + )) + })?; + match Runtime::locate().timeout(timeout, receiver).await { + Ok(Ok(result)) => result, + Ok(Err(err)) => { + // The `sender` was dropped, likely indicating a failure in the stream. + // This error suggests that it's unclear whether the server received the request before the connection failed, + // making it unsafe to retry. For example, retrying an INCR request could result in double increments. + Err(RedisError::from(( + crate::ErrorKind::FatalReceiveError, + "Failed to receive a response due to a fatal error", + err.to_string(), + ))) + } + Err(elapsed) => Err(elapsed.into()), + } + } + + /// Sets `PushManager` of Pipeline + async fn set_push_manager(&mut self, push_manager: PushManager) { + self.push_manager.store(Arc::new(push_manager)); + } + + /// Checks if the pipeline is closed. + pub fn is_closed(&self) -> bool { + self.is_stream_closed.load(Ordering::Relaxed) + } +} + +/// A connection object which can be cloned, allowing requests to be be sent concurrently +/// on the same underlying connection (tcp/unix socket). +#[derive(Clone)] +pub struct MultiplexedConnection { + pipeline: Pipeline>, + db: i64, + response_timeout: Duration, + protocol: ProtocolVersion, + push_manager: PushManager, + availability_zone: Option, + password: Option, +} + +impl Debug for MultiplexedConnection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MultiplexedConnection") + .field("pipeline", &self.pipeline) + .field("db", &self.db) + .finish() + } +} + +impl MultiplexedConnection { + /// Constructs a new `MultiplexedConnection` out of a `AsyncRead + AsyncWrite` object + /// and a `ConnectionInfo` + pub async fn new( + connection_info: &ConnectionInfo, + stream: C, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult<(Self, impl Future)> + where + C: Unpin + AsyncRead + AsyncWrite + Send + 'static, + { + Self::new_with_response_timeout( + connection_info, + stream, + std::time::Duration::MAX, + glide_connection_options, + ) + .await + } + + /// Constructs a new `MultiplexedConnection` out of a `AsyncRead + AsyncWrite` object + /// and a `ConnectionInfo`. The new object will wait on operations for the given `response_timeout`. + pub async fn new_with_response_timeout( + connection_info: &ConnectionInfo, + stream: C, + response_timeout: std::time::Duration, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult<(Self, impl Future)> + where + C: Unpin + AsyncRead + AsyncWrite + Send + 'static, + { + let codec = ValueCodec::default() + .framed(stream) + .and_then(|msg| async move { msg }); + let (mut pipeline, driver) = + Pipeline::new(codec, glide_connection_options.disconnect_notifier); + let driver = Box::pin(driver); + let pm = PushManager::default(); + if let Some(sender) = glide_connection_options.push_sender { + pm.replace_sender(sender); + } + + pipeline.set_push_manager(pm.clone()).await; + + let mut con = MultiplexedConnection::builder(pipeline) + .with_db(connection_info.redis.db) + .with_response_timeout(response_timeout) + .with_push_manager(pm) + .with_protocol(connection_info.redis.protocol) + .with_password(connection_info.redis.password.clone()) + .with_availability_zone(None) + .build() + .await?; + + let driver = { + let auth = setup_connection( + &connection_info.redis, + &mut con, + glide_connection_options.discover_az, + ); + + futures_util::pin_mut!(auth); + + match futures_util::future::select(auth, driver).await { + futures_util::future::Either::Left((result, driver)) => { + result?; + driver + } + futures_util::future::Either::Right(((), _)) => { + return Err(RedisError::from(( + crate::ErrorKind::IoError, + "Multiplexed connection driver unexpectedly terminated", + ))); + } + } + }; + + Ok((con, driver)) + } + + /// Sets the time that the multiplexer will wait for responses on operations before failing. + pub fn set_response_timeout(&mut self, timeout: std::time::Duration) { + self.response_timeout = timeout; + } + + /// Sends an already encoded (packed) command into the TCP socket and + /// reads the single response from it. + pub async fn send_packed_command(&mut self, cmd: &Cmd) -> RedisResult { + let result = self + .pipeline + .send_single(cmd.get_packed_command(), self.response_timeout) + .await; + if self.protocol != ProtocolVersion::RESP2 { + if let Err(e) = &result { + if e.is_connection_dropped() { + // Notify the PushManager that the connection was lost + self.push_manager.try_send_raw(&Value::Push { + kind: PushKind::Disconnection, + data: vec![], + }); + } + } + } + result + } + + /// Sends multiple already encoded (packed) command into the TCP socket + /// and reads `count` responses from it. This is used to implement + /// pipelining. + pub async fn send_packed_commands( + &mut self, + cmd: &crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisResult> { + let result = self + .pipeline + .send_recv( + cmd.get_packed_pipeline(), + Some(offset + count), + self.response_timeout, + ) + .await; + + if self.protocol != ProtocolVersion::RESP2 { + if let Err(e) = &result { + if e.is_connection_dropped() { + // Notify the PushManager that the connection was lost + self.push_manager.try_send_raw(&Value::Push { + kind: PushKind::Disconnection, + data: vec![], + }); + } + } + } + let value = result?; + match value { + Value::Array(mut values) => { + values.drain(..offset); + Ok(values) + } + _ => Ok(vec![value]), + } + } + + /// Sets `PushManager` of connection + pub async fn set_push_manager(&mut self, push_manager: PushManager) { + self.push_manager = push_manager.clone(); + self.pipeline.set_push_manager(push_manager).await; + } + + /// For external visibilty (glide-core) + pub fn get_availability_zone(&self) -> Option { + self.availability_zone.clone() + } + + /// Replace the password used to authenticate with the server. + /// If `None` is provided, the password will be removed. + pub async fn update_connection_password( + &mut self, + password: Option, + ) -> RedisResult { + self.password = password; + Ok(Value::Okay) + } + + /// Creates a new `MultiplexedConnectionBuilder` for constructing a `MultiplexedConnection`. + pub(crate) fn builder(pipeline: Pipeline>) -> MultiplexedConnectionBuilder { + MultiplexedConnectionBuilder::new(pipeline) + } +} + +/// A builder for creating `MultiplexedConnection` instances. +pub struct MultiplexedConnectionBuilder { + pipeline: Pipeline>, + db: Option, + response_timeout: Option, + push_manager: Option, + protocol: Option, + password: Option, + /// Represents the node's availability zone + availability_zone: Option, +} + +impl MultiplexedConnectionBuilder { + /// Creates a new builder with the required pipeline + pub(crate) fn new(pipeline: Pipeline>) -> Self { + Self { + pipeline, + db: None, + response_timeout: None, + push_manager: None, + protocol: None, + password: None, + availability_zone: None, + } + } + + /// Sets the database index for the `MultiplexedConnectionBuilder`. + pub fn with_db(mut self, db: i64) -> Self { + self.db = Some(db); + self + } + + /// Sets the response timeout for the `MultiplexedConnectionBuilder`. + pub fn with_response_timeout(mut self, timeout: Duration) -> Self { + self.response_timeout = Some(timeout); + self + } + + /// Sets the push manager for the `MultiplexedConnectionBuilder`. + pub fn with_push_manager(mut self, push_manager: PushManager) -> Self { + self.push_manager = Some(push_manager); + self + } + + /// Sets the protocol version for the `MultiplexedConnectionBuilder`. + pub fn with_protocol(mut self, protocol: ProtocolVersion) -> Self { + self.protocol = Some(protocol); + self + } + + /// Sets the password for the `MultiplexedConnectionBuilder`. + pub fn with_password(mut self, password: Option) -> Self { + self.password = password; + self + } + + /// Sets the avazilability zone for the `MultiplexedConnectionBuilder`. + pub fn with_availability_zone(mut self, az: Option) -> Self { + self.availability_zone = az; + self + } + + /// Builds and returns a new `MultiplexedConnection` instance using the configured settings. + pub async fn build(self) -> RedisResult { + let db = self.db.unwrap_or_default(); + let response_timeout = self + .response_timeout + .unwrap_or(DEFAULT_CONNECTION_ATTEMPT_TIMEOUT); + let push_manager = self.push_manager.unwrap_or_default(); + let protocol = self.protocol.unwrap_or_default(); + let password = self.password; + + let con = MultiplexedConnection { + pipeline: self.pipeline, + db, + response_timeout, + push_manager, + protocol, + password, + availability_zone: self.availability_zone, + }; + + Ok(con) + } +} + +impl ConnectionLike for MultiplexedConnection { + fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> { + (async move { self.send_packed_command(cmd).await }).boxed() + } + + fn req_packed_commands<'a>( + &'a mut self, + cmd: &'a crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisFuture<'a, Vec> { + (async move { self.send_packed_commands(cmd, offset, count).await }).boxed() + } + + fn get_db(&self) -> i64 { + self.db + } + + fn is_closed(&self) -> bool { + self.pipeline.is_closed() + } + + /// Get the node's availability zone + fn get_az(&self) -> Option { + self.availability_zone.clone() + } + + /// Set the node's availability zone + fn set_az(&mut self, az: Option) { + self.availability_zone = az; + } +} +impl MultiplexedConnection { + /// Subscribes to a new channel. + pub async fn subscribe(&mut self, channel_name: String) -> RedisResult<()> { + if self.protocol == ProtocolVersion::RESP2 { + return Err(RedisError::from(( + crate::ErrorKind::InvalidClientConfig, + "RESP3 is required for this command", + ))); + } + let mut cmd = cmd("SUBSCRIBE"); + cmd.arg(channel_name.clone()); + cmd.query_async(self).await?; + Ok(()) + } + + /// Unsubscribes from channel. + pub async fn unsubscribe(&mut self, channel_name: String) -> RedisResult<()> { + if self.protocol == ProtocolVersion::RESP2 { + return Err(RedisError::from(( + crate::ErrorKind::InvalidClientConfig, + "RESP3 is required for this command", + ))); + } + let mut cmd = cmd("UNSUBSCRIBE"); + cmd.arg(channel_name); + cmd.query_async(self).await?; + Ok(()) + } + + /// Subscribes to a new channel with pattern. + pub async fn psubscribe(&mut self, channel_pattern: String) -> RedisResult<()> { + if self.protocol == ProtocolVersion::RESP2 { + return Err(RedisError::from(( + crate::ErrorKind::InvalidClientConfig, + "RESP3 is required for this command", + ))); + } + let mut cmd = cmd("PSUBSCRIBE"); + cmd.arg(channel_pattern.clone()); + cmd.query_async(self).await?; + Ok(()) + } + + /// Unsubscribes from channel pattern. + pub async fn punsubscribe(&mut self, channel_pattern: String) -> RedisResult<()> { + if self.protocol == ProtocolVersion::RESP2 { + return Err(RedisError::from(( + crate::ErrorKind::InvalidClientConfig, + "RESP3 is required for this command", + ))); + } + let mut cmd = cmd("PUNSUBSCRIBE"); + cmd.arg(channel_pattern); + cmd.query_async(self).await?; + Ok(()) + } + + /// Returns `PushManager` of Connection, this method is used to subscribe/unsubscribe from Push types + pub fn get_push_manager(&self) -> PushManager { + self.push_manager.clone() + } +} diff --git a/glide-core/redis-rs/redis/src/aio/runtime.rs b/glide-core/redis-rs/redis/src/aio/runtime.rs new file mode 100644 index 0000000000..2222783ed8 --- /dev/null +++ b/glide-core/redis-rs/redis/src/aio/runtime.rs @@ -0,0 +1,57 @@ +use std::{io, time::Duration}; + +use futures_util::Future; + +#[cfg(feature = "tokio-comp")] +use super::tokio; +use super::RedisRuntime; +use crate::types::RedisError; + +#[derive(Clone, Debug)] +pub(crate) enum Runtime { + #[cfg(feature = "tokio-comp")] + Tokio, +} + +impl Runtime { + pub(crate) fn locate() -> Self { + #[cfg(not(feature = "tokio-comp"))] + { + compile_error!("tokio-comp feature is required for aio feature") + } + #[cfg(feature = "tokio-comp")] + { + Runtime::Tokio + } + } + + #[allow(dead_code)] + pub(super) fn spawn(&self, f: impl Future + Send + 'static) { + match self { + #[cfg(feature = "tokio-comp")] + Runtime::Tokio => tokio::Tokio::spawn(f), + } + } + + pub(crate) async fn timeout( + &self, + duration: Duration, + future: F, + ) -> Result { + match self { + #[cfg(feature = "tokio-comp")] + Runtime::Tokio => ::tokio::time::timeout(duration, future) + .await + .map_err(|_| Elapsed(())), + } + } +} + +#[derive(Debug)] +pub(crate) struct Elapsed(()); + +impl From for RedisError { + fn from(_: Elapsed) -> Self { + io::Error::from(io::ErrorKind::TimedOut).into() + } +} diff --git a/glide-core/redis-rs/redis/src/aio/tokio.rs b/glide-core/redis-rs/redis/src/aio/tokio.rs new file mode 100644 index 0000000000..3a68c0ebfc --- /dev/null +++ b/glide-core/redis-rs/redis/src/aio/tokio.rs @@ -0,0 +1,204 @@ +use super::{AsyncStream, RedisResult, RedisRuntime, SocketAddr}; +use async_trait::async_trait; +#[allow(unused_imports)] // fixes "Duration" unused when built with non-default feature set +use std::{ + future::Future, + io, + pin::Pin, + task::{self, Poll}, + time::Duration, +}; +#[cfg(unix)] +use tokio::net::UnixStream as UnixStreamTokio; +use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf}, + net::TcpStream as TcpStreamTokio, +}; + +#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] +use native_tls::TlsConnector; + +#[cfg(feature = "tls-rustls")] +use crate::connection::create_rustls_config; +#[cfg(feature = "tls-rustls")] +use std::sync::Arc; +#[cfg(feature = "tls-rustls")] +use tokio_rustls::{client::TlsStream, TlsConnector}; + +#[cfg(all(feature = "tokio-native-tls-comp", not(feature = "tokio-rustls-comp")))] +use tokio_native_tls::TlsStream; + +#[cfg(feature = "tokio-rustls-comp")] +use crate::tls::TlsConnParams; + +#[cfg(all(feature = "tokio-native-tls-comp", not(feature = "tls-rustls")))] +use crate::connection::TlsConnParams; + +#[cfg(unix)] +use super::Path; + +#[inline(always)] +async fn connect_tcp(addr: &SocketAddr) -> io::Result { + let socket = TcpStreamTokio::connect(addr).await?; + #[cfg(feature = "tcp_nodelay")] + socket.set_nodelay(true)?; + #[cfg(feature = "keep-alive")] + { + //For now rely on system defaults + const KEEP_ALIVE: socket2::TcpKeepalive = socket2::TcpKeepalive::new(); + //these are useless error that not going to happen + let std_socket = socket.into_std()?; + let socket2: socket2::Socket = std_socket.into(); + socket2.set_tcp_keepalive(&KEEP_ALIVE)?; + // TCP_USER_TIMEOUT configuration isn't supported across all operation systems + #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] + { + // TODO: Replace this hardcoded timeout with a configurable timeout when https://github.com/redis-rs/redis-rs/issues/1147 is resolved + const DFEAULT_USER_TCP_TIMEOUT: Duration = Duration::from_secs(5); + socket2.set_tcp_user_timeout(Some(DFEAULT_USER_TCP_TIMEOUT))?; + } + TcpStreamTokio::from_std(socket2.into()) + } + + #[cfg(not(feature = "keep-alive"))] + { + Ok(socket) + } +} + +pub(crate) enum Tokio { + /// Represents a Tokio TCP connection. + Tcp(TcpStreamTokio), + /// Represents a Tokio TLS encrypted TCP connection + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] + TcpTls(Box>), + /// Represents a Tokio Unix connection. + #[cfg(unix)] + Unix(UnixStreamTokio), +} + +impl AsyncWrite for Tokio { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + buf: &[u8], + ) -> Poll> { + match &mut *self { + Tokio::Tcp(r) => Pin::new(r).poll_write(cx, buf), + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] + Tokio::TcpTls(r) => Pin::new(r).poll_write(cx, buf), + #[cfg(unix)] + Tokio::Unix(r) => Pin::new(r).poll_write(cx, buf), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { + match &mut *self { + Tokio::Tcp(r) => Pin::new(r).poll_flush(cx), + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] + Tokio::TcpTls(r) => Pin::new(r).poll_flush(cx), + #[cfg(unix)] + Tokio::Unix(r) => Pin::new(r).poll_flush(cx), + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { + match &mut *self { + Tokio::Tcp(r) => Pin::new(r).poll_shutdown(cx), + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] + Tokio::TcpTls(r) => Pin::new(r).poll_shutdown(cx), + #[cfg(unix)] + Tokio::Unix(r) => Pin::new(r).poll_shutdown(cx), + } + } +} + +impl AsyncRead for Tokio { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match &mut *self { + Tokio::Tcp(r) => Pin::new(r).poll_read(cx, buf), + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] + Tokio::TcpTls(r) => Pin::new(r).poll_read(cx, buf), + #[cfg(unix)] + Tokio::Unix(r) => Pin::new(r).poll_read(cx, buf), + } + } +} + +#[async_trait] +impl RedisRuntime for Tokio { + async fn connect_tcp(socket_addr: SocketAddr) -> RedisResult { + Ok(connect_tcp(&socket_addr).await.map(Tokio::Tcp)?) + } + + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + async fn connect_tcp_tls( + hostname: &str, + socket_addr: SocketAddr, + insecure: bool, + _: &Option, + ) -> RedisResult { + let tls_connector: tokio_native_tls::TlsConnector = if insecure { + TlsConnector::builder() + .danger_accept_invalid_certs(true) + .danger_accept_invalid_hostnames(true) + .use_sni(false) + .build()? + } else { + TlsConnector::new()? + } + .into(); + Ok(tls_connector + .connect(hostname, connect_tcp(&socket_addr).await?) + .await + .map(|con| Tokio::TcpTls(Box::new(con)))?) + } + + #[cfg(feature = "tls-rustls")] + async fn connect_tcp_tls( + hostname: &str, + socket_addr: SocketAddr, + insecure: bool, + tls_params: &Option, + ) -> RedisResult { + let config = create_rustls_config(insecure, tls_params.clone())?; + let tls_connector = TlsConnector::from(Arc::new(config)); + + Ok(tls_connector + .connect( + rustls_pki_types::ServerName::try_from(hostname)?.to_owned(), + connect_tcp(&socket_addr).await?, + ) + .await + .map(|con| Tokio::TcpTls(Box::new(con)))?) + } + + #[cfg(unix)] + async fn connect_unix(path: &Path) -> RedisResult { + Ok(UnixStreamTokio::connect(path).await.map(Tokio::Unix)?) + } + + #[cfg(feature = "tokio-comp")] + fn spawn(f: impl Future + Send + 'static) { + tokio::spawn(f); + } + + #[cfg(not(feature = "tokio-comp"))] + fn spawn(_: impl Future + Send + 'static) { + unreachable!() + } + + fn boxed(self) -> Pin> { + match self { + Tokio::Tcp(x) => Box::pin(x), + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] + Tokio::TcpTls(x) => Box::pin(x), + #[cfg(unix)] + Tokio::Unix(x) => Box::pin(x), + } + } +} diff --git a/glide-core/redis-rs/redis/src/client.rs b/glide-core/redis-rs/redis/src/client.rs new file mode 100644 index 0000000000..6ac3f40bcf --- /dev/null +++ b/glide-core/redis-rs/redis/src/client.rs @@ -0,0 +1,610 @@ +use std::time::Duration; + +#[cfg(feature = "aio")] +use crate::aio::DisconnectNotifier; + +use crate::{ + connection::{connect, Connection, ConnectionInfo, ConnectionLike, IntoConnectionInfo}, + push_manager::PushInfo, + types::{RedisResult, Value}, +}; +#[cfg(feature = "aio")] +use std::net::IpAddr; +#[cfg(feature = "aio")] +use std::net::SocketAddr; +#[cfg(feature = "aio")] +use std::pin::Pin; +use tokio::sync::mpsc; + +#[cfg(feature = "tls-rustls")] +use crate::tls::{inner_build_with_tls, TlsCertificates}; + +/// The client type. +#[derive(Debug, Clone)] +pub struct Client { + pub(crate) connection_info: ConnectionInfo, +} + +/// The client acts as connector to the redis server. By itself it does not +/// do much other than providing a convenient way to fetch a connection from +/// it. In the future the plan is to provide a connection pool in the client. +/// +/// When opening a client a URL in the following format should be used: +/// +/// ```plain +/// redis://host:port/db +/// ``` +/// +/// Example usage:: +/// +/// ```rust,no_run +/// let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +/// let con = client.get_connection(None).unwrap(); +/// ``` +impl Client { + /// Connects to a redis server and returns a client. This does not + /// actually open a connection yet but it does perform some basic + /// checks on the URL that might make the operation fail. + pub fn open(params: T) -> RedisResult { + Ok(Client { + connection_info: params.into_connection_info()?, + }) + } + + /// Instructs the client to actually connect to redis and returns a + /// connection object. The connection object can be used to send + /// commands to the server. This can fail with a variety of errors + /// (like unreachable host) so it's important that you handle those + /// errors. + pub fn get_connection( + &self, + _push_sender: Option>, + ) -> RedisResult { + connect(&self.connection_info, None) + } + + /// Instructs the client to actually connect to redis with specified + /// timeout and returns a connection object. The connection object + /// can be used to send commands to the server. This can fail with + /// a variety of errors (like unreachable host) so it's important + /// that you handle those errors. + pub fn get_connection_with_timeout(&self, timeout: Duration) -> RedisResult { + connect(&self.connection_info, Some(timeout)) + } + + /// Returns a reference of client connection info object. + pub fn get_connection_info(&self) -> &ConnectionInfo { + &self.connection_info + } +} + +/// Glide-specific connection options +#[derive(Clone, Default)] +pub struct GlideConnectionOptions { + /// Queue for RESP3 notifications + pub push_sender: Option>, + #[cfg(feature = "aio")] + /// Passive disconnect notifier + pub disconnect_notifier: Option>, + /// If ReadFromReplica strategy is set to AZAffinity, this parameter will be set to 'true'. + /// In this case, an INFO command will be triggered in the connection's setup to update the connection's 'availability_zone' property. + pub discover_az: bool, +} + +/// To enable async support you need to enable the feature: `tokio-comp` +#[cfg(feature = "aio")] +#[cfg_attr(docsrs, doc(cfg(feature = "aio")))] +impl Client { + /// Returns an async connection from the client. + #[cfg(feature = "tokio-comp")] + #[deprecated( + note = "aio::Connection is deprecated. Use client::get_multiplexed_async_connection instead." + )] + #[allow(deprecated)] + pub async fn get_async_connection( + &self, + _push_sender: Option>, + ) -> RedisResult { + let (con, _ip) = match Runtime::locate() { + #[cfg(feature = "tokio-comp")] + Runtime::Tokio => { + self.get_simple_async_connection::(None) + .await? + } + }; + + crate::aio::Connection::new(&self.connection_info.redis, con).await + } + + /// Returns an async connection from the client. + #[cfg(feature = "tokio-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] + pub async fn get_multiplexed_async_connection( + &self, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult { + self.get_multiplexed_async_connection_with_timeouts( + std::time::Duration::MAX, + std::time::Duration::MAX, + glide_connection_options, + ) + .await + } + + /// Returns an async connection from the client. + #[cfg(feature = "tokio-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] + pub async fn get_multiplexed_async_connection_with_timeouts( + &self, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult { + let result = match Runtime::locate() { + #[cfg(feature = "tokio-comp")] + rt @ Runtime::Tokio => { + rt.timeout( + connection_timeout, + self.get_multiplexed_async_connection_inner::( + response_timeout, + None, + glide_connection_options, + ), + ) + .await + } + }; + + match result { + Ok(Ok(connection)) => Ok(connection), + Ok(Err(e)) => Err(e), + Err(elapsed) => Err(elapsed.into()), + } + .map(|(conn, _ip)| conn) + } + + /// For TCP connections: returns (async connection, Some(the direct IP address)) + /// For Unix connections, returns (async connection, None) + #[cfg(feature = "tokio-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] + pub async fn get_multiplexed_async_connection_ip( + &self, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult<(crate::aio::MultiplexedConnection, Option)> { + match Runtime::locate() { + #[cfg(feature = "tokio-comp")] + Runtime::Tokio => { + self.get_multiplexed_async_connection_inner::( + Duration::MAX, + None, + glide_connection_options, + ) + .await + } + } + } + + /// Returns an async multiplexed connection from the client. + /// + /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + #[cfg(feature = "tokio-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] + pub async fn get_multiplexed_tokio_connection_with_response_timeouts( + &self, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult { + let result = Runtime::locate() + .timeout( + connection_timeout, + self.get_multiplexed_async_connection_inner::( + response_timeout, + None, + glide_connection_options, + ), + ) + .await; + + match result { + Ok(Ok((connection, _ip))) => Ok(connection), + Ok(Err(e)) => Err(e), + Err(elapsed) => Err(elapsed.into()), + } + } + + /// Returns an async multiplexed connection from the client. + /// + /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + #[cfg(feature = "tokio-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] + pub async fn get_multiplexed_tokio_connection( + &self, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult { + self.get_multiplexed_tokio_connection_with_response_timeouts( + std::time::Duration::MAX, + std::time::Duration::MAX, + glide_connection_options, + ) + .await + } + + /// Returns an async [`ConnectionManager`][connection-manager] from the client. + /// + /// The connection manager wraps a + /// [`MultiplexedConnection`][multiplexed-connection]. If a command to that + /// connection fails with a connection error, then a new connection is + /// established in the background and the error is returned to the caller. + /// + /// This means that on connection loss at least one command will fail, but + /// the connection will be re-established automatically if possible. Please + /// refer to the [`ConnectionManager`][connection-manager] docs for + /// detailed reconnecting behavior. + /// + /// A connection manager can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + /// + /// [connection-manager]: aio/struct.ConnectionManager.html + /// [multiplexed-connection]: aio/struct.MultiplexedConnection.html + #[cfg(feature = "connection-manager")] + #[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))] + #[deprecated(note = "use get_connection_manager instead")] + pub async fn get_tokio_connection_manager(&self) -> RedisResult { + crate::aio::ConnectionManager::new(self.clone()).await + } + + /// Returns an async [`ConnectionManager`][connection-manager] from the client. + /// + /// The connection manager wraps a + /// [`MultiplexedConnection`][multiplexed-connection]. If a command to that + /// connection fails with a connection error, then a new connection is + /// established in the background and the error is returned to the caller. + /// + /// This means that on connection loss at least one command will fail, but + /// the connection will be re-established automatically if possible. Please + /// refer to the [`ConnectionManager`][connection-manager] docs for + /// detailed reconnecting behavior. + /// + /// A connection manager can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + /// + /// [connection-manager]: aio/struct.ConnectionManager.html + /// [multiplexed-connection]: aio/struct.MultiplexedConnection.html + #[cfg(feature = "connection-manager")] + #[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))] + pub async fn get_connection_manager(&self) -> RedisResult { + crate::aio::ConnectionManager::new(self.clone()).await + } + + /// Returns an async [`ConnectionManager`][connection-manager] from the client. + /// + /// The connection manager wraps a + /// [`MultiplexedConnection`][multiplexed-connection]. If a command to that + /// connection fails with a connection error, then a new connection is + /// established in the background and the error is returned to the caller. + /// + /// This means that on connection loss at least one command will fail, but + /// the connection will be re-established automatically if possible. Please + /// refer to the [`ConnectionManager`][connection-manager] docs for + /// detailed reconnecting behavior. + /// + /// A connection manager can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + /// + /// [connection-manager]: aio/struct.ConnectionManager.html + /// [multiplexed-connection]: aio/struct.MultiplexedConnection.html + #[cfg(feature = "connection-manager")] + #[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))] + #[deprecated(note = "use get_connection_manager_with_backoff instead")] + pub async fn get_tokio_connection_manager_with_backoff( + &self, + exponent_base: u64, + factor: u64, + number_of_retries: usize, + ) -> RedisResult { + self.get_connection_manager_with_backoff_and_timeouts( + exponent_base, + factor, + number_of_retries, + std::time::Duration::MAX, + std::time::Duration::MAX, + ) + .await + } + + /// Returns an async [`ConnectionManager`][connection-manager] from the client. + /// + /// The connection manager wraps a + /// [`MultiplexedConnection`][multiplexed-connection]. If a command to that + /// connection fails with a connection error, then a new connection is + /// established in the background and the error is returned to the caller. + /// + /// This means that on connection loss at least one command will fail, but + /// the connection will be re-established automatically if possible. Please + /// refer to the [`ConnectionManager`][connection-manager] docs for + /// detailed reconnecting behavior. + /// + /// A connection manager can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + /// + /// [connection-manager]: aio/struct.ConnectionManager.html + /// [multiplexed-connection]: aio/struct.MultiplexedConnection.html + #[cfg(feature = "connection-manager")] + #[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))] + pub async fn get_connection_manager_with_backoff_and_timeouts( + &self, + exponent_base: u64, + factor: u64, + number_of_retries: usize, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + ) -> RedisResult { + crate::aio::ConnectionManager::new_with_backoff_and_timeouts( + self.clone(), + exponent_base, + factor, + number_of_retries, + response_timeout, + connection_timeout, + ) + .await + } + + /// Returns an async [`ConnectionManager`][connection-manager] from the client. + /// + /// The connection manager wraps a + /// [`MultiplexedConnection`][multiplexed-connection]. If a command to that + /// connection fails with a connection error, then a new connection is + /// established in the background and the error is returned to the caller. + /// + /// This means that on connection loss at least one command will fail, but + /// the connection will be re-established automatically if possible. Please + /// refer to the [`ConnectionManager`][connection-manager] docs for + /// detailed reconnecting behavior. + /// + /// A connection manager can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + /// + /// [connection-manager]: aio/struct.ConnectionManager.html + /// [multiplexed-connection]: aio/struct.MultiplexedConnection.html + #[cfg(feature = "connection-manager")] + #[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))] + pub async fn get_connection_manager_with_backoff( + &self, + exponent_base: u64, + factor: u64, + number_of_retries: usize, + ) -> RedisResult { + crate::aio::ConnectionManager::new_with_backoff( + self.clone(), + exponent_base, + factor, + number_of_retries, + ) + .await + } + + pub(crate) async fn get_multiplexed_async_connection_inner( + &self, + response_timeout: std::time::Duration, + socket_addr: Option, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult<(crate::aio::MultiplexedConnection, Option)> + where + T: crate::aio::RedisRuntime, + { + let (connection, driver, ip) = self + .create_multiplexed_async_connection_inner::( + response_timeout, + socket_addr, + glide_connection_options, + ) + .await?; + T::spawn(driver); + Ok((connection, ip)) + } + + async fn create_multiplexed_async_connection_inner( + &self, + response_timeout: std::time::Duration, + socket_addr: Option, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult<( + crate::aio::MultiplexedConnection, + impl std::future::Future, + Option, + )> + where + T: crate::aio::RedisRuntime, + { + let (con, ip) = self.get_simple_async_connection::(socket_addr).await?; + crate::aio::MultiplexedConnection::new_with_response_timeout( + &self.connection_info, + con, + response_timeout, + glide_connection_options, + ) + .await + .map(|res| (res.0, res.1, ip)) + } + + async fn get_simple_async_connection( + &self, + socket_addr: Option, + ) -> RedisResult<( + Pin>, + Option, + )> + where + T: crate::aio::RedisRuntime, + { + let (conn, ip) = + crate::aio::connect_simple::(&self.connection_info, socket_addr).await?; + Ok((conn.boxed(), ip)) + } + + #[cfg(feature = "connection-manager")] + pub(crate) fn connection_info(&self) -> &ConnectionInfo { + &self.connection_info + } + + /// Constructs a new `Client` with parameters necessary to create a TLS connection. + /// + /// - `conn_info` - URL using the `rediss://` scheme. + /// - `tls_certs` - `TlsCertificates` structure containing: + /// -- `client_tls` - Optional `ClientTlsConfig` containing byte streams for + /// --- `client_cert` - client's byte stream containing client certificate in PEM format + /// --- `client_key` - client's byte stream containing private key in PEM format + /// -- `root_cert` - Optional byte stream yielding PEM formatted file for root certificates. + /// + /// If `ClientTlsConfig` ( cert+key pair ) is not provided, then client-side authentication is not enabled. + /// If `root_cert` is not provided, then system root certificates are used instead. + /// + /// # Examples + /// + /// ```no_run + /// use std::{fs::File, io::{BufReader, Read}}; + /// + /// use redis::{Client, AsyncCommands as _, TlsCertificates, ClientTlsConfig}; + /// + /// async fn do_redis_code( + /// url: &str, + /// root_cert_file: &str, + /// cert_file: &str, + /// key_file: &str + /// ) -> redis::RedisResult<()> { + /// let root_cert_file = File::open(root_cert_file).expect("cannot open private cert file"); + /// let mut root_cert_vec = Vec::new(); + /// BufReader::new(root_cert_file) + /// .read_to_end(&mut root_cert_vec) + /// .expect("Unable to read ROOT cert file"); + /// + /// let cert_file = File::open(cert_file).expect("cannot open private cert file"); + /// let mut client_cert_vec = Vec::new(); + /// BufReader::new(cert_file) + /// .read_to_end(&mut client_cert_vec) + /// .expect("Unable to read client cert file"); + /// + /// let key_file = File::open(key_file).expect("cannot open private key file"); + /// let mut client_key_vec = Vec::new(); + /// BufReader::new(key_file) + /// .read_to_end(&mut client_key_vec) + /// .expect("Unable to read client key file"); + /// + /// let client = Client::build_with_tls( + /// url, + /// TlsCertificates { + /// client_tls: Some(ClientTlsConfig{ + /// client_cert: client_cert_vec, + /// client_key: client_key_vec, + /// }), + /// root_cert: Some(root_cert_vec), + /// } + /// ) + /// .expect("Unable to build client"); + /// + /// let connection_info = client.get_connection_info(); + /// + /// println!(">>> connection info: {connection_info:?}"); + /// + /// let mut con = client.get_async_connection(None).await?; + /// + /// con.set("key1", b"foo").await?; + /// + /// redis::cmd("SET") + /// .arg(&["key2", "bar"]) + /// .query_async(&mut con) + /// .await?; + /// + /// let result = redis::cmd("MGET") + /// .arg(&["key1", "key2"]) + /// .query_async(&mut con) + /// .await; + /// assert_eq!(result, Ok(("foo".to_string(), b"bar".to_vec()))); + /// println!("Result from MGET: {result:?}"); + /// + /// Ok(()) + /// } + /// ``` + #[cfg(feature = "tls-rustls")] + pub fn build_with_tls( + conn_info: C, + tls_certs: TlsCertificates, + ) -> RedisResult { + let connection_info = conn_info.into_connection_info()?; + + inner_build_with_tls(connection_info, tls_certs) + } + + /// Returns an async receiver for pub-sub messages. + #[cfg(feature = "tokio-comp")] + // TODO - do we want to type-erase pubsub using a trait, to allow us to replace it with a different implementation later? + pub async fn get_async_pubsub(&self) -> RedisResult { + #[allow(deprecated)] + self.get_async_connection(None) + .await + .map(|connection| connection.into_pubsub()) + } + + /// Returns an async receiver for monitor messages. + #[cfg(feature = "tokio-comp")] + // TODO - do we want to type-erase monitor using a trait, to allow us to replace it with a different implementation later? + pub async fn get_async_monitor(&self) -> RedisResult { + #[allow(deprecated)] + self.get_async_connection(None) + .await + .map(|connection| connection.into_monitor()) + } +} + +#[cfg(feature = "aio")] +use crate::aio::Runtime; + +impl ConnectionLike for Client { + fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult { + self.get_connection(None)?.req_packed_command(cmd) + } + + fn req_packed_commands( + &mut self, + cmd: &[u8], + offset: usize, + count: usize, + ) -> RedisResult> { + self.get_connection(None)? + .req_packed_commands(cmd, offset, count) + } + + fn get_db(&self) -> i64 { + self.connection_info.redis.db + } + + fn check_connection(&mut self) -> bool { + if let Ok(mut conn) = self.get_connection(None) { + conn.check_connection() + } else { + false + } + } + + fn is_open(&self) -> bool { + if let Ok(conn) = self.get_connection(None) { + conn.is_open() + } else { + false + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn regression_293_parse_ipv6_with_interface() { + assert!(Client::open(("fe80::cafe:beef%eno1", 6379)).is_ok()); + } +} diff --git a/glide-core/redis-rs/redis/src/cluster.rs b/glide-core/redis-rs/redis/src/cluster.rs new file mode 100644 index 0000000000..ffd537152a --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster.rs @@ -0,0 +1,1076 @@ +//! This module extends the library to support Redis Cluster. +//! +//! Note that this module does not currently provide pubsub +//! functionality. +//! +//! # Example +//! ```rust,no_run +//! use redis::Commands; +//! use redis::cluster::ClusterClient; +//! +//! let nodes = vec!["redis://127.0.0.1:6379/", "redis://127.0.0.1:6378/", "redis://127.0.0.1:6377/"]; +//! let client = ClusterClient::new(nodes).unwrap(); +//! let mut connection = client.get_connection(None).unwrap(); +//! +//! let _: () = connection.set("test", "test_data").unwrap(); +//! let rv: String = connection.get("test").unwrap(); +//! +//! assert_eq!(rv, "test_data"); +//! ``` +//! +//! # Pipelining +//! ```rust,no_run +//! use redis::Commands; +//! use redis::cluster::{cluster_pipe, ClusterClient}; +//! +//! let nodes = vec!["redis://127.0.0.1:6379/", "redis://127.0.0.1:6378/", "redis://127.0.0.1:6377/"]; +//! let client = ClusterClient::new(nodes).unwrap(); +//! let mut connection = client.get_connection(None).unwrap(); +//! +//! let key = "test"; +//! +//! let _: () = cluster_pipe() +//! .rpush(key, "123").ignore() +//! .ltrim(key, -10, -1).ignore() +//! .expire(key, 60).ignore() +//! .query(&mut connection).unwrap(); +//! ``` +pub use crate::cluster_client::{ClusterClient, ClusterClientBuilder}; +use crate::cluster_pipeline::UNROUTABLE_ERROR; +pub use crate::cluster_pipeline::{cluster_pipe, ClusterPipeline}; +use crate::cluster_routing::{ + MultipleNodeRoutingInfo, ResponsePolicy, Routable, SingleNodeRoutingInfo, +}; +use crate::cluster_slotmap::SlotMap; +use crate::cluster_topology::parse_and_count_slots; +use crate::cmd::{cmd, Cmd}; +use crate::connection::{ + connect, Connection, ConnectionAddr, ConnectionInfo, ConnectionLike, RedisConnectionInfo, +}; +use crate::parser::parse_redis_value; +use crate::types::{ErrorKind, HashMap, RedisError, RedisResult, RetryMethod, Value}; +pub use crate::TlsMode; // Pub for backwards compatibility +use crate::{ + cluster_client::ClusterParams, + cluster_routing::{Redirect, Route, RoutingInfo}, + IntoConnectionInfo, PushInfo, +}; +use rand::{seq::IteratorRandom, thread_rng}; +use std::cell::RefCell; +use std::collections::HashSet; +use std::str::FromStr; +use std::sync::Arc; +use std::thread; +use std::time::Duration; + +use tokio::sync::mpsc; + +#[cfg(feature = "tls-rustls")] +use crate::tls::TlsConnParams; + +#[cfg(not(feature = "tls-rustls"))] +use crate::connection::TlsConnParams; + +#[derive(Clone)] +enum Input<'a> { + Slice { + cmd: &'a [u8], + routable: Value, + }, + Cmd(&'a Cmd), + Commands { + cmd: &'a [u8], + route: SingleNodeRoutingInfo, + offset: usize, + count: usize, + }, +} + +impl<'a> Input<'a> { + fn send(&'a self, connection: &mut impl ConnectionLike) -> RedisResult { + match self { + Input::Slice { cmd, routable: _ } => { + connection.req_packed_command(cmd).map(Output::Single) + } + Input::Cmd(cmd) => connection.req_command(cmd).map(Output::Single), + Input::Commands { + cmd, + route: _, + offset, + count, + } => connection + .req_packed_commands(cmd, *offset, *count) + .map(Output::Multi), + } + } +} + +impl<'a> Routable for Input<'a> { + fn arg_idx(&self, idx: usize) -> Option<&[u8]> { + match self { + Input::Slice { cmd: _, routable } => routable.arg_idx(idx), + Input::Cmd(cmd) => cmd.arg_idx(idx), + Input::Commands { .. } => None, + } + } + + fn position(&self, candidate: &[u8]) -> Option { + match self { + Input::Slice { cmd: _, routable } => routable.position(candidate), + Input::Cmd(cmd) => cmd.position(candidate), + Input::Commands { .. } => None, + } + } +} + +enum Output { + Single(Value), + Multi(Vec), +} + +impl From for Value { + fn from(value: Output) -> Self { + match value { + Output::Single(value) => value, + Output::Multi(values) => Value::Array(values), + } + } +} + +impl From for Vec { + fn from(value: Output) -> Self { + match value { + Output::Single(value) => vec![value], + Output::Multi(values) => values, + } + } +} + +/// Implements the process of connecting to a Redis server +/// and obtaining and configuring a connection handle. +pub trait Connect: Sized { + /// Connect to a node, returning handle for command execution. + fn connect(info: T, timeout: Option) -> RedisResult + where + T: IntoConnectionInfo; + + /// Sends an already encoded (packed) command into the TCP socket and + /// does not read a response. This is useful for commands like + /// `MONITOR` which yield multiple items. This needs to be used with + /// care because it changes the state of the connection. + fn send_packed_command(&mut self, cmd: &[u8]) -> RedisResult<()>; + + /// Sets the write timeout for the connection. + /// + /// If the provided value is `None`, then `send_packed_command` call will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + fn set_write_timeout(&self, dur: Option) -> RedisResult<()>; + + /// Sets the read timeout for the connection. + /// + /// If the provided value is `None`, then `recv_response` call will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + fn set_read_timeout(&self, dur: Option) -> RedisResult<()>; + + /// Fetches a single response from the connection. This is useful + /// if used in combination with `send_packed_command`. + fn recv_response(&mut self) -> RedisResult; +} + +impl Connect for Connection { + fn connect(info: T, timeout: Option) -> RedisResult + where + T: IntoConnectionInfo, + { + connect(&info.into_connection_info()?, timeout) + } + + fn send_packed_command(&mut self, cmd: &[u8]) -> RedisResult<()> { + Self::send_packed_command(self, cmd) + } + + fn set_write_timeout(&self, dur: Option) -> RedisResult<()> { + Self::set_write_timeout(self, dur) + } + + fn set_read_timeout(&self, dur: Option) -> RedisResult<()> { + Self::set_read_timeout(self, dur) + } + + fn recv_response(&mut self) -> RedisResult { + Self::recv_response(self) + } +} + +/// This represents a Redis Cluster connection. It stores the +/// underlying connections maintained for each node in the cluster, as well +/// as common parameters for connecting to nodes and executing commands. +pub struct ClusterConnection { + initial_nodes: Vec, + connections: RefCell>, + slots: RefCell, + auto_reconnect: RefCell, + read_timeout: RefCell>, + write_timeout: RefCell>, + cluster_params: ClusterParams, +} + +impl ClusterConnection +where + C: ConnectionLike + Connect, +{ + pub(crate) fn new( + cluster_params: ClusterParams, + initial_nodes: Vec, + _push_sender: Option>, + ) -> RedisResult { + let connection = Self { + connections: RefCell::new(HashMap::new()), + slots: RefCell::new(SlotMap::new( + vec![], + cluster_params.read_from_replicas.clone(), + )), + auto_reconnect: RefCell::new(true), + cluster_params, + read_timeout: RefCell::new(None), + write_timeout: RefCell::new(None), + initial_nodes: initial_nodes.to_vec(), + }; + connection.create_initial_connections()?; + + Ok(connection) + } + + /// Set an auto reconnect attribute. + /// Default value is true; + pub fn set_auto_reconnect(&self, value: bool) { + let mut auto_reconnect = self.auto_reconnect.borrow_mut(); + *auto_reconnect = value; + } + + /// Sets the write timeout for the connection. + /// + /// If the provided value is `None`, then `send_packed_command` call will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + pub fn set_write_timeout(&self, dur: Option) -> RedisResult<()> { + // Check if duration is valid before updating local value. + if dur.is_some() && dur.unwrap().is_zero() { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Duration should be None or non-zero.", + ))); + } + + let mut t = self.write_timeout.borrow_mut(); + *t = dur; + let connections = self.connections.borrow(); + for conn in connections.values() { + conn.set_write_timeout(dur)?; + } + Ok(()) + } + + /// Sets the read timeout for the connection. + /// + /// If the provided value is `None`, then `recv_response` call will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + pub fn set_read_timeout(&self, dur: Option) -> RedisResult<()> { + // Check if duration is valid before updating local value. + if dur.is_some() && dur.unwrap().is_zero() { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Duration should be None or non-zero.", + ))); + } + + let mut t = self.read_timeout.borrow_mut(); + *t = dur; + let connections = self.connections.borrow(); + for conn in connections.values() { + conn.set_read_timeout(dur)?; + } + Ok(()) + } + + /// Check that all connections it has are available (`PING` internally). + #[doc(hidden)] + pub fn check_connection(&mut self) -> bool { + ::check_connection(self) + } + + pub(crate) fn execute_pipeline(&mut self, pipe: &ClusterPipeline) -> RedisResult> { + self.send_recv_and_retry_cmds(pipe.commands()) + } + + /// Returns the connection status. + /// + /// The connection is open until any `read_response` call received an + /// invalid response from the server (most likely a closed or dropped + /// connection, otherwise a Redis protocol error). When using unix + /// sockets the connection is open until writing a command failed with a + /// `BrokenPipe` error. + fn create_initial_connections(&self) -> RedisResult<()> { + let mut connections = HashMap::with_capacity(self.initial_nodes.len()); + + for info in self.initial_nodes.iter() { + let addr = info.addr.to_string(); + + if let Ok(mut conn) = self.connect(&addr) { + if conn.check_connection() { + connections.insert(addr, conn); + break; + } + } + } + + if connections.is_empty() { + return Err(RedisError::from(( + ErrorKind::IoError, + "It failed to check startup nodes.", + ))); + } + + *self.connections.borrow_mut() = connections; + self.refresh_slots()?; + Ok(()) + } + + // Query a node to discover slot-> master mappings. + fn refresh_slots(&self) -> RedisResult<()> { + let mut slots = self.slots.borrow_mut(); + *slots = self.create_new_slots()?; + + let nodes = slots.all_node_addresses(); + let mut connections = self.connections.borrow_mut(); + *connections = nodes + .into_iter() + .filter_map(|addr| { + let addr = addr.to_string(); + if connections.contains_key(&addr) { + let mut conn = connections.remove(&addr).unwrap(); + if conn.check_connection() { + return Some((addr.to_string(), conn)); + } + } + + if let Ok(mut conn) = self.connect(&addr) { + if conn.check_connection() { + return Some((addr.to_string(), conn)); + } + } + + None + }) + .collect(); + + Ok(()) + } + + fn create_new_slots(&self) -> RedisResult { + let mut connections = self.connections.borrow_mut(); + let mut rng = thread_rng(); + let len = connections.len(); + let samples = connections.iter_mut().choose_multiple(&mut rng, len); + let mut result = Err(RedisError::from(( + ErrorKind::ResponseError, + "Slot refresh error.", + "didn't get any slots from server".to_string(), + ))); + for (addr, conn) in samples { + let value = conn.req_command(&slot_cmd())?; + let addr = addr.split(':').next().ok_or(RedisError::from(( + ErrorKind::ClientError, + "can't parse node address", + )))?; + match parse_and_count_slots(&value, self.cluster_params.tls, addr).map(|slots_data| { + SlotMap::new(slots_data.1, self.cluster_params.read_from_replicas.clone()) + }) { + Ok(new_slots) => { + result = Ok(new_slots); + break; + } + Err(err) => result = Err(err), + } + } + result + } + + fn connect(&self, node: &str) -> RedisResult { + let info = get_connection_info(node, self.cluster_params.clone())?; + + let mut conn = C::connect(info, Some(self.cluster_params.connection_timeout))?; + if self.cluster_params.read_from_replicas + != crate::cluster_slotmap::ReadFromReplicaStrategy::AlwaysFromPrimary + { + // If READONLY is sent to primary nodes, it will have no effect + cmd("READONLY").query(&mut conn)?; + } + conn.set_read_timeout(*self.read_timeout.borrow())?; + conn.set_write_timeout(*self.write_timeout.borrow())?; + Ok(conn) + } + + fn get_connection<'a>( + &self, + connections: &'a mut HashMap, + route: &Route, + ) -> RedisResult<(String, &'a mut C)> { + let slots = self.slots.borrow(); + if let Some(addr) = slots.slot_addr_for_route(route) { + Ok(( + addr.to_string(), + self.get_connection_by_addr(connections, &addr)?, + )) + } else { + // try a random node next. This is safe if slots are involved + // as a wrong node would reject the request. + Ok(get_random_connection(connections)) + } + } + + fn get_connection_by_addr<'a>( + &self, + connections: &'a mut HashMap, + addr: &str, + ) -> RedisResult<&'a mut C> { + if connections.contains_key(addr) { + Ok(connections.get_mut(addr).unwrap()) + } else { + // Create new connection. + // TODO: error handling + let conn = self.connect(addr)?; + Ok(connections.entry(addr.to_string()).or_insert(conn)) + } + } + + fn get_addr_for_cmd(&self, cmd: &Cmd) -> RedisResult { + let slots = self.slots.borrow(); + + let addr_for_slot = |route: Route| -> RedisResult { + let slot_addr = slots + .slot_addr_for_route(&route) + .ok_or((ErrorKind::ClusterDown, "Missing slot coverage"))?; + Ok(slot_addr.to_string()) + }; + + match RoutingInfo::for_routable(cmd) { + Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + | Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::RandomPrimary)) => { + Ok(addr_for_slot(Route::new_random_primary())?) + } + Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(route))) => { + Ok(addr_for_slot(route)?) + } + _ => fail!(UNROUTABLE_ERROR), + } + } + + fn map_cmds_to_nodes(&self, cmds: &[Cmd]) -> RedisResult> { + let mut cmd_map: HashMap = HashMap::new(); + + for (idx, cmd) in cmds.iter().enumerate() { + let addr = self.get_addr_for_cmd(cmd)?; + let nc = cmd_map + .entry(addr.clone()) + .or_insert_with(|| NodeCmd::new(addr)); + nc.indexes.push(idx); + cmd.write_packed_command(&mut nc.pipe); + } + + let mut result = Vec::new(); + for (_, v) in cmd_map.drain() { + result.push(v); + } + Ok(result) + } + + fn execute_on_all<'a>( + &'a self, + input: Input, + addresses: HashSet>, + connections: &'a mut HashMap, + ) -> Vec, Value)>> { + addresses + .into_iter() + .map(|addr| { + let connection = self.get_connection_by_addr(connections, &addr)?; + match input { + Input::Slice { cmd, routable: _ } => connection.req_packed_command(cmd), + Input::Cmd(cmd) => connection.req_command(cmd), + Input::Commands { + cmd: _, + route: _, + offset: _, + count: _, + } => Err(( + ErrorKind::ClientError, + "req_packed_commands isn't supported with multiple nodes", + ) + .into()), + } + .map(|res| (addr, res)) + }) + .collect() + } + + fn execute_on_all_nodes<'a>( + &'a self, + input: Input, + slots: &'a mut SlotMap, + connections: &'a mut HashMap, + ) -> Vec, Value)>> { + self.execute_on_all(input, slots.all_node_addresses(), connections) + } + + fn execute_on_all_primaries<'a>( + &'a self, + input: Input, + slots: &'a mut SlotMap, + connections: &'a mut HashMap, + ) -> Vec, Value)>> { + self.execute_on_all(input, slots.addresses_for_all_primaries(), connections) + } + + fn execute_multi_slot<'a, 'b>( + &'a self, + input: Input, + slots: &'a mut SlotMap, + connections: &'a mut HashMap, + routes: &'b [(Route, Vec)], + ) -> Vec, Value)>> + where + 'b: 'a, + { + slots + .addresses_for_multi_slot(routes) + .enumerate() + .map(|(index, addr)| { + let addr = addr.ok_or(RedisError::from(( + ErrorKind::IoError, + "Couldn't find connection", + )))?; + let connection = self.get_connection_by_addr(connections, &addr)?; + let (_, indices) = routes.get(index).unwrap(); + let cmd = + crate::cluster_routing::command_for_multi_slot_indices(&input, indices.iter()); + connection.req_command(&cmd).map(|res| (addr, res)) + }) + .collect() + } + + fn execute_on_multiple_nodes( + &self, + input: Input, + routing: MultipleNodeRoutingInfo, + response_policy: Option, + ) -> RedisResult { + let mut connections = self.connections.borrow_mut(); + let mut slots = self.slots.borrow_mut(); + + let results = match &routing { + MultipleNodeRoutingInfo::MultiSlot((routes, _)) => { + self.execute_multi_slot(input, &mut slots, &mut connections, routes) + } + MultipleNodeRoutingInfo::AllMasters => { + self.execute_on_all_primaries(input, &mut slots, &mut connections) + } + MultipleNodeRoutingInfo::AllNodes => { + self.execute_on_all_nodes(input, &mut slots, &mut connections) + } + }; + + match response_policy { + Some(ResponsePolicy::AllSucceeded) => { + for result in results { + result?; + } + + Ok(Value::Okay) + } + Some(ResponsePolicy::OneSucceeded) => { + let mut last_failure = None; + + for result in results { + match result { + Ok((_, val)) => return Ok(val), + Err(err) => last_failure = Some(err), + } + } + + Err(last_failure + .unwrap_or_else(|| (ErrorKind::IoError, "Couldn't find a connection").into())) + } + Some(ResponsePolicy::FirstSucceededNonEmptyOrAllEmpty) => { + // Attempt to return the first result that isn't `Nil` or an error. + // If no such response is found and all servers returned `Nil`, it indicates that all shards are empty, so return `Nil`. + // If we received only errors, return the last received error. + // If we received a mix of errors and `Nil`s, we can't determine if all shards are empty, + // thus we return the last received error instead of `Nil`. + let mut last_failure = None; + let num_of_results = results.len(); + let mut nil_counter = 0; + for result in results { + match result.map(|(_, res)| res) { + Ok(Value::Nil) => nil_counter += 1, + Ok(val) => return Ok(val), + Err(err) => last_failure = Some(err), + } + } + if nil_counter == num_of_results { + Ok(Value::Nil) + } else { + Err(last_failure.unwrap_or_else(|| { + (ErrorKind::IoError, "Couldn't find a connection").into() + })) + } + } + Some(ResponsePolicy::Aggregate(op)) => { + let results = results + .into_iter() + .map(|res| res.map(|(_, val)| val)) + .collect::>>()?; + crate::cluster_routing::aggregate(results, op) + } + Some(ResponsePolicy::AggregateLogical(op)) => { + let results = results + .into_iter() + .map(|res| res.map(|(_, val)| val)) + .collect::>>()?; + crate::cluster_routing::logical_aggregate(results, op) + } + Some(ResponsePolicy::CombineArrays) => { + let results = results + .into_iter() + .map(|res| res.map(|(_, val)| val)) + .collect::>>()?; + match routing { + MultipleNodeRoutingInfo::MultiSlot((vec, args_pattern)) => { + crate::cluster_routing::combine_and_sort_array_results( + results, + &vec, + &args_pattern, + ) + } + _ => crate::cluster_routing::combine_array_results(results), + } + } + Some(ResponsePolicy::CombineMaps) => { + let results = results + .into_iter() + .map(|res| res.map(|(_, val)| val)) + .collect::>>()?; + crate::cluster_routing::combine_map_results(results) + } + Some(ResponsePolicy::Special) | None => { + // This is our assumption - if there's no coherent way to aggregate the responses, we just map each response to the sender, and pass it to the user. + // TODO - once Value::Error is merged, we can use join_all and report separate errors and also pass successes. + let results = results + .into_iter() + .map(|result| { + result.map(|(addr, val)| (Value::BulkString(addr.as_bytes().to_vec()), val)) + }) + .collect::>>()?; + Ok(Value::Map(results)) + } + } + } + + #[allow(clippy::unnecessary_unwrap)] + fn request(&self, input: Input) -> RedisResult { + let route_option = match &input { + Input::Slice { cmd: _, routable } => RoutingInfo::for_routable(routable), + Input::Cmd(cmd) => RoutingInfo::for_routable(*cmd), + Input::Commands { + cmd: _, + route, + offset: _, + count: _, + } => Some(RoutingInfo::SingleNode(route.clone())), + }; + let single_node_routing = match route_option { + Some(RoutingInfo::SingleNode(single_node_routing)) => single_node_routing, + Some(RoutingInfo::MultiNode((multi_node_routing, response_policy))) => { + return self + .execute_on_multiple_nodes(input, multi_node_routing, response_policy) + .map(Output::Single); + } + None => fail!(UNROUTABLE_ERROR), + }; + + let mut retries = 0; + let mut redirected = None::; + + loop { + // Get target address and response. + let (addr, rv) = { + let mut connections = self.connections.borrow_mut(); + let (addr, conn) = if let Some(redirected) = redirected.take() { + let (addr, is_asking) = match redirected { + Redirect::Moved(addr) => (addr, false), + Redirect::Ask(addr) => (addr, true), + }; + let conn = self.get_connection_by_addr(&mut connections, &addr)?; + if is_asking { + // if we are in asking mode we want to feed a single + // ASKING command into the connection before what we + // actually want to execute. + conn.req_packed_command(&b"*1\r\n$6\r\nASKING\r\n"[..])?; + } + (addr.to_string(), conn) + } else { + match &single_node_routing { + SingleNodeRoutingInfo::Random => get_random_connection(&mut connections), + SingleNodeRoutingInfo::SpecificNode(route) => { + self.get_connection(&mut connections, route)? + } + SingleNodeRoutingInfo::RandomPrimary => { + self.get_connection(&mut connections, &Route::new_random_primary())? + } + SingleNodeRoutingInfo::ByAddress { host, port } => { + let address = format!("{host}:{port}"); + let conn = self.get_connection_by_addr(&mut connections, &address)?; + (address, conn) + } + } + }; + (addr, input.send(conn)) + }; + + match rv { + Ok(rv) => return Ok(rv), + Err(err) => { + if retries == self.cluster_params.retry_params.number_of_retries { + return Err(err); + } + retries += 1; + + match err.retry_method() { + RetryMethod::AskRedirect => { + redirected = err + .redirect_node() + .map(|(node, _slot)| Redirect::Ask(node.to_string())); + } + RetryMethod::MovedRedirect => { + // Refresh slots. + self.refresh_slots()?; + // Request again. + redirected = err + .redirect_node() + .map(|(node, _slot)| Redirect::Moved(node.to_string())); + } + RetryMethod::WaitAndRetryOnPrimaryRedirectOnReplica + | RetryMethod::WaitAndRetry => { + // Sleep and retry. + let sleep_time = self + .cluster_params + .retry_params + .wait_time_for_retry(retries); + thread::sleep(sleep_time); + } + RetryMethod::Reconnect | RetryMethod::ReconnectAndRetry => { + if *self.auto_reconnect.borrow() { + if let Ok(mut conn) = self.connect(&addr) { + if conn.check_connection() { + self.connections.borrow_mut().insert(addr, conn); + } + } + } + } + RetryMethod::NoRetry => { + return Err(err); + } + RetryMethod::RetryImmediately => {} + } + } + } + } + } + + fn send_recv_and_retry_cmds(&self, cmds: &[Cmd]) -> RedisResult> { + // Vector to hold the results, pre-populated with `Nil` values. This allows the original + // cmd ordering to be re-established by inserting the response directly into the result + // vector (e.g., results[10] = response). + let mut results = vec![Value::Nil; cmds.len()]; + + let to_retry = self + .send_all_commands(cmds) + .and_then(|node_cmds| self.recv_all_commands(&mut results, &node_cmds))?; + + if to_retry.is_empty() { + return Ok(results); + } + + // Refresh the slots to ensure that we have a clean slate for the retry attempts. + self.refresh_slots()?; + + // Given that there are commands that need to be retried, it means something in the cluster + // topology changed. Execute each command separately to take advantage of the existing + // retry logic that handles these cases. + for retry_idx in to_retry { + let cmd = &cmds[retry_idx]; + results[retry_idx] = self.request(Input::Cmd(cmd))?.into(); + } + Ok(results) + } + + // Build up a pipeline per node, then send it + fn send_all_commands(&self, cmds: &[Cmd]) -> RedisResult> { + let mut connections = self.connections.borrow_mut(); + + let node_cmds = self.map_cmds_to_nodes(cmds)?; + for nc in &node_cmds { + self.get_connection_by_addr(&mut connections, &nc.addr)? + .send_packed_command(&nc.pipe)?; + } + Ok(node_cmds) + } + + // Receive from each node, keeping track of which commands need to be retried. + fn recv_all_commands( + &self, + results: &mut [Value], + node_cmds: &[NodeCmd], + ) -> RedisResult> { + let mut to_retry = Vec::new(); + let mut connections = self.connections.borrow_mut(); + let mut first_err = None; + + for nc in node_cmds { + for cmd_idx in &nc.indexes { + match self + .get_connection_by_addr(&mut connections, &nc.addr)? + .recv_response() + { + Ok(item) => results[*cmd_idx] = item, + Err(err) if err.is_cluster_error() => to_retry.push(*cmd_idx), + Err(err) => first_err = first_err.or(Some(err)), + } + } + } + match first_err { + Some(err) => Err(err), + None => Ok(to_retry), + } + } +} + +const MULTI: &[u8] = "*1\r\n$5\r\nMULTI\r\n".as_bytes(); +impl ConnectionLike for ClusterConnection { + fn supports_pipelining(&self) -> bool { + false + } + + fn req_command(&mut self, cmd: &Cmd) -> RedisResult { + self.request(Input::Cmd(cmd)).map(|res| res.into()) + } + + fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult { + let actual_cmd = if cmd.starts_with(MULTI) { + &cmd[MULTI.len()..] + } else { + cmd + }; + let value = parse_redis_value(actual_cmd)?; + self.request(Input::Slice { + cmd, + routable: value, + }) + .map(|res| res.into()) + } + + fn req_packed_commands( + &mut self, + cmd: &[u8], + offset: usize, + count: usize, + ) -> RedisResult> { + let actual_cmd = if cmd.starts_with(MULTI) { + &cmd[MULTI.len()..] + } else { + cmd + }; + let value = parse_redis_value(actual_cmd)?; + let route = match RoutingInfo::for_routable(&value) { + Some(RoutingInfo::MultiNode(_)) => None, + Some(RoutingInfo::SingleNode(route)) => Some(route), + None => None, + } + .unwrap_or(SingleNodeRoutingInfo::Random); + self.request(Input::Commands { + cmd, + offset, + count, + route, + }) + .map(|res| res.into()) + } + + fn get_db(&self) -> i64 { + 0 + } + + fn is_open(&self) -> bool { + let connections = self.connections.borrow(); + for conn in connections.values() { + if !conn.is_open() { + return false; + } + } + true + } + + fn check_connection(&mut self) -> bool { + let mut connections = self.connections.borrow_mut(); + for conn in connections.values_mut() { + if !conn.check_connection() { + return false; + } + } + true + } +} + +#[derive(Debug)] +struct NodeCmd { + // The original command indexes + indexes: Vec, + pipe: Vec, + addr: String, +} + +impl NodeCmd { + fn new(a: String) -> NodeCmd { + NodeCmd { + indexes: vec![], + pipe: vec![], + addr: a, + } + } +} + +// TODO: This function can panic and should probably +// return an Option instead: +fn get_random_connection( + connections: &mut HashMap, +) -> (String, &mut C) { + let addr = connections + .keys() + .choose(&mut thread_rng()) + .expect("Connections is empty") + .to_string(); + let con = connections.get_mut(&addr).expect("Connections is empty"); + (addr, con) +} + +// The node string passed to this function will always be in the format host:port as it is either: +// - Created by calling ConnectionAddr::to_string (unix connections are not supported in cluster mode) +// - Returned from redis via the ASK/MOVED response +pub(crate) fn get_connection_info( + node: &str, + cluster_params: ClusterParams, +) -> RedisResult { + let invalid_error = || (ErrorKind::InvalidClientConfig, "Invalid node string"); + + let (host, port) = node + .rsplit_once(':') + .and_then(|(host, port)| { + Some(host.trim_start_matches('[').trim_end_matches(']')) + .filter(|h| !h.is_empty()) + .zip(u16::from_str(port).ok()) + }) + .ok_or_else(invalid_error)?; + + Ok(ConnectionInfo { + addr: get_connection_addr( + host.to_string(), + port, + cluster_params.tls, + cluster_params.tls_params, + ), + redis: RedisConnectionInfo { + password: cluster_params.password, + username: cluster_params.username, + client_name: cluster_params.client_name, + protocol: cluster_params.protocol, + db: 0, + pubsub_subscriptions: cluster_params.pubsub_subscriptions, + }, + }) +} + +pub(crate) fn get_connection_addr( + host: String, + port: u16, + tls: Option, + tls_params: Option, +) -> ConnectionAddr { + match tls { + Some(TlsMode::Secure) => ConnectionAddr::TcpTls { + host, + port, + insecure: false, + tls_params, + }, + Some(TlsMode::Insecure) => ConnectionAddr::TcpTls { + host, + port, + insecure: true, + tls_params, + }, + _ => ConnectionAddr::Tcp(host, port), + } +} + +pub(crate) fn slot_cmd() -> Cmd { + let mut cmd = Cmd::new(); + cmd.arg("CLUSTER").arg("SLOTS"); + cmd +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_cluster_node_host_port() { + let cases = vec![ + ( + "127.0.0.1:6379", + ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379u16), + ), + ( + "localhost.localdomain:6379", + ConnectionAddr::Tcp("localhost.localdomain".to_string(), 6379u16), + ), + ( + "dead::cafe:beef:30001", + ConnectionAddr::Tcp("dead::cafe:beef".to_string(), 30001u16), + ), + ( + "[fe80::cafe:beef%en1]:30001", + ConnectionAddr::Tcp("fe80::cafe:beef%en1".to_string(), 30001u16), + ), + ]; + + for (input, expected) in cases { + let res = get_connection_info(input, ClusterParams::default()); + assert_eq!(res.unwrap().addr, expected); + } + + let cases = vec![":0", "[]:6379"]; + for input in cases { + let res = get_connection_info(input, ClusterParams::default()); + assert_eq!( + res.err(), + Some(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Invalid node string", + ))), + ); + } + } +} diff --git a/glide-core/redis-rs/redis/src/cluster_async/LICENSE b/glide-core/redis-rs/redis/src/cluster_async/LICENSE new file mode 100644 index 0000000000..aaa71a1638 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_async/LICENSE @@ -0,0 +1,7 @@ +Copyright 2019 Atsushi Koge, Markus Westerlind + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs b/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs new file mode 100644 index 0000000000..955d24d9e9 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs @@ -0,0 +1,1192 @@ +use crate::cluster_async::ConnectionFuture; +use crate::cluster_routing::{Route, ShardAddrs, SlotAddr}; +use crate::cluster_slotmap::{ReadFromReplicaStrategy, SlotMap, SlotMapValue}; +use crate::cluster_topology::TopologyHash; +use dashmap::DashMap; +use futures::FutureExt; +use rand::seq::IteratorRandom; +use std::net::IpAddr; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use telemetrylib::Telemetry; + +/// Count the number of connections in a connections_map object +macro_rules! count_connections { + ($conn_map:expr) => {{ + let mut count = 0usize; + for a in $conn_map { + count = count.saturating_add(if a.management_connection.is_some() { + 2 + } else { + 1 + }); + } + count + }}; +} + +/// A struct that encapsulates a network connection along with its associated IP address and AZ. +#[derive(Clone, Eq, PartialEq, Debug)] +pub struct ConnectionDetails { + /// The actual connection + pub conn: Connection, + /// The IP associated with the connection + pub ip: Option, + /// The availability zone associated with the connection + pub az: Option, +} + +impl ConnectionDetails +where + Connection: Clone + Send + 'static, +{ + /// Consumes the current instance and returns a new `ConnectionDetails` + /// where the connection is wrapped in a future. + #[doc(hidden)] + pub fn into_future(self) -> ConnectionDetails> { + ConnectionDetails { + conn: async { self.conn }.boxed().shared(), + ip: self.ip, + az: self.az, + } + } +} + +impl From<(Connection, Option, Option)> + for ConnectionDetails +{ + fn from(val: (Connection, Option, Option)) -> Self { + ConnectionDetails { + conn: val.0, + ip: val.1, + az: val.2, + } + } +} + +impl From> + for (Connection, Option, Option) +{ + fn from(val: ConnectionDetails) -> Self { + (val.conn, val.ip, val.az) + } +} + +#[derive(Clone, Eq, PartialEq, Debug)] +pub struct ClusterNode { + pub user_connection: ConnectionDetails, + pub management_connection: Option>, +} + +impl ClusterNode +where + Connection: Clone, +{ + pub fn new( + user_connection: ConnectionDetails, + management_connection: Option>, + ) -> Self { + Self { + user_connection, + management_connection, + } + } + + /// Return the number of underlying connections managed by this instance of ClusterNode + pub fn connections_count(&self) -> usize { + if self.management_connection.is_some() { + 2 + } else { + 1 + } + } + + pub(crate) fn get_connection(&self, conn_type: &ConnectionType) -> Connection { + match conn_type { + ConnectionType::User => self.user_connection.conn.clone(), + ConnectionType::PreferManagement => self.management_connection.as_ref().map_or_else( + || self.user_connection.conn.clone(), + |management_conn| management_conn.conn.clone(), + ), + } + } +} + +#[derive(Clone, Eq, PartialEq, Debug)] + +pub(crate) enum ConnectionType { + User, + PreferManagement, +} + +pub(crate) struct ConnectionsMap(pub(crate) DashMap>); + +impl std::fmt::Display for ConnectionsMap { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for item in self.0.iter() { + let (address, node) = (item.key(), item.value()); + match node.user_connection.ip { + Some(ip) => writeln!(f, "{address} - {ip}")?, + None => writeln!(f, "{address}")?, + }; + } + Ok(()) + } +} + +pub(crate) struct ConnectionsContainer { + connection_map: DashMap>, + pub(crate) slot_map: SlotMap, + read_from_replica_strategy: ReadFromReplicaStrategy, + topology_hash: TopologyHash, +} + +impl Drop for ConnectionsContainer { + fn drop(&mut self) { + let count = count_connections!(&self.connection_map); + Telemetry::decr_total_connections(count); + } +} + +impl Default for ConnectionsContainer { + fn default() -> Self { + Self { + connection_map: Default::default(), + slot_map: Default::default(), + read_from_replica_strategy: ReadFromReplicaStrategy::AlwaysFromPrimary, + topology_hash: 0, + } + } +} + +pub(crate) type ConnectionAndAddress = (String, Connection); + +impl ConnectionsContainer +where + Connection: Clone, +{ + pub(crate) fn new( + slot_map: SlotMap, + connection_map: ConnectionsMap, + read_from_replica_strategy: ReadFromReplicaStrategy, + topology_hash: TopologyHash, + ) -> Self { + let connection_map = connection_map.0; + + // Update the telemetry with the number of connections + let count = count_connections!(&connection_map); + Telemetry::incr_total_connections(count); + + Self { + connection_map, + slot_map, + read_from_replica_strategy, + topology_hash, + } + } + + /// Returns an iterator over the nodes in the `slot_map`, yielding pairs of the node address and its associated shard addresses. + pub(crate) fn slot_map_nodes( + &self, + ) -> impl Iterator, Arc)> + '_ { + self.slot_map + .nodes_map() + .iter() + .map(|item| (item.key().clone(), item.value().clone())) + } + + // Extends the current connection map with the provided one + pub(crate) fn extend_connection_map( + &mut self, + other_connection_map: ConnectionsMap, + ) { + let conn_count_before = count_connections!(&self.connection_map); + self.connection_map.extend(other_connection_map.0); + let conn_count_after = count_connections!(&self.connection_map); + // Update the number of connections by the difference + Telemetry::incr_total_connections(conn_count_after.saturating_sub(conn_count_before)); + } + + /// Returns the availability zone associated with the connection in address + pub(crate) fn az_for_address(&self, address: &str) -> Option { + self.connection_map + .get(address) + .map(|item| item.value().user_connection.az.clone())? + } + + /// Returns true if the address represents a known primary node. + pub(crate) fn is_primary(&self, address: &String) -> bool { + self.connection_for_address(address).is_some() && self.slot_map.is_primary(address) + } + + fn round_robin_read_from_replica( + &self, + slot_map_value: &SlotMapValue, + ) -> Option> { + let addrs = &slot_map_value.addrs; + let initial_index = slot_map_value.last_used_replica.load(Ordering::Relaxed); + let mut check_count = 0; + loop { + check_count += 1; + + // Looped through all replicas, no connected replica was found. + if check_count > addrs.replicas().len() { + return self.connection_for_address(addrs.primary().as_str()); + } + let index = (initial_index + check_count) % addrs.replicas().len(); + if let Some(connection) = self.connection_for_address(addrs.replicas()[index].as_str()) + { + let _ = slot_map_value.last_used_replica.compare_exchange_weak( + initial_index, + index, + Ordering::Relaxed, + Ordering::Relaxed, + ); + return Some(connection); + } + } + } + + /// Returns the node's connection in the same availability zone as `client_az` in round robin strategy if exits, + /// if not, will fall back to any available replica or primary. + pub(crate) fn round_robin_read_from_replica_with_az_awareness( + &self, + slot_map_value: &SlotMapValue, + client_az: String, + ) -> Option> { + let addrs = &slot_map_value.addrs; + let initial_index = slot_map_value.last_used_replica.load(Ordering::Relaxed); + let mut retries = 0usize; + + loop { + retries = retries.saturating_add(1); + // Looped through all replicas; no connected replica found in the same availability zone. + if retries > addrs.replicas().len() { + // Attempt a fallback to any available replica or primary if needed. + return self.round_robin_read_from_replica(slot_map_value); + } + + // Calculate index based on initial index and check count. + let index = (initial_index + retries) % addrs.replicas().len(); + let replica = &addrs.replicas()[index]; + + // Check if this replica’s availability zone matches the user’s availability zone. + if let Some((address, connection_details)) = + self.connection_details_for_address(replica.as_str()) + { + if self.az_for_address(&address) == Some(client_az.clone()) { + // Attempt to update `latest_used_replica` with the index of this replica. + let _ = slot_map_value.last_used_replica.compare_exchange_weak( + initial_index, + index, + Ordering::Relaxed, + Ordering::Relaxed, + ); + return Some((address, connection_details.conn)); + } + } + } + } + + fn lookup_route(&self, route: &Route) -> Option> { + let slot_map_value = self.slot_map.slot_value_for_route(route)?; + let addrs = &slot_map_value.addrs; + if addrs.replicas().is_empty() { + return self.connection_for_address(addrs.primary().as_str()); + } + + match route.slot_addr() { + // Master strategy will be in use when the command is not read_only + SlotAddr::Master => self.connection_for_address(addrs.primary().as_str()), + // ReplicaOptional strategy will be in use when the command is read_only + SlotAddr::ReplicaOptional => match &self.read_from_replica_strategy { + ReadFromReplicaStrategy::AlwaysFromPrimary => { + self.connection_for_address(addrs.primary().as_str()) + } + ReadFromReplicaStrategy::RoundRobin => { + self.round_robin_read_from_replica(slot_map_value) + } + ReadFromReplicaStrategy::AZAffinity(az) => self + .round_robin_read_from_replica_with_az_awareness( + slot_map_value, + az.to_string(), + ), + }, + // when the user strategy per command is replica_preffered + SlotAddr::ReplicaRequired => match &self.read_from_replica_strategy { + ReadFromReplicaStrategy::AZAffinity(az) => self + .round_robin_read_from_replica_with_az_awareness( + slot_map_value, + az.to_string(), + ), + _ => self.round_robin_read_from_replica(slot_map_value), + }, + } + } + + pub(crate) fn connection_for_route( + &self, + route: &Route, + ) -> Option> { + self.lookup_route(route).or_else(|| { + if route.slot_addr() != SlotAddr::Master { + self.lookup_route(&Route::new(route.slot(), SlotAddr::Master)) + } else { + None + } + }) + } + + pub(crate) fn all_node_connections( + &self, + ) -> impl Iterator> + '_ { + self.connection_map.iter().map(move |item| { + let (node, address) = (item.key(), item.value()); + (node.clone(), address.user_connection.conn.clone()) + }) + } + + pub(crate) fn all_primary_connections( + &self, + ) -> impl Iterator> + '_ { + self.slot_map + .addresses_for_all_primaries() + .into_iter() + .flat_map(|addr| self.connection_for_address(&addr)) + } + + pub(crate) fn node_for_address(&self, address: &str) -> Option> { + self.connection_map + .get(address) + .map(|item| item.value().clone()) + } + + pub(crate) fn connection_for_address( + &self, + address: &str, + ) -> Option> { + self.connection_map.get(address).map(|item| { + let (address, conn) = (item.key(), item.value()); + (address.clone(), conn.user_connection.conn.clone()) + }) + } + + pub(crate) fn connection_details_for_address( + &self, + address: &str, + ) -> Option>> { + self.connection_map.get(address).map(|item| { + let (address, conn) = (item.key(), item.value()); + (address.clone(), conn.user_connection.clone()) + }) + } + + pub(crate) fn random_connections( + &self, + amount: usize, + conn_type: ConnectionType, + ) -> Option>> { + (!self.connection_map.is_empty()).then_some({ + self.connection_map + .iter() + .choose_multiple(&mut rand::thread_rng(), amount) + .into_iter() + .map(move |item| { + let (address, node) = (item.key(), item.value()); + let conn = node.get_connection(&conn_type); + (address.clone(), conn) + }) + .collect::>() + }) + } + + pub(crate) fn replace_or_add_connection_for_address( + &self, + address: impl Into, + node: ClusterNode, + ) -> String { + let address = address.into(); + + // Increase the total number of connections by the number of connections managed by `node` + Telemetry::incr_total_connections(node.connections_count()); + + if let Some(old_conn) = self.connection_map.insert(address.clone(), node) { + // We are replacing a node. Reduce the counter by the number of connections managed by + // the old connection + Telemetry::decr_total_connections(old_conn.connections_count()); + }; + address + } + + pub(crate) fn remove_node(&self, address: &String) -> Option> { + if let Some((_key, old_conn)) = self.connection_map.remove(address) { + Telemetry::decr_total_connections(old_conn.connections_count()); + Some(old_conn) + } else { + None + } + } + + pub(crate) fn len(&self) -> usize { + self.connection_map.len() + } + + pub(crate) fn connection_map(&self) -> &DashMap> { + &self.connection_map + } + + pub(crate) fn get_current_topology_hash(&self) -> TopologyHash { + self.topology_hash + } + + /// Returns true if the connections container contains no connections. + pub(crate) fn is_empty(&self) -> bool { + self.connection_map.is_empty() + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use crate::cluster_routing::Slot; + + use super::*; + impl ClusterNode + where + Connection: Clone, + { + pub(crate) fn new_only_with_user_conn(user_connection: Connection) -> Self { + let ip = None; + let az = None; + Self { + user_connection: (user_connection, ip, az).into(), + management_connection: None, + } + } + } + fn remove_nodes(container: &ConnectionsContainer, addresses: &[&str]) { + for address in addresses { + container.remove_node(&(*address).into()); + } + } + + fn remove_all_connections(container: &ConnectionsContainer) { + remove_nodes( + container, + &[ + "primary1", + "primary2", + "primary3", + "replica2-1", + "replica3-1", + "replica3-2", + ], + ); + } + + fn one_of( + connection: Option>, + expected_connections: &[usize], + ) -> bool { + let found = connection.unwrap().1; + expected_connections.contains(&found) + } + fn create_cluster_node( + connection: usize, + use_management_connections: bool, + node_az: Option, + ) -> ClusterNode { + let ip = None; + ClusterNode::new( + (connection, ip, node_az.clone()).into(), + if use_management_connections { + Some((connection * 10, ip, node_az).into()) + } else { + None + }, + ) + } + + fn create_container_with_az_strategy( + use_management_connections: bool, + ) -> ConnectionsContainer { + let slot_map = SlotMap::new( + vec![ + Slot::new(1, 1000, "primary1".to_owned(), Vec::new()), + Slot::new( + 1002, + 2000, + "primary2".to_owned(), + vec!["replica2-1".to_owned()], + ), + Slot::new( + 2001, + 3000, + "primary3".to_owned(), + vec![ + "replica3-1".to_owned(), + "replica3-2".to_owned(), + "replica3-3".to_owned(), + ], + ), + ], + ReadFromReplicaStrategy::AlwaysFromPrimary, // this argument shouldn't matter, since we overload the RFR strategy. + ); + let connection_map = DashMap::new(); + connection_map.insert( + "primary1".into(), + create_cluster_node(1, use_management_connections, None), + ); + connection_map.insert( + "primary2".into(), + create_cluster_node(2, use_management_connections, None), + ); + connection_map.insert( + "primary3".into(), + create_cluster_node(3, use_management_connections, None), + ); + connection_map.insert( + "replica2-1".into(), + create_cluster_node(21, use_management_connections, None), + ); + connection_map.insert( + "replica3-1".into(), + create_cluster_node(31, use_management_connections, Some("use-1a".to_string())), + ); + connection_map.insert( + "replica3-2".into(), + create_cluster_node(32, use_management_connections, Some("use-1b".to_string())), + ); + connection_map.insert( + "replica3-3".into(), + create_cluster_node(33, use_management_connections, Some("use-1a".to_string())), + ); + connection_map.insert( + "replica3-3".into(), + create_cluster_node(33, use_management_connections, Some("use-1a".to_string())), + ); + + ConnectionsContainer { + slot_map, + connection_map, + read_from_replica_strategy: ReadFromReplicaStrategy::AZAffinity("use-1a".to_string()), + topology_hash: 0, + } + } + + fn create_container_with_strategy( + strategy: ReadFromReplicaStrategy, + use_management_connections: bool, + ) -> ConnectionsContainer { + let slot_map = SlotMap::new( + vec![ + Slot::new(1, 1000, "primary1".to_owned(), Vec::new()), + Slot::new( + 1002, + 2000, + "primary2".to_owned(), + vec!["replica2-1".to_owned()], + ), + Slot::new( + 2001, + 3000, + "primary3".to_owned(), + vec!["replica3-1".to_owned(), "replica3-2".to_owned()], + ), + ], + ReadFromReplicaStrategy::AlwaysFromPrimary, // this argument shouldn't matter, since we overload the RFR strategy. + ); + let connection_map = DashMap::new(); + connection_map.insert( + "primary1".into(), + create_cluster_node(1, use_management_connections, None), + ); + connection_map.insert( + "primary2".into(), + create_cluster_node(2, use_management_connections, None), + ); + connection_map.insert( + "primary3".into(), + create_cluster_node(3, use_management_connections, None), + ); + connection_map.insert( + "replica2-1".into(), + create_cluster_node(21, use_management_connections, None), + ); + connection_map.insert( + "replica3-1".into(), + create_cluster_node(31, use_management_connections, None), + ); + connection_map.insert( + "replica3-2".into(), + create_cluster_node(32, use_management_connections, None), + ); + + ConnectionsContainer { + slot_map, + connection_map, + read_from_replica_strategy: strategy, + topology_hash: 0, + } + } + + fn create_container() -> ConnectionsContainer { + create_container_with_strategy(ReadFromReplicaStrategy::RoundRobin, false) + } + + #[test] + fn get_connection_for_primary_route() { + let container = create_container(); + + assert!(container + .connection_for_route(&Route::new(0, SlotAddr::Master)) + .is_none()); + + assert_eq!( + 1, + container + .connection_for_route(&Route::new(500, SlotAddr::Master)) + .unwrap() + .1 + ); + + assert_eq!( + 1, + container + .connection_for_route(&Route::new(1000, SlotAddr::Master)) + .unwrap() + .1 + ); + + assert!(container + .connection_for_route(&Route::new(1001, SlotAddr::Master)) + .is_none()); + + assert_eq!( + 2, + container + .connection_for_route(&Route::new(1002, SlotAddr::Master)) + .unwrap() + .1 + ); + + assert_eq!( + 2, + container + .connection_for_route(&Route::new(1500, SlotAddr::Master)) + .unwrap() + .1 + ); + + assert_eq!( + 3, + container + .connection_for_route(&Route::new(2001, SlotAddr::Master)) + .unwrap() + .1 + ); + } + + #[test] + fn get_connection_for_replica_route() { + let container = create_container(); + + assert!(container + .connection_for_route(&Route::new(1001, SlotAddr::ReplicaOptional)) + .is_none()); + + assert_eq!( + 21, + container + .connection_for_route(&Route::new(1002, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + + assert_eq!( + 21, + container + .connection_for_route(&Route::new(1500, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + + assert!(one_of( + container.connection_for_route(&Route::new(2001, SlotAddr::ReplicaOptional)), + &[31, 32], + )); + } + + #[test] + fn get_primary_connection_for_replica_route_if_no_replicas_were_added() { + let container = create_container(); + + assert!(container + .connection_for_route(&Route::new(0, SlotAddr::ReplicaOptional)) + .is_none()); + + assert_eq!( + 1, + container + .connection_for_route(&Route::new(500, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + + assert_eq!( + 1, + container + .connection_for_route(&Route::new(1000, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + } + + #[test] + fn get_replica_connection_for_replica_route_if_some_but_not_all_replicas_were_removed() { + let container = create_container(); + container.remove_node(&"replica3-2".into()); + + assert_eq!( + 31, + container + .connection_for_route(&Route::new(2001, SlotAddr::ReplicaRequired)) + .unwrap() + .1 + ); + } + + #[test] + fn get_replica_connection_for_replica_route_if_replica_is_required_even_if_strategy_is_always_from_primary( + ) { + let container = + create_container_with_strategy(ReadFromReplicaStrategy::AlwaysFromPrimary, false); + + assert!(one_of( + container.connection_for_route(&Route::new(2001, SlotAddr::ReplicaRequired)), + &[31, 32], + )); + } + + #[test] + fn get_primary_connection_for_replica_route_if_all_replicas_were_removed() { + let container = create_container(); + remove_nodes(&container, &["replica2-1", "replica3-1", "replica3-2"]); + + assert_eq!( + 2, + container + .connection_for_route(&Route::new(1002, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + + assert_eq!( + 2, + container + .connection_for_route(&Route::new(1500, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + + assert_eq!( + 3, + container + .connection_for_route(&Route::new(2001, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + } + + #[test] + fn get_connection_for_az_affinity_route() { + let container = create_container_with_az_strategy(false); + + // slot number is not exits + assert!(container + .connection_for_route(&Route::new(1001, SlotAddr::ReplicaOptional)) + .is_none()); + // Get the replica that holds the slot 1002 + assert_eq!( + 21, + container + .connection_for_route(&Route::new(1002, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + + // Get the Primary that holds the slot 1500 + assert_eq!( + 2, + container + .connection_for_route(&Route::new(1500, SlotAddr::Master)) + .unwrap() + .1 + ); + + // receive one of the replicas that holds the slot 2001 and is in the availability zone of the client ("use-1a") + assert!(one_of( + container.connection_for_route(&Route::new(2001, SlotAddr::ReplicaRequired)), + &[31, 33], + )); + + // remove the replica in the same client's az and get the other replica in the same az + remove_nodes(&container, &["replica3-3"]); + assert_eq!( + 31, + container + .connection_for_route(&Route::new(2001, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + + // remove the replica in the same clients az and get the other replica + remove_nodes(&container, &["replica3-1"]); + assert_eq!( + 32, + container + .connection_for_route(&Route::new(2001, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + + // remove the last replica and get the primary + remove_nodes(&container, &["replica3-2"]); + assert_eq!( + 3, + container + .connection_for_route(&Route::new(2001, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + } + + #[test] + fn get_connection_for_az_affinity_route_round_robin() { + let container = create_container_with_az_strategy(false); + + let mut addresses = vec![ + container + .connection_for_route(&Route::new(2001, SlotAddr::ReplicaOptional)) + .unwrap() + .1, + container + .connection_for_route(&Route::new(2001, SlotAddr::ReplicaOptional)) + .unwrap() + .1, + container + .connection_for_route(&Route::new(2001, SlotAddr::ReplicaOptional)) + .unwrap() + .1, + container + .connection_for_route(&Route::new(2001, SlotAddr::ReplicaOptional)) + .unwrap() + .1, + ]; + addresses.sort(); + assert_eq!(addresses, vec![31, 31, 33, 33]); + } + + #[test] + fn get_connection_by_address() { + let container = create_container(); + + assert!(container.connection_for_address("foobar").is_none()); + + assert_eq!(1, container.connection_for_address("primary1").unwrap().1); + assert_eq!(2, container.connection_for_address("primary2").unwrap().1); + assert_eq!(3, container.connection_for_address("primary3").unwrap().1); + assert_eq!( + 21, + container.connection_for_address("replica2-1").unwrap().1 + ); + assert_eq!( + 31, + container.connection_for_address("replica3-1").unwrap().1 + ); + assert_eq!( + 32, + container.connection_for_address("replica3-2").unwrap().1 + ); + } + + #[test] + fn get_connection_by_address_returns_none_if_connection_was_removed() { + let container = create_container(); + container.remove_node(&"primary1".into()); + + assert!(container.connection_for_address("primary1").is_none()); + } + + #[test] + fn get_connection_by_address_returns_added_connection() { + let container = create_container(); + let address = container.replace_or_add_connection_for_address( + "foobar", + ClusterNode::new_only_with_user_conn(4), + ); + + assert_eq!( + (address, 4), + container.connection_for_address("foobar").unwrap() + ); + } + + #[test] + fn get_random_connections_without_repetitions() { + let container = create_container(); + + let random_connections: HashSet<_> = container + .random_connections(3, ConnectionType::User) + .expect("No connections found") + .into_iter() + .map(|pair| pair.1) + .collect(); + + assert_eq!(random_connections.len(), 3); + assert!(random_connections + .iter() + .all(|connection| [1, 2, 3, 21, 31, 32].contains(connection))); + } + + #[test] + fn get_random_connections_returns_none_if_all_connections_were_removed() { + let container = create_container(); + remove_all_connections(&container); + + assert!(container + .random_connections(1, ConnectionType::User) + .is_none()); + } + + #[test] + fn get_random_connections_returns_added_connection() { + let container = create_container(); + remove_all_connections(&container); + let address = container.replace_or_add_connection_for_address( + "foobar", + ClusterNode::new_only_with_user_conn(4), + ); + let random_connections: Vec<_> = container + .random_connections(1, ConnectionType::User) + .expect("No connections found") + .into_iter() + .collect(); + + assert_eq!(vec![(address, 4)], random_connections); + } + + #[test] + fn get_random_connections_is_bound_by_the_number_of_connections_in_the_map() { + let container = create_container(); + let mut random_connections: Vec<_> = container + .random_connections(1000, ConnectionType::User) + .expect("No connections found") + .into_iter() + .map(|pair| pair.1) + .collect(); + random_connections.sort(); + + assert_eq!(random_connections, vec![1, 2, 3, 21, 31, 32]); + } + + #[test] + fn get_random_management_connections() { + let container = create_container_with_strategy(ReadFromReplicaStrategy::RoundRobin, true); + let mut random_connections: Vec<_> = container + .random_connections(1000, ConnectionType::PreferManagement) + .expect("No connections found") + .into_iter() + .map(|pair| pair.1) + .collect(); + random_connections.sort(); + + assert_eq!(random_connections, vec![10, 20, 30, 210, 310, 320]); + } + + #[test] + fn get_all_user_connections() { + let container = create_container(); + let mut connections: Vec<_> = container + .all_node_connections() + .map(|conn| conn.1) + .collect(); + connections.sort(); + + assert_eq!(vec![1, 2, 3, 21, 31, 32], connections); + } + + #[test] + fn get_all_user_connections_returns_added_connection() { + let container = create_container(); + container.replace_or_add_connection_for_address( + "foobar", + ClusterNode::new_only_with_user_conn(4), + ); + + let mut connections: Vec<_> = container + .all_node_connections() + .map(|conn| conn.1) + .collect(); + connections.sort(); + + assert_eq!(vec![1, 2, 3, 4, 21, 31, 32], connections); + } + + #[test] + fn get_all_user_connections_does_not_return_removed_connection() { + let container = create_container(); + container.remove_node(&"primary1".into()); + + let mut connections: Vec<_> = container + .all_node_connections() + .map(|conn| conn.1) + .collect(); + connections.sort(); + + assert_eq!(vec![2, 3, 21, 31, 32], connections); + } + + #[test] + fn get_all_primaries() { + let container = create_container(); + + let mut connections: Vec<_> = container + .all_primary_connections() + .map(|conn| conn.1) + .collect(); + connections.sort(); + + assert_eq!(vec![1, 2, 3], connections); + } + + #[test] + fn get_all_primaries_does_not_return_removed_connection() { + let container = create_container(); + container.remove_node(&"primary1".into()); + + let mut connections: Vec<_> = container + .all_primary_connections() + .map(|conn| conn.1) + .collect(); + connections.sort(); + + assert_eq!(vec![2, 3], connections); + } + + #[test] + fn len_is_adjusted_on_removals_and_additions() { + let container = create_container(); + + assert_eq!(container.len(), 6); + + container.remove_node(&"primary1".into()); + assert_eq!(container.len(), 5); + + container.replace_or_add_connection_for_address( + "foobar", + ClusterNode::new_only_with_user_conn(4), + ); + assert_eq!(container.len(), 6); + } + + #[test] + fn len_is_not_adjusted_on_removals_of_nonexisting_connections_or_additions_of_existing_connections( + ) { + let container = create_container(); + + assert_eq!(container.len(), 6); + + container.remove_node(&"foobar".into()); + assert_eq!(container.len(), 6); + + container.replace_or_add_connection_for_address( + "primary1", + ClusterNode::new_only_with_user_conn(4), + ); + assert_eq!(container.len(), 6); + } + + #[test] + fn remove_node_returns_connection_if_it_exists() { + let container = create_container(); + + let connection = container.remove_node(&"primary1".into()); + assert_eq!(connection, Some(ClusterNode::new_only_with_user_conn(1))); + + let non_connection = container.remove_node(&"foobar".into()); + assert_eq!(non_connection, None); + } + + #[test] + fn test_is_empty() { + let container = create_container(); + + assert!(!container.is_empty()); + container.remove_node(&"primary1".into()); + assert!(!container.is_empty()); + container.remove_node(&"primary2".into()); + container.remove_node(&"primary3".into()); + assert!(!container.is_empty()); + + container.remove_node(&"replica2-1".into()); + container.remove_node(&"replica3-1".into()); + assert!(!container.is_empty()); + + container.remove_node(&"replica3-2".into()); + assert!(container.is_empty()); + } + + #[test] + fn is_primary_returns_true_for_known_primary() { + let container = create_container(); + + assert!(container.is_primary(&"primary1".into())); + } + + #[test] + fn is_primary_returns_false_for_known_replica() { + let container = create_container(); + + assert!(!container.is_primary(&"replica2-1".into())); + } + + #[test] + fn is_primary_returns_false_for_removed_node() { + let container = create_container(); + let address = "primary1".into(); + container.remove_node(&address); + + assert!(!container.is_primary(&address)); + } + + #[test] + fn test_extend_connection_map() { + let mut container = create_container(); + let mut current_addresses: Vec<_> = container + .all_node_connections() + .map(|conn| conn.0) + .collect(); + + let new_node = "new_primary1".to_string(); + // Check that `new_node` not exists in the current + assert!(container.connection_for_address(&new_node).is_none()); + // Create new connection map + let new_connection_map = DashMap::new(); + new_connection_map.insert(new_node.clone(), create_cluster_node(1, false, None)); + + // Extend the current connection map + container.extend_connection_map(ConnectionsMap(new_connection_map)); + + // Check that the new addresses vector contains both the new node and all previous nodes + let mut new_addresses: Vec<_> = container + .all_node_connections() + .map(|conn| conn.0) + .collect(); + current_addresses.push(new_node); + current_addresses.sort(); + new_addresses.sort(); + assert_eq!(current_addresses, new_addresses); + } +} diff --git a/glide-core/redis-rs/redis/src/cluster_async/connections_logic.rs b/glide-core/redis-rs/redis/src/cluster_async/connections_logic.rs new file mode 100644 index 0000000000..4f9b3f0d4e --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_async/connections_logic.rs @@ -0,0 +1,496 @@ +use super::{ + connections_container::{ClusterNode, ConnectionDetails}, + Connect, +}; +use crate::cluster_slotmap::ReadFromReplicaStrategy; +use crate::{ + aio::{ConnectionLike, DisconnectNotifier}, + client::GlideConnectionOptions, + cluster::get_connection_info, + cluster_client::ClusterParams, + ErrorKind, RedisError, RedisResult, +}; +use std::net::SocketAddr; + +use futures::prelude::*; +use futures_util::{future::BoxFuture, join}; +use tracing::warn; + +pub(crate) type ConnectionFuture = futures::future::Shared>; +/// Cluster node for async connections +#[doc(hidden)] +pub type AsyncClusterNode = ClusterNode>; + +#[doc(hidden)] +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum RefreshConnectionType { + // Refresh only user connections + OnlyUserConnection, + // Refresh only management connections + OnlyManagementConnection, + // Refresh all connections: both management and user connections. + AllConnections, +} + +fn failed_management_connection( + addr: &str, + user_conn: ConnectionDetails>, + err: RedisError, +) -> ConnectAndCheckResult +where + C: ConnectionLike + Send + Clone + Sync + Connect + 'static, +{ + warn!( + "Failed to create management connection for node `{:?}`. Error: `{:?}`", + addr, err + ); + ConnectAndCheckResult::ManagementConnectionFailed { + node: AsyncClusterNode::new(user_conn, None), + err, + } +} + +pub(crate) async fn get_or_create_conn( + addr: &str, + node: Option>, + params: &ClusterParams, + conn_type: RefreshConnectionType, + glide_connection_options: GlideConnectionOptions, +) -> RedisResult> +where + C: ConnectionLike + Send + Clone + Sync + Connect + 'static, +{ + if let Some(node) = node { + // We won't check whether the DNS address of this node has changed and now points to a new IP. + // Instead, we depend on managed Redis services to close the connection for refresh if the node has changed. + match check_node_connections(&node, params, conn_type, addr).await { + None => Ok(node), + Some(conn_type) => connect_and_check( + addr, + params.clone(), + None, + conn_type, + Some(node), + glide_connection_options, + ) + .await + .get_node(), + } + } else { + connect_and_check( + addr, + params.clone(), + None, + conn_type, + None, + glide_connection_options, + ) + .await + .get_node() + } +} + +fn create_async_node( + user_conn: ConnectionDetails, + management_conn: Option>, +) -> AsyncClusterNode +where + C: ConnectionLike + Connect + Send + Sync + 'static + Clone, +{ + AsyncClusterNode::new( + user_conn.into_future(), + management_conn.map(|conn| conn.into_future()), + ) +} + +pub(crate) async fn connect_and_check_all_connections( + addr: &str, + params: ClusterParams, + socket_addr: Option, + glide_connection_options: GlideConnectionOptions, +) -> ConnectAndCheckResult +where + C: ConnectionLike + Connect + Send + Sync + 'static + Clone, +{ + match future::join( + // User connection + create_connection( + addr, + params.clone(), + socket_addr, + false, + glide_connection_options.clone(), + ), + // Management connection + create_connection( + addr, + params.clone(), + socket_addr, + true, + glide_connection_options, + ), + ) + .await + { + (Ok(conn_1), Ok(conn_2)) => { + // Both connections were successfully established + let mut user_conn: ConnectionDetails = conn_1; + let mut management_conn: ConnectionDetails = conn_2; + if let Err(err) = setup_user_connection(&mut user_conn, params).await { + return err.into(); + } + match setup_management_connection(&mut management_conn.conn).await { + Ok(_) => ConnectAndCheckResult::Success(create_async_node( + user_conn, + Some(management_conn), + )), + Err(err) => failed_management_connection(addr, user_conn.into_future(), err), + } + } + (Ok(mut connection), Err(err)) | (Err(err), Ok(mut connection)) => { + // Only a single connection was successfully established. Use it for the user connection + match setup_user_connection(&mut connection, params).await { + Ok(_) => failed_management_connection(addr, connection.into_future(), err), + Err(err) => err.into(), + } + } + (Err(err_1), Err(err_2)) => { + // Neither of the connections succeeded. + RedisError::from(( + ErrorKind::IoError, + "Failed to refresh both connections", + format!( + "Node: {:?} received errors: `{:?}`, `{:?}`", + addr, err_1, err_2 + ), + )) + .into() + } + } +} + +async fn connect_and_check_only_management_conn( + addr: &str, + params: ClusterParams, + socket_addr: Option, + prev_node: AsyncClusterNode, + disconnect_notifier: Option>, +) -> ConnectAndCheckResult +where + C: ConnectionLike + Connect + Send + Sync + 'static + Clone, +{ + let discover_az = matches!( + params.read_from_replicas, + crate::cluster_slotmap::ReadFromReplicaStrategy::AZAffinity(_) + ); + + match create_connection::( + addr, + params.clone(), + socket_addr, + true, + GlideConnectionOptions { + push_sender: None, + disconnect_notifier, + discover_az, + }, + ) + .await + { + Err(conn_err) => failed_management_connection(addr, prev_node.user_connection, conn_err), + + Ok(mut connection) => { + if let Err(err) = setup_management_connection(&mut connection.conn).await { + return failed_management_connection(addr, prev_node.user_connection, err); + } + + ConnectAndCheckResult::Success(ClusterNode { + user_connection: prev_node.user_connection, + management_connection: Some(connection.into_future()), + }) + } + } +} + +#[doc(hidden)] +#[must_use] +pub enum ConnectAndCheckResult { + // Returns a node that was fully connected according to the request. + Success(AsyncClusterNode), + // Returns a node that failed to create a management connection, but has a working user connection. + ManagementConnectionFailed { + node: AsyncClusterNode, + err: RedisError, + }, + // Request failed completely, could not return a node with any working connection. + Failed(RedisError), +} + +impl ConnectAndCheckResult { + pub fn get_node(self) -> RedisResult> { + match self { + ConnectAndCheckResult::Success(node) => Ok(node), + ConnectAndCheckResult::ManagementConnectionFailed { node, .. } => Ok(node), + ConnectAndCheckResult::Failed(err) => Err(err), + } + } + + pub fn get_error(self) -> Option { + match self { + ConnectAndCheckResult::Success(_) => None, + ConnectAndCheckResult::ManagementConnectionFailed { err, .. } => Some(err), + ConnectAndCheckResult::Failed(err) => Some(err), + } + } +} + +impl From for ConnectAndCheckResult { + fn from(value: RedisError) -> Self { + ConnectAndCheckResult::Failed(value) + } +} + +impl From> for ConnectAndCheckResult { + fn from(value: AsyncClusterNode) -> Self { + ConnectAndCheckResult::Success(value) + } +} + +impl From>> for ConnectAndCheckResult { + fn from(value: RedisResult>) -> Self { + match value { + Ok(value) => value.into(), + Err(err) => err.into(), + } + } +} + +#[doc(hidden)] +pub async fn connect_and_check( + addr: &str, + params: ClusterParams, + socket_addr: Option, + conn_type: RefreshConnectionType, + node: Option>, + glide_connection_options: GlideConnectionOptions, +) -> ConnectAndCheckResult +where + C: ConnectionLike + Connect + Send + Sync + 'static + Clone, +{ + match conn_type { + RefreshConnectionType::OnlyUserConnection => { + let user_conn = match create_and_setup_user_connection( + addr, + params.clone(), + socket_addr, + glide_connection_options, + ) + .await + { + Ok(tuple) => tuple, + Err(err) => return err.into(), + }; + let management_conn = node.and_then(|node| node.management_connection); + AsyncClusterNode::new(user_conn.into_future(), management_conn).into() + } + RefreshConnectionType::OnlyManagementConnection => { + // Refreshing only the management connection requires the node to exist alongside a user connection. Otherwise, refresh all connections. + match node { + Some(node) => { + connect_and_check_only_management_conn( + addr, + params, + socket_addr, + node, + glide_connection_options.disconnect_notifier, + ) + .await + } + None => { + connect_and_check_all_connections( + addr, + params, + socket_addr, + glide_connection_options, + ) + .await + } + } + } + RefreshConnectionType::AllConnections => { + connect_and_check_all_connections(addr, params, socket_addr, glide_connection_options) + .await + } + } +} + +async fn create_and_setup_user_connection( + node: &str, + params: ClusterParams, + socket_addr: Option, + glide_connection_options: GlideConnectionOptions, +) -> RedisResult> +where + C: ConnectionLike + Connect + Send + 'static, +{ + let mut connection: ConnectionDetails = create_connection( + node, + params.clone(), + socket_addr, + false, + glide_connection_options, + ) + .await?; + setup_user_connection(&mut connection, params).await?; + Ok(connection) +} + +async fn setup_user_connection( + conn_details: &mut ConnectionDetails, + params: ClusterParams, +) -> RedisResult<()> +where + C: ConnectionLike + Connect + Send + 'static, +{ + let read_from_replicas = + params.read_from_replicas != ReadFromReplicaStrategy::AlwaysFromPrimary; + let connection_timeout = params.connection_timeout; + check_connection(&mut conn_details.conn, connection_timeout).await?; + if read_from_replicas { + // If READONLY is sent to primary nodes, it will have no effect + crate::cmd("READONLY") + .query_async(&mut conn_details.conn) + .await?; + } + + Ok(()) +} + +#[doc(hidden)] +pub const MANAGEMENT_CONN_NAME: &str = "glide_management_connection"; + +async fn setup_management_connection(conn: &mut C) -> RedisResult<()> +where + C: ConnectionLike + Connect + Send + 'static, +{ + crate::cmd("CLIENT") + .arg(&["SETNAME", MANAGEMENT_CONN_NAME]) + .query_async(conn) + .await?; + Ok(()) +} + +async fn create_connection( + node: &str, + mut params: ClusterParams, + socket_addr: Option, + is_management: bool, + mut glide_connection_options: GlideConnectionOptions, +) -> RedisResult> +where + C: ConnectionLike + Connect + Send + 'static, +{ + let connection_timeout = params.connection_timeout; + let response_timeout = params.response_timeout; + // ignore pubsub subscriptions and push notifications for management connections + if is_management { + params.pubsub_subscriptions = None; + } + let info = get_connection_info(node, params)?; + // management connection does not require notifications or disconnect notifications + if is_management { + glide_connection_options.disconnect_notifier = None; + } + C::connect( + info, + response_timeout, + connection_timeout, + socket_addr, + glide_connection_options, + ) + .await + .map(|conn| { + let az = conn.0.get_az(); + (conn.0, conn.1, az).into() + }) +} + +/// The function returns None if the checked connection/s are healthy. Otherwise, it returns the type of the unhealthy connection/s. +#[allow(dead_code)] +#[doc(hidden)] +pub async fn check_node_connections( + node: &AsyncClusterNode, + params: &ClusterParams, + conn_type: RefreshConnectionType, + address: &str, +) -> Option +where + C: ConnectionLike + Send + 'static + Clone, +{ + let timeout = params.connection_timeout; + let (check_mgmt_connection, check_user_connection) = match conn_type { + RefreshConnectionType::OnlyUserConnection => (false, true), + RefreshConnectionType::OnlyManagementConnection => (true, false), + RefreshConnectionType::AllConnections => (true, true), + }; + let check = |conn, timeout, conn_type| async move { + match check_connection(&mut conn.await, timeout).await { + Ok(_) => false, + Err(err) => { + warn!( + "The {} connection for node {} is unhealthy. Error: {:?}", + conn_type, address, err + ); + true + } + } + }; + let (mgmt_failed, user_failed) = join!( + async { + if !check_mgmt_connection { + return false; + } + match node.management_connection.clone() { + Some(connection) => check(connection.conn, timeout, "management").await, + None => { + warn!("The management connection for node {} isn't set", address); + true + } + } + }, + async { + if !check_user_connection { + return false; + } + let conn = node.user_connection.conn.clone(); + check(conn, timeout, "user").await + }, + ); + + match (mgmt_failed, user_failed) { + (true, true) => Some(RefreshConnectionType::AllConnections), + (true, false) => Some(RefreshConnectionType::OnlyManagementConnection), + (false, true) => Some(RefreshConnectionType::OnlyUserConnection), + (false, false) => None, + } +} + +async fn check_connection(conn: &mut C, timeout: std::time::Duration) -> RedisResult<()> +where + C: ConnectionLike + Send + 'static, +{ + tokio::time::timeout(timeout, crate::cmd("PING").query_async::<_, String>(conn)).await??; + Ok(()) +} + +/// Splits a string address into host and port. If the passed address cannot be parsed, None is returned. +/// [addr] should be in the following format: ":". +pub(crate) fn get_host_and_port_from_addr(addr: &str) -> Option<(&str, u16)> { + let parts: Vec<&str> = addr.split(':').collect(); + if parts.len() != 2 { + return None; + } + let host = parts.first().unwrap(); + let port = parts.get(1).unwrap(); + port.parse::().ok().map(|port| (*host, port)) +} diff --git a/glide-core/redis-rs/redis/src/cluster_async/mod.rs b/glide-core/redis-rs/redis/src/cluster_async/mod.rs new file mode 100644 index 0000000000..3726d7a674 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_async/mod.rs @@ -0,0 +1,2979 @@ +//! This module provides async functionality for Redis Cluster. +//! +//! By default, [`ClusterConnection`] makes use of [`MultiplexedConnection`] and maintains a pool +//! of connections to each node in the cluster. While it generally behaves similarly to +//! the sync cluster module, certain commands do not route identically, due most notably to +//! a current lack of support for routing commands to multiple nodes. +//! +//! Also note that pubsub functionality is not currently provided by this module. +//! +//! # Example +//! ```rust,no_run +//! use redis::cluster::ClusterClient; +//! use redis::AsyncCommands; +//! +//! async fn fetch_an_integer() -> String { +//! let nodes = vec!["redis://127.0.0.1/"]; +//! let client = ClusterClient::new(nodes).unwrap(); +//! let mut connection = client.get_async_connection(None).await.unwrap(); +//! let _: () = connection.set("test", "test_data").await.unwrap(); +//! let rv: String = connection.get("test").await.unwrap(); +//! return rv; +//! } +//! ``` + +mod connections_container; +mod connections_logic; +/// Exposed only for testing. +pub mod testing { + pub use super::connections_container::ConnectionDetails; + pub use super::connections_logic::*; +} +use crate::{ + client::GlideConnectionOptions, + cluster_routing::{Routable, RoutingInfo, ShardUpdateResult}, + cluster_slotmap::SlotMap, + cluster_topology::{ + calculate_topology, get_slot, SlotRefreshState, DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES, + DEFAULT_REFRESH_SLOTS_RETRY_BASE_DURATION_MILLIS, DEFAULT_REFRESH_SLOTS_RETRY_BASE_FACTOR, + SLOT_SIZE, + }, + cmd, + commands::cluster_scan::{cluster_scan, ClusterScanArgs, ObjectType, ScanStateRC}, + FromRedisValue, InfoDict, ToRedisArgs, +}; +use dashmap::DashMap; +use std::{ + collections::{HashMap, HashSet}, + fmt, io, mem, + net::{IpAddr, SocketAddr}, + pin::Pin, + sync::{ + atomic::{self, AtomicUsize, Ordering}, + Arc, Mutex, + }, + task::{self, Poll}, + time::SystemTime, +}; +use strum_macros::Display; +#[cfg(feature = "tokio-comp")] +use tokio::task::JoinHandle; + +#[cfg(feature = "tokio-comp")] +use crate::aio::DisconnectNotifier; +use telemetrylib::Telemetry; + +use crate::{ + aio::{get_socket_addrs, ConnectionLike, MultiplexedConnection, Runtime}, + cluster::slot_cmd, + cluster_async::connections_logic::{ + get_host_and_port_from_addr, get_or_create_conn, ConnectionFuture, RefreshConnectionType, + }, + cluster_client::{ClusterParams, RetryParams}, + cluster_routing::{ + self, MultipleNodeRoutingInfo, Redirect, ResponsePolicy, Route, SingleNodeRoutingInfo, + SlotAddr, + }, + connection::{PubSubSubscriptionInfo, PubSubSubscriptionKind}, + push_manager::PushInfo, + Cmd, ConnectionInfo, ErrorKind, IntoConnectionInfo, RedisError, RedisFuture, RedisResult, + Value, +}; +use futures::stream::{FuturesUnordered, StreamExt}; +use std::time::Duration; + +#[cfg(feature = "tokio-comp")] +use async_trait::async_trait; +#[cfg(feature = "tokio-comp")] +use tokio_retry2::strategy::{jitter_range, ExponentialFactorBackoff}; +#[cfg(feature = "tokio-comp")] +use tokio_retry2::{Retry, RetryError}; + +#[cfg(feature = "tokio-comp")] +use tokio::{sync::Notify, time::timeout}; + +use dispose::{Disposable, Dispose}; +use futures::{future::BoxFuture, prelude::*, ready}; +use pin_project_lite::pin_project; +use std::sync::RwLock as StdRwLock; +use tokio::sync::{ + mpsc, + oneshot::{self, Receiver}, + RwLock as TokioRwLock, +}; +use tracing::{debug, info, trace, warn}; + +use self::{ + connections_container::{ConnectionAndAddress, ConnectionType, ConnectionsMap}, + connections_logic::connect_and_check, +}; +use crate::types::RetryMethod; + +pub(crate) const MUTEX_READ_ERR: &str = "Failed to obtain read lock. Poisoned mutex?"; +const MUTEX_WRITE_ERR: &str = "Failed to obtain write lock. Poisoned mutex?"; +/// This represents an async Redis Cluster connection. It stores the +/// underlying connections maintained for each node in the cluster, as well +/// as common parameters for connecting to nodes and executing commands. +#[derive(Clone)] +pub struct ClusterConnection(mpsc::Sender>); + +impl ClusterConnection +where + C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static, +{ + pub(crate) async fn new( + initial_nodes: &[ConnectionInfo], + cluster_params: ClusterParams, + push_sender: Option>, + ) -> RedisResult> { + ClusterConnInner::new(initial_nodes, cluster_params, push_sender) + .await + .map(|inner| { + let (tx, mut rx) = mpsc::channel::>(100); + let stream = async move { + let _ = stream::poll_fn(move |cx| rx.poll_recv(cx)) + .map(Ok) + .forward(inner) + .await; + }; + #[cfg(feature = "tokio-comp")] + tokio::spawn(stream); + ClusterConnection(tx) + }) + } + + /// Special handling for `SCAN` command, using `cluster_scan`. + /// If you wish to use a match pattern, use [`cluster_scan_with_pattern`]. + /// Perform a `SCAN` command on a Redis cluster, using scan state object in order to handle changes in topology + /// and make sure that all keys that were in the cluster from start to end of the scan are scanned. + /// In order to make sure all keys in the cluster scanned, topology refresh occurs more frequently and may affect performance. + /// + /// # Arguments + /// + /// * `scan_state_rc` - A reference to the scan state, For initiating new scan send [`ScanStateRC::new()`], + /// for each subsequent iteration use the returned [`ScanStateRC`]. + /// * `count` - An optional count of keys requested, + /// the amount returned can vary and not obligated to return exactly count. + /// * `object_type` - An optional [`ObjectType`] enum of requested key redis type. + /// + /// # Returns + /// + /// A [`ScanStateRC`] for the updated state of the scan and the vector of keys that were found in the scan. + /// structure of returned value: + /// `Ok((ScanStateRC, Vec))` + /// + /// When the scan is finished [`ScanStateRC`] will be None, and can be checked by calling `scan_state_wrapper.is_finished()`. + /// + /// # Example + /// ```rust,no_run + /// use redis::cluster::ClusterClient; + /// use redis::{ScanStateRC, FromRedisValue, from_redis_value, Value, ObjectType}; + /// + /// async fn scan_all_cluster() -> Vec { + /// let nodes = vec!["redis://127.0.0.1/"]; + /// let client = ClusterClient::new(nodes).unwrap(); + /// let mut connection = client.get_async_connection(None).await.unwrap(); + /// let mut scan_state_rc = ScanStateRC::new(); + /// let mut keys: Vec = vec![]; + /// loop { + /// let (next_cursor, scan_keys): (ScanStateRC, Vec) = + /// connection.cluster_scan(scan_state_rc, None, None).await.unwrap(); + /// scan_state_rc = next_cursor; + /// let mut scan_keys = scan_keys + /// .into_iter() + /// .map(|v| from_redis_value(&v).unwrap()) + /// .collect::>(); // Change the type of `keys` to `Vec` + /// keys.append(&mut scan_keys); + /// if scan_state_rc.is_finished() { + /// break; + /// } + /// } + /// keys + /// } + /// ``` + pub async fn cluster_scan( + &mut self, + scan_state_rc: ScanStateRC, + count: Option, + object_type: Option, + ) -> RedisResult<(ScanStateRC, Vec)> { + let cluster_scan_args = ClusterScanArgs::new(scan_state_rc, None, count, object_type); + self.route_cluster_scan(cluster_scan_args).await + } + + /// Special handling for `SCAN` command, using `cluster_scan_with_pattern`. + /// It is a special case of [`cluster_scan`], with an additional match pattern. + /// Perform a `SCAN` command on a Redis cluster, using scan state object in order to handle changes in topology + /// and make sure that all keys that were in the cluster from start to end of the scan are scanned. + /// In order to make sure all keys in the cluster scanned, topology refresh occurs more frequently and may affect performance. + /// + /// # Arguments + /// + /// * `scan_state_rc` - A reference to the scan state, For initiating new scan send [`ScanStateRC::new()`], + /// for each subsequent iteration use the returned [`ScanStateRC`]. + /// * `match_pattern` - A match pattern of requested keys. + /// * `count` - An optional count of keys requested, + /// the amount returned can vary and not obligated to return exactly count. + /// * `object_type` - An optional [`ObjectType`] enum of requested key redis type. + /// + /// # Returns + /// + /// A [`ScanStateRC`] for the updated state of the scan and the vector of keys that were found in the scan. + /// structure of returned value: + /// `Ok((ScanStateRC, Vec))` + /// + /// When the scan is finished [`ScanStateRC`] will be None, and can be checked by calling `scan_state_wrapper.is_finished()`. + /// + /// # Example + /// ```rust,no_run + /// use redis::cluster::ClusterClient; + /// use redis::{ScanStateRC, FromRedisValue, from_redis_value, Value, ObjectType}; + /// + /// async fn scan_all_cluster() -> Vec { + /// let nodes = vec!["redis://127.0.0.1/"]; + /// let client = ClusterClient::new(nodes).unwrap(); + /// let mut connection = client.get_async_connection(None).await.unwrap(); + /// let mut scan_state_rc = ScanStateRC::new(); + /// let mut keys: Vec = vec![]; + /// loop { + /// let (next_cursor, scan_keys): (ScanStateRC, Vec) = + /// connection.cluster_scan_with_pattern(scan_state_rc, b"my_key", None, None).await.unwrap(); + /// scan_state_rc = next_cursor; + /// let mut scan_keys = scan_keys + /// .into_iter() + /// .map(|v| from_redis_value(&v).unwrap()) + /// .collect::>(); // Change the type of `keys` to `Vec` + /// keys.append(&mut scan_keys); + /// if scan_state_rc.is_finished() { + /// break; + /// } + /// } + /// keys + /// } + /// ``` + pub async fn cluster_scan_with_pattern( + &mut self, + scan_state_rc: ScanStateRC, + match_pattern: K, + count: Option, + object_type: Option, + ) -> RedisResult<(ScanStateRC, Vec)> { + let cluster_scan_args = ClusterScanArgs::new( + scan_state_rc, + Some(match_pattern.to_redis_args().concat()), + count, + object_type, + ); + self.route_cluster_scan(cluster_scan_args).await + } + + /// Route cluster scan to be handled by internal cluster_scan command + async fn route_cluster_scan( + &mut self, + cluster_scan_args: ClusterScanArgs, + ) -> RedisResult<(ScanStateRC, Vec)> { + let (sender, receiver) = oneshot::channel(); + self.0 + .send(Message { + cmd: CmdArg::ClusterScan { cluster_scan_args }, + sender, + }) + .await + .map_err(|_| { + RedisError::from(io::Error::new( + io::ErrorKind::BrokenPipe, + "redis_cluster: Unable to send command", + )) + })?; + receiver + .await + .unwrap_or_else(|_| { + Err(RedisError::from(io::Error::new( + io::ErrorKind::BrokenPipe, + "redis_cluster: Unable to receive command", + ))) + }) + .map(|response| match response { + Response::ClusterScanResult(new_scan_state_ref, key) => (new_scan_state_ref, key), + Response::Single(_) | Response::Multiple(_) => unreachable!(), + }) + } + + /// Send a command to the given `routing`. If `routing` is [None], it will be computed from `cmd`. + pub async fn route_command( + &mut self, + cmd: &Cmd, + routing: cluster_routing::RoutingInfo, + ) -> RedisResult { + trace!("route_command"); + let (sender, receiver) = oneshot::channel(); + self.0 + .send(Message { + cmd: CmdArg::Cmd { + cmd: Arc::new(cmd.clone()), + routing: routing.into(), + }, + sender, + }) + .await + .map_err(|_| { + RedisError::from(io::Error::new( + io::ErrorKind::BrokenPipe, + "redis_cluster: Unable to send command", + )) + })?; + receiver + .await + .unwrap_or_else(|_| { + Err(RedisError::from(io::Error::new( + io::ErrorKind::BrokenPipe, + "redis_cluster: Unable to receive command", + ))) + }) + .map(|response| match response { + Response::Single(value) => value, + Response::ClusterScanResult(..) | Response::Multiple(_) => unreachable!(), + }) + } + + /// Send commands in `pipeline` to the given `route`. If `route` is [None], it will be computed from `pipeline`. + pub async fn route_pipeline<'a>( + &'a mut self, + pipeline: &'a crate::Pipeline, + offset: usize, + count: usize, + route: SingleNodeRoutingInfo, + ) -> RedisResult> { + let (sender, receiver) = oneshot::channel(); + self.0 + .send(Message { + cmd: CmdArg::Pipeline { + pipeline: Arc::new(pipeline.clone()), + offset, + count, + route: route.into(), + }, + sender, + }) + .await + .map_err(|err| { + RedisError::from(io::Error::new(io::ErrorKind::BrokenPipe, err.to_string())) + })?; + + receiver + .await + .unwrap_or_else(|err| { + Err(RedisError::from(io::Error::new( + io::ErrorKind::BrokenPipe, + err.to_string(), + ))) + }) + .map(|response| match response { + Response::Multiple(values) => values, + Response::ClusterScanResult(..) | Response::Single(_) => unreachable!(), + }) + } + /// Update the password used to authenticate with all cluster servers + pub async fn update_connection_password( + &mut self, + password: Option, + ) -> RedisResult { + self.route_operation_request(Operation::UpdateConnectionPassword(password)) + .await + } + + /// Routes an operation request to the appropriate handler. + async fn route_operation_request( + &mut self, + operation_request: Operation, + ) -> RedisResult { + let (sender, receiver) = oneshot::channel(); + self.0 + .send(Message { + cmd: CmdArg::OperationRequest(operation_request), + sender, + }) + .await + .map_err(|_| RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe)))?; + + receiver + .await + .unwrap_or_else(|err| { + Err(RedisError::from(io::Error::new( + io::ErrorKind::BrokenPipe, + err.to_string(), + ))) + }) + .map(|response| match response { + Response::Single(values) => values, + Response::ClusterScanResult(..) | Response::Multiple(_) => unreachable!(), + }) + } +} + +#[cfg(feature = "tokio-comp")] +#[derive(Clone)] +struct TokioDisconnectNotifier { + disconnect_notifier: Arc, +} + +#[cfg(feature = "tokio-comp")] +#[async_trait] +impl DisconnectNotifier for TokioDisconnectNotifier { + fn notify_disconnect(&mut self) { + self.disconnect_notifier.notify_one(); + } + + async fn wait_for_disconnect_with_timeout(&self, max_wait: &Duration) { + let _ = timeout(*max_wait, async { + self.disconnect_notifier.notified().await; + }) + .await; + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + +#[cfg(feature = "tokio-comp")] +impl TokioDisconnectNotifier { + fn new() -> TokioDisconnectNotifier { + TokioDisconnectNotifier { + disconnect_notifier: Arc::new(Notify::new()), + } + } +} + +type ConnectionMap = connections_container::ConnectionsMap>; +type ConnectionsContainer = + self::connections_container::ConnectionsContainer>; + +pub(crate) struct InnerCore { + pub(crate) conn_lock: StdRwLock>, + cluster_params: StdRwLock, + pending_requests: Mutex>>, + slot_refresh_state: SlotRefreshState, + initial_nodes: Vec, + subscriptions_by_address: TokioRwLock>, + unassigned_subscriptions: TokioRwLock, + glide_connection_options: GlideConnectionOptions, +} + +pub(crate) type Core = Arc>; + +impl InnerCore +where + C: ConnectionLike + Connect + Clone + Send + Sync + 'static, +{ + fn get_cluster_param(&self, f: F) -> Result + where + F: FnOnce(&ClusterParams) -> T, + T: Clone, + { + self.cluster_params + .read() + .map(|guard| f(&guard).clone()) + .map_err(|_| RedisError::from((ErrorKind::ClientError, MUTEX_READ_ERR))) + } + + fn set_cluster_param(&self, f: F) -> Result<(), RedisError> + where + F: FnOnce(&mut ClusterParams), + { + self.cluster_params + .write() + .map(|mut params| { + f(&mut params); + }) + .map_err(|_| RedisError::from((ErrorKind::ClientError, MUTEX_WRITE_ERR))) + } + + // return address of node for slot + pub(crate) async fn get_address_from_slot( + &self, + slot: u16, + slot_addr: SlotAddr, + ) -> Option> { + self.conn_lock + .read() + .expect(MUTEX_READ_ERR) + .slot_map + .get_node_address_for_slot(slot, slot_addr) + } + + // return epoch of node + pub(crate) async fn get_address_epoch(&self, node_address: &str) -> Result { + let command = cmd("CLUSTER").arg("INFO").to_owned(); + let node_conn = self + .conn_lock + .read() + .expect(MUTEX_READ_ERR) + .connection_for_address(node_address) + .ok_or(RedisError::from(( + ErrorKind::ResponseError, + "Failed to parse cluster info", + )))?; + + let cluster_info = node_conn.1.await.req_packed_command(&command).await; + match cluster_info { + Ok(value) => { + let info_dict: Result = + FromRedisValue::from_redis_value(&value); + if let Ok(info_dict) = info_dict { + let epoch = info_dict.get("cluster_my_epoch"); + if let Some(epoch) = epoch { + Ok(epoch) + } else { + Err(RedisError::from(( + ErrorKind::ResponseError, + "Failed to get epoch from cluster info", + ))) + } + } else { + Err(RedisError::from(( + ErrorKind::ResponseError, + "Failed to parse cluster info", + ))) + } + } + Err(redis_error) => Err(redis_error), + } + } + + // return slots of node + pub(crate) async fn get_slots_of_address(&self, node_address: Arc) -> Vec { + self.conn_lock + .read() + .expect(MUTEX_READ_ERR) + .slot_map + .get_slots_of_node(node_address) + } +} + +pub(crate) struct ClusterConnInner { + pub(crate) inner: Core, + state: ConnectionState, + #[allow(clippy::complexity)] + in_flight_requests: stream::FuturesUnordered>>>, + refresh_error: Option, + // Handler of the periodic check task. + periodic_checks_handler: Option>, + // Handler of fast connection validation task + connections_validation_handler: Option>, +} + +impl Dispose for ClusterConnInner { + fn dispose(self) { + if let Ok(conn_lock) = self.inner.conn_lock.try_read() { + // Each node may contain user and *maybe* a management connection + let mut count = 0usize; + for node in conn_lock.connection_map() { + count = node.connections_count(); + } + Telemetry::decr_total_connections(count); + } + + if let Some(handle) = self.periodic_checks_handler { + #[cfg(feature = "tokio-comp")] + handle.abort() + } + + if let Some(handle) = self.connections_validation_handler { + #[cfg(feature = "tokio-comp")] + handle.abort() + } + + // Reduce the number of clients + Telemetry::decr_total_clients(1); + } +} + +#[derive(Clone)] +pub(crate) enum InternalRoutingInfo { + SingleNode(InternalSingleNodeRouting), + MultiNode((MultipleNodeRoutingInfo, Option)), +} + +#[derive(PartialEq, Clone, Debug)] +/// Represents different policies for refreshing the cluster slots. +pub(crate) enum RefreshPolicy { + /// `Throttable` indicates that the refresh operation can be throttled, + /// meaning it can be delayed or rate-limited if necessary. + Throttable, + /// `NotThrottable` indicates that the refresh operation should not be throttled, + /// meaning it should be executed immediately without any delay or rate-limiting. + NotThrottable, +} + +impl From for InternalRoutingInfo { + fn from(value: cluster_routing::RoutingInfo) -> Self { + match value { + cluster_routing::RoutingInfo::SingleNode(route) => { + InternalRoutingInfo::SingleNode(route.into()) + } + cluster_routing::RoutingInfo::MultiNode(routes) => { + InternalRoutingInfo::MultiNode(routes) + } + } + } +} + +impl From> for InternalRoutingInfo { + fn from(value: InternalSingleNodeRouting) -> Self { + InternalRoutingInfo::SingleNode(value) + } +} + +#[derive(Clone)] +pub(crate) enum InternalSingleNodeRouting { + Random, + SpecificNode(Route), + ByAddress(String), + Connection { + address: String, + conn: ConnectionFuture, + }, + Redirect { + redirect: Redirect, + previous_routing: Box>, + }, +} + +impl Default for InternalSingleNodeRouting { + fn default() -> Self { + Self::Random + } +} + +impl From for InternalSingleNodeRouting { + fn from(value: SingleNodeRoutingInfo) -> Self { + match value { + SingleNodeRoutingInfo::Random => InternalSingleNodeRouting::Random, + SingleNodeRoutingInfo::SpecificNode(route) => { + InternalSingleNodeRouting::SpecificNode(route) + } + SingleNodeRoutingInfo::RandomPrimary => { + InternalSingleNodeRouting::SpecificNode(Route::new_random_primary()) + } + SingleNodeRoutingInfo::ByAddress { host, port } => { + InternalSingleNodeRouting::ByAddress(format!("{host}:{port}")) + } + } + } +} + +#[derive(Clone)] +enum CmdArg { + Cmd { + cmd: Arc, + routing: InternalRoutingInfo, + }, + Pipeline { + pipeline: Arc, + offset: usize, + count: usize, + route: InternalSingleNodeRouting, + }, + ClusterScan { + // struct containing the arguments for the cluster scan command - scan state cursor, match pattern, count and object type. + cluster_scan_args: ClusterScanArgs, + }, + // Operational requests which are connected to the internal state of the connection and not send as a command to the server. + OperationRequest(Operation), +} + +// Operation requests which are connected to the internal state of the connection and not send as a command to the server. +#[derive(Clone)] +enum Operation { + UpdateConnectionPassword(Option), +} + +fn route_for_pipeline(pipeline: &crate::Pipeline) -> RedisResult> { + fn route_for_command(cmd: &Cmd) -> Option { + match cluster_routing::RoutingInfo::for_routable(cmd) { + Some(cluster_routing::RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) => None, + Some(cluster_routing::RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(route), + )) => Some(route), + Some(cluster_routing::RoutingInfo::SingleNode( + SingleNodeRoutingInfo::RandomPrimary, + )) => Some(Route::new_random_primary()), + Some(cluster_routing::RoutingInfo::MultiNode(_)) => None, + Some(cluster_routing::RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + .. + })) => None, + None => None, + } + } + + // Find first specific slot and send to it. There's no need to check If later commands + // should be routed to a different slot, since the server will return an error indicating this. + pipeline.cmd_iter().map(route_for_command).try_fold( + None, + |chosen_route, next_cmd_route| match (chosen_route, next_cmd_route) { + (None, _) => Ok(next_cmd_route), + (_, None) => Ok(chosen_route), + (Some(chosen_route), Some(next_cmd_route)) => { + if chosen_route.slot() != next_cmd_route.slot() { + Err((ErrorKind::CrossSlot, "Received crossed slots in pipeline").into()) + } else if chosen_route.slot_addr() == SlotAddr::ReplicaOptional { + Ok(Some(next_cmd_route)) + } else { + Ok(Some(chosen_route)) + } + } + }, + ) +} + +fn boxed_sleep(duration: Duration) -> BoxFuture<'static, ()> { + Box::pin(tokio::time::sleep(duration)) +} + +#[derive(Debug, Display)] +pub(crate) enum Response { + Single(Value), + ClusterScanResult(ScanStateRC, Vec), + Multiple(Vec), +} + +#[derive(Debug)] +pub(crate) enum OperationTarget { + Node { address: String }, + FanOut, + NotFound, +} +type OperationResult = Result; + +impl From for OperationTarget { + fn from(address: String) -> Self { + OperationTarget::Node { address } + } +} + +/// Represents a node to which a `MOVED` or `ASK` error redirects. +#[derive(Clone, Debug)] +pub(crate) struct RedirectNode { + /// The address of the redirect node. + pub address: String, + /// The slot of the redirect node. + pub slot: u16, +} + +impl RedirectNode { + /// Constructs a `RedirectNode` from an optional tuple containing an address and a slot number. + pub(crate) fn from_option_tuple(option: Option<(&str, u16)>) -> Option { + option.map(|(address, slot)| RedirectNode { + address: address.to_string(), + slot, + }) + } +} + +struct Message { + cmd: CmdArg, + sender: oneshot::Sender>, +} + +enum RecoverFuture { + RecoverSlots(BoxFuture<'static, RedisResult<()>>), + Reconnect(BoxFuture<'static, ()>), +} + +enum ConnectionState { + PollComplete, + Recover(RecoverFuture), +} + +impl fmt::Debug for ConnectionState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}", + match self { + ConnectionState::PollComplete => "PollComplete", + ConnectionState::Recover(_) => "Recover", + } + ) + } +} + +#[derive(Clone)] +struct RequestInfo { + cmd: CmdArg, +} + +impl RequestInfo { + fn set_redirect(&mut self, redirect: Option) { + if let Some(redirect) = redirect { + match &mut self.cmd { + CmdArg::Cmd { routing, .. } => match routing { + InternalRoutingInfo::SingleNode(route) => { + let redirect = InternalSingleNodeRouting::Redirect { + redirect, + previous_routing: Box::new(std::mem::take(route)), + } + .into(); + *routing = redirect; + } + InternalRoutingInfo::MultiNode(_) => { + panic!("Cannot redirect multinode requests") + } + }, + CmdArg::Pipeline { route, .. } => { + let redirect = InternalSingleNodeRouting::Redirect { + redirect, + previous_routing: Box::new(std::mem::take(route)), + }; + *route = redirect; + } + // cluster_scan is sent as a normal command internally so we will not reach that point. + CmdArg::ClusterScan { .. } => { + unreachable!() + } + // Operation requests are not routed. + CmdArg::OperationRequest(_) => { + unreachable!() + } + } + } + } + + fn reset_routing(&mut self) { + let fix_route = |route: &mut InternalSingleNodeRouting| { + match route { + InternalSingleNodeRouting::Redirect { + previous_routing, .. + } => { + let previous_routing = std::mem::take(previous_routing.as_mut()); + *route = previous_routing; + } + // If a specific connection is specified, then reconnecting without resetting the routing + // will mean that the request is still routed to the old connection. + InternalSingleNodeRouting::Connection { address, .. } => { + *route = InternalSingleNodeRouting::ByAddress(address.to_string()); + } + _ => {} + } + }; + match &mut self.cmd { + CmdArg::Cmd { routing, .. } => { + if let InternalRoutingInfo::SingleNode(route) = routing { + fix_route(route); + } + } + CmdArg::Pipeline { route, .. } => { + fix_route(route); + } + // cluster_scan is sent as a normal command internally so we will not reach that point. + CmdArg::ClusterScan { .. } => { + unreachable!() + } + // Operation requests are not routed. + CmdArg::OperationRequest { .. } => { + unreachable!() + } + } + } +} + +pin_project! { + #[project = RequestStateProj] + enum RequestState { + None, + Future { + #[pin] + future: F, + }, + Sleep { + #[pin] + sleep: BoxFuture<'static, ()>, + }, + UpdateMoved { + #[pin] + future: BoxFuture<'static, RedisResult<()>>, + }, + } +} + +struct PendingRequest { + retry: u32, + sender: oneshot::Sender>, + info: RequestInfo, +} + +pin_project! { + struct Request { + retry_params: RetryParams, + request: Option>, + #[pin] + future: RequestState>, + } +} + +#[must_use] +enum Next { + Retry { + request: PendingRequest, + }, + RetryBusyLoadingError { + request: PendingRequest, + address: String, + }, + Reconnect { + // if not set, then a reconnect should happen without sending a request afterwards + request: Option>, + target: String, + }, + RefreshSlots { + // if not set, then a slot refresh should happen without sending a request afterwards + request: Option>, + sleep_duration: Option, + moved_redirect: Option, + }, + ReconnectToInitialNodes { + // if not set, then a reconnect should happen without sending a request afterwards + request: Option>, + }, + Done, +} + +impl Future for Request { + type Output = Next; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll { + let mut this = self.as_mut().project(); + // If the sender is closed, the caller is no longer waiting for the reply, and it is ambiguous + // whether they expect the side-effect of the request to happen or not. + if this.request.is_none() || this.request.as_ref().unwrap().sender.is_closed() { + return Poll::Ready(Next::Done); + } + let future = match this.future.as_mut().project() { + RequestStateProj::Future { future } => future, + RequestStateProj::Sleep { sleep } => { + ready!(sleep.poll(cx)); + return Next::Retry { + request: self.project().request.take().unwrap(), + } + .into(); + } + RequestStateProj::UpdateMoved { future } => { + if let Err(err) = ready!(future.poll(cx)) { + // Updating the slot map based on the MOVED error is an optimization. + // If it fails, proceed by retrying the request with the redirected node, + // and allow the slot refresh task to correct the slot map. + info!( + "Failed to update the slot map based on the received MOVED error. + Error: {err:?}" + ); + } + if let Some(request) = self.project().request.take() { + return Next::Retry { request }.into(); + } else { + return Next::Done.into(); + } + } + _ => panic!("Request future must be Some"), + }; + + match ready!(future.poll(cx)) { + Ok(item) => { + self.respond(Ok(item)); + Next::Done.into() + } + Err((target, err)) => { + let request = this.request.as_mut().unwrap(); + // TODO - would be nice if we didn't need to repeat this code twice, with & without retries. + if request.retry >= this.retry_params.number_of_retries { + let retry_method = err.retry_method(); + let next = if err.kind() == ErrorKind::AllConnectionsUnavailable { + Next::ReconnectToInitialNodes { request: None }.into() + } else if matches!(err.retry_method(), RetryMethod::MovedRedirect) + || matches!(target, OperationTarget::NotFound) + { + Next::RefreshSlots { + request: None, + sleep_duration: None, + moved_redirect: RedirectNode::from_option_tuple(err.redirect_node()), + } + .into() + } else if matches!(retry_method, RetryMethod::Reconnect) + || matches!(retry_method, RetryMethod::ReconnectAndRetry) + { + if let OperationTarget::Node { address } = target { + Next::Reconnect { + request: None, + target: address, + } + .into() + } else { + Next::Done.into() + } + } else { + Next::Done.into() + }; + self.respond(Err(err)); + return next; + } + request.retry = request.retry.saturating_add(1); + + if err.kind() == ErrorKind::AllConnectionsUnavailable { + return Next::ReconnectToInitialNodes { + request: Some(this.request.take().unwrap()), + } + .into(); + } + + let sleep_duration = this.retry_params.wait_time_for_retry(request.retry); + + let address = match target { + OperationTarget::Node { address } => address, + OperationTarget::FanOut => { + trace!("Request error `{}` multi-node request", err); + + // Fanout operation are retried per internal request, and don't need additional retries. + self.respond(Err(err)); + return Next::Done.into(); + } + OperationTarget::NotFound => { + // TODO - this is essentially a repeat of the retirable error. probably can remove duplication. + let mut request = this.request.take().unwrap(); + request.info.reset_routing(); + return Next::RefreshSlots { + request: Some(request), + sleep_duration: Some(sleep_duration), + moved_redirect: None, + } + .into(); + } + }; + + warn!("Received request error {} on node {:?}.", err, address); + + match err.retry_method() { + RetryMethod::AskRedirect => { + let mut request = this.request.take().unwrap(); + request.info.set_redirect( + err.redirect_node() + .map(|(node, _slot)| Redirect::Ask(node.to_string())), + ); + Next::Retry { request }.into() + } + RetryMethod::MovedRedirect => { + let mut request = this.request.take().unwrap(); + let redirect_node = err.redirect_node(); + request.info.set_redirect( + err.redirect_node() + .map(|(node, _slot)| Redirect::Moved(node.to_string())), + ); + Next::RefreshSlots { + request: Some(request), + sleep_duration: None, + moved_redirect: RedirectNode::from_option_tuple(redirect_node), + } + .into() + } + RetryMethod::WaitAndRetry => { + let sleep_duration = this.retry_params.wait_time_for_retry(request.retry); + // Sleep and retry. + this.future.set(RequestState::Sleep { + sleep: boxed_sleep(sleep_duration), + }); + self.poll(cx) + } + RetryMethod::Reconnect | RetryMethod::ReconnectAndRetry => { + let mut request = this.request.take().unwrap(); + // TODO should we reset the redirect here? + request.info.reset_routing(); + warn!("disconnected from {:?}", address); + let should_retry = + matches!(err.retry_method(), RetryMethod::ReconnectAndRetry); + Next::Reconnect { + request: should_retry.then_some(request), + target: address, + } + .into() + } + RetryMethod::WaitAndRetryOnPrimaryRedirectOnReplica => { + Next::RetryBusyLoadingError { + request: this.request.take().unwrap(), + address, + } + .into() + } + RetryMethod::RetryImmediately => Next::Retry { + request: this.request.take().unwrap(), + } + .into(), + RetryMethod::NoRetry => { + self.respond(Err(err)); + Next::Done.into() + } + } + } + } + } +} + +impl Request { + fn respond(self: Pin<&mut Self>, msg: RedisResult) { + // If `send` errors the receiver has dropped and thus does not care about the message + let _ = self + .project() + .request + .take() + .expect("Result should only be sent once") + .sender + .send(msg); + } +} + +enum ConnectionCheck { + Found((String, ConnectionFuture)), + OnlyAddress(String), + RandomConnection, +} + +impl ClusterConnInner +where + C: ConnectionLike + Connect + Clone + Send + Sync + 'static, +{ + async fn new( + initial_nodes: &[ConnectionInfo], + cluster_params: ClusterParams, + push_sender: Option>, + ) -> RedisResult> { + let disconnect_notifier = { + #[cfg(feature = "tokio-comp")] + { + Some::>(Box::new(TokioDisconnectNotifier::new())) + } + #[cfg(not(feature = "tokio-comp"))] + None + }; + + let discover_az = matches!( + cluster_params.read_from_replicas, + crate::cluster_slotmap::ReadFromReplicaStrategy::AZAffinity(_) + ); + + let glide_connection_options = GlideConnectionOptions { + push_sender, + disconnect_notifier, + discover_az, + }; + + let connections = Self::create_initial_connections( + initial_nodes, + &cluster_params, + glide_connection_options.clone(), + ) + .await?; + + let topology_checks_interval = cluster_params.topology_checks_interval; + let slots_refresh_rate_limiter = cluster_params.slots_refresh_rate_limit; + let inner = Arc::new(InnerCore { + conn_lock: StdRwLock::new(ConnectionsContainer::new( + Default::default(), + connections, + cluster_params.read_from_replicas.clone(), + 0, + )), + cluster_params: StdRwLock::new(cluster_params.clone()), + pending_requests: Mutex::new(Vec::new()), + slot_refresh_state: SlotRefreshState::new(slots_refresh_rate_limiter), + initial_nodes: initial_nodes.to_vec(), + unassigned_subscriptions: TokioRwLock::new( + if let Some(subs) = cluster_params.pubsub_subscriptions { + subs.clone() + } else { + PubSubSubscriptionInfo::new() + }, + ), + subscriptions_by_address: TokioRwLock::new(Default::default()), + glide_connection_options, + }); + let mut connection = ClusterConnInner { + inner, + in_flight_requests: Default::default(), + refresh_error: None, + state: ConnectionState::PollComplete, + periodic_checks_handler: None, + connections_validation_handler: None, + }; + Self::refresh_slots_and_subscriptions_with_retries( + connection.inner.clone(), + &RefreshPolicy::NotThrottable, + ) + .await?; + + if let Some(duration) = topology_checks_interval { + let periodic_task = + ClusterConnInner::periodic_topology_check(connection.inner.clone(), duration); + #[cfg(feature = "tokio-comp")] + { + connection.periodic_checks_handler = Some(tokio::spawn(periodic_task)); + } + } + + let connections_validation_interval = cluster_params.connections_validation_interval; + if let Some(duration) = connections_validation_interval { + let connections_validation_handler = + ClusterConnInner::connections_validation_task(connection.inner.clone(), duration); + #[cfg(feature = "tokio-comp")] + { + connection.connections_validation_handler = + Some(tokio::spawn(connections_validation_handler)); + } + } + + // New client added + Telemetry::incr_total_clients(1); + Ok(Disposable::new(connection)) + } + + /// Go through each of the initial nodes and attempt to retrieve all IP entries from them. + /// If there's a DNS endpoint that directs to several IP addresses, add all addresses to the initial nodes list. + /// Returns a vector of tuples, each containing a node's address (including the hostname) and its corresponding SocketAddr if retrieved. + pub(crate) async fn try_to_expand_initial_nodes( + initial_nodes: &[ConnectionInfo], + ) -> Vec<(String, Option)> { + stream::iter(initial_nodes) + .fold( + Vec::with_capacity(initial_nodes.len()), + |mut acc, info| async { + let (host, port) = match &info.addr { + crate::ConnectionAddr::Tcp(host, port) => (host, port), + crate::ConnectionAddr::TcpTls { + host, + port, + insecure: _, + tls_params: _, + } => (host, port), + crate::ConnectionAddr::Unix(_) => { + // We don't support multiple addresses for a Unix address. Store the initial node address and continue + acc.push((info.addr.to_string(), None)); + return acc; + } + }; + match get_socket_addrs(host, *port).await { + Ok(socket_addrs) => { + for addr in socket_addrs { + acc.push((info.addr.to_string(), Some(addr))); + } + } + Err(_) => { + // Couldn't find socket addresses, store the initial node address and continue + acc.push((info.addr.to_string(), None)); + } + }; + acc + }, + ) + .await + } + + async fn create_initial_connections( + initial_nodes: &[ConnectionInfo], + params: &ClusterParams, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult> { + let initial_nodes: Vec<(String, Option)> = + Self::try_to_expand_initial_nodes(initial_nodes).await; + let connections = stream::iter(initial_nodes.iter().cloned()) + .map(|(node_addr, socket_addr)| { + let mut params: ClusterParams = params.clone(); + let glide_connection_options = glide_connection_options.clone(); + // set subscriptions to none, they will be applied upon the topology discovery + params.pubsub_subscriptions = None; + + async move { + let result = connect_and_check( + &node_addr, + params, + socket_addr, + RefreshConnectionType::AllConnections, + None, + glide_connection_options, + ) + .await + .get_node(); + let node_address = if let Some(socket_addr) = socket_addr { + socket_addr.to_string() + } else { + node_addr + }; + result.map(|node| (node_address, node)) + } + }) + .buffer_unordered(initial_nodes.len()) + .fold( + ( + ConnectionsMap(DashMap::with_capacity(initial_nodes.len())), + None, + ), + |connections: (ConnectionMap, Option), addr_conn_res| async move { + match addr_conn_res { + Ok((addr, node)) => { + connections.0 .0.insert(addr, node); + (connections.0, None) + } + Err(e) => (connections.0, Some(e.to_string())), + } + }, + ) + .await; + if connections.0 .0.is_empty() { + return Err(RedisError::from(( + ErrorKind::IoError, + "Failed to create initial connections", + connections.1.unwrap_or("".to_string()), + ))); + } + info!("Connected to initial nodes:\n{}", connections.0); + Ok(connections.0) + } + + // Reconnect to the initial nodes provided by the user in the creation of the client, + // and try to refresh the slots based on the initial connections. + // Being used when all cluster connections are unavailable. + fn reconnect_to_initial_nodes(inner: Arc>) -> impl Future { + let inner = inner.clone(); + let cluster_params = match inner.get_cluster_param(|params| params.clone()) { + Ok(params) => params, + Err(err) => { + warn!("Failed to get cluster params: {}", err); + return async {}.boxed(); + } + }; + Box::pin(async move { + let connection_map = match Self::create_initial_connections( + &inner.initial_nodes, + &cluster_params, + inner.glide_connection_options.clone(), + ) + .await + { + Ok(map) => map, + Err(err) => { + warn!("Can't reconnect to initial nodes: `{err}`"); + return; + } + }; + inner + .conn_lock + .write() + .expect(MUTEX_WRITE_ERR) + .extend_connection_map(connection_map); + if let Err(err) = Self::refresh_slots_and_subscriptions_with_retries( + inner.clone(), + &RefreshPolicy::Throttable, + ) + .await + { + warn!("Can't refresh slots with initial nodes: `{err}`"); + }; + }) + } + + // Validate all existing user connections and try to reconnect if necessary. + // In addition, as a safety measure, drop nodes that do not have any assigned slots. + // This function serves as a cheap alternative to slot_refresh() and thus can be used much more frequently. + // The function does not discover the topology from the cluster and assumes the cached topology is valid. + // In addition, the validation is done by peeking at the state of the underlying transport w/o overhead of additional commands to server. + async fn validate_all_user_connections(inner: Arc>) { + let mut all_valid_conns = HashMap::new(); + // prep connections and clean out these w/o assigned slots, as we might have established connections to unwanted hosts + let mut nodes_to_delete = Vec::new(); + let all_nodes_with_slots: HashSet>; + { + let connections_container = inner.conn_lock.read().expect(MUTEX_READ_ERR); + + all_nodes_with_slots = connections_container.slot_map.all_node_addresses(); + + connections_container + .all_node_connections() + .for_each(|(addr, con)| { + if all_nodes_with_slots.contains(&addr) { + all_valid_conns.insert(addr.clone(), con.clone()); + } else { + nodes_to_delete.push(addr.clone()); + } + }); + + for addr in &nodes_to_delete { + connections_container.remove_node(addr); + } + } + + // identify nodes with closed connection + let mut addrs_to_refresh = Vec::new(); + for (addr, con_fut) in &all_valid_conns { + let con = con_fut.clone().await; + // connection object might be present despite the transport being closed + if con.is_closed() { + // transport is closed, need to refresh + addrs_to_refresh.push(addr.clone()); + } + } + + // identify missing nodes + addrs_to_refresh.extend( + all_nodes_with_slots + .iter() + .filter(|addr| !all_valid_conns.contains_key(addr.as_str())) + .map(|addr| addr.to_string()), + ); + + if !addrs_to_refresh.is_empty() { + // don't try existing nodes since we know a. it does not exist. b. exist but its connection is closed + Self::refresh_connections( + inner.clone(), + addrs_to_refresh, + RefreshConnectionType::AllConnections, + false, + ) + .await; + } + } + + async fn refresh_connections( + inner: Arc>, + addresses: Vec, + conn_type: RefreshConnectionType, + check_existing_conn: bool, + ) { + info!("Started refreshing connections to {:?}", addresses); + let mut tasks = FuturesUnordered::new(); + let inner = inner.clone(); + + for address in addresses.into_iter() { + let inner = inner.clone(); + + tasks.push(async move { + let node_option = if check_existing_conn { + let connections_container = inner.conn_lock.read().expect(MUTEX_READ_ERR); + connections_container.remove_node(&address) + } else { + None + }; + + // Override subscriptions for this connection + let mut cluster_params = inner.cluster_params.read().expect(MUTEX_READ_ERR).clone(); + let subs_guard = inner.subscriptions_by_address.read().await; + cluster_params.pubsub_subscriptions = subs_guard.get(&address).cloned(); + drop(subs_guard); + + let node = get_or_create_conn( + &address, + node_option, + &cluster_params, + conn_type, + inner.glide_connection_options.clone(), + ) + .await; + + (address, node) + }); + } + + // Poll connection tasks as soon as each one finishes + while let Some(result) = tasks.next().await { + match result { + (address, Ok(node)) => { + let connections_container = inner.conn_lock.read().expect(MUTEX_READ_ERR); + connections_container.replace_or_add_connection_for_address(address, node); + } + (address, Err(err)) => { + warn!( + "Failed to refresh connection for node {}. Error: `{:?}`", + address, err + ); + } + } + } + debug!("refresh connections completed"); + } + + async fn aggregate_results( + receivers: Vec<(Option, oneshot::Receiver>)>, + routing: &MultipleNodeRoutingInfo, + response_policy: Option, + ) -> RedisResult { + let extract_result = |response| match response { + Response::Single(value) => value, + Response::Multiple(_) => unreachable!(), + Response::ClusterScanResult(_, _) => unreachable!(), + }; + + let convert_result = |res: Result, _>| { + res.map_err(|_| RedisError::from((ErrorKind::ResponseError, "request wasn't handled due to internal failure"))) // this happens only if the result sender is dropped before usage. + .and_then(|res| res.map(extract_result)) + }; + + let get_receiver = |(_, receiver): (_, oneshot::Receiver>)| async { + convert_result(receiver.await) + }; + + // Sanity + if receivers.is_empty() { + return Err(RedisError::from(( + ErrorKind::ClientError, + "Client internal error", + "Failed to aggregate results for multi-slot command. Maybe a malformed command?" + .to_string(), + ))); + } + + // TODO - once Value::Error will be merged, these will need to be updated to handle this new value. + match response_policy { + Some(ResponsePolicy::AllSucceeded) => { + future::try_join_all(receivers.into_iter().map(get_receiver)) + .await + .map(|mut results| { + results.pop().unwrap() // unwrap is safe, since at least one function succeeded + }) + } + Some(ResponsePolicy::OneSucceeded) => future::select_ok( + receivers + .into_iter() + .map(|tuple| Box::pin(get_receiver(tuple))), + ) + .await + .map(|(result, _)| result), + Some(ResponsePolicy::FirstSucceededNonEmptyOrAllEmpty) => { + // Attempt to return the first result that isn't `Nil` or an error. + // If no such response is found and all servers returned `Nil`, it indicates that all shards are empty, so return `Nil`. + // If we received only errors, return the last received error. + // If we received a mix of errors and `Nil`s, we can't determine if all shards are empty, + // thus we return the last received error instead of `Nil`. + let num_of_results: usize = receivers.len(); + let mut futures = receivers + .into_iter() + .map(get_receiver) + .collect::>(); + let mut nil_counter = 0; + let mut last_err = None; + while let Some(result) = futures.next().await { + match result { + Ok(Value::Nil) => nil_counter += 1, + Ok(val) => return Ok(val), + Err(e) => last_err = Some(e), + } + } + + if nil_counter == num_of_results { + // All received results are `Nil` + Ok(Value::Nil) + } else { + Err(last_err.unwrap_or_else(|| { + ( + ErrorKind::AllConnectionsUnavailable, + "Couldn't find any connection", + ) + .into() + })) + } + } + Some(ResponsePolicy::Aggregate(op)) => { + future::try_join_all(receivers.into_iter().map(get_receiver)) + .await + .and_then(|results| crate::cluster_routing::aggregate(results, op)) + } + Some(ResponsePolicy::AggregateLogical(op)) => { + future::try_join_all(receivers.into_iter().map(get_receiver)) + .await + .and_then(|results| crate::cluster_routing::logical_aggregate(results, op)) + } + Some(ResponsePolicy::CombineArrays) => { + future::try_join_all(receivers.into_iter().map(get_receiver)) + .await + .and_then(|results| match routing { + MultipleNodeRoutingInfo::MultiSlot((vec, args_pattern)) => { + crate::cluster_routing::combine_and_sort_array_results( + results, + vec, + args_pattern, + ) + } + _ => crate::cluster_routing::combine_array_results(results), + }) + } + Some(ResponsePolicy::CombineMaps) => { + future::try_join_all(receivers.into_iter().map(get_receiver)) + .await + .and_then(crate::cluster_routing::combine_map_results) + } + Some(ResponsePolicy::Special) | None => { + // This is our assumption - if there's no coherent way to aggregate the responses, we just map each response to the sender, and pass it to the user. + // TODO - once Value::Error is merged, we can use join_all and report separate errors and also pass successes. + future::try_join_all(receivers.into_iter().map(|(addr, receiver)| async move { + let result = convert_result(receiver.await)?; + // The unwrap here is possible, because if `addr` is None, an error should have been sent on the receiver. + Ok((Value::BulkString(addr.unwrap().as_bytes().to_vec()), result)) + })) + .await + .map(Value::Map) + } + } + } + + // Query a node to discover slot-> master mappings with retries + async fn refresh_slots_and_subscriptions_with_retries( + inner: Arc>, + policy: &RefreshPolicy, + ) -> RedisResult<()> { + let SlotRefreshState { + in_progress, + last_run, + rate_limiter, + } = &inner.slot_refresh_state; + // Ensure only a single slot refresh operation occurs at a time + if in_progress + .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed) + .is_err() + { + return Ok(()); + } + let mut should_refresh_slots = true; + if *policy == RefreshPolicy::Throttable { + // Check if the current slot refresh is triggered before the wait duration has passed + let last_run_rlock = last_run.read().await; + if let Some(last_run_time) = *last_run_rlock { + let passed_time = SystemTime::now() + .duration_since(last_run_time) + .unwrap_or_else(|err| { + warn!( + "Failed to get the duration since the last slot refresh, received error: {:?}", + err + ); + // Setting the passed time to 0 will force the current refresh to continue and reset the stored last_run timestamp with the current one + Duration::from_secs(0) + }); + let wait_duration = rate_limiter.wait_duration(); + if passed_time <= wait_duration { + debug!("Skipping slot refresh as the wait duration hasn't yet passed. Passed time = {:?}, + Wait duration = {:?}", passed_time, wait_duration); + should_refresh_slots = false; + } + } + } + + let mut res = Ok(()); + if should_refresh_slots { + let retry_strategy = ExponentialFactorBackoff::from_millis( + DEFAULT_REFRESH_SLOTS_RETRY_BASE_DURATION_MILLIS, + DEFAULT_REFRESH_SLOTS_RETRY_BASE_FACTOR, + ) + .map(jitter_range(0.8, 1.2)) + .take(DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES); + let retries_counter = AtomicUsize::new(0); + res = Retry::spawn(retry_strategy, || async { + let curr_retry = retries_counter.fetch_add(1, atomic::Ordering::Relaxed); + Self::refresh_slots(inner.clone(), curr_retry) + .await + .map_err(|err| { + if err.kind() == ErrorKind::AllConnectionsUnavailable { + RetryError::permanent(err) + } else { + RetryError::transient(err) + } + }) + }) + .await; + } + in_progress.store(false, Ordering::Relaxed); + + Self::refresh_pubsub_subscriptions(inner).await; + + res + } + + /// Determines if the cluster topology has changed and refreshes slots and subscriptions if needed. + /// Returns `RedisResult` with `true` if changes were detected and slots were refreshed, + /// or `false` if no changes were found. Raises an error if refreshing the topology fails. + pub(crate) async fn check_topology_and_refresh_if_diff( + inner: Arc>, + policy: &RefreshPolicy, + ) -> RedisResult { + let topology_changed = Self::check_for_topology_diff(inner.clone()).await; + if topology_changed { + Self::refresh_slots_and_subscriptions_with_retries(inner.clone(), policy).await?; + } + Ok(topology_changed) + } + + async fn periodic_topology_check(inner: Arc>, interval_duration: Duration) { + loop { + let _ = boxed_sleep(interval_duration).await; + // Check and refresh topology if needed + let should_refresh_pubsub = match Self::check_topology_and_refresh_if_diff( + inner.clone(), + &RefreshPolicy::Throttable, + ) + .await + { + Ok(topology_changed) => !topology_changed, + Err(err) => { + warn!( + "Failed to refresh slots during periodic topology checks:\n{:?}", + err + ); + true + } + }; + + // Refresh pubsub subscriptions if topology wasn't changed or an error occurred. + // This serves as a safety measure for validating pubsub subscriptions state in case it has drifted + // while topology stayed the same. + // For example, a failed attempt to refresh a connection which is triggered from refresh_pubsub_subscriptions(), + // might leave a node unconnected indefinitely in case topology is stable and no request are attempted to this node. + if should_refresh_pubsub { + Self::refresh_pubsub_subscriptions(inner.clone()).await; + } + } + } + + async fn connections_validation_task(inner: Arc>, interval_duration: Duration) { + loop { + if let Some(disconnect_notifier) = + inner.glide_connection_options.disconnect_notifier.clone() + { + disconnect_notifier + .wait_for_disconnect_with_timeout(&interval_duration) + .await; + } else { + let _ = boxed_sleep(interval_duration).await; + } + + Self::validate_all_user_connections(inner.clone()).await; + } + } + + async fn refresh_pubsub_subscriptions(inner: Arc>) { + if inner.cluster_params.read().expect(MUTEX_READ_ERR).protocol + != crate::types::ProtocolVersion::RESP3 + { + return; + } + + let mut addrs_to_refresh: HashSet = HashSet::new(); + { + let mut subs_by_address_guard = inner.subscriptions_by_address.write().await; + let mut unassigned_subs_guard = inner.unassigned_subscriptions.write().await; + let conns_read_guard = inner.conn_lock.read().expect(MUTEX_READ_ERR); + // validate active subscriptions location + subs_by_address_guard.retain(|current_address, address_subs| { + address_subs.retain(|kind, channels_patterns| { + channels_patterns.retain(|channel_pattern| { + let new_slot = get_slot(channel_pattern); + let valid = if let Some((new_address, _)) = conns_read_guard + .connection_for_route(&Route::new(new_slot, SlotAddr::Master)) + { + *new_address == *current_address + } else { + false + }; + // no new address or new address differ - move to unassigned and store this address for connection reset + if !valid { + // need to drop the original connection for clearing the subscription in the server, avoiding possible double-receivers + if conns_read_guard + .connection_for_address(current_address) + .is_some() + { + addrs_to_refresh.insert(current_address.clone()); + } + + unassigned_subs_guard + .entry(*kind) + .and_modify(|channels_patterns| { + channels_patterns.insert(channel_pattern.clone()); + }) + .or_insert(HashSet::from([channel_pattern.clone()])); + } + valid + }); + !channels_patterns.is_empty() + }); + !address_subs.is_empty() + }); + + // try to assign new addresses + unassigned_subs_guard.retain(|kind: &PubSubSubscriptionKind, channels_patterns| { + channels_patterns.retain(|channel_pattern| { + let new_slot = get_slot(channel_pattern); + if let Some((new_address, _)) = conns_read_guard + .connection_for_route(&Route::new(new_slot, SlotAddr::Master)) + { + // need to drop the new connection so the subscription will be picked up in setup_connection() + addrs_to_refresh.insert(new_address.clone()); + + let e = subs_by_address_guard + .entry(new_address.clone()) + .or_insert(PubSubSubscriptionInfo::new()); + + e.entry(*kind) + .or_insert(HashSet::new()) + .insert(channel_pattern.clone()); + + return false; + } + true + }); + !channels_patterns.is_empty() + }); + } + + if !addrs_to_refresh.is_empty() { + // immediately trigger connection reestablishment + Self::refresh_connections( + inner.clone(), + addrs_to_refresh.into_iter().collect(), + RefreshConnectionType::AllConnections, + false, + ) + .await; + } + } + + /// Queries log2n nodes (where n represents the number of cluster nodes) to determine whether their + /// topology view differs from the one currently stored in the connection manager. + /// Returns true if change was detected, otherwise false. + async fn check_for_topology_diff(inner: Arc>) -> bool { + let num_of_nodes = inner.conn_lock.read().expect(MUTEX_READ_ERR).len(); + let num_of_nodes_to_query = std::cmp::max(num_of_nodes.ilog2() as usize, 1); + let (res, failed_connections) = calculate_topology_from_random_nodes( + &inner, + num_of_nodes_to_query, + DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES, + ) + .await; + + if let Ok((_, found_topology_hash)) = res { + if inner + .conn_lock + .read() + .expect(MUTEX_READ_ERR) + .get_current_topology_hash() + != found_topology_hash + { + return true; + } + } + + if !failed_connections.is_empty() { + Self::refresh_connections( + inner, + failed_connections, + RefreshConnectionType::OnlyManagementConnection, + true, + ) + .await; + } + + false + } + + async fn refresh_slots(inner: Arc>, curr_retry: usize) -> RedisResult<()> { + // Update the slot refresh last run timestamp + let now = SystemTime::now(); + let mut last_run_wlock = inner.slot_refresh_state.last_run.write().await; + *last_run_wlock = Some(now); + drop(last_run_wlock); + Self::refresh_slots_inner(inner, curr_retry).await + } + + pub(crate) fn check_if_all_slots_covered(slot_map: &SlotMap) -> bool { + let mut slots_covered = 0; + for (end, slots) in slot_map.slots.iter() { + slots_covered += end.saturating_sub(slots.start).saturating_add(1); + } + slots_covered == SLOT_SIZE + } + + // Query a node to discover slot-> master mappings + async fn refresh_slots_inner(inner: Arc>, curr_retry: usize) -> RedisResult<()> { + let num_of_nodes = inner.conn_lock.read().expect(MUTEX_READ_ERR).len(); + const MAX_REQUESTED_NODES: usize = 10; + let num_of_nodes_to_query = std::cmp::min(num_of_nodes, MAX_REQUESTED_NODES); + let (new_slots, topology_hash) = + calculate_topology_from_random_nodes(&inner, num_of_nodes_to_query, curr_retry) + .await + .0?; + // Create a new connection vector of the found nodes + let nodes = new_slots.all_node_addresses(); + let nodes_len = nodes.len(); + let addresses_and_connections_iter = stream::iter(nodes) + .fold( + Vec::with_capacity(nodes_len), + |mut addrs_and_conns, addr| { + let inner = inner.clone(); + async move { + let addr = addr.to_string(); + if let Some(node) = inner + .conn_lock + .read() + .expect(MUTEX_READ_ERR) + .node_for_address(addr.as_str()) + { + addrs_and_conns.push((addr, Some(node))); + return addrs_and_conns; + } + // If it's a DNS endpoint, it could have been stored in the existing connections vector using the resolved IP address instead of the DNS endpoint's name. + // We shall check if a connection is already exists under the resolved IP name. + let Some((host, port)) = get_host_and_port_from_addr(&addr) else { + addrs_and_conns.push((addr, None)); + return addrs_and_conns; + }; + let conn = get_socket_addrs(host, port) + .await + .ok() + .map(|mut socket_addresses| { + socket_addresses.find_map(|addr| { + inner + .conn_lock + .read() + .expect(MUTEX_READ_ERR) + .node_for_address(&addr.to_string()) + }) + }) + .unwrap_or(None); + addrs_and_conns.push((addr, conn)); + addrs_and_conns + } + }, + ) + .await; + let new_connections: ConnectionMap = stream::iter(addresses_and_connections_iter) + .fold( + ConnectionsMap(DashMap::with_capacity(nodes_len)), + |connections, (addr, node)| async { + let mut cluster_params = inner + .get_cluster_param(|params| params.clone()) + .expect(MUTEX_READ_ERR); + let subs_guard = inner.subscriptions_by_address.read().await; + cluster_params.pubsub_subscriptions = subs_guard.get(&addr).cloned(); + drop(subs_guard); + let node = get_or_create_conn( + &addr, + node, + &cluster_params, + RefreshConnectionType::AllConnections, + inner.glide_connection_options.clone(), + ) + .await; + if let Ok(node) = node { + connections.0.insert(addr, node); + } + connections + }, + ) + .await; + + info!("refresh_slots found nodes:\n{new_connections}"); + // Reset the current slot map and connection vector with the new ones + let mut write_guard = inner.conn_lock.write().expect(MUTEX_WRITE_ERR); + let read_from_replicas = inner + .get_cluster_param(|params| params.read_from_replicas.clone()) + .expect(MUTEX_READ_ERR); + *write_guard = ConnectionsContainer::new( + new_slots, + new_connections, + read_from_replicas, + topology_hash, + ); + Ok(()) + } + + /// Handles MOVED errors by updating the client's slot and node mappings based on the new primary's role: + /// + /// 1. **No Change**: If the new primary is already the current slot owner, no updates are needed. + /// 2. **Failover**: If the new primary is a replica within the same shard (indicating a failover), + /// the slot ownership is updated by promoting the replica to the primary in the existing shard addresses. + /// 3. **Slot Migration**: If the new primary is an existing primary in another shard, this indicates a slot migration, + /// and the slot mapping is updated to point to the new shard addresses. + /// 4. **Replica Moved to a Different Shard**: If the new primary is a replica in a different shard, it can be due to: + /// - The replica became the primary of its shard after a failover, with new slots migrated to it. + /// - The replica has moved to a different shard as the primary. + /// Since further information is unknown, the replica is removed from its original shard and added as the primary of a new shard. + /// 5. **New Node**: If the new primary is unknown, it is added as a new node in a new shard, possibly indicating scale-out. + /// + /// # Arguments + /// * `inner` - Shared reference to InnerCore containing connection and slot state. + /// * `slot` - The slot number reported as moved. + /// * `new_primary` - The address of the node now responsible for the slot. + /// + /// # Returns + /// * `RedisResult<()>` indicating success or failure in updating slot mappings. + async fn update_upon_moved_error( + inner: Arc>, + slot: u16, + new_primary: Arc, + ) -> RedisResult<()> { + let curr_shard_addrs = inner + .conn_lock + .read() + .expect(MUTEX_READ_ERR) + .slot_map + .shard_addrs_for_slot(slot); + // let curr_shard_addrs = connections_container.slot_map.shard_addrs_for_slot(slot); + // Check if the new primary is part of the current shard and update if required + if let Some(curr_shard_addrs) = curr_shard_addrs { + match curr_shard_addrs.attempt_shard_role_update(new_primary.clone()) { + // Scenario 1: No changes needed as the new primary is already the current slot owner. + // Scenario 2: Failover occurred and the new primary was promoted from a replica. + ShardUpdateResult::AlreadyPrimary | ShardUpdateResult::Promoted => return Ok(()), + // The node was not found in this shard, proceed with further scenarios. + ShardUpdateResult::NodeNotFound => {} + } + } + + // Scenario 3 & 4: Check if the new primary exists in other shards + + let mut wlock_conn_container = inner.conn_lock.write().expect(MUTEX_READ_ERR); + let mut nodes_iter = wlock_conn_container.slot_map_nodes(); + for (node_addr, shard_addrs_arc) in &mut nodes_iter { + if node_addr == new_primary { + let is_existing_primary = shard_addrs_arc.primary().eq(&new_primary); + if is_existing_primary { + // Scenario 3: Slot Migration - The new primary is an existing primary in another shard + // Update the associated addresses for `slot` to `shard_addrs`. + drop(nodes_iter); + return wlock_conn_container + .slot_map + .update_slot_range(slot, shard_addrs_arc.clone()); + } else { + // Scenario 4: The MOVED error redirects to `new_primary` which is known as a replica in a shard that doesn’t own `slot`. + // Remove the replica from its existing shard and treat it as a new node in a new shard. + shard_addrs_arc.remove_replica(new_primary.clone())?; + drop(nodes_iter); + return wlock_conn_container + .slot_map + .add_new_primary(slot, new_primary); + } + } + } + + drop(nodes_iter); + // Scenario 5: New Node - The new primary is not present in the current slots map, add it as a primary of a new shard. + wlock_conn_container + .slot_map + .add_new_primary(slot, new_primary) + } + + async fn execute_on_multiple_nodes<'a>( + cmd: &'a Arc, + routing: &'a MultipleNodeRoutingInfo, + core: Core, + response_policy: Option, + ) -> OperationResult { + trace!("execute_on_multiple_nodes"); + + // This function maps the connections to senders & receivers of one-shot channels, and the receivers are mapped to `PendingRequest`s. + // This allows us to pass the new `PendingRequest`s to `try_request`, while letting `execute_on_multiple_nodes` wait on the receivers + // for all of the individual requests to complete. + #[allow(clippy::type_complexity)] // The return value is complex, but indentation and linebreaks make it human readable. + fn into_channels( + iterator: impl Iterator< + Item = Option<(Arc, ConnectionAndAddress>)>, + >, + ) -> ( + Vec<(Option, Receiver>)>, + Vec>>, + ) { + iterator + .map(|tuple_opt| { + let (sender, receiver) = oneshot::channel(); + if let Some((cmd, conn, address)) = + tuple_opt.map(|(cmd, (address, conn))| (cmd, conn, address)) + { + ( + (Some(address.clone()), receiver), + Some(PendingRequest { + retry: 0, + sender, + info: RequestInfo { + cmd: CmdArg::Cmd { + cmd, + routing: InternalSingleNodeRouting::Connection { + address, + conn, + } + .into(), + }, + }, + }), + ) + } else { + let _ = sender.send(Err(( + ErrorKind::ConnectionNotFoundForRoute, + "Connection not found", + ) + .into())); + ((None, receiver), None) + } + }) + .unzip() + } + let (receivers, requests): (Vec<_>, Vec<_>); + { + let connections_container = core.conn_lock.read().expect(MUTEX_READ_ERR); + if connections_container.is_empty() { + return OperationResult::Err(( + OperationTarget::FanOut, + ( + ErrorKind::AllConnectionsUnavailable, + "No connections found for multi-node operation", + ) + .into(), + )); + } + + (receivers, requests) = match routing { + MultipleNodeRoutingInfo::AllNodes => into_channels( + connections_container + .all_node_connections() + .map(|tuple| Some((cmd.clone(), tuple))), + ), + MultipleNodeRoutingInfo::AllMasters => into_channels( + connections_container + .all_primary_connections() + .map(|tuple| Some((cmd.clone(), tuple))), + ), + MultipleNodeRoutingInfo::MultiSlot((slots, _)) => { + into_channels(slots.iter().map(|(route, indices)| { + connections_container + .connection_for_route(route) + .map(|tuple| { + let new_cmd = + crate::cluster_routing::command_for_multi_slot_indices( + cmd.as_ref(), + indices.iter(), + ); + (Arc::new(new_cmd), tuple) + }) + })) + } + }; + } + core.pending_requests + .lock() + .unwrap() + .extend(requests.into_iter().flatten()); + + Self::aggregate_results(receivers, routing, response_policy) + .await + .map(Response::Single) + .map_err(|err| (OperationTarget::FanOut, err)) + } + + pub(crate) async fn try_cmd_request( + cmd: Arc, + routing: InternalRoutingInfo, + core: Core, + ) -> OperationResult { + let routing = match routing { + // commands that are sent to multiple nodes are handled here. + InternalRoutingInfo::MultiNode((multi_node_routing, response_policy)) => { + return Self::execute_on_multiple_nodes( + &cmd, + &multi_node_routing, + core, + response_policy, + ) + .await; + } + + InternalRoutingInfo::SingleNode(routing) => routing, + }; + trace!("route request to single node"); + + // if we reached this point, we're sending the command only to single node, and we need to find the + // right connection to the node. + let (address, mut conn) = Self::get_connection(routing, core, Some(cmd.clone())) + .await + .map_err(|err| (OperationTarget::NotFound, err))?; + conn.req_packed_command(&cmd) + .await + .map(Response::Single) + .map_err(|err| (address.into(), err)) + } + + async fn try_pipeline_request( + pipeline: Arc, + offset: usize, + count: usize, + conn: impl Future>, + ) -> OperationResult { + trace!("try_pipeline_request"); + let (address, mut conn) = conn.await.map_err(|err| (OperationTarget::NotFound, err))?; + conn.req_packed_commands(&pipeline, offset, count) + .await + .map(Response::Multiple) + .map_err(|err| (OperationTarget::Node { address }, err)) + } + + async fn try_request(info: RequestInfo, core: Core) -> OperationResult { + match info.cmd { + CmdArg::Cmd { cmd, routing } => Self::try_cmd_request(cmd, routing, core).await, + CmdArg::Pipeline { + pipeline, + offset, + count, + route, + } => { + Self::try_pipeline_request( + pipeline, + offset, + count, + Self::get_connection(route, core, None), + ) + .await + } + CmdArg::ClusterScan { + cluster_scan_args, .. + } => { + let core = core; + let scan_result = cluster_scan(core, cluster_scan_args).await; + match scan_result { + Ok((scan_state_ref, values)) => { + Ok(Response::ClusterScanResult(scan_state_ref, values)) + } + // TODO: After routing issues with sending to random node on not-key based commands are resolved, + // this error should be handled in the same way as other errors and not fan-out. + Err(err) => Err((OperationTarget::FanOut, err)), + } + } + CmdArg::OperationRequest(operation_request) => match operation_request { + Operation::UpdateConnectionPassword(password) => { + core.set_cluster_param(|params| params.password = password) + .expect(MUTEX_WRITE_ERR); + Ok(Response::Single(Value::Okay)) + } + }, + } + } + + async fn get_connection( + routing: InternalSingleNodeRouting, + core: Core, + cmd: Option>, + ) -> RedisResult<(String, C)> { + let mut asking = false; + + let conn_check = match routing { + InternalSingleNodeRouting::Redirect { + redirect: Redirect::Moved(moved_addr), + .. + } => core + .conn_lock + .read() + .expect(MUTEX_READ_ERR) + .connection_for_address(moved_addr.as_str()) + .map_or( + ConnectionCheck::OnlyAddress(moved_addr), + ConnectionCheck::Found, + ), + InternalSingleNodeRouting::Redirect { + redirect: Redirect::Ask(ask_addr), + .. + } => { + asking = true; + core.conn_lock + .read() + .expect(MUTEX_READ_ERR) + .connection_for_address(ask_addr.as_str()) + .map_or( + ConnectionCheck::OnlyAddress(ask_addr), + ConnectionCheck::Found, + ) + } + InternalSingleNodeRouting::SpecificNode(route) => { + if let Some((conn, address)) = core + .conn_lock + .read() + .expect(MUTEX_READ_ERR) + .connection_for_route(&route) + { + ConnectionCheck::Found((conn, address)) + } else { + // No connection is found for the given route: + // - For key-based commands, attempt redirection to a random node, + // hopefully to be redirected afterwards by a MOVED error. + // - For non-key-based commands, avoid attempting redirection to a random node + // as it wouldn't result in MOVED hints and can lead to unwanted results + // (e.g., sending management command to a different node than the user asked for); instead, raise the error. + let routable_cmd = cmd.and_then(|cmd| Routable::command(&*cmd)); + if routable_cmd.is_some() + && !RoutingInfo::is_key_routing_command(&routable_cmd.unwrap()) + { + return Err(( + ErrorKind::ConnectionNotFoundForRoute, + "Requested connection not found for route", + format!("{route:?}"), + ) + .into()); + } else { + warn!("No connection found for route `{route:?}`. Attempting redirection to a random node."); + ConnectionCheck::RandomConnection + } + } + } + InternalSingleNodeRouting::Random => ConnectionCheck::RandomConnection, + InternalSingleNodeRouting::Connection { address, conn } => { + return Ok((address, conn.await)); + } + InternalSingleNodeRouting::ByAddress(address) => { + let conn_option = core + .conn_lock + .read() + .expect(MUTEX_READ_ERR) + .connection_for_address(&address); + if let Some((address, conn)) = conn_option { + return Ok((address, conn.await)); + } else { + return Err(( + ErrorKind::ConnectionNotFoundForRoute, + "Requested connection not found", + address, + ) + .into()); + } + } + }; + + let (address, mut conn) = match conn_check { + ConnectionCheck::Found((address, connection)) => (address, connection.await), + ConnectionCheck::OnlyAddress(addr) => { + let mut this_conn_params = core.get_cluster_param(|params| params.clone())?; + let subs_guard = core.subscriptions_by_address.read().await; + this_conn_params.pubsub_subscriptions = subs_guard.get(addr.as_str()).cloned(); + drop(subs_guard); + match connect_and_check::( + &addr, + this_conn_params, + None, + RefreshConnectionType::AllConnections, + None, + core.glide_connection_options.clone(), + ) + .await + .get_node() + { + Ok(node) => { + let connection_clone = node.user_connection.conn.clone().await; + let connections = core.conn_lock.read().expect(MUTEX_READ_ERR); + let address = connections.replace_or_add_connection_for_address(addr, node); + drop(connections); + (address, connection_clone) + } + Err(err) => { + return Err(err); + } + } + } + ConnectionCheck::RandomConnection => { + let random_conn = core + .conn_lock + .read() + .expect(MUTEX_READ_ERR) + .random_connections(1, ConnectionType::User); + let (random_address, random_conn_future) = + match random_conn.and_then(|conn_iter| conn_iter.into_iter().next()) { + Some((address, future)) => (address, future), + None => { + return Err(RedisError::from(( + ErrorKind::AllConnectionsUnavailable, + "No random connection found", + ))); + } + }; + + (random_address, random_conn_future.await) + } + }; + + if asking { + let _ = conn.req_packed_command(&crate::cmd::cmd("ASKING")).await; + } + Ok((address, conn)) + } + + fn poll_recover(&mut self, cx: &mut task::Context<'_>) -> Poll> { + let recover_future = match &mut self.state { + ConnectionState::PollComplete => return Poll::Ready(Ok(())), + ConnectionState::Recover(future) => future, + }; + match recover_future { + RecoverFuture::RecoverSlots(ref mut future) => match ready!(future.as_mut().poll(cx)) { + Ok(_) => { + trace!("Recovered!"); + self.state = ConnectionState::PollComplete; + Poll::Ready(Ok(())) + } + Err(err) => { + trace!("Recover slots failed!"); + let next_state = if err.kind() == ErrorKind::AllConnectionsUnavailable { + ConnectionState::Recover(RecoverFuture::Reconnect(Box::pin( + ClusterConnInner::reconnect_to_initial_nodes(self.inner.clone()), + ))) + } else { + ConnectionState::Recover(RecoverFuture::RecoverSlots(Box::pin( + Self::refresh_slots_and_subscriptions_with_retries( + self.inner.clone(), + &RefreshPolicy::Throttable, + ), + ))) + }; + self.state = next_state; + Poll::Ready(Err(err)) + } + }, + RecoverFuture::Reconnect(ref mut future) => { + ready!(future.as_mut().poll(cx)); + trace!("Reconnected connections"); + self.state = ConnectionState::PollComplete; + Poll::Ready(Ok(())) + } + } + } + + async fn handle_loading_error( + core: Core, + info: RequestInfo, + address: String, + retry: u32, + retry_params: RetryParams, + ) -> OperationResult { + let is_primary = core + .conn_lock + .read() + .expect(MUTEX_READ_ERR) + .is_primary(&address); + + if !is_primary { + // If the connection is a replica, remove the connection and retry. + // The connection will be established again on the next call to refresh slots once the replica is no longer in loading state. + core.conn_lock + .read() + .expect(MUTEX_READ_ERR) + .remove_node(&address); + } else { + // If the connection is primary, just sleep and retry + let sleep_duration = retry_params.wait_time_for_retry(retry); + boxed_sleep(sleep_duration).await; + } + + Self::try_request(info, core).await + } + + fn poll_complete(&mut self, cx: &mut task::Context<'_>) -> Poll { + let retry_params = self + .inner + .get_cluster_param(|params| params.retry_params.clone()) + .expect(MUTEX_READ_ERR); + let mut poll_flush_action = PollFlushAction::None; + let mut pending_requests_guard = self.inner.pending_requests.lock().unwrap(); + if !pending_requests_guard.is_empty() { + let mut pending_requests = mem::take(&mut *pending_requests_guard); + for request in pending_requests.drain(..) { + // Drop the request if none is waiting for a response to free up resources for + // requests callers care about (load shedding). It will be ambiguous whether the + // request actually goes through regardless. + if request.sender.is_closed() { + continue; + } + + let future = Self::try_request(request.info.clone(), self.inner.clone()).boxed(); + self.in_flight_requests.push(Box::pin(Request { + retry_params: retry_params.clone(), + request: Some(request), + future: RequestState::Future { future }, + })); + } + *pending_requests_guard = pending_requests; + } + drop(pending_requests_guard); + + loop { + let retry_params = retry_params.clone(); + let result = match Pin::new(&mut self.in_flight_requests).poll_next(cx) { + Poll::Ready(Some(result)) => result, + Poll::Ready(None) | Poll::Pending => break, + }; + match result { + Next::Done => {} + Next::Retry { request } => { + let future = Self::try_request(request.info.clone(), self.inner.clone()); + self.in_flight_requests.push(Box::pin(Request { + retry_params: retry_params.clone(), + request: Some(request), + future: RequestState::Future { + future: Box::pin(future), + }, + })); + } + Next::RetryBusyLoadingError { request, address } => { + // TODO - do we also want to try and reconnect to replica if it is loading? + let future = Self::handle_loading_error( + self.inner.clone(), + request.info.clone(), + address, + request.retry, + retry_params.clone(), + ); + self.in_flight_requests.push(Box::pin(Request { + retry_params: retry_params.clone(), + request: Some(request), + future: RequestState::Future { + future: Box::pin(future), + }, + })); + } + Next::RefreshSlots { + request, + sleep_duration, + moved_redirect, + } => { + poll_flush_action = + poll_flush_action.change_state(PollFlushAction::RebuildSlots); + let future: Option< + RequestState + Send>>>, + > = if let Some(moved_redirect) = moved_redirect { + Some(RequestState::UpdateMoved { + future: Box::pin(ClusterConnInner::update_upon_moved_error( + self.inner.clone(), + moved_redirect.slot, + moved_redirect.address.into(), + )), + }) + } else if let Some(ref request) = request { + match sleep_duration { + Some(sleep_duration) => Some(RequestState::Sleep { + sleep: boxed_sleep(sleep_duration), + }), + None => Some(RequestState::Future { + future: Box::pin(Self::try_request( + request.info.clone(), + self.inner.clone(), + )), + }), + } + } else { + None + }; + if let Some(future) = future { + self.in_flight_requests.push(Box::pin(Request { + retry_params, + request, + future, + })); + } + } + Next::Reconnect { request, target } => { + poll_flush_action = + poll_flush_action.change_state(PollFlushAction::Reconnect(vec![target])); + if let Some(request) = request { + self.inner.pending_requests.lock().unwrap().push(request); + } + } + Next::ReconnectToInitialNodes { request } => { + poll_flush_action = poll_flush_action + .change_state(PollFlushAction::ReconnectFromInitialConnections); + if let Some(request) = request { + self.inner.pending_requests.lock().unwrap().push(request); + } + } + } + } + + if matches!(poll_flush_action, PollFlushAction::None) { + if self.in_flight_requests.is_empty() { + Poll::Ready(poll_flush_action) + } else { + Poll::Pending + } + } else { + Poll::Ready(poll_flush_action) + } + } + + fn send_refresh_error(&mut self) { + if self.refresh_error.is_some() { + if let Some(mut request) = Pin::new(&mut self.in_flight_requests) + .iter_pin_mut() + .find(|request| request.request.is_some()) + { + (*request) + .as_mut() + .respond(Err(self.refresh_error.take().unwrap())); + } else if let Some(request) = self.inner.pending_requests.lock().unwrap().pop() { + let _ = request.sender.send(Err(self.refresh_error.take().unwrap())); + } + } + } +} + +enum PollFlushAction { + None, + RebuildSlots, + Reconnect(Vec), + ReconnectFromInitialConnections, +} + +impl PollFlushAction { + fn change_state(self, next_state: PollFlushAction) -> PollFlushAction { + match (self, next_state) { + (PollFlushAction::None, next_state) => next_state, + (next_state, PollFlushAction::None) => next_state, + (PollFlushAction::ReconnectFromInitialConnections, _) + | (_, PollFlushAction::ReconnectFromInitialConnections) => { + PollFlushAction::ReconnectFromInitialConnections + } + + (PollFlushAction::RebuildSlots, _) | (_, PollFlushAction::RebuildSlots) => { + PollFlushAction::RebuildSlots + } + + (PollFlushAction::Reconnect(mut addrs), PollFlushAction::Reconnect(new_addrs)) => { + addrs.extend(new_addrs); + Self::Reconnect(addrs) + } + } + } +} + +impl Sink> for Disposable> +where + C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static, +{ + type Error = (); + + fn poll_ready(self: Pin<&mut Self>, _cx: &mut task::Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, msg: Message) -> Result<(), Self::Error> { + let Message { cmd, sender } = msg; + + let info = RequestInfo { cmd }; + + self.inner + .pending_requests + .lock() + .unwrap() + .push(PendingRequest { + retry: 0, + sender, + info, + }); + Ok(()) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + ) -> Poll> { + trace!("poll_flush: {:?}", self.state); + loop { + self.send_refresh_error(); + + if let Err(err) = ready!(self.as_mut().poll_recover(cx)) { + // We failed to reconnect, while we will try again we will report the + // error if we can to avoid getting trapped in an infinite loop of + // trying to reconnect + self.refresh_error = Some(err); + + // Give other tasks a chance to progress before we try to recover + // again. Since the future may not have registered a wake up we do so + // now so the task is not forgotten + cx.waker().wake_by_ref(); + return Poll::Pending; + } + + match ready!(self.poll_complete(cx)) { + PollFlushAction::None => return Poll::Ready(Ok(())), + PollFlushAction::RebuildSlots => { + self.state = ConnectionState::Recover(RecoverFuture::RecoverSlots(Box::pin( + ClusterConnInner::refresh_slots_and_subscriptions_with_retries( + self.inner.clone(), + &RefreshPolicy::Throttable, + ), + ))); + } + PollFlushAction::Reconnect(addresses) => { + self.state = ConnectionState::Recover(RecoverFuture::Reconnect(Box::pin( + ClusterConnInner::refresh_connections( + self.inner.clone(), + addresses, + RefreshConnectionType::OnlyUserConnection, + true, + ), + ))); + } + PollFlushAction::ReconnectFromInitialConnections => { + self.state = ConnectionState::Recover(RecoverFuture::Reconnect(Box::pin( + ClusterConnInner::reconnect_to_initial_nodes(self.inner.clone()), + ))); + } + } + } + } + + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + ) -> Poll> { + // Try to drive any in flight requests to completion + match self.poll_complete(cx) { + Poll::Ready(PollFlushAction::None) => (), + Poll::Ready(_) => Err(())?, + Poll::Pending => (), + }; + // If we no longer have any requests in flight we are done (skips any reconnection + // attempts) + if self.in_flight_requests.is_empty() { + return Poll::Ready(Ok(())); + } + + self.poll_flush(cx) + } +} + +async fn calculate_topology_from_random_nodes<'a, C>( + inner: &Core, + num_of_nodes_to_query: usize, + curr_retry: usize, +) -> ( + RedisResult<( + crate::cluster_slotmap::SlotMap, + crate::cluster_topology::TopologyHash, + )>, + Vec, +) +where + C: ConnectionLike + Connect + Clone + Send + Sync + 'static, +{ + let requested_nodes = if let Some(random_conns) = inner + .conn_lock + .read() + .expect(MUTEX_READ_ERR) + .random_connections(num_of_nodes_to_query, ConnectionType::PreferManagement) + { + random_conns + } else { + return ( + Err(RedisError::from(( + ErrorKind::AllConnectionsUnavailable, + "No available connections to refresh slots from", + ))), + vec![], + ); + }; + let topology_join_results = + futures::future::join_all(requested_nodes.into_iter().map(|(addr, conn)| async move { + let mut conn: C = conn.await; + let res = conn.req_packed_command(&slot_cmd()).await; + (addr, res) + })) + .await; + let failed_addresses = topology_join_results + .iter() + .filter_map(|(address, res)| match res { + Err(err) if err.is_unrecoverable_error() => Some(address.clone()), + _ => None, + }) + .collect(); + let topology_values = topology_join_results.iter().filter_map(|(addr, res)| { + res.as_ref() + .ok() + .and_then(|value| get_host_and_port_from_addr(addr).map(|(host, _)| (host, value))) + }); + let tls_mode = inner + .get_cluster_param(|params| params.tls) + .expect(MUTEX_READ_ERR); + + let read_from_replicas = inner + .get_cluster_param(|params| params.read_from_replicas.clone()) + .expect(MUTEX_READ_ERR); + ( + calculate_topology( + topology_values, + curr_retry, + tls_mode, + num_of_nodes_to_query, + read_from_replicas, + ), + failed_addresses, + ) +} + +impl ConnectionLike for ClusterConnection +where + C: ConnectionLike + Send + Clone + Unpin + Sync + Connect + 'static, +{ + fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> { + let routing = cluster_routing::RoutingInfo::for_routable(cmd).unwrap_or( + cluster_routing::RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random), + ); + self.route_command(cmd, routing).boxed() + } + + fn req_packed_commands<'a>( + &'a mut self, + pipeline: &'a crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisFuture<'a, Vec> { + async move { + let route = route_for_pipeline(pipeline)?; + self.route_pipeline(pipeline, offset, count, route.into()) + .await + } + .boxed() + } + + fn get_db(&self) -> i64 { + 0 + } + + fn is_closed(&self) -> bool { + false + } +} + +/// Implements the process of connecting to a Redis server +/// and obtaining a connection handle. +pub trait Connect: Sized { + /// Connect to a node. + /// For TCP connections, returning a tuple of handle for command execution and the node's IP address. + /// For UNIX connections, returning a tuple of handle for command execution and None. + fn connect<'a, T>( + info: T, + response_timeout: Duration, + connection_timeout: Duration, + socket_addr: Option, + glide_connection_options: GlideConnectionOptions, + ) -> RedisFuture<'a, (Self, Option)> + where + T: IntoConnectionInfo + Send + 'a; +} + +impl Connect for MultiplexedConnection { + fn connect<'a, T>( + info: T, + response_timeout: Duration, + connection_timeout: Duration, + socket_addr: Option, + glide_connection_options: GlideConnectionOptions, + ) -> RedisFuture<'a, (MultiplexedConnection, Option)> + where + T: IntoConnectionInfo + Send + 'a, + { + async move { + let connection_info = info.into_connection_info()?; + let client = crate::Client::open(connection_info)?; + + match Runtime::locate() { + #[cfg(feature = "tokio-comp")] + rt @ Runtime::Tokio => { + rt.timeout( + connection_timeout, + client.get_multiplexed_async_connection_inner::( + response_timeout, + socket_addr, + glide_connection_options, + ), + ) + .await? + } + } + } + .boxed() + } +} + +#[cfg(test)] +mod pipeline_routing_tests { + use super::route_for_pipeline; + use crate::{ + cluster_routing::{Route, SlotAddr}, + cmd, + }; + + #[test] + fn test_first_route_is_found() { + let mut pipeline = crate::Pipeline::new(); + + pipeline + .add_command(cmd("FLUSHALL")) // route to all masters + .get("foo") // route to slot 12182 + .add_command(cmd("EVAL")); // route randomly + + assert_eq!( + route_for_pipeline(&pipeline), + Ok(Some(Route::new(12182, SlotAddr::ReplicaOptional))) + ); + } + + #[test] + fn test_return_none_if_no_route_is_found() { + let mut pipeline = crate::Pipeline::new(); + + pipeline + .add_command(cmd("FLUSHALL")) // route to all masters + .add_command(cmd("EVAL")); // route randomly + + assert_eq!(route_for_pipeline(&pipeline), Ok(None)); + } + + #[test] + fn test_prefer_primary_route_over_replica() { + let mut pipeline = crate::Pipeline::new(); + + pipeline + .get("foo") // route to replica of slot 12182 + .add_command(cmd("FLUSHALL")) // route to all masters + .add_command(cmd("EVAL"))// route randomly + .cmd("CONFIG").arg("GET").arg("timeout") // unkeyed command + .set("foo", "bar"); // route to primary of slot 12182 + + assert_eq!( + route_for_pipeline(&pipeline), + Ok(Some(Route::new(12182, SlotAddr::Master))) + ); + } + + #[test] + fn test_raise_cross_slot_error_on_conflicting_slots() { + let mut pipeline = crate::Pipeline::new(); + + pipeline + .add_command(cmd("FLUSHALL")) // route to all masters + .set("baz", "bar") // route to slot 4813 + .get("foo"); // route to slot 12182 + + assert_eq!( + route_for_pipeline(&pipeline).unwrap_err().kind(), + crate::ErrorKind::CrossSlot + ); + } + + #[test] + fn unkeyed_commands_dont_affect_route() { + let mut pipeline = crate::Pipeline::new(); + + pipeline + .set("{foo}bar", "baz") // route to primary of slot 12182 + .cmd("CONFIG").arg("GET").arg("timeout") // unkeyed command + .set("foo", "bar") // route to primary of slot 12182 + .cmd("DEBUG").arg("PAUSE").arg("100") // unkeyed command + .cmd("ECHO").arg("hello world"); // unkeyed command + + assert_eq!( + route_for_pipeline(&pipeline), + Ok(Some(Route::new(12182, SlotAddr::Master))) + ); + } +} diff --git a/glide-core/redis-rs/redis/src/cluster_client.rs b/glide-core/redis-rs/redis/src/cluster_client.rs new file mode 100644 index 0000000000..c4dc0103dc --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_client.rs @@ -0,0 +1,767 @@ +use crate::cluster_slotmap::ReadFromReplicaStrategy; +#[cfg(feature = "cluster-async")] +use crate::cluster_topology::{ + DEFAULT_SLOTS_REFRESH_MAX_JITTER_MILLI, DEFAULT_SLOTS_REFRESH_WAIT_DURATION, +}; +use crate::connection::{ConnectionAddr, ConnectionInfo, IntoConnectionInfo}; +use crate::types::{ErrorKind, ProtocolVersion, RedisError, RedisResult}; +use crate::{cluster, cluster::TlsMode}; +use crate::{PubSubSubscriptionInfo, PushInfo}; +use rand::Rng; +#[cfg(feature = "cluster-async")] +use std::ops::Add; +use std::time::Duration; + +#[cfg(feature = "tls-rustls")] +use crate::tls::TlsConnParams; + +#[cfg(not(feature = "tls-rustls"))] +use crate::connection::TlsConnParams; + +#[cfg(feature = "cluster-async")] +use crate::cluster_async; + +#[cfg(feature = "tls-rustls")] +use crate::tls::{retrieve_tls_certificates, TlsCertificates}; + +use tokio::sync::mpsc; + +/// Parameters specific to builder, so that +/// builder parameters may have different types +/// than final ClusterParams +#[derive(Default)] +struct BuilderParams { + password: Option, + username: Option, + read_from_replicas: ReadFromReplicaStrategy, + tls: Option, + #[cfg(feature = "tls-rustls")] + certs: Option, + retries_configuration: RetryParams, + connection_timeout: Option, + #[cfg(feature = "cluster-async")] + topology_checks_interval: Option, + #[cfg(feature = "cluster-async")] + connections_validation_interval: Option, + #[cfg(feature = "cluster-async")] + slots_refresh_rate_limit: SlotsRefreshRateLimit, + client_name: Option, + response_timeout: Option, + protocol: ProtocolVersion, + pubsub_subscriptions: Option, +} + +#[derive(Clone)] +pub(crate) struct RetryParams { + pub(crate) number_of_retries: u32, + max_wait_time: u64, + min_wait_time: u64, + exponent_base: u64, + factor: u64, +} + +impl Default for RetryParams { + fn default() -> Self { + const DEFAULT_RETRIES: u32 = 16; + const DEFAULT_MAX_RETRY_WAIT_TIME: u64 = 655360; + const DEFAULT_MIN_RETRY_WAIT_TIME: u64 = 1280; + const DEFAULT_EXPONENT_BASE: u64 = 2; + const DEFAULT_FACTOR: u64 = 10; + Self { + number_of_retries: DEFAULT_RETRIES, + max_wait_time: DEFAULT_MAX_RETRY_WAIT_TIME, + min_wait_time: DEFAULT_MIN_RETRY_WAIT_TIME, + exponent_base: DEFAULT_EXPONENT_BASE, + factor: DEFAULT_FACTOR, + } + } +} + +impl RetryParams { + pub(crate) fn wait_time_for_retry(&self, retry: u32) -> Duration { + let base_wait = self.exponent_base.pow(retry) * self.factor; + let clamped_wait = base_wait + .min(self.max_wait_time) + .max(self.min_wait_time + 1); + let jittered_wait = rand::thread_rng().gen_range(self.min_wait_time..clamped_wait); + Duration::from_millis(jittered_wait) + } +} + +/// Configuration for rate limiting slot refresh operations in a Redis cluster. +/// +/// This struct defines the interval duration between consecutive slot refresh +/// operations and an additional jitter to introduce randomness in the refresh intervals. +/// +/// # Fields +/// +/// * `interval_duration`: The minimum duration to wait between consecutive slot refresh operations. +/// * `max_jitter_milli`: The maximum jitter in milliseconds to add to the interval duration. +#[cfg(feature = "cluster-async")] +#[derive(Clone, Copy)] +pub(crate) struct SlotsRefreshRateLimit { + pub(crate) interval_duration: Duration, + pub(crate) max_jitter_milli: u64, +} + +#[cfg(feature = "cluster-async")] +impl Default for SlotsRefreshRateLimit { + fn default() -> Self { + Self { + interval_duration: DEFAULT_SLOTS_REFRESH_WAIT_DURATION, + max_jitter_milli: DEFAULT_SLOTS_REFRESH_MAX_JITTER_MILLI, + } + } +} + +#[cfg(feature = "cluster-async")] +impl SlotsRefreshRateLimit { + pub(crate) fn wait_duration(&self) -> Duration { + let duration_jitter = match self.max_jitter_milli { + 0 => Duration::from_millis(0), + _ => Duration::from_millis(rand::thread_rng().gen_range(0..self.max_jitter_milli)), + }; + self.interval_duration.add(duration_jitter) + } +} +/// Redis cluster specific parameters. +#[derive(Default, Clone)] +#[doc(hidden)] +pub struct ClusterParams { + pub(crate) password: Option, + pub(crate) username: Option, + pub(crate) read_from_replicas: ReadFromReplicaStrategy, + /// tls indicates tls behavior of connections. + /// When Some(TlsMode), connections use tls and verify certification depends on TlsMode. + /// When None, connections do not use tls. + pub(crate) tls: Option, + pub(crate) retry_params: RetryParams, + #[cfg(feature = "cluster-async")] + pub(crate) topology_checks_interval: Option, + #[cfg(feature = "cluster-async")] + pub(crate) slots_refresh_rate_limit: SlotsRefreshRateLimit, + #[cfg(feature = "cluster-async")] + pub(crate) connections_validation_interval: Option, + pub(crate) tls_params: Option, + pub(crate) client_name: Option, + pub(crate) connection_timeout: Duration, + pub(crate) response_timeout: Duration, + pub(crate) protocol: ProtocolVersion, + pub(crate) pubsub_subscriptions: Option, +} + +impl ClusterParams { + fn from(value: BuilderParams) -> RedisResult { + #[cfg(not(feature = "tls-rustls"))] + let tls_params = None; + + #[cfg(feature = "tls-rustls")] + let tls_params = { + let retrieved_tls_params = value.certs.clone().map(retrieve_tls_certificates); + + retrieved_tls_params.transpose()? + }; + + Ok(Self { + password: value.password, + username: value.username, + read_from_replicas: value.read_from_replicas, + tls: value.tls, + retry_params: value.retries_configuration, + connection_timeout: value.connection_timeout.unwrap_or(Duration::MAX), + #[cfg(feature = "cluster-async")] + topology_checks_interval: value.topology_checks_interval, + #[cfg(feature = "cluster-async")] + slots_refresh_rate_limit: value.slots_refresh_rate_limit, + #[cfg(feature = "cluster-async")] + connections_validation_interval: value.connections_validation_interval, + tls_params, + client_name: value.client_name, + response_timeout: value.response_timeout.unwrap_or(Duration::MAX), + protocol: value.protocol, + pubsub_subscriptions: value.pubsub_subscriptions, + }) + } +} + +/// Used to configure and build a [`ClusterClient`]. +pub struct ClusterClientBuilder { + initial_nodes: RedisResult>, + builder_params: BuilderParams, +} + +impl ClusterClientBuilder { + /// Creates a new `ClusterClientBuilder` with the provided initial_nodes. + /// + /// This is the same as `ClusterClient::builder(initial_nodes)`. + pub fn new( + initial_nodes: impl IntoIterator, + ) -> ClusterClientBuilder { + ClusterClientBuilder { + initial_nodes: initial_nodes + .into_iter() + .map(|x| x.into_connection_info()) + .collect(), + builder_params: Default::default(), + } + } + + /// Creates a new [`ClusterClient`] from the parameters. + /// + /// This does not create connections to the Redis Cluster, but only performs some basic checks + /// on the initial nodes' URLs and passwords/usernames. + /// + /// When the `tls-rustls` feature is enabled and TLS credentials are provided, they are set for + /// each cluster connection. + /// + /// # Errors + /// + /// Upon failure to parse initial nodes or if the initial nodes have different passwords or + /// usernames, an error is returned. + pub fn build(self) -> RedisResult { + let initial_nodes = self.initial_nodes?; + + let first_node = match initial_nodes.first() { + Some(node) => node, + None => { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Initial nodes can't be empty.", + ))) + } + }; + + let mut cluster_params = ClusterParams::from(self.builder_params)?; + let password = if cluster_params.password.is_none() { + cluster_params + .password + .clone_from(&first_node.redis.password); + &cluster_params.password + } else { + &None + }; + let username = if cluster_params.username.is_none() { + cluster_params + .username + .clone_from(&first_node.redis.username); + &cluster_params.username + } else { + &None + }; + if cluster_params.tls.is_none() { + cluster_params.tls = match first_node.addr { + ConnectionAddr::TcpTls { + host: _, + port: _, + insecure, + tls_params: _, + } => Some(match insecure { + false => TlsMode::Secure, + true => TlsMode::Insecure, + }), + _ => None, + }; + } + + let mut nodes = Vec::with_capacity(initial_nodes.len()); + for mut node in initial_nodes { + if let ConnectionAddr::Unix(_) = node.addr { + return Err(RedisError::from((ErrorKind::InvalidClientConfig, + "This library cannot use unix socket because Redis's cluster command returns only cluster's IP and port."))); + } + + if password.is_some() && node.redis.password != *password { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Cannot use different password among initial nodes.", + ))); + } + + if username.is_some() && node.redis.username != *username { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Cannot use different username among initial nodes.", + ))); + } + + if node.redis.client_name.is_some() + && node.redis.client_name != cluster_params.client_name + { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Cannot use different client_name among initial nodes.", + ))); + } + + node.redis.protocol = cluster_params.protocol; + nodes.push(node); + } + + Ok(ClusterClient { + initial_nodes: nodes, + cluster_params, + }) + } + + /// Sets client name for the new ClusterClient. + pub fn client_name(mut self, client_name: String) -> ClusterClientBuilder { + self.builder_params.client_name = Some(client_name); + self + } + + /// Sets password for the new ClusterClient. + pub fn password(mut self, password: String) -> ClusterClientBuilder { + self.builder_params.password = Some(password); + self + } + + /// Sets username for the new ClusterClient. + pub fn username(mut self, username: String) -> ClusterClientBuilder { + self.builder_params.username = Some(username); + self + } + + /// Sets number of retries for the new ClusterClient. + pub fn retries(mut self, retries: u32) -> ClusterClientBuilder { + self.builder_params.retries_configuration.number_of_retries = retries; + self + } + + /// Sets maximal wait time in milliseconds between retries for the new ClusterClient. + pub fn max_retry_wait(mut self, max_wait: u64) -> ClusterClientBuilder { + self.builder_params.retries_configuration.max_wait_time = max_wait; + self + } + + /// Sets minimal wait time in milliseconds between retries for the new ClusterClient. + pub fn min_retry_wait(mut self, min_wait: u64) -> ClusterClientBuilder { + self.builder_params.retries_configuration.min_wait_time = min_wait; + self + } + + /// Sets the factor and exponent base for the retry wait time. + /// The formula for the wait is rand(min_wait_retry .. min(max_retry_wait , factor * exponent_base ^ retry))ms. + pub fn retry_wait_formula(mut self, factor: u64, exponent_base: u64) -> ClusterClientBuilder { + self.builder_params.retries_configuration.factor = factor; + self.builder_params.retries_configuration.exponent_base = exponent_base; + self + } + + /// Sets TLS mode for the new ClusterClient. + /// + /// It is extracted from the first node of initial_nodes if not set. + #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))] + pub fn tls(mut self, tls: TlsMode) -> ClusterClientBuilder { + self.builder_params.tls = Some(tls); + self + } + + /// Sets raw TLS certificates for the new ClusterClient. + /// + /// When set, enforces the connection must be TLS secured. + /// + /// All certificates must be provided as byte streams loaded from PEM files their consistency is + /// checked during `build()` call. + /// + /// - `certificates` - `TlsCertificates` structure containing: + /// -- `client_tls` - Optional `ClientTlsConfig` containing byte streams for + /// --- `client_cert` - client's byte stream containing client certificate in PEM format + /// --- `client_key` - client's byte stream containing private key in PEM format + /// -- `root_cert` - Optional byte stream yielding PEM formatted file for root certificates. + /// + /// If `ClientTlsConfig` ( cert+key pair ) is not provided, then client-side authentication is not enabled. + /// If `root_cert` is not provided, then system root certificates are used instead. + #[cfg(feature = "tls-rustls")] + pub fn certs(mut self, certificates: TlsCertificates) -> ClusterClientBuilder { + self.builder_params.tls = Some(TlsMode::Secure); + self.builder_params.certs = Some(certificates); + self + } + + /// Enables reading from replicas for all new connections (default is disabled). + /// + /// If enabled, then read queries will go to the replica nodes & write queries will go to the + /// primary nodes. If there are no replica nodes, then all queries will go to the primary nodes. + pub fn read_from_replicas(mut self) -> ClusterClientBuilder { + self.builder_params.read_from_replicas = ReadFromReplicaStrategy::RoundRobin; + self + } + + /// Set the read strategy for this client. + /// + /// The parameter `read_strategy` can be one of: + /// `ReadFromReplicaStrategy::AZAffinity(availability_zone)` - attempt to access replicas in the same availability zone. + /// If no suitable replica is found (i.e. no replica could be found in the requested availability zone), choose any replica. Falling back to primary if needed. + /// `ReadFromReplicaStrategy::RoundRobin` - reads are distributed across replicas for load balancing using round-robin algorithm. Falling back to primary if needed. + /// `ReadFromReplicaStrategy::AlwaysFromPrimary` ensures all read and write queries are directed to the primary node. + /// + /// # Parameters + /// - `read_strategy`: defines the replica routing strategy. + pub fn read_from(mut self, read_strategy: ReadFromReplicaStrategy) -> ClusterClientBuilder { + self.builder_params.read_from_replicas = read_strategy; + self + } + + /// Enables periodic topology checks for this client. + /// + /// If enabled, periodic topology checks will be executed at the configured intervals to examine whether there + /// have been any changes in the cluster's topology. If a change is detected, it will trigger a slot refresh. + /// Unlike slot refreshments, the periodic topology checks only examine a limited number of nodes to query their + /// topology, ensuring that the check remains quick and efficient. + #[cfg(feature = "cluster-async")] + pub fn periodic_topology_checks(mut self, interval: Duration) -> ClusterClientBuilder { + self.builder_params.topology_checks_interval = Some(interval); + self + } + + /// Enables periodic connections checks for this client. + /// If enabled, the connections to the cluster nodes will be validated periodically, per configured interval. + /// In addition, for tokio runtime, passive disconnections could be detected instantly, + /// triggering reestablishment, w/o waiting for the next periodic check. + #[cfg(feature = "cluster-async")] + pub fn periodic_connections_checks(mut self, interval: Duration) -> ClusterClientBuilder { + self.builder_params.connections_validation_interval = Some(interval); + self + } + + /// Sets the rate limit for slot refresh operations in the cluster. + /// + /// This method configures the interval duration between consecutive slot + /// refresh operations and an additional jitter to introduce randomness + /// in the refresh intervals. + /// + /// # Parameters + /// + /// * `interval_duration`: The minimum duration to wait between consecutive slot refresh operations. + /// * `max_jitter_milli`: The maximum jitter in milliseconds to add to the interval duration. + /// + /// # Defaults + /// + /// If not set, the slots refresh rate limit configurations will be set with the default values: + /// ``` + /// #[cfg(feature = "cluster-async")] + /// use redis::cluster_topology::{DEFAULT_SLOTS_REFRESH_MAX_JITTER_MILLI, DEFAULT_SLOTS_REFRESH_WAIT_DURATION}; + /// ``` + /// + /// - `interval_duration`: `DEFAULT_SLOTS_REFRESH_WAIT_DURATION` + /// - `max_jitter_milli`: `DEFAULT_SLOTS_REFRESH_MAX_JITTER_MILLI` + /// + #[cfg(feature = "cluster-async")] + pub fn slots_refresh_rate_limit( + mut self, + interval_duration: Duration, + max_jitter_milli: u64, + ) -> ClusterClientBuilder { + self.builder_params.slots_refresh_rate_limit = SlotsRefreshRateLimit { + interval_duration, + max_jitter_milli, + }; + self + } + + /// Enables timing out on slow connection time. + /// + /// If enabled, the cluster will only wait the given time on each connection attempt to each node. + pub fn connection_timeout(mut self, connection_timeout: Duration) -> ClusterClientBuilder { + self.builder_params.connection_timeout = Some(connection_timeout); + self + } + + /// Enables timing out on slow responses. + /// + /// If enabled, the cluster will only wait the given time to each response from each node. + pub fn response_timeout(mut self, response_timeout: Duration) -> ClusterClientBuilder { + self.builder_params.response_timeout = Some(response_timeout); + self + } + + /// Sets the protocol with which the client should communicate with the server. + pub fn use_protocol(mut self, protocol: ProtocolVersion) -> ClusterClientBuilder { + self.builder_params.protocol = protocol; + self + } + + /// Use `build()`. + #[deprecated(since = "0.22.0", note = "Use build()")] + pub fn open(self) -> RedisResult { + self.build() + } + + /// Use `read_from_replicas()`. + #[deprecated(since = "0.22.0", note = "Use read_from_replicas()")] + pub fn readonly(mut self, read_from_replicas: bool) -> ClusterClientBuilder { + self.builder_params.read_from_replicas = if read_from_replicas { + ReadFromReplicaStrategy::RoundRobin + } else { + ReadFromReplicaStrategy::AlwaysFromPrimary + }; + self + } + + /// Sets the pubsub configuration for the new ClusterClient. + pub fn pubsub_subscriptions( + mut self, + pubsub_subscriptions: PubSubSubscriptionInfo, + ) -> ClusterClientBuilder { + self.builder_params.pubsub_subscriptions = Some(pubsub_subscriptions); + self + } +} + +/// This is a Redis Cluster client. +#[derive(Clone)] +pub struct ClusterClient { + initial_nodes: Vec, + cluster_params: ClusterParams, +} + +impl ClusterClient { + /// Creates a `ClusterClient` with the default parameters. + /// + /// This does not create connections to the Redis Cluster, but only performs some basic checks + /// on the initial nodes' URLs and passwords/usernames. + /// + /// # Errors + /// + /// Upon failure to parse initial nodes or if the initial nodes have different passwords or + /// usernames, an error is returned. + pub fn new( + initial_nodes: impl IntoIterator, + ) -> RedisResult { + Self::builder(initial_nodes).build() + } + + /// Creates a [`ClusterClientBuilder`] with the provided initial_nodes. + pub fn builder( + initial_nodes: impl IntoIterator, + ) -> ClusterClientBuilder { + ClusterClientBuilder::new(initial_nodes) + } + + /// Creates new connections to Redis Cluster nodes and returns a + /// [`cluster::ClusterConnection`]. + /// + /// # Errors + /// + /// An error is returned if there is a failure while creating connections or slots. + pub fn get_connection( + &self, + push_sender: Option>, + ) -> RedisResult { + cluster::ClusterConnection::new( + self.cluster_params.clone(), + self.initial_nodes.clone(), + push_sender, + ) + } + + /// Creates new connections to Redis Cluster nodes and returns a + /// [`cluster_async::ClusterConnection`]. + /// + /// # Errors + /// + /// An error is returned if there is a failure while creating connections or slots. + #[cfg(feature = "cluster-async")] + pub async fn get_async_connection( + &self, + push_sender: Option>, + ) -> RedisResult { + cluster_async::ClusterConnection::new( + &self.initial_nodes, + self.cluster_params.clone(), + push_sender, + ) + .await + } + + #[doc(hidden)] + pub fn get_generic_connection( + &self, + push_sender: Option>, + ) -> RedisResult> + where + C: crate::ConnectionLike + crate::cluster::Connect + Send, + { + cluster::ClusterConnection::new( + self.cluster_params.clone(), + self.initial_nodes.clone(), + push_sender, + ) + } + + #[doc(hidden)] + #[cfg(feature = "cluster-async")] + pub async fn get_async_generic_connection( + &self, + ) -> RedisResult> + where + C: crate::aio::ConnectionLike + + cluster_async::Connect + + Clone + + Send + + Sync + + Unpin + + 'static, + { + cluster_async::ClusterConnection::new( + &self.initial_nodes, + self.cluster_params.clone(), + None, + ) + .await + } + + /// Use `new()`. + #[deprecated(since = "0.22.0", note = "Use new()")] + pub fn open(initial_nodes: Vec) -> RedisResult { + Self::new(initial_nodes) + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "cluster-async")] + use crate::cluster_topology::{ + DEFAULT_SLOTS_REFRESH_MAX_JITTER_MILLI, DEFAULT_SLOTS_REFRESH_WAIT_DURATION, + }; + + use super::{ClusterClient, ClusterClientBuilder, ConnectionInfo, IntoConnectionInfo}; + + fn get_connection_data() -> Vec { + vec![ + "redis://127.0.0.1:6379".into_connection_info().unwrap(), + "redis://127.0.0.1:6378".into_connection_info().unwrap(), + "redis://127.0.0.1:6377".into_connection_info().unwrap(), + ] + } + + fn get_connection_data_with_password() -> Vec { + vec![ + "redis://:password@127.0.0.1:6379" + .into_connection_info() + .unwrap(), + "redis://:password@127.0.0.1:6378" + .into_connection_info() + .unwrap(), + "redis://:password@127.0.0.1:6377" + .into_connection_info() + .unwrap(), + ] + } + + fn get_connection_data_with_username_and_password() -> Vec { + vec![ + "redis://user1:password@127.0.0.1:6379" + .into_connection_info() + .unwrap(), + "redis://user1:password@127.0.0.1:6378" + .into_connection_info() + .unwrap(), + "redis://user1:password@127.0.0.1:6377" + .into_connection_info() + .unwrap(), + ] + } + + #[test] + fn give_no_password() { + let client = ClusterClient::new(get_connection_data()).unwrap(); + assert_eq!(client.cluster_params.password, None); + } + + #[test] + fn give_password_by_initial_nodes() { + let client = ClusterClient::new(get_connection_data_with_password()).unwrap(); + assert_eq!(client.cluster_params.password, Some("password".to_string())); + } + + #[test] + fn give_username_and_password_by_initial_nodes() { + let client = ClusterClient::new(get_connection_data_with_username_and_password()).unwrap(); + assert_eq!(client.cluster_params.password, Some("password".to_string())); + assert_eq!(client.cluster_params.username, Some("user1".to_string())); + } + + #[test] + fn give_different_password_by_initial_nodes() { + let result = ClusterClient::new(vec![ + "redis://:password1@127.0.0.1:6379", + "redis://:password2@127.0.0.1:6378", + "redis://:password3@127.0.0.1:6377", + ]); + assert!(result.is_err()); + } + + #[test] + fn give_different_username_by_initial_nodes() { + let result = ClusterClient::new(vec![ + "redis://user1:password@127.0.0.1:6379", + "redis://user2:password@127.0.0.1:6378", + "redis://user1:password@127.0.0.1:6377", + ]); + assert!(result.is_err()); + } + + #[test] + fn give_username_password_by_method() { + let client = ClusterClientBuilder::new(get_connection_data_with_password()) + .password("pass".to_string()) + .username("user1".to_string()) + .build() + .unwrap(); + assert_eq!(client.cluster_params.password, Some("pass".to_string())); + assert_eq!(client.cluster_params.username, Some("user1".to_string())); + } + + #[test] + fn give_empty_initial_nodes() { + let client = ClusterClient::new(Vec::::new()); + assert!(client.is_err()) + } + + #[cfg(feature = "cluster-async")] + #[test] + fn give_slots_refresh_rate_limit_configurations() { + let interval_dur = std::time::Duration::from_secs(20); + let client = ClusterClientBuilder::new(get_connection_data()) + .slots_refresh_rate_limit(interval_dur, 500) + .build() + .unwrap(); + assert_eq!( + client + .cluster_params + .slots_refresh_rate_limit + .interval_duration, + interval_dur + ); + assert_eq!( + client + .cluster_params + .slots_refresh_rate_limit + .max_jitter_milli, + 500 + ); + } + + #[cfg(feature = "cluster-async")] + #[test] + fn dont_give_slots_refresh_rate_limit_configurations_uses_defaults() { + let client = ClusterClientBuilder::new(get_connection_data()) + .build() + .unwrap(); + assert_eq!( + client + .cluster_params + .slots_refresh_rate_limit + .interval_duration, + DEFAULT_SLOTS_REFRESH_WAIT_DURATION + ); + assert_eq!( + client + .cluster_params + .slots_refresh_rate_limit + .max_jitter_milli, + DEFAULT_SLOTS_REFRESH_MAX_JITTER_MILLI + ); + } +} diff --git a/glide-core/redis-rs/redis/src/cluster_pipeline.rs b/glide-core/redis-rs/redis/src/cluster_pipeline.rs new file mode 100644 index 0000000000..9da1fee781 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_pipeline.rs @@ -0,0 +1,151 @@ +use crate::cluster::ClusterConnection; +use crate::cmd::{cmd, Cmd}; +use crate::types::{ + from_owned_redis_value, ErrorKind, FromRedisValue, HashSet, RedisResult, ToRedisArgs, Value, +}; + +pub(crate) const UNROUTABLE_ERROR: (ErrorKind, &str) = ( + ErrorKind::ClientError, + "This command cannot be safely routed in cluster mode", +); + +fn is_illegal_cmd(cmd: &str) -> bool { + matches!( + cmd, + "BGREWRITEAOF" | "BGSAVE" | "BITOP" | "BRPOPLPUSH" | + // All commands that start with "CLIENT" + "CLIENT" | "CLIENT GETNAME" | "CLIENT KILL" | "CLIENT LIST" | "CLIENT SETNAME" | + // All commands that start with "CONFIG" + "CONFIG" | "CONFIG GET" | "CONFIG RESETSTAT" | "CONFIG REWRITE" | "CONFIG SET" | + "DBSIZE" | + "ECHO" | "EVALSHA" | + "FLUSHALL" | "FLUSHDB" | + "INFO" | + "KEYS" | + "LASTSAVE" | + "MGET" | "MOVE" | "MSET" | "MSETNX" | + "PFMERGE" | "PFCOUNT" | "PING" | "PUBLISH" | + "RANDOMKEY" | "RENAME" | "RENAMENX" | "RPOPLPUSH" | + "SAVE" | "SCAN" | + // All commands that start with "SCRIPT" + "SCRIPT" | "SCRIPT EXISTS" | "SCRIPT FLUSH" | "SCRIPT KILL" | "SCRIPT LOAD" | + "SDIFF" | "SDIFFSTORE" | + // All commands that start with "SENTINEL" + "SENTINEL" | "SENTINEL GET MASTER ADDR BY NAME" | "SENTINEL MASTER" | "SENTINEL MASTERS" | + "SENTINEL MONITOR" | "SENTINEL REMOVE" | "SENTINEL SENTINELS" | "SENTINEL SET" | + "SENTINEL SLAVES" | "SHUTDOWN" | "SINTER" | "SINTERSTORE" | "SLAVEOF" | + // All commands that start with "SLOWLOG" + "SLOWLOG" | "SLOWLOG GET" | "SLOWLOG LEN" | "SLOWLOG RESET" | + "SMOVE" | "SORT" | "SUNION" | "SUNIONSTORE" | + "TIME" + ) +} + +/// Represents a Redis Cluster command pipeline. +#[derive(Clone)] +pub struct ClusterPipeline { + commands: Vec, + ignored_commands: HashSet, +} + +/// A cluster pipeline is almost identical to a normal [Pipeline](crate::pipeline::Pipeline), with two exceptions: +/// * It does not support transactions +/// * The following commands can not be used in a cluster pipeline: +/// ```text +/// BGREWRITEAOF, BGSAVE, BITOP, BRPOPLPUSH +/// CLIENT GETNAME, CLIENT KILL, CLIENT LIST, CLIENT SETNAME, CONFIG GET, +/// CONFIG RESETSTAT, CONFIG REWRITE, CONFIG SET +/// DBSIZE +/// ECHO, EVALSHA +/// FLUSHALL, FLUSHDB +/// INFO +/// KEYS +/// LASTSAVE +/// MGET, MOVE, MSET, MSETNX +/// PFMERGE, PFCOUNT, PING, PUBLISH +/// RANDOMKEY, RENAME, RENAMENX, RPOPLPUSH +/// SAVE, SCAN, SCRIPT EXISTS, SCRIPT FLUSH, SCRIPT KILL, SCRIPT LOAD, SDIFF, SDIFFSTORE, +/// SENTINEL GET MASTER ADDR BY NAME, SENTINEL MASTER, SENTINEL MASTERS, SENTINEL MONITOR, +/// SENTINEL REMOVE, SENTINEL SENTINELS, SENTINEL SET, SENTINEL SLAVES, SHUTDOWN, SINTER, +/// SINTERSTORE, SLAVEOF, SLOWLOG GET, SLOWLOG LEN, SLOWLOG RESET, SMOVE, SORT, SUNION, SUNIONSTORE +/// TIME +/// ``` +impl ClusterPipeline { + /// Create an empty pipeline. + pub fn new() -> ClusterPipeline { + Self::with_capacity(0) + } + + /// Creates an empty pipeline with pre-allocated capacity. + pub fn with_capacity(capacity: usize) -> ClusterPipeline { + ClusterPipeline { + commands: Vec::with_capacity(capacity), + ignored_commands: HashSet::new(), + } + } + + pub(crate) fn commands(&self) -> &Vec { + &self.commands + } + + /// Executes the pipeline and fetches the return values: + /// + /// ```rust,no_run + /// # let nodes = vec!["redis://127.0.0.1:6379/"]; + /// # let client = redis::cluster::ClusterClient::new(nodes).unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// let mut pipe = redis::cluster::cluster_pipe(); + /// let (k1, k2) : (i32, i32) = pipe + /// .cmd("SET").arg("key_1").arg(42).ignore() + /// .cmd("SET").arg("key_2").arg(43).ignore() + /// .cmd("GET").arg("key_1") + /// .cmd("GET").arg("key_2").query(&mut con).unwrap(); + /// ``` + #[inline] + pub fn query(&self, con: &mut ClusterConnection) -> RedisResult { + for cmd in &self.commands { + let cmd_name = std::str::from_utf8(cmd.arg_idx(0).unwrap_or(b"")) + .unwrap_or("") + .trim() + .to_ascii_uppercase(); + + if is_illegal_cmd(&cmd_name) { + fail!(( + UNROUTABLE_ERROR.0, + UNROUTABLE_ERROR.1, + format!("Command '{cmd_name}' can't be executed in a cluster pipeline.") + )) + } + } + + from_owned_redis_value(if self.commands.is_empty() { + Value::Array(vec![]) + } else { + self.make_pipeline_results(con.execute_pipeline(self)?) + }) + } + + /// This is a shortcut to `query()` that does not return a value and + /// will fail the task if the query of the pipeline fails. + /// + /// This is equivalent to a call to query like this: + /// + /// ```rust,no_run + /// # let nodes = vec!["redis://127.0.0.1:6379/"]; + /// # let client = redis::cluster::ClusterClient::new(nodes).unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// let mut pipe = redis::cluster::cluster_pipe(); + /// let _ : () = pipe.cmd("SET").arg("key_1").arg(42).ignore().query(&mut con).unwrap(); + /// ``` + #[inline] + pub fn execute(&self, con: &mut ClusterConnection) { + self.query::<()>(con).unwrap(); + } +} + +/// Shortcut for creating a new cluster pipeline. +pub fn cluster_pipe() -> ClusterPipeline { + ClusterPipeline::new() +} + +implement_pipeline_commands!(ClusterPipeline); diff --git a/glide-core/redis-rs/redis/src/cluster_routing.rs b/glide-core/redis-rs/redis/src/cluster_routing.rs new file mode 100644 index 0000000000..eab3bf398a --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_routing.rs @@ -0,0 +1,2026 @@ +use rand::Rng; + +use crate::cluster_topology::get_slot; +use crate::cmd::{Arg, Cmd}; +use crate::types::Value; +use crate::{ErrorKind, RedisError, RedisResult}; +use core::cmp::Ordering; +use std::borrow::Cow; +use std::cmp::min; +use std::collections::HashMap; +use std::iter::Once; +use std::sync::Arc; +use std::sync::{RwLock, RwLockWriteGuard}; + +#[derive(Clone)] +pub(crate) enum Redirect { + Moved(String), + Ask(String), +} + +/// Logical bitwise aggregating operators. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum LogicalAggregateOp { + /// Aggregate by bitwise && + And, + // Or, omitted due to dead code warnings. ATM this value isn't constructed anywhere +} + +/// Numerical aggregating operators. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum AggregateOp { + /// Choose minimal value + Min, + /// Sum all values + Sum, + // Max, omitted due to dead code warnings. ATM this value isn't constructed anywhere +} + +/// Policy defining how to combine multiple responses into one. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum ResponsePolicy { + /// Wait for one request to succeed and return its results. Return error if all requests fail. + OneSucceeded, + /// Returns the first succeeded non-empty result; if all results are empty, returns `Nil`; otherwise, returns the last received error. + FirstSucceededNonEmptyOrAllEmpty, + /// Waits for all requests to succeed, and the returns one of the successes. Returns the error on the first received error. + AllSucceeded, + /// Aggregate success results according to a logical bitwise operator. Return error on any failed request or on a response that doesn't conform to 0 or 1. + AggregateLogical(LogicalAggregateOp), + /// Aggregate success results according to a numeric operator. Return error on any failed request or on a response that isn't an integer. + Aggregate(AggregateOp), + /// Aggregate array responses into a single array. Return error on any failed request or on a response that isn't an array. + CombineArrays, + /// Handling is not defined by the Redis standard. Will receive a special case + Special, + /// Combines multiple map responses into a single map. + CombineMaps, +} + +/// Defines whether a request should be routed to a single node, or multiple ones. +#[derive(Debug, Clone, PartialEq)] +pub enum RoutingInfo { + /// Route to single node + SingleNode(SingleNodeRoutingInfo), + /// Route to multiple nodes + MultiNode((MultipleNodeRoutingInfo, Option)), +} + +/// Defines which single node should receive a request. +#[derive(Debug, Clone, PartialEq)] +pub enum SingleNodeRoutingInfo { + /// Route to any node at random + Random, + /// Route to any *primary* node + RandomPrimary, + /// Route to the node that matches the [Route] + SpecificNode(Route), + /// Route to the node with the given address. + ByAddress { + /// DNS hostname of the node + host: String, + /// port of the node + port: u16, + }, +} + +impl From> for SingleNodeRoutingInfo { + fn from(value: Option) -> Self { + value + .map(SingleNodeRoutingInfo::SpecificNode) + .unwrap_or(SingleNodeRoutingInfo::Random) + } +} + +/// Defines which collection of nodes should receive a request +#[derive(Debug, Clone, PartialEq)] +pub enum MultipleNodeRoutingInfo { + /// Route to all nodes in the clusters + AllNodes, + /// Route to all primaries in the cluster + AllMasters, + /// Routes the request to multiple slots. + /// This variant contains instructions for splitting a multi-slot command (e.g., MGET, MSET) into sub-commands. + /// Each tuple consists of a `Route` representing the target node for the subcommand, + /// and a vector of argument indices from the original command that should be copied to each subcommand. + /// The `MultiSlotArgPattern` specifies the pattern of the command’s arguments, indicating how they are organized + /// (e.g., only keys, key-value pairs, etc). + MultiSlot((Vec<(Route, Vec)>, MultiSlotArgPattern)), +} + +/// Takes a routable and an iterator of indices, which is assued to be created from`MultipleNodeRoutingInfo::MultiSlot`, +/// and returns a command with the arguments matching the indices. +pub fn command_for_multi_slot_indices<'a, 'b>( + original_cmd: &'a impl Routable, + indices: impl Iterator + 'a, +) -> Cmd +where + 'b: 'a, +{ + let mut new_cmd = Cmd::new(); + let command_length = 1; // TODO - the +1 should change if we have multi-slot commands with 2 command words. + new_cmd.arg(original_cmd.arg_idx(0)); + for index in indices { + new_cmd.arg(original_cmd.arg_idx(index + command_length)); + } + new_cmd +} + +/// Aggreagte numeric responses. +pub fn aggregate(values: Vec, op: AggregateOp) -> RedisResult { + let initial_value = match op { + AggregateOp::Min => i64::MAX, + AggregateOp::Sum => 0, + }; + let result = values.into_iter().try_fold(initial_value, |acc, curr| { + let int = match curr { + Value::Int(int) => int, + _ => { + return RedisResult::Err( + ( + ErrorKind::TypeError, + "expected array of integers as response", + ) + .into(), + ); + } + }; + let acc = match op { + AggregateOp::Min => min(acc, int), + AggregateOp::Sum => acc + int, + }; + Ok(acc) + })?; + Ok(Value::Int(result)) +} + +/// Aggreagte numeric responses by a boolean operator. +pub fn logical_aggregate(values: Vec, op: LogicalAggregateOp) -> RedisResult { + let initial_value = match op { + LogicalAggregateOp::And => true, + }; + let results = values.into_iter().try_fold(Vec::new(), |acc, curr| { + let values = match curr { + Value::Array(values) => values, + _ => { + return RedisResult::Err( + ( + ErrorKind::TypeError, + "expected array of integers as response", + ) + .into(), + ); + } + }; + let mut acc = if acc.is_empty() { + vec![initial_value; values.len()] + } else { + acc + }; + for (index, value) in values.into_iter().enumerate() { + let int = match value { + Value::Int(int) => int, + _ => { + return Err(( + ErrorKind::TypeError, + "expected array of integers as response", + ) + .into()); + } + }; + acc[index] = match op { + LogicalAggregateOp::And => acc[index] && (int > 0), + }; + } + Ok(acc) + })?; + Ok(Value::Array( + results + .into_iter() + .map(|result| Value::Int(result as i64)) + .collect(), + )) +} +/// Aggregate array responses into a single map. +pub fn combine_map_results(values: Vec) -> RedisResult { + let mut map: HashMap, i64> = HashMap::new(); + + for value in values { + match value { + Value::Array(elements) => { + let mut iter = elements.into_iter(); + + while let Some(key) = iter.next() { + if let Value::BulkString(key_bytes) = key { + if let Some(Value::Int(value)) = iter.next() { + *map.entry(key_bytes).or_insert(0) += value; + } else { + return Err((ErrorKind::TypeError, "expected integer value").into()); + } + } else { + return Err((ErrorKind::TypeError, "expected string key").into()); + } + } + } + _ => { + return Err((ErrorKind::TypeError, "expected array of values as response").into()); + } + } + } + + let result_vec: Vec<(Value, Value)> = map + .into_iter() + .map(|(k, v)| (Value::BulkString(k), Value::Int(v))) + .collect(); + + Ok(Value::Map(result_vec)) +} + +/// Aggregate array responses into a single array. +pub fn combine_array_results(values: Vec) -> RedisResult { + let mut results = Vec::new(); + + for value in values { + match value { + Value::Array(values) => results.extend(values), + _ => { + return Err((ErrorKind::TypeError, "expected array of values as response").into()); + } + } + } + + Ok(Value::Array(results)) +} + +// An iterator that yields `Cow<[usize]>` representing grouped result indices according to a specified argument pattern. +// This type is used to combine multi-slot array responses. +type MultiSlotResIdxIter<'a> = std::iter::Map< + std::slice::Iter<'a, (Route, Vec)>, + fn(&'a (Route, Vec)) -> Cow<'a, [usize]>, +>; + +/// Generates an iterator that yields a vector of result indices for each slot within the final merged results array for a multi-slot command response. +/// The indices are calculated based on the `args_pattern` and the positions of the arguments for each slot-specific request in the original multi-slot request, +/// ensuring that the results are ordered according to the structure of the initial multi-slot command. +/// +/// # Arguments +/// * `route_arg_indices` - A reference to a vector where each element is a tuple containing a route and +/// the corresponding argument indices for that route. +/// * `args_pattern` - Specifies the argument pattern (e.g., `KeysOnly`, `KeyValuePairs`, ..), which defines how the indices are grouped for each slot. +/// +/// # Returns +/// An iterator yielding `Cow<[usize]>` with the grouped result indices based on the specified argument pattern. +/// +/// /// For example, given the command `MSET foo bar foo2 bar2 {foo}foo3 bar3` with the `KeyValuePairs` pattern: +/// - `route_arg_indices` would include: +/// - Slot of "foo" with argument indices `[0, 1, 4, 5]` (where `{foo}foo3` hashes to the same slot as "foo" due to curly braces). +/// - Slot of "foo2" with argument indices `[2, 3]`. +/// - Using the `KeyValuePairs` pattern, each key-value pair contributes a single response, yielding three responses total. +/// - Therefore, the iterator generated by this function would yield grouped result indices as follows: +/// - Slot "foo" is mapped to `[0, 2]` in the final result order. +/// - Slot "foo2" is mapped to `[1]`. +fn calculate_multi_slot_result_indices<'a>( + route_arg_indices: &'a [(Route, Vec)], + args_pattern: &MultiSlotArgPattern, +) -> RedisResult> { + let check_indices_input = |step_count: usize| { + for (_, indices) in route_arg_indices { + if indices.len() % step_count != 0 { + return Err(RedisError::from(( + ErrorKind::ClientError, + "Invalid indices input detected", + format!( + "Expected argument pattern with tuples of size {step_count}, but found indices: {indices:?}" + ), + ))); + } + } + Ok(()) + }; + + match args_pattern { + MultiSlotArgPattern::KeysOnly => Ok(route_arg_indices + .iter() + .map(|(_, indices)| Cow::Borrowed(indices))), + MultiSlotArgPattern::KeysAndLastArg => { + // The last index corresponds to the path, skip it + Ok(route_arg_indices + .iter() + .map(|(_, indices)| Cow::Borrowed(&indices[..indices.len() - 1]))) + } + MultiSlotArgPattern::KeyWithTwoArgTriples => { + // For each triplet (key, path, value) we receive a single response. + // For example, for argument indices: [(_, [0,1,2]), (_, [3,4,5,9,10,11]), (_, [6,7,8])] + // The resulting grouped indices would be: [0], [1, 3], [2] + check_indices_input(3)?; + Ok(route_arg_indices.iter().map(|(_, indices)| { + Cow::Owned( + indices + .iter() + .step_by(3) + .map(|idx| idx / 3) + .collect::>(), + ) + })) + } + MultiSlotArgPattern::KeyValuePairs => + // For each pair (key, value) we receive a single response. + // For example, for argument indices: [(_, [0,1]), (_, [2,3,6,7]), (_, [4,5])] + // The resulting grouped indices would be: [0], [1, 3], [2] + { + check_indices_input(2)?; + Ok(route_arg_indices.iter().map(|(_, indices)| { + Cow::Owned( + indices + .iter() + .step_by(2) + .map(|idx| idx / 2) + .collect::>(), + ) + })) + } + } +} + +/// Merges the results of a multi-slot command from the `values` field, where each entry is expected to be an array of results. +/// The combined results are ordered according to the sequence in which they appeared in the original command. +/// +/// # Arguments +/// +/// * `values` - A vector of `Value`s, where each `Value` is expected to be an array representing results +/// from separate slots in a multi-slot command. Each `Value::Array` within `values` corresponds to +/// the results associated with a specific slot, as indicated by `route_arg_indices`. +/// +/// * `route_arg_indices` - A reference to a vector of tuples, where each tuple represents a route and a vector of +/// argument indices associated with that route. The route indicates the slot, while the indices vector +/// specifies the positions of arguments relevant to this slot. This is used to construct `sorting_order`, +/// which guides the placement of results in the final array. +/// +/// * `args_pattern` - Specifies the argument pattern (e.g., `KeysOnly`, `KeyValuePairs`, ...). +/// The pattern defines how the argument indices are grouped for each slot and determines +/// the ordering of results from `values` as they are placed in the final combined array. +/// +/// # Returns +/// +/// Returns a `RedisResult` containing the final ordered array (`Value::Array`) of combined results. +pub(crate) fn combine_and_sort_array_results( + values: Vec, + route_arg_indices: &[(Route, Vec)], + args_pattern: &MultiSlotArgPattern, +) -> RedisResult { + let result_indices = calculate_multi_slot_result_indices(route_arg_indices, args_pattern)?; + let mut results = Vec::new(); + results.resize( + values.iter().fold(0, |acc, value| match value { + Value::Array(values) => values.len() + acc, + _ => 0, + }), + Value::Nil, + ); + if values.len() != result_indices.len() { + return Err(RedisError::from(( + ErrorKind::ClientError, + "Mismatch in the number of multi-slot results compared to the expected result count.", + format!( + "Expected: {:?}, Found: {:?}", + values.len(), + result_indices.len() + ), + ))); + } + + for (key_indices, value) in result_indices.into_iter().zip(values) { + match value { + Value::Array(values) => { + assert_eq!(values.len(), key_indices.len()); + for (index, value) in key_indices.iter().zip(values) { + results[*index] = value; + } + } + _ => { + return Err((ErrorKind::TypeError, "expected array of values as response").into()); + } + } + } + + Ok(Value::Array(results)) +} + +fn get_route(is_readonly: bool, key: &[u8]) -> Route { + let slot = get_slot(key); + if is_readonly { + Route::new(slot, SlotAddr::ReplicaOptional) + } else { + Route::new(slot, SlotAddr::Master) + } +} + +/// Represents the pattern of argument structures in multi-slot commands, +/// defining how the arguments are organized in the command. +#[derive(Debug, Clone, PartialEq)] +pub enum MultiSlotArgPattern { + /// Pattern where only keys are provided in the command. + /// For example: `MGET key1 key2` + KeysOnly, + + /// Pattern where each key is followed by a corresponding value. + /// For example: `MSET key1 value1 key2 value2` + KeyValuePairs, + + /// Pattern where a list of keys is followed by a shared parameter. + /// For example: `JSON.MGET key1 key2 key3 path` + KeysAndLastArg, + + /// Pattern where each key is followed by two associated arguments, forming key-argument-argument triples. + /// For example: `JSON.MSET key1 path1 value1 key2 path2 value2` + KeyWithTwoArgTriples, +} + +/// Takes the given `routable` and creates a multi-slot routing info. +/// This is used for commands like MSET & MGET, where if the command's keys +/// are hashed to multiple slots, the command should be split into sub-commands, +/// each targetting a single slot. The results of these sub-commands are then +/// usually reassembled using `combine_and_sort_array_results`. In order to do this, +/// `MultipleNodeRoutingInfo::MultiSlot` contains the routes for each sub-command, and +/// the indices in the final combined result for each result from the sub-command. +/// +/// If all keys are routed to the same slot, there's no need to split the command, +/// so a single node routing info will be returned. +/// +/// # Arguments +/// * `routable` - The command or structure containing key-related data that can be routed. +/// * `cmd` - A byte slice representing the command name or opcode (e.g., `b"MGET"`). +/// * `first_key_index` - The starting index in the command where the first key is located. +/// * `args_pattern` - Specifies how keys and values are patterned in the command (e.g., `OnlyKeys`, `KeyValuePairs`). +/// +/// # Returns +/// `Some(RoutingInfo)` if routing info is created, indicating the command targets multiple slots or a single slot; +/// `None` if no routing info could be derived. +fn multi_shard( + routable: &R, + cmd: &[u8], + first_key_index: usize, + args_pattern: MultiSlotArgPattern, +) -> Option +where + R: Routable + ?Sized, +{ + let is_readonly = is_readonly_cmd(cmd); + let mut routes = HashMap::new(); + let mut curr_arg_idx = 0; + let incr_add_next_arg = |arg_indices: &mut Vec, mut curr_arg_idx: usize| { + curr_arg_idx += 1; + // Ensure there's a value following the key + routable.arg_idx(curr_arg_idx)?; + arg_indices.push(curr_arg_idx); + Some(curr_arg_idx) + }; + while let Some(arg) = routable.arg_idx(first_key_index + curr_arg_idx) { + let route = get_route(is_readonly, arg); + let arg_indices = routes.entry(route).or_insert(Vec::new()); + + arg_indices.push(curr_arg_idx); + + match args_pattern { + MultiSlotArgPattern::KeysOnly => {} // no additional handling needed for keys-only commands + MultiSlotArgPattern::KeyValuePairs => { + // Increment to the value paired with the current key and add its index + curr_arg_idx = incr_add_next_arg(arg_indices, curr_arg_idx)?; + } + MultiSlotArgPattern::KeysAndLastArg => { + // Check if the command has more keys or if the next argument is a path + if routable + .arg_idx(first_key_index + curr_arg_idx + 2) + .is_none() + { + // Last key reached; add the path argument index for each route and break + let path_idx = curr_arg_idx + 1; + for (_, arg_indices) in routes.iter_mut() { + arg_indices.push(path_idx); + } + break; + } + } + MultiSlotArgPattern::KeyWithTwoArgTriples => { + // Increment to the first argument associated with the current key and add its index + curr_arg_idx = incr_add_next_arg(arg_indices, curr_arg_idx)?; + // Increment to the second argument associated with the current key and add its index + curr_arg_idx = incr_add_next_arg(arg_indices, curr_arg_idx)?; + } + } + curr_arg_idx += 1; + } + + let mut routes: Vec<(Route, Vec)> = routes.into_iter().collect(); + if routes.is_empty() { + return None; + } + + Some(if routes.len() == 1 { + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(routes.pop().unwrap().0)) + } else { + RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::MultiSlot((routes, args_pattern)), + ResponsePolicy::for_command(cmd), + )) + }) +} + +impl ResponsePolicy { + /// Parse the command for the matching response policy. + pub fn for_command(cmd: &[u8]) -> Option { + match cmd { + b"SCRIPT EXISTS" => Some(ResponsePolicy::AggregateLogical(LogicalAggregateOp::And)), + + b"DBSIZE" | b"DEL" | b"EXISTS" | b"SLOWLOG LEN" | b"TOUCH" | b"UNLINK" + | b"LATENCY RESET" | b"PUBSUB NUMPAT" => { + Some(ResponsePolicy::Aggregate(AggregateOp::Sum)) + } + + b"WAIT" => Some(ResponsePolicy::Aggregate(AggregateOp::Min)), + + b"ACL SETUSER" | b"ACL DELUSER" | b"ACL SAVE" | b"CLIENT SETNAME" + | b"CLIENT SETINFO" | b"CONFIG SET" | b"CONFIG RESETSTAT" | b"CONFIG REWRITE" + | b"FLUSHALL" | b"FLUSHDB" | b"FUNCTION DELETE" | b"FUNCTION FLUSH" + | b"FUNCTION LOAD" | b"FUNCTION RESTORE" | b"MEMORY PURGE" | b"MSET" | b"JSON.MSET" + | b"PING" | b"SCRIPT FLUSH" | b"SCRIPT LOAD" | b"SLOWLOG RESET" | b"UNWATCH" + | b"WATCH" => Some(ResponsePolicy::AllSucceeded), + + b"KEYS" + | b"FT._ALIASLIST" + | b"FT._LIST" + | b"MGET" + | b"JSON.MGET" + | b"SLOWLOG GET" + | b"PUBSUB CHANNELS" + | b"PUBSUB SHARDCHANNELS" => Some(ResponsePolicy::CombineArrays), + + b"PUBSUB NUMSUB" | b"PUBSUB SHARDNUMSUB" => Some(ResponsePolicy::CombineMaps), + + b"FUNCTION KILL" | b"SCRIPT KILL" => Some(ResponsePolicy::OneSucceeded), + + // This isn't based on response_tips, but on the discussion here - https://github.com/redis/redis/issues/12410 + b"RANDOMKEY" => Some(ResponsePolicy::FirstSucceededNonEmptyOrAllEmpty), + + b"LATENCY GRAPH" | b"LATENCY HISTOGRAM" | b"LATENCY HISTORY" | b"LATENCY DOCTOR" + | b"LATENCY LATEST" => Some(ResponsePolicy::Special), + + b"FUNCTION STATS" => Some(ResponsePolicy::Special), + + b"MEMORY MALLOC-STATS" | b"MEMORY DOCTOR" | b"MEMORY STATS" => { + Some(ResponsePolicy::Special) + } + + b"INFO" => Some(ResponsePolicy::Special), + + _ => None, + } + } +} + +enum RouteBy { + AllNodes, + AllPrimaries, + FirstKey, + MultiShard(MultiSlotArgPattern), + Random, + SecondArg, + SecondArgAfterKeyCount, + SecondArgSlot, + StreamsIndex, + ThirdArgAfterKeyCount, + Undefined, +} + +fn base_routing(cmd: &[u8]) -> RouteBy { + match cmd { + b"ACL SETUSER" + | b"ACL DELUSER" + | b"ACL SAVE" + | b"CLIENT SETNAME" + | b"CLIENT SETINFO" + | b"SLOWLOG GET" + | b"SLOWLOG LEN" + | b"SLOWLOG RESET" + | b"CONFIG SET" + | b"CONFIG RESETSTAT" + | b"CONFIG REWRITE" + | b"SCRIPT FLUSH" + | b"SCRIPT LOAD" + | b"LATENCY RESET" + | b"LATENCY GRAPH" + | b"LATENCY HISTOGRAM" + | b"LATENCY HISTORY" + | b"LATENCY DOCTOR" + | b"LATENCY LATEST" + | b"PUBSUB NUMPAT" + | b"PUBSUB CHANNELS" + | b"PUBSUB NUMSUB" + | b"PUBSUB SHARDCHANNELS" + | b"PUBSUB SHARDNUMSUB" + | b"SCRIPT KILL" + | b"FUNCTION KILL" + | b"FUNCTION STATS" => RouteBy::AllNodes, + + b"DBSIZE" + | b"FLUSHALL" + | b"FLUSHDB" + | b"FT._ALIASLIST" + | b"FT._LIST" + | b"FUNCTION DELETE" + | b"FUNCTION FLUSH" + | b"FUNCTION LOAD" + | b"FUNCTION RESTORE" + | b"INFO" + | b"KEYS" + | b"MEMORY DOCTOR" + | b"MEMORY MALLOC-STATS" + | b"MEMORY PURGE" + | b"MEMORY STATS" + | b"PING" + | b"SCRIPT EXISTS" + | b"UNWATCH" + | b"WAIT" + | b"RANDOMKEY" + | b"WAITAOF" => RouteBy::AllPrimaries, + + b"MGET" | b"DEL" | b"EXISTS" | b"UNLINK" | b"TOUCH" | b"WATCH" => { + RouteBy::MultiShard(MultiSlotArgPattern::KeysOnly) + } + + b"MSET" => RouteBy::MultiShard(MultiSlotArgPattern::KeyValuePairs), + b"JSON.MGET" => RouteBy::MultiShard(MultiSlotArgPattern::KeysAndLastArg), + b"JSON.MSET" => RouteBy::MultiShard(MultiSlotArgPattern::KeyWithTwoArgTriples), + // TODO - special handling - b"SCAN" + b"SCAN" | b"SHUTDOWN" | b"SLAVEOF" | b"REPLICAOF" => RouteBy::Undefined, + + b"BLMPOP" | b"BZMPOP" | b"EVAL" | b"EVALSHA" | b"EVALSHA_RO" | b"EVAL_RO" | b"FCALL" + | b"FCALL_RO" => RouteBy::ThirdArgAfterKeyCount, + + b"BITOP" + | b"MEMORY USAGE" + | b"PFDEBUG" + | b"XGROUP CREATE" + | b"XGROUP CREATECONSUMER" + | b"XGROUP DELCONSUMER" + | b"XGROUP DESTROY" + | b"XGROUP SETID" + | b"XINFO CONSUMERS" + | b"XINFO GROUPS" + | b"XINFO STREAM" + | b"OBJECT ENCODING" + | b"OBJECT FREQ" + | b"OBJECT IDLETIME" + | b"OBJECT REFCOUNT" => RouteBy::SecondArg, + + b"LMPOP" | b"SINTERCARD" | b"ZDIFF" | b"ZINTER" | b"ZINTERCARD" | b"ZMPOP" | b"ZUNION" => { + RouteBy::SecondArgAfterKeyCount + } + + b"XREAD" | b"XREADGROUP" => RouteBy::StreamsIndex, + + // keyless commands with more arguments, whose arguments might be wrongly taken to be keys. + // TODO - double check these, in order to find better ways to route some of them. + b"ACL DRYRUN" + | b"ACL GENPASS" + | b"ACL GETUSER" + | b"ACL HELP" + | b"ACL LIST" + | b"ACL LOG" + | b"ACL USERS" + | b"ACL WHOAMI" + | b"AUTH" + | b"BGSAVE" + | b"CLIENT GETNAME" + | b"CLIENT GETREDIR" + | b"CLIENT ID" + | b"CLIENT INFO" + | b"CLIENT KILL" + | b"CLIENT PAUSE" + | b"CLIENT REPLY" + | b"CLIENT TRACKINGINFO" + | b"CLIENT UNBLOCK" + | b"CLIENT UNPAUSE" + | b"CLUSTER COUNT-FAILURE-REPORTS" + | b"CLUSTER INFO" + | b"CLUSTER KEYSLOT" + | b"CLUSTER MEET" + | b"CLUSTER MYSHARDID" + | b"CLUSTER NODES" + | b"CLUSTER REPLICAS" + | b"CLUSTER RESET" + | b"CLUSTER SET-CONFIG-EPOCH" + | b"CLUSTER SHARDS" + | b"CLUSTER SLOTS" + | b"COMMAND COUNT" + | b"COMMAND GETKEYS" + | b"COMMAND LIST" + | b"COMMAND" + | b"CONFIG GET" + | b"DEBUG" + | b"ECHO" + | b"FUNCTION LIST" + | b"LASTSAVE" + | b"LOLWUT" + | b"MODULE LIST" + | b"MODULE LOAD" + | b"MODULE LOADEX" + | b"MODULE UNLOAD" + | b"READONLY" + | b"READWRITE" + | b"SAVE" + | b"SCRIPT SHOW" + | b"TFCALL" + | b"TFCALLASYNC" + | b"TFUNCTION DELETE" + | b"TFUNCTION LIST" + | b"TFUNCTION LOAD" + | b"TIME" => RouteBy::Random, + + b"CLUSTER ADDSLOTS" + | b"CLUSTER COUNTKEYSINSLOT" + | b"CLUSTER DELSLOTS" + | b"CLUSTER DELSLOTSRANGE" + | b"CLUSTER GETKEYSINSLOT" + | b"CLUSTER SETSLOT" => RouteBy::SecondArgSlot, + + _ => RouteBy::FirstKey, + } +} + +impl RoutingInfo { + /// Returns true if the `cmd` should be routed to all nodes. + pub fn is_all_nodes(cmd: &[u8]) -> bool { + matches!(base_routing(cmd), RouteBy::AllNodes) + } + + /// Returns true if the `cmd` is a key-based command that triggers MOVED errors. + /// A key-based command is one that will be accepted only by the slot owner, + /// while other nodes will respond with a MOVED error redirecting to the relevant primary owner. + pub fn is_key_routing_command(cmd: &[u8]) -> bool { + match base_routing(cmd) { + RouteBy::FirstKey + | RouteBy::SecondArg + | RouteBy::SecondArgAfterKeyCount + | RouteBy::ThirdArgAfterKeyCount + | RouteBy::SecondArgSlot + | RouteBy::StreamsIndex + | RouteBy::MultiShard(_) => { + if matches!(cmd, b"SPUBLISH") { + // SPUBLISH does not return MOVED errors within the slot's shard. This means that even if READONLY wasn't sent to a replica, + // executing SPUBLISH FOO BAR on that replica will succeed. This behavior differs from true key-based commands, + // such as SET FOO BAR, where a non-readonly replica would return a MOVED error if READONLY is off. + // Consequently, SPUBLISH does not meet the requirement of being a command that triggers MOVED errors. + // TODO: remove this when PRIMARY_PREFERRED route for SPUBLISH is added + false + } else { + true + } + } + RouteBy::AllNodes | RouteBy::AllPrimaries | RouteBy::Random | RouteBy::Undefined => { + false + } + } + } + + /// Returns the routing info for `r`. + pub fn for_routable(r: &R) -> Option + where + R: Routable + ?Sized, + { + let cmd = &r.command()?[..]; + match base_routing(cmd) { + RouteBy::AllNodes => Some(RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllNodes, + ResponsePolicy::for_command(cmd), + ))), + + RouteBy::AllPrimaries => Some(RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllMasters, + ResponsePolicy::for_command(cmd), + ))), + + RouteBy::MultiShard(arg_pattern) => multi_shard(r, cmd, 1, arg_pattern), + + RouteBy::Random => Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)), + + RouteBy::ThirdArgAfterKeyCount => { + let key_count = r + .arg_idx(2) + .and_then(|x| std::str::from_utf8(x).ok()) + .and_then(|x| x.parse::().ok())?; + if key_count == 0 { + Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + } else { + r.arg_idx(3).map(|key| RoutingInfo::for_key(cmd, key)) + } + } + + RouteBy::SecondArg => r.arg_idx(2).map(|key| RoutingInfo::for_key(cmd, key)), + + RouteBy::SecondArgAfterKeyCount => { + let key_count = r + .arg_idx(1) + .and_then(|x| std::str::from_utf8(x).ok()) + .and_then(|x| x.parse::().ok())?; + if key_count == 0 { + Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + } else { + r.arg_idx(2).map(|key| RoutingInfo::for_key(cmd, key)) + } + } + + RouteBy::StreamsIndex => { + let streams_position = r.position(b"STREAMS")?; + r.arg_idx(streams_position + 1) + .map(|key| RoutingInfo::for_key(cmd, key)) + } + + RouteBy::SecondArgSlot => r + .arg_idx(2) + .and_then(|arg| std::str::from_utf8(arg).ok()) + .and_then(|slot| slot.parse::().ok()) + .map(|slot| { + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + slot, + SlotAddr::Master, + ))) + }), + + RouteBy::FirstKey => match r.arg_idx(1) { + Some(key) => Some(RoutingInfo::for_key(cmd, key)), + None => Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)), + }, + + RouteBy::Undefined => None, + } + } + + fn for_key(cmd: &[u8], key: &[u8]) -> RoutingInfo { + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(get_route( + is_readonly_cmd(cmd), + key, + ))) + } +} + +/// Returns true if the given `routable` represents a readonly command. +pub fn is_readonly(routable: &impl Routable) -> bool { + match routable.command() { + Some(cmd) => is_readonly_cmd(cmd.as_slice()), + None => false, + } +} + +/// Returns `true` if the given `cmd` is a readonly command. +pub fn is_readonly_cmd(cmd: &[u8]) -> bool { + matches!( + cmd, + b"ACL CAT" + | b"ACL DELUSER" + | b"ACL DRYRUN" + | b"ACL GENPASS" + | b"ACL GETUSER" + | b"ACL HELP" + | b"ACL LIST" + | b"ACL LOAD" + | b"ACL LOG" + | b"ACL SAVE" + | b"ACL SETUSER" + | b"ACL USERS" + | b"ACL WHOAMI" + | b"AUTH" + | b"BGREWRITEAOF" + | b"BGSAVE" + | b"BITCOUNT" + | b"BITFIELD_RO" + | b"BITPOS" + | b"CLIENT ID" + | b"CLIENT CACHING" + | b"CLIENT CAPA" + | b"CLIENT GETNAME" + | b"CLIENT GETREDIR" + | b"CLIENT HELP" + | b"CLIENT INFO" + | b"CLIENT KILL" + | b"CLIENT LIST" + | b"CLIENT NO-EVICT" + | b"CLIENT NO-TOUCH" + | b"CLIENT PAUSE" + | b"CLIENT REPLY" + | b"CLIENT SETINFO" + | b"CLIENT SETNAME" + | b"CLIENT TRACKING" + | b"CLIENT TRACKINGINFO" + | b"CLIENT UNBLOCK" + | b"CLIENT UNPAUSE" + | b"CLUSTER COUNT-FAILURE-REPORTS" + | b"CLUSTER COUNTKEYSINSLOT" + | b"CLUSTER FAILOVER" + | b"CLUSTER GETKEYSINSLOT" + | b"CLUSTER HELP" + | b"CLUSTER INFO" + | b"CLUSTER KEYSLOT" + | b"CLUSTER LINKS" + | b"CLUSTER MYID" + | b"CLUSTER MYSHARDID" + | b"CLUSTER NODES" + | b"CLUSTER REPLICATE" + | b"CLUSTER SAVECONFIG" + | b"CLUSTER SHARDS" + | b"CLUSTER SLOTS" + | b"COMMAND COUNT" + | b"COMMAND DOCS" + | b"COMMAND GETKEYS" + | b"COMMAND GETKEYSANDFLAGS" + | b"COMMAND HELP" + | b"COMMAND INFO" + | b"COMMAND LIST" + | b"CONFIG GET" + | b"CONFIG HELP" + | b"CONFIG RESETSTAT" + | b"CONFIG REWRITE" + | b"CONFIG SET" + | b"DBSIZE" + | b"DUMP" + | b"ECHO" + | b"EVAL_RO" + | b"EVALSHA_RO" + | b"EXISTS" + | b"EXPIRETIME" + | b"FCALL_RO" + | b"FT.AGGREGATE" + | b"FT.EXPLAIN" + | b"FT.EXPLAINCLI" + | b"FT.INFO" + | b"FT.PROFILE" + | b"FT.SEARCH" + | b"FT._ALIASLIST" + | b"FT._LIST" + | b"FUNCTION DUMP" + | b"FUNCTION HELP" + | b"FUNCTION KILL" + | b"FUNCTION LIST" + | b"FUNCTION STATS" + | b"GEODIST" + | b"GEOHASH" + | b"GEOPOS" + | b"GEORADIUSBYMEMBER_RO" + | b"GEORADIUS_RO" + | b"GEOSEARCH" + | b"GET" + | b"GETBIT" + | b"GETRANGE" + | b"HELLO" + | b"HEXISTS" + | b"HGET" + | b"HGETALL" + | b"HKEYS" + | b"HLEN" + | b"HMGET" + | b"HRANDFIELD" + | b"HSCAN" + | b"HSTRLEN" + | b"HVALS" + | b"JSON.ARRINDEX" + | b"JSON.ARRLEN" + | b"JSON.DEBUG" + | b"JSON.GET" + | b"JSON.OBJLEN" + | b"JSON.OBJKEYS" + | b"JSON.MGET" + | b"JSON.RESP" + | b"JSON.STRLEN" + | b"JSON.TYPE" + | b"INFO" + | b"KEYS" + | b"LASTSAVE" + | b"LATENCY DOCTOR" + | b"LATENCY GRAPH" + | b"LATENCY HELP" + | b"LATENCY HISTOGRAM" + | b"LATENCY HISTORY" + | b"LATENCY LATEST" + | b"LATENCY RESET" + | b"LCS" + | b"LINDEX" + | b"LLEN" + | b"LOLWUT" + | b"LPOS" + | b"LRANGE" + | b"MEMORY DOCTOR" + | b"MEMORY HELP" + | b"MEMORY MALLOC-STATS" + | b"MEMORY PURGE" + | b"MEMORY STATS" + | b"MEMORY USAGE" + | b"MGET" + | b"MODULE HELP" + | b"MODULE LIST" + | b"MODULE LOAD" + | b"MODULE LOADEX" + | b"MODULE UNLOAD" + | b"OBJECT ENCODING" + | b"OBJECT FREQ" + | b"OBJECT HELP" + | b"OBJECT IDLETIME" + | b"OBJECT REFCOUNT" + | b"PEXPIRETIME" + | b"PFCOUNT" + | b"PING" + | b"PTTL" + | b"PUBLISH" + | b"PUBSUB CHANNELS" + | b"PUBSUB HELP" + | b"PUBSUB NUMPAT" + | b"PUBSUB NUMSUB" + | b"PUBSUB SHARDCHANNELS" + | b"PUBSUB SHARDNUMSUB" + | b"RANDOMKEY" + | b"REPLICAOF" + | b"RESET" + | b"ROLE" + | b"SAVE" + | b"SCAN" + | b"SCARD" + | b"SCRIPT DEBUG" + | b"SCRIPT EXISTS" + | b"SCRIPT FLUSH" + | b"SCRIPT KILL" + | b"SCRIPT LOAD" + | b"SCRIPT SHOW" + | b"SDIFF" + | b"SELECT" + | b"SHUTDOWN" + | b"SINTER" + | b"SINTERCARD" + | b"SISMEMBER" + | b"SMEMBERS" + | b"SMISMEMBER" + | b"SLOWLOG GET" + | b"SLOWLOG HELP" + | b"SLOWLOG LEN" + | b"SLOWLOG RESET" + | b"SORT_RO" + | b"SPUBLISH" + | b"SRANDMEMBER" + | b"SSCAN" + | b"SSUBSCRIBE" + | b"STRLEN" + | b"SUBSCRIBE" + | b"SUBSTR" + | b"SUNION" + | b"SUNSUBSCRIBE" + | b"TIME" + | b"TOUCH" + | b"TTL" + | b"TYPE" + | b"UNSUBSCRIBE" + | b"XINFO CONSUMERS" + | b"XINFO GROUPS" + | b"XINFO HELP" + | b"XINFO STREAM" + | b"XLEN" + | b"XPENDING" + | b"XRANGE" + | b"XREAD" + | b"XREVRANGE" + | b"ZCARD" + | b"ZCOUNT" + | b"ZDIFF" + | b"ZINTER" + | b"ZINTERCARD" + | b"ZLEXCOUNT" + | b"ZMSCORE" + | b"ZRANDMEMBER" + | b"ZRANGE" + | b"ZRANGEBYLEX" + | b"ZRANGEBYSCORE" + | b"ZRANK" + | b"ZREVRANGE" + | b"ZREVRANGEBYLEX" + | b"ZREVRANGEBYSCORE" + | b"ZREVRANK" + | b"ZSCAN" + | b"ZSCORE" + | b"ZUNION" + ) +} + +/// Objects that implement this trait define a request that can be routed by a cluster client to different nodes in the cluster. +pub trait Routable { + /// Convenience function to return ascii uppercase version of the + /// the first argument (i.e., the command). + fn command(&self) -> Option> { + let primary_command = self.arg_idx(0).map(|x| x.to_ascii_uppercase())?; + let mut primary_command = match primary_command.as_slice() { + b"XGROUP" | b"OBJECT" | b"SLOWLOG" | b"FUNCTION" | b"MODULE" | b"COMMAND" + | b"PUBSUB" | b"CONFIG" | b"MEMORY" | b"XINFO" | b"CLIENT" | b"ACL" | b"SCRIPT" + | b"CLUSTER" | b"LATENCY" => primary_command, + _ => { + return Some(primary_command); + } + }; + + Some(match self.arg_idx(1) { + Some(secondary_command) => { + let previous_len = primary_command.len(); + primary_command.reserve(secondary_command.len() + 1); + primary_command.extend(b" "); + primary_command.extend(secondary_command); + let current_len = primary_command.len(); + primary_command[previous_len + 1..current_len].make_ascii_uppercase(); + primary_command + } + None => primary_command, + }) + } + + /// Returns a reference to the data for the argument at `idx`. + fn arg_idx(&self, idx: usize) -> Option<&[u8]>; + + /// Returns index of argument that matches `candidate`, if it exists + fn position(&self, candidate: &[u8]) -> Option; +} + +impl Routable for Cmd { + fn arg_idx(&self, idx: usize) -> Option<&[u8]> { + self.arg_idx(idx) + } + + fn position(&self, candidate: &[u8]) -> Option { + self.args_iter().position(|a| match a { + Arg::Simple(d) => d.eq_ignore_ascii_case(candidate), + _ => false, + }) + } +} + +impl Routable for Value { + fn arg_idx(&self, idx: usize) -> Option<&[u8]> { + match self { + Value::Array(args) => match args.get(idx) { + Some(Value::BulkString(ref data)) => Some(&data[..]), + _ => None, + }, + _ => None, + } + } + + fn position(&self, candidate: &[u8]) -> Option { + match self { + Value::Array(args) => args.iter().position(|a| match a { + Value::BulkString(d) => d.eq_ignore_ascii_case(candidate), + _ => false, + }), + _ => None, + } + } +} + +#[derive(Debug, Hash, Clone)] +pub(crate) struct Slot { + pub(crate) start: u16, + pub(crate) end: u16, + pub(crate) master: String, + pub(crate) replicas: Vec, +} + +impl Slot { + pub fn new(s: u16, e: u16, m: String, r: Vec) -> Self { + Self { + start: s, + end: e, + master: m, + replicas: r, + } + } + + #[allow(dead_code)] // used in tests + pub(crate) fn master(&self) -> &str { + self.master.as_str() + } + + #[allow(dead_code)] // used in tests + pub fn replicas(&self) -> Vec { + self.replicas.clone() + } +} + +/// What type of node should a request be routed to, assuming read from replica is enabled. +#[derive(Eq, PartialEq, Clone, Copy, Debug, Hash)] +pub enum SlotAddr { + /// The request must be routed to primary node + Master, + /// The request may be routed to a replica node. + /// For example, a GET command can be routed either to replica or primary. + ReplicaOptional, + /// The request must be routed to replica node, if one exists. + /// For example, by user requested routing. + ReplicaRequired, +} + +/// Represents the result of checking a shard for the status of a node. +/// +/// This enum indicates whether a given node is already the primary, has been promoted to a primary from a replica, +/// or is not found in the shard at all. +/// +/// Variants: +/// - `AlreadyPrimary`: The specified node is already the primary for the shard, so no changes are needed. +/// - `Promoted`: The specified node was found as a replica and successfully promoted to primary. +/// - `NodeNotFound`: The specified node is neither the current primary nor a replica within the shard. +#[derive(PartialEq, Debug)] +pub(crate) enum ShardUpdateResult { + AlreadyPrimary, + Promoted, + NodeNotFound, +} + +const READ_LK_ERR_SHARDADDRS: &str = "Failed to acquire read lock for ShardAddrs"; +const WRITE_LK_ERR_SHARDADDRS: &str = "Failed to acquire write lock for ShardAddrs"; +/// This is just a simplified version of [`Slot`], +/// which stores only the master and [optional] replica +/// to avoid the need to choose a replica each time +/// a command is executed +#[derive(Debug)] +pub(crate) struct ShardAddrs { + primary: RwLock>, + replicas: RwLock>>, +} + +impl PartialEq for ShardAddrs { + fn eq(&self, other: &Self) -> bool { + let self_primary = self.primary.read().expect(READ_LK_ERR_SHARDADDRS); + let other_primary = other.primary.read().expect(READ_LK_ERR_SHARDADDRS); + + let self_replicas = self.replicas.read().expect(READ_LK_ERR_SHARDADDRS); + let other_replicas = other.replicas.read().expect(READ_LK_ERR_SHARDADDRS); + + *self_primary == *other_primary && *self_replicas == *other_replicas + } +} + +impl Eq for ShardAddrs {} + +impl PartialOrd for ShardAddrs { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for ShardAddrs { + fn cmp(&self, other: &Self) -> Ordering { + let self_primary = self.primary.read().expect(READ_LK_ERR_SHARDADDRS); + let other_primary = other.primary.read().expect(READ_LK_ERR_SHARDADDRS); + + let primary_cmp = self_primary.cmp(&other_primary); + if primary_cmp == Ordering::Equal { + let self_replicas = self.replicas.read().expect(READ_LK_ERR_SHARDADDRS); + let other_replicas = other.replicas.read().expect(READ_LK_ERR_SHARDADDRS); + return self_replicas.cmp(&other_replicas); + } + + primary_cmp + } +} + +impl ShardAddrs { + pub(crate) fn new(primary: Arc, replicas: Vec>) -> Self { + let primary = RwLock::new(primary); + let replicas = RwLock::new(replicas); + Self { primary, replicas } + } + + pub(crate) fn new_with_primary(primary: Arc) -> Self { + Self::new(primary, Vec::default()) + } + + pub(crate) fn primary(&self) -> Arc { + self.primary.read().expect(READ_LK_ERR_SHARDADDRS).clone() + } + + pub(crate) fn replicas(&self) -> std::sync::RwLockReadGuard>> { + self.replicas.read().expect(READ_LK_ERR_SHARDADDRS) + } + + /// Attempts to update the shard roles based on the provided `new_primary`. + /// + /// This function evaluates whether the specified `new_primary` node is already + /// the primary, a replica that can be promoted to primary, or a node not present + /// in the shard. It handles three scenarios: + /// + /// 1. **Already Primary**: If the `new_primary` is already the current primary, + /// the function returns `ShardUpdateResult::AlreadyPrimary` and no changes are made. + /// + /// 2. **Promoted**: If the `new_primary` is found in the list of replicas, it is promoted + /// to primary by swapping it with the current primary, and the function returns + /// `ShardUpdateResult::Promoted`. + /// + /// 3. **Node Not Found**: If the `new_primary` is neither the current primary nor a replica, + /// the function returns `ShardUpdateResult::NodeNotFound` to indicate that the node is + /// not part of the current shard. + /// + /// # Arguments: + /// * `new_primary` - Representing the node to be promoted or checked. + /// + /// # Returns: + /// * `ShardUpdateResult` - The result of the role update operation. + pub(crate) fn attempt_shard_role_update(&self, new_primary: Arc) -> ShardUpdateResult { + let mut primary_lock = self.primary.write().expect(WRITE_LK_ERR_SHARDADDRS); + let mut replicas_lock = self.replicas.write().expect(WRITE_LK_ERR_SHARDADDRS); + + // If the new primary is already the current primary, return early. + if *primary_lock == new_primary { + return ShardUpdateResult::AlreadyPrimary; + } + + // If the new primary is found among replicas, swap it with the current primary. + if let Some(replica_idx) = Self::replica_index(&replicas_lock, new_primary.clone()) { + std::mem::swap(&mut *primary_lock, &mut replicas_lock[replica_idx]); + return ShardUpdateResult::Promoted; + } + + // If the new primary isn't part of the shard. + ShardUpdateResult::NodeNotFound + } + + fn replica_index( + replicas: &RwLockWriteGuard<'_, Vec>>, + target_replica: Arc, + ) -> Option { + replicas + .iter() + .position(|curr_replica| **curr_replica == *target_replica) + } + + /// Removes the specified `replica_to_remove` from the shard's replica list if it exists. + /// This method searches for the replica's index and removes it from the list. If the replica + /// is not found, it returns an error. + /// + /// # Arguments + /// * `replica_to_remove` - The address of the replica to be removed. + /// + /// # Returns + /// * `RedisResult<()>` - `Ok(())` if the replica was successfully removed, or an error if the + /// replica was not found. + pub(crate) fn remove_replica(&self, replica_to_remove: Arc) -> RedisResult<()> { + let mut replicas_lock = self.replicas.write().expect(WRITE_LK_ERR_SHARDADDRS); + if let Some(index) = Self::replica_index(&replicas_lock, replica_to_remove.clone()) { + replicas_lock.remove(index); + Ok(()) + } else { + Err(RedisError::from(( + ErrorKind::ClientError, + "Couldn't remove replica", + format!("Replica {replica_to_remove:?} not found"), + ))) + } + } +} + +impl<'a> IntoIterator for &'a ShardAddrs { + type Item = Arc; + type IntoIter = std::iter::Chain>, std::vec::IntoIter>>; + + fn into_iter(self) -> Self::IntoIter { + let primary = self.primary.read().expect(READ_LK_ERR_SHARDADDRS).clone(); + let replicas = self.replicas.read().expect(READ_LK_ERR_SHARDADDRS).clone(); + + std::iter::once(primary).chain(replicas) + } +} + +/// Defines the slot and the [`SlotAddr`] to which +/// a command should be sent +#[derive(Eq, PartialEq, Clone, Copy, Debug, Hash)] +pub struct Route(u16, SlotAddr); + +impl Route { + /// Returns a new Route. + pub fn new(slot: u16, slot_addr: SlotAddr) -> Self { + Self(slot, slot_addr) + } + + /// Returns the slot number of the route. + pub fn slot(&self) -> u16 { + self.0 + } + + /// Returns the slot address of the route. + pub fn slot_addr(&self) -> SlotAddr { + self.1 + } + + /// Returns a new Route for a random primary node + pub fn new_random_primary() -> Self { + Self::new(random_slot(), SlotAddr::Master) + } +} + +/// Choose a random slot from `0..SLOT_SIZE` (excluding) +fn random_slot() -> u16 { + let mut rng = rand::thread_rng(); + rng.gen_range(0..crate::cluster_topology::SLOT_SIZE) +} + +#[cfg(test)] +mod tests_routing { + use super::{ + command_for_multi_slot_indices, AggregateOp, MultiSlotArgPattern, MultipleNodeRoutingInfo, + ResponsePolicy, Route, RoutingInfo, ShardAddrs, SingleNodeRoutingInfo, SlotAddr, + }; + use crate::cluster_routing::ShardUpdateResult; + use crate::{cluster_topology::slot, cmd, parser::parse_redis_value, Value}; + use core::panic; + use std::sync::{Arc, RwLock}; + + #[test] + fn test_routing_info_mixed_capatalization() { + let mut upper = cmd("XREAD"); + upper.arg("STREAMS").arg("foo").arg(0); + + let mut lower = cmd("xread"); + lower.arg("streams").arg("foo").arg(0); + + assert_eq!( + RoutingInfo::for_routable(&upper).unwrap(), + RoutingInfo::for_routable(&lower).unwrap() + ); + + let mut mixed = cmd("xReAd"); + mixed.arg("StReAmS").arg("foo").arg(0); + + assert_eq!( + RoutingInfo::for_routable(&lower).unwrap(), + RoutingInfo::for_routable(&mixed).unwrap() + ); + } + + #[test] + fn test_routing_info() { + let mut test_cmds = vec![]; + + // RoutingInfo::AllMasters + let mut test_cmd = cmd("FLUSHALL"); + test_cmd.arg(""); + test_cmds.push(test_cmd); + + // RoutingInfo::AllNodes + test_cmd = cmd("ECHO"); + test_cmd.arg(""); + test_cmds.push(test_cmd); + + // Routing key is 2nd arg ("42") + test_cmd = cmd("SET"); + test_cmd.arg("42"); + test_cmds.push(test_cmd); + + // Routing key is 3rd arg ("FOOBAR") + test_cmd = cmd("XINFO"); + test_cmd.arg("GROUPS").arg("FOOBAR"); + test_cmds.push(test_cmd); + + // Routing key is 3rd or 4th arg (3rd = "0" == RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + test_cmd = cmd("EVAL"); + test_cmd.arg("FOO").arg("0").arg("BAR"); + test_cmds.push(test_cmd); + + // Routing key is 3rd or 4th arg (3rd != "0" == RoutingInfo::Slot) + test_cmd = cmd("EVAL"); + test_cmd.arg("FOO").arg("4").arg("BAR"); + test_cmds.push(test_cmd); + + // Routing key position is variable, 3rd arg + test_cmd = cmd("XREAD"); + test_cmd.arg("STREAMS").arg("4"); + test_cmds.push(test_cmd); + + // Routing key position is variable, 4th arg + test_cmd = cmd("XREAD"); + test_cmd.arg("FOO").arg("STREAMS").arg("4"); + test_cmds.push(test_cmd); + + for cmd in test_cmds { + let value = parse_redis_value(&cmd.get_packed_command()).unwrap(); + assert_eq!( + RoutingInfo::for_routable(&value).unwrap(), + RoutingInfo::for_routable(&cmd).unwrap(), + ); + } + + // Assert expected RoutingInfo explicitly: + + for cmd in [cmd("FLUSHALL"), cmd("FLUSHDB"), cmd("PING")] { + assert_eq!( + RoutingInfo::for_routable(&cmd), + Some(RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllMasters, + Some(ResponsePolicy::AllSucceeded) + ))) + ); + } + + assert_eq!( + RoutingInfo::for_routable(&cmd("DBSIZE")), + Some(RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllMasters, + Some(ResponsePolicy::Aggregate(AggregateOp::Sum)) + ))) + ); + + assert_eq!( + RoutingInfo::for_routable(&cmd("SCRIPT KILL")), + Some(RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllNodes, + Some(ResponsePolicy::OneSucceeded) + ))) + ); + + assert_eq!( + RoutingInfo::for_routable(&cmd("INFO")), + Some(RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllMasters, + Some(ResponsePolicy::Special) + ))) + ); + + assert_eq!( + RoutingInfo::for_routable(&cmd("KEYS")), + Some(RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllMasters, + Some(ResponsePolicy::CombineArrays) + ))) + ); + + for cmd in vec![ + cmd("SCAN"), + cmd("SHUTDOWN"), + cmd("SLAVEOF"), + cmd("REPLICAOF"), + ] { + assert_eq!( + RoutingInfo::for_routable(&cmd), + None, + "{}", + std::str::from_utf8(cmd.arg_idx(0).unwrap()).unwrap() + ); + } + + for cmd in [ + cmd("EVAL").arg(r#"redis.call("PING");"#).arg(0), + cmd("EVALSHA").arg(r#"redis.call("PING");"#).arg(0), + ] { + assert_eq!( + RoutingInfo::for_routable(cmd), + Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + ); + } + + // While FCALL with N keys is expected to be routed to a specific node + assert_eq!( + RoutingInfo::for_routable(cmd("FCALL").arg("foo").arg(1).arg("mykey")), + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new(slot(b"mykey"), SlotAddr::Master)) + )) + ); + + for (cmd, expected) in [ + ( + cmd("EVAL") + .arg(r#"redis.call("GET, KEYS[1]");"#) + .arg(1) + .arg("foo"), + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new(slot(b"foo"), SlotAddr::Master)), + )), + ), + ( + cmd("XGROUP") + .arg("CREATE") + .arg("mystream") + .arg("workers") + .arg("$") + .arg("MKSTREAM"), + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new( + slot(b"mystream"), + SlotAddr::Master, + )), + )), + ), + ( + cmd("XINFO").arg("GROUPS").arg("foo"), + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new( + slot(b"foo"), + SlotAddr::ReplicaOptional, + )), + )), + ), + ( + cmd("XREADGROUP") + .arg("GROUP") + .arg("wkrs") + .arg("consmrs") + .arg("STREAMS") + .arg("mystream"), + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new( + slot(b"mystream"), + SlotAddr::Master, + )), + )), + ), + ( + cmd("XREAD") + .arg("COUNT") + .arg("2") + .arg("STREAMS") + .arg("mystream") + .arg("writers") + .arg("0-0") + .arg("0-0"), + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new( + slot(b"mystream"), + SlotAddr::ReplicaOptional, + )), + )), + ), + ] { + assert_eq!( + RoutingInfo::for_routable(cmd), + expected, + "{}", + std::str::from_utf8(cmd.arg_idx(0).unwrap()).unwrap() + ); + } + } + + #[test] + fn test_slot_for_packed_cmd() { + assert!(matches!(RoutingInfo::for_routable(&parse_redis_value(&[ + 42, 50, 13, 10, 36, 54, 13, 10, 69, 88, 73, 83, 84, 83, 13, 10, 36, 49, 54, 13, 10, + 244, 93, 23, 40, 126, 127, 253, 33, 89, 47, 185, 204, 171, 249, 96, 139, 13, 10 + ]).unwrap()), Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route(slot, SlotAddr::ReplicaOptional)))) if slot == 964)); + + assert!(matches!(RoutingInfo::for_routable(&parse_redis_value(&[ + 42, 54, 13, 10, 36, 51, 13, 10, 83, 69, 84, 13, 10, 36, 49, 54, 13, 10, 36, 241, + 197, 111, 180, 254, 5, 175, 143, 146, 171, 39, 172, 23, 164, 145, 13, 10, 36, 52, + 13, 10, 116, 114, 117, 101, 13, 10, 36, 50, 13, 10, 78, 88, 13, 10, 36, 50, 13, 10, + 80, 88, 13, 10, 36, 55, 13, 10, 49, 56, 48, 48, 48, 48, 48, 13, 10 + ]).unwrap()), Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route(slot, SlotAddr::Master)))) if slot == 8352)); + + assert!(matches!(RoutingInfo::for_routable(&parse_redis_value(&[ + 42, 54, 13, 10, 36, 51, 13, 10, 83, 69, 84, 13, 10, 36, 49, 54, 13, 10, 169, 233, + 247, 59, 50, 247, 100, 232, 123, 140, 2, 101, 125, 221, 66, 170, 13, 10, 36, 52, + 13, 10, 116, 114, 117, 101, 13, 10, 36, 50, 13, 10, 78, 88, 13, 10, 36, 50, 13, 10, + 80, 88, 13, 10, 36, 55, 13, 10, 49, 56, 48, 48, 48, 48, 48, 13, 10 + ]).unwrap()), Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route(slot, SlotAddr::Master)))) if slot == 5210)); + } + + #[test] + fn test_multi_shard_keys_only() { + let mut cmd = cmd("DEL"); + cmd.arg("foo").arg("bar").arg("baz").arg("{bar}vaz"); + let routing = RoutingInfo::for_routable(&cmd); + let mut expected = std::collections::HashMap::new(); + expected.insert(Route(4813, SlotAddr::Master), vec![2]); + expected.insert(Route(5061, SlotAddr::Master), vec![1, 3]); + expected.insert(Route(12182, SlotAddr::Master), vec![0]); + + assert!( + matches!(routing.clone(), Some(RoutingInfo::MultiNode((MultipleNodeRoutingInfo::MultiSlot((vec, args_pattern)), Some(ResponsePolicy::Aggregate(AggregateOp::Sum))))) if { + let routes = vec.clone().into_iter().collect(); + expected == routes && args_pattern == MultiSlotArgPattern::KeysOnly + }), + "expected={expected:?}\nrouting={routing:?}" + ); + + let mut cmd = crate::cmd("MGET"); + cmd.arg("foo").arg("bar").arg("baz").arg("{bar}vaz"); + let routing = RoutingInfo::for_routable(&cmd); + let mut expected = std::collections::HashMap::new(); + expected.insert(Route(4813, SlotAddr::ReplicaOptional), vec![2]); + expected.insert(Route(5061, SlotAddr::ReplicaOptional), vec![1, 3]); + expected.insert(Route(12182, SlotAddr::ReplicaOptional), vec![0]); + + assert!( + matches!(routing.clone(), Some(RoutingInfo::MultiNode((MultipleNodeRoutingInfo::MultiSlot((vec, args_pattern)), Some(ResponsePolicy::CombineArrays)))) if { + let routes = vec.clone().into_iter().collect(); + expected == routes && args_pattern == MultiSlotArgPattern::KeysOnly + }), + "expected={expected:?}\nrouting={routing:?}" + ); + } + + #[test] + fn test_multi_shard_key_value_pairs() { + let mut cmd = cmd("MSET"); + cmd.arg("foo") // key slot 12182 + .arg("bar") // value + .arg("foo2") // key slot 1044 + .arg("bar2") // value + .arg("{foo}foo3") // key slot 12182 + .arg("bar3"); // value + let routing = RoutingInfo::for_routable(&cmd); + let mut expected = std::collections::HashMap::new(); + expected.insert(Route(1044, SlotAddr::Master), vec![2, 3]); + expected.insert(Route(12182, SlotAddr::Master), vec![0, 1, 4, 5]); + + assert!( + matches!(routing.clone(), Some(RoutingInfo::MultiNode((MultipleNodeRoutingInfo::MultiSlot((vec, args_pattern)), Some(ResponsePolicy::AllSucceeded)))) if { + let routes = vec.clone().into_iter().collect(); + expected == routes && args_pattern == MultiSlotArgPattern::KeyValuePairs + }), + "expected={expected:?}\nrouting={routing:?}" + ); + } + + #[test] + fn test_multi_shard_keys_and_path() { + let mut cmd = cmd("JSON.MGET"); + cmd.arg("foo") // key slot 12182 + .arg("bar") // key slot 5061 + .arg("baz") // key slot 4813 + .arg("{bar}vaz") // key slot 5061 + .arg("$.f.a"); // path + let routing = RoutingInfo::for_routable(&cmd); + let mut expected = std::collections::HashMap::new(); + expected.insert(Route(4813, SlotAddr::ReplicaOptional), vec![2, 4]); + expected.insert(Route(5061, SlotAddr::ReplicaOptional), vec![1, 3, 4]); + expected.insert(Route(12182, SlotAddr::ReplicaOptional), vec![0, 4]); + + assert!( + matches!(routing.clone(), Some(RoutingInfo::MultiNode((MultipleNodeRoutingInfo::MultiSlot((vec, args_pattern)), Some(ResponsePolicy::CombineArrays)))) if { + let routes = vec.clone().into_iter().collect(); + expected == routes && args_pattern == MultiSlotArgPattern::KeysAndLastArg + }), + "expected={expected:?}\nrouting={routing:?}" + ); + } + + #[test] + fn test_multi_shard_key_with_two_arg_triples() { + let mut cmd = cmd("JSON.MSET"); + cmd + .arg("foo") // key slot 12182 + .arg("$.a") // path + .arg("bar") // value + .arg("foo2") // key slot 1044 + .arg("$.f.a") // path + .arg("bar2") // value + .arg("{foo}foo3") // key slot 12182 + .arg("$.f.a") // path + .arg("bar3"); // value + let routing = RoutingInfo::for_routable(&cmd); + let mut expected = std::collections::HashMap::new(); + expected.insert(Route(1044, SlotAddr::Master), vec![3, 4, 5]); + expected.insert(Route(12182, SlotAddr::Master), vec![0, 1, 2, 6, 7, 8]); + + assert!( + matches!(routing.clone(), Some(RoutingInfo::MultiNode((MultipleNodeRoutingInfo::MultiSlot((vec, args_pattern)), Some(ResponsePolicy::AllSucceeded)))) if { + let routes = vec.clone().into_iter().collect(); + expected == routes && args_pattern == MultiSlotArgPattern::KeyWithTwoArgTriples + }), + "expected={expected:?}\nrouting={routing:?}" + ); + } + + #[test] + fn test_command_creation_for_multi_shard() { + let mut original_cmd = cmd("DEL"); + original_cmd + .arg("foo") + .arg("bar") + .arg("baz") + .arg("{bar}vaz"); + let routing = RoutingInfo::for_routable(&original_cmd); + let expected = [vec![0], vec![1, 3], vec![2]]; + + let mut indices: Vec<_> = match routing { + Some(RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::MultiSlot((vec, MultiSlotArgPattern::KeysOnly)), + _, + ))) => vec.into_iter().map(|(_, indices)| indices).collect(), + _ => panic!("unexpected routing: {routing:?}"), + }; + indices.sort_by(|prev, next| prev.iter().next().unwrap().cmp(next.iter().next().unwrap())); // sorting because the `for_routable` doesn't return values in a consistent order between runs. + + for (index, indices) in indices.into_iter().enumerate() { + let cmd = command_for_multi_slot_indices(&original_cmd, indices.iter()); + let expected_indices = &expected[index]; + assert_eq!(original_cmd.arg_idx(0), cmd.arg_idx(0)); + for (index, target_index) in expected_indices.iter().enumerate() { + let target_index = target_index + 1; + assert_eq!(original_cmd.arg_idx(target_index), cmd.arg_idx(index + 1)); + } + } + } + + #[test] + fn test_combine_multi_shard_to_single_node_when_all_keys_are_in_same_slot() { + let mut cmd = cmd("DEL"); + cmd.arg("foo").arg("{foo}bar").arg("{foo}baz"); + let routing = RoutingInfo::for_routable(&cmd); + + assert!( + matches!( + routing, + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route(12182, SlotAddr::Master)) + )) + ), + "{routing:?}" + ); + } + + #[test] + fn test_combining_results_into_single_array_only_keys() { + // For example `MGET foo bar baz {baz}baz2 {bar}bar2 {foo}foo2` + let res1 = Value::Array(vec![Value::Nil, Value::Okay]); + let res2 = Value::Array(vec![ + Value::BulkString("1".as_bytes().to_vec()), + Value::BulkString("4".as_bytes().to_vec()), + ]); + let res3 = Value::Array(vec![Value::SimpleString("2".to_string()), Value::Int(3)]); + let results = super::combine_and_sort_array_results( + vec![res1, res2, res3], + &[ + (Route(4813, SlotAddr::Master), vec![2, 3]), + (Route(5061, SlotAddr::Master), vec![1, 4]), + (Route(12182, SlotAddr::Master), vec![0, 5]), + ], + &MultiSlotArgPattern::KeysOnly, + ); + + assert_eq!( + results.unwrap(), + Value::Array(vec![ + Value::SimpleString("2".to_string()), + Value::BulkString("1".as_bytes().to_vec()), + Value::Nil, + Value::Okay, + Value::BulkString("4".as_bytes().to_vec()), + Value::Int(3), + ]) + ); + } + + #[test] + fn test_combining_results_into_single_array_key_value_paires() { + // For example `MSET foo bar foo2 bar2 {foo}foo3 bar3` + let res1 = Value::Array(vec![Value::Okay]); + let res2 = Value::Array(vec![Value::BulkString("1".as_bytes().to_vec()), Value::Nil]); + let results = super::combine_and_sort_array_results( + vec![res1, res2], + &[ + (Route(1044, SlotAddr::Master), vec![2, 3]), + (Route(12182, SlotAddr::Master), vec![0, 1, 4, 5]), + ], + &MultiSlotArgPattern::KeyValuePairs, + ); + + assert_eq!( + results.unwrap(), + Value::Array(vec![ + Value::BulkString("1".as_bytes().to_vec()), + Value::Okay, + Value::Nil + ]) + ); + } + + #[test] + fn test_combining_results_into_single_array_keys_and_path() { + // For example `JSON.MGET foo bar {foo}foo2 $.a` + let res1 = Value::Array(vec![Value::Okay]); + let res2 = Value::Array(vec![Value::BulkString("1".as_bytes().to_vec()), Value::Nil]); + let results = super::combine_and_sort_array_results( + vec![res1, res2], + &[ + (Route(5061, SlotAddr::Master), vec![2, 3]), + (Route(12182, SlotAddr::Master), vec![0, 1, 3]), + ], + &MultiSlotArgPattern::KeysAndLastArg, + ); + + assert_eq!( + results.unwrap(), + Value::Array(vec![ + Value::BulkString("1".as_bytes().to_vec()), + Value::Nil, + Value::Okay, + ]) + ); + } + + #[test] + fn test_combining_results_into_single_array_key_with_two_arg_triples() { + // For example `JSON.MSET foo $.a bar foo2 $.f.a bar2 {foo}foo3 $.f bar3` + let res1 = Value::Array(vec![Value::Okay]); + let res2 = Value::Array(vec![Value::BulkString("1".as_bytes().to_vec()), Value::Nil]); + let results = super::combine_and_sort_array_results( + vec![res1, res2], + &[ + (Route(5061, SlotAddr::Master), vec![3, 4, 5]), + (Route(12182, SlotAddr::Master), vec![0, 1, 2, 6, 7, 8]), + ], + &MultiSlotArgPattern::KeyWithTwoArgTriples, + ); + + assert_eq!( + results.unwrap(), + Value::Array(vec![ + Value::BulkString("1".as_bytes().to_vec()), + Value::Okay, + Value::Nil + ]) + ); + } + + #[test] + fn test_combine_map_results() { + let input = vec![]; + let result = super::combine_map_results(input).unwrap(); + assert_eq!(result, Value::Map(vec![])); + + let input = vec![ + Value::Array(vec![ + Value::BulkString(b"key1".to_vec()), + Value::Int(5), + Value::BulkString(b"key2".to_vec()), + Value::Int(10), + ]), + Value::Array(vec![ + Value::BulkString(b"key1".to_vec()), + Value::Int(3), + Value::BulkString(b"key3".to_vec()), + Value::Int(15), + ]), + ]; + let result = super::combine_map_results(input).unwrap(); + let mut expected = vec![ + (Value::BulkString(b"key1".to_vec()), Value::Int(8)), + (Value::BulkString(b"key2".to_vec()), Value::Int(10)), + (Value::BulkString(b"key3".to_vec()), Value::Int(15)), + ]; + expected.sort_unstable_by(|a, b| match (&a.0, &b.0) { + (Value::BulkString(a_bytes), Value::BulkString(b_bytes)) => a_bytes.cmp(b_bytes), + _ => std::cmp::Ordering::Equal, + }); + let mut result_vec = match result { + Value::Map(v) => v, + _ => panic!("Expected Map"), + }; + result_vec.sort_unstable_by(|a, b| match (&a.0, &b.0) { + (Value::BulkString(a_bytes), Value::BulkString(b_bytes)) => a_bytes.cmp(b_bytes), + _ => std::cmp::Ordering::Equal, + }); + assert_eq!(result_vec, expected); + + let input = vec![Value::Int(5)]; + let result = super::combine_map_results(input); + assert!(result.is_err()); + } + + fn create_shard_addrs(primary: &str, replicas: Vec<&str>) -> ShardAddrs { + ShardAddrs { + primary: RwLock::new(Arc::new(primary.to_string())), + replicas: RwLock::new( + replicas + .into_iter() + .map(|r| Arc::new(r.to_string())) + .collect(), + ), + } + } + + #[test] + fn test_attempt_shard_role_update_already_primary() { + let shard_addrs = create_shard_addrs("node1:6379", vec!["node2:6379", "node3:6379"]); + let result = shard_addrs.attempt_shard_role_update(Arc::new("node1:6379".to_string())); + assert_eq!(result, ShardUpdateResult::AlreadyPrimary); + } + + #[test] + fn test_attempt_shard_role_update_promoted() { + let shard_addrs = create_shard_addrs("node1:6379", vec!["node2:6379", "node3:6379"]); + let result = shard_addrs.attempt_shard_role_update(Arc::new("node2:6379".to_string())); + assert_eq!(result, ShardUpdateResult::Promoted); + + let primary = shard_addrs.primary.read().unwrap().clone(); + assert_eq!(primary.as_str(), "node2:6379"); + + let replicas = shard_addrs.replicas.read().unwrap(); + assert_eq!(replicas.len(), 2); + assert!(replicas.iter().any(|r| r.as_str() == "node1:6379")); + } + + #[test] + fn test_attempt_shard_role_update_node_not_found() { + let shard_addrs = create_shard_addrs("node1:6379", vec!["node2:6379", "node3:6379"]); + let result = shard_addrs.attempt_shard_role_update(Arc::new("node4:6379".to_string())); + assert_eq!(result, ShardUpdateResult::NodeNotFound); + } +} diff --git a/glide-core/redis-rs/redis/src/cluster_slotmap.rs b/glide-core/redis-rs/redis/src/cluster_slotmap.rs new file mode 100644 index 0000000000..88e7549323 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_slotmap.rs @@ -0,0 +1,1141 @@ +use std::sync::Arc; +use std::{ + collections::{BTreeMap, HashSet}, + fmt::Display, + sync::atomic::AtomicUsize, +}; + +use dashmap::DashMap; + +use crate::cluster_routing::{Route, ShardAddrs, Slot, SlotAddr}; +use crate::ErrorKind; +use crate::RedisError; +use crate::RedisResult; +pub(crate) type NodesMap = DashMap, Arc>; + +#[derive(Debug)] +pub(crate) struct SlotMapValue { + pub(crate) start: u16, + pub(crate) addrs: Arc, + pub(crate) last_used_replica: Arc, +} + +#[derive(Debug, Default, Clone, PartialEq)] +/// Represents the client's read from strategy. +pub enum ReadFromReplicaStrategy { + #[default] + /// Always get from primary, in order to get the freshest data. + AlwaysFromPrimary, + /// Spread the read requests between all replicas in a round robin manner. + /// If no replica is available, route the requests to the primary. + RoundRobin, + /// Spread the read requests between replicas in the same client's Aviliablity zone in a round robin manner, + /// falling back to other replicas or the primary if needed. + AZAffinity(String), +} + +#[derive(Debug, Default)] +pub(crate) struct SlotMap { + pub(crate) slots: BTreeMap, + nodes_map: NodesMap, + read_from_replica: ReadFromReplicaStrategy, +} + +fn get_address_from_slot( + slot: &SlotMapValue, + read_from_replica: ReadFromReplicaStrategy, + slot_addr: SlotAddr, +) -> Arc { + let addrs = &slot.addrs; + if slot_addr == SlotAddr::Master || addrs.replicas().is_empty() { + return addrs.primary(); + } + match read_from_replica { + ReadFromReplicaStrategy::AlwaysFromPrimary => addrs.primary(), + ReadFromReplicaStrategy::RoundRobin => { + let index = slot + .last_used_replica + .fetch_add(1, std::sync::atomic::Ordering::Relaxed) + % addrs.replicas().len(); + addrs.replicas()[index].clone() + } + ReadFromReplicaStrategy::AZAffinity(_az) => todo!(), // Drop sync client + } +} + +impl SlotMap { + pub(crate) fn new_with_read_strategy(read_from_replica: ReadFromReplicaStrategy) -> Self { + SlotMap { + slots: BTreeMap::new(), + nodes_map: DashMap::new(), + read_from_replica, + } + } + + pub(crate) fn new(slots: Vec, read_from_replica: ReadFromReplicaStrategy) -> Self { + let mut slot_map = SlotMap::new_with_read_strategy(read_from_replica); + let mut shard_id = 0; + for slot in slots { + let primary = Arc::new(slot.master); + // Get the shard addresses if the primary is already in nodes_map; + // otherwise, create a new ShardAddrs and add it + let shard_addrs_arc = slot_map + .nodes_map + .entry(primary.clone()) + .or_insert_with(|| { + shard_id += 1; + let replicas: Vec> = + slot.replicas.into_iter().map(Arc::new).collect(); + Arc::new(ShardAddrs::new(primary, replicas)) + }) + .clone(); + + // Add all replicas to nodes_map with a reference to the same ShardAddrs if not already present + shard_addrs_arc.replicas().iter().for_each(|replica| { + slot_map + .nodes_map + .entry(replica.clone()) + .or_insert(shard_addrs_arc.clone()); + }); + + // Insert the slot value into the slots map + slot_map.slots.insert( + slot.end, + SlotMapValue { + addrs: shard_addrs_arc.clone(), + start: slot.start, + last_used_replica: Arc::new(AtomicUsize::new(0)), + }, + ); + } + slot_map + } + + pub(crate) fn nodes_map(&self) -> &NodesMap { + &self.nodes_map + } + + pub fn is_primary(&self, address: &String) -> bool { + self.nodes_map + .get(address) + .map_or(false, |shard_addrs| *shard_addrs.primary() == *address) + } + + pub fn slot_value_for_route(&self, route: &Route) -> Option<&SlotMapValue> { + let slot = route.slot(); + self.slots + .range(slot..) + .next() + .and_then(|(end, slot_value)| { + if slot <= *end && slot_value.start <= slot { + Some(slot_value) + } else { + None + } + }) + } + + pub fn slot_addr_for_route(&self, route: &Route) -> Option> { + self.slot_value_for_route(route).map(|slot_value| { + get_address_from_slot( + slot_value, + self.read_from_replica.clone(), + route.slot_addr(), + ) + }) + } + + /// Retrieves the shard addresses (`ShardAddrs`) for the specified `slot` by looking it up in the `slots` tree, + /// returning a reference to the stored shard addresses if found. + pub(crate) fn shard_addrs_for_slot(&self, slot: u16) -> Option> { + self.slots + .range(slot..) + .next() + .map(|(_, slot_value)| slot_value.addrs.clone()) + } + + pub fn addresses_for_all_primaries(&self) -> HashSet> { + self.nodes_map + .iter() + .map(|map_item| { + let shard_addrs = map_item.value(); + shard_addrs.primary().clone() + }) + .collect() + } + + pub fn all_node_addresses(&self) -> HashSet> { + self.nodes_map + .iter() + .map(|map_item| { + let node_addr = map_item.key(); + node_addr.clone() + }) + .collect() + } + + pub fn addresses_for_multi_slot<'a, 'b>( + &'a self, + routes: &'b [(Route, Vec)], + ) -> impl Iterator>> + 'a + where + 'b: 'a, + { + routes + .iter() + .map(|(route, _)| self.slot_addr_for_route(route)) + } + + // Returns the slots that are assigned to the given address. + pub(crate) fn get_slots_of_node(&self, node_address: Arc) -> Vec { + self.slots + .iter() + .filter_map(|(end, slot_value)| { + let addrs = &slot_value.addrs; + if addrs.primary() == node_address || addrs.replicas().contains(&node_address) { + Some(slot_value.start..(*end + 1)) + } else { + None + } + }) + .flatten() + .collect() + } + + pub(crate) fn get_node_address_for_slot( + &self, + slot: u16, + slot_addr: SlotAddr, + ) -> Option> { + self.slots.range(slot..).next().and_then(|(_, slot_value)| { + if slot_value.start <= slot { + Some(get_address_from_slot( + slot_value, + self.read_from_replica.clone(), + slot_addr, + )) + } else { + None + } + }) + } + + /// Inserts a single slot into the `slots` map, associating it with a new `SlotMapValue` + /// that contains the shard addresses (`shard_addrs`) and represents a range of just the given slot. + /// + /// # Returns + /// * `Option` - Returns the previous `SlotMapValue` if a slot already existed for the given key, + /// or `None` if the slot was newly inserted. + fn insert_single_slot( + &mut self, + slot: u16, + shard_addrs: Arc, + ) -> Option { + self.slots.insert( + slot, + SlotMapValue { + start: slot, + addrs: shard_addrs, + last_used_replica: Arc::new(AtomicUsize::new(0)), + }, + ) + } + + /// Creats a new shard addresses that contain only the primary node, adds it to the nodes map + /// and updates the slots tree for the given `slot` to point to the new primary. + pub(crate) fn add_new_primary(&mut self, slot: u16, node_addr: Arc) -> RedisResult<()> { + let shard_addrs = Arc::new(ShardAddrs::new_with_primary(node_addr.clone())); + self.nodes_map.insert(node_addr, shard_addrs.clone()); + self.update_slot_range(slot, shard_addrs) + } + + fn shard_addrs_equal(shard1: &Arc, shard2: &Arc) -> bool { + Arc::ptr_eq(shard1, shard2) + } + + /// Updates the end of an existing slot range in the `slots` tree. This function removes the slot entry + /// associated with the current end (`curr_end`) and reinserts it with a new end value (`new_end`). + /// + /// The operation effectively shifts the range's end boundary from `curr_end` to `new_end`, while keeping the + /// rest of the slot's data (e.g., shard addresses) unchanged. + /// + /// # Parameters: + /// - `curr_end`: The current end of the slot range that will be removed. + /// - `new_end`: The new end of the slot range where the slot data will be reinserted. + fn update_end_range(&mut self, curr_end: u16, new_end: u16) -> RedisResult<()> { + if let Some(curr_slot_val) = self.slots.remove(&curr_end) { + self.slots.insert(new_end, curr_slot_val); + return Ok(()); + } + Err(RedisError::from(( + ErrorKind::ClientError, + "Couldn't find slot range with end: {curr_end:?} in the slot map", + ))) + } + + /// Attempts to merge the current `slot` with the next slot range in the `slots` map, if they are consecutive + /// and share the same shard addresses. If the next slot's starting position is exactly `slot + 1` + /// and the shard addresses match, the next slot's starting point is moved to `slot`, effectively merging + /// the slot to the existing range. + /// + /// # Parameters: + /// - `slot`: The slot to attempt to merge with the next slot. + /// - `new_addrs`: The shard addresses to compare with the next slot's shard addresses. + /// + /// # Returns: + /// - `bool`: Returns `true` if the merge was successful, otherwise `false`. + fn try_merge_to_next_range(&mut self, slot: u16, new_addrs: Arc) -> bool { + if let Some((_next_end, next_slot_value)) = self.slots.range_mut((slot + 1)..).next() { + if next_slot_value.start == slot + 1 + && Self::shard_addrs_equal(&next_slot_value.addrs, &new_addrs) + { + next_slot_value.start = slot; + return true; + } + } + false + } + + /// Attempts to merge the current slot with the previous slot range in the `slots` map, if they are consecutive + /// and share the same shard addresses. If the previous slot ends at `slot - 1` and the shard addresses match, + /// the end of the previous slot is extended to `slot`, effectively merging the slot to the existing range. + /// + /// # Parameters: + /// - `slot`: The slot to attempt to merge with the previous slot. + /// - `new_addrs`: The shard addresses to compare with the previous slot's shard addresses. + /// + /// # Returns: + /// - `RedisResult`: Returns `Ok(true)` if the merge was successful, otherwise `Ok(false)`. + fn try_merge_to_prev_range( + &mut self, + slot: u16, + new_addrs: Arc, + ) -> RedisResult { + if let Some((prev_end, prev_slot_value)) = self.slots.range_mut(..slot).next_back() { + if *prev_end == slot - 1 && Self::shard_addrs_equal(&prev_slot_value.addrs, &new_addrs) + { + let prev_end = *prev_end; + self.update_end_range(prev_end, slot)?; + return Ok(true); + } + } + Ok(false) + } + + /// Updates the slot range in the `slots` to point to new shard addresses. + /// + /// This function handles the following scenarios when updating the slot mapping: + /// + /// **Scenario 1 - Same Shard Owner**: + /// - If the slot is already associated with the same shard addresses, no changes are needed. + /// + /// **Scenario 2 - Single Slot Range**: + /// - If the slot is the only slot in the current range (i.e., `start == end == slot`), + /// the function simply replaces the shard addresses for this slot with the new shard addresses. + /// + /// **Scenario 3 - Slot Matches the End of a Range**: + /// - If the slot is the last slot in the current range (`slot == end`), the function + /// adjusts the range by decrementing the end of the current range by 1 (making the + /// new end equal to `end - 1`). The current slot is then removed and a new entry is + /// inserted for the slot with the new shard addresses. + /// + /// **Scenario 4 - Slot Matches the Start of a Range**: + /// - If the slot is the first slot in the current range (`slot == start`), the function + /// increments the start of the current range by 1 (making the new start equal to + /// `start + 1`). A new entry is then inserted for the slot with the new shard addresses. + /// + /// **Scenario 5 - Slot is Within a Range**: + /// - If the slot falls between the start and end of a current range (`start < slot < end`), + /// the function splits the current range into two. The range before the slot (`start` to + /// `slot - 1`) remains with the old shard addresses, a new entry for the slot is added + /// with the new shard addresses, and the range after the slot (`slot + 1` to `end`) is + /// reinserted with the old shard addresses. + /// + /// **Scenario 6 - Slot is Not Covered**: + /// - If the slot is not part of any existing range, a new entry is simply inserted into + /// the `slots` tree with the new shard addresses. + /// + /// # Parameters: + /// - `slot`: The specific slot that needs to be updated. + /// - `new_addrs`: The new shard addresses to associate with the slot. + /// + /// # Returns: + /// - `RedisResult<()>`: Indicates the success or failure of the operation. + pub(crate) fn update_slot_range( + &mut self, + slot: u16, + new_addrs: Arc, + ) -> RedisResult<()> { + let curr_tree_node = + self.slots + .range_mut(slot..) + .next() + .and_then(|(&end, slot_map_value)| { + if slot >= slot_map_value.start && slot <= end { + Some((end, slot_map_value)) + } else { + None + } + }); + + if let Some((curr_end, curr_slot_val)) = curr_tree_node { + // Scenario 1: Same shard owner + if Self::shard_addrs_equal(&curr_slot_val.addrs, &new_addrs) { + return Ok(()); + } + // Scenario 2: The slot is the only slot in the current range + else if curr_slot_val.start == curr_end && curr_slot_val.start == slot { + // Replace the shard addresses of the current slot value + curr_slot_val.addrs = new_addrs; + // Scenario 3: Slot matches the end of the current range + } else if slot == curr_end { + // Merge with the next range if shard addresses match + if self.try_merge_to_next_range(slot, new_addrs.clone()) { + // Adjust current range end + self.update_end_range(curr_end, curr_end - 1)?; + } else { + // Insert as a standalone slot + let curr_slot_val = self.insert_single_slot(curr_end, new_addrs); + if let Some(curr_slot_val) = curr_slot_val { + // Adjust current range end + self.slots.insert(curr_end - 1, curr_slot_val); + } + } + + // Scenario 4: Slot matches the start of the current range + } else if slot == curr_slot_val.start { + // Adjust current range start + curr_slot_val.start += 1; + // Attempt to merge with the previous range + if !self.try_merge_to_prev_range(slot, new_addrs.clone())? { + // Insert as a standalone slot + self.insert_single_slot(slot, new_addrs); + } + + // Scenario 5: Slot is within the current range + } else if slot > curr_slot_val.start && slot < curr_end { + // We will split the current range into three parts: + // A: [start, slot - 1], which will remain owned by the current shard, + // B: [slot, slot], which will be owned by the new shard addresses, + // C: [slot + 1, end], which will remain owned by the current shard. + + let start: u16 = curr_slot_val.start; + let addrs = curr_slot_val.addrs.clone(); + let last_used_replica = curr_slot_val.last_used_replica.clone(); + + // Modify the current slot range to become part C: [slot + 1, end], still owned by the current shard. + curr_slot_val.start = slot + 1; + + // Create and insert a new SlotMapValue representing part A: [start, slot - 1], + // still owned by the current shard, into the slot map. + self.slots.insert( + slot - 1, + SlotMapValue { + start, + addrs, + last_used_replica, + }, + ); + + // Insert the new shard addresses into the slot map as part B: [slot, slot], + // which will be owned by the new shard addresses. + self.insert_single_slot(slot, new_addrs); + } + // Scenario 6: Slot isn't covered by any existing range + } else { + // Try merging with the previous or next range; if no merge is possible, insert as a standalone slot + if !self.try_merge_to_prev_range(slot, new_addrs.clone())? + && !self.try_merge_to_next_range(slot, new_addrs.clone()) + { + self.insert_single_slot(slot, new_addrs); + } + } + Ok(()) + } +} + +impl Display for SlotMap { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "Strategy: {:?}. Slot mapping:", self.read_from_replica)?; + for (end, slot_map_value) in self.slots.iter() { + let addrs = &slot_map_value.addrs; + writeln!( + f, + "({}-{}): primary: {}, replicas: {:?}", + slot_map_value.start, + end, + addrs.primary(), + addrs.replicas() + )?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests_cluster_slotmap { + use super::*; + + fn process_expected(expected: Vec<&str>) -> HashSet> { + as IntoIterator>::into_iter(HashSet::from_iter(expected)) + .map(|s| Arc::new(s.to_string())) + .collect() + } + + fn process_expected_with_option(expected: Vec>) -> Vec> { + expected + .into_iter() + .filter_map(|opt| opt.map(|s| Arc::new(s.to_string()))) + .collect() + } + + #[test] + fn test_slot_map_retrieve_routes() { + let slot_map = SlotMap::new( + vec![ + Slot::new( + 1, + 1000, + "node1:6379".to_owned(), + vec!["replica1:6379".to_owned()], + ), + Slot::new( + 1002, + 2000, + "node2:6379".to_owned(), + vec!["replica2:6379".to_owned()], + ), + ], + ReadFromReplicaStrategy::AlwaysFromPrimary, + ); + + assert!(slot_map + .slot_addr_for_route(&Route::new(0, SlotAddr::Master)) + .is_none()); + assert_eq!( + "node1:6379", + *slot_map + .slot_addr_for_route(&Route::new(1, SlotAddr::Master)) + .unwrap() + ); + assert_eq!( + "node1:6379", + *slot_map + .slot_addr_for_route(&Route::new(500, SlotAddr::Master)) + .unwrap() + ); + assert_eq!( + "node1:6379", + *slot_map + .slot_addr_for_route(&Route::new(1000, SlotAddr::Master)) + .unwrap() + ); + assert!(slot_map + .slot_addr_for_route(&Route::new(1001, SlotAddr::Master)) + .is_none()); + + assert_eq!( + "node2:6379", + *slot_map + .slot_addr_for_route(&Route::new(1002, SlotAddr::Master)) + .unwrap() + ); + assert_eq!( + "node2:6379", + *slot_map + .slot_addr_for_route(&Route::new(1500, SlotAddr::Master)) + .unwrap() + ); + assert_eq!( + "node2:6379", + *slot_map + .slot_addr_for_route(&Route::new(2000, SlotAddr::Master)) + .unwrap() + ); + assert!(slot_map + .slot_addr_for_route(&Route::new(2001, SlotAddr::Master)) + .is_none()); + } + + fn get_slot_map(read_from_replica: ReadFromReplicaStrategy) -> SlotMap { + SlotMap::new( + vec![ + Slot::new( + 1, + 1000, + "node1:6379".to_owned(), + vec!["replica1:6379".to_owned()], + ), + Slot::new( + 1002, + 2000, + "node2:6379".to_owned(), + vec!["replica2:6379".to_owned(), "replica3:6379".to_owned()], + ), + Slot::new( + 2001, + 3000, + "node3:6379".to_owned(), + vec![ + "replica4:6379".to_owned(), + "replica5:6379".to_owned(), + "replica6:6379".to_owned(), + ], + ), + Slot::new( + 3001, + 4000, + "node2:6379".to_owned(), + vec!["replica2:6379".to_owned(), "replica3:6379".to_owned()], + ), + ], + read_from_replica, + ) + } + + #[test] + fn test_slot_map_get_all_primaries() { + let slot_map = get_slot_map(ReadFromReplicaStrategy::AlwaysFromPrimary); + let addresses = slot_map.addresses_for_all_primaries(); + assert_eq!( + addresses, + process_expected(vec!["node1:6379", "node2:6379", "node3:6379"]) + ); + } + + #[test] + fn test_slot_map_get_all_nodes() { + let slot_map = get_slot_map(ReadFromReplicaStrategy::AlwaysFromPrimary); + let addresses = slot_map.all_node_addresses(); + assert_eq!( + addresses, + process_expected(vec![ + "node1:6379", + "node2:6379", + "node3:6379", + "replica1:6379", + "replica2:6379", + "replica3:6379", + "replica4:6379", + "replica5:6379", + "replica6:6379" + ]) + ); + } + + #[test] + fn test_slot_map_get_multi_node() { + let slot_map = get_slot_map(ReadFromReplicaStrategy::RoundRobin); + let routes = vec![ + (Route::new(1, SlotAddr::Master), vec![]), + (Route::new(2001, SlotAddr::ReplicaOptional), vec![]), + ]; + let addresses = slot_map + .addresses_for_multi_slot(&routes) + .collect::>(); + assert!(addresses.contains(&Some(Arc::new("node1:6379".to_string())))); + assert!( + addresses.contains(&Some(Arc::new("replica4:6379".to_string()))) + || addresses.contains(&Some(Arc::new("replica5:6379".to_string()))) + || addresses.contains(&Some(Arc::new("replica6:6379".to_string()))) + ); + } + + /// This test is needed in order to verify that if the MultiSlot route finds the same node for more than a single route, + /// that node's address will appear multiple times, in the same order. + #[test] + fn test_slot_map_get_repeating_addresses_when_the_same_node_is_found_in_multi_slot() { + let slot_map = get_slot_map(ReadFromReplicaStrategy::RoundRobin); + let routes = vec![ + (Route::new(1, SlotAddr::ReplicaOptional), vec![]), + (Route::new(2001, SlotAddr::Master), vec![]), + (Route::new(2, SlotAddr::ReplicaOptional), vec![]), + (Route::new(2002, SlotAddr::Master), vec![]), + (Route::new(3, SlotAddr::ReplicaOptional), vec![]), + (Route::new(2003, SlotAddr::Master), vec![]), + ]; + let addresses: Vec> = slot_map + .addresses_for_multi_slot(&routes) + .flatten() + .collect(); + + assert_eq!( + addresses, + process_expected_with_option(vec![ + Some("replica1:6379"), + Some("node3:6379"), + Some("replica1:6379"), + Some("node3:6379"), + Some("replica1:6379"), + Some("node3:6379") + ]) + ); + } + + #[test] + fn test_slot_map_get_none_when_slot_is_missing_from_multi_slot() { + let slot_map = get_slot_map(ReadFromReplicaStrategy::RoundRobin); + let routes = vec![ + (Route::new(1, SlotAddr::ReplicaOptional), vec![]), + (Route::new(5000, SlotAddr::Master), vec![]), + (Route::new(6000, SlotAddr::ReplicaOptional), vec![]), + (Route::new(2002, SlotAddr::Master), vec![]), + ]; + let addresses: Vec> = slot_map + .addresses_for_multi_slot(&routes) + .flatten() + .collect(); + + assert_eq!( + addresses, + process_expected_with_option(vec![ + Some("replica1:6379"), + None, + None, + Some("node3:6379") + ]) + ); + } + + #[test] + fn test_slot_map_rotate_read_replicas() { + let slot_map = get_slot_map(ReadFromReplicaStrategy::RoundRobin); + let route = Route::new(2001, SlotAddr::ReplicaOptional); + let mut addresses = vec![ + slot_map.slot_addr_for_route(&route).unwrap(), + slot_map.slot_addr_for_route(&route).unwrap(), + slot_map.slot_addr_for_route(&route).unwrap(), + ]; + addresses.sort(); + assert_eq!( + addresses, + vec!["replica4:6379", "replica5:6379", "replica6:6379"] + .into_iter() + .map(|s| Arc::new(s.to_string())) + .collect::>() + ); + } + + #[test] + fn test_get_slots_of_node() { + let slot_map = get_slot_map(ReadFromReplicaStrategy::AlwaysFromPrimary); + assert_eq!( + slot_map.get_slots_of_node(Arc::new("node1:6379".to_string())), + (1..1001).collect::>() + ); + assert_eq!( + slot_map.get_slots_of_node(Arc::new("node2:6379".to_string())), + vec![1002..2001, 3001..4001] + .into_iter() + .flatten() + .collect::>() + ); + assert_eq!( + slot_map.get_slots_of_node(Arc::new("replica3:6379".to_string())), + vec![1002..2001, 3001..4001] + .into_iter() + .flatten() + .collect::>() + ); + assert_eq!( + slot_map.get_slots_of_node(Arc::new("replica4:6379".to_string())), + (2001..3001).collect::>() + ); + assert_eq!( + slot_map.get_slots_of_node(Arc::new("replica5:6379".to_string())), + (2001..3001).collect::>() + ); + assert_eq!( + slot_map.get_slots_of_node(Arc::new("replica6:6379".to_string())), + (2001..3001).collect::>() + ); + } + + fn create_slot(start: u16, end: u16, master: &str, replicas: Vec<&str>) -> Slot { + Slot::new( + start, + end, + master.to_owned(), + replicas.into_iter().map(|r| r.to_owned()).collect(), + ) + } + + fn assert_equal_slot_maps(this: SlotMap, expected: Vec) { + for ((end, slot_value), expected_slot) in this.slots.iter().zip(expected.iter()) { + assert_eq!(*end, expected_slot.end); + assert_eq!(slot_value.start, expected_slot.start); + let shard_addrs = &slot_value.addrs; + assert_eq!(*shard_addrs.primary(), expected_slot.master); + let _ = shard_addrs + .replicas() + .iter() + .zip(expected_slot.replicas.iter()) + .map(|(curr, expected)| { + assert_eq!(**curr, *expected); + }); + } + } + + fn assert_slot_map_and_shard_addrs( + slot_map: SlotMap, + slot: u16, + new_shard_addrs: Arc, + expected_slots: Vec, + ) { + assert!(SlotMap::shard_addrs_equal( + &slot_map.shard_addrs_for_slot(slot).unwrap(), + &new_shard_addrs + )); + assert_equal_slot_maps(slot_map, expected_slots); + } + + #[test] + fn test_update_slot_range_single_slot_range() { + let test_slot = 8000; + let before_slots = vec![ + create_slot(0, 7999, "node1:6379", vec!["replica1:6379"]), + create_slot(8000, 8000, "node1:6379", vec!["replica1:6379"]), + create_slot(8001, 16383, "node3:6379", vec!["replica3:6379"]), + ]; + + let mut slot_map = SlotMap::new(before_slots, ReadFromReplicaStrategy::AlwaysFromPrimary); + let new_shard_addrs = slot_map + .shard_addrs_for_slot(8001) + .expect("Couldn't find shard address for slot"); + + let res = slot_map.update_slot_range(test_slot, new_shard_addrs.clone()); + assert!(res.is_ok(), "{res:?}"); + + let after_slots = vec![ + create_slot(0, test_slot - 1, "node1:6379", vec!["replica1:6379"]), + create_slot(test_slot, test_slot, "node3:6379", vec!["replica3:6379"]), + create_slot(test_slot + 1, 16383, "node3:6379", vec!["replica3:6379"]), + ]; + + assert_slot_map_and_shard_addrs(slot_map, test_slot, new_shard_addrs, after_slots); + } + + #[test] + fn test_update_slot_range_slot_matches_end_range_merge_ranges() { + let test_slot = 7999; + let before_slots = vec![ + create_slot(0, 7999, "node1:6379", vec!["replica1:6379"]), + create_slot(8000, 16383, "node2:6379", vec!["replica2:6379"]), + ]; + + let mut slot_map = SlotMap::new(before_slots, ReadFromReplicaStrategy::AlwaysFromPrimary); + let new_shard_addrs = slot_map + .shard_addrs_for_slot(8000) + .expect("Couldn't find shard address for slot"); + + let res = slot_map.update_slot_range(test_slot, new_shard_addrs.clone()); + assert!(res.is_ok(), "{res:?}"); + + let after_slots = vec![ + create_slot(0, test_slot - 1, "node1:6379", vec!["replica1:6379"]), + create_slot(test_slot, 16383, "node2:6379", vec!["replica2:6379"]), + ]; + + assert_slot_map_and_shard_addrs(slot_map, test_slot, new_shard_addrs, after_slots); + } + + #[test] + fn test_update_slot_range_slot_matches_end_range_cant_merge_ranges() { + let test_slot = 7999; + let before_slots = vec![ + create_slot(0, 7999, "node1:6379", vec!["replica1:6379"]), + create_slot(8000, 16383, "node2:6379", vec!["replica2:6379"]), + ]; + + let mut slot_map = SlotMap::new(before_slots, ReadFromReplicaStrategy::AlwaysFromPrimary); + let new_shard_addrs = Arc::new(ShardAddrs::new( + Arc::new("node3:6379".to_owned()), + vec![Arc::new("replica3:6379".to_owned())], + )); + + let res = slot_map.update_slot_range(test_slot, new_shard_addrs.clone()); + assert!(res.is_ok(), "{res:?}"); + + let after_slots = vec![ + create_slot(0, test_slot - 1, "node1:6379", vec!["replica1:6379"]), + create_slot(test_slot, test_slot, "node3:6379", vec!["replica3:6379"]), + create_slot(test_slot + 1, 16383, "node2:6379", vec!["replica2:6379"]), + ]; + + assert_slot_map_and_shard_addrs(slot_map, test_slot, new_shard_addrs, after_slots); + } + + #[test] + fn test_update_slot_range_slot_matches_start_range_merge_ranges() { + let test_slot = 8000; + let before_slots = vec![ + create_slot(0, 7999, "node1:6379", vec!["replica1:6379"]), + create_slot(8000, 16383, "node2:6379", vec!["replica2:6379"]), + ]; + + let mut slot_map = SlotMap::new(before_slots, ReadFromReplicaStrategy::AlwaysFromPrimary); + let new_shard_addrs = slot_map + .shard_addrs_for_slot(7999) + .expect("Couldn't find shard address for slot"); + + let res = slot_map.update_slot_range(test_slot, new_shard_addrs.clone()); + assert!(res.is_ok(), "{res:?}"); + + let after_slots = vec![ + create_slot(0, test_slot, "node1:6379", vec!["replica1:6379"]), + create_slot(test_slot + 1, 16383, "node2:6379", vec!["replica2:6379"]), + ]; + + assert_slot_map_and_shard_addrs(slot_map, test_slot, new_shard_addrs, after_slots); + } + + #[test] + fn test_update_slot_range_slot_matches_start_range_cant_merge_ranges() { + let test_slot = 8000; + let before_slots = vec![ + create_slot(0, 7999, "node1:6379", vec!["replica1:6379"]), + create_slot(8000, 16383, "node2:6379", vec!["replica2:6379"]), + ]; + + let mut slot_map = SlotMap::new(before_slots, ReadFromReplicaStrategy::AlwaysFromPrimary); + let new_shard_addrs = Arc::new(ShardAddrs::new( + Arc::new("node3:6379".to_owned()), + vec![Arc::new("replica3:6379".to_owned())], + )); + + let res = slot_map.update_slot_range(test_slot, new_shard_addrs.clone()); + assert!(res.is_ok(), "{res:?}"); + + let after_slots = vec![ + create_slot(0, test_slot - 1, "node1:6379", vec!["replica1:6379"]), + create_slot(test_slot, test_slot, "node3:6379", vec!["replica3:6379"]), + create_slot(test_slot + 1, 16383, "node2:6379", vec!["replica2:6379"]), + ]; + + assert_slot_map_and_shard_addrs(slot_map, test_slot, new_shard_addrs, after_slots); + } + + #[test] + fn test_update_slot_range_slot_is_within_a_range() { + let test_slot = 4000; + let before_slots = vec![ + create_slot(0, 7999, "node1:6379", vec!["replica1:6379"]), + create_slot(8000, 16383, "node2:6379", vec!["replica2:6379"]), + ]; + + let mut slot_map = SlotMap::new(before_slots, ReadFromReplicaStrategy::AlwaysFromPrimary); + let new_shard_addrs = slot_map + .shard_addrs_for_slot(8000) + .expect("Couldn't find shard address for slot"); + + let res = slot_map.update_slot_range(test_slot, new_shard_addrs.clone()); + assert!(res.is_ok(), "{res:?}"); + + let after_slots = vec![ + create_slot(0, test_slot - 1, "node1:6379", vec!["replica1:6379"]), + create_slot(test_slot, test_slot, "node2:6379", vec!["replica2:6379"]), + create_slot(test_slot + 1, 7999, "node1:6379", vec!["replica1:6379"]), + create_slot(8000, 16383, "node2:6379", vec!["replica2:6379"]), + ]; + assert_slot_map_and_shard_addrs(slot_map, test_slot, new_shard_addrs, after_slots); + } + + #[test] + fn test_update_slot_range_slot_is_not_covered_cant_merge_ranges() { + let test_slot = 7998; + let before_slots = vec![ + create_slot(0, 7000, "node1:6379", vec!["replica1:6379"]), + create_slot(8000, 16383, "node2:6379", vec!["replica2:6379"]), + ]; + + let mut slot_map = SlotMap::new(before_slots, ReadFromReplicaStrategy::AlwaysFromPrimary); + let new_shard_addrs = slot_map + .shard_addrs_for_slot(8000) + .expect("Couldn't find shard address for slot"); + + let res = slot_map.update_slot_range(test_slot, new_shard_addrs.clone()); + assert!(res.is_ok(), "{res:?}"); + + let after_slots = vec![ + create_slot(0, 7000, "node1:6379", vec!["replica1:6379"]), + create_slot(test_slot, test_slot, "node2:6379", vec!["replica2:6379"]), + create_slot(8000, 16383, "node2:6379", vec!["replica2:6379"]), + ]; + assert_slot_map_and_shard_addrs(slot_map, test_slot, new_shard_addrs, after_slots); + } + + #[test] + fn test_update_slot_range_slot_is_not_covered_merge_with_next() { + let test_slot = 7999; + let before_slots = vec![ + create_slot(0, 7000, "node1:6379", vec!["replica1:6379"]), + create_slot(8000, 16383, "node2:6379", vec!["replica2:6379"]), + ]; + + let mut slot_map = SlotMap::new(before_slots, ReadFromReplicaStrategy::AlwaysFromPrimary); + let new_shard_addrs = slot_map + .shard_addrs_for_slot(8000) + .expect("Couldn't find shard address for slot"); + + let res = slot_map.update_slot_range(test_slot, new_shard_addrs.clone()); + assert!(res.is_ok(), "{res:?}"); + + let after_slots = vec![ + create_slot(0, 7000, "node1:6379", vec!["replica1:6379"]), + create_slot(test_slot, 16383, "node2:6379", vec!["replica2:6379"]), + ]; + assert_slot_map_and_shard_addrs(slot_map, test_slot, new_shard_addrs, after_slots); + } + + #[test] + fn test_update_slot_range_slot_is_not_covered_merge_with_prev() { + let test_slot = 7001; + let before_slots = vec![ + create_slot(0, 7000, "node1:6379", vec!["replica1:6379"]), + create_slot(8000, 16383, "node2:6379", vec!["replica2:6379"]), + ]; + + let mut slot_map = SlotMap::new(before_slots, ReadFromReplicaStrategy::AlwaysFromPrimary); + let new_shard_addrs = slot_map + .shard_addrs_for_slot(7000) + .expect("Couldn't find shard address for slot"); + + let res = slot_map.update_slot_range(test_slot, new_shard_addrs.clone()); + assert!(res.is_ok(), "{res:?}"); + + let after_slots = vec![ + create_slot(0, test_slot, "node1:6379", vec!["replica1:6379"]), + create_slot(8000, 16383, "node2:6379", vec!["replica2:6379"]), + ]; + assert_slot_map_and_shard_addrs(slot_map, test_slot, new_shard_addrs, after_slots); + } + + #[test] + fn test_update_slot_range_same_shard_owner_no_change_needed() { + let test_slot = 7000; + let before_slots = vec![ + create_slot(0, 7999, "node1:6379", vec!["replica1:6379"]), + create_slot(8000, 16383, "node2:6379", vec!["replica2:6379"]), + ]; + + let mut slot_map = SlotMap::new( + before_slots.clone(), + ReadFromReplicaStrategy::AlwaysFromPrimary, + ); + let new_shard_addrs = slot_map + .shard_addrs_for_slot(7000) + .expect("Couldn't find shard address for slot"); + + let res = slot_map.update_slot_range(test_slot, new_shard_addrs.clone()); + assert!(res.is_ok(), "{res:?}"); + + let after_slots = before_slots; + assert_slot_map_and_shard_addrs(slot_map, test_slot, new_shard_addrs, after_slots); + } + + #[test] + fn test_update_slot_range_max_slot_matches_end_range() { + let max_slot = 16383; + let before_slots = vec![ + create_slot(0, 7999, "node1:6379", vec!["replica1:6379"]), + create_slot(8000, 16383, "node2:6379", vec!["replica2:6379"]), + ]; + + let mut slot_map = SlotMap::new( + before_slots.clone(), + ReadFromReplicaStrategy::AlwaysFromPrimary, + ); + let new_shard_addrs = slot_map + .shard_addrs_for_slot(7000) + .expect("Couldn't find shard address for slot"); + + let res = slot_map.update_slot_range(max_slot, new_shard_addrs.clone()); + assert!(res.is_ok(), "{res:?}"); + + let after_slots = vec![ + create_slot(0, 7999, "node1:6379", vec!["replica1:6379"]), + create_slot(8000, max_slot - 1, "node2:6379", vec!["replica2:6379"]), + create_slot(max_slot, max_slot, "node1:6379", vec!["replica1:6379"]), + ]; + assert_slot_map_and_shard_addrs(slot_map, max_slot, new_shard_addrs, after_slots); + } + + #[test] + fn test_update_slot_range_max_slot_single_slot_range() { + let max_slot = 16383; + let before_slots = vec![ + create_slot(0, 16382, "node1:6379", vec!["replica1:6379"]), + create_slot(16383, 16383, "node2:6379", vec!["replica2:6379"]), + ]; + + let mut slot_map = SlotMap::new( + before_slots.clone(), + ReadFromReplicaStrategy::AlwaysFromPrimary, + ); + let new_shard_addrs = slot_map + .shard_addrs_for_slot(0) + .expect("Couldn't find shard address for slot"); + + let res = slot_map.update_slot_range(max_slot, new_shard_addrs.clone()); + assert!(res.is_ok(), "{res:?}"); + + let after_slots = vec![ + create_slot(0, max_slot - 1, "node1:6379", vec!["replica1:6379"]), + create_slot(max_slot, max_slot, "node1:6379", vec!["replica1:6379"]), + ]; + assert_slot_map_and_shard_addrs(slot_map, max_slot, new_shard_addrs, after_slots); + } + + #[test] + fn test_update_slot_range_min_slot_matches_start_range() { + let min_slot = 0; + let before_slots = vec![ + create_slot(0, 7999, "node1:6379", vec!["replica1:6379"]), + create_slot(8000, 16383, "node2:6379", vec!["replica2:6379"]), + ]; + + let mut slot_map = SlotMap::new( + before_slots.clone(), + ReadFromReplicaStrategy::AlwaysFromPrimary, + ); + let new_shard_addrs = slot_map + .shard_addrs_for_slot(8000) + .expect("Couldn't find shard address for slot"); + + let res = slot_map.update_slot_range(min_slot, new_shard_addrs.clone()); + assert!(res.is_ok(), "{res:?}"); + + let after_slots = vec![ + create_slot(min_slot, min_slot, "node2:6379", vec!["replica2:6379"]), + create_slot(min_slot + 1, 7999, "node1:6379", vec!["replica1:6379"]), + create_slot(8000, 16383, "node2:6379", vec!["replica2:6379"]), + ]; + assert_slot_map_and_shard_addrs(slot_map, min_slot, new_shard_addrs, after_slots); + } + + #[test] + fn test_update_slot_range_min_slot_single_slot_range() { + let min_slot = 0; + let before_slots = vec![ + create_slot(0, 0, "node1:6379", vec!["replica1:6379"]), + create_slot(1, 16383, "node2:6379", vec!["replica2:6379"]), + ]; + + let mut slot_map = SlotMap::new( + before_slots.clone(), + ReadFromReplicaStrategy::AlwaysFromPrimary, + ); + let new_shard_addrs = slot_map + .shard_addrs_for_slot(1) + .expect("Couldn't find shard address for slot"); + + let res = slot_map.update_slot_range(min_slot, new_shard_addrs.clone()); + assert!(res.is_ok(), "{res:?}"); + + let after_slots = vec![ + create_slot(min_slot, min_slot, "node2:6379", vec!["replica2:6379"]), + create_slot(min_slot + 1, 16383, "node2:6379", vec!["replica2:6379"]), + ]; + assert_slot_map_and_shard_addrs(slot_map, min_slot, new_shard_addrs, after_slots); + } +} diff --git a/glide-core/redis-rs/redis/src/cluster_topology.rs b/glide-core/redis-rs/redis/src/cluster_topology.rs new file mode 100644 index 0000000000..b3a4a200d5 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_topology.rs @@ -0,0 +1,661 @@ +//! This module provides the functionality to refresh and calculate the cluster topology for Redis Cluster. + +use crate::cluster::get_connection_addr; +#[cfg(feature = "cluster-async")] +use crate::cluster_client::SlotsRefreshRateLimit; +use crate::cluster_routing::Slot; +use crate::cluster_slotmap::{ReadFromReplicaStrategy, SlotMap}; +use crate::{cluster::TlsMode, ErrorKind, RedisError, RedisResult, Value}; +#[cfg(all(feature = "cluster-async", not(feature = "tokio-comp")))] +use async_std::sync::RwLock; +use std::collections::{hash_map::DefaultHasher, HashMap}; +use std::hash::{Hash, Hasher}; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; +use std::time::{Duration, SystemTime}; +#[cfg(all(feature = "cluster-async", feature = "tokio-comp"))] +use tokio::sync::RwLock; +use tracing::info; + +// Exponential backoff constants for retrying a slot refresh +/// The default number of refresh topology retries in the same call +pub const DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES: usize = 3; +/// The default base duration for retrying topology refresh +pub const DEFAULT_REFRESH_SLOTS_RETRY_BASE_DURATION_MILLIS: u64 = 500; +/// The default base factor for retrying topology refresh +pub const DEFAULT_REFRESH_SLOTS_RETRY_BASE_FACTOR: f64 = 1.5; +// Constants for the intervals between two independent consecutive refresh slots calls +/// The default wait duration between two consecutive refresh slots calls +#[cfg(feature = "cluster-async")] +pub const DEFAULT_SLOTS_REFRESH_WAIT_DURATION: Duration = Duration::from_secs(15); +/// The default maximum jitter duration to add to the refresh slots wait duration +#[cfg(feature = "cluster-async")] +pub const DEFAULT_SLOTS_REFRESH_MAX_JITTER_MILLI: u64 = 15 * 1000; // 15 seconds + +pub(crate) const SLOT_SIZE: u16 = 16384; +pub(crate) type TopologyHash = u64; + +/// Represents the state of slot refresh operations. +#[cfg(feature = "cluster-async")] +pub(crate) struct SlotRefreshState { + /// Indicates if a slot refresh is currently in progress + pub(crate) in_progress: AtomicBool, + /// The last slot refresh run timestamp + pub(crate) last_run: Arc>>, + pub(crate) rate_limiter: SlotsRefreshRateLimit, +} + +#[cfg(feature = "cluster-async")] +impl SlotRefreshState { + pub(crate) fn new(rate_limiter: SlotsRefreshRateLimit) -> Self { + Self { + in_progress: AtomicBool::new(false), + last_run: Arc::new(RwLock::new(None)), + rate_limiter, + } + } +} + +#[derive(Debug)] +pub(crate) struct TopologyView { + pub(crate) hash_value: TopologyHash, + pub(crate) nodes_count: u16, + slots_and_count: (u16, Vec), +} + +impl PartialEq for TopologyView { + fn eq(&self, other: &Self) -> bool { + self.hash_value == other.hash_value + } +} + +impl Eq for TopologyView {} + +pub(crate) fn slot(key: &[u8]) -> u16 { + crc16::State::::calculate(key) % SLOT_SIZE +} + +fn get_hashtag(key: &[u8]) -> Option<&[u8]> { + let open = key.iter().position(|v| *v == b'{'); + let open = match open { + Some(open) => open, + None => return None, + }; + + let close = key[open..].iter().position(|v| *v == b'}'); + let close = match close { + Some(close) => close, + None => return None, + }; + + let rv = &key[open + 1..open + close]; + if rv.is_empty() { + None + } else { + Some(rv) + } +} + +/// Returns the slot that matches `key`. +pub fn get_slot(key: &[u8]) -> u16 { + let key = match get_hashtag(key) { + Some(tag) => tag, + None => key, + }; + + slot(key) +} + +// Parse slot data from raw redis value. +pub(crate) fn parse_and_count_slots( + raw_slot_resp: &Value, + tls: Option, + // The DNS address of the node from which `raw_slot_resp` was received. + addr_of_answering_node: &str, +) -> RedisResult<(u16, Vec)> { + // Parse response. + let mut slots = Vec::with_capacity(2); + let mut count = 0; + + if let Value::Array(items) = raw_slot_resp { + let mut iter = items.iter(); + while let Some(Value::Array(item)) = iter.next() { + if item.len() < 3 { + continue; + } + + let start = if let Value::Int(start) = item[0] { + start as u16 + } else { + continue; + }; + + let end = if let Value::Int(end) = item[1] { + end as u16 + } else { + continue; + }; + + let mut nodes: Vec = item + .iter() + .skip(2) + .filter_map(|node| { + if let Value::Array(node) = node { + if node.len() < 2 { + return None; + } + // According to the CLUSTER SLOTS documentation: + // If the received hostname is an empty string or NULL, clients should utilize the hostname of the responding node. + // However, if the received hostname is "?", it should be regarded as an indication of an unknown node. + let hostname = if let Value::BulkString(ref ip) = node[0] { + let hostname = String::from_utf8_lossy(ip); + if hostname.is_empty() { + addr_of_answering_node.into() + } else if hostname == "?" { + return None; + } else { + hostname + } + } else if let Value::Nil = node[0] { + addr_of_answering_node.into() + } else { + return None; + }; + if hostname.is_empty() { + return None; + } + + let port = if let Value::Int(port) = node[1] { + port as u16 + } else { + return None; + }; + Some( + get_connection_addr(hostname.into_owned(), port, tls, None).to_string(), + ) + } else { + None + } + }) + .collect(); + + if nodes.is_empty() { + continue; + } + count += end - start; + + let mut replicas = nodes.split_off(1); + // we sort the replicas, because different nodes in a cluster might return the same slot view + // with different order of the replicas, which might cause the views to be considered evaluated as not equal. + replicas.sort_unstable(); + slots.push(Slot::new(start, end, nodes.pop().unwrap(), replicas)); + } + } + if slots.is_empty() { + return Err(RedisError::from(( + ErrorKind::ResponseError, + "Error parsing slots: No healthy node found", + format!("Raw slot map response: {:?}", raw_slot_resp), + ))); + } + + Ok((count, slots)) +} + +fn calculate_hash(t: &T) -> u64 { + let mut s = DefaultHasher::new(); + t.hash(&mut s); + s.finish() +} + +pub(crate) fn calculate_topology<'a>( + topology_views: impl Iterator, + curr_retry: usize, + tls_mode: Option, + num_of_queried_nodes: usize, + read_from_replica: ReadFromReplicaStrategy, +) -> RedisResult<(SlotMap, TopologyHash)> { + let mut hash_view_map = HashMap::new(); + for (host, view) in topology_views { + if let Ok(slots_and_count) = parse_and_count_slots(view, tls_mode, host) { + let hash_value = calculate_hash(&slots_and_count); + let topology_entry = hash_view_map.entry(hash_value).or_insert(TopologyView { + hash_value, + nodes_count: 0, + slots_and_count, + }); + topology_entry.nodes_count += 1; + } + } + let mut non_unique_max_node_count = false; + let mut vec_iter = hash_view_map.into_values(); + let mut most_frequent_topology = match vec_iter.next() { + Some(view) => view, + None => { + return Err(RedisError::from(( + ErrorKind::ResponseError, + "No topology views found", + ))); + } + }; + // Find the most frequent topology view + for curr_view in vec_iter { + match most_frequent_topology + .nodes_count + .cmp(&curr_view.nodes_count) + { + std::cmp::Ordering::Less => { + most_frequent_topology = curr_view; + non_unique_max_node_count = false; + } + std::cmp::Ordering::Greater => continue, + std::cmp::Ordering::Equal => { + non_unique_max_node_count = true; + let seen_slot_count = most_frequent_topology.slots_and_count.0; + + // We choose as the greater view the one with higher slot coverage. + if let std::cmp::Ordering::Less = seen_slot_count.cmp(&curr_view.slots_and_count.0) + { + most_frequent_topology = curr_view; + } + } + } + } + + let parse_and_built_result = |most_frequent_topology: TopologyView| { + info!( + "calculate_topology found topology map:\n{:?}", + most_frequent_topology + ); + let slots_data = most_frequent_topology.slots_and_count.1; + Ok(( + SlotMap::new(slots_data, read_from_replica), + most_frequent_topology.hash_value, + )) + }; + + if non_unique_max_node_count { + // More than a single most frequent view was found + // If we reached the last retry, or if we it's a 2-nodes cluster, we'll return a view with the highest slot coverage, and that is one of most agreed on views. + if curr_retry >= DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES || num_of_queried_nodes < 3 { + return parse_and_built_result(most_frequent_topology); + } + return Err(RedisError::from(( + ErrorKind::ResponseError, + "Slot refresh error: Failed to obtain a majority in topology views", + ))); + } + + // The rate of agreement of the topology view is determined by assessing the number of nodes that share this view out of the total number queried + let agreement_rate = most_frequent_topology.nodes_count as f32 / num_of_queried_nodes as f32; + const MIN_AGREEMENT_RATE: f32 = 0.2; + if agreement_rate >= MIN_AGREEMENT_RATE { + parse_and_built_result(most_frequent_topology) + } else { + Err(RedisError::from(( + ErrorKind::ResponseError, + "Slot refresh error: The accuracy of the topology view is too low", + ))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::cluster_routing::ShardAddrs; + + #[test] + fn test_get_hashtag() { + assert_eq!(get_hashtag(&b"foo{bar}baz"[..]), Some(&b"bar"[..])); + assert_eq!(get_hashtag(&b"foo{}{baz}"[..]), None); + assert_eq!(get_hashtag(&b"foo{{bar}}zap"[..]), Some(&b"{bar"[..])); + } + + fn slot_value_with_replicas(start: u16, end: u16, nodes: Vec<(&str, u16)>) -> Value { + let mut node_values: Vec = nodes + .iter() + .map(|(host, port)| { + Value::Array(vec![ + Value::BulkString(host.as_bytes().to_vec()), + Value::Int(*port as i64), + ]) + }) + .collect(); + let mut slot_vec = vec![Value::Int(start as i64), Value::Int(end as i64)]; + slot_vec.append(&mut node_values); + Value::Array(slot_vec) + } + + fn slot_value(start: u16, end: u16, node: &str, port: u16) -> Value { + slot_value_with_replicas(start, end, vec![(node, port)]) + } + + #[test] + fn parse_slots_with_different_replicas_order_returns_the_same_view() { + let view1 = Value::Array(vec![ + slot_value_with_replicas( + 0, + 4000, + vec![ + ("primary1", 6379), + ("replica1_1", 6379), + ("replica1_2", 6379), + ("replica1_3", 6379), + ], + ), + slot_value_with_replicas( + 4001, + 8000, + vec![ + ("primary2", 6379), + ("replica2_1", 6379), + ("replica2_2", 6379), + ("replica2_3", 6379), + ], + ), + slot_value_with_replicas( + 8001, + 16383, + vec![ + ("primary3", 6379), + ("replica3_1", 6379), + ("replica3_2", 6379), + ("replica3_3", 6379), + ], + ), + ]); + + let view2 = Value::Array(vec![ + slot_value_with_replicas( + 0, + 4000, + vec![ + ("primary1", 6379), + ("replica1_1", 6379), + ("replica1_3", 6379), + ("replica1_2", 6379), + ], + ), + slot_value_with_replicas( + 4001, + 8000, + vec![ + ("primary2", 6379), + ("replica2_2", 6379), + ("replica2_3", 6379), + ("replica2_1", 6379), + ], + ), + slot_value_with_replicas( + 8001, + 16383, + vec![ + ("primary3", 6379), + ("replica3_3", 6379), + ("replica3_1", 6379), + ("replica3_2", 6379), + ], + ), + ]); + + let res1 = parse_and_count_slots(&view1, None, "foo").unwrap(); + let res2 = parse_and_count_slots(&view2, None, "foo").unwrap(); + assert_eq!(calculate_hash(&res1), calculate_hash(&res2)); + assert_eq!(res1.0, res2.0); + assert_eq!(res1.1.len(), res2.1.len()); + let check = res1 + .1 + .into_iter() + .zip(res2.1) + .all(|(first, second)| first.replicas() == second.replicas()); + assert!(check); + } + + #[test] + fn parse_slots_returns_slots_with_host_name_if_missing() { + let view = Value::Array(vec![slot_value(0, 4000, "", 6379)]); + + let (slot_count, slots) = parse_and_count_slots(&view, None, "node").unwrap(); + assert_eq!(slot_count, 4000); + assert_eq!(slots[0].master(), "node:6379"); + } + + #[test] + fn should_parse_and_hash_regardless_of_missing_host_name_and_replicas_order() { + let view1 = Value::Array(vec![ + slot_value(0, 4000, "", 6379), + slot_value(4001, 8000, "node2", 6380), + slot_value_with_replicas( + 8001, + 16383, + vec![ + ("node3", 6379), + ("replica3_1", 6379), + ("replica3_2", 6379), + ("replica3_3", 6379), + ], + ), + ]); + + let view2 = Value::Array(vec![ + slot_value(0, 4000, "node1", 6379), + slot_value(4001, 8000, "node2", 6380), + slot_value_with_replicas( + 8001, + 16383, + vec![ + ("", 6379), + ("replica3_3", 6379), + ("replica3_2", 6379), + ("replica3_1", 6379), + ], + ), + ]); + + let res1 = parse_and_count_slots(&view1, None, "node1").unwrap(); + let res2 = parse_and_count_slots(&view2, None, "node3").unwrap(); + + assert_eq!(calculate_hash(&res1), calculate_hash(&res2)); + assert_eq!(res1.0, res2.0); + assert_eq!(res1.1.len(), res2.1.len()); + let equality_check = res1 + .1 + .iter() + .zip(&res2.1) + .all(|(first, second)| first.start == second.start && first.end == second.end); + assert!(equality_check); + let replicas_check = res1 + .1 + .iter() + .zip(res2.1) + .all(|(first, second)| first.replicas() == second.replicas()); + assert!(replicas_check); + } + + enum ViewType { + SingleNodeViewFullCoverage, + SingleNodeViewMissingSlots, + TwoNodesViewFullCoverage, + TwoNodesViewMissingSlots, + } + fn get_view(view_type: &ViewType) -> (&str, Value) { + match view_type { + ViewType::SingleNodeViewFullCoverage => ( + "first", + Value::Array(vec![slot_value(0, 16383, "node1", 6379)]), + ), + ViewType::SingleNodeViewMissingSlots => ( + "second", + Value::Array(vec![slot_value(0, 4000, "node1", 6379)]), + ), + ViewType::TwoNodesViewFullCoverage => ( + "third", + Value::Array(vec![ + slot_value(0, 4000, "node1", 6379), + slot_value(4001, 16383, "node2", 6380), + ]), + ), + ViewType::TwoNodesViewMissingSlots => ( + "fourth", + Value::Array(vec![ + slot_value(0, 3000, "node3", 6381), + slot_value(4001, 16383, "node4", 6382), + ]), + ), + } + } + + fn get_node_addr(name: &str, port: u16) -> Arc { + Arc::new(ShardAddrs::new(format!("{name}:{port}").into(), Vec::new())) + } + + fn collect_shard_addrs(slot_map: &SlotMap) -> Vec> { + let mut shard_addrs: Vec> = slot_map + .nodes_map() + .iter() + .map(|map_item| { + let shard_addrs = map_item.value(); + shard_addrs.clone() + }) + .collect(); + shard_addrs.sort_unstable(); + shard_addrs + } + + #[test] + fn test_topology_calculator_4_nodes_queried_has_a_majority_success() { + // 4 nodes queried (1 error): Has a majority, single_node_view should be chosen + let queried_nodes: usize = 4; + let topology_results = vec![ + get_view(&ViewType::SingleNodeViewFullCoverage), + get_view(&ViewType::SingleNodeViewFullCoverage), + get_view(&ViewType::TwoNodesViewFullCoverage), + ]; + + let (topology_view, _) = calculate_topology( + topology_results.iter().map(|(addr, value)| (*addr, value)), + 1, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ) + .unwrap(); + let res = collect_shard_addrs(&topology_view); + let node_1 = get_node_addr("node1", 6379); + let expected = vec![node_1]; + assert_eq!(res, expected); + } + + #[test] + fn test_topology_calculator_3_nodes_queried_no_majority_has_more_retries_raise_error() { + // 3 nodes queried: No majority, should return an error + let queried_nodes = 3; + let topology_results = vec![ + get_view(&ViewType::SingleNodeViewFullCoverage), + get_view(&ViewType::TwoNodesViewFullCoverage), + get_view(&ViewType::TwoNodesViewMissingSlots), + ]; + let topology_view = calculate_topology( + topology_results.iter().map(|(addr, value)| (*addr, value)), + 1, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ); + assert!(topology_view.is_err()); + } + + #[test] + fn test_topology_calculator_3_nodes_queried_no_majority_last_retry_success() { + // 3 nodes queried:: No majority, last retry, should get the view that has a full slot coverage + let queried_nodes = 3; + let topology_results = vec![ + get_view(&ViewType::SingleNodeViewMissingSlots), + get_view(&ViewType::TwoNodesViewFullCoverage), + get_view(&ViewType::TwoNodesViewMissingSlots), + ]; + let (topology_view, _) = calculate_topology( + topology_results.iter().map(|(addr, value)| (*addr, value)), + 3, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ) + .unwrap(); + let res = collect_shard_addrs(&topology_view); + let node_1 = get_node_addr("node1", 6379); + let node_2 = get_node_addr("node2", 6380); + let expected = vec![node_1, node_2]; + assert_eq!(res, expected); + } + + #[test] + fn test_topology_calculator_2_nodes_queried_no_majority_return_full_slot_coverage_view() { + // 2 nodes queried: No majority, should get the view that has a full slot coverage + let queried_nodes = 2; + let topology_results = [ + get_view(&ViewType::TwoNodesViewFullCoverage), + get_view(&ViewType::TwoNodesViewMissingSlots), + ]; + let (topology_view, _) = calculate_topology( + topology_results.iter().map(|(addr, value)| (*addr, value)), + 1, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ) + .unwrap(); + let res = collect_shard_addrs(&topology_view); + let node_1 = get_node_addr("node1", 6379); + let node_2 = get_node_addr("node2", 6380); + let expected = vec![node_1, node_2]; + assert_eq!(res, expected); + } + + #[test] + fn test_topology_calculator_2_nodes_queried_no_majority_no_full_coverage_prefer_fuller_coverage( + ) { + // 2 nodes queried: No majority, no full slot coverage, should return error + let queried_nodes = 2; + let topology_results = [ + get_view(&ViewType::SingleNodeViewMissingSlots), + get_view(&ViewType::TwoNodesViewMissingSlots), + ]; + let (topology_view, _) = calculate_topology( + topology_results.iter().map(|(addr, value)| (*addr, value)), + 1, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ) + .unwrap(); + let res = collect_shard_addrs(&topology_view); + let node_1 = get_node_addr("node3", 6381); + let node_2 = get_node_addr("node4", 6382); + let expected = vec![node_1, node_2]; + assert_eq!(res, expected); + } + + #[test] + fn test_topology_calculator_3_nodes_queried_no_full_coverage_prefer_majority() { + // 2 nodes queried: No majority, no full slot coverage, should return error + let queried_nodes = 2; + let topology_results = vec![ + get_view(&ViewType::SingleNodeViewMissingSlots), + get_view(&ViewType::TwoNodesViewMissingSlots), + get_view(&ViewType::SingleNodeViewMissingSlots), + ]; + let (topology_view, _) = calculate_topology( + topology_results.iter().map(|(addr, value)| (*addr, value)), + 1, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ) + .unwrap(); + let res = collect_shard_addrs(&topology_view); + let node_1 = get_node_addr("node1", 6379); + let expected = vec![node_1]; + assert_eq!(res, expected); + } +} diff --git a/glide-core/redis-rs/redis/src/cmd.rs b/glide-core/redis-rs/redis/src/cmd.rs new file mode 100644 index 0000000000..3e248dad6f --- /dev/null +++ b/glide-core/redis-rs/redis/src/cmd.rs @@ -0,0 +1,663 @@ +#[cfg(feature = "aio")] +use futures_util::{ + future::BoxFuture, + task::{Context, Poll}, + Stream, StreamExt, +}; +#[cfg(feature = "aio")] +use std::pin::Pin; +use std::{fmt, io}; + +use crate::connection::ConnectionLike; +use crate::pipeline::Pipeline; +use crate::types::{from_owned_redis_value, FromRedisValue, RedisResult, RedisWrite, ToRedisArgs}; + +/// An argument to a redis command +#[derive(Clone)] +pub enum Arg { + /// A normal argument + Simple(D), + /// A cursor argument created from `cursor_arg()` + Cursor, +} + +/// Represents redis commands. +#[derive(Clone)] +pub struct Cmd { + data: Vec, + // Arg::Simple contains the offset that marks the end of the argument + args: Vec>, + cursor: Option, + // If it's true command's response won't be read from socket. Useful for Pub/Sub. + no_response: bool, +} + +/// Represents a redis iterator. +pub struct Iter<'a, T: FromRedisValue> { + batch: std::vec::IntoIter, + cursor: u64, + con: &'a mut (dyn ConnectionLike + 'a), + cmd: Cmd, +} + +impl<'a, T: FromRedisValue> Iterator for Iter<'a, T> { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + // we need to do this in a loop until we produce at least one item + // or we find the actual end of the iteration. This is necessary + // because with filtering an iterator it is possible that a whole + // chunk is not matching the pattern and thus yielding empty results. + loop { + if let Some(v) = self.batch.next() { + return Some(v); + }; + if self.cursor == 0 { + return None; + } + + let packed_cmd = self.cmd.get_packed_command_with_cursor(self.cursor)?; + let rv = self.con.req_packed_command(&packed_cmd).ok()?; + let (cur, batch): (u64, Vec) = from_owned_redis_value(rv).ok()?; + + self.cursor = cur; + self.batch = batch.into_iter(); + } + } +} + +#[cfg(feature = "aio")] +use crate::aio::ConnectionLike as AsyncConnection; + +/// The inner future of AsyncIter +#[cfg(feature = "aio")] +struct AsyncIterInner<'a, T: FromRedisValue + 'a> { + batch: std::vec::IntoIter, + con: &'a mut (dyn AsyncConnection + Send + 'a), + cmd: Cmd, +} + +/// Represents the state of AsyncIter +#[cfg(feature = "aio")] +enum IterOrFuture<'a, T: FromRedisValue + 'a> { + Iter(AsyncIterInner<'a, T>), + Future(BoxFuture<'a, (AsyncIterInner<'a, T>, Option)>), + Empty, +} + +/// Represents a redis iterator that can be used with async connections. +#[cfg(feature = "aio")] +pub struct AsyncIter<'a, T: FromRedisValue + 'a> { + inner: IterOrFuture<'a, T>, +} + +#[cfg(feature = "aio")] +impl<'a, T: FromRedisValue + 'a> AsyncIterInner<'a, T> { + #[inline] + pub async fn next_item(&mut self) -> Option { + // we need to do this in a loop until we produce at least one item + // or we find the actual end of the iteration. This is necessary + // because with filtering an iterator it is possible that a whole + // chunk is not matching the pattern and thus yielding empty results. + loop { + if let Some(v) = self.batch.next() { + return Some(v); + }; + if let Some(cursor) = self.cmd.cursor { + if cursor == 0 { + return None; + } + } else { + return None; + } + + let rv = self.con.req_packed_command(&self.cmd).await.ok()?; + let (cur, batch): (u64, Vec) = from_owned_redis_value(rv).ok()?; + + self.cmd.cursor = Some(cur); + self.batch = batch.into_iter(); + } + } +} + +#[cfg(feature = "aio")] +impl<'a, T: FromRedisValue + 'a + Unpin + Send> AsyncIter<'a, T> { + /// ```rust,no_run + /// # use redis::AsyncCommands; + /// # async fn scan_set() -> redis::RedisResult<()> { + /// # let client = redis::Client::open("redis://127.0.0.1/")?; + /// # let mut con = client.get_async_connection(None).await?; + /// con.sadd("my_set", 42i32).await?; + /// con.sadd("my_set", 43i32).await?; + /// let mut iter: redis::AsyncIter = con.sscan("my_set").await?; + /// while let Some(element) = iter.next_item().await { + /// assert!(element == 42 || element == 43); + /// } + /// # Ok(()) + /// # } + /// ``` + #[inline] + pub async fn next_item(&mut self) -> Option { + StreamExt::next(self).await + } +} + +#[cfg(feature = "aio")] +impl<'a, T: FromRedisValue + Unpin + Send + 'a> Stream for AsyncIter<'a, T> { + type Item = T; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + let inner = std::mem::replace(&mut this.inner, IterOrFuture::Empty); + match inner { + IterOrFuture::Iter(mut iter) => { + let fut = async move { + let next_item = iter.next_item().await; + (iter, next_item) + }; + this.inner = IterOrFuture::Future(Box::pin(fut)); + Pin::new(this).poll_next(cx) + } + IterOrFuture::Future(mut fut) => match fut.as_mut().poll(cx) { + Poll::Pending => { + this.inner = IterOrFuture::Future(fut); + Poll::Pending + } + Poll::Ready((iter, value)) => { + this.inner = IterOrFuture::Iter(iter); + Poll::Ready(value) + } + }, + IterOrFuture::Empty => unreachable!(), + } + } +} + +fn countdigits(mut v: usize) -> usize { + let mut result = 1; + loop { + if v < 10 { + return result; + } + if v < 100 { + return result + 1; + } + if v < 1000 { + return result + 2; + } + if v < 10000 { + return result + 3; + } + + v /= 10000; + result += 4; + } +} + +#[inline] +fn bulklen(len: usize) -> usize { + 1 + countdigits(len) + 2 + len + 2 +} + +fn args_len<'a, I>(args: I, cursor: u64) -> usize +where + I: IntoIterator> + ExactSizeIterator, +{ + let mut total_len = countdigits(args.len()).saturating_add(3); + for item in args { + total_len += bulklen(match item { + Arg::Cursor => countdigits(cursor as usize), + Arg::Simple(val) => val.len(), + }); + } + total_len +} + +pub(crate) fn cmd_len(cmd: &Cmd) -> usize { + args_len(cmd.args_iter(), cmd.cursor.unwrap_or(0)) +} + +fn encode_command<'a, I>(args: I, cursor: u64) -> Vec +where + I: IntoIterator> + Clone + ExactSizeIterator, +{ + let mut cmd = Vec::new(); + write_command_to_vec(&mut cmd, args, cursor); + cmd +} + +fn write_command_to_vec<'a, I>(cmd: &mut Vec, args: I, cursor: u64) +where + I: IntoIterator> + Clone + ExactSizeIterator, +{ + let total_len = args_len(args.clone(), cursor); + + cmd.reserve(total_len); + + write_command(cmd, args, cursor).unwrap() +} + +fn write_command<'a, I>(cmd: &mut (impl ?Sized + io::Write), args: I, cursor: u64) -> io::Result<()> +where + I: IntoIterator> + Clone + ExactSizeIterator, +{ + let mut buf = ::itoa::Buffer::new(); + + cmd.write_all(b"*")?; + let s = buf.format(args.len()); + cmd.write_all(s.as_bytes())?; + cmd.write_all(b"\r\n")?; + + let mut cursor_bytes = itoa::Buffer::new(); + for item in args { + let bytes = match item { + Arg::Cursor => cursor_bytes.format(cursor).as_bytes(), + Arg::Simple(val) => val, + }; + + cmd.write_all(b"$")?; + let s = buf.format(bytes.len()); + cmd.write_all(s.as_bytes())?; + cmd.write_all(b"\r\n")?; + + cmd.write_all(bytes)?; + cmd.write_all(b"\r\n")?; + } + Ok(()) +} + +impl RedisWrite for Cmd { + fn write_arg(&mut self, arg: &[u8]) { + self.data.extend_from_slice(arg); + self.args.push(Arg::Simple(self.data.len())); + } + + fn write_arg_fmt(&mut self, arg: impl fmt::Display) { + use std::io::Write; + write!(self.data, "{arg}").unwrap(); + self.args.push(Arg::Simple(self.data.len())); + } +} + +impl Default for Cmd { + fn default() -> Cmd { + Cmd::new() + } +} + +/// A command acts as a builder interface to creating encoded redis +/// requests. This allows you to easily assemble a packed command +/// by chaining arguments together. +/// +/// Basic example: +/// +/// ```rust +/// redis::Cmd::new().arg("SET").arg("my_key").arg(42); +/// ``` +/// +/// There is also a helper function called `cmd` which makes it a +/// tiny bit shorter: +/// +/// ```rust +/// redis::cmd("SET").arg("my_key").arg(42); +/// ``` +/// +/// Because Rust currently does not have an ideal system +/// for lifetimes of temporaries, sometimes you need to hold on to +/// the initially generated command: +/// +/// ```rust,no_run +/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +/// # let mut con = client.get_connection(None).unwrap(); +/// let mut cmd = redis::cmd("SMEMBERS"); +/// let mut iter : redis::Iter = cmd.arg("my_set").clone().iter(&mut con).unwrap(); +/// ``` +impl Cmd { + /// Creates a new empty command. + pub fn new() -> Cmd { + Cmd { + data: vec![], + args: vec![], + cursor: None, + no_response: false, + } + } + + /// Creates a new empty command, with at least the requested capacity. + pub fn with_capacity(arg_count: usize, size_of_data: usize) -> Cmd { + Cmd { + data: Vec::with_capacity(size_of_data), + args: Vec::with_capacity(arg_count), + cursor: None, + no_response: false, + } + } + + /// Get the capacities for the internal buffers. + #[cfg(test)] + #[allow(dead_code)] + pub(crate) fn capacity(&self) -> (usize, usize) { + (self.args.capacity(), self.data.capacity()) + } + + /// Appends an argument to the command. The argument passed must + /// be a type that implements `ToRedisArgs`. Most primitive types as + /// well as vectors of primitive types implement it. + /// + /// For instance all of the following are valid: + /// + /// ```rust,no_run + /// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// redis::cmd("SET").arg(&["my_key", "my_value"]); + /// redis::cmd("SET").arg("my_key").arg(42); + /// redis::cmd("SET").arg("my_key").arg(b"my_value"); + /// ``` + #[inline] + pub fn arg(&mut self, arg: T) -> &mut Cmd { + arg.write_redis_args(self); + self + } + + /// Works similar to `arg` but adds a cursor argument. This is always + /// an integer and also flips the command implementation to support a + /// different mode for the iterators where the iterator will ask for + /// another batch of items when the local data is exhausted. + /// + /// ```rust,no_run + /// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// let mut cmd = redis::cmd("SSCAN"); + /// let mut iter : redis::Iter = + /// cmd.arg("my_set").cursor_arg(0).clone().iter(&mut con).unwrap(); + /// for x in iter { + /// // do something with the item + /// } + /// ``` + #[inline] + pub fn cursor_arg(&mut self, cursor: u64) -> &mut Cmd { + assert!(!self.in_scan_mode()); + self.cursor = Some(cursor); + self.args.push(Arg::Cursor); + self + } + + /// Returns the packed command as a byte vector. + #[inline] + pub fn get_packed_command(&self) -> Vec { + let mut cmd = Vec::new(); + self.write_packed_command(&mut cmd); + cmd + } + + pub(crate) fn write_packed_command(&self, cmd: &mut Vec) { + write_command_to_vec(cmd, self.args_iter(), self.cursor.unwrap_or(0)) + } + + pub(crate) fn write_packed_command_preallocated(&self, cmd: &mut Vec) { + write_command(cmd, self.args_iter(), self.cursor.unwrap_or(0)).unwrap() + } + + /// Like `get_packed_command` but replaces the cursor with the + /// provided value. If the command is not in scan mode, `None` + /// is returned. + #[inline] + fn get_packed_command_with_cursor(&self, cursor: u64) -> Option> { + if !self.in_scan_mode() { + None + } else { + Some(encode_command(self.args_iter(), cursor)) + } + } + + /// Returns true if the command is in scan mode. + #[inline] + pub fn in_scan_mode(&self) -> bool { + self.cursor.is_some() + } + + /// Sends the command as query to the connection and converts the + /// result to the target redis value. This is the general way how + /// you can retrieve data. + #[inline] + pub fn query(&self, con: &mut dyn ConnectionLike) -> RedisResult { + match con.req_command(self) { + Ok(val) => from_owned_redis_value(val), + Err(e) => Err(e), + } + } + + /// Async version of `query`. + #[inline] + #[cfg(feature = "aio")] + pub async fn query_async(&self, con: &mut C) -> RedisResult + where + C: crate::aio::ConnectionLike, + { + let val = con.req_packed_command(self).await?; + from_owned_redis_value(val) + } + + /// Similar to `query()` but returns an iterator over the items of the + /// bulk result or iterator. In normal mode this is not in any way more + /// efficient than just querying into a `Vec` as it's internally + /// implemented as buffering into a vector. This however is useful when + /// `cursor_arg` was used in which case the iterator will query for more + /// items until the server side cursor is exhausted. + /// + /// This is useful for commands such as `SSCAN`, `SCAN` and others. + /// + /// One specialty of this function is that it will check if the response + /// looks like a cursor or not and always just looks at the payload. + /// This way you can use the function the same for responses in the + /// format of `KEYS` (just a list) as well as `SSCAN` (which returns a + /// tuple of cursor and list). + #[inline] + pub fn iter(self, con: &mut dyn ConnectionLike) -> RedisResult> { + let rv = con.req_command(&self)?; + + let (cursor, batch) = if rv.looks_like_cursor() { + from_owned_redis_value::<(u64, Vec)>(rv)? + } else { + (0, from_owned_redis_value(rv)?) + }; + + Ok(Iter { + batch: batch.into_iter(), + cursor, + con, + cmd: self, + }) + } + + /// Similar to `iter()` but returns an AsyncIter over the items of the + /// bulk result or iterator. A [futures::Stream](https://docs.rs/futures/0.3.3/futures/stream/trait.Stream.html) + /// is implemented on AsyncIter. In normal mode this is not in any way more + /// efficient than just querying into a `Vec` as it's internally + /// implemented as buffering into a vector. This however is useful when + /// `cursor_arg` was used in which case the stream will query for more + /// items until the server side cursor is exhausted. + /// + /// This is useful for commands such as `SSCAN`, `SCAN` and others in async contexts. + /// + /// One specialty of this function is that it will check if the response + /// looks like a cursor or not and always just looks at the payload. + /// This way you can use the function the same for responses in the + /// format of `KEYS` (just a list) as well as `SSCAN` (which returns a + /// tuple of cursor and list). + #[cfg(feature = "aio")] + #[inline] + pub async fn iter_async<'a, T: FromRedisValue + 'a>( + mut self, + con: &'a mut (dyn AsyncConnection + Send), + ) -> RedisResult> { + let rv = con.req_packed_command(&self).await?; + + let (cursor, batch) = if rv.looks_like_cursor() { + from_owned_redis_value::<(u64, Vec)>(rv)? + } else { + (0, from_owned_redis_value(rv)?) + }; + if cursor == 0 { + self.cursor = None; + } else { + self.cursor = Some(cursor); + } + + Ok(AsyncIter { + inner: IterOrFuture::Iter(AsyncIterInner { + batch: batch.into_iter(), + con, + cmd: self, + }), + }) + } + + /// This is a shortcut to `query()` that does not return a value and + /// will fail the task if the query fails because of an error. This is + /// mainly useful in examples and for simple commands like setting + /// keys. + /// + /// This is equivalent to a call of query like this: + /// + /// ```rust,no_run + /// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// let _ : () = redis::cmd("PING").query(&mut con).unwrap(); + /// ``` + #[inline] + pub fn execute(&self, con: &mut dyn ConnectionLike) { + self.query::<()>(con).unwrap(); + } + + /// Returns an iterator over the arguments in this command (including the command name itself) + pub fn args_iter(&self) -> impl Clone + ExactSizeIterator> { + let mut prev = 0; + self.args.iter().map(move |arg| match *arg { + Arg::Simple(i) => { + let arg = Arg::Simple(&self.data[prev..i]); + prev = i; + arg + } + + Arg::Cursor => Arg::Cursor, + }) + } + + // Get a reference to the argument at `idx` + #[cfg(feature = "cluster")] + pub(crate) fn arg_idx(&self, idx: usize) -> Option<&[u8]> { + if idx >= self.args.len() { + return None; + } + + let start = if idx == 0 { + 0 + } else { + match self.args[idx - 1] { + Arg::Simple(n) => n, + _ => 0, + } + }; + let end = match self.args[idx] { + Arg::Simple(n) => n, + _ => 0, + }; + if start == 0 && end == 0 { + return None; + } + Some(&self.data[start..end]) + } + + /// Client won't read and wait for results. Currently only used for Pub/Sub commands in RESP3. + #[inline] + pub fn set_no_response(&mut self, nr: bool) -> &mut Cmd { + self.no_response = nr; + self + } + + /// Check whether command's result will be waited for. + #[inline] + pub fn is_no_response(&self) -> bool { + self.no_response + } +} + +impl fmt::Debug for Cmd { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let res = self + .args_iter() + .map(|arg| { + let bytes = match arg { + Arg::Cursor => b"", + Arg::Simple(val) => val, + }; + std::str::from_utf8(bytes).unwrap_or_default() + }) + .collect::>(); + f.debug_struct("Cmd").field("args", &res).finish() + } +} + +/// Shortcut function to creating a command with a single argument. +/// +/// The first argument of a redis command is always the name of the command +/// which needs to be a string. This is the recommended way to start a +/// command pipe. +/// +/// ```rust +/// redis::cmd("PING"); +/// ``` +pub fn cmd(name: &str) -> Cmd { + let mut rv = Cmd::new(); + rv.arg(name); + rv +} + +/// Packs a bunch of commands into a request. This is generally a quite +/// useless function as this functionality is nicely wrapped through the +/// `Cmd` object, but in some cases it can be useful. The return value +/// of this can then be send to the low level `ConnectionLike` methods. +/// +/// Example: +/// +/// ```rust +/// # use redis::ToRedisArgs; +/// let mut args = vec![]; +/// args.extend("SET".to_redis_args()); +/// args.extend("my_key".to_redis_args()); +/// args.extend(42.to_redis_args()); +/// let cmd = redis::pack_command(&args); +/// assert_eq!(cmd, b"*3\r\n$3\r\nSET\r\n$6\r\nmy_key\r\n$2\r\n42\r\n".to_vec()); +/// ``` +pub fn pack_command(args: &[Vec]) -> Vec { + encode_command(args.iter().map(|x| Arg::Simple(&x[..])), 0) +} + +/// Shortcut for creating a new pipeline. +pub fn pipe() -> Pipeline { + Pipeline::new() +} + +#[cfg(test)] +#[cfg(feature = "cluster")] +mod tests { + use super::Cmd; + + #[test] + fn test_cmd_arg_idx() { + let mut c = Cmd::new(); + assert_eq!(c.arg_idx(0), None); + + c.arg("SET"); + assert_eq!(c.arg_idx(0), Some(&b"SET"[..])); + assert_eq!(c.arg_idx(1), None); + + c.arg("foo").arg("42"); + assert_eq!(c.arg_idx(1), Some(&b"foo"[..])); + assert_eq!(c.arg_idx(2), Some(&b"42"[..])); + assert_eq!(c.arg_idx(3), None); + assert_eq!(c.arg_idx(4), None); + } +} diff --git a/glide-core/redis-rs/redis/src/commands/cluster_scan.rs b/glide-core/redis-rs/redis/src/commands/cluster_scan.rs new file mode 100644 index 0000000000..0fccb0e6f5 --- /dev/null +++ b/glide-core/redis-rs/redis/src/commands/cluster_scan.rs @@ -0,0 +1,750 @@ +use crate::aio::ConnectionLike; +use crate::cluster_async::{ + ClusterConnInner, Connect, Core, InternalRoutingInfo, InternalSingleNodeRouting, RefreshPolicy, + Response, MUTEX_READ_ERR, +}; +use crate::cluster_routing::SlotAddr; +use crate::cluster_topology::SLOT_SIZE; +use crate::{cmd, from_redis_value, Cmd, ErrorKind, RedisError, RedisResult, Value}; +use async_trait::async_trait; +use std::sync::Arc; +use strum_macros::Display; + +/// This module contains the implementation of scanning operations in a Redis cluster. +/// +/// The [`ClusterScanArgs`] struct represents the arguments for a cluster scan operation, +/// including the scan state reference, match pattern, count, and object type. +/// +/// The [[`ScanStateRC`]] struct is a wrapper for managing the state of a scan operation in a cluster. +/// It holds a reference to the scan state and provides methods for accessing the state. +/// +/// The [[`ClusterInScan`]] trait defines the methods for interacting with a Redis cluster during scanning, +/// including retrieving address information, refreshing slot mapping, and routing commands to specific address. +/// +/// The [[`ScanState`]] struct represents the state of a scan operation in a Redis cluster. +/// It holds information about the current scan state, including the cursor position, scanned slots map, +/// address being scanned, and address's epoch. + +const BITS_PER_U64: usize = u64::BITS as usize; +const NUM_OF_SLOTS: usize = SLOT_SIZE as usize; +const BITS_ARRAY_SIZE: usize = NUM_OF_SLOTS / BITS_PER_U64; +const END_OF_SCAN: u16 = NUM_OF_SLOTS as u16 + 1; +type SlotsBitsArray = [u64; BITS_ARRAY_SIZE]; + +#[derive(Clone)] +pub(crate) struct ClusterScanArgs { + pub(crate) scan_state_cursor: ScanStateRC, + match_pattern: Option>, + count: Option, + object_type: Option, +} + +#[derive(Debug, Clone, Display)] +/// Represents the type of an object in Redis. +pub enum ObjectType { + /// Represents a string object in Redis. + String, + /// Represents a list object in Redis. + List, + /// Represents a set object in Redis. + Set, + /// Represents a sorted set object in Redis. + ZSet, + /// Represents a hash object in Redis. + Hash, + /// Represents a stream object in Redis. + Stream, +} + +impl ClusterScanArgs { + pub(crate) fn new( + scan_state_cursor: ScanStateRC, + match_pattern: Option>, + count: Option, + object_type: Option, + ) -> Self { + Self { + scan_state_cursor, + match_pattern, + count, + object_type, + } + } +} + +#[derive(PartialEq, Debug, Clone, Default)] +pub enum ScanStateStage { + #[default] + Initiating, + InProgress, + Finished, +} + +#[derive(Debug, Clone, Default)] +/// A wrapper struct for managing the state of a scan operation in a cluster. +/// It holds a reference to the scan state and provides methods for accessing the state. +/// The `status` field indicates the status of the scan operation. +pub struct ScanStateRC { + scan_state_rc: Arc>, + status: ScanStateStage, +} + +impl ScanStateRC { + /// Creates a new instance of [`ScanStateRC`] from a given [`ScanState`]. + fn from_scan_state(scan_state: ScanState) -> Self { + Self { + scan_state_rc: Arc::new(Some(scan_state)), + status: ScanStateStage::InProgress, + } + } + + /// Creates a new instance of [`ScanStateRC`]. + /// + /// This method initializes the [`ScanStateRC`] with a reference to a [`ScanState`] that is initially set to `None`. + /// An empty ScanState is equivalent to a 0 cursor. + pub fn new() -> Self { + Self { + scan_state_rc: Arc::new(None), + status: ScanStateStage::Initiating, + } + } + /// create a new instance of [`ScanStateRC`] with finished state and empty scan state. + fn create_finished() -> Self { + Self { + scan_state_rc: Arc::new(None), + status: ScanStateStage::Finished, + } + } + /// Returns `true` if the scan state is finished. + pub fn is_finished(&self) -> bool { + self.status == ScanStateStage::Finished + } + + /// Returns a clone of the scan state, if it exist. + pub(crate) fn get_state_from_wrapper(&self) -> Option { + if self.status == ScanStateStage::Initiating || self.status == ScanStateStage::Finished { + None + } else { + self.scan_state_rc.as_ref().clone() + } + } +} + +/// This trait defines the methods for interacting with a Redis cluster during scanning. +#[async_trait] +pub(crate) trait ClusterInScan { + /// Retrieves the address associated with a given slot in the cluster. + async fn get_address_by_slot(&self, slot: u16) -> RedisResult>; + + /// Retrieves the epoch of a given address in the cluster. + /// The epoch represents the version of the address, which is updated when a failover occurs or slots migrate in. + async fn get_address_epoch(&self, address: &str) -> Result; + + /// Retrieves the slots assigned to a given address in the cluster. + async fn get_slots_of_address(&self, address: Arc) -> Vec; + + /// Routes a Redis command to a specific address in the cluster. + async fn route_command(&self, cmd: Cmd, address: &str) -> RedisResult; + + /// Check if all slots are covered by the cluster + async fn are_all_slots_covered(&self) -> bool; + + /// Check if the topology of the cluster has changed and refresh the slots if needed + async fn refresh_if_topology_changed(&self) -> RedisResult; +} + +/// Represents the state of a scan operation in a Redis cluster. +/// +/// This struct holds information about the current scan state, including the cursor position, +/// the scanned slots map, the address being scanned, and the address's epoch. +#[derive(PartialEq, Debug, Clone)] +pub(crate) struct ScanState { + // the real cursor in the scan operation + cursor: u64, + // a map of the slots that have been scanned + scanned_slots_map: SlotsBitsArray, + // the address that is being scanned currently, based on the next slot set to 0 in the scanned_slots_map, and the address that "owns" the slot + // in the SlotMap + pub(crate) address_in_scan: Arc, + // epoch represent the version of the address, when a failover happens or slots migrate in the epoch will be updated to +1 + address_epoch: u64, + // the status of the scan operation + scan_status: ScanStateStage, +} + +impl ScanState { + /// Create a new instance of ScanState. + /// + /// # Arguments + /// + /// * `cursor` - The cursor position. + /// * `scanned_slots_map` - The scanned slots map. + /// * `address_in_scan` - The address being scanned. + /// * `address_epoch` - The epoch of the address being scanned. + /// * `scan_status` - The status of the scan operation. + /// + /// # Returns + /// + /// A new instance of ScanState. + pub fn new( + cursor: u64, + scanned_slots_map: SlotsBitsArray, + address_in_scan: Arc, + address_epoch: u64, + scan_status: ScanStateStage, + ) -> Self { + Self { + cursor, + scanned_slots_map, + address_in_scan, + address_epoch, + scan_status, + } + } + + fn create_finished_state() -> Self { + Self { + cursor: 0, + scanned_slots_map: [0; BITS_ARRAY_SIZE], + address_in_scan: Default::default(), + address_epoch: 0, + scan_status: ScanStateStage::Finished, + } + } + + /// Initialize a new scan operation. + /// This method creates a new scan state with the cursor set to 0, the scanned slots map initialized to 0, + /// and the address set to the address associated with slot 0. + /// The address epoch is set to the epoch of the address. + /// If the address epoch cannot be retrieved, the method returns an error. + async fn initiate_scan(connection: &C) -> RedisResult { + let new_scanned_slots_map: SlotsBitsArray = [0; BITS_ARRAY_SIZE]; + let new_cursor = 0; + let address = connection.get_address_by_slot(0).await?; + let address_epoch = connection.get_address_epoch(&address).await.unwrap_or(0); + Ok(ScanState::new( + new_cursor, + new_scanned_slots_map, + address, + address_epoch, + ScanStateStage::InProgress, + )) + } + + /// Get the next slot to be scanned based on the scanned slots map. + /// If all slots have been scanned, the method returns [`END_OF_SCAN`]. + fn get_next_slot(&self, scanned_slots_map: &SlotsBitsArray) -> Option { + let all_slots_scanned = scanned_slots_map.iter().all(|&word| word == u64::MAX); + if all_slots_scanned { + return Some(END_OF_SCAN); + } + for (i, slot) in scanned_slots_map.iter().enumerate() { + let mut mask = 1; + for j in 0..BITS_PER_U64 { + if (slot & mask) == 0 { + return Some((i * BITS_PER_U64 + j) as u16); + } + mask <<= 1; + } + } + None + } + + /// Update the scan state without updating the scanned slots map. + /// This method is used when the address epoch has changed, and we can't determine which slots are new. + /// In this case, we skip updating the scanned slots map and only update the address and cursor. + async fn creating_state_without_slot_changes( + &self, + connection: &C, + ) -> RedisResult { + let next_slot = self.get_next_slot(&self.scanned_slots_map).unwrap_or(0); + let new_address = if next_slot == END_OF_SCAN { + return Ok(ScanState::create_finished_state()); + } else { + connection.get_address_by_slot(next_slot).await + }; + match new_address { + Ok(address) => { + let new_epoch = connection.get_address_epoch(&address).await.unwrap_or(0); + Ok(ScanState::new( + 0, + self.scanned_slots_map, + address, + new_epoch, + ScanStateStage::InProgress, + )) + } + Err(err) => Err(err), + } + } + + /// Update the scan state and get the next address to scan. + /// This method is called when the cursor reaches 0, indicating that the current address has been scanned. + /// This method updates the scan state based on the scanned slots map and retrieves the next address to scan. + /// If the address epoch has changed, the method skips updating the scanned slots map and only updates the address and cursor. + /// If the address epoch has not changed, the method updates the scanned slots map with the slots owned by the address. + /// The method returns the new scan state with the updated cursor, scanned slots map, address, and epoch. + async fn create_updated_scan_state_for_completed_address( + &mut self, + connection: &C, + ) -> RedisResult { + connection + .refresh_if_topology_changed() + .await + .map_err(|err| { + RedisError::from(( + ErrorKind::ResponseError, + "Error during cluster scan: failed to refresh slots", + format!("{:?}", err), + )) + })?; + let mut scanned_slots_map = self.scanned_slots_map; + // If the address epoch changed it mean that some slots in the address are new, so we cant know which slots been there from the beginning and which are new, or out and in later. + // In this case we will skip updating the scanned_slots_map and will just update the address and the cursor + let new_address_epoch = connection + .get_address_epoch(&self.address_in_scan) + .await + .unwrap_or(0); + if new_address_epoch != self.address_epoch { + return self.creating_state_without_slot_changes(connection).await; + } + // If epoch wasn't changed, the slots owned by the address after the refresh are all valid as slots that been scanned + // So we will update the scanned_slots_map with the slots owned by the address + let slots_scanned = connection + .get_slots_of_address(self.address_in_scan.clone()) + .await; + for slot in slots_scanned { + let slot_index = slot as usize / BITS_PER_U64; + let slot_bit = slot as usize % BITS_PER_U64; + scanned_slots_map[slot_index] |= 1 << slot_bit; + } + // Get the next address to scan and its param base on the next slot set to 0 in the scanned_slots_map + let next_slot = self.get_next_slot(&scanned_slots_map).unwrap_or(0); + let new_address = if next_slot == END_OF_SCAN { + return Ok(ScanState::create_finished_state()); + } else { + connection.get_address_by_slot(next_slot).await + }; + match new_address { + Ok(new_address) => { + let new_epoch = connection + .get_address_epoch(&new_address) + .await + .unwrap_or(0); + let new_cursor = 0; + Ok(ScanState::new( + new_cursor, + scanned_slots_map, + new_address, + new_epoch, + ScanStateStage::InProgress, + )) + } + Err(err) => Err(err), + } + } +} + +// Implement the [`ClusterInScan`] trait for [`InnerCore`] of async cluster connection. +#[async_trait] +impl ClusterInScan for Core +where + C: ConnectionLike + Connect + Clone + Send + Sync + 'static, +{ + async fn get_address_by_slot(&self, slot: u16) -> RedisResult> { + let address = self + .get_address_from_slot(slot, SlotAddr::ReplicaRequired) + .await; + match address { + Some(addr) => Ok(addr), + None => { + if self.are_all_slots_covered().await { + Err(RedisError::from(( + ErrorKind::IoError, + "Failed to get connection to the node cover the slot, please check the cluster configuration ", + ))) + } else { + Err(RedisError::from(( + ErrorKind::NotAllSlotsCovered, + "All slots are not covered by the cluster, please check the cluster configuration ", + ))) + } + } + } + } + + async fn get_address_epoch(&self, address: &str) -> Result { + self.as_ref().get_address_epoch(address).await + } + async fn get_slots_of_address(&self, address: Arc) -> Vec { + self.as_ref().get_slots_of_address(address).await + } + async fn route_command(&self, cmd: Cmd, address: &str) -> RedisResult { + let routing = InternalRoutingInfo::SingleNode(InternalSingleNodeRouting::ByAddress( + address.to_string(), + )); + let core = self.to_owned(); + let response = ClusterConnInner::::try_cmd_request(Arc::new(cmd), routing, core) + .await + .map_err(|err| err.1)?; + match response { + Response::Single(value) => Ok(value), + _ => Err(RedisError::from(( + ErrorKind::ClientError, + "Expected single response, got unexpected response", + ))), + } + } + async fn are_all_slots_covered(&self) -> bool { + ClusterConnInner::::check_if_all_slots_covered( + &self.conn_lock.read().expect(MUTEX_READ_ERR).slot_map, + ) + } + async fn refresh_if_topology_changed(&self) -> RedisResult { + ClusterConnInner::check_topology_and_refresh_if_diff( + self.to_owned(), + // The cluster SCAN implementation must refresh the slots when a topology change is found + // to ensure the scan logic is correct. + &RefreshPolicy::NotThrottable, + ) + .await + } +} + +/// Perform a cluster scan operation. +/// This function performs a scan operation in a Redis cluster using the given [`ClusterInScan`] connection. +/// It scans the cluster for keys based on the given `ClusterScanArgs` arguments. +/// The function returns a tuple containing the new scan state cursor and the keys found in the scan operation. +/// If the scan operation fails, an error is returned. +/// +/// # Arguments +/// * `core` - The connection to the Redis cluster. +/// * `cluster_scan_args` - The arguments for the cluster scan operation. +/// +/// # Returns +/// A tuple containing the new scan state cursor and the keys found in the scan operation. +/// If the scan operation fails, an error is returned. +pub(crate) async fn cluster_scan( + core: C, + cluster_scan_args: ClusterScanArgs, +) -> RedisResult<(ScanStateRC, Vec)> +where + C: ClusterInScan, +{ + let ClusterScanArgs { + scan_state_cursor, + match_pattern, + count, + object_type, + } = cluster_scan_args; + // If scan_state is None, meaning we start a new scan + let scan_state = match scan_state_cursor.get_state_from_wrapper() { + Some(state) => state, + None => match ScanState::initiate_scan(&core).await { + Ok(state) => state, + Err(err) => { + return Err(err); + } + }, + }; + // Send the actual scan command to the address in the scan_state + let scan_result = send_scan( + &scan_state, + &core, + match_pattern.clone(), + count, + object_type.clone(), + ) + .await; + let ((new_cursor, new_keys), mut scan_state): ((u64, Vec), ScanState) = match scan_result + { + Ok(scan_result) => (from_redis_value(&scan_result)?, scan_state.clone()), + Err(err) => match err.kind() { + // If the scan command failed to route to the address because the address is not found in the cluster or + // the connection to the address cant be reached from different reasons, we will check we want to check if + // the problem is problem that we can recover from like failover or scale down or some network issue + // that we can retry the scan command to an address that own the next slot we are at. + ErrorKind::IoError + | ErrorKind::AllConnectionsUnavailable + | ErrorKind::ConnectionNotFoundForRoute => { + let retry = + retry_scan(&scan_state, &core, match_pattern, count, object_type).await?; + (from_redis_value(&retry.0?)?, retry.1) + } + _ => return Err(err), + }, + }; + + // If the cursor is 0, meaning we finished scanning the address + // we will update the scan state to get the next address to scan + if new_cursor == 0 { + scan_state = scan_state + .create_updated_scan_state_for_completed_address(&core) + .await?; + } + + // If the address is empty, meaning we finished scanning all the address + if scan_state.scan_status == ScanStateStage::Finished { + return Ok((ScanStateRC::create_finished(), new_keys)); + } + + scan_state = ScanState::new( + new_cursor, + scan_state.scanned_slots_map, + scan_state.address_in_scan, + scan_state.address_epoch, + ScanStateStage::InProgress, + ); + Ok((ScanStateRC::from_scan_state(scan_state), new_keys)) +} + +// Send the scan command to the address in the scan_state +async fn send_scan( + scan_state: &ScanState, + core: &C, + match_pattern: Option>, + count: Option, + object_type: Option, +) -> RedisResult +where + C: ClusterInScan, +{ + let mut scan_command = cmd("SCAN"); + scan_command.arg(scan_state.cursor); + if let Some(match_pattern) = match_pattern { + scan_command.arg("MATCH").arg(match_pattern); + } + if let Some(count) = count { + scan_command.arg("COUNT").arg(count); + } + if let Some(object_type) = object_type { + scan_command.arg("TYPE").arg(object_type.to_string()); + } + + core.route_command(scan_command, &scan_state.address_in_scan) + .await +} + +// If the scan command failed to route to the address we will check we will first refresh the slots, we will check if all slots are covered by cluster, +// and if so we will try to get a new address to scan for handling case of failover. +// if all slots are not covered by the cluster we will return an error indicating that the cluster is not well configured. +// if all slots are covered by cluster but we failed to get a new address to scan we will return an error indicating that we failed to get a new address to scan. +// if we got a new address to scan but the scan command failed to route to the address we will return an error indicating that we failed to route the command. +async fn retry_scan( + scan_state: &ScanState, + core: &C, + match_pattern: Option>, + count: Option, + object_type: Option, +) -> RedisResult<(RedisResult, ScanState)> +where + C: ClusterInScan, +{ + // TODO: This mechanism of refreshing on failure to route to address should be part of the routing mechanism + // After the routing mechanism is updated to handle this case, this refresh in the case bellow should be removed + core.refresh_if_topology_changed().await.map_err(|err| { + RedisError::from(( + ErrorKind::ResponseError, + "Error during cluster scan: failed to refresh slots", + format!("{:?}", err), + )) + })?; + if !core.are_all_slots_covered().await { + return Err(RedisError::from(( + ErrorKind::NotAllSlotsCovered, + "Not all slots are covered by the cluster, please check the cluster configuration", + ))); + } + // If for some reason we failed to reach the address we don't know if its a scale down or a failover. + // Since it might be scale down we cant just keep going with the current state we the same cursor as we are at + // the same point in the new address, so we need to get the new address own the next slot that haven't been scanned + // and start from the beginning of this address. + let next_slot = scan_state + .get_next_slot(&scan_state.scanned_slots_map) + .unwrap_or(0); + let address = core.get_address_by_slot(next_slot).await?; + + let new_epoch = core.get_address_epoch(&address).await.unwrap_or(0); + let scan_state = &ScanState::new( + 0, + scan_state.scanned_slots_map, + address, + new_epoch, + ScanStateStage::InProgress, + ); + let res = ( + send_scan(scan_state, core, match_pattern, count, object_type).await, + scan_state.clone(), + ); + Ok(res) +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_creation_of_empty_scan_wrapper() { + let scan_state_wrapper = ScanStateRC::new(); + assert!(scan_state_wrapper.status == ScanStateStage::Initiating); + } + + #[test] + fn test_creation_of_scan_state_wrapper_from() { + let scan_state = ScanState { + cursor: 0, + scanned_slots_map: [0; BITS_ARRAY_SIZE], + address_in_scan: String::from("address1").into(), + address_epoch: 1, + scan_status: ScanStateStage::InProgress, + }; + + let scan_state_wrapper = ScanStateRC::from_scan_state(scan_state); + assert!(!scan_state_wrapper.is_finished()); + } + + #[test] + // Test the get_next_slot method + fn test_scan_state_get_next_slot() { + let scanned_slots_map: SlotsBitsArray = [0; BITS_ARRAY_SIZE]; + let scan_state = ScanState { + cursor: 0, + scanned_slots_map, + address_in_scan: String::from("address1").into(), + address_epoch: 1, + scan_status: ScanStateStage::InProgress, + }; + let next_slot = scan_state.get_next_slot(&scanned_slots_map); + assert_eq!(next_slot, Some(0)); + // Set the first slot to 1 + let mut scanned_slots_map: SlotsBitsArray = [0; BITS_ARRAY_SIZE]; + scanned_slots_map[0] = 1; + let scan_state = ScanState { + cursor: 0, + scanned_slots_map, + address_in_scan: String::from("address1").into(), + address_epoch: 1, + scan_status: ScanStateStage::InProgress, + }; + let next_slot = scan_state.get_next_slot(&scanned_slots_map); + assert_eq!(next_slot, Some(1)); + } + // Create a mock connection + struct MockConnection; + #[async_trait] + impl ClusterInScan for MockConnection { + async fn refresh_if_topology_changed(&self) -> RedisResult { + Ok(true) + } + async fn get_address_by_slot(&self, _slot: u16) -> RedisResult> { + Ok("mock_address".to_string().into()) + } + async fn get_address_epoch(&self, _address: &str) -> Result { + Ok(0) + } + async fn get_slots_of_address(&self, address: Arc) -> Vec { + if address.as_str() == "mock_address" { + vec![3, 4, 5] + } else { + vec![0, 1, 2] + } + } + async fn route_command(&self, _: Cmd, _: &str) -> RedisResult { + unimplemented!() + } + async fn are_all_slots_covered(&self) -> bool { + true + } + } + // Test the initiate_scan function + #[tokio::test] + async fn test_initiate_scan() { + let connection = MockConnection; + let scan_state = ScanState::initiate_scan(&connection).await.unwrap(); + + // Assert that the scan state is initialized correctly + assert_eq!(scan_state.cursor, 0); + assert_eq!(scan_state.scanned_slots_map, [0; BITS_ARRAY_SIZE]); + assert_eq!( + scan_state.address_in_scan, + "mock_address".to_string().into() + ); + assert_eq!(scan_state.address_epoch, 0); + } + + // Test the get_next_slot function + #[test] + fn test_get_next_slot() { + let scan_state = ScanState { + cursor: 0, + scanned_slots_map: [0; BITS_ARRAY_SIZE], + address_in_scan: "".to_string().into(), + address_epoch: 0, + scan_status: ScanStateStage::InProgress, + }; + // Test when all first bits of each u6 are set to 1, the next slots should be 1 + let scanned_slots_map: SlotsBitsArray = [1; BITS_ARRAY_SIZE]; + let next_slot = scan_state.get_next_slot(&scanned_slots_map); + assert_eq!(next_slot, Some(1)); + + // Test when all slots are scanned, the next slot should be 0 + let scanned_slots_map: SlotsBitsArray = [u64::MAX; BITS_ARRAY_SIZE]; + let next_slot = scan_state.get_next_slot(&scanned_slots_map); + assert_eq!(next_slot, Some(16385)); + + // Test when first, second, fourth, sixth and eighth slots scanned, the next slot should be 2 + let mut scanned_slots_map: SlotsBitsArray = [0; BITS_ARRAY_SIZE]; + scanned_slots_map[0] = 171; // 10101011 + let next_slot = scan_state.get_next_slot(&scanned_slots_map); + assert_eq!(next_slot, Some(2)); + } + + // Test the update_scan_state_and_get_next_address function + #[tokio::test] + async fn test_update_scan_state_and_get_next_address() { + let connection = MockConnection; + let scan_state = ScanState::initiate_scan(&connection).await; + let updated_scan_state = scan_state + .unwrap() + .create_updated_scan_state_for_completed_address(&connection) + .await + .unwrap(); + + // cursor should be reset to 0 + assert_eq!(updated_scan_state.cursor, 0); + + // address_in_scan should be updated to the new address + assert_eq!( + updated_scan_state.address_in_scan, + "mock_address".to_string().into() + ); + + // address_epoch should be updated to the new address epoch + assert_eq!(updated_scan_state.address_epoch, 0); + } + + #[tokio::test] + async fn test_update_scan_state_without_updating_scanned_map() { + let connection = MockConnection; + let scan_state = ScanState::new( + 0, + [0; BITS_ARRAY_SIZE], + "address".to_string().into(), + 0, + ScanStateStage::InProgress, + ); + let scanned_slots_map = scan_state.scanned_slots_map; + let updated_scan_state = scan_state + .creating_state_without_slot_changes(&connection) + .await + .unwrap(); + assert_eq!(updated_scan_state.scanned_slots_map, scanned_slots_map); + assert_eq!(updated_scan_state.cursor, 0); + assert_eq!( + updated_scan_state.address_in_scan, + "mock_address".to_string().into() + ); + assert_eq!(updated_scan_state.address_epoch, 0); + } +} diff --git a/glide-core/redis-rs/redis/src/commands/json.rs b/glide-core/redis-rs/redis/src/commands/json.rs new file mode 100644 index 0000000000..d63f70c86f --- /dev/null +++ b/glide-core/redis-rs/redis/src/commands/json.rs @@ -0,0 +1,390 @@ +use crate::cmd::{cmd, Cmd}; +use crate::connection::ConnectionLike; +use crate::pipeline::Pipeline; +use crate::types::{FromRedisValue, RedisResult, ToRedisArgs}; +use crate::RedisError; + +#[cfg(feature = "cluster")] +use crate::commands::ClusterPipeline; + +use serde::ser::Serialize; + +macro_rules! implement_json_commands { + ( + $lifetime: lifetime + $( + $(#[$attr:meta])+ + fn $name:ident<$($tyargs:ident : $ty:ident),*>( + $($argname:ident: $argty:ty),*) $body:block + )* + ) => ( + + /// Implements RedisJSON commands for connection like objects. This + /// allows you to send commands straight to a connection or client. It + /// is also implemented for redis results of clients which makes for + /// very convenient access in some basic cases. + /// + /// This allows you to use nicer syntax for some common operations. + /// For instance this code: + /// + /// ```rust,no_run + /// use redis::JsonCommands; + /// use serde_json::json; + /// # fn do_something() -> redis::RedisResult<()> { + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_connection(None)?; + /// redis::cmd("JSON.SET").arg("my_key").arg("$").arg(&json!({"item": 42i32}).to_string()).execute(&mut con); + /// assert_eq!(redis::cmd("JSON.GET").arg("my_key").arg("$").query(&mut con), Ok(String::from(r#"[{"item":42}]"#))); + /// # Ok(()) } + /// ``` + /// + /// Will become this: + /// + /// ```rust,no_run + /// use redis::JsonCommands; + /// use serde_json::json; + /// # fn do_something() -> redis::RedisResult<()> { + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_connection(None)?; + /// con.json_set("my_key", "$", &json!({"item": 42i32}).to_string())?; + /// assert_eq!(con.json_get("my_key", "$"), Ok(String::from(r#"[{"item":42}]"#))); + /// assert_eq!(con.json_get("my_key", "$.item"), Ok(String::from(r#"[42]"#))); + /// # Ok(()) } + /// ``` + /// + /// With RedisJSON commands, you have to note that all results will be wrapped + /// in square brackets (or empty brackets if not found). If you want to deserialize it + /// with e.g. `serde_json` you have to use `Vec` for your output type instead of `T`. + pub trait JsonCommands : ConnectionLike + Sized { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + fn $name<$lifetime, $($tyargs: $ty, )* RV: FromRedisValue>( + &mut self $(, $argname: $argty)*) -> RedisResult + { Cmd::$name($($argname),*)?.query(self) } + )* + } + + impl Cmd { + $( + $(#[$attr])* + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + pub fn $name<$lifetime, $($tyargs: $ty),*>($($argname: $argty),*) -> RedisResult { + $body + } + )* + } + + /// Implements RedisJSON commands over asynchronous connections. This + /// allows you to send commands straight to a connection or client. + /// + /// This allows you to use nicer syntax for some common operations. + /// For instance this code: + /// + /// ```rust,no_run + /// use redis::JsonAsyncCommands; + /// use serde_json::json; + /// # async fn do_something() -> redis::RedisResult<()> { + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_async_connection(None).await?; + /// redis::cmd("JSON.SET").arg("my_key").arg("$").arg(&json!({"item": 42i32}).to_string()).query_async(&mut con).await?; + /// assert_eq!(redis::cmd("JSON.GET").arg("my_key").arg("$").query_async(&mut con).await, Ok(String::from(r#"[{"item":42}]"#))); + /// # Ok(()) } + /// ``` + /// + /// Will become this: + /// + /// ```rust,no_run + /// use redis::JsonAsyncCommands; + /// use serde_json::json; + /// # async fn do_something() -> redis::RedisResult<()> { + /// use redis::Commands; + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_async_connection(None).await?; + /// con.json_set("my_key", "$", &json!({"item": 42i32}).to_string()).await?; + /// assert_eq!(con.json_get("my_key", "$").await, Ok(String::from(r#"[{"item":42}]"#))); + /// assert_eq!(con.json_get("my_key", "$.item").await, Ok(String::from(r#"[42]"#))); + /// # Ok(()) } + /// ``` + /// + /// With RedisJSON commands, you have to note that all results will be wrapped + /// in square brackets (or empty brackets if not found). If you want to deserialize it + /// with e.g. `serde_json` you have to use `Vec` for your output type instead of `T`. + /// + #[cfg(feature = "aio")] + pub trait JsonAsyncCommands : crate::aio::ConnectionLike + Send + Sized { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + fn $name<$lifetime, $($tyargs: $ty + Send + Sync + $lifetime,)* RV>( + & $lifetime mut self + $(, $argname: $argty)* + ) -> $crate::types::RedisFuture<'a, RV> + where + RV: FromRedisValue, + { + Box::pin(async move { + $body?.query_async(self).await + }) + } + )* + } + + /// Implements RedisJSON commands for pipelines. Unlike the regular + /// commands trait, this returns the pipeline rather than a result + /// directly. Other than that it works the same however. + impl Pipeline { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + pub fn $name<$lifetime, $($tyargs: $ty),*>( + &mut self $(, $argname: $argty)* + ) -> RedisResult<&mut Self> { + self.add_command($body?); + Ok(self) + } + )* + } + + /// Implements RedisJSON commands for cluster pipelines. Unlike the regular + /// commands trait, this returns the cluster pipeline rather than a result + /// directly. Other than that it works the same however. + #[cfg(feature = "cluster")] + impl ClusterPipeline { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + pub fn $name<$lifetime, $($tyargs: $ty),*>( + &mut self $(, $argname: $argty)* + ) -> RedisResult<&mut Self> { + self.add_command($body?); + Ok(self) + } + )* + } + + ) +} + +implement_json_commands! { + 'a + + /// Append the JSON `value` to the array at `path` after the last element in it. + fn json_arr_append(key: K, path: P, value: &'a V) { + let mut cmd = cmd("JSON.ARRAPPEND"); + + cmd.arg(key) + .arg(path) + .arg(serde_json::to_string(value)?); + + Ok::<_, RedisError>(cmd) + } + + /// Index array at `path`, returns first occurance of `value` + fn json_arr_index(key: K, path: P, value: &'a V) { + let mut cmd = cmd("JSON.ARRINDEX"); + + cmd.arg(key) + .arg(path) + .arg(serde_json::to_string(value)?); + + Ok::<_, RedisError>(cmd) + } + + /// Same as `json_arr_index` except takes a `start` and a `stop` value, setting these to `0` will mean + /// they make no effect on the query + /// + /// The default values for `start` and `stop` are `0`, so pass those in if you want them to take no effect + fn json_arr_index_ss(key: K, path: P, value: &'a V, start: &'a isize, stop: &'a isize) { + let mut cmd = cmd("JSON.ARRINDEX"); + + cmd.arg(key) + .arg(path) + .arg(serde_json::to_string(value)?) + .arg(start) + .arg(stop); + + Ok::<_, RedisError>(cmd) + } + + /// Inserts the JSON `value` in the array at `path` before the `index` (shifts to the right). + /// + /// `index` must be withing the array's range. + fn json_arr_insert(key: K, path: P, index: i64, value: &'a V) { + let mut cmd = cmd("JSON.ARRINSERT"); + + cmd.arg(key) + .arg(path) + .arg(index) + .arg(serde_json::to_string(value)?); + + Ok::<_, RedisError>(cmd) + + } + + /// Reports the length of the JSON Array at `path` in `key`. + fn json_arr_len(key: K, path: P) { + let mut cmd = cmd("JSON.ARRLEN"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Removes and returns an element from the `index` in the array. + /// + /// `index` defaults to `-1` (the end of the array). + fn json_arr_pop(key: K, path: P, index: i64) { + let mut cmd = cmd("JSON.ARRPOP"); + + cmd.arg(key) + .arg(path) + .arg(index); + + Ok::<_, RedisError>(cmd) + } + + /// Trims an array so that it contains only the specified inclusive range of elements. + /// + /// This command is extremely forgiving and using it with out-of-range indexes will not produce an error. + /// There are a few differences between how RedisJSON v2.0 and legacy versions handle out-of-range indexes. + fn json_arr_trim(key: K, path: P, start: i64, stop: i64) { + let mut cmd = cmd("JSON.ARRTRIM"); + + cmd.arg(key) + .arg(path) + .arg(start) + .arg(stop); + + Ok::<_, RedisError>(cmd) + } + + /// Clears container values (Arrays/Objects), and sets numeric values to 0. + fn json_clear(key: K, path: P) { + let mut cmd = cmd("JSON.CLEAR"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Deletes a value at `path`. + fn json_del(key: K, path: P) { + let mut cmd = cmd("JSON.DEL"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Gets JSON Value(s) at `path`. + /// + /// Runs `JSON.GET` if key is singular, `JSON.MGET` if there are multiple keys. + /// + /// With RedisJSON commands, you have to note that all results will be wrapped + /// in square brackets (or empty brackets if not found). If you want to deserialize it + /// with e.g. `serde_json` you have to use `Vec` for your output type instead of `T`. + fn json_get(key: K, path: P) { + let mut cmd = cmd(if key.is_single_arg() { "JSON.GET" } else { "JSON.MGET" }); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Increments the number value stored at `path` by `number`. + fn json_num_incr_by(key: K, path: P, value: i64) { + let mut cmd = cmd("JSON.NUMINCRBY"); + + cmd.arg(key) + .arg(path) + .arg(value); + + Ok::<_, RedisError>(cmd) + } + + /// Returns the keys in the object that's referenced by `path`. + fn json_obj_keys(key: K, path: P) { + let mut cmd = cmd("JSON.OBJKEYS"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Reports the number of keys in the JSON Object at `path` in `key`. + fn json_obj_len(key: K, path: P) { + let mut cmd = cmd("JSON.OBJLEN"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Sets the JSON Value at `path` in `key`. + fn json_set(key: K, path: P, value: &'a V) { + let mut cmd = cmd("JSON.SET"); + + cmd.arg(key) + .arg(path) + .arg(serde_json::to_string(value)?); + + Ok::<_, RedisError>(cmd) + } + + /// Appends the `json-string` values to the string at `path`. + fn json_str_append(key: K, path: P, value: V) { + let mut cmd = cmd("JSON.STRAPPEND"); + + cmd.arg(key) + .arg(path) + .arg(value); + + Ok::<_, RedisError>(cmd) + } + + /// Reports the length of the JSON String at `path` in `key`. + fn json_str_len(key: K, path: P) { + let mut cmd = cmd("JSON.STRLEN"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Toggle a `boolean` value stored at `path`. + fn json_toggle(key: K, path: P) { + let mut cmd = cmd("JSON.TOGGLE"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Reports the type of JSON value at `path`. + fn json_type(key: K, path: P) { + let mut cmd = cmd("JSON.TYPE"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } +} + +impl JsonCommands for T where T: ConnectionLike {} + +#[cfg(feature = "aio")] +impl JsonAsyncCommands for T where T: crate::aio::ConnectionLike + Send + Sized {} diff --git a/glide-core/redis-rs/redis/src/commands/macros.rs b/glide-core/redis-rs/redis/src/commands/macros.rs new file mode 100644 index 0000000000..9e7d4373c0 --- /dev/null +++ b/glide-core/redis-rs/redis/src/commands/macros.rs @@ -0,0 +1,275 @@ +macro_rules! implement_commands { + ( + $lifetime: lifetime + $( + $(#[$attr:meta])+ + fn $name:ident<$($tyargs:ident : $ty:ident),*>( + $($argname:ident: $argty:ty),*) $body:block + )* + ) => + ( + /// Implements common redis commands for connection like objects. This + /// allows you to send commands straight to a connection or client. It + /// is also implemented for redis results of clients which makes for + /// very convenient access in some basic cases. + /// + /// This allows you to use nicer syntax for some common operations. + /// For instance this code: + /// + /// ```rust,no_run + /// # fn do_something() -> redis::RedisResult<()> { + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_connection(None)?; + /// redis::cmd("SET").arg("my_key").arg(42).execute(&mut con); + /// assert_eq!(redis::cmd("GET").arg("my_key").query(&mut con), Ok(42)); + /// # Ok(()) } + /// ``` + /// + /// Will become this: + /// + /// ```rust,no_run + /// # fn do_something() -> redis::RedisResult<()> { + /// use redis::Commands; + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_connection(None)?; + /// con.set("my_key", 42)?; + /// assert_eq!(con.get("my_key"), Ok(42)); + /// # Ok(()) } + /// ``` + pub trait Commands : ConnectionLike+Sized { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + fn $name<$lifetime, $($tyargs: $ty, )* RV: FromRedisValue>( + &mut self $(, $argname: $argty)*) -> RedisResult + { Cmd::$name($($argname),*).query(self) } + )* + + /// Incrementally iterate the keys space. + #[inline] + fn scan(&mut self) -> RedisResult> { + let mut c = cmd("SCAN"); + c.cursor_arg(0); + c.iter(self) + } + + /// Incrementally iterate the keys space for keys matching a pattern. + #[inline] + fn scan_match(&mut self, pattern: P) -> RedisResult> { + let mut c = cmd("SCAN"); + c.cursor_arg(0).arg("MATCH").arg(pattern); + c.iter(self) + } + + /// Incrementally iterate hash fields and associated values. + #[inline] + fn hscan(&mut self, key: K) -> RedisResult> { + let mut c = cmd("HSCAN"); + c.arg(key).cursor_arg(0); + c.iter(self) + } + + /// Incrementally iterate hash fields and associated values for + /// field names matching a pattern. + #[inline] + fn hscan_match + (&mut self, key: K, pattern: P) -> RedisResult> { + let mut c = cmd("HSCAN"); + c.arg(key).cursor_arg(0).arg("MATCH").arg(pattern); + c.iter(self) + } + + /// Incrementally iterate set elements. + #[inline] + fn sscan(&mut self, key: K) -> RedisResult> { + let mut c = cmd("SSCAN"); + c.arg(key).cursor_arg(0); + c.iter(self) + } + + /// Incrementally iterate set elements for elements matching a pattern. + #[inline] + fn sscan_match + (&mut self, key: K, pattern: P) -> RedisResult> { + let mut c = cmd("SSCAN"); + c.arg(key).cursor_arg(0).arg("MATCH").arg(pattern); + c.iter(self) + } + + /// Incrementally iterate sorted set elements. + #[inline] + fn zscan(&mut self, key: K) -> RedisResult> { + let mut c = cmd("ZSCAN"); + c.arg(key).cursor_arg(0); + c.iter(self) + } + + /// Incrementally iterate sorted set elements for elements matching a pattern. + #[inline] + fn zscan_match + (&mut self, key: K, pattern: P) -> RedisResult> { + let mut c = cmd("ZSCAN"); + c.arg(key).cursor_arg(0).arg("MATCH").arg(pattern); + c.iter(self) + } + } + + impl Cmd { + $( + $(#[$attr])* + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + pub fn $name<$lifetime, $($tyargs: $ty),*>($($argname: $argty),*) -> Self { + ::std::mem::replace($body, Cmd::new()) + } + )* + } + + /// Implements common redis commands over asynchronous connections. This + /// allows you to send commands straight to a connection or client. + /// + /// This allows you to use nicer syntax for some common operations. + /// For instance this code: + /// + /// ```rust,no_run + /// use redis::AsyncCommands; + /// # async fn do_something() -> redis::RedisResult<()> { + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_async_connection(None).await?; + /// redis::cmd("SET").arg("my_key").arg(42i32).query_async(&mut con).await?; + /// assert_eq!(redis::cmd("GET").arg("my_key").query_async(&mut con).await, Ok(42i32)); + /// # Ok(()) } + /// ``` + /// + /// Will become this: + /// + /// ```rust,no_run + /// use redis::AsyncCommands; + /// # async fn do_something() -> redis::RedisResult<()> { + /// use redis::Commands; + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_async_connection(None).await?; + /// con.set("my_key", 42i32).await?; + /// assert_eq!(con.get("my_key").await, Ok(42i32)); + /// # Ok(()) } + /// ``` + #[cfg(feature = "aio")] + pub trait AsyncCommands : crate::aio::ConnectionLike + Send + Sized { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + fn $name<$lifetime, $($tyargs: $ty + Send + Sync + $lifetime,)* RV>( + & $lifetime mut self + $(, $argname: $argty)* + ) -> crate::types::RedisFuture<'a, RV> + where + RV: FromRedisValue, + { + Box::pin(async move { ($body).query_async(self).await }) + } + )* + + /// Incrementally iterate the keys space. + #[inline] + fn scan(&mut self) -> crate::types::RedisFuture> { + let mut c = cmd("SCAN"); + c.cursor_arg(0); + Box::pin(async move { c.iter_async(self).await }) + } + + /// Incrementally iterate set elements for elements matching a pattern. + #[inline] + fn scan_match(&mut self, pattern: P) -> crate::types::RedisFuture> { + let mut c = cmd("SCAN"); + c.cursor_arg(0).arg("MATCH").arg(pattern); + Box::pin(async move { c.iter_async(self).await }) + } + + /// Incrementally iterate hash fields and associated values. + #[inline] + fn hscan(&mut self, key: K) -> crate::types::RedisFuture> { + let mut c = cmd("HSCAN"); + c.arg(key).cursor_arg(0); + Box::pin(async move {c.iter_async(self).await }) + } + + /// Incrementally iterate hash fields and associated values for + /// field names matching a pattern. + #[inline] + fn hscan_match + (&mut self, key: K, pattern: P) -> crate::types::RedisFuture> { + let mut c = cmd("HSCAN"); + c.arg(key).cursor_arg(0).arg("MATCH").arg(pattern); + Box::pin(async move {c.iter_async(self).await }) + } + + /// Incrementally iterate set elements. + #[inline] + fn sscan(&mut self, key: K) -> crate::types::RedisFuture> { + let mut c = cmd("SSCAN"); + c.arg(key).cursor_arg(0); + Box::pin(async move {c.iter_async(self).await }) + } + + /// Incrementally iterate set elements for elements matching a pattern. + #[inline] + fn sscan_match + (&mut self, key: K, pattern: P) -> crate::types::RedisFuture> { + let mut c = cmd("SSCAN"); + c.arg(key).cursor_arg(0).arg("MATCH").arg(pattern); + Box::pin(async move {c.iter_async(self).await }) + } + + /// Incrementally iterate sorted set elements. + #[inline] + fn zscan(&mut self, key: K) -> crate::types::RedisFuture> { + let mut c = cmd("ZSCAN"); + c.arg(key).cursor_arg(0); + Box::pin(async move {c.iter_async(self).await }) + } + + /// Incrementally iterate sorted set elements for elements matching a pattern. + #[inline] + fn zscan_match + (&mut self, key: K, pattern: P) -> crate::types::RedisFuture> { + let mut c = cmd("ZSCAN"); + c.arg(key).cursor_arg(0).arg("MATCH").arg(pattern); + Box::pin(async move {c.iter_async(self).await }) + } + } + + /// Implements common redis commands for pipelines. Unlike the regular + /// commands trait, this returns the pipeline rather than a result + /// directly. Other than that it works the same however. + impl Pipeline { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + pub fn $name<$lifetime, $($tyargs: $ty),*>( + &mut self $(, $argname: $argty)* + ) -> &mut Self { + self.add_command(::std::mem::replace($body, Cmd::new())) + } + )* + } + + // Implements common redis commands for cluster pipelines. Unlike the regular + // commands trait, this returns the cluster pipeline rather than a result + // directly. Other than that it works the same however. + #[cfg(feature = "cluster")] + impl ClusterPipeline { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + pub fn $name<$lifetime, $($tyargs: $ty),*>( + &mut self $(, $argname: $argty)* + ) -> &mut Self { + self.add_command(::std::mem::replace($body, Cmd::new())) + } + )* + } + ) +} diff --git a/glide-core/redis-rs/redis/src/commands/mod.rs b/glide-core/redis-rs/redis/src/commands/mod.rs new file mode 100644 index 0000000000..22a68cc987 --- /dev/null +++ b/glide-core/redis-rs/redis/src/commands/mod.rs @@ -0,0 +1,2187 @@ +use crate::cmd::{cmd, Cmd, Iter}; +use crate::connection::{Connection, ConnectionLike, Msg}; +use crate::pipeline::Pipeline; +use crate::types::{ + ExistenceCheck, Expiry, FromRedisValue, NumericBehavior, RedisResult, RedisWrite, SetExpiry, + ToRedisArgs, +}; + +#[macro_use] +mod macros; + +#[cfg(feature = "json")] +#[cfg_attr(docsrs, doc(cfg(feature = "json")))] +mod json; + +#[cfg(feature = "cluster-async")] +pub use cluster_scan::ScanStateRC; + +#[cfg(feature = "cluster-async")] +pub(crate) mod cluster_scan; + +#[cfg(feature = "cluster-async")] +pub use cluster_scan::ObjectType; + +#[cfg(feature = "json")] +pub use json::JsonCommands; + +#[cfg(all(feature = "json", feature = "aio"))] +pub use json::JsonAsyncCommands; + +#[cfg(feature = "cluster")] +use crate::cluster_pipeline::ClusterPipeline; + +#[cfg(feature = "geospatial")] +use crate::geo; + +#[cfg(feature = "streams")] +use crate::streams; + +#[cfg(feature = "acl")] +use crate::acl; +use crate::RedisConnectionInfo; + +implement_commands! { + 'a + // most common operations + + /// Get the value of a key. If key is a vec this becomes an `MGET`. + fn get(key: K) { + cmd(if key.is_single_arg() { "GET" } else { "MGET" }).arg(key) + } + + /// Get values of keys + fn mget(key: K){ + cmd("MGET").arg(key) + } + + /// Gets all keys matching pattern + fn keys(key: K) { + cmd("KEYS").arg(key) + } + + /// Set the string value of a key. + fn set(key: K, value: V) { + cmd("SET").arg(key).arg(value) + } + + /// Set the string value of a key with options. + fn set_options(key: K, value: V, options: SetOptions) { + cmd("SET").arg(key).arg(value).arg(options) + } + + /// Sets multiple keys to their values. + #[allow(deprecated)] + #[deprecated(since = "0.22.4", note = "Renamed to mset() to reflect Redis name")] + fn set_multiple(items: &'a [(K, V)]) { + cmd("MSET").arg(items) + } + + /// Sets multiple keys to their values. + fn mset(items: &'a [(K, V)]) { + cmd("MSET").arg(items) + } + + /// Set the value and expiration of a key. + fn set_ex(key: K, value: V, seconds: u64) { + cmd("SETEX").arg(key).arg(seconds).arg(value) + } + + /// Set the value and expiration in milliseconds of a key. + fn pset_ex(key: K, value: V, milliseconds: u64) { + cmd("PSETEX").arg(key).arg(milliseconds).arg(value) + } + + /// Set the value of a key, only if the key does not exist + fn set_nx(key: K, value: V) { + cmd("SETNX").arg(key).arg(value) + } + + /// Sets multiple keys to their values failing if at least one already exists. + fn mset_nx(items: &'a [(K, V)]) { + cmd("MSETNX").arg(items) + } + + /// Set the string value of a key and return its old value. + fn getset(key: K, value: V) { + cmd("GETSET").arg(key).arg(value) + } + + /// Get a range of bytes/substring from the value of a key. Negative values provide an offset from the end of the value. + fn getrange(key: K, from: isize, to: isize) { + cmd("GETRANGE").arg(key).arg(from).arg(to) + } + + /// Overwrite the part of the value stored in key at the specified offset. + fn setrange(key: K, offset: isize, value: V) { + cmd("SETRANGE").arg(key).arg(offset).arg(value) + } + + /// Delete one or more keys. + fn del(key: K) { + cmd("DEL").arg(key) + } + + /// Determine if a key exists. + fn exists(key: K) { + cmd("EXISTS").arg(key) + } + + /// Determine the type of a key. + fn key_type(key: K) { + cmd("TYPE").arg(key) + } + + /// Set a key's time to live in seconds. + fn expire(key: K, seconds: i64) { + cmd("EXPIRE").arg(key).arg(seconds) + } + + /// Set the expiration for a key as a UNIX timestamp. + fn expire_at(key: K, ts: i64) { + cmd("EXPIREAT").arg(key).arg(ts) + } + + /// Set a key's time to live in milliseconds. + fn pexpire(key: K, ms: i64) { + cmd("PEXPIRE").arg(key).arg(ms) + } + + /// Set the expiration for a key as a UNIX timestamp in milliseconds. + fn pexpire_at(key: K, ts: i64) { + cmd("PEXPIREAT").arg(key).arg(ts) + } + + /// Remove the expiration from a key. + fn persist(key: K) { + cmd("PERSIST").arg(key) + } + + /// Get the expiration time of a key. + fn ttl(key: K) { + cmd("TTL").arg(key) + } + + /// Get the expiration time of a key in milliseconds. + fn pttl(key: K) { + cmd("PTTL").arg(key) + } + + /// Get the value of a key and set expiration + fn get_ex(key: K, expire_at: Expiry) { + let (option, time_arg) = match expire_at { + Expiry::EX(sec) => ("EX", Some(sec)), + Expiry::PX(ms) => ("PX", Some(ms)), + Expiry::EXAT(timestamp_sec) => ("EXAT", Some(timestamp_sec)), + Expiry::PXAT(timestamp_ms) => ("PXAT", Some(timestamp_ms)), + Expiry::PERSIST => ("PERSIST", None), + }; + + cmd("GETEX").arg(key).arg(option).arg(time_arg) + } + + /// Get the value of a key and delete it + fn get_del(key: K) { + cmd("GETDEL").arg(key) + } + + /// Rename a key. + fn rename(key: K, new_key: N) { + cmd("RENAME").arg(key).arg(new_key) + } + + /// Rename a key, only if the new key does not exist. + fn rename_nx(key: K, new_key: N) { + cmd("RENAMENX").arg(key).arg(new_key) + } + + /// Unlink one or more keys. + fn unlink(key: K) { + cmd("UNLINK").arg(key) + } + + // common string operations + + /// Append a value to a key. + fn append(key: K, value: V) { + cmd("APPEND").arg(key).arg(value) + } + + /// Increment the numeric value of a key by the given amount. This + /// issues a `INCRBY` or `INCRBYFLOAT` depending on the type. + fn incr(key: K, delta: V) { + cmd(if delta.describe_numeric_behavior() == NumericBehavior::NumberIsFloat { + "INCRBYFLOAT" + } else { + "INCRBY" + }).arg(key).arg(delta) + } + + /// Decrement the numeric value of a key by the given amount. + fn decr(key: K, delta: V) { + cmd("DECRBY").arg(key).arg(delta) + } + + /// Sets or clears the bit at offset in the string value stored at key. + fn setbit(key: K, offset: usize, value: bool) { + cmd("SETBIT").arg(key).arg(offset).arg(i32::from(value)) + } + + /// Returns the bit value at offset in the string value stored at key. + fn getbit(key: K, offset: usize) { + cmd("GETBIT").arg(key).arg(offset) + } + + /// Count set bits in a string. + fn bitcount(key: K) { + cmd("BITCOUNT").arg(key) + } + + /// Count set bits in a string in a range. + fn bitcount_range(key: K, start: usize, end: usize) { + cmd("BITCOUNT").arg(key).arg(start).arg(end) + } + + /// Perform a bitwise AND between multiple keys (containing string values) + /// and store the result in the destination key. + fn bit_and(dstkey: D, srckeys: S) { + cmd("BITOP").arg("AND").arg(dstkey).arg(srckeys) + } + + /// Perform a bitwise OR between multiple keys (containing string values) + /// and store the result in the destination key. + fn bit_or(dstkey: D, srckeys: S) { + cmd("BITOP").arg("OR").arg(dstkey).arg(srckeys) + } + + /// Perform a bitwise XOR between multiple keys (containing string values) + /// and store the result in the destination key. + fn bit_xor(dstkey: D, srckeys: S) { + cmd("BITOP").arg("XOR").arg(dstkey).arg(srckeys) + } + + /// Perform a bitwise NOT of the key (containing string values) + /// and store the result in the destination key. + fn bit_not(dstkey: D, srckey: S) { + cmd("BITOP").arg("NOT").arg(dstkey).arg(srckey) + } + + /// Get the length of the value stored in a key. + fn strlen(key: K) { + cmd("STRLEN").arg(key) + } + + // hash operations + + /// Gets a single (or multiple) fields from a hash. + fn hget(key: K, field: F) { + cmd(if field.is_single_arg() { "HGET" } else { "HMGET" }).arg(key).arg(field) + } + + /// Deletes a single (or multiple) fields from a hash. + fn hdel(key: K, field: F) { + cmd("HDEL").arg(key).arg(field) + } + + /// Sets a single field in a hash. + fn hset(key: K, field: F, value: V) { + cmd("HSET").arg(key).arg(field).arg(value) + } + + /// Sets a single field in a hash if it does not exist. + fn hset_nx(key: K, field: F, value: V) { + cmd("HSETNX").arg(key).arg(field).arg(value) + } + + /// Sets a multiple fields in a hash. + fn hset_multiple(key: K, items: &'a [(F, V)]) { + cmd("HMSET").arg(key).arg(items) + } + + /// Increments a value. + fn hincr(key: K, field: F, delta: D) { + cmd(if delta.describe_numeric_behavior() == NumericBehavior::NumberIsFloat { + "HINCRBYFLOAT" + } else { + "HINCRBY" + }).arg(key).arg(field).arg(delta) + } + + /// Checks if a field in a hash exists. + fn hexists(key: K, field: F) { + cmd("HEXISTS").arg(key).arg(field) + } + + /// Gets all the keys in a hash. + fn hkeys(key: K) { + cmd("HKEYS").arg(key) + } + + /// Gets all the values in a hash. + fn hvals(key: K) { + cmd("HVALS").arg(key) + } + + /// Gets all the fields and values in a hash. + fn hgetall(key: K) { + cmd("HGETALL").arg(key) + } + + /// Gets the length of a hash. + fn hlen(key: K) { + cmd("HLEN").arg(key) + } + + // list operations + + /// Pop an element from a list, push it to another list + /// and return it; or block until one is available + fn blmove(srckey: S, dstkey: D, src_dir: Direction, dst_dir: Direction, timeout: f64) { + cmd("BLMOVE").arg(srckey).arg(dstkey).arg(src_dir).arg(dst_dir).arg(timeout) + } + + /// Pops `count` elements from the first non-empty list key from the list of + /// provided key names; or blocks until one is available. + fn blmpop(timeout: f64, numkeys: usize, key: K, dir: Direction, count: usize){ + cmd("BLMPOP").arg(timeout).arg(numkeys).arg(key).arg(dir).arg("COUNT").arg(count) + } + + /// Remove and get the first element in a list, or block until one is available. + fn blpop(key: K, timeout: f64) { + cmd("BLPOP").arg(key).arg(timeout) + } + + /// Remove and get the last element in a list, or block until one is available. + fn brpop(key: K, timeout: f64) { + cmd("BRPOP").arg(key).arg(timeout) + } + + /// Pop a value from a list, push it to another list and return it; + /// or block until one is available. + fn brpoplpush(srckey: S, dstkey: D, timeout: f64) { + cmd("BRPOPLPUSH").arg(srckey).arg(dstkey).arg(timeout) + } + + /// Get an element from a list by its index. + fn lindex(key: K, index: isize) { + cmd("LINDEX").arg(key).arg(index) + } + + /// Insert an element before another element in a list. + fn linsert_before( + key: K, pivot: P, value: V) { + cmd("LINSERT").arg(key).arg("BEFORE").arg(pivot).arg(value) + } + + /// Insert an element after another element in a list. + fn linsert_after( + key: K, pivot: P, value: V) { + cmd("LINSERT").arg(key).arg("AFTER").arg(pivot).arg(value) + } + + /// Returns the length of the list stored at key. + fn llen(key: K) { + cmd("LLEN").arg(key) + } + + /// Pop an element a list, push it to another list and return it + fn lmove(srckey: S, dstkey: D, src_dir: Direction, dst_dir: Direction) { + cmd("LMOVE").arg(srckey).arg(dstkey).arg(src_dir).arg(dst_dir) + } + + /// Pops `count` elements from the first non-empty list key from the list of + /// provided key names. + fn lmpop( numkeys: usize, key: K, dir: Direction, count: usize) { + cmd("LMPOP").arg(numkeys).arg(key).arg(dir).arg("COUNT").arg(count) + } + + /// Removes and returns the up to `count` first elements of the list stored at key. + /// + /// If `count` is not specified, then defaults to first element. + fn lpop(key: K, count: Option) { + cmd("LPOP").arg(key).arg(count) + } + + /// Returns the index of the first matching value of the list stored at key. + fn lpos(key: K, value: V, options: LposOptions) { + cmd("LPOS").arg(key).arg(value).arg(options) + } + + /// Insert all the specified values at the head of the list stored at key. + fn lpush(key: K, value: V) { + cmd("LPUSH").arg(key).arg(value) + } + + /// Inserts a value at the head of the list stored at key, only if key + /// already exists and holds a list. + fn lpush_exists(key: K, value: V) { + cmd("LPUSHX").arg(key).arg(value) + } + + /// Returns the specified elements of the list stored at key. + fn lrange(key: K, start: isize, stop: isize) { + cmd("LRANGE").arg(key).arg(start).arg(stop) + } + + /// Removes the first count occurrences of elements equal to value + /// from the list stored at key. + fn lrem(key: K, count: isize, value: V) { + cmd("LREM").arg(key).arg(count).arg(value) + } + + /// Trim an existing list so that it will contain only the specified + /// range of elements specified. + fn ltrim(key: K, start: isize, stop: isize) { + cmd("LTRIM").arg(key).arg(start).arg(stop) + } + + /// Sets the list element at index to value + fn lset(key: K, index: isize, value: V) { + cmd("LSET").arg(key).arg(index).arg(value) + } + + /// Removes and returns the up to `count` last elements of the list stored at key + /// + /// If `count` is not specified, then defaults to last element. + fn rpop(key: K, count: Option) { + cmd("RPOP").arg(key).arg(count) + } + + /// Pop a value from a list, push it to another list and return it. + fn rpoplpush(key: K, dstkey: D) { + cmd("RPOPLPUSH").arg(key).arg(dstkey) + } + + /// Insert all the specified values at the tail of the list stored at key. + fn rpush(key: K, value: V) { + cmd("RPUSH").arg(key).arg(value) + } + + /// Inserts value at the tail of the list stored at key, only if key + /// already exists and holds a list. + fn rpush_exists(key: K, value: V) { + cmd("RPUSHX").arg(key).arg(value) + } + + // set commands + + /// Add one or more members to a set. + fn sadd(key: K, member: M) { + cmd("SADD").arg(key).arg(member) + } + + /// Get the number of members in a set. + fn scard(key: K) { + cmd("SCARD").arg(key) + } + + /// Subtract multiple sets. + fn sdiff(keys: K) { + cmd("SDIFF").arg(keys) + } + + /// Subtract multiple sets and store the resulting set in a key. + fn sdiffstore(dstkey: D, keys: K) { + cmd("SDIFFSTORE").arg(dstkey).arg(keys) + } + + /// Intersect multiple sets. + fn sinter(keys: K) { + cmd("SINTER").arg(keys) + } + + /// Intersect multiple sets and store the resulting set in a key. + fn sinterstore(dstkey: D, keys: K) { + cmd("SINTERSTORE").arg(dstkey).arg(keys) + } + + /// Determine if a given value is a member of a set. + fn sismember(key: K, member: M) { + cmd("SISMEMBER").arg(key).arg(member) + } + + /// Determine if given values are members of a set. + fn smismember(key: K, members: M) { + cmd("SMISMEMBER").arg(key).arg(members) + } + + /// Get all the members in a set. + fn smembers(key: K) { + cmd("SMEMBERS").arg(key) + } + + /// Move a member from one set to another. + fn smove(srckey: S, dstkey: D, member: M) { + cmd("SMOVE").arg(srckey).arg(dstkey).arg(member) + } + + /// Remove and return a random member from a set. + fn spop(key: K) { + cmd("SPOP").arg(key) + } + + /// Get one random member from a set. + fn srandmember(key: K) { + cmd("SRANDMEMBER").arg(key) + } + + /// Get multiple random members from a set. + fn srandmember_multiple(key: K, count: usize) { + cmd("SRANDMEMBER").arg(key).arg(count) + } + + /// Remove one or more members from a set. + fn srem(key: K, member: M) { + cmd("SREM").arg(key).arg(member) + } + + /// Add multiple sets. + fn sunion(keys: K) { + cmd("SUNION").arg(keys) + } + + /// Add multiple sets and store the resulting set in a key. + fn sunionstore(dstkey: D, keys: K) { + cmd("SUNIONSTORE").arg(dstkey).arg(keys) + } + + // sorted set commands + + /// Add one member to a sorted set, or update its score if it already exists. + fn zadd(key: K, member: M, score: S) { + cmd("ZADD").arg(key).arg(score).arg(member) + } + + /// Add multiple members to a sorted set, or update its score if it already exists. + fn zadd_multiple(key: K, items: &'a [(S, M)]) { + cmd("ZADD").arg(key).arg(items) + } + + /// Get the number of members in a sorted set. + fn zcard(key: K) { + cmd("ZCARD").arg(key) + } + + /// Count the members in a sorted set with scores within the given values. + fn zcount(key: K, min: M, max: MM) { + cmd("ZCOUNT").arg(key).arg(min).arg(max) + } + + /// Increments the member in a sorted set at key by delta. + /// If the member does not exist, it is added with delta as its score. + fn zincr(key: K, member: M, delta: D) { + cmd("ZINCRBY").arg(key).arg(delta).arg(member) + } + + /// Intersect multiple sorted sets and store the resulting sorted set in + /// a new key using SUM as aggregation function. + fn zinterstore(dstkey: D, keys: &'a [K]) { + cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys) + } + + /// Intersect multiple sorted sets and store the resulting sorted set in + /// a new key using MIN as aggregation function. + fn zinterstore_min(dstkey: D, keys: &'a [K]) { + cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MIN") + } + + /// Intersect multiple sorted sets and store the resulting sorted set in + /// a new key using MAX as aggregation function. + fn zinterstore_max(dstkey: D, keys: &'a [K]) { + cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MAX") + } + + /// [`Commands::zinterstore`], but with the ability to specify a + /// multiplication factor for each sorted set by pairing one with each key + /// in a tuple. + fn zinterstore_weights(dstkey: D, keys: &'a [(K, W)]) { + let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight):&(K, W)| -> (&K, &W) {(key, weight)}).unzip(); + cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("WEIGHTS").arg(weights) + } + + /// [`Commands::zinterstore_min`], but with the ability to specify a + /// multiplication factor for each sorted set by pairing one with each key + /// in a tuple. + fn zinterstore_min_weights(dstkey: D, keys: &'a [(K, W)]) { + let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight):&(K, W)| -> (&K, &W) {(key, weight)}).unzip(); + cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MIN").arg("WEIGHTS").arg(weights) + } + + /// [`Commands::zinterstore_max`], but with the ability to specify a + /// multiplication factor for each sorted set by pairing one with each key + /// in a tuple. + fn zinterstore_max_weights(dstkey: D, keys: &'a [(K, W)]) { + let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight):&(K, W)| -> (&K, &W) {(key, weight)}).unzip(); + cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MAX").arg("WEIGHTS").arg(weights) + } + + /// Count the number of members in a sorted set between a given lexicographical range. + fn zlexcount(key: K, min: M, max: MM) { + cmd("ZLEXCOUNT").arg(key).arg(min).arg(max) + } + + /// Removes and returns the member with the highest score in a sorted set. + /// Blocks until a member is available otherwise. + fn bzpopmax(key: K, timeout: f64) { + cmd("BZPOPMAX").arg(key).arg(timeout) + } + + /// Removes and returns up to count members with the highest scores in a sorted set + fn zpopmax(key: K, count: isize) { + cmd("ZPOPMAX").arg(key).arg(count) + } + + /// Removes and returns the member with the lowest score in a sorted set. + /// Blocks until a member is available otherwise. + fn bzpopmin(key: K, timeout: f64) { + cmd("BZPOPMIN").arg(key).arg(timeout) + } + + /// Removes and returns up to count members with the lowest scores in a sorted set + fn zpopmin(key: K, count: isize) { + cmd("ZPOPMIN").arg(key).arg(count) + } + + /// Removes and returns up to count members with the highest scores, + /// from the first non-empty sorted set in the provided list of key names. + /// Blocks until a member is available otherwise. + fn bzmpop_max(timeout: f64, keys: &'a [K], count: isize) { + cmd("BZMPOP").arg(timeout).arg(keys.len()).arg(keys).arg("MAX").arg("COUNT").arg(count) + } + + /// Removes and returns up to count members with the highest scores, + /// from the first non-empty sorted set in the provided list of key names. + fn zmpop_max(keys: &'a [K], count: isize) { + cmd("ZMPOP").arg(keys.len()).arg(keys).arg("MAX").arg("COUNT").arg(count) + } + + /// Removes and returns up to count members with the lowest scores, + /// from the first non-empty sorted set in the provided list of key names. + /// Blocks until a member is available otherwise. + fn bzmpop_min(timeout: f64, keys: &'a [K], count: isize) { + cmd("BZMPOP").arg(timeout).arg(keys.len()).arg(keys).arg("MIN").arg("COUNT").arg(count) + } + + /// Removes and returns up to count members with the lowest scores, + /// from the first non-empty sorted set in the provided list of key names. + fn zmpop_min(keys: &'a [K], count: isize) { + cmd("ZMPOP").arg(keys.len()).arg(keys).arg("MIN").arg("COUNT").arg(count) + } + + /// Return up to count random members in a sorted set (or 1 if `count == None`) + fn zrandmember(key: K, count: Option) { + cmd("ZRANDMEMBER").arg(key).arg(count) + } + + /// Return up to count random members in a sorted set with scores + fn zrandmember_withscores(key: K, count: isize) { + cmd("ZRANDMEMBER").arg(key).arg(count).arg("WITHSCORES") + } + + /// Return a range of members in a sorted set, by index + fn zrange(key: K, start: isize, stop: isize) { + cmd("ZRANGE").arg(key).arg(start).arg(stop) + } + + /// Return a range of members in a sorted set, by index with scores. + fn zrange_withscores(key: K, start: isize, stop: isize) { + cmd("ZRANGE").arg(key).arg(start).arg(stop).arg("WITHSCORES") + } + + /// Return a range of members in a sorted set, by lexicographical range. + fn zrangebylex(key: K, min: M, max: MM) { + cmd("ZRANGEBYLEX").arg(key).arg(min).arg(max) + } + + /// Return a range of members in a sorted set, by lexicographical + /// range with offset and limit. + fn zrangebylex_limit( + key: K, min: M, max: MM, offset: isize, count: isize) { + cmd("ZRANGEBYLEX").arg(key).arg(min).arg(max).arg("LIMIT").arg(offset).arg(count) + } + + /// Return a range of members in a sorted set, by lexicographical range. + fn zrevrangebylex(key: K, max: MM, min: M) { + cmd("ZREVRANGEBYLEX").arg(key).arg(max).arg(min) + } + + /// Return a range of members in a sorted set, by lexicographical + /// range with offset and limit. + fn zrevrangebylex_limit( + key: K, max: MM, min: M, offset: isize, count: isize) { + cmd("ZREVRANGEBYLEX").arg(key).arg(max).arg(min).arg("LIMIT").arg(offset).arg(count) + } + + /// Return a range of members in a sorted set, by score. + fn zrangebyscore(key: K, min: M, max: MM) { + cmd("ZRANGEBYSCORE").arg(key).arg(min).arg(max) + } + + /// Return a range of members in a sorted set, by score with scores. + fn zrangebyscore_withscores(key: K, min: M, max: MM) { + cmd("ZRANGEBYSCORE").arg(key).arg(min).arg(max).arg("WITHSCORES") + } + + /// Return a range of members in a sorted set, by score with limit. + fn zrangebyscore_limit + (key: K, min: M, max: MM, offset: isize, count: isize) { + cmd("ZRANGEBYSCORE").arg(key).arg(min).arg(max).arg("LIMIT").arg(offset).arg(count) + } + + /// Return a range of members in a sorted set, by score with limit with scores. + fn zrangebyscore_limit_withscores + (key: K, min: M, max: MM, offset: isize, count: isize) { + cmd("ZRANGEBYSCORE").arg(key).arg(min).arg(max).arg("WITHSCORES") + .arg("LIMIT").arg(offset).arg(count) + } + + /// Determine the index of a member in a sorted set. + fn zrank(key: K, member: M) { + cmd("ZRANK").arg(key).arg(member) + } + + /// Remove one or more members from a sorted set. + fn zrem(key: K, members: M) { + cmd("ZREM").arg(key).arg(members) + } + + /// Remove all members in a sorted set between the given lexicographical range. + fn zrembylex(key: K, min: M, max: MM) { + cmd("ZREMRANGEBYLEX").arg(key).arg(min).arg(max) + } + + /// Remove all members in a sorted set within the given indexes. + fn zremrangebyrank(key: K, start: isize, stop: isize) { + cmd("ZREMRANGEBYRANK").arg(key).arg(start).arg(stop) + } + + /// Remove all members in a sorted set within the given scores. + fn zrembyscore(key: K, min: M, max: MM) { + cmd("ZREMRANGEBYSCORE").arg(key).arg(min).arg(max) + } + + /// Return a range of members in a sorted set, by index, with scores + /// ordered from high to low. + fn zrevrange(key: K, start: isize, stop: isize) { + cmd("ZREVRANGE").arg(key).arg(start).arg(stop) + } + + /// Return a range of members in a sorted set, by index, with scores + /// ordered from high to low. + fn zrevrange_withscores(key: K, start: isize, stop: isize) { + cmd("ZREVRANGE").arg(key).arg(start).arg(stop).arg("WITHSCORES") + } + + /// Return a range of members in a sorted set, by score. + fn zrevrangebyscore(key: K, max: MM, min: M) { + cmd("ZREVRANGEBYSCORE").arg(key).arg(max).arg(min) + } + + /// Return a range of members in a sorted set, by score with scores. + fn zrevrangebyscore_withscores(key: K, max: MM, min: M) { + cmd("ZREVRANGEBYSCORE").arg(key).arg(max).arg(min).arg("WITHSCORES") + } + + /// Return a range of members in a sorted set, by score with limit. + fn zrevrangebyscore_limit + (key: K, max: MM, min: M, offset: isize, count: isize) { + cmd("ZREVRANGEBYSCORE").arg(key).arg(max).arg(min).arg("LIMIT").arg(offset).arg(count) + } + + /// Return a range of members in a sorted set, by score with limit with scores. + fn zrevrangebyscore_limit_withscores + (key: K, max: MM, min: M, offset: isize, count: isize) { + cmd("ZREVRANGEBYSCORE").arg(key).arg(max).arg(min).arg("WITHSCORES") + .arg("LIMIT").arg(offset).arg(count) + } + + /// Determine the index of a member in a sorted set, with scores ordered from high to low. + fn zrevrank(key: K, member: M) { + cmd("ZREVRANK").arg(key).arg(member) + } + + /// Get the score associated with the given member in a sorted set. + fn zscore(key: K, member: M) { + cmd("ZSCORE").arg(key).arg(member) + } + + /// Get the scores associated with multiple members in a sorted set. + fn zscore_multiple(key: K, members: &'a [M]) { + cmd("ZMSCORE").arg(key).arg(members) + } + + /// Unions multiple sorted sets and store the resulting sorted set in + /// a new key using SUM as aggregation function. + fn zunionstore(dstkey: D, keys: &'a [K]) { + cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys) + } + + /// Unions multiple sorted sets and store the resulting sorted set in + /// a new key using MIN as aggregation function. + fn zunionstore_min(dstkey: D, keys: &'a [K]) { + cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MIN") + } + + /// Unions multiple sorted sets and store the resulting sorted set in + /// a new key using MAX as aggregation function. + fn zunionstore_max(dstkey: D, keys: &'a [K]) { + cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MAX") + } + + /// [`Commands::zunionstore`], but with the ability to specify a + /// multiplication factor for each sorted set by pairing one with each key + /// in a tuple. + fn zunionstore_weights(dstkey: D, keys: &'a [(K, W)]) { + let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight):&(K, W)| -> (&K, &W) {(key, weight)}).unzip(); + cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("WEIGHTS").arg(weights) + } + + /// [`Commands::zunionstore_min`], but with the ability to specify a + /// multiplication factor for each sorted set by pairing one with each key + /// in a tuple. + fn zunionstore_min_weights(dstkey: D, keys: &'a [(K, W)]) { + let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight):&(K, W)| -> (&K, &W) {(key, weight)}).unzip(); + cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MIN").arg("WEIGHTS").arg(weights) + } + + /// [`Commands::zunionstore_max`], but with the ability to specify a + /// multiplication factor for each sorted set by pairing one with each key + /// in a tuple. + fn zunionstore_max_weights(dstkey: D, keys: &'a [(K, W)]) { + let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight):&(K, W)| -> (&K, &W) {(key, weight)}).unzip(); + cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MAX").arg("WEIGHTS").arg(weights) + } + + // hyperloglog commands + + /// Adds the specified elements to the specified HyperLogLog. + fn pfadd(key: K, element: E) { + cmd("PFADD").arg(key).arg(element) + } + + /// Return the approximated cardinality of the set(s) observed by the + /// HyperLogLog at key(s). + fn pfcount(key: K) { + cmd("PFCOUNT").arg(key) + } + + /// Merge N different HyperLogLogs into a single one. + fn pfmerge(dstkey: D, srckeys: S) { + cmd("PFMERGE").arg(dstkey).arg(srckeys) + } + + /// Posts a message to the given channel. + fn publish(channel: K, message: E) { + cmd("PUBLISH").arg(channel).arg(message) + } + + // Object commands + + /// Returns the encoding of a key. + fn object_encoding(key: K) { + cmd("OBJECT").arg("ENCODING").arg(key) + } + + /// Returns the time in seconds since the last access of a key. + fn object_idletime(key: K) { + cmd("OBJECT").arg("IDLETIME").arg(key) + } + + /// Returns the logarithmic access frequency counter of a key. + fn object_freq(key: K) { + cmd("OBJECT").arg("FREQ").arg(key) + } + + /// Returns the reference count of a key. + fn object_refcount(key: K) { + cmd("OBJECT").arg("REFCOUNT").arg(key) + } + + // ACL commands + + /// When Redis is configured to use an ACL file (with the aclfile + /// configuration option), this command will reload the ACLs from the file, + /// replacing all the current ACL rules with the ones defined in the file. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_load<>() { + cmd("ACL").arg("LOAD") + } + + /// When Redis is configured to use an ACL file (with the aclfile + /// configuration option), this command will save the currently defined + /// ACLs from the server memory to the ACL file. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_save<>() { + cmd("ACL").arg("SAVE") + } + + /// Shows the currently active ACL rules in the Redis server. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_list<>() { + cmd("ACL").arg("LIST") + } + + /// Shows a list of all the usernames of the currently configured users in + /// the Redis ACL system. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_users<>() { + cmd("ACL").arg("USERS") + } + + /// Returns all the rules defined for an existing ACL user. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_getuser(username: K) { + cmd("ACL").arg("GETUSER").arg(username) + } + + /// Creates an ACL user without any privilege. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_setuser(username: K) { + cmd("ACL").arg("SETUSER").arg(username) + } + + /// Creates an ACL user with the specified rules or modify the rules of + /// an existing user. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_setuser_rules(username: K, rules: &'a [acl::Rule]) { + cmd("ACL").arg("SETUSER").arg(username).arg(rules) + } + + /// Delete all the specified ACL users and terminate all the connections + /// that are authenticated with such users. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_deluser(usernames: &'a [K]) { + cmd("ACL").arg("DELUSER").arg(usernames) + } + + /// Shows the available ACL categories. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_cat<>() { + cmd("ACL").arg("CAT") + } + + /// Shows all the Redis commands in the specified category. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_cat_categoryname(categoryname: K) { + cmd("ACL").arg("CAT").arg(categoryname) + } + + /// Generates a 256-bits password starting from /dev/urandom if available. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_genpass<>() { + cmd("ACL").arg("GENPASS") + } + + /// Generates a 1-to-1024-bits password starting from /dev/urandom if available. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_genpass_bits<>(bits: isize) { + cmd("ACL").arg("GENPASS").arg(bits) + } + + /// Returns the username the current connection is authenticated with. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_whoami<>() { + cmd("ACL").arg("WHOAMI") + } + + /// Shows a list of recent ACL security events + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_log<>(count: isize) { + cmd("ACL").arg("LOG").arg(count) + + } + + /// Clears the ACL log. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_log_reset<>() { + cmd("ACL").arg("LOG").arg("RESET") + } + + /// Returns a helpful text describing the different subcommands. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_help<>() { + cmd("ACL").arg("HELP") + } + + // + // geospatial commands + // + + /// Adds the specified geospatial items to the specified key. + /// + /// Every member has to be written as a tuple of `(longitude, latitude, + /// member_name)`. It can be a single tuple, or a vector of tuples. + /// + /// `longitude, latitude` can be set using [`redis::geo::Coord`][1]. + /// + /// [1]: ./geo/struct.Coord.html + /// + /// Returns the number of elements added to the sorted set, not including + /// elements already existing for which the score was updated. + /// + /// # Example + /// + /// ```rust,no_run + /// use redis::{Commands, Connection, RedisResult}; + /// use redis::geo::Coord; + /// + /// fn add_point(con: &mut Connection) -> RedisResult { + /// con.geo_add("my_gis", (Coord::lon_lat(13.361389, 38.115556), "Palermo")) + /// } + /// + /// fn add_point_with_tuples(con: &mut Connection) -> RedisResult { + /// con.geo_add("my_gis", ("13.361389", "38.115556", "Palermo")) + /// } + /// + /// fn add_many_points(con: &mut Connection) -> RedisResult { + /// con.geo_add("my_gis", &[ + /// ("13.361389", "38.115556", "Palermo"), + /// ("15.087269", "37.502669", "Catania") + /// ]) + /// } + /// ``` + #[cfg(feature = "geospatial")] + #[cfg_attr(docsrs, doc(cfg(feature = "geospatial")))] + fn geo_add(key: K, members: M) { + cmd("GEOADD").arg(key).arg(members) + } + + /// Return the distance between two members in the geospatial index + /// represented by the sorted set. + /// + /// If one or both the members are missing, the command returns NULL, so + /// it may be convenient to parse its response as either `Option` or + /// `Option`. + /// + /// # Example + /// + /// ```rust,no_run + /// use redis::{Commands, RedisResult}; + /// use redis::geo::Unit; + /// + /// fn get_dists(con: &mut redis::Connection) { + /// let x: RedisResult = con.geo_dist( + /// "my_gis", + /// "Palermo", + /// "Catania", + /// Unit::Kilometers + /// ); + /// // x is Ok(166.2742) + /// + /// let x: RedisResult> = con.geo_dist( + /// "my_gis", + /// "Palermo", + /// "Atlantis", + /// Unit::Meters + /// ); + /// // x is Ok(None) + /// } + /// ``` + #[cfg(feature = "geospatial")] + #[cfg_attr(docsrs, doc(cfg(feature = "geospatial")))] + fn geo_dist( + key: K, + member1: M1, + member2: M2, + unit: geo::Unit + ) { + cmd("GEODIST") + .arg(key) + .arg(member1) + .arg(member2) + .arg(unit) + } + + /// Return valid [Geohash][1] strings representing the position of one or + /// more members of the geospatial index represented by the sorted set at + /// key. + /// + /// [1]: https://en.wikipedia.org/wiki/Geohash + /// + /// # Example + /// + /// ```rust,no_run + /// use redis::{Commands, RedisResult}; + /// + /// fn get_hash(con: &mut redis::Connection) { + /// let x: RedisResult> = con.geo_hash("my_gis", "Palermo"); + /// // x is vec!["sqc8b49rny0"] + /// + /// let x: RedisResult> = con.geo_hash("my_gis", &["Palermo", "Catania"]); + /// // x is vec!["sqc8b49rny0", "sqdtr74hyu0"] + /// } + /// ``` + #[cfg(feature = "geospatial")] + #[cfg_attr(docsrs, doc(cfg(feature = "geospatial")))] + fn geo_hash(key: K, members: M) { + cmd("GEOHASH").arg(key).arg(members) + } + + /// Return the positions of all the specified members of the geospatial + /// index represented by the sorted set at key. + /// + /// Every position is a pair of `(longitude, latitude)`. [`redis::geo::Coord`][1] + /// can be used to convert these value in a struct. + /// + /// [1]: ./geo/struct.Coord.html + /// + /// # Example + /// + /// ```rust,no_run + /// use redis::{Commands, RedisResult}; + /// use redis::geo::Coord; + /// + /// fn get_position(con: &mut redis::Connection) { + /// let x: RedisResult>> = con.geo_pos("my_gis", &["Palermo", "Catania"]); + /// // x is [ [ 13.361389, 38.115556 ], [ 15.087269, 37.502669 ] ]; + /// + /// let x: Vec> = con.geo_pos("my_gis", "Palermo").unwrap(); + /// // x[0].longitude is 13.361389 + /// // x[0].latitude is 38.115556 + /// } + /// ``` + #[cfg(feature = "geospatial")] + #[cfg_attr(docsrs, doc(cfg(feature = "geospatial")))] + fn geo_pos(key: K, members: M) { + cmd("GEOPOS").arg(key).arg(members) + } + + /// Return the members of a sorted set populated with geospatial information + /// using [`geo_add`](#method.geo_add), which are within the borders of the area + /// specified with the center location and the maximum distance from the center + /// (the radius). + /// + /// Every item in the result can be read with [`redis::geo::RadiusSearchResult`][1], + /// which support the multiple formats returned by `GEORADIUS`. + /// + /// [1]: ./geo/struct.RadiusSearchResult.html + /// + /// ```rust,no_run + /// use redis::{Commands, RedisResult}; + /// use redis::geo::{RadiusOptions, RadiusSearchResult, RadiusOrder, Unit}; + /// + /// fn radius(con: &mut redis::Connection) -> Vec { + /// let opts = RadiusOptions::default().with_dist().order(RadiusOrder::Asc); + /// con.geo_radius("my_gis", 15.90, 37.21, 51.39, Unit::Kilometers, opts).unwrap() + /// } + /// ``` + #[cfg(feature = "geospatial")] + #[cfg_attr(docsrs, doc(cfg(feature = "geospatial")))] + fn geo_radius( + key: K, + longitude: f64, + latitude: f64, + radius: f64, + unit: geo::Unit, + options: geo::RadiusOptions + ) { + cmd("GEORADIUS") + .arg(key) + .arg(longitude) + .arg(latitude) + .arg(radius) + .arg(unit) + .arg(options) + } + + /// Retrieve members selected by distance with the center of `member`. The + /// member itself is always contained in the results. + #[cfg(feature = "geospatial")] + #[cfg_attr(docsrs, doc(cfg(feature = "geospatial")))] + fn geo_radius_by_member( + key: K, + member: M, + radius: f64, + unit: geo::Unit, + options: geo::RadiusOptions + ) { + cmd("GEORADIUSBYMEMBER") + .arg(key) + .arg(member) + .arg(radius) + .arg(unit) + .arg(options) + } + + // + // streams commands + // + + /// Ack pending stream messages checked out by a consumer. + /// + /// ```text + /// XACK ... + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xack( + key: K, + group: G, + ids: &'a [I]) { + cmd("XACK") + .arg(key) + .arg(group) + .arg(ids) + } + + + /// Add a stream message by `key`. Use `*` as the `id` for the current timestamp. + /// + /// ```text + /// XADD key [field value] [field value] ... + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xadd( + key: K, + id: ID, + items: &'a [(F, V)] + ) { + cmd("XADD").arg(key).arg(id).arg(items) + } + + + /// BTreeMap variant for adding a stream message by `key`. + /// Use `*` as the `id` for the current timestamp. + /// + /// ```text + /// XADD key [rust BTreeMap] ... + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xadd_map( + key: K, + id: ID, + map: BTM + ) { + cmd("XADD").arg(key).arg(id).arg(map) + } + + /// Add a stream message while capping the stream at a maxlength. + /// + /// ```text + /// XADD key [MAXLEN [~|=] ] [field value] [field value] ... + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xadd_maxlen< + K: ToRedisArgs, + ID: ToRedisArgs, + F: ToRedisArgs, + V: ToRedisArgs + >( + key: K, + maxlen: streams::StreamMaxlen, + id: ID, + items: &'a [(F, V)] + ) { + cmd("XADD") + .arg(key) + .arg(maxlen) + .arg(id) + .arg(items) + } + + + /// BTreeMap variant for adding a stream message while capping the stream at a maxlength. + /// + /// ```text + /// XADD key [MAXLEN [~|=] ] [rust BTreeMap] ... + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xadd_maxlen_map( + key: K, + maxlen: streams::StreamMaxlen, + id: ID, + map: BTM + ) { + cmd("XADD") + .arg(key) + .arg(maxlen) + .arg(id) + .arg(map) + } + + + + /// Claim pending, unacked messages, after some period of time, + /// currently checked out by another consumer. + /// + /// This method only accepts the must-have arguments for claiming messages. + /// If optional arguments are required, see `xclaim_options` below. + /// + /// ```text + /// XCLAIM [ ] + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xclaim( + key: K, + group: G, + consumer: C, + min_idle_time: MIT, + ids: &'a [ID] + ) { + cmd("XCLAIM") + .arg(key) + .arg(group) + .arg(consumer) + .arg(min_idle_time) + .arg(ids) + } + + /// This is the optional arguments version for claiming unacked, pending messages + /// currently checked out by another consumer. + /// + /// ```no_run + /// use redis::{Connection,Commands,RedisResult}; + /// use redis::streams::{StreamClaimOptions,StreamClaimReply}; + /// let client = redis::Client::open("redis://127.0.0.1/0").unwrap(); + /// let mut con = client.get_connection(None).unwrap(); + /// + /// // Claim all pending messages for key "k1", + /// // from group "g1", checked out by consumer "c1" + /// // for 10ms with RETRYCOUNT 2 and FORCE + /// + /// let opts = StreamClaimOptions::default() + /// .with_force() + /// .retry(2); + /// let results: RedisResult = + /// con.xclaim_options("k1", "g1", "c1", 10, &["0"], opts); + /// + /// // All optional arguments return a `Result` with one exception: + /// // Passing JUSTID returns only the message `id` and omits the HashMap for each message. + /// + /// let opts = StreamClaimOptions::default() + /// .with_justid(); + /// let results: RedisResult> = + /// con.xclaim_options("k1", "g1", "c1", 10, &["0"], opts); + /// ``` + /// + /// ```text + /// XCLAIM + /// [IDLE ] [TIME ] [RETRYCOUNT ] + /// [FORCE] [JUSTID] + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xclaim_options< + K: ToRedisArgs, + G: ToRedisArgs, + C: ToRedisArgs, + MIT: ToRedisArgs, + ID: ToRedisArgs + >( + key: K, + group: G, + consumer: C, + min_idle_time: MIT, + ids: &'a [ID], + options: streams::StreamClaimOptions + ) { + cmd("XCLAIM") + .arg(key) + .arg(group) + .arg(consumer) + .arg(min_idle_time) + .arg(ids) + .arg(options) + } + + + /// Deletes a list of `id`s for a given stream `key`. + /// + /// ```text + /// XDEL [ ... ] + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xdel( + key: K, + ids: &'a [ID] + ) { + cmd("XDEL").arg(key).arg(ids) + } + + + /// This command is used for creating a consumer `group`. It expects the stream key + /// to already exist. Otherwise, use `xgroup_create_mkstream` if it doesn't. + /// The `id` is the starting message id all consumers should read from. Use `$` If you want + /// all consumers to read from the last message added to stream. + /// + /// ```text + /// XGROUP CREATE + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xgroup_create( + key: K, + group: G, + id: ID + ) { + cmd("XGROUP") + .arg("CREATE") + .arg(key) + .arg(group) + .arg(id) + } + + + /// This is the alternate version for creating a consumer `group` + /// which makes the stream if it doesn't exist. + /// + /// ```text + /// XGROUP CREATE [MKSTREAM] + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xgroup_create_mkstream< + K: ToRedisArgs, + G: ToRedisArgs, + ID: ToRedisArgs + >( + key: K, + group: G, + id: ID + ) { + cmd("XGROUP") + .arg("CREATE") + .arg(key) + .arg(group) + .arg(id) + .arg("MKSTREAM") + } + + + /// Alter which `id` you want consumers to begin reading from an existing + /// consumer `group`. + /// + /// ```text + /// XGROUP SETID + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xgroup_setid( + key: K, + group: G, + id: ID + ) { + cmd("XGROUP") + .arg("SETID") + .arg(key) + .arg(group) + .arg(id) + } + + + /// Destroy an existing consumer `group` for a given stream `key` + /// + /// ```text + /// XGROUP SETID + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xgroup_destroy( + key: K, + group: G + ) { + cmd("XGROUP").arg("DESTROY").arg(key).arg(group) + } + + /// This deletes a `consumer` from an existing consumer `group` + /// for given stream `key. + /// + /// ```text + /// XGROUP DELCONSUMER + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xgroup_delconsumer( + key: K, + group: G, + consumer: C + ) { + cmd("XGROUP") + .arg("DELCONSUMER") + .arg(key) + .arg(group) + .arg(consumer) + } + + + /// This returns all info details about + /// which consumers have read messages for given consumer `group`. + /// Take note of the StreamInfoConsumersReply return type. + /// + /// *It's possible this return value might not contain new fields + /// added by Redis in future versions.* + /// + /// ```text + /// XINFO CONSUMERS + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xinfo_consumers( + key: K, + group: G + ) { + cmd("XINFO") + .arg("CONSUMERS") + .arg(key) + .arg(group) + } + + + /// Returns all consumer `group`s created for a given stream `key`. + /// Take note of the StreamInfoGroupsReply return type. + /// + /// *It's possible this return value might not contain new fields + /// added by Redis in future versions.* + /// + /// ```text + /// XINFO GROUPS + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xinfo_groups(key: K) { + cmd("XINFO").arg("GROUPS").arg(key) + } + + + /// Returns info about high-level stream details + /// (first & last message `id`, length, number of groups, etc.) + /// Take note of the StreamInfoStreamReply return type. + /// + /// *It's possible this return value might not contain new fields + /// added by Redis in future versions.* + /// + /// ```text + /// XINFO STREAM + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xinfo_stream(key: K) { + cmd("XINFO").arg("STREAM").arg(key) + } + + /// Returns the number of messages for a given stream `key`. + /// + /// ```text + /// XLEN + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xlen(key: K) { + cmd("XLEN").arg(key) + } + + + /// This is a basic version of making XPENDING command calls which only + /// passes a stream `key` and consumer `group` and it + /// returns details about which consumers have pending messages + /// that haven't been acked. + /// + /// You can use this method along with + /// `xclaim` or `xclaim_options` for determining which messages + /// need to be retried. + /// + /// Take note of the StreamPendingReply return type. + /// + /// ```text + /// XPENDING [ []] + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xpending( + key: K, + group: G + ) { + cmd("XPENDING").arg(key).arg(group) + } + + + /// This XPENDING version returns a list of all messages over the range. + /// You can use this for paginating pending messages (but without the message HashMap). + /// + /// Start and end follow the same rules `xrange` args. Set start to `-` + /// and end to `+` for the entire stream. + /// + /// Take note of the StreamPendingCountReply return type. + /// + /// ```text + /// XPENDING + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xpending_count< + K: ToRedisArgs, + G: ToRedisArgs, + S: ToRedisArgs, + E: ToRedisArgs, + C: ToRedisArgs + >( + key: K, + group: G, + start: S, + end: E, + count: C + ) { + cmd("XPENDING") + .arg(key) + .arg(group) + .arg(start) + .arg(end) + .arg(count) + } + + + /// An alternate version of `xpending_count` which filters by `consumer` name. + /// + /// Start and end follow the same rules `xrange` args. Set start to `-` + /// and end to `+` for the entire stream. + /// + /// Take note of the StreamPendingCountReply return type. + /// + /// ```text + /// XPENDING + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xpending_consumer_count< + K: ToRedisArgs, + G: ToRedisArgs, + S: ToRedisArgs, + E: ToRedisArgs, + C: ToRedisArgs, + CN: ToRedisArgs + >( + key: K, + group: G, + start: S, + end: E, + count: C, + consumer: CN + ) { + cmd("XPENDING") + .arg(key) + .arg(group) + .arg(start) + .arg(end) + .arg(count) + .arg(consumer) + } + + /// Returns a range of messages in a given stream `key`. + /// + /// Set `start` to `-` to begin at the first message. + /// Set `end` to `+` to end the most recent message. + /// You can pass message `id` to both `start` and `end`. + /// + /// Take note of the StreamRangeReply return type. + /// + /// ```text + /// XRANGE key start end + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xrange( + key: K, + start: S, + end: E + ) { + cmd("XRANGE").arg(key).arg(start).arg(end) + } + + + /// A helper method for automatically returning all messages in a stream by `key`. + /// **Use with caution!** + /// + /// ```text + /// XRANGE key - + + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xrange_all(key: K) { + cmd("XRANGE").arg(key).arg("-").arg("+") + } + + + /// A method for paginating a stream by `key`. + /// + /// ```text + /// XRANGE key start end [COUNT ] + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xrange_count( + key: K, + start: S, + end: E, + count: C + ) { + cmd("XRANGE") + .arg(key) + .arg(start) + .arg(end) + .arg("COUNT") + .arg(count) + } + + + /// Read a list of `id`s for each stream `key`. + /// This is the basic form of reading streams. + /// For more advanced control, like blocking, limiting, or reading by consumer `group`, + /// see `xread_options`. + /// + /// ```text + /// XREAD STREAMS key_1 key_2 ... key_N ID_1 ID_2 ... ID_N + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xread( + keys: &'a [K], + ids: &'a [ID] + ) { + cmd("XREAD").arg("STREAMS").arg(keys).arg(ids) + } + + /// This method handles setting optional arguments for + /// `XREAD` or `XREADGROUP` Redis commands. + /// ```no_run + /// use redis::{Connection,RedisResult,Commands}; + /// use redis::streams::{StreamReadOptions,StreamReadReply}; + /// let client = redis::Client::open("redis://127.0.0.1/0").unwrap(); + /// let mut con = client.get_connection(None).unwrap(); + /// + /// // Read 10 messages from the start of the stream, + /// // without registering as a consumer group. + /// + /// let opts = StreamReadOptions::default() + /// .count(10); + /// let results: RedisResult = + /// con.xread_options(&["k1"], &["0"], &opts); + /// + /// // Read all undelivered messages for a given + /// // consumer group. Be advised: the consumer group must already + /// // exist before making this call. Also note: we're passing + /// // '>' as the id here, which means all undelivered messages. + /// + /// let opts = StreamReadOptions::default() + /// .group("group-1", "consumer-1"); + /// let results: RedisResult = + /// con.xread_options(&["k1"], &[">"], &opts); + /// ``` + /// + /// ```text + /// XREAD [BLOCK ] [COUNT ] + /// STREAMS key_1 key_2 ... key_N + /// ID_1 ID_2 ... ID_N + /// + /// XREADGROUP [GROUP group-name consumer-name] [BLOCK ] [COUNT ] [NOACK] + /// STREAMS key_1 key_2 ... key_N + /// ID_1 ID_2 ... ID_N + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xread_options( + keys: &'a [K], + ids: &'a [ID], + options: &'a streams::StreamReadOptions + ) { + cmd(if options.read_only() { + "XREAD" + } else { + "XREADGROUP" + }) + .arg(options) + .arg("STREAMS") + .arg(keys) + .arg(ids) + } + + /// This is the reverse version of `xrange`. + /// The same rules apply for `start` and `end` here. + /// + /// ```text + /// XREVRANGE key end start + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xrevrange( + key: K, + end: E, + start: S + ) { + cmd("XREVRANGE").arg(key).arg(end).arg(start) + } + + /// This is the reverse version of `xrange_all`. + /// The same rules apply for `start` and `end` here. + /// + /// ```text + /// XREVRANGE key + - + /// ``` + fn xrevrange_all(key: K) { + cmd("XREVRANGE").arg(key).arg("+").arg("-") + } + + /// This is the reverse version of `xrange_count`. + /// The same rules apply for `start` and `end` here. + /// + /// ```text + /// XREVRANGE key end start [COUNT ] + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xrevrange_count( + key: K, + end: E, + start: S, + count: C + ) { + cmd("XREVRANGE") + .arg(key) + .arg(end) + .arg(start) + .arg("COUNT") + .arg(count) + } + + + /// Trim a stream `key` to a MAXLEN count. + /// + /// ```text + /// XTRIM MAXLEN [~|=] (Same as XADD MAXLEN option) + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xtrim( + key: K, + maxlen: streams::StreamMaxlen + ) { + cmd("XTRIM").arg(key).arg(maxlen) + } +} + +/// Allows pubsub callbacks to stop receiving messages. +/// +/// Arbitrary data may be returned from `Break`. +pub enum ControlFlow { + /// Continues. + Continue, + /// Breaks with a value. + Break(U), +} + +/// The PubSub trait allows subscribing to one or more channels +/// and receiving a callback whenever a message arrives. +/// +/// Each method handles subscribing to the list of keys, waiting for +/// messages, and unsubscribing from the same list of channels once +/// a ControlFlow::Break is encountered. +/// +/// Once (p)subscribe returns Ok(U), the connection is again safe to use +/// for calling other methods. +/// +/// # Examples +/// +/// ```rust,no_run +/// # fn do_something() -> redis::RedisResult<()> { +/// use redis::{PubSubCommands, ControlFlow}; +/// let client = redis::Client::open("redis://127.0.0.1/")?; +/// let mut con = client.get_connection(None)?; +/// let mut count = 0; +/// con.subscribe(&["foo"], |msg| { +/// // do something with message +/// assert_eq!(msg.get_channel(), Ok(String::from("foo"))); +/// +/// // increment messages seen counter +/// count += 1; +/// match count { +/// // stop after receiving 10 messages +/// 10 => ControlFlow::Break(()), +/// _ => ControlFlow::Continue, +/// } +/// })?; +/// # Ok(()) } +/// ``` +// TODO In the future, it would be nice to implement Try such that `?` will work +// within the closure. +pub trait PubSubCommands: Sized { + /// Subscribe to a list of channels using SUBSCRIBE and run the provided + /// closure for each message received. + /// + /// For every `Msg` passed to the provided closure, either + /// `ControlFlow::Break` or `ControlFlow::Continue` must be returned. This + /// method will not return until `ControlFlow::Break` is observed. + fn subscribe(&mut self, _: C, _: F) -> RedisResult + where + F: FnMut(Msg) -> ControlFlow, + C: ToRedisArgs; + + /// Subscribe to a list of channels using PSUBSCRIBE and run the provided + /// closure for each message received. + /// + /// For every `Msg` passed to the provided closure, either + /// `ControlFlow::Break` or `ControlFlow::Continue` must be returned. This + /// method will not return until `ControlFlow::Break` is observed. + fn psubscribe(&mut self, _: P, _: F) -> RedisResult + where + F: FnMut(Msg) -> ControlFlow, + P: ToRedisArgs; +} + +impl Commands for T where T: ConnectionLike {} + +#[cfg(feature = "aio")] +impl AsyncCommands for T where T: crate::aio::ConnectionLike + Send + Sized {} + +impl PubSubCommands for Connection { + fn subscribe(&mut self, channels: C, mut func: F) -> RedisResult + where + F: FnMut(Msg) -> ControlFlow, + C: ToRedisArgs, + { + let mut pubsub = self.as_pubsub(); + pubsub.subscribe(channels)?; + + loop { + let msg = pubsub.get_message()?; + match func(msg) { + ControlFlow::Continue => continue, + ControlFlow::Break(value) => return Ok(value), + } + } + } + + fn psubscribe(&mut self, patterns: P, mut func: F) -> RedisResult + where + F: FnMut(Msg) -> ControlFlow, + P: ToRedisArgs, + { + let mut pubsub = self.as_pubsub(); + pubsub.psubscribe(patterns)?; + + loop { + let msg = pubsub.get_message()?; + match func(msg) { + ControlFlow::Continue => continue, + ControlFlow::Break(value) => return Ok(value), + } + } + } +} + +/// Options for the [LPOS](https://redis.io/commands/lpos) command +/// +/// # Example +/// +/// ```rust,no_run +/// use redis::{Commands, RedisResult, LposOptions}; +/// fn fetch_list_position( +/// con: &mut redis::Connection, +/// key: &str, +/// value: &str, +/// count: usize, +/// rank: isize, +/// maxlen: usize, +/// ) -> RedisResult> { +/// let opts = LposOptions::default() +/// .count(count) +/// .rank(rank) +/// .maxlen(maxlen); +/// con.lpos(key, value, opts) +/// } +/// ``` +#[derive(Default)] +pub struct LposOptions { + count: Option, + maxlen: Option, + rank: Option, +} + +impl LposOptions { + /// Limit the results to the first N matching items. + pub fn count(mut self, n: usize) -> Self { + self.count = Some(n); + self + } + + /// Return the value of N from the matching items. + pub fn rank(mut self, n: isize) -> Self { + self.rank = Some(n); + self + } + + /// Limit the search to N items in the list. + pub fn maxlen(mut self, n: usize) -> Self { + self.maxlen = Some(n); + self + } +} + +impl ToRedisArgs for LposOptions { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + if let Some(n) = self.count { + out.write_arg(b"COUNT"); + out.write_arg_fmt(n); + } + + if let Some(n) = self.rank { + out.write_arg(b"RANK"); + out.write_arg_fmt(n); + } + + if let Some(n) = self.maxlen { + out.write_arg(b"MAXLEN"); + out.write_arg_fmt(n); + } + } + + fn is_single_arg(&self) -> bool { + false + } +} + +/// Enum for the LEFT | RIGHT args used by some commands +pub enum Direction { + /// Targets the first element (head) of the list + Left, + /// Targets the last element (tail) of the list + Right, +} + +impl ToRedisArgs for Direction { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + let s: &[u8] = match self { + Direction::Left => b"LEFT", + Direction::Right => b"RIGHT", + }; + out.write_arg(s); + } +} + +/// Options for the [SET](https://redis.io/commands/set) command +/// +/// # Example +/// ```rust,no_run +/// use redis::{Commands, RedisResult, SetOptions, SetExpiry, ExistenceCheck}; +/// fn set_key_value( +/// con: &mut redis::Connection, +/// key: &str, +/// value: &str, +/// ) -> RedisResult> { +/// let opts = SetOptions::default() +/// .conditional_set(ExistenceCheck::NX) +/// .get(true) +/// .with_expiration(SetExpiry::EX(60)); +/// con.set_options(key, value, opts) +/// } +/// ``` +#[derive(Clone, Copy, Default)] +pub struct SetOptions { + conditional_set: Option, + get: bool, + expiration: Option, +} + +impl SetOptions { + /// Set the existence check for the SET command + pub fn conditional_set(mut self, existence_check: ExistenceCheck) -> Self { + self.conditional_set = Some(existence_check); + self + } + + /// Set the GET option for the SET command + pub fn get(mut self, get: bool) -> Self { + self.get = get; + self + } + + /// Set the expiration for the SET command + pub fn with_expiration(mut self, expiration: SetExpiry) -> Self { + self.expiration = Some(expiration); + self + } +} + +impl ToRedisArgs for SetOptions { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + if let Some(ref conditional_set) = self.conditional_set { + match conditional_set { + ExistenceCheck::NX => { + out.write_arg(b"NX"); + } + ExistenceCheck::XX => { + out.write_arg(b"XX"); + } + } + } + if self.get { + out.write_arg(b"GET"); + } + if let Some(ref expiration) = self.expiration { + match expiration { + SetExpiry::EX(secs) => { + out.write_arg(b"EX"); + out.write_arg(format!("{}", secs).as_bytes()); + } + SetExpiry::PX(millis) => { + out.write_arg(b"PX"); + out.write_arg(format!("{}", millis).as_bytes()); + } + SetExpiry::EXAT(unix_time) => { + out.write_arg(b"EXAT"); + out.write_arg(format!("{}", unix_time).as_bytes()); + } + SetExpiry::PXAT(unix_time) => { + out.write_arg(b"PXAT"); + out.write_arg(format!("{}", unix_time).as_bytes()); + } + SetExpiry::KEEPTTL => { + out.write_arg(b"KEEPTTL"); + } + } + } + } +} + +/// Creates HELLO command for RESP3 with RedisConnectionInfo +pub fn resp3_hello(connection_info: &RedisConnectionInfo) -> Cmd { + let mut hello_cmd = cmd("HELLO"); + hello_cmd.arg("3"); + if let Some(password) = &connection_info.password { + let username: &str = match connection_info.username.as_ref() { + None => "default", + Some(username) => username, + }; + hello_cmd.arg("AUTH").arg(username).arg(password); + } + hello_cmd +} diff --git a/glide-core/redis-rs/redis/src/connection.rs b/glide-core/redis-rs/redis/src/connection.rs new file mode 100644 index 0000000000..f75b9df494 --- /dev/null +++ b/glide-core/redis-rs/redis/src/connection.rs @@ -0,0 +1,1997 @@ +use std::collections::{HashSet, VecDeque}; +use std::fmt; +use std::io::{self, Write}; +use std::net::{self, SocketAddr, TcpStream, ToSocketAddrs}; +use std::ops::DerefMut; +use std::path::PathBuf; +use std::str::{from_utf8, FromStr}; +use std::time::Duration; + +use crate::cmd::{cmd, pipe, Cmd}; +use crate::parser::Parser; +use crate::pipeline::Pipeline; +use crate::types::{ + from_redis_value, ErrorKind, FromRedisValue, HashMap, PushKind, RedisError, RedisResult, + ToRedisArgs, Value, +}; +use crate::{from_owned_redis_value, ProtocolVersion}; + +#[cfg(unix)] +use std::os::unix::net::UnixStream; +use std::vec::IntoIter; + +use crate::commands::resp3_hello; +#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] +use native_tls::{TlsConnector, TlsStream}; + +#[cfg(feature = "tls-rustls")] +use rustls::{RootCertStore, StreamOwned}; +#[cfg(feature = "tls-rustls")] +use std::sync::Arc; + +use crate::push_manager::PushManager; +use crate::PushInfo; + +#[cfg(all( + feature = "tls-rustls", + not(feature = "tls-native-tls"), + not(feature = "tls-rustls-webpki-roots") +))] +use rustls_native_certs::load_native_certs; + +#[cfg(feature = "tls-rustls")] +use crate::tls::TlsConnParams; + +// Non-exhaustive to prevent construction outside this crate +#[cfg(not(feature = "tls-rustls"))] +#[derive(Clone, Debug)] +#[non_exhaustive] +pub struct TlsConnParams; + +static DEFAULT_PORT: u16 = 6379; + +#[inline(always)] +fn connect_tcp(addr: (&str, u16)) -> io::Result { + let socket = TcpStream::connect(addr)?; + #[cfg(feature = "tcp_nodelay")] + socket.set_nodelay(true)?; + #[cfg(feature = "keep-alive")] + { + //For now rely on system defaults + const KEEP_ALIVE: socket2::TcpKeepalive = socket2::TcpKeepalive::new(); + //these are useless error that not going to happen + let socket2: socket2::Socket = socket.into(); + socket2.set_tcp_keepalive(&KEEP_ALIVE)?; + Ok(socket2.into()) + } + #[cfg(not(feature = "keep-alive"))] + { + Ok(socket) + } +} + +#[inline(always)] +fn connect_tcp_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result { + let socket = TcpStream::connect_timeout(addr, timeout)?; + #[cfg(feature = "tcp_nodelay")] + socket.set_nodelay(true)?; + #[cfg(feature = "keep-alive")] + { + //For now rely on system defaults + const KEEP_ALIVE: socket2::TcpKeepalive = socket2::TcpKeepalive::new(); + //these are useless error that not going to happen + let socket2: socket2::Socket = socket.into(); + socket2.set_tcp_keepalive(&KEEP_ALIVE)?; + Ok(socket2.into()) + } + #[cfg(not(feature = "keep-alive"))] + { + Ok(socket) + } +} + +/// This function takes a redis URL string and parses it into a URL +/// as used by rust-url. This is necessary as the default parser does +/// not understand how redis URLs function. +pub fn parse_redis_url(input: &str) -> Option { + match url::Url::parse(input) { + Ok(result) => match result.scheme() { + "redis" | "rediss" | "redis+unix" | "unix" => Some(result), + _ => None, + }, + Err(_) => None, + } +} + +/// TlsMode indicates use or do not use verification of certification. +/// Check [ConnectionAddr](ConnectionAddr::TcpTls::insecure) for more. +#[derive(Clone, Copy)] +pub enum TlsMode { + /// Secure verify certification. + Secure, + /// Insecure do not verify certification. + Insecure, +} + +/// Defines the connection address. +/// +/// Not all connection addresses are supported on all platforms. For instance +/// to connect to a unix socket you need to run this on an operating system +/// that supports them. +#[derive(Clone, Debug)] +pub enum ConnectionAddr { + /// Format for this is `(host, port)`. + Tcp(String, u16), + /// Format for this is `(host, port)`. + TcpTls { + /// Hostname + host: String, + /// Port + port: u16, + /// Disable hostname verification when connecting. + /// + /// # Warning + /// + /// You should think very carefully before you use this method. If hostname + /// verification is not used, any valid certificate for any site will be + /// trusted for use from any other. This introduces a significant + /// vulnerability to man-in-the-middle attacks. + insecure: bool, + + /// TLS certificates and client key. + tls_params: Option, + }, + /// Format for this is the path to the unix socket. + Unix(PathBuf), +} + +impl PartialEq for ConnectionAddr { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (ConnectionAddr::Tcp(host1, port1), ConnectionAddr::Tcp(host2, port2)) => { + host1 == host2 && port1 == port2 + } + ( + ConnectionAddr::TcpTls { + host: host1, + port: port1, + insecure: insecure1, + tls_params: _, + }, + ConnectionAddr::TcpTls { + host: host2, + port: port2, + insecure: insecure2, + tls_params: _, + }, + ) => port1 == port2 && host1 == host2 && insecure1 == insecure2, + (ConnectionAddr::Unix(path1), ConnectionAddr::Unix(path2)) => path1 == path2, + _ => false, + } + } +} + +impl Eq for ConnectionAddr {} + +impl ConnectionAddr { + /// Checks if this address is supported. + /// + /// Because not all platforms support all connection addresses this is a + /// quick way to figure out if a connection method is supported. Currently + /// this only affects unix connections which are only supported on unix + /// platforms and on older versions of rust also require an explicit feature + /// to be enabled. + pub fn is_supported(&self) -> bool { + match *self { + ConnectionAddr::Tcp(_, _) => true, + ConnectionAddr::TcpTls { .. } => { + cfg!(any(feature = "tls-native-tls", feature = "tls-rustls")) + } + ConnectionAddr::Unix(_) => cfg!(unix), + } + } +} + +impl fmt::Display for ConnectionAddr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // Cluster::get_connection_info depends on the return value from this function + match *self { + ConnectionAddr::Tcp(ref host, port) => write!(f, "{host}:{port}"), + ConnectionAddr::TcpTls { ref host, port, .. } => write!(f, "{host}:{port}"), + ConnectionAddr::Unix(ref path) => write!(f, "{}", path.display()), + } + } +} + +/// Holds the connection information that redis should use for connecting. +#[derive(Clone, Debug)] +pub struct ConnectionInfo { + /// A connection address for where to connect to. + pub addr: ConnectionAddr, + + /// A boxed connection address for where to connect to. + pub redis: RedisConnectionInfo, +} + +/// Types of pubsub subscriptions +/// See for more details +#[derive(Clone, Debug, PartialEq, Eq, Hash, Copy)] +pub enum PubSubSubscriptionKind { + /// Exact channel name. + /// Receives messages which are published to a specific channel using PUBLISH command. + Exact = 0, + /// Pattern-based channel name. + /// Receives messages which are published to channels matched by glob pattern using PUBLISH command. + Pattern = 1, + /// Sharded pubsub mode. + /// Receives messages which are published to a specific channel using SPUBLISH command. + Sharded = 2, +} + +impl From for usize { + fn from(val: PubSubSubscriptionKind) -> Self { + val as usize + } +} + +/// Type for pubsub channels/patterns +pub type PubSubChannelOrPattern = Vec; + +/// Type for pubsub channels/patterns +pub type PubSubSubscriptionInfo = HashMap>; + +/// Redis specific/connection independent information used to establish a connection to redis. +#[derive(Clone, Debug, Default)] +pub struct RedisConnectionInfo { + /// The database number to use. This is usually `0`. + pub db: i64, + /// Optionally a username that should be used for connection. + pub username: Option, + /// Optionally a password that should be used for connection. + pub password: Option, + /// Version of the protocol to use. + pub protocol: ProtocolVersion, + /// Optionally a client name that should be used for connection + pub client_name: Option, + /// Optionally a pubsub subscriptions that should be used for connection + pub pubsub_subscriptions: Option, +} + +impl FromStr for ConnectionInfo { + type Err = RedisError; + + fn from_str(s: &str) -> Result { + s.into_connection_info() + } +} + +/// Converts an object into a connection info struct. This allows the +/// constructor of the client to accept connection information in a +/// range of different formats. +pub trait IntoConnectionInfo { + /// Converts the object into a connection info object. + fn into_connection_info(self) -> RedisResult; +} + +impl IntoConnectionInfo for ConnectionInfo { + fn into_connection_info(self) -> RedisResult { + Ok(self) + } +} + +/// URL format: `{redis|rediss}://[][:@][:port][/]` +/// +/// - Basic: `redis://127.0.0.1:6379` +/// - Username & Password: `redis://user:password@127.0.0.1:6379` +/// - Password only: `redis://:password@127.0.0.1:6379` +/// - Specifying DB: `redis://127.0.0.1:6379/0` +/// - Enabling TLS: `rediss://127.0.0.1:6379` +/// - Enabling Insecure TLS: `rediss://127.0.0.1:6379/#insecure` +impl<'a> IntoConnectionInfo for &'a str { + fn into_connection_info(self) -> RedisResult { + match parse_redis_url(self) { + Some(u) => u.into_connection_info(), + None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")), + } + } +} + +impl IntoConnectionInfo for (T, u16) +where + T: Into, +{ + fn into_connection_info(self) -> RedisResult { + Ok(ConnectionInfo { + addr: ConnectionAddr::Tcp(self.0.into(), self.1), + redis: RedisConnectionInfo::default(), + }) + } +} + +/// URL format: `{redis|rediss}://[][:@][:port][/]` +/// +/// - Basic: `redis://127.0.0.1:6379` +/// - Username & Password: `redis://user:password@127.0.0.1:6379` +/// - Password only: `redis://:password@127.0.0.1:6379` +/// - Specifying DB: `redis://127.0.0.1:6379/0` +/// - Enabling TLS: `rediss://127.0.0.1:6379` +/// - Enabling Insecure TLS: `rediss://127.0.0.1:6379/#insecure` +impl IntoConnectionInfo for String { + fn into_connection_info(self) -> RedisResult { + match parse_redis_url(&self) { + Some(u) => u.into_connection_info(), + None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")), + } + } +} + +fn url_to_tcp_connection_info(url: url::Url) -> RedisResult { + let host = match url.host() { + Some(host) => { + // Here we manually match host's enum arms and call their to_string(). + // Because url.host().to_string() will add `[` and `]` for ipv6: + // https://docs.rs/url/latest/src/url/host.rs.html#170 + // And these brackets will break host.parse::() when + // `client.open()` - `ActualConnection::new()` - `addr.to_socket_addrs()`: + // https://doc.rust-lang.org/src/std/net/addr.rs.html#963 + // https://doc.rust-lang.org/src/std/net/parser.rs.html#158 + // IpAddr string with brackets can ONLY parse to SocketAddrV6: + // https://doc.rust-lang.org/src/std/net/parser.rs.html#255 + // But if we call Ipv6Addr.to_string directly, it follows rfc5952 without brackets: + // https://doc.rust-lang.org/src/std/net/ip.rs.html#1755 + match host { + url::Host::Domain(path) => path.to_string(), + url::Host::Ipv4(v4) => v4.to_string(), + url::Host::Ipv6(v6) => v6.to_string(), + } + } + None => fail!((ErrorKind::InvalidClientConfig, "Missing hostname")), + }; + let port = url.port().unwrap_or(DEFAULT_PORT); + let addr = if url.scheme() == "rediss" { + #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))] + { + match url.fragment() { + Some("insecure") => ConnectionAddr::TcpTls { + host, + port, + insecure: true, + tls_params: None, + }, + Some(_) => fail!(( + ErrorKind::InvalidClientConfig, + "only #insecure is supported as URL fragment" + )), + _ => ConnectionAddr::TcpTls { + host, + port, + insecure: false, + tls_params: None, + }, + } + } + + #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))] + fail!(( + ErrorKind::InvalidClientConfig, + "can't connect with TLS, the feature is not enabled" + )); + } else { + ConnectionAddr::Tcp(host, port) + }; + let query: HashMap<_, _> = url.query_pairs().collect(); + Ok(ConnectionInfo { + addr, + redis: RedisConnectionInfo { + db: match url.path().trim_matches('/') { + "" => 0, + path => path.parse::().map_err(|_| -> RedisError { + (ErrorKind::InvalidClientConfig, "Invalid database number").into() + })?, + }, + username: if url.username().is_empty() { + None + } else { + match percent_encoding::percent_decode(url.username().as_bytes()).decode_utf8() { + Ok(decoded) => Some(decoded.into_owned()), + Err(_) => fail!(( + ErrorKind::InvalidClientConfig, + "Username is not valid UTF-8 string" + )), + } + }, + password: match url.password() { + Some(pw) => match percent_encoding::percent_decode(pw.as_bytes()).decode_utf8() { + Ok(decoded) => Some(decoded.into_owned()), + Err(_) => fail!(( + ErrorKind::InvalidClientConfig, + "Password is not valid UTF-8 string" + )), + }, + None => None, + }, + protocol: match query.get("resp3") { + Some(v) => { + if v == "true" { + ProtocolVersion::RESP3 + } else { + ProtocolVersion::RESP2 + } + } + _ => ProtocolVersion::RESP2, + }, + client_name: None, + pubsub_subscriptions: None, + }, + }) +} + +#[cfg(unix)] +fn url_to_unix_connection_info(url: url::Url) -> RedisResult { + let query: HashMap<_, _> = url.query_pairs().collect(); + Ok(ConnectionInfo { + addr: ConnectionAddr::Unix(url.to_file_path().map_err(|_| -> RedisError { + (ErrorKind::InvalidClientConfig, "Missing path").into() + })?), + redis: RedisConnectionInfo { + db: match query.get("db") { + Some(db) => db.parse::().map_err(|_| -> RedisError { + (ErrorKind::InvalidClientConfig, "Invalid database number").into() + })?, + + None => 0, + }, + username: query.get("user").map(|username| username.to_string()), + password: query.get("pass").map(|password| password.to_string()), + protocol: match query.get("resp3") { + Some(v) => { + if v == "true" { + ProtocolVersion::RESP3 + } else { + ProtocolVersion::RESP2 + } + } + _ => ProtocolVersion::RESP2, + }, + client_name: None, + pubsub_subscriptions: None, + }, + }) +} + +#[cfg(not(unix))] +fn url_to_unix_connection_info(_: url::Url) -> RedisResult { + fail!(( + ErrorKind::InvalidClientConfig, + "Unix sockets are not available on this platform." + )); +} + +impl IntoConnectionInfo for url::Url { + fn into_connection_info(self) -> RedisResult { + match self.scheme() { + "redis" | "rediss" => url_to_tcp_connection_info(self), + "unix" | "redis+unix" => url_to_unix_connection_info(self), + _ => fail!(( + ErrorKind::InvalidClientConfig, + "URL provided is not a redis URL" + )), + } + } +} + +struct TcpConnection { + reader: TcpStream, + open: bool, +} + +#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] +struct TcpNativeTlsConnection { + reader: TlsStream, + open: bool, +} + +#[cfg(feature = "tls-rustls")] +struct TcpRustlsConnection { + reader: StreamOwned, + open: bool, +} + +#[cfg(unix)] +struct UnixConnection { + sock: UnixStream, + open: bool, +} + +enum ActualConnection { + Tcp(TcpConnection), + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + TcpNativeTls(Box), + #[cfg(feature = "tls-rustls")] + TcpRustls(Box), + #[cfg(unix)] + Unix(UnixConnection), +} + +#[cfg(feature = "tls-rustls-insecure")] +struct NoCertificateVerification { + supported: rustls::crypto::WebPkiSupportedAlgorithms, +} + +#[cfg(feature = "tls-rustls-insecure")] +impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification { + fn verify_server_cert( + &self, + _end_entity: &rustls_pki_types::CertificateDer<'_>, + _intermediates: &[rustls_pki_types::CertificateDer<'_>], + _server_name: &rustls_pki_types::ServerName<'_>, + _ocsp_response: &[u8], + _now: rustls_pki_types::UnixTime, + ) -> Result { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &rustls_pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &rustls_pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + self.supported.supported_schemes() + } +} + +#[cfg(feature = "tls-rustls-insecure")] +impl fmt::Debug for NoCertificateVerification { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("NoCertificateVerification").finish() + } +} + +/// Represents a stateful redis TCP connection. +pub struct Connection { + con: ActualConnection, + parser: Parser, + db: i64, + + /// Flag indicating whether the connection was left in the PubSub state after dropping `PubSub`. + /// + /// This flag is checked when attempting to send a command, and if it's raised, we attempt to + /// exit the pubsub state before executing the new request. + pubsub: bool, + + // Field indicating which protocol to use for server communications. + protocol: ProtocolVersion, + + /// `PushManager` instance for the connection. + /// This is used to manage Push messages in RESP3 mode. + push_manager: PushManager, +} + +/// Represents a pubsub connection. +pub struct PubSub<'a> { + con: &'a mut Connection, + waiting_messages: VecDeque, +} + +/// Represents a pubsub message. +#[derive(Debug)] +pub struct Msg { + payload: Value, + channel: Value, + pattern: Option, +} + +impl ActualConnection { + pub fn new(addr: &ConnectionAddr, timeout: Option) -> RedisResult { + Ok(match *addr { + ConnectionAddr::Tcp(ref host, ref port) => { + let addr = (host.as_str(), *port); + let tcp = match timeout { + None => connect_tcp(addr)?, + Some(timeout) => { + let mut tcp = None; + let mut last_error = None; + for addr in addr.to_socket_addrs()? { + match connect_tcp_timeout(&addr, timeout) { + Ok(l) => { + tcp = Some(l); + break; + } + Err(e) => { + last_error = Some(e); + } + }; + } + match (tcp, last_error) { + (Some(tcp), _) => tcp, + (None, Some(e)) => { + fail!(e); + } + (None, None) => { + fail!(( + ErrorKind::InvalidClientConfig, + "could not resolve to any addresses" + )); + } + } + } + }; + ActualConnection::Tcp(TcpConnection { + reader: tcp, + open: true, + }) + } + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ConnectionAddr::TcpTls { + ref host, + port, + insecure, + .. + } => { + let tls_connector = if insecure { + TlsConnector::builder() + .danger_accept_invalid_certs(true) + .danger_accept_invalid_hostnames(true) + .use_sni(false) + .build()? + } else { + TlsConnector::new()? + }; + let addr = (host.as_str(), port); + let tls = match timeout { + None => { + let tcp = connect_tcp(addr)?; + match tls_connector.connect(host, tcp) { + Ok(res) => res, + Err(e) => { + fail!((ErrorKind::IoError, "SSL Handshake error", e.to_string())); + } + } + } + Some(timeout) => { + let mut tcp = None; + let mut last_error = None; + for addr in (host.as_str(), port).to_socket_addrs()? { + match connect_tcp_timeout(&addr, timeout) { + Ok(l) => { + tcp = Some(l); + break; + } + Err(e) => { + last_error = Some(e); + } + }; + } + match (tcp, last_error) { + (Some(tcp), _) => tls_connector.connect(host, tcp).unwrap(), + (None, Some(e)) => { + fail!(e); + } + (None, None) => { + fail!(( + ErrorKind::InvalidClientConfig, + "could not resolve to any addresses" + )); + } + } + } + }; + ActualConnection::TcpNativeTls(Box::new(TcpNativeTlsConnection { + reader: tls, + open: true, + })) + } + #[cfg(feature = "tls-rustls")] + ConnectionAddr::TcpTls { + ref host, + port, + insecure, + ref tls_params, + } => { + let host: &str = host; + let config = create_rustls_config(insecure, tls_params.clone())?; + let conn = rustls::ClientConnection::new( + Arc::new(config), + rustls_pki_types::ServerName::try_from(host)?.to_owned(), + )?; + let reader = match timeout { + None => { + let tcp = connect_tcp((host, port))?; + StreamOwned::new(conn, tcp) + } + Some(timeout) => { + let mut tcp = None; + let mut last_error = None; + for addr in (host, port).to_socket_addrs()? { + match connect_tcp_timeout(&addr, timeout) { + Ok(l) => { + tcp = Some(l); + break; + } + Err(e) => { + last_error = Some(e); + } + }; + } + match (tcp, last_error) { + (Some(tcp), _) => StreamOwned::new(conn, tcp), + (None, Some(e)) => { + fail!(e); + } + (None, None) => { + fail!(( + ErrorKind::InvalidClientConfig, + "could not resolve to any addresses" + )); + } + } + } + }; + + ActualConnection::TcpRustls(Box::new(TcpRustlsConnection { reader, open: true })) + } + #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))] + ConnectionAddr::TcpTls { .. } => { + fail!(( + ErrorKind::InvalidClientConfig, + "Cannot connect to TCP with TLS without the tls feature" + )); + } + #[cfg(unix)] + ConnectionAddr::Unix(ref path) => ActualConnection::Unix(UnixConnection { + sock: UnixStream::connect(path)?, + open: true, + }), + #[cfg(not(unix))] + ConnectionAddr::Unix(ref _path) => { + fail!(( + ErrorKind::InvalidClientConfig, + "Cannot connect to unix sockets \ + on this platform" + )); + } + }) + } + + pub fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult { + match *self { + ActualConnection::Tcp(ref mut connection) => { + let res = connection.reader.write_all(bytes).map_err(RedisError::from); + match res { + Err(e) => { + if e.is_unrecoverable_error() { + connection.open = false; + } + Err(e) + } + Ok(_) => Ok(Value::Okay), + } + } + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref mut connection) => { + let res = connection.reader.write_all(bytes).map_err(RedisError::from); + match res { + Err(e) => { + if e.is_unrecoverable_error() { + connection.open = false; + } + Err(e) + } + Ok(_) => Ok(Value::Okay), + } + } + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref mut connection) => { + let res = connection.reader.write_all(bytes).map_err(RedisError::from); + match res { + Err(e) => { + if e.is_unrecoverable_error() { + connection.open = false; + } + Err(e) + } + Ok(_) => Ok(Value::Okay), + } + } + #[cfg(unix)] + ActualConnection::Unix(ref mut connection) => { + let result = connection.sock.write_all(bytes).map_err(RedisError::from); + match result { + Err(e) => { + if e.is_unrecoverable_error() { + connection.open = false; + } + Err(e) + } + Ok(_) => Ok(Value::Okay), + } + } + } + } + + pub fn set_write_timeout(&self, dur: Option) -> RedisResult<()> { + match *self { + ActualConnection::Tcp(TcpConnection { ref reader, .. }) => { + reader.set_write_timeout(dur)?; + } + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref boxed_tls_connection) => { + let reader = &(boxed_tls_connection.reader); + reader.get_ref().set_write_timeout(dur)?; + } + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref boxed_tls_connection) => { + let reader = &(boxed_tls_connection.reader); + reader.get_ref().set_write_timeout(dur)?; + } + #[cfg(unix)] + ActualConnection::Unix(UnixConnection { ref sock, .. }) => { + sock.set_write_timeout(dur)?; + } + } + Ok(()) + } + + pub fn set_read_timeout(&self, dur: Option) -> RedisResult<()> { + match *self { + ActualConnection::Tcp(TcpConnection { ref reader, .. }) => { + reader.set_read_timeout(dur)?; + } + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref boxed_tls_connection) => { + let reader = &(boxed_tls_connection.reader); + reader.get_ref().set_read_timeout(dur)?; + } + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref boxed_tls_connection) => { + let reader = &(boxed_tls_connection.reader); + reader.get_ref().set_read_timeout(dur)?; + } + #[cfg(unix)] + ActualConnection::Unix(UnixConnection { ref sock, .. }) => { + sock.set_read_timeout(dur)?; + } + } + Ok(()) + } + + pub fn is_open(&self) -> bool { + match *self { + ActualConnection::Tcp(TcpConnection { open, .. }) => open, + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref boxed_tls_connection) => boxed_tls_connection.open, + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref boxed_tls_connection) => boxed_tls_connection.open, + #[cfg(unix)] + ActualConnection::Unix(UnixConnection { open, .. }) => open, + } + } +} + +#[cfg(feature = "tls-rustls")] +pub(crate) fn create_rustls_config( + insecure: bool, + tls_params: Option, +) -> RedisResult { + use crate::tls::ClientTlsParams; + + #[allow(unused_mut)] + let mut root_store = RootCertStore::empty(); + #[cfg(feature = "tls-rustls-webpki-roots")] + root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + #[cfg(all( + feature = "tls-rustls", + not(feature = "tls-native-tls"), + not(feature = "tls-rustls-webpki-roots") + ))] + for cert in load_native_certs()? { + root_store.add(cert)?; + } + + let config = rustls::ClientConfig::builder(); + let config = if let Some(tls_params) = tls_params { + let config_builder = + config.with_root_certificates(tls_params.root_cert_store.unwrap_or(root_store)); + + if let Some(ClientTlsParams { + client_cert_chain: client_cert, + client_key, + }) = tls_params.client_tls_params + { + config_builder + .with_client_auth_cert(client_cert, client_key) + .map_err(|err| { + RedisError::from(( + ErrorKind::InvalidClientConfig, + "Unable to build client with TLS parameters provided.", + err.to_string(), + )) + })? + } else { + config_builder.with_no_client_auth() + } + } else { + config + .with_root_certificates(root_store) + .with_no_client_auth() + }; + + match (insecure, cfg!(feature = "tls-rustls-insecure")) { + #[cfg(feature = "tls-rustls-insecure")] + (true, true) => { + let mut config = config; + config.enable_sni = false; + // nosemgrep + config + .dangerous() + .set_certificate_verifier(Arc::new(NoCertificateVerification { + supported: rustls::crypto::ring::default_provider() + .signature_verification_algorithms, + })); + + Ok(config) + } + (true, false) => { + fail!(( + ErrorKind::InvalidClientConfig, + "Cannot create insecure client without tls-rustls-insecure feature" + )); + } + _ => Ok(config), + } +} + +fn connect_auth(con: &mut Connection, connection_info: &RedisConnectionInfo) -> RedisResult<()> { + let mut command = cmd("AUTH"); + if let Some(username) = &connection_info.username { + command.arg(username); + } + let password = connection_info.password.as_ref().unwrap(); + let err = match command.arg(password).query::(con) { + Ok(Value::Okay) => return Ok(()), + Ok(_) => { + fail!(( + ErrorKind::ResponseError, + "Redis server refused to authenticate, returns Ok() != Value::Okay" + )); + } + Err(e) => e, + }; + let err_msg = err.detail().ok_or(( + ErrorKind::AuthenticationFailed, + "Password authentication failed", + ))?; + if !err_msg.contains("wrong number of arguments for 'auth' command") { + fail!(( + ErrorKind::AuthenticationFailed, + "Password authentication failed", + )); + } + + // fallback to AUTH version <= 5 + let mut command = cmd("AUTH"); + match command.arg(password).query::(con) { + Ok(Value::Okay) => Ok(()), + _ => fail!(( + ErrorKind::AuthenticationFailed, + "Password authentication failed", + )), + } +} + +pub fn connect( + connection_info: &ConnectionInfo, + timeout: Option, +) -> RedisResult { + let con = ActualConnection::new(&connection_info.addr, timeout)?; + setup_connection(con, &connection_info.redis) +} + +#[cfg(not(feature = "disable-client-setinfo"))] +pub(crate) fn client_set_info_pipeline() -> Pipeline { + let mut pipeline = crate::pipe(); + pipeline + .cmd("CLIENT") + .arg("SETINFO") + .arg("LIB-NAME") + .arg(std::env!("GLIDE_NAME")) + .ignore(); + pipeline + .cmd("CLIENT") + .arg("SETINFO") + .arg("LIB-VER") + .arg(std::env!("GLIDE_VERSION")) + .ignore(); + pipeline +} + +fn setup_connection( + con: ActualConnection, + connection_info: &RedisConnectionInfo, +) -> RedisResult { + let mut rv = Connection { + con, + parser: Parser::new(), + db: connection_info.db, + pubsub: false, + protocol: connection_info.protocol, + push_manager: PushManager::new(), + }; + + if connection_info.protocol != ProtocolVersion::RESP2 { + let hello_cmd = resp3_hello(connection_info); + let val: RedisResult = hello_cmd.query(&mut rv); + if let Err(err) = val { + return Err(get_resp3_hello_command_error(err)); + } + } else if connection_info.password.is_some() { + connect_auth(&mut rv, connection_info)?; + } + if connection_info.db != 0 { + match cmd("SELECT") + .arg(connection_info.db) + .query::(&mut rv) + { + Ok(Value::Okay) => {} + _ => fail!(( + ErrorKind::ResponseError, + "Redis server refused to switch database" + )), + } + } + + if connection_info.client_name.is_some() { + match cmd("CLIENT") + .arg("SETNAME") + .arg(connection_info.client_name.as_ref().unwrap()) + .query::(&mut rv) + { + Ok(Value::Okay) => {} + _ => fail!(( + ErrorKind::ResponseError, + "Redis server refused to set client name" + )), + } + } + + // result is ignored, as per the command's instructions. + // https://redis.io/commands/client-setinfo/ + #[cfg(not(feature = "disable-client-setinfo"))] + let _: RedisResult<()> = client_set_info_pipeline().query(&mut rv); + + Ok(rv) +} + +/// Implements the "stateless" part of the connection interface that is used by the +/// different objects in redis-rs. Primarily it obviously applies to `Connection` +/// object but also some other objects implement the interface (for instance +/// whole clients or certain redis results). +/// +/// Generally clients and connections (as well as redis results of those) implement +/// this trait. Actual connections provide more functionality which can be used +/// to implement things like `PubSub` but they also can modify the intrinsic +/// state of the TCP connection. This is not possible with `ConnectionLike` +/// implementors because that functionality is not exposed. +pub trait ConnectionLike { + /// Sends an already encoded (packed) command into the TCP socket and + /// reads the single response from it. + fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult; + + /// Sends multiple already encoded (packed) command into the TCP socket + /// and reads `count` responses from it. This is used to implement + /// pipelining. + /// Important - this function is meant for internal usage, since it's + /// easy to pass incorrect `offset` & `count` parameters, which might + /// cause the connection to enter an erroneous state. Users shouldn't + /// call it, instead using the Pipeline::query function. + #[doc(hidden)] + fn req_packed_commands( + &mut self, + cmd: &[u8], + offset: usize, + count: usize, + ) -> RedisResult>; + + /// Sends a [Cmd] into the TCP socket and reads a single response from it. + fn req_command(&mut self, cmd: &Cmd) -> RedisResult { + let pcmd = cmd.get_packed_command(); + self.req_packed_command(&pcmd) + } + + /// Returns the database this connection is bound to. Note that this + /// information might be unreliable because it's initially cached and + /// also might be incorrect if the connection like object is not + /// actually connected. + fn get_db(&self) -> i64; + + /// Does this connection support pipelining? + #[doc(hidden)] + fn supports_pipelining(&self) -> bool { + true + } + + /// Check that all connections it has are available (`PING` internally). + fn check_connection(&mut self) -> bool; + + /// Returns the connection status. + /// + /// The connection is open until any `read_response` call recieved an + /// invalid response from the server (most likely a closed or dropped + /// connection, otherwise a Redis protocol error). When using unix + /// sockets the connection is open until writing a command failed with a + /// `BrokenPipe` error. + fn is_open(&self) -> bool; +} + +/// A connection is an object that represents a single redis connection. It +/// provides basic support for sending encoded commands into a redis connection +/// and to read a response from it. It's bound to a single database and can +/// only be created from the client. +/// +/// You generally do not much with this object other than passing it to +/// `Cmd` objects. +impl Connection { + /// Sends an already encoded (packed) command into the TCP socket and + /// does not read a response. This is useful for commands like + /// `MONITOR` which yield multiple items. This needs to be used with + /// care because it changes the state of the connection. + pub fn send_packed_command(&mut self, cmd: &[u8]) -> RedisResult<()> { + self.send_bytes(cmd)?; + Ok(()) + } + + /// Fetches a single response from the connection. This is useful + /// if used in combination with `send_packed_command`. + pub fn recv_response(&mut self) -> RedisResult { + self.read_response() + } + + /// Sets the write timeout for the connection. + /// + /// If the provided value is `None`, then `send_packed_command` call will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + pub fn set_write_timeout(&self, dur: Option) -> RedisResult<()> { + self.con.set_write_timeout(dur) + } + + /// Sets the read timeout for the connection. + /// + /// If the provided value is `None`, then `recv_response` call will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + pub fn set_read_timeout(&self, dur: Option) -> RedisResult<()> { + self.con.set_read_timeout(dur) + } + + /// Creates a [`PubSub`] instance for this connection. + pub fn as_pubsub(&mut self) -> PubSub<'_> { + // NOTE: The pubsub flag is intentionally not raised at this time since + // running commands within the pubsub state should not try and exit from + // the pubsub state. + PubSub::new(self) + } + + fn exit_pubsub(&mut self) -> RedisResult<()> { + let res = self.clear_active_subscriptions(); + if res.is_ok() { + self.pubsub = false; + } else { + // Raise the pubsub flag to indicate the connection is "stuck" in that state. + self.pubsub = true; + } + + res + } + + /// Get the inner connection out of a PubSub + /// + /// Any active subscriptions are unsubscribed. In the event of an error, the connection is + /// dropped. + fn clear_active_subscriptions(&mut self) -> RedisResult<()> { + // Responses to unsubscribe commands return in a 3-tuple with values + // ("unsubscribe" or "punsubscribe", name of subscription removed, count of remaining subs). + // The "count of remaining subs" includes both pattern subscriptions and non pattern + // subscriptions. Thus, to accurately drain all unsubscribe messages received from the + // server, both commands need to be executed at once. + { + // Prepare both unsubscribe commands + let unsubscribe = cmd("UNSUBSCRIBE").get_packed_command(); + let punsubscribe = cmd("PUNSUBSCRIBE").get_packed_command(); + + // Execute commands + self.send_bytes(&unsubscribe)?; + self.send_bytes(&punsubscribe)?; + } + + // Receive responses + // + // There will be at minimum two responses - 1 for each of punsubscribe and unsubscribe + // commands. There may be more responses if there are active subscriptions. In this case, + // messages are received until the _subscription count_ in the responses reach zero. + let mut received_unsub = false; + let mut received_punsub = false; + if self.protocol != ProtocolVersion::RESP2 { + while let Value::Push { kind, data } = from_owned_redis_value(self.recv_response()?)? { + if data.len() >= 2 { + if let Value::Int(num) = data[1] { + if resp3_is_pub_sub_state_cleared( + &mut received_unsub, + &mut received_punsub, + &kind, + num as isize, + ) { + break; + } + } + } + } + } else { + loop { + let res: (Vec, (), isize) = from_owned_redis_value(self.recv_response()?)?; + if resp2_is_pub_sub_state_cleared( + &mut received_unsub, + &mut received_punsub, + &res.0, + res.2, + ) { + break; + } + } + } + + // Finally, the connection is back in its normal state since all subscriptions were + // cancelled *and* all unsubscribe messages were received. + Ok(()) + } + + /// Fetches a single response from the connection. + fn read_response(&mut self) -> RedisResult { + let result = match self.con { + ActualConnection::Tcp(TcpConnection { ref mut reader, .. }) => { + let result = self.parser.parse_value(reader); + self.push_manager.try_send(&result); + result + } + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref mut boxed_tls_connection) => { + let reader = &mut boxed_tls_connection.reader; + let result = self.parser.parse_value(reader); + self.push_manager.try_send(&result); + result + } + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref mut boxed_tls_connection) => { + let reader = &mut boxed_tls_connection.reader; + let result = self.parser.parse_value(reader); + self.push_manager.try_send(&result); + result + } + #[cfg(unix)] + ActualConnection::Unix(UnixConnection { ref mut sock, .. }) => { + let result = self.parser.parse_value(sock); + self.push_manager.try_send(&result); + result + } + }; + // shutdown connection on protocol error + if let Err(e) = &result { + let shutdown = match e.as_io_error() { + Some(e) => e.kind() == io::ErrorKind::UnexpectedEof, + None => false, + }; + if shutdown { + // Notify the PushManager that the connection was lost + self.push_manager.try_send_raw(&Value::Push { + kind: PushKind::Disconnection, + data: vec![], + }); + match self.con { + ActualConnection::Tcp(ref mut connection) => { + let _ = connection.reader.shutdown(net::Shutdown::Both); + connection.open = false; + } + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref mut connection) => { + let _ = connection.reader.shutdown(); + connection.open = false; + } + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref mut connection) => { + let _ = connection.reader.get_mut().shutdown(net::Shutdown::Both); + connection.open = false; + } + #[cfg(unix)] + ActualConnection::Unix(ref mut connection) => { + let _ = connection.sock.shutdown(net::Shutdown::Both); + connection.open = false; + } + } + } + } + result + } + + /// Returns `PushManager` of Connection, this method is used to subscribe/unsubscribe from Push types + pub fn get_push_manager(&self) -> PushManager { + self.push_manager.clone() + } + + fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult { + let result = self.con.send_bytes(bytes); + if self.protocol != ProtocolVersion::RESP2 { + if let Err(e) = &result { + if e.is_connection_dropped() { + // Notify the PushManager that the connection was lost + self.push_manager.try_send_raw(&Value::Push { + kind: PushKind::Disconnection, + data: vec![], + }); + } + } + } + result + } +} + +impl ConnectionLike for Connection { + /// Sends a [Cmd] into the TCP socket and reads a single response from it. + fn req_command(&mut self, cmd: &Cmd) -> RedisResult { + let pcmd = cmd.get_packed_command(); + if self.pubsub { + self.exit_pubsub()?; + } + + self.send_bytes(&pcmd)?; + if cmd.is_no_response() { + return Ok(Value::Nil); + } + loop { + match self.read_response()? { + Value::Push { + kind: _kind, + data: _data, + } => continue, + val => return Ok(val), + } + } + } + fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult { + if self.pubsub { + self.exit_pubsub()?; + } + + self.send_bytes(cmd)?; + loop { + match self.read_response()? { + Value::Push { + kind: _kind, + data: _data, + } => continue, + val => return Ok(val), + } + } + } + + fn req_packed_commands( + &mut self, + cmd: &[u8], + offset: usize, + count: usize, + ) -> RedisResult> { + if self.pubsub { + self.exit_pubsub()?; + } + self.send_bytes(cmd)?; + let mut rv = vec![]; + let mut first_err = None; + let mut count = count; + let mut idx = 0; + while idx < (offset + count) { + // When processing a transaction, some responses may be errors. + // We need to keep processing the rest of the responses in that case, + // so bailing early with `?` would not be correct. + // See: https://github.com/redis-rs/redis-rs/issues/436 + let response = self.read_response(); + match response { + Ok(item) => { + // RESP3 can insert push data between command replies + if let Value::Push { + kind: _kind, + data: _data, + } = item + { + // if that is the case we have to extend the loop and handle push data + count += 1; + } else if idx >= offset { + rv.push(item); + } + } + Err(err) => { + if first_err.is_none() { + first_err = Some(err); + } + } + } + idx += 1; + } + + first_err.map_or(Ok(rv), Err) + } + + fn get_db(&self) -> i64 { + self.db + } + + fn check_connection(&mut self) -> bool { + cmd("PING").query::(self).is_ok() + } + + fn is_open(&self) -> bool { + self.con.is_open() + } +} + +impl ConnectionLike for T +where + C: ConnectionLike, + T: DerefMut, +{ + fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult { + self.deref_mut().req_packed_command(cmd) + } + + fn req_packed_commands( + &mut self, + cmd: &[u8], + offset: usize, + count: usize, + ) -> RedisResult> { + self.deref_mut().req_packed_commands(cmd, offset, count) + } + + fn req_command(&mut self, cmd: &Cmd) -> RedisResult { + self.deref_mut().req_command(cmd) + } + + fn get_db(&self) -> i64 { + self.deref().get_db() + } + + fn supports_pipelining(&self) -> bool { + self.deref().supports_pipelining() + } + + fn check_connection(&mut self) -> bool { + self.deref_mut().check_connection() + } + + fn is_open(&self) -> bool { + self.deref().is_open() + } +} + +/// The pubsub object provides convenient access to the redis pubsub +/// system. Once created you can subscribe and unsubscribe from channels +/// and listen in on messages. +/// +/// Example: +/// +/// ```rust,no_run +/// # fn do_something() -> redis::RedisResult<()> { +/// let client = redis::Client::open("redis://127.0.0.1/")?; +/// let mut con = client.get_connection(None)?; +/// let mut pubsub = con.as_pubsub(); +/// pubsub.subscribe("channel_1")?; +/// pubsub.subscribe("channel_2")?; +/// +/// loop { +/// let msg = pubsub.get_message()?; +/// let payload : String = msg.get_payload()?; +/// println!("channel '{}': {}", msg.get_channel_name(), payload); +/// } +/// # } +/// ``` +impl<'a> PubSub<'a> { + fn new(con: &'a mut Connection) -> Self { + Self { + con, + waiting_messages: VecDeque::new(), + } + } + + fn cache_messages_until_received_response(&mut self, cmd: &mut Cmd) -> RedisResult<()> { + if self.con.protocol != ProtocolVersion::RESP2 { + cmd.set_no_response(true); + } + let mut response = cmd.query(self.con)?; + loop { + if let Some(msg) = Msg::from_value(&response) { + self.waiting_messages.push_back(msg); + } else { + return Ok(()); + } + response = self.con.recv_response()?; + } + } + + /// Subscribes to a new channel. + pub fn subscribe(&mut self, channel: T) -> RedisResult<()> { + self.cache_messages_until_received_response(cmd("SUBSCRIBE").arg(channel)) + } + + /// Subscribes to a new channel with a pattern. + pub fn psubscribe(&mut self, pchannel: T) -> RedisResult<()> { + self.cache_messages_until_received_response(cmd("PSUBSCRIBE").arg(pchannel)) + } + + /// Unsubscribes from a channel. + pub fn unsubscribe(&mut self, channel: T) -> RedisResult<()> { + self.cache_messages_until_received_response(cmd("UNSUBSCRIBE").arg(channel)) + } + + /// Unsubscribes from a channel with a pattern. + pub fn punsubscribe(&mut self, pchannel: T) -> RedisResult<()> { + self.cache_messages_until_received_response(cmd("PUNSUBSCRIBE").arg(pchannel)) + } + + /// Fetches the next message from the pubsub connection. Blocks until + /// a message becomes available. This currently does not provide a + /// wait not to block :( + /// + /// The message itself is still generic and can be converted into an + /// appropriate type through the helper methods on it. + pub fn get_message(&mut self) -> RedisResult { + if let Some(msg) = self.waiting_messages.pop_front() { + return Ok(msg); + } + loop { + if let Some(msg) = Msg::from_value(&self.con.recv_response()?) { + return Ok(msg); + } else { + continue; + } + } + } + + /// Sets the read timeout for the connection. + /// + /// If the provided value is `None`, then `get_message` call will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + pub fn set_read_timeout(&self, dur: Option) -> RedisResult<()> { + self.con.set_read_timeout(dur) + } +} + +impl<'a> Drop for PubSub<'a> { + fn drop(&mut self) { + let _ = self.con.exit_pubsub(); + } +} + +/// This holds the data that comes from listening to a pubsub +/// connection. It only contains actual message data. +impl Msg { + /// Tries to convert provided [`Value`] into [`Msg`]. + #[allow(clippy::unnecessary_to_owned)] + pub fn from_value(value: &Value) -> Option { + let mut pattern = None; + let payload; + let channel; + + if let Value::Push { kind, data } = value { + let mut iter: IntoIter = data.to_vec().into_iter(); + if kind == &PushKind::Message || kind == &PushKind::SMessage { + channel = iter.next()?; + payload = iter.next()?; + } else if kind == &PushKind::PMessage { + pattern = Some(iter.next()?); + channel = iter.next()?; + payload = iter.next()?; + } else { + return None; + } + } else { + let raw_msg: Vec = from_redis_value(value).ok()?; + let mut iter = raw_msg.into_iter(); + let msg_type: String = from_owned_redis_value(iter.next()?).ok()?; + if msg_type == "message" { + channel = iter.next()?; + payload = iter.next()?; + } else if msg_type == "pmessage" { + pattern = Some(iter.next()?); + channel = iter.next()?; + payload = iter.next()?; + } else { + return None; + } + }; + Some(Msg { + payload, + channel, + pattern, + }) + } + + /// Tries to convert provided [`PushInfo`] into [`Msg`]. + pub fn from_push_info(push_info: &PushInfo) -> Option { + let mut pattern = None; + let payload; + let channel; + + let mut iter = push_info.data.iter().cloned(); + if push_info.kind == PushKind::Message || push_info.kind == PushKind::SMessage { + channel = iter.next()?; + payload = iter.next()?; + } else if push_info.kind == PushKind::PMessage { + pattern = Some(iter.next()?); + channel = iter.next()?; + payload = iter.next()?; + } else { + return None; + } + + Some(Msg { + payload, + channel, + pattern, + }) + } + + /// Returns the channel this message came on. + pub fn get_channel(&self) -> RedisResult { + from_redis_value(&self.channel) + } + + /// Convenience method to get a string version of the channel. Unless + /// your channel contains non utf-8 bytes you can always use this + /// method. If the channel is not a valid string (which really should + /// not happen) then the return value is `"?"`. + pub fn get_channel_name(&self) -> &str { + match self.channel { + Value::BulkString(ref bytes) => from_utf8(bytes).unwrap_or("?"), + _ => "?", + } + } + + /// Returns the message's payload in a specific format. + pub fn get_payload(&self) -> RedisResult { + from_redis_value(&self.payload) + } + + /// Returns the bytes that are the message's payload. This can be used + /// as an alternative to the `get_payload` function if you are interested + /// in the raw bytes in it. + pub fn get_payload_bytes(&self) -> &[u8] { + match self.payload { + Value::BulkString(ref bytes) => bytes, + _ => b"", + } + } + + /// Returns true if the message was constructed from a pattern + /// subscription. + #[allow(clippy::wrong_self_convention)] + pub fn from_pattern(&self) -> bool { + self.pattern.is_some() + } + + /// If the message was constructed from a message pattern this can be + /// used to find out which one. It's recommended to match against + /// an `Option` so that you do not need to use `from_pattern` + /// to figure out if a pattern was set. + pub fn get_pattern(&self) -> RedisResult { + match self.pattern { + None => from_redis_value(&Value::Nil), + Some(ref x) => from_redis_value(x), + } + } +} + +/// This function simplifies transaction management slightly. What it +/// does is automatically watching keys and then going into a transaction +/// loop util it succeeds. Once it goes through the results are +/// returned. +/// +/// To use the transaction two pieces of information are needed: a list +/// of all the keys that need to be watched for modifications and a +/// closure with the code that should be execute in the context of the +/// transaction. The closure is invoked with a fresh pipeline in atomic +/// mode. To use the transaction the function needs to return the result +/// from querying the pipeline with the connection. +/// +/// The end result of the transaction is then available as the return +/// value from the function call. +/// +/// Example: +/// +/// ```rust,no_run +/// use redis::Commands; +/// # fn do_something() -> redis::RedisResult<()> { +/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +/// # let mut con = client.get_connection(None).unwrap(); +/// let key = "the_key"; +/// let (new_val,) : (isize,) = redis::transaction(&mut con, &[key], |con, pipe| { +/// let old_val : isize = con.get(key)?; +/// pipe +/// .set(key, old_val + 1).ignore() +/// .get(key).query(con) +/// })?; +/// println!("The incremented number is: {}", new_val); +/// # Ok(()) } +/// ``` +pub fn transaction< + C: ConnectionLike, + K: ToRedisArgs, + T, + F: FnMut(&mut C, &mut Pipeline) -> RedisResult>, +>( + con: &mut C, + keys: &[K], + func: F, +) -> RedisResult { + let mut func = func; + loop { + cmd("WATCH").arg(keys).query::<()>(con)?; + let mut p = pipe(); + let response: Option = func(con, p.atomic())?; + match response { + None => { + continue; + } + Some(response) => { + // make sure no watch is left in the connection, even if + // someone forgot to use the pipeline. + cmd("UNWATCH").query::<()>(con)?; + return Ok(response); + } + } + } +} +//TODO: for both clearing logic support sharded channels. + +/// Common logic for clearing subscriptions in RESP2 async/sync +pub fn resp2_is_pub_sub_state_cleared( + received_unsub: &mut bool, + received_punsub: &mut bool, + kind: &[u8], + num: isize, +) -> bool { + match kind.first() { + Some(&b'u') => *received_unsub = true, + Some(&b'p') => *received_punsub = true, + _ => (), + }; + *received_unsub && *received_punsub && num == 0 +} + +/// Common logic for clearing subscriptions in RESP3 async/sync +pub fn resp3_is_pub_sub_state_cleared( + received_unsub: &mut bool, + received_punsub: &mut bool, + kind: &PushKind, + num: isize, +) -> bool { + match kind { + PushKind::Unsubscribe => *received_unsub = true, + PushKind::PUnsubscribe => *received_punsub = true, + _ => (), + }; + *received_unsub && *received_punsub && num == 0 +} + +/// Common logic for checking real cause of hello3 command error +pub fn get_resp3_hello_command_error(err: RedisError) -> RedisError { + if let Some(detail) = err.detail() { + if detail.starts_with("unknown command `HELLO`") { + return ( + ErrorKind::RESP3NotSupported, + "Redis Server doesn't support HELLO command therefore resp3 cannot be used", + ) + .into(); + } + } + err +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_redis_url() { + let cases = vec![ + ("redis://127.0.0.1", true), + ("redis://[::1]", true), + ("redis+unix:///run/redis.sock", true), + ("unix:///run/redis.sock", true), + ("http://127.0.0.1", false), + ("tcp://127.0.0.1", false), + ]; + for (url, expected) in cases.into_iter() { + let res = parse_redis_url(url); + assert_eq!( + res.is_some(), + expected, + "Parsed result of `{url}` is not expected", + ); + } + } + + #[test] + fn test_url_to_tcp_connection_info() { + let cases = vec![ + ( + url::Url::parse("redis://127.0.0.1").unwrap(), + ConnectionInfo { + addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379), + redis: Default::default(), + }, + ), + ( + url::Url::parse("redis://[::1]").unwrap(), + ConnectionInfo { + addr: ConnectionAddr::Tcp("::1".to_string(), 6379), + redis: Default::default(), + }, + ), + ( + url::Url::parse("redis://%25johndoe%25:%23%40%3C%3E%24@example.com/2").unwrap(), + ConnectionInfo { + addr: ConnectionAddr::Tcp("example.com".to_string(), 6379), + redis: RedisConnectionInfo { + db: 2, + username: Some("%johndoe%".to_string()), + password: Some("#@<>$".to_string()), + ..Default::default() + }, + }, + ), + ]; + for (url, expected) in cases.into_iter() { + let res = url_to_tcp_connection_info(url.clone()).unwrap(); + assert_eq!(res.addr, expected.addr, "addr of {url} is not expected"); + assert_eq!( + res.redis.db, expected.redis.db, + "db of {url} is not expected", + ); + assert_eq!( + res.redis.username, expected.redis.username, + "username of {url} is not expected", + ); + assert_eq!( + res.redis.password, expected.redis.password, + "password of {url} is not expected", + ); + } + } + + #[test] + fn test_url_to_tcp_connection_info_failed() { + let cases = vec![ + (url::Url::parse("redis://").unwrap(), "Missing hostname"), + ( + url::Url::parse("redis://127.0.0.1/db").unwrap(), + "Invalid database number", + ), + ( + url::Url::parse("redis://C3%B0@127.0.0.1").unwrap(), + "Username is not valid UTF-8 string", + ), + ( + url::Url::parse("redis://:C3%B0@127.0.0.1").unwrap(), + "Password is not valid UTF-8 string", + ), + ]; + for (url, expected) in cases.into_iter() { + let res = url_to_tcp_connection_info(url).unwrap_err(); + assert_eq!( + res.kind(), + crate::ErrorKind::InvalidClientConfig, + "{}", + &res, + ); + #[allow(deprecated)] + let desc = std::error::Error::description(&res); + assert_eq!(desc, expected, "{}", &res); + assert_eq!(res.detail(), None, "{}", &res); + } + } + + #[test] + #[cfg(unix)] + fn test_url_to_unix_connection_info() { + let cases = vec![ + ( + url::Url::parse("unix:///var/run/redis.sock").unwrap(), + ConnectionInfo { + addr: ConnectionAddr::Unix("/var/run/redis.sock".into()), + redis: RedisConnectionInfo { + db: 0, + username: None, + password: None, + protocol: ProtocolVersion::RESP2, + client_name: None, + pubsub_subscriptions: None, + }, + }, + ), + ( + url::Url::parse("redis+unix:///var/run/redis.sock?db=1").unwrap(), + ConnectionInfo { + addr: ConnectionAddr::Unix("/var/run/redis.sock".into()), + redis: RedisConnectionInfo { + db: 1, + ..Default::default() + }, + }, + ), + ( + url::Url::parse( + "unix:///example.sock?user=%25johndoe%25&pass=%23%40%3C%3E%24&db=2", + ) + .unwrap(), + ConnectionInfo { + addr: ConnectionAddr::Unix("/example.sock".into()), + redis: RedisConnectionInfo { + db: 2, + username: Some("%johndoe%".to_string()), + password: Some("#@<>$".to_string()), + ..Default::default() + }, + }, + ), + ( + url::Url::parse( + "redis+unix:///example.sock?pass=%26%3F%3D+%2A%2B&db=2&user=%25johndoe%25", + ) + .unwrap(), + ConnectionInfo { + addr: ConnectionAddr::Unix("/example.sock".into()), + redis: RedisConnectionInfo { + db: 2, + username: Some("%johndoe%".to_string()), + password: Some("&?= *+".to_string()), + ..Default::default() + }, + }, + ), + ]; + for (url, expected) in cases.into_iter() { + assert_eq!( + ConnectionAddr::Unix(url.to_file_path().unwrap()), + expected.addr, + "addr of {url} is not expected", + ); + let res = url_to_unix_connection_info(url.clone()).unwrap(); + assert_eq!(res.addr, expected.addr, "addr of {url} is not expected"); + assert_eq!( + res.redis.db, expected.redis.db, + "db of {url} is not expected", + ); + assert_eq!( + res.redis.username, expected.redis.username, + "username of {url} is not expected", + ); + assert_eq!( + res.redis.password, expected.redis.password, + "password of {url} is not expected", + ); + } + } +} diff --git a/glide-core/redis-rs/redis/src/geo.rs b/glide-core/redis-rs/redis/src/geo.rs new file mode 100644 index 0000000000..6195264a7c --- /dev/null +++ b/glide-core/redis-rs/redis/src/geo.rs @@ -0,0 +1,361 @@ +//! Defines types to use with the geospatial commands. + +use super::{ErrorKind, RedisResult}; +use crate::types::{FromRedisValue, RedisWrite, ToRedisArgs, Value}; + +macro_rules! invalid_type_error { + ($v:expr, $det:expr) => {{ + fail!(( + ErrorKind::TypeError, + "Response was of incompatible type", + format!("{:?} (response was {:?})", $det, $v) + )); + }}; +} + +/// Units used by [`geo_dist`][1] and [`geo_radius`][2]. +/// +/// [1]: ../trait.Commands.html#method.geo_dist +/// [2]: ../trait.Commands.html#method.geo_radius +pub enum Unit { + /// Represents meters. + Meters, + /// Represents kilometers. + Kilometers, + /// Represents miles. + Miles, + /// Represents feed. + Feet, +} + +impl ToRedisArgs for Unit { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + let unit = match *self { + Unit::Meters => "m", + Unit::Kilometers => "km", + Unit::Miles => "mi", + Unit::Feet => "ft", + }; + out.write_arg(unit.as_bytes()); + } +} + +/// A coordinate (longitude, latitude). Can be used with [`geo_pos`][1] +/// to parse response from Redis. +/// +/// [1]: ../trait.Commands.html#method.geo_pos +/// +/// `T` is the type of the every value. +/// +/// * You may want to use either `f64` or `f32` if you want to perform mathematical operations. +/// * To keep the raw value from Redis, use `String`. +#[allow(clippy::derive_partial_eq_without_eq)] // allow f32/f64 here, which don't implement Eq +#[derive(Debug, PartialEq)] +pub struct Coord { + /// Longitude + pub longitude: T, + /// Latitude + pub latitude: T, +} + +impl Coord { + /// Create a new Coord with the (longitude, latitude) + pub fn lon_lat(longitude: T, latitude: T) -> Coord { + Coord { + longitude, + latitude, + } + } +} + +impl FromRedisValue for Coord { + fn from_redis_value(v: &Value) -> RedisResult { + let values: Vec = FromRedisValue::from_redis_value(v)?; + let mut values = values.into_iter(); + let (longitude, latitude) = match (values.next(), values.next(), values.next()) { + (Some(longitude), Some(latitude), None) => (longitude, latitude), + _ => invalid_type_error!(v, "Expect a pair of numbers"), + }; + Ok(Coord { + longitude, + latitude, + }) + } +} + +impl ToRedisArgs for Coord { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + ToRedisArgs::write_redis_args(&self.longitude, out); + ToRedisArgs::write_redis_args(&self.latitude, out); + } + + fn is_single_arg(&self) -> bool { + false + } +} + +/// Options to sort results from [GEORADIUS][1] and [GEORADIUSBYMEMBER][2] commands +/// +/// [1]: https://redis.io/commands/georadius +/// [2]: https://redis.io/commands/georadiusbymember +#[derive(Default)] +pub enum RadiusOrder { + /// Don't sort the results + #[default] + Unsorted, + + /// Sort returned items from the nearest to the farthest, relative to the center. + Asc, + + /// Sort returned items from the farthest to the nearest, relative to the center. + Desc, +} + +/// Options for the [GEORADIUS][1] and [GEORADIUSBYMEMBER][2] commands +/// +/// [1]: https://redis.io/commands/georadius +/// [2]: https://redis.io/commands/georadiusbymember +/// +/// # Example +/// +/// ```rust,no_run +/// use redis::{Commands, RedisResult}; +/// use redis::geo::{RadiusSearchResult, RadiusOptions, RadiusOrder, Unit}; +/// fn nearest_in_radius( +/// con: &mut redis::Connection, +/// key: &str, +/// longitude: f64, +/// latitude: f64, +/// meters: f64, +/// limit: usize, +/// ) -> RedisResult> { +/// let opts = RadiusOptions::default() +/// .order(RadiusOrder::Asc) +/// .limit(limit); +/// con.geo_radius(key, longitude, latitude, meters, Unit::Meters, opts) +/// } +/// ``` +#[derive(Default)] +pub struct RadiusOptions { + with_coord: bool, + with_dist: bool, + count: Option, + order: RadiusOrder, + store: Option>>, + store_dist: Option>>, +} + +impl RadiusOptions { + /// Limit the results to the first N matching items. + pub fn limit(mut self, n: usize) -> Self { + self.count = Some(n); + self + } + + /// Return the distance of the returned items from the specified center. + /// The distance is returned in the same unit as the unit specified as the + /// radius argument of the command. + pub fn with_dist(mut self) -> Self { + self.with_dist = true; + self + } + + /// Return the `longitude, latitude` coordinates of the matching items. + pub fn with_coord(mut self) -> Self { + self.with_coord = true; + self + } + + /// Sort the returned items + pub fn order(mut self, o: RadiusOrder) -> Self { + self.order = o; + self + } + + /// Store the results in a sorted set at `key`, instead of returning them. + /// + /// This feature can't be used with any `with_*` method. + pub fn store(mut self, key: K) -> Self { + self.store = Some(ToRedisArgs::to_redis_args(&key)); + self + } + + /// Store the results in a sorted set at `key`, with the distance from the + /// center as its score. This feature can't be used with any `with_*` method. + pub fn store_dist(mut self, key: K) -> Self { + self.store_dist = Some(ToRedisArgs::to_redis_args(&key)); + self + } +} + +impl ToRedisArgs for RadiusOptions { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + if self.with_coord { + out.write_arg(b"WITHCOORD"); + } + + if self.with_dist { + out.write_arg(b"WITHDIST"); + } + + if let Some(n) = self.count { + out.write_arg(b"COUNT"); + out.write_arg_fmt(n); + } + + match self.order { + RadiusOrder::Asc => out.write_arg(b"ASC"), + RadiusOrder::Desc => out.write_arg(b"DESC"), + _ => (), + }; + + if let Some(ref store) = self.store { + out.write_arg(b"STORE"); + for i in store { + out.write_arg(i); + } + } + + if let Some(ref store_dist) = self.store_dist { + out.write_arg(b"STOREDIST"); + for i in store_dist { + out.write_arg(i); + } + } + } + + fn is_single_arg(&self) -> bool { + false + } +} + +/// Contain an item returned by [`geo_radius`][1] and [`geo_radius_by_member`][2]. +/// +/// [1]: ../trait.Commands.html#method.geo_radius +/// [2]: ../trait.Commands.html#method.geo_radius_by_member +pub struct RadiusSearchResult { + /// The name that was found. + pub name: String, + /// The coordinate if available. + pub coord: Option>, + /// The distance if available. + pub dist: Option, +} + +impl FromRedisValue for RadiusSearchResult { + fn from_redis_value(v: &Value) -> RedisResult { + // If we receive only the member name, it will be a plain string + if let Ok(name) = FromRedisValue::from_redis_value(v) { + return Ok(RadiusSearchResult { + name, + coord: None, + dist: None, + }); + } + + // Try to parse the result from multitple values + if let Value::Array(ref items) = *v { + if let Some(result) = RadiusSearchResult::parse_multi_values(items) { + return Ok(result); + } + } + + invalid_type_error!(v, "Response type not RadiusSearchResult compatible."); + } +} + +impl RadiusSearchResult { + fn parse_multi_values(items: &[Value]) -> Option { + let mut iter = items.iter(); + + // First item is always the member name + let name: String = match iter.next().map(FromRedisValue::from_redis_value) { + Some(Ok(n)) => n, + _ => return None, + }; + + let mut next = iter.next(); + + // Next element, if present, will be the distance. + let dist = match next.map(FromRedisValue::from_redis_value) { + Some(Ok(c)) => { + next = iter.next(); + Some(c) + } + _ => None, + }; + + // Finally, if present, the last item will be the coordinates + + let coord = match next.map(FromRedisValue::from_redis_value) { + Some(Ok(c)) => Some(c), + _ => None, + }; + + Some(RadiusSearchResult { name, coord, dist }) + } +} + +#[cfg(test)] +mod tests { + use super::{Coord, RadiusOptions, RadiusOrder}; + use crate::types::ToRedisArgs; + use std::str; + + macro_rules! assert_args { + ($value:expr, $($args:expr),+) => { + let args = $value.to_redis_args(); + let strings: Vec<_> = args.iter() + .map(|a| str::from_utf8(a.as_ref()).unwrap()) + .collect(); + assert_eq!(strings, vec![$($args),+]); + } + } + + #[test] + fn test_coord_to_args() { + let member = ("Palermo", Coord::lon_lat("13.361389", "38.115556")); + assert_args!(&member, "Palermo", "13.361389", "38.115556"); + } + + #[test] + fn test_radius_options() { + // Without options, should not generate any argument + let empty = RadiusOptions::default(); + assert_eq!(ToRedisArgs::to_redis_args(&empty).len(), 0); + + // Some combinations with WITH* options + let opts = RadiusOptions::default; + + assert_args!(opts().with_coord().with_dist(), "WITHCOORD", "WITHDIST"); + + assert_args!(opts().limit(50), "COUNT", "50"); + + assert_args!(opts().limit(50).store("x"), "COUNT", "50", "STORE", "x"); + + assert_args!( + opts().limit(100).store_dist("y"), + "COUNT", + "100", + "STOREDIST", + "y" + ); + + assert_args!( + opts().order(RadiusOrder::Asc).limit(10).with_dist(), + "WITHDIST", + "COUNT", + "10", + "ASC" + ); + } +} diff --git a/glide-core/redis-rs/redis/src/lib.rs b/glide-core/redis-rs/redis/src/lib.rs new file mode 100644 index 0000000000..0c960f3b4e --- /dev/null +++ b/glide-core/redis-rs/redis/src/lib.rs @@ -0,0 +1,507 @@ +//! redis-rs is a Rust implementation of a Redis client library. It exposes +//! a general purpose interface to Redis and also provides specific helpers for +//! commonly used functionality. +//! +//! The crate is called `redis` and you can depend on it via cargo: +//! +//! ```ini +//! [dependencies.redis] +//! version = "*" +//! ``` +//! +//! If you want to use the git version: +//! +//! ```ini +//! [dependencies.redis] +//! git = "https://github.com/redis-rs/redis-rs.git" +//! ``` +//! +//! # Basic Operation +//! +//! redis-rs exposes two API levels: a low- and a high-level part. +//! The high-level part does not expose all the functionality of redis and +//! might take some liberties in how it speaks the protocol. The low-level +//! part of the API allows you to express any request on the redis level. +//! You can fluently switch between both API levels at any point. +//! +//! ## Connection Handling +//! +//! For connecting to redis you can use a client object which then can produce +//! actual connections. Connections and clients as well as results of +//! connections and clients are considered `ConnectionLike` objects and +//! can be used anywhere a request is made. +//! +//! The full canonical way to get a connection is to create a client and +//! to ask for a connection from it: +//! +//! ```rust,no_run +//! extern crate redis; +//! +//! fn do_something() -> redis::RedisResult<()> { +//! let client = redis::Client::open("redis://127.0.0.1/")?; +//! let mut con = client.get_connection(None)?; +//! +//! /* do something here */ +//! +//! Ok(()) +//! } +//! ``` +//! +//! ## Optional Features +//! +//! There are a few features defined that can enable additional functionality +//! if so desired. Some of them are turned on by default. +//! +//! * `acl`: enables acl support (enabled by default) +//! * `aio`: enables async IO support (enabled by default) +//! * `geospatial`: enables geospatial support (enabled by default) +//! * `script`: enables script support (enabled by default) +//! * `r2d2`: enables r2d2 connection pool support (optional) +//! * `ahash`: enables ahash map/set support & uses ahash internally (+7-10% performance) (optional) +//! * `cluster`: enables redis cluster support (optional) +//! * `cluster-async`: enables async redis cluster support (optional) +//! * `tokio-comp`: enables support for tokio (optional) +//! * `connection-manager`: enables support for automatic reconnection (optional) +//! * `keep-alive`: enables keep-alive option on socket by means of `socket2` crate (optional) +//! +//! ## Connection Parameters +//! +//! redis-rs knows different ways to define where a connection should +//! go. The parameter to `Client::open` needs to implement the +//! `IntoConnectionInfo` trait of which there are three implementations: +//! +//! * string slices in `redis://` URL format. +//! * URL objects from the redis-url crate. +//! * `ConnectionInfo` objects. +//! +//! The URL format is `redis://[][:@][:port][/]` +//! +//! If Unix socket support is available you can use a unix URL in this format: +//! +//! `redis+unix:///[?db=[&pass=][&user=]]` +//! +//! For compatibility with some other redis libraries, the "unix" scheme +//! is also supported: +//! +//! `unix:///[?db=][&pass=][&user=]]` +//! +//! ## Executing Low-Level Commands +//! +//! To execute low-level commands you can use the `cmd` function which allows +//! you to build redis requests. Once you have configured a command object +//! to your liking you can send a query into any `ConnectionLike` object: +//! +//! ```rust,no_run +//! fn do_something(con: &mut redis::Connection) -> redis::RedisResult<()> { +//! let _ : () = redis::cmd("SET").arg("my_key").arg(42).query(con)?; +//! Ok(()) +//! } +//! ``` +//! +//! Upon querying the return value is a result object. If you do not care +//! about the actual return value (other than that it is not a failure) +//! you can always type annotate it to the unit type `()`. +//! +//! Note that commands with a sub-command (like "MEMORY USAGE", "ACL WHOAMI", +//! "LATENCY HISTORY", etc) must specify the sub-command as a separate `arg`: +//! +//! ```rust,no_run +//! fn do_something(con: &mut redis::Connection) -> redis::RedisResult { +//! // This will result in a server error: "unknown command `MEMORY USAGE`" +//! // because "USAGE" is technically a sub-command of "MEMORY". +//! redis::cmd("MEMORY USAGE").arg("my_key").query(con)?; +//! +//! // However, this will work as you'd expect +//! redis::cmd("MEMORY").arg("USAGE").arg("my_key").query(con) +//! } +//! ``` +//! +//! ## Executing High-Level Commands +//! +//! The high-level interface is similar. For it to become available you +//! need to use the `Commands` trait in which case all `ConnectionLike` +//! objects the library provides will also have high-level methods which +//! make working with the protocol easier: +//! +//! ```rust,no_run +//! extern crate redis; +//! use redis::Commands; +//! +//! fn do_something(con: &mut redis::Connection) -> redis::RedisResult<()> { +//! let _ : () = con.set("my_key", 42)?; +//! Ok(()) +//! } +//! ``` +//! +//! Note that high-level commands are work in progress and many are still +//! missing! +//! +//! ## Type Conversions +//! +//! Because redis inherently is mostly type-less and the protocol is not +//! exactly friendly to developers, this library provides flexible support +//! for casting values to the intended results. This is driven through the `FromRedisValue` and `ToRedisArgs` traits. +//! +//! The `arg` method of the command will accept a wide range of types through +//! the `ToRedisArgs` trait and the `query` method of a command can convert the +//! value to what you expect the function to return through the `FromRedisValue` +//! trait. This is quite flexible and allows vectors, tuples, hashsets, hashmaps +//! as well as optional values: +//! +//! ```rust,no_run +//! # use redis::Commands; +//! # use std::collections::{HashMap, HashSet}; +//! # fn do_something() -> redis::RedisResult<()> { +//! # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +//! # let mut con = client.get_connection(None).unwrap(); +//! let count : i32 = con.get("my_counter")?; +//! let count = con.get("my_counter").unwrap_or(0i32); +//! let k : Option = con.get("missing_key")?; +//! let name : String = con.get("my_name")?; +//! let bin : Vec = con.get("my_binary")?; +//! let map : HashMap = con.hgetall("my_hash")?; +//! let keys : Vec = con.hkeys("my_hash")?; +//! let mems : HashSet = con.smembers("my_set")?; +//! let (k1, k2) : (String, String) = con.get(&["k1", "k2"])?; +//! # Ok(()) +//! # } +//! ``` +//! +//! # Iteration Protocol +//! +//! In addition to sending a single query, iterators are also supported. When +//! used with regular bulk responses they don't give you much over querying and +//! converting into a vector (both use a vector internally) but they can also +//! be used with `SCAN` like commands in which case iteration will send more +//! queries until the cursor is exhausted: +//! +//! ```rust,ignore +//! # fn do_something() -> redis::RedisResult<()> { +//! # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +//! # let mut con = client.get_connection(None).unwrap(); +//! let mut iter : redis::Iter = redis::cmd("SSCAN").arg("my_set") +//! .cursor_arg(0).clone().iter(&mut con)?; +//! for x in iter { +//! // do something with the item +//! } +//! # Ok(()) } +//! ``` +//! +//! As you can see the cursor argument needs to be defined with `cursor_arg` +//! instead of `arg` so that the library knows which argument needs updating +//! as the query is run for more items. +//! +//! # Pipelining +//! +//! In addition to simple queries you can also send command pipelines. This +//! is provided through the `pipe` function. It works very similar to sending +//! individual commands but you can send more than one in one go. This also +//! allows you to ignore individual results so that matching on the end result +//! is easier: +//! +//! ```rust,no_run +//! # fn do_something() -> redis::RedisResult<()> { +//! # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +//! # let mut con = client.get_connection(None).unwrap(); +//! let (k1, k2) : (i32, i32) = redis::pipe() +//! .cmd("SET").arg("key_1").arg(42).ignore() +//! .cmd("SET").arg("key_2").arg(43).ignore() +//! .cmd("GET").arg("key_1") +//! .cmd("GET").arg("key_2").query(&mut con)?; +//! # Ok(()) } +//! ``` +//! +//! If you want the pipeline to be wrapped in a `MULTI`/`EXEC` block you can +//! easily do that by switching the pipeline into `atomic` mode. From the +//! caller's point of view nothing changes, the pipeline itself will take +//! care of the rest for you: +//! +//! ```rust,no_run +//! # fn do_something() -> redis::RedisResult<()> { +//! # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +//! # let mut con = client.get_connection(None).unwrap(); +//! let (k1, k2) : (i32, i32) = redis::pipe() +//! .atomic() +//! .cmd("SET").arg("key_1").arg(42).ignore() +//! .cmd("SET").arg("key_2").arg(43).ignore() +//! .cmd("GET").arg("key_1") +//! .cmd("GET").arg("key_2").query(&mut con)?; +//! # Ok(()) } +//! ``` +//! +//! You can also use high-level commands on pipelines: +//! +//! ```rust,no_run +//! # fn do_something() -> redis::RedisResult<()> { +//! # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +//! # let mut con = client.get_connection(None).unwrap(); +//! let (k1, k2) : (i32, i32) = redis::pipe() +//! .atomic() +//! .set("key_1", 42).ignore() +//! .set("key_2", 43).ignore() +//! .get("key_1") +//! .get("key_2").query(&mut con)?; +//! # Ok(()) } +//! ``` +//! +//! # Transactions +//! +//! Transactions are available through atomic pipelines. In order to use +//! them in a more simple way you can use the `transaction` function of a +//! connection: +//! +//! ```rust,no_run +//! # fn do_something() -> redis::RedisResult<()> { +//! use redis::Commands; +//! # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +//! # let mut con = client.get_connection(None).unwrap(); +//! let key = "the_key"; +//! let (new_val,) : (isize,) = redis::transaction(&mut con, &[key], |con, pipe| { +//! let old_val : isize = con.get(key)?; +//! pipe +//! .set(key, old_val + 1).ignore() +//! .get(key).query(con) +//! })?; +//! println!("The incremented number is: {}", new_val); +//! # Ok(()) } +//! ``` +//! +//! For more information see the `transaction` function. +//! +//! # PubSub +//! +//! Pubsub is currently work in progress but provided through the `PubSub` +//! connection object. Due to the fact that Rust does not have support +//! for async IO in libnative yet, the API does not provide a way to +//! read messages with any form of timeout yet. +//! +//! Example usage: +//! +//! ```rust,no_run +//! # fn do_something() -> redis::RedisResult<()> { +//! let client = redis::Client::open("redis://127.0.0.1/")?; +//! let mut con = client.get_connection(None)?; +//! let mut pubsub = con.as_pubsub(); +//! pubsub.subscribe("channel_1")?; +//! pubsub.subscribe("channel_2")?; +//! +//! loop { +//! let msg = pubsub.get_message()?; +//! let payload : String = msg.get_payload()?; +//! println!("channel '{}': {}", msg.get_channel_name(), payload); +//! } +//! # } +//! ``` +//! +#![cfg_attr( + feature = "script", + doc = r##" +# Scripts + +Lua scripts are supported through the `Script` type in a convenient +way (it does not support pipelining currently). It will automatically +load the script if it does not exist and invoke it. + +Example: + +```rust,no_run +# fn do_something() -> redis::RedisResult<()> { +# let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +# let mut con = client.get_connection(None).unwrap(); +let script = redis::Script::new(r" + return tonumber(ARGV[1]) + tonumber(ARGV[2]); +"); +let result : isize = script.arg(1).arg(2).invoke(&mut con)?; +assert_eq!(result, 3); +# Ok(()) } +``` +"## +)] +//! +#![cfg_attr( + feature = "aio", + doc = r##" +# Async + +In addition to the synchronous interface that's been explained above there also exists an +asynchronous interface based on [`futures`][] and [`tokio`][]. + +This interface exists under the `aio` (async io) module (which requires that the `aio` feature +is enabled) and largely mirrors the synchronous with a few concessions to make it fit the +constraints of `futures`. + +```rust,no_run +use futures::prelude::*; +use redis::AsyncCommands; + +# #[tokio::main] +# async fn main() -> redis::RedisResult<()> { +let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +let mut con = client.get_async_connection(None).await?; + +con.set("key1", b"foo").await?; + +redis::cmd("SET").arg(&["key2", "bar"]).query_async(&mut con).await?; + +let result = redis::cmd("MGET") + .arg(&["key1", "key2"]) + .query_async(&mut con) + .await; +assert_eq!(result, Ok(("foo".to_string(), b"bar".to_vec()))); +# Ok(()) } +``` +"## +)] +//! +//! [`futures`]:https://crates.io/crates/futures +//! [`tokio`]:https://tokio.rs + +#![deny(non_camel_case_types)] +#![warn(missing_docs)] +#![cfg_attr(docsrs, warn(rustdoc::broken_intra_doc_links))] +#![cfg_attr(docsrs, feature(doc_cfg))] +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] + +// public api +pub use crate::client::Client; +pub use crate::client::GlideConnectionOptions; +pub use crate::cmd::{cmd, pack_command, pipe, Arg, Cmd, Iter}; +pub use crate::commands::{ + Commands, ControlFlow, Direction, LposOptions, PubSubCommands, SetOptions, +}; +pub use crate::connection::{ + parse_redis_url, transaction, Connection, ConnectionAddr, ConnectionInfo, ConnectionLike, + IntoConnectionInfo, Msg, PubSub, PubSubChannelOrPattern, PubSubSubscriptionInfo, + PubSubSubscriptionKind, RedisConnectionInfo, TlsMode, +}; +pub use crate::parser::{parse_redis_value, Parser}; +pub use crate::pipeline::Pipeline; +pub use push_manager::{PushInfo, PushManager}; + +#[cfg(feature = "script")] +#[cfg_attr(docsrs, doc(cfg(feature = "script")))] +pub use crate::script::{Script, ScriptInvocation}; + +// preserve grouping and order +#[rustfmt::skip] +pub use crate::types::{ + // utility functions + from_redis_value, + from_owned_redis_value, + + // error kinds + ErrorKind, + + // conversion traits + FromRedisValue, + + // utility types + InfoDict, + NumericBehavior, + Expiry, + SetExpiry, + ExistenceCheck, + + // error and result types + RedisError, + RedisResult, + RedisWrite, + ToRedisArgs, + + // low level values + Value, + PushKind, + VerbatimFormat, + ProtocolVersion +}; + +#[cfg(feature = "aio")] +#[cfg_attr(docsrs, doc(cfg(feature = "aio")))] +pub use crate::{ + cmd::AsyncIter, commands::AsyncCommands, parser::parse_redis_value_async, types::RedisFuture, +}; + +mod macros; +mod pipeline; + +#[cfg(feature = "acl")] +#[cfg_attr(docsrs, doc(cfg(feature = "acl")))] +pub mod acl; + +#[cfg(feature = "aio")] +#[cfg_attr(docsrs, doc(cfg(feature = "aio")))] +pub mod aio; + +#[cfg(feature = "json")] +pub use crate::commands::JsonCommands; + +#[cfg(all(feature = "json", feature = "aio"))] +pub use crate::commands::JsonAsyncCommands; + +#[cfg(feature = "geospatial")] +#[cfg_attr(docsrs, doc(cfg(feature = "geospatial")))] +pub mod geo; + +#[cfg(feature = "cluster")] +#[cfg_attr(docsrs, doc(cfg(feature = "cluster")))] +pub mod cluster; + +#[cfg(feature = "cluster")] +#[cfg_attr(docsrs, doc(cfg(feature = "cluster")))] +/// Used for ReadFromReplicaStrategy information. +pub mod cluster_slotmap; + +#[cfg(feature = "cluster-async")] +pub use crate::commands::ScanStateRC; + +#[cfg(feature = "cluster-async")] +pub use crate::commands::ObjectType; + +#[cfg(feature = "cluster")] +mod cluster_client; + +/// for testing purposes +pub mod testing { + #[cfg(feature = "cluster")] + pub use crate::cluster_client::ClusterParams; +} + +#[cfg(feature = "cluster")] +mod cluster_pipeline; + +/// Routing information for cluster commands. +#[cfg(feature = "cluster")] +pub mod cluster_routing; + +#[cfg(feature = "cluster")] +#[cfg_attr(docsrs, doc(cfg(feature = "cluster")))] +pub mod cluster_topology; + +#[cfg(feature = "r2d2")] +#[cfg_attr(docsrs, doc(cfg(feature = "r2d2")))] +mod r2d2; + +#[cfg(feature = "streams")] +#[cfg_attr(docsrs, doc(cfg(feature = "streams")))] +pub mod streams; + +#[cfg(feature = "cluster-async")] +pub mod cluster_async; + +#[cfg(feature = "sentinel")] +pub mod sentinel; + +#[cfg(feature = "tls-rustls")] +mod tls; + +#[cfg(feature = "tls-rustls")] +pub use crate::tls::{ClientTlsConfig, TlsCertificates}; + +mod client; +mod cmd; +mod commands; +mod connection; +mod parser; +mod push_manager; +mod script; +mod types; diff --git a/glide-core/redis-rs/redis/src/macros.rs b/glide-core/redis-rs/redis/src/macros.rs new file mode 100644 index 0000000000..b8886cc759 --- /dev/null +++ b/glide-core/redis-rs/redis/src/macros.rs @@ -0,0 +1,7 @@ +#![macro_use] + +macro_rules! fail { + ($expr:expr) => { + return Err(::std::convert::From::from($expr)) + }; +} diff --git a/glide-core/redis-rs/redis/src/parser.rs b/glide-core/redis-rs/redis/src/parser.rs new file mode 100644 index 0000000000..1f42a774f1 --- /dev/null +++ b/glide-core/redis-rs/redis/src/parser.rs @@ -0,0 +1,658 @@ +use std::{ + io::{self, Read}, + str, +}; + +use crate::types::{ + ErrorKind, InternalValue, PushKind, RedisError, RedisResult, ServerError, ServerErrorKind, + Value, VerbatimFormat, +}; + +use combine::{ + any, + error::StreamError, + opaque, + parser::{ + byte::{crlf, take_until_bytes}, + combinator::{any_send_sync_partial_state, AnySendSyncPartialState}, + range::{recognize, take}, + }, + stream::{PointerOffset, RangeStream, StreamErrorFor}, + ParseError, Parser as _, +}; +use num_bigint::BigInt; + +const MAX_RECURSE_DEPTH: usize = 100; + +fn err_parser(line: &str) -> ServerError { + let mut pieces = line.splitn(2, ' '); + let kind = match pieces.next().unwrap() { + "ERR" => ServerErrorKind::ResponseError, + "EXECABORT" => ServerErrorKind::ExecAbortError, + "LOADING" => ServerErrorKind::BusyLoadingError, + "NOSCRIPT" => ServerErrorKind::NoScriptError, + "MOVED" => ServerErrorKind::Moved, + "ASK" => ServerErrorKind::Ask, + "TRYAGAIN" => ServerErrorKind::TryAgain, + "CLUSTERDOWN" => ServerErrorKind::ClusterDown, + "CROSSSLOT" => ServerErrorKind::CrossSlot, + "MASTERDOWN" => ServerErrorKind::MasterDown, + "READONLY" => ServerErrorKind::ReadOnly, + "NOTBUSY" => ServerErrorKind::NotBusy, + code => { + return ServerError::ExtensionError { + code: code.to_string(), + detail: pieces.next().map(|str| str.to_string()), + } + } + }; + let detail = pieces.next().map(|str| str.to_string()); + ServerError::KnownError { kind, detail } +} + +pub fn get_push_kind(kind: String) -> PushKind { + match kind.as_str() { + "invalidate" => PushKind::Invalidate, + "message" => PushKind::Message, + "pmessage" => PushKind::PMessage, + "smessage" => PushKind::SMessage, + "unsubscribe" => PushKind::Unsubscribe, + "punsubscribe" => PushKind::PUnsubscribe, + "sunsubscribe" => PushKind::SUnsubscribe, + "subscribe" => PushKind::Subscribe, + "psubscribe" => PushKind::PSubscribe, + "ssubscribe" => PushKind::SSubscribe, + _ => PushKind::Other(kind), + } +} + +fn value<'a, I>( + count: Option, +) -> impl combine::Parser +where + I: RangeStream, + I::Error: combine::ParseError, +{ + let count = count.unwrap_or(1); + + opaque!(any_send_sync_partial_state( + any() + .then_partial(move |&mut b| { + if b == b'*' && count > MAX_RECURSE_DEPTH { + combine::unexpected_any("Maximum recursion depth exceeded").left() + } else { + combine::value(b).right() + } + }) + .then_partial(move |&mut b| { + let line = || { + recognize(take_until_bytes(&b"\r\n"[..]).with(take(2).map(|_| ()))).and_then( + |line: &[u8]| { + str::from_utf8(&line[..line.len() - 2]) + .map_err(StreamErrorFor::::other) + }, + ) + }; + + let simple_string = || { + line().map(|line| { + if line == "OK" { + InternalValue::Okay + } else { + InternalValue::SimpleString(line.into()) + } + }) + }; + + let int = || { + line().and_then(|line| { + line.trim().parse::().map_err(|_| { + StreamErrorFor::::message_static_message( + "Expected integer, got garbage", + ) + }) + }) + }; + + let bulk_string = || { + int().then_partial(move |size| { + if *size < 0 { + combine::produce(|| InternalValue::Nil).left() + } else { + take(*size as usize) + .map(|bs: &[u8]| InternalValue::BulkString(bs.to_vec())) + .skip(crlf()) + .right() + } + }) + }; + let blob = || { + int().then_partial(move |size| { + take(*size as usize) + .map(|bs: &[u8]| String::from_utf8_lossy(bs).to_string()) + .skip(crlf()) + }) + }; + + let array = || { + int().then_partial(move |&mut length| { + if length < 0 { + combine::produce(|| InternalValue::Nil).left() + } else { + let length = length as usize; + combine::count_min_max(length, length, value(Some(count + 1))) + .map(InternalValue::Array) + .right() + } + }) + }; + + let error = || line().map(err_parser); + let map = || { + int().then_partial(move |&mut kv_length| { + let length = kv_length as usize * 2; + combine::count_min_max(length, length, value(Some(count + 1))).map( + move |result: Vec| { + let mut it = result.into_iter(); + let mut x = vec![]; + for _ in 0..kv_length { + if let (Some(k), Some(v)) = (it.next(), it.next()) { + x.push((k, v)) + } + } + InternalValue::Map(x) + }, + ) + }) + }; + let attribute = || { + int().then_partial(move |&mut kv_length| { + // + 1 is for data! + let length = kv_length as usize * 2 + 1; + combine::count_min_max(length, length, value(Some(count + 1))).map( + move |result: Vec| { + let mut it = result.into_iter(); + let mut attributes = vec![]; + for _ in 0..kv_length { + if let (Some(k), Some(v)) = (it.next(), it.next()) { + attributes.push((k, v)) + } + } + InternalValue::Attribute { + data: Box::new(it.next().unwrap()), + attributes, + } + }, + ) + }) + }; + let set = || { + int().then_partial(move |&mut length| { + if length < 0 { + combine::produce(|| InternalValue::Nil).left() + } else { + let length = length as usize; + combine::count_min_max(length, length, value(Some(count + 1))) + .map(InternalValue::Set) + .right() + } + }) + }; + let push = || { + int().then_partial(move |&mut length| { + if length <= 0 { + combine::produce(|| InternalValue::Push { + kind: PushKind::Other("".to_string()), + data: vec![], + }) + .left() + } else { + let length = length as usize; + combine::count_min_max(length, length, value(Some(count + 1))) + .and_then(|result: Vec| { + let mut it = result.into_iter(); + let first = it.next().unwrap_or(InternalValue::Nil); + if let InternalValue::BulkString(kind) = first { + let push_kind = String::from_utf8(kind) + .map_err(StreamErrorFor::::other)?; + Ok(InternalValue::Push { + kind: get_push_kind(push_kind), + data: it.collect(), + }) + } else if let InternalValue::SimpleString(kind) = first { + Ok(InternalValue::Push { + kind: get_push_kind(kind), + data: it.collect(), + }) + } else { + Err(StreamErrorFor::::message_static_message( + "parse error when decoding push", + )) + } + }) + .right() + } + }) + }; + let null = || line().map(|_| InternalValue::Nil); + let double = || { + line().and_then(|line| { + line.trim() + .parse::() + .map_err(StreamErrorFor::::other) + }) + }; + let boolean = || { + line().and_then(|line: &str| match line { + "t" => Ok(true), + "f" => Ok(false), + _ => Err(StreamErrorFor::::message_static_message( + "Expected boolean, got garbage", + )), + }) + }; + let blob_error = || blob().map(|line| err_parser(&line)); + let verbatim = || { + blob().and_then(|line| { + if let Some((format, text)) = line.split_once(':') { + let format = match format { + "txt" => VerbatimFormat::Text, + "mkd" => VerbatimFormat::Markdown, + x => VerbatimFormat::Unknown(x.to_string()), + }; + Ok(InternalValue::VerbatimString { + format, + text: text.to_string(), + }) + } else { + Err(StreamErrorFor::::message_static_message( + "parse error when decoding verbatim string", + )) + } + }) + }; + let big_number = || { + line().and_then(|line| { + BigInt::parse_bytes(line.as_bytes(), 10).ok_or_else(|| { + StreamErrorFor::::message_static_message( + "Expected bigint, got garbage", + ) + }) + }) + }; + combine::dispatch!(b; + b'+' => simple_string(), + b':' => int().map(InternalValue::Int), + b'$' => bulk_string(), + b'*' => array(), + b'%' => map(), + b'|' => attribute(), + b'~' => set(), + b'-' => error().map(InternalValue::ServerError), + b'_' => null(), + b',' => double().map(InternalValue::Double), + b'#' => boolean().map(InternalValue::Boolean), + b'!' => blob_error().map(InternalValue::ServerError), + b'=' => verbatim(), + b'(' => big_number().map(InternalValue::BigNumber), + b'>' => push(), + b => combine::unexpected_any(combine::error::Token(b)) + ) + }) + )) +} + +#[cfg(feature = "aio")] +mod aio_support { + use super::*; + + use bytes::{Buf, BytesMut}; + use tokio::io::AsyncRead; + use tokio_util::codec::{Decoder, Encoder}; + + #[derive(Default)] + pub struct ValueCodec { + state: AnySendSyncPartialState, + } + + impl ValueCodec { + fn decode_stream( + &mut self, + bytes: &mut BytesMut, + eof: bool, + ) -> RedisResult>> { + let (opt, removed_len) = { + let buffer = &bytes[..]; + let mut stream = + combine::easy::Stream(combine::stream::MaybePartialStream(buffer, !eof)); + match combine::stream::decode_tokio(value(None), &mut stream, &mut self.state) { + Ok(x) => x, + Err(err) => { + let err = err + .map_position(|pos| pos.translate_position(buffer)) + .map_range(|range| format!("{range:?}")) + .to_string(); + return Err(RedisError::from(( + ErrorKind::ParseError, + "parse error", + err, + ))); + } + } + }; + + bytes.advance(removed_len); + match opt { + Some(result) => Ok(Some(result.try_into())), + None => Ok(None), + } + } + } + + impl Encoder> for ValueCodec { + type Error = RedisError; + fn encode(&mut self, item: Vec, dst: &mut BytesMut) -> Result<(), Self::Error> { + dst.extend_from_slice(item.as_ref()); + Ok(()) + } + } + + impl Decoder for ValueCodec { + type Item = RedisResult; + type Error = RedisError; + + fn decode(&mut self, bytes: &mut BytesMut) -> Result, Self::Error> { + self.decode_stream(bytes, false) + } + + fn decode_eof(&mut self, bytes: &mut BytesMut) -> Result, Self::Error> { + self.decode_stream(bytes, true) + } + } + + /// Parses a redis value asynchronously. + pub async fn parse_redis_value_async( + decoder: &mut combine::stream::Decoder>, + read: &mut R, + ) -> RedisResult + where + R: AsyncRead + std::marker::Unpin, + { + let result = combine::decode_tokio!(*decoder, *read, value(None), |input, _| { + combine::stream::easy::Stream::from(input) + }); + match result { + Err(err) => Err(match err { + combine::stream::decoder::Error::Io { error, .. } => error.into(), + combine::stream::decoder::Error::Parse(err) => { + if err.is_unexpected_end_of_input() { + RedisError::from(io::Error::from(io::ErrorKind::UnexpectedEof)) + } else { + let err = err + .map_range(|range| format!("{range:?}")) + .map_position(|pos| pos.translate_position(decoder.buffer())) + .to_string(); + RedisError::from((ErrorKind::ParseError, "parse error", err)) + } + } + }), + Ok(result) => result.try_into(), + } + } +} + +#[cfg(feature = "aio")] +#[cfg_attr(docsrs, doc(cfg(feature = "aio")))] +pub use self::aio_support::*; + +/// The internal redis response parser. +pub struct Parser { + decoder: combine::stream::decoder::Decoder>, +} + +impl Default for Parser { + fn default() -> Self { + Parser::new() + } +} + +/// The parser can be used to parse redis responses into values. Generally +/// you normally do not use this directly as it's already done for you by +/// the client but in some more complex situations it might be useful to be +/// able to parse the redis responses. +impl Parser { + /// Creates a new parser that parses the data behind the reader. More + /// than one value can be behind the reader in which case the parser can + /// be invoked multiple times. In other words: the stream does not have + /// to be terminated. + pub fn new() -> Parser { + Parser { + decoder: combine::stream::decoder::Decoder::new(), + } + } + + // public api + + /// Parses synchronously into a single value from the reader. + pub fn parse_value(&mut self, mut reader: T) -> RedisResult { + let mut decoder = &mut self.decoder; + let result = combine::decode!(decoder, reader, value(None), |input, _| { + combine::stream::easy::Stream::from(input) + }); + match result { + Err(err) => Err(match err { + combine::stream::decoder::Error::Io { error, .. } => error.into(), + combine::stream::decoder::Error::Parse(err) => { + if err.is_unexpected_end_of_input() { + RedisError::from(io::Error::from(io::ErrorKind::UnexpectedEof)) + } else { + let err = err + .map_range(|range| format!("{range:?}")) + .map_position(|pos| pos.translate_position(decoder.buffer())) + .to_string(); + RedisError::from((ErrorKind::ParseError, "parse error", err)) + } + } + }), + Ok(result) => result.try_into(), + } + } +} + +/// Parses bytes into a redis value. +/// +/// This is the most straightforward way to parse something into a low +/// level redis value instead of having to use a whole parser. +pub fn parse_redis_value(bytes: &[u8]) -> RedisResult { + let mut parser = Parser::new(); + parser.parse_value(bytes) +} + +#[cfg(test)] +mod tests { + use crate::types::make_extension_error; + + use super::*; + + #[cfg(feature = "aio")] + #[test] + fn decode_eof_returns_none_at_eof() { + use tokio_util::codec::Decoder; + let mut codec = ValueCodec::default(); + + let mut bytes = bytes::BytesMut::from(&b"+GET 123\r\n"[..]); + assert_eq!( + codec.decode_eof(&mut bytes), + Ok(Some(Ok(parse_redis_value(b"+GET 123\r\n").unwrap()))) + ); + assert_eq!(codec.decode_eof(&mut bytes), Ok(None)); + assert_eq!(codec.decode_eof(&mut bytes), Ok(None)); + } + + #[cfg(feature = "aio")] + #[test] + fn decode_eof_returns_error_inside_array_and_can_parse_more_inputs() { + use tokio_util::codec::Decoder; + let mut codec = ValueCodec::default(); + + let mut bytes = + bytes::BytesMut::from(b"*3\r\n+OK\r\n-LOADING server is loading\r\n+OK\r\n".as_slice()); + let result = codec.decode_eof(&mut bytes).unwrap().unwrap(); + + assert_eq!( + result, + Err(RedisError::from(( + ErrorKind::BusyLoadingError, + "An error was signalled by the server", + "server is loading".to_string() + ))) + ); + + let mut bytes = bytes::BytesMut::from(b"+OK\r\n".as_slice()); + let result = codec.decode_eof(&mut bytes).unwrap().unwrap(); + + assert_eq!(result, Ok(Value::Okay)); + } + + #[test] + fn parse_nested_error_and_handle_more_inputs() { + // from https://redis.io/docs/interact/transactions/ - + // "EXEC returned two-element bulk string reply where one is an OK code and the other an error reply. It's up to the client library to find a sensible way to provide the error to the user." + + let bytes = b"*3\r\n+OK\r\n-LOADING server is loading\r\n+OK\r\n"; + let result = parse_redis_value(bytes); + + assert_eq!( + result, + Err(RedisError::from(( + ErrorKind::BusyLoadingError, + "An error was signalled by the server", + "server is loading".to_string() + ))) + ); + + let result = parse_redis_value(b"+OK\r\n").unwrap(); + + assert_eq!(result, Value::Okay); + } + + #[test] + fn decode_resp3_double() { + let val = parse_redis_value(b",1.23\r\n").unwrap(); + assert_eq!(val, Value::Double(1.23)); + let val = parse_redis_value(b",nan\r\n").unwrap(); + if let Value::Double(val) = val { + assert!(val.is_sign_positive()); + assert!(val.is_nan()); + } else { + panic!("expected double"); + } + // -nan is supported prior to redis 7.2 + let val = parse_redis_value(b",-nan\r\n").unwrap(); + if let Value::Double(val) = val { + assert!(val.is_sign_negative()); + assert!(val.is_nan()); + } else { + panic!("expected double"); + } + //Allow doubles in scientific E notation + let val = parse_redis_value(b",2.67923e+8\r\n").unwrap(); + assert_eq!(val, Value::Double(267923000.0)); + let val = parse_redis_value(b",2.67923E+8\r\n").unwrap(); + assert_eq!(val, Value::Double(267923000.0)); + let val = parse_redis_value(b",-2.67923E+8\r\n").unwrap(); + assert_eq!(val, Value::Double(-267923000.0)); + let val = parse_redis_value(b",2.1E-2\r\n").unwrap(); + assert_eq!(val, Value::Double(0.021)); + + let val = parse_redis_value(b",-inf\r\n").unwrap(); + assert_eq!(val, Value::Double(-f64::INFINITY)); + let val = parse_redis_value(b",inf\r\n").unwrap(); + assert_eq!(val, Value::Double(f64::INFINITY)); + } + + #[test] + fn decode_resp3_map() { + let val = parse_redis_value(b"%2\r\n+first\r\n:1\r\n+second\r\n:2\r\n").unwrap(); + let mut v = val.as_map_iter().unwrap(); + assert_eq!( + (&Value::SimpleString("first".to_string()), &Value::Int(1)), + v.next().unwrap() + ); + assert_eq!( + (&Value::SimpleString("second".to_string()), &Value::Int(2)), + v.next().unwrap() + ); + } + + #[test] + fn decode_resp3_boolean() { + let val = parse_redis_value(b"#t\r\n").unwrap(); + assert_eq!(val, Value::Boolean(true)); + let val = parse_redis_value(b"#f\r\n").unwrap(); + assert_eq!(val, Value::Boolean(false)); + let val = parse_redis_value(b"#x\r\n"); + assert!(val.is_err()); + let val = parse_redis_value(b"#\r\n"); + assert!(val.is_err()); + } + + #[test] + fn decode_resp3_blob_error() { + let val = parse_redis_value(b"!21\r\nSYNTAX invalid syntax\r\n"); + assert_eq!( + val.err(), + Some(make_extension_error( + "SYNTAX".to_string(), + Some("invalid syntax".to_string()) + )) + ) + } + + #[test] + fn decode_resp3_big_number() { + let val = parse_redis_value(b"(3492890328409238509324850943850943825024385\r\n").unwrap(); + assert_eq!( + val, + Value::BigNumber( + BigInt::parse_bytes(b"3492890328409238509324850943850943825024385", 10).unwrap() + ) + ); + } + + #[test] + fn decode_resp3_set() { + let val = parse_redis_value(b"~5\r\n+orange\r\n+apple\r\n#t\r\n:100\r\n:999\r\n").unwrap(); + let v = val.as_sequence().unwrap(); + assert_eq!(Value::SimpleString("orange".to_string()), v[0]); + assert_eq!(Value::SimpleString("apple".to_string()), v[1]); + assert_eq!(Value::Boolean(true), v[2]); + assert_eq!(Value::Int(100), v[3]); + assert_eq!(Value::Int(999), v[4]); + } + + #[test] + fn decode_resp3_push() { + let val = parse_redis_value(b">3\r\n+message\r\n+some_channel\r\n+this is the message\r\n") + .unwrap(); + if let Value::Push { ref kind, ref data } = val { + assert_eq!(&PushKind::Message, kind); + assert_eq!(Value::SimpleString("some_channel".to_string()), data[0]); + assert_eq!( + Value::SimpleString("this is the message".to_string()), + data[1] + ); + } else { + panic!("Expected Value::Push") + } + } + + #[test] + fn test_max_recursion_depth() { + let bytes = b"*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n"; + match parse_redis_value(bytes) { + Ok(_) => panic!("Expected Err"), + Err(e) => assert!(matches!(e.kind(), ErrorKind::ParseError)), + } + } +} diff --git a/glide-core/redis-rs/redis/src/pipeline.rs b/glide-core/redis-rs/redis/src/pipeline.rs new file mode 100644 index 0000000000..babb57a1ff --- /dev/null +++ b/glide-core/redis-rs/redis/src/pipeline.rs @@ -0,0 +1,324 @@ +#![macro_use] + +use crate::cmd::{cmd, cmd_len, Cmd}; +use crate::connection::ConnectionLike; +use crate::types::{ + from_owned_redis_value, ErrorKind, FromRedisValue, HashSet, RedisResult, ToRedisArgs, Value, +}; + +/// Represents a redis command pipeline. +#[derive(Clone)] +pub struct Pipeline { + commands: Vec, + transaction_mode: bool, + ignored_commands: HashSet, +} + +/// A pipeline allows you to send multiple commands in one go to the +/// redis server. API wise it's very similar to just using a command +/// but it allows multiple commands to be chained and some features such +/// as iteration are not available. +/// +/// Basic example: +/// +/// ```rust,no_run +/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +/// # let mut con = client.get_connection(None).unwrap(); +/// let ((k1, k2),) : ((i32, i32),) = redis::pipe() +/// .cmd("SET").arg("key_1").arg(42).ignore() +/// .cmd("SET").arg("key_2").arg(43).ignore() +/// .cmd("MGET").arg(&["key_1", "key_2"]).query(&mut con).unwrap(); +/// ``` +/// +/// As you can see with `cmd` you can start a new command. By default +/// each command produces a value but for some you can ignore them by +/// calling `ignore` on the command. That way it will be skipped in the +/// return value which is useful for `SET` commands and others, which +/// do not have a useful return value. +impl Pipeline { + /// Creates an empty pipeline. For consistency with the `cmd` + /// api a `pipe` function is provided as alias. + pub fn new() -> Pipeline { + Self::with_capacity(0) + } + + /// Creates an empty pipeline with pre-allocated capacity. + pub fn with_capacity(capacity: usize) -> Pipeline { + Pipeline { + commands: Vec::with_capacity(capacity), + transaction_mode: false, + ignored_commands: HashSet::new(), + } + } + + /// This enables atomic mode. In atomic mode the whole pipeline is + /// enclosed in `MULTI`/`EXEC`. From the user's point of view nothing + /// changes however. This is easier than using `MULTI`/`EXEC` yourself + /// as the format does not change. + /// + /// ```rust,no_run + /// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// let (k1, k2) : (i32, i32) = redis::pipe() + /// .atomic() + /// .cmd("GET").arg("key_1") + /// .cmd("GET").arg("key_2").query(&mut con).unwrap(); + /// ``` + #[inline] + pub fn atomic(&mut self) -> &mut Pipeline { + self.transaction_mode = true; + self + } + + /// Returns the encoded pipeline commands. + pub fn get_packed_pipeline(&self) -> Vec { + encode_pipeline(&self.commands, self.transaction_mode) + } + + #[cfg(feature = "aio")] + pub(crate) fn write_packed_pipeline(&self, out: &mut Vec) { + write_pipeline(out, &self.commands, self.transaction_mode) + } + + fn execute_pipelined(&self, con: &mut dyn ConnectionLike) -> RedisResult { + Ok(self.make_pipeline_results(con.req_packed_commands( + &encode_pipeline(&self.commands, false), + 0, + self.commands.len(), + )?)) + } + + fn execute_transaction(&self, con: &mut dyn ConnectionLike) -> RedisResult { + let mut resp = con.req_packed_commands( + &encode_pipeline(&self.commands, true), + self.commands.len() + 1, + 1, + )?; + match resp.pop() { + Some(Value::Nil) => Ok(Value::Nil), + Some(Value::Array(items)) => Ok(self.make_pipeline_results(items)), + _ => fail!(( + ErrorKind::ResponseError, + "Invalid response when parsing multi response" + )), + } + } + + /// Executes the pipeline and fetches the return values. Since most + /// pipelines return different types it's recommended to use tuple + /// matching to process the results: + /// + /// ```rust,no_run + /// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// let (k1, k2) : (i32, i32) = redis::pipe() + /// .cmd("SET").arg("key_1").arg(42).ignore() + /// .cmd("SET").arg("key_2").arg(43).ignore() + /// .cmd("GET").arg("key_1") + /// .cmd("GET").arg("key_2").query(&mut con).unwrap(); + /// ``` + /// + /// NOTE: A Pipeline object may be reused after `query()` with all the commands as were inserted + /// to them. In order to clear a Pipeline object with minimal memory released/allocated, + /// it is necessary to call the `clear()` before inserting new commands. + #[inline] + pub fn query(&self, con: &mut dyn ConnectionLike) -> RedisResult { + if !con.supports_pipelining() { + fail!(( + ErrorKind::ResponseError, + "This connection does not support pipelining." + )); + } + from_owned_redis_value(if self.commands.is_empty() { + Value::Array(vec![]) + } else if self.transaction_mode { + self.execute_transaction(con)? + } else { + self.execute_pipelined(con)? + }) + } + + #[cfg(feature = "aio")] + async fn execute_pipelined_async(&self, con: &mut C) -> RedisResult + where + C: crate::aio::ConnectionLike, + { + let value = con + .req_packed_commands(self, 0, self.commands.len()) + .await?; + Ok(self.make_pipeline_results(value)) + } + + #[cfg(feature = "aio")] + async fn execute_transaction_async(&self, con: &mut C) -> RedisResult + where + C: crate::aio::ConnectionLike, + { + let mut resp = con + .req_packed_commands(self, self.commands.len() + 1, 1) + .await?; + match resp.pop() { + Some(Value::Nil) => Ok(Value::Nil), + Some(Value::Array(items)) => Ok(self.make_pipeline_results(items)), + _ => Err(( + ErrorKind::ResponseError, + "Invalid response when parsing multi response", + ) + .into()), + } + } + + /// Async version of `query`. + #[inline] + #[cfg(feature = "aio")] + pub async fn query_async(&self, con: &mut C) -> RedisResult + where + C: crate::aio::ConnectionLike, + { + let v = if self.commands.is_empty() { + return from_owned_redis_value(Value::Array(vec![])); + } else if self.transaction_mode { + self.execute_transaction_async(con).await? + } else { + self.execute_pipelined_async(con).await? + }; + from_owned_redis_value(v) + } + + /// This is a shortcut to `query()` that does not return a value and + /// will fail the task if the query of the pipeline fails. + /// + /// This is equivalent to a call of query like this: + /// + /// ```rust,no_run + /// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// let _ : () = redis::pipe().cmd("PING").query(&mut con).unwrap(); + /// ``` + /// + /// NOTE: A Pipeline object may be reused after `query()` with all the commands as were inserted + /// to them. In order to clear a Pipeline object with minimal memory released/allocated, + /// it is necessary to call the `clear()` before inserting new commands. + #[inline] + pub fn execute(&self, con: &mut dyn ConnectionLike) { + self.query::<()>(con).unwrap(); + } +} + +fn encode_pipeline(cmds: &[Cmd], atomic: bool) -> Vec { + let mut rv = vec![]; + write_pipeline(&mut rv, cmds, atomic); + rv +} + +fn write_pipeline(rv: &mut Vec, cmds: &[Cmd], atomic: bool) { + let cmds_len = cmds.iter().map(cmd_len).sum(); + + if atomic { + let multi = cmd("MULTI"); + let exec = cmd("EXEC"); + rv.reserve(cmd_len(&multi) + cmd_len(&exec) + cmds_len); + + multi.write_packed_command_preallocated(rv); + for cmd in cmds { + cmd.write_packed_command_preallocated(rv); + } + exec.write_packed_command_preallocated(rv); + } else { + rv.reserve(cmds_len); + + for cmd in cmds { + cmd.write_packed_command_preallocated(rv); + } + } +} + +// Macro to implement shared methods between Pipeline and ClusterPipeline +macro_rules! implement_pipeline_commands { + ($struct_name:ident) => { + impl $struct_name { + /// Adds a command to the cluster pipeline. + #[inline] + pub fn add_command(&mut self, cmd: Cmd) -> &mut Self { + self.commands.push(cmd); + self + } + + /// Starts a new command. Functions such as `arg` then become + /// available to add more arguments to that command. + #[inline] + pub fn cmd(&mut self, name: &str) -> &mut Self { + self.add_command(cmd(name)) + } + + /// Returns an iterator over all the commands currently in this pipeline + pub fn cmd_iter(&self) -> impl Iterator { + self.commands.iter() + } + + /// Instructs the pipeline to ignore the return value of this command. + /// It will still be ensured that it is not an error, but any successful + /// result is just thrown away. This makes result processing through + /// tuples much easier because you do not need to handle all the items + /// you do not care about. + #[inline] + pub fn ignore(&mut self) -> &mut Self { + match self.commands.len() { + 0 => true, + x => self.ignored_commands.insert(x - 1), + }; + self + } + + /// Adds an argument to the last started command. This works similar + /// to the `arg` method of the `Cmd` object. + /// + /// Note that this function fails the task if executed on an empty pipeline. + #[inline] + pub fn arg(&mut self, arg: T) -> &mut Self { + { + let cmd = self.get_last_command(); + cmd.arg(arg); + } + self + } + + /// Clear a pipeline object's internal data structure. + /// + /// This allows reusing a pipeline object as a clear object while performing a minimal + /// amount of memory released/reallocated. + #[inline] + pub fn clear(&mut self) { + self.commands.clear(); + self.ignored_commands.clear(); + } + + #[inline] + fn get_last_command(&mut self) -> &mut Cmd { + let idx = match self.commands.len() { + 0 => panic!("No command on stack"), + x => x - 1, + }; + &mut self.commands[idx] + } + + fn make_pipeline_results(&self, resp: Vec) -> Value { + let mut rv = Vec::with_capacity(resp.len() - self.ignored_commands.len()); + for (idx, result) in resp.into_iter().enumerate() { + if !self.ignored_commands.contains(&idx) { + rv.push(result); + } + } + Value::Array(rv) + } + } + + impl Default for $struct_name { + fn default() -> Self { + Self::new() + } + } + }; +} + +implement_pipeline_commands!(Pipeline); diff --git a/glide-core/redis-rs/redis/src/push_manager.rs b/glide-core/redis-rs/redis/src/push_manager.rs new file mode 100644 index 0000000000..8a22e06a57 --- /dev/null +++ b/glide-core/redis-rs/redis/src/push_manager.rs @@ -0,0 +1,234 @@ +use crate::{PushKind, RedisResult, Value}; +use arc_swap::ArcSwap; +use std::sync::Arc; +use tokio::sync::mpsc; + +/// Holds information about received Push data +#[derive(Debug, Clone)] +pub struct PushInfo { + /// Push Kind + pub kind: PushKind, + /// Data from push message + pub data: Vec, +} + +/// Manages Push messages for single tokio channel +#[derive(Clone, Default)] +pub struct PushManager { + sender: Arc>>>, +} +impl PushManager { + /// It checks if value's type is Push + /// then invokes `try_send_raw` method + pub(crate) fn try_send(&self, value: &RedisResult) { + if let Ok(value) = &value { + self.try_send_raw(value); + } + } + + /// It checks if value's type is Push and there is a provided sender + /// then creates PushInfo and invokes `send` method of sender + pub(crate) fn try_send_raw(&self, value: &Value) { + if let Value::Push { kind, data } = value { + let guard = self.sender.load(); + if let Some(sender) = guard.as_ref() { + let push_info = PushInfo { + kind: kind.clone(), + data: data.clone(), + }; + if sender.send(push_info).is_err() { + self.sender.compare_and_swap(guard, Arc::new(None)); + } + } + } + } + /// Replace mpsc channel of `PushManager` with provided sender. + pub fn replace_sender(&self, sender: mpsc::UnboundedSender) { + self.sender.store(Arc::new(Some(sender))); + } + + /// Creates new `PushManager` + pub fn new() -> Self { + PushManager { + sender: Arc::from(ArcSwap::from(Arc::new(None))), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_send_and_receive_push_info() { + let push_manager = PushManager::new(); + let (tx, mut rx) = mpsc::unbounded_channel(); + push_manager.replace_sender(tx); + + let value = Ok(Value::Push { + kind: PushKind::Message, + data: vec![Value::BulkString("hello".to_string().into_bytes())], + }); + + push_manager.try_send(&value); + + let push_info = rx.try_recv().unwrap(); + assert_eq!(push_info.kind, PushKind::Message); + assert_eq!( + push_info.data, + vec![Value::BulkString("hello".to_string().into_bytes())] + ); + } + #[test] + fn test_push_manager_receiver_dropped() { + let push_manager = PushManager::new(); + let (tx, rx) = mpsc::unbounded_channel(); + push_manager.replace_sender(tx); + + let value = Ok(Value::Push { + kind: PushKind::Message, + data: vec![Value::BulkString("hello".to_string().into_bytes())], + }); + + drop(rx); + + push_manager.try_send(&value); + push_manager.try_send(&value); + push_manager.try_send(&value); + } + #[test] + fn test_push_manager_without_sender() { + let push_manager = PushManager::new(); + + push_manager.try_send(&Ok(Value::Push { + kind: PushKind::Message, + data: vec![Value::BulkString("hello".to_string().into_bytes())], + })); // nothing happens! + + let (tx, mut rx) = mpsc::unbounded_channel(); + push_manager.replace_sender(tx); + push_manager.try_send(&Ok(Value::Push { + kind: PushKind::Message, + data: vec![Value::BulkString("hello2".to_string().into_bytes())], + })); + + assert_eq!( + rx.try_recv().unwrap().data, + vec![Value::BulkString("hello2".to_string().into_bytes())] + ); + } + #[test] + fn test_push_manager_multiple_channels_and_messages() { + let push_manager = PushManager::new(); + let (tx1, mut rx1) = mpsc::unbounded_channel(); + let (tx2, mut rx2) = mpsc::unbounded_channel(); + push_manager.replace_sender(tx1); + + let value1 = Ok(Value::Push { + kind: PushKind::Message, + data: vec![Value::Int(1)], + }); + + let value2 = Ok(Value::Push { + kind: PushKind::Message, + data: vec![Value::Int(2)], + }); + + push_manager.try_send(&value1); + push_manager.try_send(&value2); + + assert_eq!(rx1.try_recv().unwrap().data, vec![Value::Int(1)]); + assert_eq!(rx1.try_recv().unwrap().data, vec![Value::Int(2)]); + + push_manager.replace_sender(tx2); + // make sure rx1 is disconnected after replacing tx1 with tx2. + assert_eq!( + rx1.try_recv().err().unwrap(), + mpsc::error::TryRecvError::Disconnected + ); + + push_manager.try_send(&value1); + push_manager.try_send(&value2); + + assert_eq!(rx2.try_recv().unwrap().data, vec![Value::Int(1)]); + assert_eq!(rx2.try_recv().unwrap().data, vec![Value::Int(2)]); + } + + #[tokio::test] + async fn test_push_manager_multi_threaded() { + // In this test we create 4 channels and send 1000 message, it switchs channels for each message we sent. + // Then we check if all messages are received and sum of messages are equal to expected sum. + // We also check if all channels are used. + let push_manager = PushManager::new(); + let (tx1, mut rx1) = mpsc::unbounded_channel(); + let (tx2, mut rx2) = mpsc::unbounded_channel(); + let (tx3, mut rx3) = mpsc::unbounded_channel(); + let (tx4, mut rx4) = mpsc::unbounded_channel(); + + let mut handles = vec![]; + let txs = [tx1, tx2, tx3, tx4]; + let mut expected_sum = 0; + for i in 0..1000 { + expected_sum += i; + let push_manager_clone = push_manager.clone(); + let new_tx = txs[(i % 4) as usize].clone(); + let value = Ok(Value::Push { + kind: PushKind::Message, + data: vec![Value::Int(i)], + }); + let handle = tokio::spawn(async move { + push_manager_clone.replace_sender(new_tx); + push_manager_clone.try_send(&value); + }); + handles.push(handle); + } + + for handle in handles { + handle.await.unwrap(); + } + + let mut count1 = 0; + let mut count2 = 0; + let mut count3 = 0; + let mut count4 = 0; + let mut received_sum = 0; + while let Ok(push_info) = rx1.try_recv() { + assert_eq!(push_info.kind, PushKind::Message); + if let Value::Int(i) = push_info.data[0] { + received_sum += i; + } + count1 += 1; + } + while let Ok(push_info) = rx2.try_recv() { + assert_eq!(push_info.kind, PushKind::Message); + if let Value::Int(i) = push_info.data[0] { + received_sum += i; + } + count2 += 1; + } + + while let Ok(push_info) = rx3.try_recv() { + assert_eq!(push_info.kind, PushKind::Message); + if let Value::Int(i) = push_info.data[0] { + received_sum += i; + } + count3 += 1; + } + + while let Ok(push_info) = rx4.try_recv() { + assert_eq!(push_info.kind, PushKind::Message); + if let Value::Int(i) = push_info.data[0] { + received_sum += i; + } + count4 += 1; + } + + assert_ne!(count1, 0); + assert_ne!(count2, 0); + assert_ne!(count3, 0); + assert_ne!(count4, 0); + + assert_eq!(count1 + count2 + count3 + count4, 1000); + assert_eq!(received_sum, expected_sum); + } +} diff --git a/glide-core/redis-rs/redis/src/r2d2.rs b/glide-core/redis-rs/redis/src/r2d2.rs new file mode 100644 index 0000000000..e34d2c7bb9 --- /dev/null +++ b/glide-core/redis-rs/redis/src/r2d2.rs @@ -0,0 +1,36 @@ +use std::io; + +use crate::{ConnectionLike, RedisError}; + +macro_rules! impl_manage_connection { + ($client:ty, $connection:ty) => { + impl r2d2::ManageConnection for $client { + type Connection = $connection; + type Error = RedisError; + + fn connect(&self) -> Result { + self.get_connection(None) + } + + fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { + if conn.check_connection() { + Ok(()) + } else { + Err(RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe))) + } + } + + fn has_broken(&self, conn: &mut Self::Connection) -> bool { + !conn.is_open() + } + } + }; +} + +impl_manage_connection!(crate::Client, crate::Connection); + +#[cfg(feature = "cluster")] +impl_manage_connection!( + crate::cluster::ClusterClient, + crate::cluster::ClusterConnection +); diff --git a/glide-core/redis-rs/redis/src/script.rs b/glide-core/redis-rs/redis/src/script.rs new file mode 100644 index 0000000000..c62d2344ae --- /dev/null +++ b/glide-core/redis-rs/redis/src/script.rs @@ -0,0 +1,255 @@ +#![cfg(feature = "script")] +use sha1_smol::Sha1; + +use crate::cmd::cmd; +use crate::connection::ConnectionLike; +use crate::types::{ErrorKind, FromRedisValue, RedisResult, ToRedisArgs}; +use crate::Cmd; + +/// Represents a lua script. +#[derive(Debug, Clone)] +pub struct Script { + code: String, + hash: String, +} + +/// The script object represents a lua script that can be executed on the +/// redis server. The object itself takes care of automatic uploading and +/// execution. The script object itself can be shared and is immutable. +/// +/// Example: +/// +/// ```rust,no_run +/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +/// # let mut con = client.get_connection(None).unwrap(); +/// let script = redis::Script::new(r" +/// return tonumber(ARGV[1]) + tonumber(ARGV[2]); +/// "); +/// let result = script.arg(1).arg(2).invoke(&mut con); +/// assert_eq!(result, Ok(3)); +/// ``` +impl Script { + /// Creates a new script object. + pub fn new(code: &str) -> Script { + let mut hash = Sha1::new(); + hash.update(code.as_bytes()); + Script { + code: code.to_string(), + hash: hash.digest().to_string(), + } + } + + /// Returns the script's SHA1 hash in hexadecimal format. + pub fn get_hash(&self) -> &str { + &self.hash + } + + /// Creates a script invocation object with a key filled in. + #[inline] + pub fn key(&self, key: T) -> ScriptInvocation<'_> { + ScriptInvocation { + script: self, + args: vec![], + keys: key.to_redis_args(), + } + } + + /// Creates a script invocation object with an argument filled in. + #[inline] + pub fn arg(&self, arg: T) -> ScriptInvocation<'_> { + ScriptInvocation { + script: self, + args: arg.to_redis_args(), + keys: vec![], + } + } + + /// Returns an empty script invocation object. This is primarily useful + /// for programmatically adding arguments and keys because the type will + /// not change. Normally you can use `arg` and `key` directly. + #[inline] + pub fn prepare_invoke(&self) -> ScriptInvocation<'_> { + ScriptInvocation { + script: self, + args: vec![], + keys: vec![], + } + } + + /// Invokes the script directly without arguments. + #[inline] + pub fn invoke(&self, con: &mut dyn ConnectionLike) -> RedisResult { + ScriptInvocation { + script: self, + args: vec![], + keys: vec![], + } + .invoke(con) + } + + /// Asynchronously invokes the script without arguments. + #[inline] + #[cfg(feature = "aio")] + pub async fn invoke_async(&self, con: &mut C) -> RedisResult + where + C: crate::aio::ConnectionLike, + T: FromRedisValue, + { + ScriptInvocation { + script: self, + args: vec![], + keys: vec![], + } + .invoke_async(con) + .await + } +} + +/// Represents a prepared script call. +pub struct ScriptInvocation<'a> { + script: &'a Script, + args: Vec>, + keys: Vec>, +} + +/// This type collects keys and other arguments for the script so that it +/// can be then invoked. While the `Script` type itself holds the script, +/// the `ScriptInvocation` holds the arguments that should be invoked until +/// it's sent to the server. +impl<'a> ScriptInvocation<'a> { + /// Adds a regular argument to the invocation. This ends up as `ARGV[i]` + /// in the script. + #[inline] + pub fn arg<'b, T: ToRedisArgs>(&'b mut self, arg: T) -> &'b mut ScriptInvocation<'a> + where + 'a: 'b, + { + arg.write_redis_args(&mut self.args); + self + } + + /// Adds a key argument to the invocation. This ends up as `KEYS[i]` + /// in the script. + #[inline] + pub fn key<'b, T: ToRedisArgs>(&'b mut self, key: T) -> &'b mut ScriptInvocation<'a> + where + 'a: 'b, + { + key.write_redis_args(&mut self.keys); + self + } + + /// Invokes the script and returns the result. + #[inline] + pub fn invoke(&self, con: &mut dyn ConnectionLike) -> RedisResult { + let eval_cmd = self.eval_cmd(); + match eval_cmd.query(con) { + Ok(val) => Ok(val), + Err(err) => { + if err.kind() == ErrorKind::NoScriptError { + self.load_cmd().query(con)?; + eval_cmd.query(con) + } else { + Err(err) + } + } + } + } + + /// Asynchronously invokes the script and returns the result. + #[inline] + #[cfg(feature = "aio")] + pub async fn invoke_async(&self, con: &mut C) -> RedisResult + where + C: crate::aio::ConnectionLike, + T: FromRedisValue, + { + let eval_cmd = self.eval_cmd(); + match eval_cmd.query_async(con).await { + Ok(val) => { + // Return the value from the script evaluation + Ok(val) + } + Err(err) => { + // Load the script into Redis if the script hash wasn't there already + if err.kind() == ErrorKind::NoScriptError { + self.load_cmd().query_async(con).await?; + eval_cmd.query_async(con).await + } else { + Err(err) + } + } + } + } + + /// Loads the script and returns the SHA1 of it. + #[inline] + pub fn load(&self, con: &mut dyn ConnectionLike) -> RedisResult { + let hash: String = self.load_cmd().query(con)?; + + debug_assert_eq!(hash, self.script.hash); + + Ok(hash) + } + + /// Asynchronously loads the script and returns the SHA1 of it. + #[inline] + #[cfg(feature = "aio")] + pub async fn load_async(&self, con: &mut C) -> RedisResult + where + C: crate::aio::ConnectionLike, + { + let hash: String = self.load_cmd().query_async(con).await?; + + debug_assert_eq!(hash, self.script.hash); + + Ok(hash) + } + + fn load_cmd(&self) -> Cmd { + let mut cmd = cmd("SCRIPT"); + cmd.arg("LOAD").arg(self.script.code.as_bytes()); + cmd + } + + fn estimate_buflen(&self) -> usize { + self + .keys + .iter() + .chain(self.args.iter()) + .fold(0, |acc, e| acc + e.len()) + + 7 /* "EVALSHA".len() */ + + self.script.hash.len() + + 4 /* Slots reserved for the length of keys. */ + } + + fn eval_cmd(&self) -> Cmd { + let args_len = 3 + self.keys.len() + self.args.len(); + let mut cmd = Cmd::with_capacity(args_len, self.estimate_buflen()); + cmd.arg("EVALSHA") + .arg(self.script.hash.as_bytes()) + .arg(self.keys.len()) + .arg(&*self.keys) + .arg(&*self.args); + cmd + } +} + +#[cfg(test)] +mod tests { + use super::Script; + + #[test] + fn script_eval_should_work() { + let script = Script::new("return KEYS[1]"); + let invocation = script.key("dummy"); + let estimated_buflen = invocation.estimate_buflen(); + let cmd = invocation.eval_cmd(); + assert!(estimated_buflen >= cmd.capacity().1); + let expected = "*4\r\n$7\r\nEVALSHA\r\n$40\r\n4a2267357833227dd98abdedb8cf24b15a986445\r\n$1\r\n1\r\n$5\r\ndummy\r\n"; + assert_eq!( + expected, + std::str::from_utf8(cmd.get_packed_command().as_slice()).unwrap() + ); + } +} diff --git a/glide-core/redis-rs/redis/src/sentinel.rs b/glide-core/redis-rs/redis/src/sentinel.rs new file mode 100644 index 0000000000..569ab2fe0f --- /dev/null +++ b/glide-core/redis-rs/redis/src/sentinel.rs @@ -0,0 +1,777 @@ +//! Defines a Sentinel type that connects to Redis sentinels and creates clients to +//! master or replica nodes. +//! +//! # Example +//! ```rust,no_run +//! use redis::Commands; +//! use redis::sentinel::Sentinel; +//! +//! let nodes = vec!["redis://127.0.0.1:6379/", "redis://127.0.0.1:6378/", "redis://127.0.0.1:6377/"]; +//! let mut sentinel = Sentinel::build(nodes).unwrap(); +//! let mut master = sentinel.master_for("master_name", None).unwrap().get_connection(None).unwrap(); +//! let mut replica = sentinel.replica_for("master_name", None).unwrap().get_connection(None).unwrap(); +//! +//! let _: () = master.set("test", "test_data").unwrap(); +//! let rv: String = replica.get("test").unwrap(); +//! +//! assert_eq!(rv, "test_data"); +//! ``` +//! +//! There is also a SentinelClient which acts like a regular Client, providing the +//! `get_connection` and `get_async_connection` methods, internally using the Sentinel +//! type to create clients on demand for the desired node type (Master or Replica). +//! +//! # Example +//! ```rust,no_run +//! use redis::Commands; +//! use redis::sentinel::{ SentinelServerType, SentinelClient }; +//! +//! let nodes = vec!["redis://127.0.0.1:6379/", "redis://127.0.0.1:6378/", "redis://127.0.0.1:6377/"]; +//! let mut master_client = SentinelClient::build(nodes.clone(), String::from("master_name"), None, SentinelServerType::Master).unwrap(); +//! let mut replica_client = SentinelClient::build(nodes, String::from("master_name"), None, SentinelServerType::Replica).unwrap(); +//! let mut master_conn = master_client.get_connection().unwrap(); +//! let mut replica_conn = replica_client.get_connection().unwrap(); +//! +//! let _: () = master_conn.set("test", "test_data").unwrap(); +//! let rv: String = replica_conn.get("test").unwrap(); +//! +//! assert_eq!(rv, "test_data"); +//! ``` +//! +//! If the sentinel's nodes are using TLS or require authentication, a full +//! SentinelNodeConnectionInfo struct may be used instead of just the master's name: +//! +//! # Example +//! ```rust,no_run +//! use redis::{ Commands, RedisConnectionInfo }; +//! use redis::sentinel::{ Sentinel, SentinelNodeConnectionInfo }; +//! +//! let nodes = vec!["redis://127.0.0.1:6379/", "redis://127.0.0.1:6378/", "redis://127.0.0.1:6377/"]; +//! let mut sentinel = Sentinel::build(nodes).unwrap(); +//! +//! let mut master_with_auth = sentinel +//! .master_for( +//! "master_name", +//! Some(&SentinelNodeConnectionInfo { +//! tls_mode: None, +//! redis_connection_info: Some(RedisConnectionInfo { +//! db: 1, +//! username: Some(String::from("foo")), +//! password: Some(String::from("bar")), +//! ..Default::default() +//! }), +//! }), +//! ) +//! .unwrap() +//! .get_connection(None) +//! .unwrap(); +//! +//! let mut replica_with_tls = sentinel +//! .master_for( +//! "master_name", +//! Some(&SentinelNodeConnectionInfo { +//! tls_mode: Some(redis::TlsMode::Secure), +//! redis_connection_info: None, +//! }), +//! ) +//! .unwrap() +//! .get_connection(None) +//! .unwrap(); +//! ``` +//! +//! # Example +//! ```rust,no_run +//! use redis::{ Commands, RedisConnectionInfo }; +//! use redis::sentinel::{ SentinelServerType, SentinelClient, SentinelNodeConnectionInfo }; +//! +//! let nodes = vec!["redis://127.0.0.1:6379/", "redis://127.0.0.1:6378/", "redis://127.0.0.1:6377/"]; +//! let mut master_client = SentinelClient::build( +//! nodes, +//! String::from("master1"), +//! Some(SentinelNodeConnectionInfo { +//! tls_mode: Some(redis::TlsMode::Insecure), +//! redis_connection_info: Some(RedisConnectionInfo { +//! username: Some(String::from("user")), +//! password: Some(String::from("pass")), +//! ..Default::default() +//! }), +//! }), +//! redis::sentinel::SentinelServerType::Master, +//! ) +//! .unwrap(); +//! ``` +//! + +use std::{collections::HashMap, num::NonZeroUsize}; + +#[cfg(feature = "aio")] +use futures_util::StreamExt; +use rand::Rng; + +#[cfg(feature = "aio")] +use crate::aio::MultiplexedConnection as AsyncConnection; + +use crate::{ + client::GlideConnectionOptions, connection::ConnectionInfo, types::RedisResult, Client, Cmd, + Connection, ErrorKind, FromRedisValue, IntoConnectionInfo, RedisConnectionInfo, TlsMode, Value, +}; + +/// The Sentinel type, serves as a special purpose client which builds other clients on +/// demand. +pub struct Sentinel { + sentinels_connection_info: Vec, + connections_cache: Vec>, + #[cfg(feature = "aio")] + async_connections_cache: Vec>, + replica_start_index: usize, +} + +/// Holds the connection information that a sentinel should use when connecting to the +/// servers (masters and replicas) belonging to it. +#[derive(Clone, Default)] +pub struct SentinelNodeConnectionInfo { + /// The TLS mode of the connection, or None if we do not want to connect using TLS + /// (just a plain TCP connection). + pub tls_mode: Option, + + /// The Redis specific/connection independent information to be used. + pub redis_connection_info: Option, +} + +impl SentinelNodeConnectionInfo { + fn create_connection_info(&self, ip: String, port: u16) -> ConnectionInfo { + let addr = match self.tls_mode { + None => crate::ConnectionAddr::Tcp(ip, port), + Some(TlsMode::Secure) => crate::ConnectionAddr::TcpTls { + host: ip, + port, + insecure: false, + tls_params: None, + }, + Some(TlsMode::Insecure) => crate::ConnectionAddr::TcpTls { + host: ip, + port, + insecure: true, + tls_params: None, + }, + }; + + ConnectionInfo { + addr, + redis: self.redis_connection_info.clone().unwrap_or_default(), + } + } +} + +impl Default for &SentinelNodeConnectionInfo { + fn default() -> Self { + static DEFAULT_VALUE: SentinelNodeConnectionInfo = SentinelNodeConnectionInfo { + tls_mode: None, + redis_connection_info: None, + }; + &DEFAULT_VALUE + } +} + +fn sentinel_masters_cmd() -> crate::Cmd { + let mut cmd = crate::cmd("SENTINEL"); + cmd.arg("MASTERS"); + cmd +} + +fn sentinel_replicas_cmd(master_name: &str) -> crate::Cmd { + let mut cmd = crate::cmd("SENTINEL"); + cmd.arg("SLAVES"); // For compatibility with older redis versions + cmd.arg(master_name); + cmd +} + +fn is_master_valid(master_info: &HashMap, service_name: &str) -> bool { + master_info.get("name").map(|s| s.as_str()) == Some(service_name) + && master_info.contains_key("ip") + && master_info.contains_key("port") + && master_info.get("flags").map_or(false, |flags| { + flags.contains("master") && !flags.contains("s_down") && !flags.contains("o_down") + }) + && master_info["port"].parse::().is_ok() +} + +fn is_replica_valid(replica_info: &HashMap) -> bool { + replica_info.contains_key("ip") + && replica_info.contains_key("port") + && replica_info.get("flags").map_or(false, |flags| { + !flags.contains("s_down") && !flags.contains("o_down") + }) + && replica_info["port"].parse::().is_ok() +} + +/// Generates a random value in the 0..max range. +fn random_replica_index(max: NonZeroUsize) -> usize { + rand::thread_rng().gen_range(0..max.into()) +} + +fn try_connect_to_first_replica( + addresses: &[ConnectionInfo], + start_index: Option, +) -> Result { + if addresses.is_empty() { + fail!(( + ErrorKind::NoValidReplicasFoundBySentinel, + "No valid replica found in sentinel for given name", + )); + } + + let start_index = start_index.unwrap_or(0); + + let mut last_err = None; + for i in 0..addresses.len() { + let index = (i + start_index) % addresses.len(); + match Client::open(addresses[index].clone()) { + Ok(client) => return Ok(client), + Err(err) => last_err = Some(err), + } + } + + // We can unwrap here because we know there is at least one error, since there is at + // least one address, so we'll either return a client for it or store an error in + // last_err. + Err(last_err.expect("There should be an error because there is should be at least one address")) +} + +fn valid_addrs<'a>( + servers_info: Vec>, + validate: impl Fn(&HashMap) -> bool + 'a, +) -> impl Iterator { + servers_info + .into_iter() + .filter(move |info| validate(info)) + .map(|mut info| { + // We can unwrap here because we already checked everything + let ip = info.remove("ip").unwrap(); + let port = info["port"].parse::().unwrap(); + (ip, port) + }) +} + +fn check_role_result(result: &RedisResult>, target_role: &str) -> bool { + if let Ok(values) = result { + if !values.is_empty() { + if let Ok(role) = String::from_redis_value(&values[0]) { + return role.to_ascii_lowercase() == target_role; + } + } + } + false +} + +fn check_role(connection_info: &ConnectionInfo, target_role: &str) -> bool { + if let Ok(client) = Client::open(connection_info.clone()) { + if let Ok(mut conn) = client.get_connection(None) { + let result: RedisResult> = crate::cmd("ROLE").query(&mut conn); + return check_role_result(&result, target_role); + } + } + false +} + +/// Searches for a valid master with the given name in the list of masters returned by +/// a sentinel. A valid master is one which has a role of "master" (checked by running +/// the `ROLE` command and by seeing if its flags contains the "master" flag) and which +/// does not have the flags s_down or o_down set to it (these flags are returned by the +/// `SENTINEL MASTERS` command, and we expect the `masters` parameter to be the result of +/// that command). +fn find_valid_master( + masters: Vec>, + service_name: &str, + node_connection_info: &SentinelNodeConnectionInfo, +) -> RedisResult { + for (ip, port) in valid_addrs(masters, |m| is_master_valid(m, service_name)) { + let connection_info = node_connection_info.create_connection_info(ip, port); + if check_role(&connection_info, "master") { + return Ok(connection_info); + } + } + + fail!(( + ErrorKind::MasterNameNotFoundBySentinel, + "Master with given name not found in sentinel", + )) +} + +#[cfg(feature = "aio")] +async fn async_check_role(connection_info: &ConnectionInfo, target_role: &str) -> bool { + if let Ok(client) = Client::open(connection_info.clone()) { + if let Ok(mut conn) = client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await + { + let result: RedisResult> = crate::cmd("ROLE").query_async(&mut conn).await; + return check_role_result(&result, target_role); + } + } + false +} + +/// Async version of [find_valid_master]. +#[cfg(feature = "aio")] +async fn async_find_valid_master( + masters: Vec>, + service_name: &str, + node_connection_info: &SentinelNodeConnectionInfo, +) -> RedisResult { + for (ip, port) in valid_addrs(masters, |m| is_master_valid(m, service_name)) { + let connection_info = node_connection_info.create_connection_info(ip, port); + if async_check_role(&connection_info, "master").await { + return Ok(connection_info); + } + } + + fail!(( + ErrorKind::MasterNameNotFoundBySentinel, + "Master with given name not found in sentinel", + )) +} + +fn get_valid_replicas_addresses( + replicas: Vec>, + node_connection_info: &SentinelNodeConnectionInfo, +) -> Vec { + valid_addrs(replicas, is_replica_valid) + .map(|(ip, port)| node_connection_info.create_connection_info(ip, port)) + .filter(|connection_info| check_role(connection_info, "slave")) + .collect() +} + +#[cfg(feature = "aio")] +async fn async_get_valid_replicas_addresses<'a>( + replicas: Vec>, + node_connection_info: &SentinelNodeConnectionInfo, +) -> Vec { + async fn is_replica_role_valid(connection_info: ConnectionInfo) -> Option { + if async_check_role(&connection_info, "slave").await { + Some(connection_info) + } else { + None + } + } + + futures_util::stream::iter(valid_addrs(replicas, is_replica_valid)) + .map(|(ip, port)| node_connection_info.create_connection_info(ip, port)) + .filter_map(is_replica_role_valid) + .collect() + .await +} + +#[cfg(feature = "aio")] +async fn async_reconnect( + connection: &mut Option, + connection_info: &ConnectionInfo, +) -> RedisResult<()> { + let sentinel_client = Client::open(connection_info.clone())?; + let new_connection = sentinel_client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + connection.replace(new_connection); + Ok(()) +} + +#[cfg(feature = "aio")] +async fn async_try_single_sentinel( + cmd: Cmd, + connection_info: &ConnectionInfo, + cached_connection: &mut Option, +) -> RedisResult { + if cached_connection.is_none() { + async_reconnect(cached_connection, connection_info).await?; + } + + let result = cmd.query_async(cached_connection.as_mut().unwrap()).await; + + if let Err(err) = result { + if err.is_unrecoverable_error() || err.is_io_error() { + async_reconnect(cached_connection, connection_info).await?; + cmd.query_async(cached_connection.as_mut().unwrap()).await + } else { + Err(err) + } + } else { + result + } +} + +fn reconnect( + connection: &mut Option, + connection_info: &ConnectionInfo, +) -> RedisResult<()> { + let sentinel_client = Client::open(connection_info.clone())?; + let new_connection = sentinel_client.get_connection(None)?; + connection.replace(new_connection); + Ok(()) +} + +fn try_single_sentinel( + cmd: Cmd, + connection_info: &ConnectionInfo, + cached_connection: &mut Option, +) -> RedisResult { + if cached_connection.is_none() { + reconnect(cached_connection, connection_info)?; + } + + let result = cmd.query(cached_connection.as_mut().unwrap()); + + if let Err(err) = result { + if err.is_unrecoverable_error() || err.is_io_error() { + reconnect(cached_connection, connection_info)?; + cmd.query(cached_connection.as_mut().unwrap()) + } else { + Err(err) + } + } else { + result + } +} + +// non-async methods +impl Sentinel { + /// Creates a Sentinel client performing some basic + /// checks on the URLs that might make the operation fail. + pub fn build(params: Vec) -> RedisResult { + if params.is_empty() { + fail!(( + ErrorKind::EmptySentinelList, + "At least one sentinel is required", + )) + } + + let sentinels_connection_info = params + .into_iter() + .map(|p| p.into_connection_info()) + .collect::>>()?; + + let mut connections_cache = vec![]; + connections_cache.resize_with(sentinels_connection_info.len(), Default::default); + + #[cfg(feature = "aio")] + { + let mut async_connections_cache = vec![]; + async_connections_cache.resize_with(sentinels_connection_info.len(), Default::default); + + Ok(Sentinel { + sentinels_connection_info, + connections_cache, + async_connections_cache, + replica_start_index: random_replica_index(NonZeroUsize::new(1000000).unwrap()), + }) + } + + #[cfg(not(feature = "aio"))] + { + Ok(Sentinel { + sentinels_connection_info, + connections_cache, + replica_start_index: random_replica_index(NonZeroUsize::new(1000000).unwrap()), + }) + } + } + + /// Try to execute the given command in each sentinel, returning the result of the + /// first one that executes without errors. If all return errors, we return the + /// error of the last attempt. + /// + /// For each sentinel, we first check if there is a cached connection, and if not + /// we attempt to connect to it (skipping that sentinel if there is an error during + /// the connection). Then, we attempt to execute the given command with the cached + /// connection. If there is an error indicating that the connection is invalid, we + /// reconnect and try one more time in the new connection. + /// + fn try_all_sentinels(&mut self, cmd: Cmd) -> RedisResult { + let mut last_err = None; + for (connection_info, cached_connection) in self + .sentinels_connection_info + .iter() + .zip(self.connections_cache.iter_mut()) + { + match try_single_sentinel(cmd.clone(), connection_info, cached_connection) { + Ok(result) => { + return Ok(result); + } + Err(err) => { + last_err = Some(err); + } + } + } + + // We can unwrap here because we know there is at least one connection info. + Err(last_err.expect("There should be at least one connection info")) + } + + /// Get a list of all masters (using the command SENTINEL MASTERS) from the + /// sentinels. + fn get_sentinel_masters(&mut self) -> RedisResult>> { + self.try_all_sentinels(sentinel_masters_cmd()) + } + + fn get_sentinel_replicas( + &mut self, + service_name: &str, + ) -> RedisResult>> { + self.try_all_sentinels(sentinel_replicas_cmd(service_name)) + } + + fn find_master_address( + &mut self, + service_name: &str, + node_connection_info: &SentinelNodeConnectionInfo, + ) -> RedisResult { + let masters = self.get_sentinel_masters()?; + find_valid_master(masters, service_name, node_connection_info) + } + + fn find_valid_replica_addresses( + &mut self, + service_name: &str, + node_connection_info: &SentinelNodeConnectionInfo, + ) -> RedisResult> { + let replicas = self.get_sentinel_replicas(service_name)?; + Ok(get_valid_replicas_addresses(replicas, node_connection_info)) + } + + /// Determines the masters address for the given name, and returns a client for that + /// master. + pub fn master_for( + &mut self, + service_name: &str, + node_connection_info: Option<&SentinelNodeConnectionInfo>, + ) -> RedisResult { + let connection_info = + self.find_master_address(service_name, node_connection_info.unwrap_or_default())?; + Client::open(connection_info) + } + + /// Connects to a randomly chosen replica of the given master name. + pub fn replica_for( + &mut self, + service_name: &str, + node_connection_info: Option<&SentinelNodeConnectionInfo>, + ) -> RedisResult { + let addresses = self + .find_valid_replica_addresses(service_name, node_connection_info.unwrap_or_default())?; + let start_index = NonZeroUsize::new(addresses.len()).map(random_replica_index); + try_connect_to_first_replica(&addresses, start_index) + } + + /// Attempts to connect to a different replica of the given master name each time. + /// There is no guarantee that we'll actually be connecting to a different replica + /// in the next call, but in a static set of replicas (no replicas added or + /// removed), on average we'll choose each replica the same number of times. + pub fn replica_rotate_for( + &mut self, + service_name: &str, + node_connection_info: Option<&SentinelNodeConnectionInfo>, + ) -> RedisResult { + let addresses = self + .find_valid_replica_addresses(service_name, node_connection_info.unwrap_or_default())?; + if !addresses.is_empty() { + self.replica_start_index = (self.replica_start_index + 1) % addresses.len(); + } + try_connect_to_first_replica(&addresses, Some(self.replica_start_index)) + } +} + +// Async versions of the public methods above, along with async versions of private +// methods required for the public methods. +#[cfg(feature = "aio")] +impl Sentinel { + async fn async_try_all_sentinels(&mut self, cmd: Cmd) -> RedisResult { + let mut last_err = None; + for (connection_info, cached_connection) in self + .sentinels_connection_info + .iter() + .zip(self.async_connections_cache.iter_mut()) + { + match async_try_single_sentinel(cmd.clone(), connection_info, cached_connection).await { + Ok(result) => { + return Ok(result); + } + Err(err) => { + last_err = Some(err); + } + } + } + + // We can unwrap here because we know there is at least one connection info. + Err(last_err.expect("There should be at least one connection info")) + } + + async fn async_get_sentinel_masters(&mut self) -> RedisResult>> { + self.async_try_all_sentinels(sentinel_masters_cmd()).await + } + + async fn async_get_sentinel_replicas<'a>( + &mut self, + service_name: &'a str, + ) -> RedisResult>> { + self.async_try_all_sentinels(sentinel_replicas_cmd(service_name)) + .await + } + + async fn async_find_master_address<'a>( + &mut self, + service_name: &str, + node_connection_info: &SentinelNodeConnectionInfo, + ) -> RedisResult { + let masters = self.async_get_sentinel_masters().await?; + async_find_valid_master(masters, service_name, node_connection_info).await + } + + async fn async_find_valid_replica_addresses<'a>( + &mut self, + service_name: &str, + node_connection_info: &SentinelNodeConnectionInfo, + ) -> RedisResult> { + let replicas = self.async_get_sentinel_replicas(service_name).await?; + Ok(async_get_valid_replicas_addresses(replicas, node_connection_info).await) + } + + /// Determines the masters address for the given name, and returns a client for that + /// master. + pub async fn async_master_for( + &mut self, + service_name: &str, + node_connection_info: Option<&SentinelNodeConnectionInfo>, + ) -> RedisResult { + let address = self + .async_find_master_address(service_name, node_connection_info.unwrap_or_default()) + .await?; + Client::open(address) + } + + /// Connects to a randomly chosen replica of the given master name. + pub async fn async_replica_for( + &mut self, + service_name: &str, + node_connection_info: Option<&SentinelNodeConnectionInfo>, + ) -> RedisResult { + let addresses = self + .async_find_valid_replica_addresses( + service_name, + node_connection_info.unwrap_or_default(), + ) + .await?; + let start_index = NonZeroUsize::new(addresses.len()).map(random_replica_index); + try_connect_to_first_replica(&addresses, start_index) + } + + /// Attempts to connect to a different replica of the given master name each time. + /// There is no guarantee that we'll actually be connecting to a different replica + /// in the next call, but in a static set of replicas (no replicas added or + /// removed), on average we'll choose each replica the same number of times. + pub async fn async_replica_rotate_for<'a>( + &mut self, + service_name: &str, + node_connection_info: Option<&SentinelNodeConnectionInfo>, + ) -> RedisResult { + let addresses = self + .async_find_valid_replica_addresses( + service_name, + node_connection_info.unwrap_or_default(), + ) + .await?; + if !addresses.is_empty() { + self.replica_start_index = (self.replica_start_index + 1) % addresses.len(); + } + try_connect_to_first_replica(&addresses, Some(self.replica_start_index)) + } +} + +/// Enum defining the server types from a sentinel's point of view. +#[derive(Debug, Clone)] +pub enum SentinelServerType { + /// Master connections only + Master, + /// Replica connections only + Replica, +} + +/// An alternative to the Client type which creates connections from clients created +/// on-demand based on information fetched from the sentinels. Uses the Sentinel type +/// internally. This is basic an utility to help make it easier to use sentinels but +/// with an interface similar to the client (`get_connection` and +/// `get_async_connection`). The type of server (master or replica) and name of the +/// desired master are specified when constructing an instance, so it will always +/// return connections to the same target (for example, always to the master with name +/// "mymaster123", or always to replicas of the master "another-master-abc"). +pub struct SentinelClient { + sentinel: Sentinel, + service_name: String, + node_connection_info: SentinelNodeConnectionInfo, + server_type: SentinelServerType, +} + +impl SentinelClient { + /// Creates a SentinelClient performing some basic checks on the URLs that might + /// result in an error. + pub fn build( + params: Vec, + service_name: String, + node_connection_info: Option, + server_type: SentinelServerType, + ) -> RedisResult { + Ok(SentinelClient { + sentinel: Sentinel::build(params)?, + service_name, + node_connection_info: node_connection_info.unwrap_or_default(), + server_type, + }) + } + + fn get_client(&mut self) -> RedisResult { + match self.server_type { + SentinelServerType::Master => self + .sentinel + .master_for(self.service_name.as_str(), Some(&self.node_connection_info)), + SentinelServerType::Replica => self + .sentinel + .replica_for(self.service_name.as_str(), Some(&self.node_connection_info)), + } + } + + /// Creates a new connection to the desired type of server (based on the + /// service/master name, and the server type). We use a Sentinel to create a client + /// for the target type of server, and then create a connection using that client. + pub fn get_connection(&mut self) -> RedisResult { + let client = self.get_client()?; + client.get_connection(None) + } +} + +/// To enable async support you need to enable the feature: `tokio-comp` +#[cfg(feature = "aio")] +#[cfg_attr(docsrs, doc(cfg(feature = "aio")))] +impl SentinelClient { + async fn async_get_client(&mut self) -> RedisResult { + match self.server_type { + SentinelServerType::Master => { + self.sentinel + .async_master_for(self.service_name.as_str(), Some(&self.node_connection_info)) + .await + } + SentinelServerType::Replica => { + self.sentinel + .async_replica_for(self.service_name.as_str(), Some(&self.node_connection_info)) + .await + } + } + } + + /// Returns an async connection from the client, using the same logic from + /// `SentinelClient::get_connection`. + #[cfg(feature = "tokio-comp")] + pub async fn get_async_connection(&mut self) -> RedisResult { + let client = self.async_get_client().await?; + client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await + } +} diff --git a/glide-core/redis-rs/redis/src/streams.rs b/glide-core/redis-rs/redis/src/streams.rs new file mode 100644 index 0000000000..62505d6d75 --- /dev/null +++ b/glide-core/redis-rs/redis/src/streams.rs @@ -0,0 +1,670 @@ +//! Defines types to use with the streams commands. + +use crate::{ + from_redis_value, types::HashMap, FromRedisValue, RedisResult, RedisWrite, ToRedisArgs, Value, +}; + +use std::io::{Error, ErrorKind}; + +// Stream Maxlen Enum + +/// Utility enum for passing `MAXLEN [= or ~] [COUNT]` +/// arguments into `StreamCommands`. +/// The enum value represents the count. +#[derive(PartialEq, Eq, Clone, Debug, Copy)] +pub enum StreamMaxlen { + /// Match an exact count + Equals(usize), + /// Match an approximate count + Approx(usize), +} + +impl ToRedisArgs for StreamMaxlen { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + let (ch, val) = match *self { + StreamMaxlen::Equals(v) => ("=", v), + StreamMaxlen::Approx(v) => ("~", v), + }; + out.write_arg(b"MAXLEN"); + out.write_arg(ch.as_bytes()); + val.write_redis_args(out); + } +} + +/// Builder options for [`xclaim_options`] command. +/// +/// [`xclaim_options`]: ../trait.Commands.html#method.xclaim_options +/// +#[derive(Default, Debug)] +pub struct StreamClaimOptions { + /// Set `IDLE ` cmd arg. + idle: Option, + /// Set `TIME ` cmd arg. + time: Option, + /// Set `RETRYCOUNT ` cmd arg. + retry: Option, + /// Set `FORCE` cmd arg. + force: bool, + /// Set `JUSTID` cmd arg. Be advised: the response + /// type changes with this option. + justid: bool, +} + +impl StreamClaimOptions { + /// Set `IDLE ` cmd arg. + pub fn idle(mut self, ms: usize) -> Self { + self.idle = Some(ms); + self + } + + /// Set `TIME ` cmd arg. + pub fn time(mut self, ms_time: usize) -> Self { + self.time = Some(ms_time); + self + } + + /// Set `RETRYCOUNT ` cmd arg. + pub fn retry(mut self, count: usize) -> Self { + self.retry = Some(count); + self + } + + /// Set `FORCE` cmd arg to true. + pub fn with_force(mut self) -> Self { + self.force = true; + self + } + + /// Set `JUSTID` cmd arg to true. Be advised: the response + /// type changes with this option. + pub fn with_justid(mut self) -> Self { + self.justid = true; + self + } +} + +impl ToRedisArgs for StreamClaimOptions { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + if let Some(ref ms) = self.idle { + out.write_arg(b"IDLE"); + out.write_arg(format!("{ms}").as_bytes()); + } + if let Some(ref ms_time) = self.time { + out.write_arg(b"TIME"); + out.write_arg(format!("{ms_time}").as_bytes()); + } + if let Some(ref count) = self.retry { + out.write_arg(b"RETRYCOUNT"); + out.write_arg(format!("{count}").as_bytes()); + } + if self.force { + out.write_arg(b"FORCE"); + } + if self.justid { + out.write_arg(b"JUSTID"); + } + } +} + +/// Argument to `StreamReadOptions` +/// Represents the Redis `GROUP ` cmd arg. +/// This option will toggle the cmd from `XREAD` to `XREADGROUP` +type SRGroup = Option<(Vec>, Vec>)>; +/// Builder options for [`xread_options`] command. +/// +/// [`xread_options`]: ../trait.Commands.html#method.xread_options +/// +#[derive(Default, Debug)] +pub struct StreamReadOptions { + /// Set the `BLOCK ` cmd arg. + block: Option, + /// Set the `COUNT ` cmd arg. + count: Option, + /// Set the `NOACK` cmd arg. + noack: Option, + /// Set the `GROUP ` cmd arg. + /// This option will toggle the cmd from XREAD to XREADGROUP. + group: SRGroup, +} + +impl StreamReadOptions { + /// Indicates whether the command is participating in a group + /// and generating ACKs + pub fn read_only(&self) -> bool { + self.group.is_none() + } + + /// Sets the command so that it avoids adding the message + /// to the PEL in cases where reliability is not a requirement + /// and the occasional message loss is acceptable. + pub fn noack(mut self) -> Self { + self.noack = Some(true); + self + } + + /// Sets the block time in milliseconds. + pub fn block(mut self, ms: usize) -> Self { + self.block = Some(ms); + self + } + + /// Sets the maximum number of elements to return per stream. + pub fn count(mut self, n: usize) -> Self { + self.count = Some(n); + self + } + + /// Sets the name of a consumer group associated to the stream. + pub fn group( + mut self, + group_name: GN, + consumer_name: CN, + ) -> Self { + self.group = Some(( + ToRedisArgs::to_redis_args(&group_name), + ToRedisArgs::to_redis_args(&consumer_name), + )); + self + } +} + +impl ToRedisArgs for StreamReadOptions { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + if let Some(ref group) = self.group { + out.write_arg(b"GROUP"); + for i in &group.0 { + out.write_arg(i); + } + for i in &group.1 { + out.write_arg(i); + } + } + + if let Some(ref ms) = self.block { + out.write_arg(b"BLOCK"); + out.write_arg(format!("{ms}").as_bytes()); + } + + if let Some(ref n) = self.count { + out.write_arg(b"COUNT"); + out.write_arg(format!("{n}").as_bytes()); + } + + if self.group.is_some() { + // noack is only available w/ xreadgroup + if self.noack == Some(true) { + out.write_arg(b"NOACK"); + } + } + } +} + +/// Reply type used with [`xread`] or [`xread_options`] commands. +/// +/// [`xread`]: ../trait.Commands.html#method.xread +/// [`xread_options`]: ../trait.Commands.html#method.xread_options +/// +#[derive(Default, Debug, Clone)] +pub struct StreamReadReply { + /// Complex data structure containing a payload for each key in this array + pub keys: Vec, +} + +/// Reply type used with [`xrange`], [`xrange_count`], [`xrange_all`], [`xrevrange`], [`xrevrange_count`], [`xrevrange_all`] commands. +/// +/// Represents stream entries matching a given range of `id`'s. +/// +/// [`xrange`]: ../trait.Commands.html#method.xrange +/// [`xrange_count`]: ../trait.Commands.html#method.xrange_count +/// [`xrange_all`]: ../trait.Commands.html#method.xrange_all +/// [`xrevrange`]: ../trait.Commands.html#method.xrevrange +/// [`xrevrange_count`]: ../trait.Commands.html#method.xrevrange_count +/// [`xrevrange_all`]: ../trait.Commands.html#method.xrevrange_all +/// +#[derive(Default, Debug, Clone)] +pub struct StreamRangeReply { + /// Complex data structure containing a payload for each ID in this array + pub ids: Vec, +} + +/// Reply type used with [`xclaim`] command. +/// +/// Represents that ownership of the specified messages was changed. +/// +/// [`xclaim`]: ../trait.Commands.html#method.xclaim +/// +#[derive(Default, Debug, Clone)] +pub struct StreamClaimReply { + /// Complex data structure containing a payload for each ID in this array + pub ids: Vec, +} + +/// Reply type used with [`xpending`] command. +/// +/// Data returned here were fetched from the stream without +/// having been acknowledged. +/// +/// [`xpending`]: ../trait.Commands.html#method.xpending +/// +#[derive(Debug, Clone, Default)] +pub enum StreamPendingReply { + /// The stream is empty. + #[default] + Empty, + /// Data with payload exists in the stream. + Data(StreamPendingData), +} + +impl StreamPendingReply { + /// Returns how many records are in the reply. + pub fn count(&self) -> usize { + match self { + StreamPendingReply::Empty => 0, + StreamPendingReply::Data(x) => x.count, + } + } +} + +/// Inner reply type when an [`xpending`] command has data. +/// +/// [`xpending`]: ../trait.Commands.html#method.xpending +#[derive(Default, Debug, Clone)] +pub struct StreamPendingData { + /// Limit on the number of messages to return per call. + pub count: usize, + /// ID for the first pending record. + pub start_id: String, + /// ID for the final pending record. + pub end_id: String, + /// Every consumer in the consumer group with at + /// least one pending message, + /// and the number of pending messages it has. + pub consumers: Vec, +} + +/// Reply type used with [`xpending_count`] and +/// [`xpending_consumer_count`] commands. +/// +/// Data returned here have been fetched from the stream without +/// any acknowledgement. +/// +/// [`xpending_count`]: ../trait.Commands.html#method.xpending_count +/// [`xpending_consumer_count`]: ../trait.Commands.html#method.xpending_consumer_count +/// +#[derive(Default, Debug, Clone)] +pub struct StreamPendingCountReply { + /// An array of structs containing information about + /// message IDs yet to be acknowledged by various consumers, + /// time since last ack, and total number of acks by that consumer. + pub ids: Vec, +} + +/// Reply type used with [`xinfo_stream`] command, containing +/// general information about the stream stored at the specified key. +/// +/// The very first and last IDs in the stream are shown, +/// in order to give some sense about what is the stream content. +/// +/// [`xinfo_stream`]: ../trait.Commands.html#method.xinfo_stream +/// +#[derive(Default, Debug, Clone)] +pub struct StreamInfoStreamReply { + /// The last generated ID that may not be the same as the last + /// entry ID in case some entry was deleted. + pub last_generated_id: String, + /// Details about the radix tree representing the stream mostly + /// useful for optimization and debugging tasks. + pub radix_tree_keys: usize, + /// The number of consumer groups associated with the stream. + pub groups: usize, + /// Number of elements of the stream. + pub length: usize, + /// The very first entry in the stream. + pub first_entry: StreamId, + /// The very last entry in the stream. + pub last_entry: StreamId, +} + +/// Reply type used with [`xinfo_consumer`] command, an array of every +/// consumer in a specific consumer group. +/// +/// [`xinfo_consumer`]: ../trait.Commands.html#method.xinfo_consumer +/// +#[derive(Default, Debug, Clone)] +pub struct StreamInfoConsumersReply { + /// An array of every consumer in a specific consumer group. + pub consumers: Vec, +} + +/// Reply type used with [`xinfo_groups`] command. +/// +/// This output represents all the consumer groups associated with +/// the stream. +/// +/// [`xinfo_groups`]: ../trait.Commands.html#method.xinfo_groups +/// +#[derive(Default, Debug, Clone)] +pub struct StreamInfoGroupsReply { + /// All the consumer groups associated with the stream. + pub groups: Vec, +} + +/// A consumer parsed from [`xinfo_consumers`] command. +/// +/// [`xinfo_consumers`]: ../trait.Commands.html#method.xinfo_consumers +/// +#[derive(Default, Debug, Clone)] +pub struct StreamInfoConsumer { + /// Name of the consumer group. + pub name: String, + /// Number of pending messages for this specific consumer. + pub pending: usize, + /// This consumer's idle time in milliseconds. + pub idle: usize, +} + +/// A group parsed from [`xinfo_groups`] command. +/// +/// [`xinfo_groups`]: ../trait.Commands.html#method.xinfo_groups +/// +#[derive(Default, Debug, Clone)] +pub struct StreamInfoGroup { + /// The group name. + pub name: String, + /// Number of consumers known in the group. + pub consumers: usize, + /// Number of pending messages (delivered but not yet acknowledged) in the group. + pub pending: usize, + /// Last ID delivered to this group. + pub last_delivered_id: String, +} + +/// Represents a pending message parsed from [`xpending`] methods. +/// +/// [`xpending`]: ../trait.Commands.html#method.xpending +#[derive(Default, Debug, Clone)] +pub struct StreamPendingId { + /// The ID of the message. + pub id: String, + /// The name of the consumer that fetched the message and has + /// still to acknowledge it. We call it the current owner + /// of the message. + pub consumer: String, + /// The number of milliseconds that elapsed since the + /// last time this message was delivered to this consumer. + pub last_delivered_ms: usize, + /// The number of times this message was delivered. + pub times_delivered: usize, +} + +/// Represents a stream `key` and its `id`'s parsed from `xread` methods. +#[derive(Default, Debug, Clone)] +pub struct StreamKey { + /// The stream `key`. + pub key: String, + /// The parsed stream `id`'s. + pub ids: Vec, +} + +/// Represents a stream `id` and its field/values as a `HashMap` +#[derive(Default, Debug, Clone)] +pub struct StreamId { + /// The stream `id` (entry ID) of this particular message. + pub id: String, + /// All fields in this message, associated with their respective values. + pub map: HashMap, +} + +impl StreamId { + /// Converts a `Value::Array` into a `StreamId`. + fn from_array_value(v: &Value) -> RedisResult { + let mut stream_id = StreamId::default(); + if let Value::Array(ref values) = *v { + if let Some(v) = values.first() { + stream_id.id = from_redis_value(v)?; + } + if let Some(v) = values.get(1) { + stream_id.map = from_redis_value(v)?; + } + } + + Ok(stream_id) + } + + /// Fetches value of a given field and converts it to the specified + /// type. + pub fn get(&self, key: &str) -> Option { + match self.map.get(key) { + Some(x) => from_redis_value(x).ok(), + None => None, + } + } + + /// Does the message contain a particular field? + pub fn contains_key(&self, key: &str) -> bool { + self.map.contains_key(key) + } + + /// Returns how many field/value pairs exist in this message. + pub fn len(&self) -> usize { + self.map.len() + } + + /// Returns true if there are no field/value pairs in this message. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +type SRRows = Vec>>>>; +impl FromRedisValue for StreamReadReply { + fn from_redis_value(v: &Value) -> RedisResult { + let rows: SRRows = from_redis_value(v)?; + let keys = rows + .into_iter() + .flat_map(|row| { + row.into_iter().map(|(key, entry)| { + let ids = entry + .into_iter() + .flat_map(|id_row| id_row.into_iter().map(|(id, map)| StreamId { id, map })) + .collect(); + StreamKey { key, ids } + }) + }) + .collect(); + Ok(StreamReadReply { keys }) + } +} + +impl FromRedisValue for StreamRangeReply { + fn from_redis_value(v: &Value) -> RedisResult { + let rows: Vec>> = from_redis_value(v)?; + let ids: Vec = rows + .into_iter() + .flat_map(|row| row.into_iter().map(|(id, map)| StreamId { id, map })) + .collect(); + Ok(StreamRangeReply { ids }) + } +} + +impl FromRedisValue for StreamClaimReply { + fn from_redis_value(v: &Value) -> RedisResult { + let rows: Vec>> = from_redis_value(v)?; + let ids: Vec = rows + .into_iter() + .flat_map(|row| row.into_iter().map(|(id, map)| StreamId { id, map })) + .collect(); + Ok(StreamClaimReply { ids }) + } +} + +type SPRInner = ( + usize, + Option, + Option, + Vec>, +); +impl FromRedisValue for StreamPendingReply { + fn from_redis_value(v: &Value) -> RedisResult { + let (count, start, end, consumer_data): SPRInner = from_redis_value(v)?; + + if count == 0 { + Ok(StreamPendingReply::Empty) + } else { + let mut result = StreamPendingData::default(); + + let start_id = start.ok_or_else(|| { + Error::new( + ErrorKind::Other, + "IllegalState: Non-zero pending expects start id", + ) + })?; + + let end_id = end.ok_or_else(|| { + Error::new( + ErrorKind::Other, + "IllegalState: Non-zero pending expects end id", + ) + })?; + + result.count = count; + result.start_id = start_id; + result.end_id = end_id; + + result.consumers = consumer_data + .into_iter() + .flatten() + .map(|(name, pending)| StreamInfoConsumer { + name, + pending: pending.parse().unwrap_or_default(), + ..Default::default() + }) + .collect(); + + Ok(StreamPendingReply::Data(result)) + } + } +} + +impl FromRedisValue for StreamPendingCountReply { + fn from_redis_value(v: &Value) -> RedisResult { + let mut reply = StreamPendingCountReply::default(); + match v { + Value::Array(outer_tuple) => { + for outer in outer_tuple { + match outer { + Value::Array(inner_tuple) => match &inner_tuple[..] { + [Value::BulkString(id_bytes), Value::BulkString(consumer_bytes), Value::Int(last_delivered_ms_u64), Value::Int(times_delivered_u64)] => + { + let id = String::from_utf8(id_bytes.to_vec())?; + let consumer = String::from_utf8(consumer_bytes.to_vec())?; + let last_delivered_ms = *last_delivered_ms_u64 as usize; + let times_delivered = *times_delivered_u64 as usize; + reply.ids.push(StreamPendingId { + id, + consumer, + last_delivered_ms, + times_delivered, + }); + } + _ => fail!(( + crate::types::ErrorKind::TypeError, + "Cannot parse redis data (3)" + )), + }, + _ => fail!(( + crate::types::ErrorKind::TypeError, + "Cannot parse redis data (2)" + )), + } + } + } + _ => fail!(( + crate::types::ErrorKind::TypeError, + "Cannot parse redis data (1)" + )), + }; + Ok(reply) + } +} + +impl FromRedisValue for StreamInfoStreamReply { + fn from_redis_value(v: &Value) -> RedisResult { + let map: HashMap = from_redis_value(v)?; + let mut reply = StreamInfoStreamReply::default(); + if let Some(v) = &map.get("last-generated-id") { + reply.last_generated_id = from_redis_value(v)?; + } + if let Some(v) = &map.get("radix-tree-nodes") { + reply.radix_tree_keys = from_redis_value(v)?; + } + if let Some(v) = &map.get("groups") { + reply.groups = from_redis_value(v)?; + } + if let Some(v) = &map.get("length") { + reply.length = from_redis_value(v)?; + } + if let Some(v) = &map.get("first-entry") { + reply.first_entry = StreamId::from_array_value(v)?; + } + if let Some(v) = &map.get("last-entry") { + reply.last_entry = StreamId::from_array_value(v)?; + } + Ok(reply) + } +} + +impl FromRedisValue for StreamInfoConsumersReply { + fn from_redis_value(v: &Value) -> RedisResult { + let consumers: Vec> = from_redis_value(v)?; + let mut reply = StreamInfoConsumersReply::default(); + for map in consumers { + let mut c = StreamInfoConsumer::default(); + if let Some(v) = &map.get("name") { + c.name = from_redis_value(v)?; + } + if let Some(v) = &map.get("pending") { + c.pending = from_redis_value(v)?; + } + if let Some(v) = &map.get("idle") { + c.idle = from_redis_value(v)?; + } + reply.consumers.push(c); + } + + Ok(reply) + } +} + +impl FromRedisValue for StreamInfoGroupsReply { + fn from_redis_value(v: &Value) -> RedisResult { + let groups: Vec> = from_redis_value(v)?; + let mut reply = StreamInfoGroupsReply::default(); + for map in groups { + let mut g = StreamInfoGroup::default(); + if let Some(v) = &map.get("name") { + g.name = from_redis_value(v)?; + } + if let Some(v) = &map.get("pending") { + g.pending = from_redis_value(v)?; + } + if let Some(v) = &map.get("consumers") { + g.consumers = from_redis_value(v)?; + } + if let Some(v) = &map.get("last-delivered-id") { + g.last_delivered_id = from_redis_value(v)?; + } + reply.groups.push(g); + } + Ok(reply) + } +} diff --git a/glide-core/redis-rs/redis/src/tls.rs b/glide-core/redis-rs/redis/src/tls.rs new file mode 100644 index 0000000000..6886efb836 --- /dev/null +++ b/glide-core/redis-rs/redis/src/tls.rs @@ -0,0 +1,142 @@ +use std::io::{BufRead, Error, ErrorKind as IOErrorKind}; + +use rustls::RootCertStore; +use rustls_pki_types::{CertificateDer, PrivateKeyDer}; + +use crate::{Client, ConnectionAddr, ConnectionInfo, ErrorKind, RedisError, RedisResult}; + +/// Structure to hold mTLS client _certificate_ and _key_ binaries in PEM format +/// +#[derive(Clone)] +pub struct ClientTlsConfig { + /// client certificate byte stream in PEM format + pub client_cert: Vec, + /// client key byte stream in PEM format + pub client_key: Vec, +} + +/// Structure to hold TLS certificates +/// - `client_tls`: binaries of clientkey and certificate within a `ClientTlsConfig` structure if mTLS is used +/// - `root_cert`: binary CA certificate in PEM format if CA is not in local truststore +/// +#[derive(Clone)] +pub struct TlsCertificates { + /// 'ClientTlsConfig' containing client certificate and key if mTLS is to be used + pub client_tls: Option, + /// root certificate byte stream in PEM format if the local truststore is *not* to be used + pub root_cert: Option>, +} + +pub(crate) fn inner_build_with_tls( + mut connection_info: ConnectionInfo, + certificates: TlsCertificates, +) -> RedisResult { + let tls_params = retrieve_tls_certificates(certificates)?; + + connection_info.addr = if let ConnectionAddr::TcpTls { + host, + port, + insecure, + .. + } = connection_info.addr + { + ConnectionAddr::TcpTls { + host, + port, + insecure, + tls_params: Some(tls_params), + } + } else { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Constructing a TLS client requires a URL with the `rediss://` scheme", + ))); + }; + + Ok(Client { connection_info }) +} + +pub(crate) fn retrieve_tls_certificates( + certificates: TlsCertificates, +) -> RedisResult { + let TlsCertificates { + client_tls, + root_cert, + } = certificates; + + let client_tls_params = if let Some(ClientTlsConfig { + client_cert, + client_key, + }) = client_tls + { + let buf = &mut client_cert.as_slice() as &mut dyn BufRead; + let certs = rustls_pemfile::certs(buf); + let client_cert_chain = certs.collect::, _>>()?; + + let client_key = + rustls_pemfile::private_key(&mut client_key.as_slice() as &mut dyn BufRead)? + .ok_or_else(|| { + Error::new( + IOErrorKind::Other, + "Unable to extract private key from PEM file", + ) + })?; + + Some(ClientTlsParams { + client_cert_chain, + client_key, + }) + } else { + None + }; + + let root_cert_store = if let Some(root_cert) = root_cert { + let buf = &mut root_cert.as_slice() as &mut dyn BufRead; + let certs = rustls_pemfile::certs(buf); + let mut root_cert_store = RootCertStore::empty(); + for result in certs { + if root_cert_store.add(result?.to_owned()).is_err() { + return Err( + Error::new(IOErrorKind::Other, "Unable to parse TLS trust anchors").into(), + ); + } + } + + Some(root_cert_store) + } else { + None + }; + + Ok(TlsConnParams { + client_tls_params, + root_cert_store, + }) +} + +#[derive(Debug)] +pub struct ClientTlsParams { + pub(crate) client_cert_chain: Vec>, + pub(crate) client_key: PrivateKeyDer<'static>, +} + +/// [`PrivateKeyDer`] does not implement `Clone` so we need to implement it manually. +impl Clone for ClientTlsParams { + fn clone(&self) -> Self { + use PrivateKeyDer::*; + Self { + client_cert_chain: self.client_cert_chain.clone(), + client_key: match &self.client_key { + Pkcs1(key) => Pkcs1(key.secret_pkcs1_der().to_vec().into()), + Pkcs8(key) => Pkcs8(key.secret_pkcs8_der().to_vec().into()), + Sec1(key) => Sec1(key.secret_sec1_der().to_vec().into()), + _ => unreachable!(), + }, + } + } +} + +#[derive(Debug, Clone)] +pub struct TlsConnParams { + pub(crate) client_tls_params: Option, + pub(crate) root_cert_store: Option, +} diff --git a/glide-core/redis-rs/redis/src/types.rs b/glide-core/redis-rs/redis/src/types.rs new file mode 100644 index 0000000000..2d8035d697 --- /dev/null +++ b/glide-core/redis-rs/redis/src/types.rs @@ -0,0 +1,2491 @@ +use std::collections::{BTreeMap, BTreeSet}; +use std::default::Default; +use std::error; +use std::ffi::{CString, NulError}; +use std::fmt; +use std::hash::{BuildHasher, Hash}; +use std::io; +use std::str::{from_utf8, Utf8Error}; +use std::string::FromUtf8Error; + +use num_bigint::BigInt; +pub(crate) use std::collections::{HashMap, HashSet}; +use std::ops::Deref; + +macro_rules! invalid_type_error { + ($v:expr, $det:expr) => {{ + fail!(invalid_type_error_inner!($v, $det)) + }}; +} + +macro_rules! invalid_type_error_inner { + ($v:expr, $det:expr) => { + RedisError::from(( + ErrorKind::TypeError, + "Response was of incompatible type", + format!("{:?} (response was {:?})", $det, $v), + )) + }; +} + +/// Helper enum that is used to define expiry time +pub enum Expiry { + /// EX seconds -- Set the specified expire time, in seconds. + EX(usize), + /// PX milliseconds -- Set the specified expire time, in milliseconds. + PX(usize), + /// EXAT timestamp-seconds -- Set the specified Unix time at which the key will expire, in seconds. + EXAT(usize), + /// PXAT timestamp-milliseconds -- Set the specified Unix time at which the key will expire, in milliseconds. + PXAT(usize), + /// PERSIST -- Remove the time to live associated with the key. + PERSIST, +} + +/// Helper enum that is used to define expiry time for SET command +#[derive(Clone, Copy)] +pub enum SetExpiry { + /// EX seconds -- Set the specified expire time, in seconds. + EX(usize), + /// PX milliseconds -- Set the specified expire time, in milliseconds. + PX(usize), + /// EXAT timestamp-seconds -- Set the specified Unix time at which the key will expire, in seconds. + EXAT(usize), + /// PXAT timestamp-milliseconds -- Set the specified Unix time at which the key will expire, in milliseconds. + PXAT(usize), + /// KEEPTTL -- Retain the time to live associated with the key. + KEEPTTL, +} + +/// Helper enum that is used to define existence checks +#[derive(Clone, Copy)] +pub enum ExistenceCheck { + /// NX -- Only set the key if it does not already exist. + NX, + /// XX -- Only set the key if it already exists. + XX, +} + +/// Helper enum that is used in some situations to describe +/// the behavior of arguments in a numeric context. +#[derive(PartialEq, Eq, Clone, Debug, Copy)] +pub enum NumericBehavior { + /// This argument is not numeric. + NonNumeric, + /// This argument is an integer. + NumberIsInteger, + /// This argument is a floating point value. + NumberIsFloat, +} + +/// An enum of all error kinds. +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +#[non_exhaustive] +pub enum ErrorKind { + /// The server generated an invalid response. + ResponseError, + /// The parser failed to parse the server response. + ParseError, + /// The authentication with the server failed. + AuthenticationFailed, + /// Operation failed because of a type mismatch. + TypeError, + /// A script execution was aborted. + ExecAbortError, + /// The server cannot response because it's loading a dump. + BusyLoadingError, + /// A script that was requested does not actually exist. + NoScriptError, + /// An error that was caused because the parameter to the + /// client were wrong. + InvalidClientConfig, + /// Raised if a key moved to a different node. + Moved, + /// Raised if a key moved to a different node but we need to ask. + Ask, + /// Raised if a request needs to be retried. + TryAgain, + /// Raised if a redis cluster is down. + ClusterDown, + /// A request spans multiple slots + CrossSlot, + /// A cluster master is unavailable. + MasterDown, + /// This kind is returned if the redis error is one that is + /// not native to the system. This is usually the case if + /// the cause is another error. + IoError, + /// An error indicating that a fatal error occurred while attempting to send a request to the server, + /// meaning the connection was closed before the request was transmitted. Since the server did not process the request, + /// it is safe to retry the request. + FatalSendError, + /// An error indicating that a fatal error occurred while trying to receive a response, + /// likely due to the closure of the underlying connection. It is unclear whether + /// the server processed the request, making it unsafe to retry the request. + FatalReceiveError, + /// An error raised that was identified on the client before execution. + ClientError, + /// An extension error. This is an error created by the server + /// that is not directly understood by the library. + ExtensionError, + /// Attempt to write to a read-only server + ReadOnly, + /// Requested name not found among masters returned by the sentinels + MasterNameNotFoundBySentinel, + /// No valid replicas found in the sentinels, for a given master name + NoValidReplicasFoundBySentinel, + /// At least one sentinel connection info is required + EmptySentinelList, + /// Attempted to kill a script/function while they weren't executing + NotBusy, + /// Used when no valid node connections remain in the cluster connection + AllConnectionsUnavailable, + /// Used when a connection is not found for the specified route. + ConnectionNotFoundForRoute, + + #[cfg(feature = "json")] + /// Error Serializing a struct to JSON form + Serialize, + + /// Redis Servers prior to v6.0.0 doesn't support RESP3. + /// Try disabling resp3 option + RESP3NotSupported, + + /// Not all slots are covered by the cluster + NotAllSlotsCovered, + + /// Used when an error occurs on when user perform wrong usage of management operation. + /// E.g. not allowed configuration change. + UserOperationError, +} + +#[derive(PartialEq, Debug)] +pub(crate) enum ServerErrorKind { + ResponseError, + ExecAbortError, + BusyLoadingError, + NoScriptError, + Moved, + Ask, + TryAgain, + ClusterDown, + CrossSlot, + MasterDown, + ReadOnly, + NotBusy, +} + +#[derive(PartialEq, Debug)] +pub(crate) enum ServerError { + ExtensionError { + code: String, + detail: Option, + }, + KnownError { + kind: ServerErrorKind, + detail: Option, + }, +} + +impl From for RedisError { + fn from(_: tokio::time::error::Elapsed) -> Self { + RedisError::from((ErrorKind::IoError, "Operation timed out")) + } +} + +impl From for RedisError { + fn from(value: ServerError) -> Self { + // TODO - Consider changing RedisError to explicitly represent whether an error came from the server or not. Today it is only implied. + match value { + ServerError::ExtensionError { code, detail } => make_extension_error(code, detail), + ServerError::KnownError { kind, detail } => { + let desc = "An error was signalled by the server"; + let kind = match kind { + ServerErrorKind::ResponseError => ErrorKind::ResponseError, + ServerErrorKind::ExecAbortError => ErrorKind::ExecAbortError, + ServerErrorKind::BusyLoadingError => ErrorKind::BusyLoadingError, + ServerErrorKind::NoScriptError => ErrorKind::NoScriptError, + ServerErrorKind::Moved => ErrorKind::Moved, + ServerErrorKind::Ask => ErrorKind::Ask, + ServerErrorKind::TryAgain => ErrorKind::TryAgain, + ServerErrorKind::ClusterDown => ErrorKind::ClusterDown, + ServerErrorKind::CrossSlot => ErrorKind::CrossSlot, + ServerErrorKind::MasterDown => ErrorKind::MasterDown, + ServerErrorKind::ReadOnly => ErrorKind::ReadOnly, + ServerErrorKind::NotBusy => ErrorKind::NotBusy, + }; + match detail { + Some(detail) => RedisError::from((kind, desc, detail)), + None => RedisError::from((kind, desc)), + } + } + } + } +} + +/// Internal low-level redis value enum. +#[derive(PartialEq, Debug)] +pub(crate) enum InternalValue { + /// A nil response from the server. + Nil, + /// An integer response. Note that there are a few situations + /// in which redis actually returns a string for an integer which + /// is why this library generally treats integers and strings + /// the same for all numeric responses. + Int(i64), + /// An arbitrary binary data, usually represents a binary-safe string. + BulkString(Vec), + /// A response containing an array with more data. This is generally used by redis + /// to express nested structures. + Array(Vec), + /// A simple string response, without line breaks and not binary safe. + SimpleString(String), + /// A status response which represents the string "OK". + Okay, + /// Unordered key,value list from the server. Use `as_map_iter` function. + Map(Vec<(InternalValue, InternalValue)>), + /// Attribute value from the server. Client will give data instead of whole Attribute type. + Attribute { + /// Data that attributes belong to. + data: Box, + /// Key,Value list of attributes. + attributes: Vec<(InternalValue, InternalValue)>, + }, + /// Unordered set value from the server. + Set(Vec), + /// A floating number response from the server. + Double(f64), + /// A boolean response from the server. + Boolean(bool), + /// First String is format and other is the string + VerbatimString { + /// Text's format type + format: VerbatimFormat, + /// Remaining string check format before using! + text: String, + }, + /// Very large number that out of the range of the signed 64 bit numbers + BigNumber(BigInt), + /// Push data from the server. + Push { + /// Push Kind + kind: PushKind, + /// Remaining data from push message + data: Vec, + }, + ServerError(ServerError), +} + +impl InternalValue { + pub(crate) fn try_into(self) -> RedisResult { + match self { + InternalValue::Nil => Ok(Value::Nil), + InternalValue::Int(val) => Ok(Value::Int(val)), + InternalValue::BulkString(val) => Ok(Value::BulkString(val)), + InternalValue::Array(val) => Ok(Value::Array(Self::try_into_vec(val)?)), + InternalValue::SimpleString(val) => Ok(Value::SimpleString(val)), + InternalValue::Okay => Ok(Value::Okay), + InternalValue::Map(map) => Ok(Value::Map(Self::try_into_map(map)?)), + InternalValue::Attribute { data, attributes } => { + let data = Box::new((*data).try_into()?); + let attributes = Self::try_into_map(attributes)?; + Ok(Value::Attribute { data, attributes }) + } + InternalValue::Set(set) => Ok(Value::Set(Self::try_into_vec(set)?)), + InternalValue::Double(double) => Ok(Value::Double(double)), + InternalValue::Boolean(boolean) => Ok(Value::Boolean(boolean)), + InternalValue::VerbatimString { format, text } => { + Ok(Value::VerbatimString { format, text }) + } + InternalValue::BigNumber(number) => Ok(Value::BigNumber(number)), + InternalValue::Push { kind, data } => Ok(Value::Push { + kind, + data: Self::try_into_vec(data)?, + }), + + InternalValue::ServerError(err) => Err(err.into()), + } + } + + fn try_into_vec(vec: Vec) -> RedisResult> { + vec.into_iter() + .map(InternalValue::try_into) + .collect::>>() + } + + fn try_into_map(map: Vec<(InternalValue, InternalValue)>) -> RedisResult> { + let mut vec = Vec::with_capacity(map.len()); + for (key, value) in map.into_iter() { + vec.push((key.try_into()?, value.try_into()?)); + } + Ok(vec) + } +} + +/// Internal low-level redis value enum. +#[derive(PartialEq, Clone)] +pub enum Value { + /// A nil response from the server. + Nil, + /// An integer response. Note that there are a few situations + /// in which redis actually returns a string for an integer which + /// is why this library generally treats integers and strings + /// the same for all numeric responses. + Int(i64), + /// An arbitrary binary data, usually represents a binary-safe string. + BulkString(Vec), + /// A response containing an array with more data. This is generally used by redis + /// to express nested structures. + Array(Vec), + /// A simple string response, without line breaks and not binary safe. + SimpleString(String), + /// A status response which represents the string "OK". + Okay, + /// Unordered key,value list from the server. Use `as_map_iter` function. + Map(Vec<(Value, Value)>), + /// Attribute value from the server. Client will give data instead of whole Attribute type. + Attribute { + /// Data that attributes belong to. + data: Box, + /// Key,Value list of attributes. + attributes: Vec<(Value, Value)>, + }, + /// Unordered set value from the server. + Set(Vec), + /// A floating number response from the server. + Double(f64), + /// A boolean response from the server. + Boolean(bool), + /// First String is format and other is the string + VerbatimString { + /// Text's format type + format: VerbatimFormat, + /// Remaining string check format before using! + text: String, + }, + /// Very large number that out of the range of the signed 64 bit numbers + BigNumber(BigInt), + /// Push data from the server. + Push { + /// Push Kind + kind: PushKind, + /// Remaining data from push message + data: Vec, + }, +} + +/// `VerbatimString`'s format types defined by spec +#[derive(PartialEq, Clone, Debug)] +pub enum VerbatimFormat { + /// Unknown type to catch future formats. + Unknown(String), + /// `mkd` format + Markdown, + /// `txt` format + Text, +} + +/// `Push` type's currently known kinds. +#[derive(PartialEq, Clone, Debug)] +pub enum PushKind { + /// `Disconnection` is sent from the **library** when connection is closed. + Disconnection, + /// Other kind to catch future kinds. + Other(String), + /// `invalidate` is received when a key is changed/deleted. + Invalidate, + /// `message` is received when pubsub message published by another client. + Message, + /// `pmessage` is received when pubsub message published by another client and client subscribed to topic via pattern. + PMessage, + /// `smessage` is received when pubsub message published by another client and client subscribed to it with sharding. + SMessage, + /// `unsubscribe` is received when client unsubscribed from a channel. + Unsubscribe, + /// `punsubscribe` is received when client unsubscribed from a pattern. + PUnsubscribe, + /// `sunsubscribe` is received when client unsubscribed from a shard channel. + SUnsubscribe, + /// `subscribe` is received when client subscribed to a channel. + Subscribe, + /// `psubscribe` is received when client subscribed to a pattern. + PSubscribe, + /// `ssubscribe` is received when client subscribed to a shard channel. + SSubscribe, +} + +impl PushKind { + #[cfg(feature = "aio")] + pub(crate) fn has_reply(&self) -> bool { + matches!( + self, + &PushKind::Unsubscribe + | &PushKind::PUnsubscribe + | &PushKind::SUnsubscribe + | &PushKind::Subscribe + | &PushKind::PSubscribe + | &PushKind::SSubscribe + ) + } +} + +impl fmt::Display for VerbatimFormat { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + VerbatimFormat::Markdown => write!(f, "mkd"), + VerbatimFormat::Unknown(val) => write!(f, "{val}"), + VerbatimFormat::Text => write!(f, "txt"), + } + } +} + +impl fmt::Display for PushKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + PushKind::Other(kind) => write!(f, "{}", kind), + PushKind::Invalidate => write!(f, "invalidate"), + PushKind::Message => write!(f, "message"), + PushKind::PMessage => write!(f, "pmessage"), + PushKind::SMessage => write!(f, "smessage"), + PushKind::Unsubscribe => write!(f, "unsubscribe"), + PushKind::PUnsubscribe => write!(f, "punsubscribe"), + PushKind::SUnsubscribe => write!(f, "sunsubscribe"), + PushKind::Subscribe => write!(f, "subscribe"), + PushKind::PSubscribe => write!(f, "psubscribe"), + PushKind::SSubscribe => write!(f, "ssubscribe"), + PushKind::Disconnection => write!(f, "disconnection"), + } + } +} + +pub enum MapIter<'a> { + Array(std::slice::Iter<'a, Value>), + Map(std::slice::Iter<'a, (Value, Value)>), +} + +impl<'a> Iterator for MapIter<'a> { + type Item = (&'a Value, &'a Value); + + fn next(&mut self) -> Option { + match self { + MapIter::Array(iter) => Some((iter.next()?, iter.next()?)), + MapIter::Map(iter) => { + let (k, v) = iter.next()?; + Some((k, v)) + } + } + } + + fn size_hint(&self) -> (usize, Option) { + match self { + MapIter::Array(iter) => iter.size_hint(), + MapIter::Map(iter) => iter.size_hint(), + } + } +} + +pub enum OwnedMapIter { + Array(std::vec::IntoIter), + Map(std::vec::IntoIter<(Value, Value)>), +} + +impl Iterator for OwnedMapIter { + type Item = (Value, Value); + + fn next(&mut self) -> Option { + match self { + OwnedMapIter::Array(iter) => Some((iter.next()?, iter.next()?)), + OwnedMapIter::Map(iter) => iter.next(), + } + } + + fn size_hint(&self) -> (usize, Option) { + match self { + OwnedMapIter::Array(iter) => { + let (low, high) = iter.size_hint(); + (low / 2, high.map(|h| h / 2)) + } + OwnedMapIter::Map(iter) => iter.size_hint(), + } + } +} + +/// Values are generally not used directly unless you are using the +/// more low level functionality in the library. For the most part +/// this is hidden with the help of the `FromRedisValue` trait. +/// +/// While on the redis protocol there is an error type this is already +/// separated at an early point so the value only holds the remaining +/// types. +impl Value { + /// Checks if the return value looks like it fulfils the cursor + /// protocol. That means the result is an array item of length + /// two with the first one being a cursor and the second an + /// array response. + pub fn looks_like_cursor(&self) -> bool { + match *self { + Value::Array(ref items) => { + if items.len() != 2 { + return false; + } + matches!(items[0], Value::BulkString(_)) && matches!(items[1], Value::Array(_)) + } + _ => false, + } + } + + /// Returns an `&[Value]` if `self` is compatible with a sequence type + pub fn as_sequence(&self) -> Option<&[Value]> { + match self { + Value::Array(items) => Some(&items[..]), + Value::Set(items) => Some(&items[..]), + Value::Nil => Some(&[]), + _ => None, + } + } + + /// Returns a `Vec` if `self` is compatible with a sequence type, + /// otherwise returns `Err(self)`. + pub fn into_sequence(self) -> Result, Value> { + match self { + Value::Array(items) => Ok(items), + Value::Set(items) => Ok(items), + Value::Nil => Ok(vec![]), + _ => Err(self), + } + } + + /// Returns an iterator of `(&Value, &Value)` if `self` is compatible with a map type + pub fn as_map_iter(&self) -> Option> { + match self { + Value::Array(items) => { + if items.len() % 2 == 0 { + Some(MapIter::Array(items.iter())) + } else { + None + } + } + Value::Map(items) => Some(MapIter::Map(items.iter())), + _ => None, + } + } + + /// Returns an iterator of `(Value, Value)` if `self` is compatible with a map type. + /// If not, returns `Err(self)`. + pub fn into_map_iter(self) -> Result { + match self { + Value::Array(items) => { + if items.len() % 2 == 0 { + Ok(OwnedMapIter::Array(items.into_iter())) + } else { + Err(Value::Array(items)) + } + } + Value::Map(items) => Ok(OwnedMapIter::Map(items.into_iter())), + _ => Err(self), + } + } +} + +impl fmt::Debug for Value { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Value::Nil => write!(fmt, "nil"), + Value::Int(val) => write!(fmt, "int({val:?})"), + Value::BulkString(ref val) => match from_utf8(val) { + Ok(x) => write!(fmt, "bulk-string('{x:?}')"), + Err(_) => write!(fmt, "binary-data({val:?})"), + }, + Value::Array(ref values) => write!(fmt, "array({values:?})"), + Value::Push { ref kind, ref data } => write!(fmt, "push({kind:?}, {data:?})"), + Value::Okay => write!(fmt, "ok"), + Value::SimpleString(ref s) => write!(fmt, "simple-string({s:?})"), + Value::Map(ref values) => write!(fmt, "map({values:?})"), + Value::Attribute { + ref data, + attributes: _, + } => write!(fmt, "attribute({data:?})"), + Value::Set(ref values) => write!(fmt, "set({values:?})"), + Value::Double(ref d) => write!(fmt, "double({d:?})"), + Value::Boolean(ref b) => write!(fmt, "boolean({b:?})"), + Value::VerbatimString { + ref format, + ref text, + } => { + write!(fmt, "verbatim-string({:?},{:?})", format, text) + } + Value::BigNumber(ref m) => write!(fmt, "big-number({:?})", m), + } + } +} + +/// Represents a redis error. For the most part you should be using +/// the Error trait to interact with this rather than the actual +/// struct. +pub struct RedisError { + repr: ErrorRepr, +} + +#[cfg(feature = "json")] +impl From for RedisError { + fn from(serde_err: serde_json::Error) -> RedisError { + RedisError::from(( + ErrorKind::Serialize, + "Serialization Error", + format!("{serde_err}"), + )) + } +} + +#[derive(Debug)] +enum ErrorRepr { + WithDescription(ErrorKind, &'static str), + WithDescriptionAndDetail(ErrorKind, &'static str, String), + ExtensionError(String, String), + IoError(io::Error), +} + +impl PartialEq for RedisError { + fn eq(&self, other: &RedisError) -> bool { + match (&self.repr, &other.repr) { + (&ErrorRepr::WithDescription(kind_a, _), &ErrorRepr::WithDescription(kind_b, _)) => { + kind_a == kind_b + } + ( + &ErrorRepr::WithDescriptionAndDetail(kind_a, _, _), + &ErrorRepr::WithDescriptionAndDetail(kind_b, _, _), + ) => kind_a == kind_b, + (ErrorRepr::ExtensionError(a, _), ErrorRepr::ExtensionError(b, _)) => *a == *b, + _ => false, + } + } +} + +impl From for RedisError { + fn from(err: io::Error) -> RedisError { + RedisError { + repr: ErrorRepr::IoError(err), + } + } +} + +impl From for RedisError { + fn from(_: Utf8Error) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescription(ErrorKind::TypeError, "Invalid UTF-8"), + } + } +} + +impl From for RedisError { + fn from(err: NulError) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescriptionAndDetail( + ErrorKind::TypeError, + "Value contains interior nul terminator", + err.to_string(), + ), + } + } +} + +#[cfg(feature = "tls-native-tls")] +impl From for RedisError { + fn from(err: native_tls::Error) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescriptionAndDetail( + ErrorKind::IoError, + "TLS error", + err.to_string(), + ), + } + } +} + +#[cfg(feature = "tls-rustls")] +impl From for RedisError { + fn from(err: rustls::Error) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescriptionAndDetail( + ErrorKind::IoError, + "TLS error", + err.to_string(), + ), + } + } +} + +#[cfg(feature = "tls-rustls")] +impl From for RedisError { + fn from(err: rustls_pki_types::InvalidDnsNameError) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescriptionAndDetail( + ErrorKind::IoError, + "TLS Error", + err.to_string(), + ), + } + } +} + +#[cfg(feature = "uuid")] +impl From for RedisError { + fn from(err: uuid::Error) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescriptionAndDetail( + ErrorKind::TypeError, + "Value is not a valid UUID", + err.to_string(), + ), + } + } +} + +impl From for RedisError { + fn from(_: FromUtf8Error) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescription(ErrorKind::TypeError, "Cannot convert from UTF-8"), + } + } +} + +impl From<(ErrorKind, &'static str)> for RedisError { + fn from((kind, desc): (ErrorKind, &'static str)) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescription(kind, desc), + } + } +} + +impl From<(ErrorKind, &'static str, String)> for RedisError { + fn from((kind, desc, detail): (ErrorKind, &'static str, String)) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescriptionAndDetail(kind, desc, detail), + } + } +} + +impl error::Error for RedisError { + #[allow(deprecated)] + fn description(&self) -> &str { + match self.repr { + ErrorRepr::WithDescription(_, desc) => desc, + ErrorRepr::WithDescriptionAndDetail(_, desc, _) => desc, + ErrorRepr::ExtensionError(_, _) => "extension error", + ErrorRepr::IoError(ref err) => err.description(), + } + } + + fn cause(&self) -> Option<&dyn error::Error> { + match self.repr { + ErrorRepr::IoError(ref err) => Some(err as &dyn error::Error), + _ => None, + } + } +} + +impl fmt::Display for RedisError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + match self.repr { + ErrorRepr::WithDescription(kind, desc) => { + desc.fmt(f)?; + f.write_str("- ")?; + fmt::Debug::fmt(&kind, f) + } + ErrorRepr::WithDescriptionAndDetail(kind, desc, ref detail) => { + desc.fmt(f)?; + f.write_str(" - ")?; + fmt::Debug::fmt(&kind, f)?; + f.write_str(": ")?; + detail.fmt(f) + } + ErrorRepr::ExtensionError(ref code, ref detail) => { + code.fmt(f)?; + f.write_str(": ")?; + detail.fmt(f) + } + ErrorRepr::IoError(ref err) => err.fmt(f), + } + } +} + +impl fmt::Debug for RedisError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + fmt::Display::fmt(self, f) + } +} + +pub(crate) enum RetryMethod { + Reconnect, + ReconnectAndRetry, + NoRetry, + RetryImmediately, + WaitAndRetry, + AskRedirect, + MovedRedirect, + WaitAndRetryOnPrimaryRedirectOnReplica, +} + +/// Indicates a general failure in the library. +impl RedisError { + /// Returns the kind of the error. + pub fn kind(&self) -> ErrorKind { + match self.repr { + ErrorRepr::WithDescription(kind, _) + | ErrorRepr::WithDescriptionAndDetail(kind, _, _) => kind, + ErrorRepr::ExtensionError(_, _) => ErrorKind::ExtensionError, + ErrorRepr::IoError(_) => ErrorKind::IoError, + } + } + + /// Returns the error detail. + pub fn detail(&self) -> Option<&str> { + match self.repr { + ErrorRepr::WithDescriptionAndDetail(_, _, ref detail) + | ErrorRepr::ExtensionError(_, ref detail) => Some(detail.as_str()), + _ => None, + } + } + + /// Returns the raw error code if available. + pub fn code(&self) -> Option<&str> { + match self.kind() { + ErrorKind::ResponseError => Some("ERR"), + ErrorKind::ExecAbortError => Some("EXECABORT"), + ErrorKind::BusyLoadingError => Some("LOADING"), + ErrorKind::NoScriptError => Some("NOSCRIPT"), + ErrorKind::Moved => Some("MOVED"), + ErrorKind::Ask => Some("ASK"), + ErrorKind::TryAgain => Some("TRYAGAIN"), + ErrorKind::ClusterDown => Some("CLUSTERDOWN"), + ErrorKind::CrossSlot => Some("CROSSSLOT"), + ErrorKind::MasterDown => Some("MASTERDOWN"), + ErrorKind::ReadOnly => Some("READONLY"), + ErrorKind::NotBusy => Some("NOTBUSY"), + _ => match self.repr { + ErrorRepr::ExtensionError(ref code, _) => Some(code), + _ => None, + }, + } + } + + /// Returns the name of the error category for display purposes. + pub fn category(&self) -> &str { + match self.kind() { + ErrorKind::ResponseError => "response error", + ErrorKind::AuthenticationFailed => "authentication failed", + ErrorKind::TypeError => "type error", + ErrorKind::ExecAbortError => "script execution aborted", + ErrorKind::BusyLoadingError => "busy loading", + ErrorKind::NoScriptError => "no script", + ErrorKind::InvalidClientConfig => "invalid client config", + ErrorKind::Moved => "key moved", + ErrorKind::Ask => "key moved (ask)", + ErrorKind::TryAgain => "try again", + ErrorKind::ClusterDown => "cluster down", + ErrorKind::CrossSlot => "cross-slot", + ErrorKind::MasterDown => "master down", + ErrorKind::IoError => "I/O error", + ErrorKind::FatalSendError => { + "failed to send the request to the server due to a fatal error - the request was not transmitted" + } + ErrorKind::FatalReceiveError => "a fatal error occurred while attempting to receive a response from the server", + ErrorKind::ExtensionError => "extension error", + ErrorKind::ClientError => "client error", + ErrorKind::ReadOnly => "read-only", + ErrorKind::MasterNameNotFoundBySentinel => "master name not found by sentinel", + ErrorKind::NoValidReplicasFoundBySentinel => "no valid replicas found by sentinel", + ErrorKind::EmptySentinelList => "empty sentinel list", + ErrorKind::NotBusy => "not busy", + ErrorKind::AllConnectionsUnavailable => "no valid connections remain in the cluster", + ErrorKind::ConnectionNotFoundForRoute => "No connection found for the requested route", + #[cfg(feature = "json")] + ErrorKind::Serialize => "serializing", + ErrorKind::RESP3NotSupported => "resp3 is not supported by server", + ErrorKind::ParseError => "parse error", + ErrorKind::NotAllSlotsCovered => "not all slots are covered", + ErrorKind::UserOperationError => "Wrong usage of management operation", + } + } + + /// Indicates that this failure is an IO failure. + pub fn is_io_error(&self) -> bool { + self.as_io_error().is_some() + } + + pub(crate) fn as_io_error(&self) -> Option<&io::Error> { + match &self.repr { + ErrorRepr::IoError(e) => Some(e), + _ => None, + } + } + + /// Indicates that this is a cluster error. + pub fn is_cluster_error(&self) -> bool { + matches!( + self.kind(), + ErrorKind::Moved | ErrorKind::Ask | ErrorKind::TryAgain | ErrorKind::ClusterDown + ) + } + + /// Returns true if this error indicates that the connection was + /// refused. You should generally not rely much on this function + /// unless you are writing unit tests that want to detect if a + /// local server is available. + pub fn is_connection_refusal(&self) -> bool { + match self.repr { + ErrorRepr::IoError(ref err) => { + #[allow(clippy::match_like_matches_macro)] + match err.kind() { + io::ErrorKind::ConnectionRefused => true, + // if we connect to a unix socket and the file does not + // exist yet, then we want to treat this as if it was a + // connection refusal. + io::ErrorKind::NotFound => cfg!(unix), + _ => false, + } + } + _ => false, + } + } + + /// Returns true if error was caused by I/O time out. + /// Note that this may not be accurate depending on platform. + pub fn is_timeout(&self) -> bool { + match self.repr { + ErrorRepr::IoError(ref err) => matches!( + err.kind(), + io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock + ), + _ => false, + } + } + + /// Returns true if error was caused by a dropped connection. + pub fn is_connection_dropped(&self) -> bool { + if matches!( + self.kind(), + ErrorKind::FatalSendError | ErrorKind::FatalReceiveError + ) { + return true; + } + match self.repr { + ErrorRepr::IoError(ref err) => matches!( + err.kind(), + io::ErrorKind::BrokenPipe + | io::ErrorKind::ConnectionReset + | io::ErrorKind::UnexpectedEof + ), + _ => false, + } + } + + /// Returns true if the error is likely to not be recoverable, and the connection must be replaced. + pub fn is_unrecoverable_error(&self) -> bool { + match self.retry_method() { + RetryMethod::Reconnect => true, + RetryMethod::ReconnectAndRetry => true, + RetryMethod::NoRetry => false, + RetryMethod::RetryImmediately => false, + RetryMethod::WaitAndRetry => false, + RetryMethod::AskRedirect => false, + RetryMethod::MovedRedirect => false, + RetryMethod::WaitAndRetryOnPrimaryRedirectOnReplica => false, + } + } + + /// Returns the node the error refers to. + /// + /// This returns `(addr, slot_id)`. + pub fn redirect_node(&self) -> Option<(&str, u16)> { + match self.kind() { + ErrorKind::Ask | ErrorKind::Moved => (), + _ => return None, + } + let mut iter = self.detail()?.split_ascii_whitespace(); + let slot_id: u16 = iter.next()?.parse().ok()?; + let addr = iter.next()?; + Some((addr, slot_id)) + } + + /// Returns the extension error code. + /// + /// This method should not be used because every time the redis library + /// adds support for a new error code it would disappear form this method. + /// `code()` always returns the code. + #[deprecated(note = "use code() instead")] + pub fn extension_error_code(&self) -> Option<&str> { + match self.repr { + ErrorRepr::ExtensionError(ref code, _) => Some(code), + _ => None, + } + } + + /// Clone the `RedisError`, throwing away non-cloneable parts of an `IoError`. + /// + /// Deriving `Clone` is not possible because the wrapped `io::Error` is not + /// cloneable. + /// + /// The `ioerror_description` parameter will be prepended to the message in + /// case an `IoError` is found. + #[cfg(feature = "connection-manager")] // Used to avoid "unused method" warning + pub(crate) fn clone_mostly(&self, ioerror_description: &'static str) -> Self { + let repr = match self.repr { + ErrorRepr::WithDescription(kind, desc) => ErrorRepr::WithDescription(kind, desc), + ErrorRepr::WithDescriptionAndDetail(kind, desc, ref detail) => { + ErrorRepr::WithDescriptionAndDetail(kind, desc, detail.clone()) + } + ErrorRepr::ExtensionError(ref code, ref detail) => { + ErrorRepr::ExtensionError(code.clone(), detail.clone()) + } + ErrorRepr::IoError(ref e) => ErrorRepr::IoError(io::Error::new( + e.kind(), + format!("{ioerror_description}: {e}"), + )), + }; + Self { repr } + } + + pub(crate) fn retry_method(&self) -> RetryMethod { + match self.kind() { + ErrorKind::Moved => RetryMethod::MovedRedirect, + ErrorKind::Ask => RetryMethod::AskRedirect, + + ErrorKind::TryAgain => RetryMethod::WaitAndRetry, + ErrorKind::MasterDown => RetryMethod::WaitAndRetry, + ErrorKind::ClusterDown => RetryMethod::WaitAndRetry, + ErrorKind::MasterNameNotFoundBySentinel => RetryMethod::WaitAndRetry, + ErrorKind::NoValidReplicasFoundBySentinel => RetryMethod::WaitAndRetry, + + ErrorKind::BusyLoadingError => RetryMethod::WaitAndRetryOnPrimaryRedirectOnReplica, + + ErrorKind::ResponseError => RetryMethod::NoRetry, + ErrorKind::ReadOnly => RetryMethod::NoRetry, + ErrorKind::ExtensionError => RetryMethod::NoRetry, + ErrorKind::ExecAbortError => RetryMethod::NoRetry, + ErrorKind::TypeError => RetryMethod::NoRetry, + ErrorKind::NoScriptError => RetryMethod::NoRetry, + ErrorKind::InvalidClientConfig => RetryMethod::NoRetry, + ErrorKind::CrossSlot => RetryMethod::NoRetry, + ErrorKind::ClientError => RetryMethod::NoRetry, + ErrorKind::EmptySentinelList => RetryMethod::NoRetry, + ErrorKind::NotBusy => RetryMethod::NoRetry, + #[cfg(feature = "json")] + ErrorKind::Serialize => RetryMethod::NoRetry, + ErrorKind::RESP3NotSupported => RetryMethod::NoRetry, + + ErrorKind::ParseError => RetryMethod::Reconnect, + ErrorKind::AuthenticationFailed => RetryMethod::Reconnect, + ErrorKind::AllConnectionsUnavailable => RetryMethod::Reconnect, + ErrorKind::ConnectionNotFoundForRoute => RetryMethod::Reconnect, + + ErrorKind::IoError => match &self.repr { + ErrorRepr::IoError(err) => match err.kind() { + io::ErrorKind::ConnectionRefused => RetryMethod::Reconnect, + io::ErrorKind::NotFound => RetryMethod::Reconnect, + io::ErrorKind::ConnectionReset => RetryMethod::Reconnect, + io::ErrorKind::ConnectionAborted => RetryMethod::Reconnect, + io::ErrorKind::NotConnected => RetryMethod::Reconnect, + io::ErrorKind::BrokenPipe => RetryMethod::Reconnect, + io::ErrorKind::UnexpectedEof => RetryMethod::Reconnect, + + io::ErrorKind::PermissionDenied => RetryMethod::NoRetry, + io::ErrorKind::Unsupported => RetryMethod::NoRetry, + io::ErrorKind::TimedOut => RetryMethod::NoRetry, + + _ => RetryMethod::RetryImmediately, + }, + _ => RetryMethod::RetryImmediately, + }, + ErrorKind::NotAllSlotsCovered => RetryMethod::NoRetry, + ErrorKind::FatalReceiveError => RetryMethod::Reconnect, + ErrorKind::FatalSendError => RetryMethod::ReconnectAndRetry, + ErrorKind::UserOperationError => RetryMethod::NoRetry, + } + } +} + +pub fn make_extension_error(code: String, detail: Option) -> RedisError { + RedisError { + repr: ErrorRepr::ExtensionError( + code, + match detail { + Some(x) => x, + None => "Unknown extension error encountered".to_string(), + }, + ), + } +} + +/// Library generic result type. +pub type RedisResult = Result; + +/// Library generic future type. +#[cfg(feature = "aio")] +pub type RedisFuture<'a, T> = futures_util::future::BoxFuture<'a, RedisResult>; + +/// An info dictionary type. +#[derive(Debug, Clone)] +pub struct InfoDict { + map: HashMap, +} + +/// This type provides convenient access to key/value data returned by +/// the "INFO" command. It acts like a regular mapping but also has +/// a convenience method `get` which can return data in the appropriate +/// type. +/// +/// For instance this can be used to query the server for the role it's +/// in (master, slave) etc: +/// +/// ```rust,no_run +/// # fn do_something() -> redis::RedisResult<()> { +/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +/// # let mut con = client.get_connection(None).unwrap(); +/// let info : redis::InfoDict = redis::cmd("INFO").query(&mut con)?; +/// let role : Option = info.get("role"); +/// # Ok(()) } +/// ``` +impl InfoDict { + /// Creates a new info dictionary from a string in the response of + /// the INFO command. Each line is a key, value pair with the + /// key and value separated by a colon (`:`). Lines starting with a + /// hash (`#`) are ignored. + pub fn new(key_val_pairs: &str) -> InfoDict { + let mut map = HashMap::new(); + for line in key_val_pairs.lines() { + if line.is_empty() || line.starts_with('#') { + continue; + } + let mut p = line.splitn(2, ':'); + let (k, v) = match (p.next(), p.next()) { + (Some(k), Some(v)) => (k.to_string(), v.to_string()), + _ => continue, + }; + map.insert(k, Value::SimpleString(v)); + } + InfoDict { map } + } + + /// Fetches a value by key and converts it into the given type. + /// Typical types are `String`, `bool` and integer types. + pub fn get(&self, key: &str) -> Option { + match self.find(&key) { + Some(x) => from_redis_value(x).ok(), + None => None, + } + } + + /// Looks up a key in the info dict. + pub fn find(&self, key: &&str) -> Option<&Value> { + self.map.get(*key) + } + + /// Checks if a key is contained in the info dict. + pub fn contains_key(&self, key: &&str) -> bool { + self.find(key).is_some() + } + + /// Returns the size of the info dict. + pub fn len(&self) -> usize { + self.map.len() + } + + /// Checks if the dict is empty. + pub fn is_empty(&self) -> bool { + self.map.is_empty() + } +} + +impl Deref for InfoDict { + type Target = HashMap; + + fn deref(&self) -> &Self::Target { + &self.map + } +} + +/// Abstraction trait for redis command abstractions. +pub trait RedisWrite { + /// Accepts a serialized redis command. + fn write_arg(&mut self, arg: &[u8]); + + /// Accepts a serialized redis command. + fn write_arg_fmt(&mut self, arg: impl fmt::Display) { + self.write_arg(arg.to_string().as_bytes()) + } +} + +impl RedisWrite for Vec> { + fn write_arg(&mut self, arg: &[u8]) { + self.push(arg.to_owned()); + } + + fn write_arg_fmt(&mut self, arg: impl fmt::Display) { + self.push(arg.to_string().into_bytes()) + } +} + +/// Used to convert a value into one or multiple redis argument +/// strings. Most values will produce exactly one item but in +/// some cases it might make sense to produce more than one. +pub trait ToRedisArgs: Sized { + /// This converts the value into a vector of bytes. Each item + /// is a single argument. Most items generate a vector of a + /// single item. + /// + /// The exception to this rule currently are vectors of items. + fn to_redis_args(&self) -> Vec> { + let mut out = Vec::new(); + self.write_redis_args(&mut out); + out + } + + /// This writes the value into a vector of bytes. Each item + /// is a single argument. Most items generate a single item. + /// + /// The exception to this rule currently are vectors of items. + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite; + + /// Returns an information about the contained value with regards + /// to it's numeric behavior in a redis context. This is used in + /// some high level concepts to switch between different implementations + /// of redis functions (for instance `INCR` vs `INCRBYFLOAT`). + fn describe_numeric_behavior(&self) -> NumericBehavior { + NumericBehavior::NonNumeric + } + + /// Returns an indication if the value contained is exactly one + /// argument. It returns false if it's zero or more than one. This + /// is used in some high level functions to intelligently switch + /// between `GET` and `MGET` variants. + fn is_single_arg(&self) -> bool { + true + } + + /// This only exists internally as a workaround for the lack of + /// specialization. + #[doc(hidden)] + fn write_args_from_slice(items: &[Self], out: &mut W) + where + W: ?Sized + RedisWrite, + { + Self::make_arg_iter_ref(items.iter(), out) + } + + /// This only exists internally as a workaround for the lack of + /// specialization. + #[doc(hidden)] + fn make_arg_iter_ref<'a, I, W>(items: I, out: &mut W) + where + W: ?Sized + RedisWrite, + I: Iterator, + Self: 'a, + { + for item in items { + item.write_redis_args(out); + } + } + + #[doc(hidden)] + fn is_single_vec_arg(items: &[Self]) -> bool { + items.len() == 1 && items[0].is_single_arg() + } +} + +macro_rules! itoa_based_to_redis_impl { + ($t:ty, $numeric:expr) => { + impl ToRedisArgs for $t { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + let mut buf = ::itoa::Buffer::new(); + let s = buf.format(*self); + out.write_arg(s.as_bytes()) + } + + fn describe_numeric_behavior(&self) -> NumericBehavior { + $numeric + } + } + }; +} + +macro_rules! non_zero_itoa_based_to_redis_impl { + ($t:ty, $numeric:expr) => { + impl ToRedisArgs for $t { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + let mut buf = ::itoa::Buffer::new(); + let s = buf.format(self.get()); + out.write_arg(s.as_bytes()) + } + + fn describe_numeric_behavior(&self) -> NumericBehavior { + $numeric + } + } + }; +} + +macro_rules! ryu_based_to_redis_impl { + ($t:ty, $numeric:expr) => { + impl ToRedisArgs for $t { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + let mut buf = ::ryu::Buffer::new(); + let s = buf.format(*self); + out.write_arg(s.as_bytes()) + } + + fn describe_numeric_behavior(&self) -> NumericBehavior { + $numeric + } + } + }; +} + +impl ToRedisArgs for u8 { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + let mut buf = ::itoa::Buffer::new(); + let s = buf.format(*self); + out.write_arg(s.as_bytes()) + } + + fn write_args_from_slice(items: &[u8], out: &mut W) + where + W: ?Sized + RedisWrite, + { + out.write_arg(items); + } + + fn is_single_vec_arg(_items: &[u8]) -> bool { + true + } +} + +itoa_based_to_redis_impl!(i8, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(i16, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(u16, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(i32, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(u32, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(i64, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(u64, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(isize, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(usize, NumericBehavior::NumberIsInteger); + +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroU8, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroI8, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroU16, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroI16, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroU32, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroI32, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroU64, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroI64, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroUsize, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroIsize, NumericBehavior::NumberIsInteger); + +ryu_based_to_redis_impl!(f32, NumericBehavior::NumberIsFloat); +ryu_based_to_redis_impl!(f64, NumericBehavior::NumberIsFloat); + +#[cfg(any( + feature = "rust_decimal", + feature = "bigdecimal", + feature = "num-bigint" +))] +macro_rules! big_num_to_redis_impl { + ($t:ty) => { + impl ToRedisArgs for $t { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + out.write_arg(&self.to_string().into_bytes()) + } + } + }; +} + +#[cfg(feature = "rust_decimal")] +big_num_to_redis_impl!(rust_decimal::Decimal); +#[cfg(feature = "bigdecimal")] +big_num_to_redis_impl!(bigdecimal::BigDecimal); +#[cfg(feature = "num-bigint")] +big_num_to_redis_impl!(num_bigint::BigInt); +#[cfg(feature = "num-bigint")] +big_num_to_redis_impl!(num_bigint::BigUint); + +impl ToRedisArgs for bool { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + out.write_arg(if *self { b"1" } else { b"0" }) + } +} + +impl ToRedisArgs for String { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + out.write_arg(self.as_bytes()) + } +} + +impl<'a> ToRedisArgs for &'a str { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + out.write_arg(self.as_bytes()) + } +} + +impl ToRedisArgs for Vec { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + ToRedisArgs::write_args_from_slice(self, out) + } + + fn is_single_arg(&self) -> bool { + ToRedisArgs::is_single_vec_arg(&self[..]) + } +} + +impl<'a, T: ToRedisArgs> ToRedisArgs for &'a [T] { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + ToRedisArgs::write_args_from_slice(self, out) + } + + fn is_single_arg(&self) -> bool { + ToRedisArgs::is_single_vec_arg(self) + } +} + +impl ToRedisArgs for Option { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + if let Some(ref x) = *self { + x.write_redis_args(out); + } + } + + fn describe_numeric_behavior(&self) -> NumericBehavior { + match *self { + Some(ref x) => x.describe_numeric_behavior(), + None => NumericBehavior::NonNumeric, + } + } + + fn is_single_arg(&self) -> bool { + match *self { + Some(ref x) => x.is_single_arg(), + None => false, + } + } +} + +impl ToRedisArgs for &T { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + (*self).write_redis_args(out) + } + + fn is_single_arg(&self) -> bool { + (*self).is_single_arg() + } +} + +/// @note: Redis cannot store empty sets so the application has to +/// check whether the set is empty and if so, not attempt to use that +/// result +impl ToRedisArgs + for std::collections::HashSet +{ + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + ToRedisArgs::make_arg_iter_ref(self.iter(), out) + } + + fn is_single_arg(&self) -> bool { + self.len() <= 1 + } +} + +/// @note: Redis cannot store empty sets so the application has to +/// check whether the set is empty and if so, not attempt to use that +/// result +#[cfg(feature = "ahash")] +impl ToRedisArgs for ahash::AHashSet { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + ToRedisArgs::make_arg_iter_ref(self.iter(), out) + } + + fn is_single_arg(&self) -> bool { + self.len() <= 1 + } +} + +/// @note: Redis cannot store empty sets so the application has to +/// check whether the set is empty and if so, not attempt to use that +/// result +impl ToRedisArgs for BTreeSet { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + ToRedisArgs::make_arg_iter_ref(self.iter(), out) + } + + fn is_single_arg(&self) -> bool { + self.len() <= 1 + } +} + +/// this flattens BTreeMap into something that goes well with HMSET +/// @note: Redis cannot store empty sets so the application has to +/// check whether the set is empty and if so, not attempt to use that +/// result +impl ToRedisArgs for BTreeMap { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + for (key, value) in self { + // otherwise things like HMSET will simply NOT work + assert!(key.is_single_arg() && value.is_single_arg()); + + key.write_redis_args(out); + value.write_redis_args(out); + } + } + + fn is_single_arg(&self) -> bool { + self.len() <= 1 + } +} + +impl ToRedisArgs + for std::collections::HashMap +{ + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + for (key, value) in self { + assert!(key.is_single_arg() && value.is_single_arg()); + + key.write_redis_args(out); + value.write_redis_args(out); + } + } + + fn is_single_arg(&self) -> bool { + self.len() <= 1 + } +} + +macro_rules! to_redis_args_for_tuple { + () => (); + ($($name:ident,)+) => ( + #[doc(hidden)] + impl<$($name: ToRedisArgs),*> ToRedisArgs for ($($name,)*) { + // we have local variables named T1 as dummies and those + // variables are unused. + #[allow(non_snake_case, unused_variables)] + fn write_redis_args(&self, out: &mut W) where W: ?Sized + RedisWrite { + let ($(ref $name,)*) = *self; + $($name.write_redis_args(out);)* + } + + #[allow(non_snake_case, unused_variables)] + fn is_single_arg(&self) -> bool { + let mut n = 0u32; + $(let $name = (); n += 1;)* + n == 1 + } + } + to_redis_args_for_tuple_peel!($($name,)*); + ) +} + +/// This chips of the leading one and recurses for the rest. So if the first +/// iteration was T1, T2, T3 it will recurse to T2, T3. It stops for tuples +/// of size 1 (does not implement down to unit). +macro_rules! to_redis_args_for_tuple_peel { + ($name:ident, $($other:ident,)*) => (to_redis_args_for_tuple!($($other,)*);) +} + +to_redis_args_for_tuple! { T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, } + +impl ToRedisArgs for &[T; N] { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + ToRedisArgs::write_args_from_slice(self.as_slice(), out) + } + + fn is_single_arg(&self) -> bool { + ToRedisArgs::is_single_vec_arg(self.as_slice()) + } +} + +fn vec_to_array(items: Vec, original_value: &Value) -> RedisResult<[T; N]> { + match items.try_into() { + Ok(array) => Ok(array), + Err(items) => { + let msg = format!( + "Response has wrong dimension, expected {N}, got {}", + items.len() + ); + invalid_type_error!(original_value, msg) + } + } +} + +impl FromRedisValue for [T; N] { + fn from_redis_value(value: &Value) -> RedisResult<[T; N]> { + match *value { + Value::BulkString(ref bytes) => match FromRedisValue::from_byte_vec(bytes) { + Some(items) => vec_to_array(items, value), + None => { + let msg = format!( + "Conversion to Array[{}; {N}] failed", + std::any::type_name::() + ); + invalid_type_error!(value, msg) + } + }, + Value::Array(ref items) => { + let items = FromRedisValue::from_redis_values(items)?; + vec_to_array(items, value) + } + Value::Nil => vec_to_array(vec![], value), + _ => invalid_type_error!(value, "Response type not array compatible"), + } + } +} + +/// This trait is used to convert a redis value into a more appropriate +/// type. While a redis `Value` can represent any response that comes +/// back from the redis server, usually you want to map this into something +/// that works better in rust. For instance you might want to convert the +/// return value into a `String` or an integer. +/// +/// This trait is well supported throughout the library and you can +/// implement it for your own types if you want. +/// +/// In addition to what you can see from the docs, this is also implemented +/// for tuples up to size 12 and for `Vec`. +pub trait FromRedisValue: Sized { + /// Given a redis `Value` this attempts to convert it into the given + /// destination type. If that fails because it's not compatible an + /// appropriate error is generated. + fn from_redis_value(v: &Value) -> RedisResult; + + /// Given a redis `Value` this attempts to convert it into the given + /// destination type. If that fails because it's not compatible an + /// appropriate error is generated. + fn from_owned_redis_value(v: Value) -> RedisResult { + // By default, fall back to `from_redis_value`. + // This function only needs to be implemented if it can benefit + // from taking `v` by value. + Self::from_redis_value(&v) + } + + /// Similar to `from_redis_value` but constructs a vector of objects + /// from another vector of values. This primarily exists internally + /// to customize the behavior for vectors of tuples. + fn from_redis_values(items: &[Value]) -> RedisResult> { + items.iter().map(FromRedisValue::from_redis_value).collect() + } + + /// The same as `from_redis_values`, but takes a `Vec` instead + /// of a `&[Value]`. + fn from_owned_redis_values(items: Vec) -> RedisResult> { + items + .into_iter() + .map(FromRedisValue::from_owned_redis_value) + .collect() + } + + /// Convert bytes to a single element vector. + fn from_byte_vec(_vec: &[u8]) -> Option> { + Self::from_owned_redis_value(Value::BulkString(_vec.into())) + .map(|rv| vec![rv]) + .ok() + } + + /// Convert bytes to a single element vector. + fn from_owned_byte_vec(_vec: Vec) -> RedisResult> { + Self::from_owned_redis_value(Value::BulkString(_vec)).map(|rv| vec![rv]) + } +} + +fn get_inner_value(v: &Value) -> &Value { + if let Value::Attribute { + data, + attributes: _, + } = v + { + data.as_ref() + } else { + v + } +} + +fn get_owned_inner_value(v: Value) -> Value { + if let Value::Attribute { + data, + attributes: _, + } = v + { + *data + } else { + v + } +} + +macro_rules! from_redis_value_for_num_internal { + ($t:ty, $v:expr) => {{ + let v = if let Value::Attribute { + data, + attributes: _, + } = $v + { + data + } else { + $v + }; + match *v { + Value::Int(val) => Ok(val as $t), + Value::SimpleString(ref s) => match s.parse::<$t>() { + Ok(rv) => Ok(rv), + Err(_) => invalid_type_error!(v, "Could not convert from string."), + }, + Value::BulkString(ref bytes) => match from_utf8(bytes)?.parse::<$t>() { + Ok(rv) => Ok(rv), + Err(_) => invalid_type_error!(v, "Could not convert from string."), + }, + Value::Double(val) => Ok(val as $t), + _ => invalid_type_error!(v, "Response type not convertible to numeric."), + } + }}; +} + +macro_rules! from_redis_value_for_num { + ($t:ty) => { + impl FromRedisValue for $t { + fn from_redis_value(v: &Value) -> RedisResult<$t> { + from_redis_value_for_num_internal!($t, v) + } + } + }; +} + +impl FromRedisValue for u8 { + fn from_redis_value(v: &Value) -> RedisResult { + from_redis_value_for_num_internal!(u8, v) + } + + // this hack allows us to specialize Vec to work with binary data. + fn from_byte_vec(vec: &[u8]) -> Option> { + Some(vec.to_vec()) + } + fn from_owned_byte_vec(vec: Vec) -> RedisResult> { + Ok(vec) + } +} + +from_redis_value_for_num!(i8); +from_redis_value_for_num!(i16); +from_redis_value_for_num!(u16); +from_redis_value_for_num!(i32); +from_redis_value_for_num!(u32); +from_redis_value_for_num!(i64); +from_redis_value_for_num!(u64); +from_redis_value_for_num!(i128); +from_redis_value_for_num!(u128); +from_redis_value_for_num!(f32); +from_redis_value_for_num!(f64); +from_redis_value_for_num!(isize); +from_redis_value_for_num!(usize); + +#[cfg(any( + feature = "rust_decimal", + feature = "bigdecimal", + feature = "num-bigint" +))] +macro_rules! from_redis_value_for_bignum_internal { + ($t:ty, $v:expr) => {{ + let v = $v; + match *v { + Value::Int(val) => <$t>::try_from(val) + .map_err(|_| invalid_type_error_inner!(v, "Could not convert from integer.")), + Value::SimpleString(ref s) => match s.parse::<$t>() { + Ok(rv) => Ok(rv), + Err(_) => invalid_type_error!(v, "Could not convert from string."), + }, + Value::BulkString(ref bytes) => match from_utf8(bytes)?.parse::<$t>() { + Ok(rv) => Ok(rv), + Err(_) => invalid_type_error!(v, "Could not convert from string."), + }, + _ => invalid_type_error!(v, "Response type not convertible to numeric."), + } + }}; +} + +#[cfg(any( + feature = "rust_decimal", + feature = "bigdecimal", + feature = "num-bigint" +))] +macro_rules! from_redis_value_for_bignum { + ($t:ty) => { + impl FromRedisValue for $t { + fn from_redis_value(v: &Value) -> RedisResult<$t> { + from_redis_value_for_bignum_internal!($t, v) + } + } + }; +} + +#[cfg(feature = "rust_decimal")] +from_redis_value_for_bignum!(rust_decimal::Decimal); +#[cfg(feature = "bigdecimal")] +from_redis_value_for_bignum!(bigdecimal::BigDecimal); +#[cfg(feature = "num-bigint")] +from_redis_value_for_bignum!(num_bigint::BigInt); +#[cfg(feature = "num-bigint")] +from_redis_value_for_bignum!(num_bigint::BigUint); + +impl FromRedisValue for bool { + fn from_redis_value(v: &Value) -> RedisResult { + let v = get_inner_value(v); + match *v { + Value::Nil => Ok(false), + Value::Int(val) => Ok(val != 0), + Value::SimpleString(ref s) => { + if &s[..] == "1" { + Ok(true) + } else if &s[..] == "0" { + Ok(false) + } else { + invalid_type_error!(v, "Response status not valid boolean"); + } + } + Value::BulkString(ref bytes) => { + if bytes == b"1" { + Ok(true) + } else if bytes == b"0" { + Ok(false) + } else { + invalid_type_error!(v, "Response type not bool compatible."); + } + } + Value::Boolean(b) => Ok(b), + Value::Okay => Ok(true), + _ => invalid_type_error!(v, "Response type not bool compatible."), + } + } +} + +impl FromRedisValue for CString { + fn from_redis_value(v: &Value) -> RedisResult { + let v = get_inner_value(v); + match *v { + Value::BulkString(ref bytes) => Ok(CString::new(bytes.as_slice())?), + Value::Okay => Ok(CString::new("OK")?), + Value::SimpleString(ref val) => Ok(CString::new(val.as_bytes())?), + _ => invalid_type_error!(v, "Response type not CString compatible."), + } + } + fn from_owned_redis_value(v: Value) -> RedisResult { + let v = get_owned_inner_value(v); + match v { + Value::BulkString(bytes) => Ok(CString::new(bytes)?), + Value::Okay => Ok(CString::new("OK")?), + Value::SimpleString(val) => Ok(CString::new(val)?), + _ => invalid_type_error!(v, "Response type not CString compatible."), + } + } +} + +impl FromRedisValue for String { + fn from_redis_value(v: &Value) -> RedisResult { + let v = get_inner_value(v); + match *v { + Value::BulkString(ref bytes) => Ok(from_utf8(bytes)?.to_string()), + Value::Okay => Ok("OK".to_string()), + Value::SimpleString(ref val) => Ok(val.to_string()), + Value::VerbatimString { + format: _, + ref text, + } => Ok(text.to_string()), + Value::Double(ref val) => Ok(val.to_string()), + Value::Int(val) => Ok(val.to_string()), + _ => invalid_type_error!(v, "Response type not string compatible."), + } + } + + fn from_owned_redis_value(v: Value) -> RedisResult { + let v = get_owned_inner_value(v); + match v { + Value::BulkString(bytes) => Ok(String::from_utf8(bytes)?), + Value::Okay => Ok("OK".to_string()), + Value::SimpleString(val) => Ok(val), + Value::VerbatimString { format: _, text } => Ok(text), + Value::Double(val) => Ok(val.to_string()), + Value::Int(val) => Ok(val.to_string()), + _ => invalid_type_error!(v, "Response type not string compatible."), + } + } +} + +/// Implement `FromRedisValue` for `$Type` (which should use the generic parameter `$T`). +/// +/// The implementation parses the value into a vec, and then passes the value through `$convert`. +/// If `$convert` is omitted, it defaults to `Into::into`. +macro_rules! from_vec_from_redis_value { + (<$T:ident> $Type:ty) => { + from_vec_from_redis_value!(<$T> $Type; Into::into); + }; + + (<$T:ident> $Type:ty; $convert:expr) => { + impl<$T: FromRedisValue> FromRedisValue for $Type { + fn from_redis_value(v: &Value) -> RedisResult<$Type> { + match v { + // All binary data except u8 will try to parse into a single element vector. + // u8 has its own implementation of from_byte_vec. + Value::BulkString(bytes) => match FromRedisValue::from_byte_vec(bytes) { + Some(x) => Ok($convert(x)), + None => invalid_type_error!( + v, + format!("Conversion to {} failed.", std::any::type_name::<$Type>()) + ), + }, + Value::Array(items) => FromRedisValue::from_redis_values(items).map($convert), + Value::Set(ref items) => FromRedisValue::from_redis_values(items).map($convert), + Value::Map(ref items) => { + let mut n: Vec = vec![]; + for item in items { + match FromRedisValue::from_redis_value(&Value::Map(vec![item.clone()])) { + Ok(v) => { + n.push(v); + } + Err(e) => { + return Err(e); + } + } + } + Ok($convert(n)) + } + Value::Nil => Ok($convert(Vec::new())), + _ => invalid_type_error!(v, "Response type not vector compatible."), + } + } + fn from_owned_redis_value(v: Value) -> RedisResult<$Type> { + match v { + // Binary data is parsed into a single-element vector, except + // for the element type `u8`, which directly consumes the entire + // array of bytes. + Value::BulkString(bytes) => FromRedisValue::from_owned_byte_vec(bytes).map($convert), + Value::Array(items) => FromRedisValue::from_owned_redis_values(items).map($convert), + Value::Set(items) => FromRedisValue::from_owned_redis_values(items).map($convert), + Value::Map(items) => { + let mut n: Vec = vec![]; + for item in items { + match FromRedisValue::from_owned_redis_value(Value::Map(vec![item])) { + Ok(v) => { + n.push(v); + } + Err(e) => { + return Err(e); + } + } + } + Ok($convert(n)) + } + Value::Nil => Ok($convert(Vec::new())), + _ => invalid_type_error!(v, "Response type not vector compatible."), + } + } + } + }; +} + +from_vec_from_redis_value!( Vec); +from_vec_from_redis_value!( std::sync::Arc<[T]>); +from_vec_from_redis_value!( Box<[T]>; Vec::into_boxed_slice); + +impl FromRedisValue + for std::collections::HashMap +{ + fn from_redis_value(v: &Value) -> RedisResult> { + let v = get_inner_value(v); + match *v { + Value::Nil => Ok(Default::default()), + _ => v + .as_map_iter() + .ok_or_else(|| { + invalid_type_error_inner!(v, "Response type not hashmap compatible") + })? + .map(|(k, v)| Ok((from_redis_value(k)?, from_redis_value(v)?))) + .collect(), + } + } + fn from_owned_redis_value(v: Value) -> RedisResult> { + let v = get_owned_inner_value(v); + match v { + Value::Nil => Ok(Default::default()), + _ => v + .into_map_iter() + .map_err(|v| invalid_type_error_inner!(v, "Response type not hashmap compatible"))? + .map(|(k, v)| Ok((from_owned_redis_value(k)?, from_owned_redis_value(v)?))) + .collect(), + } + } +} + +#[cfg(feature = "ahash")] +impl FromRedisValue for ahash::AHashMap { + fn from_redis_value(v: &Value) -> RedisResult> { + let v = get_inner_value(v); + match *v { + Value::Nil => Ok(ahash::AHashMap::with_hasher(Default::default())), + _ => v + .as_map_iter() + .ok_or_else(|| { + invalid_type_error_inner!(v, "Response type not hashmap compatible") + })? + .map(|(k, v)| Ok((from_redis_value(k)?, from_redis_value(v)?))) + .collect(), + } + } + fn from_owned_redis_value(v: Value) -> RedisResult> { + let v = get_owned_inner_value(v); + match v { + Value::Nil => Ok(ahash::AHashMap::with_hasher(Default::default())), + _ => v + .into_map_iter() + .map_err(|v| invalid_type_error_inner!(v, "Response type not hashmap compatible"))? + .map(|(k, v)| Ok((from_owned_redis_value(k)?, from_owned_redis_value(v)?))) + .collect(), + } + } +} + +impl FromRedisValue for BTreeMap +where + K: Ord, +{ + fn from_redis_value(v: &Value) -> RedisResult> { + let v = get_inner_value(v); + v.as_map_iter() + .ok_or_else(|| invalid_type_error_inner!(v, "Response type not btreemap compatible"))? + .map(|(k, v)| Ok((from_redis_value(k)?, from_redis_value(v)?))) + .collect() + } + fn from_owned_redis_value(v: Value) -> RedisResult> { + let v = get_owned_inner_value(v); + v.into_map_iter() + .map_err(|v| invalid_type_error_inner!(v, "Response type not btreemap compatible"))? + .map(|(k, v)| Ok((from_owned_redis_value(k)?, from_owned_redis_value(v)?))) + .collect() + } +} + +impl FromRedisValue + for std::collections::HashSet +{ + fn from_redis_value(v: &Value) -> RedisResult> { + let v = get_inner_value(v); + let items = v + .as_sequence() + .ok_or_else(|| invalid_type_error_inner!(v, "Response type not hashset compatible"))?; + items.iter().map(|item| from_redis_value(item)).collect() + } + fn from_owned_redis_value(v: Value) -> RedisResult> { + let v = get_owned_inner_value(v); + let items = v + .into_sequence() + .map_err(|v| invalid_type_error_inner!(v, "Response type not hashset compatible"))?; + items + .into_iter() + .map(|item| from_owned_redis_value(item)) + .collect() + } +} + +#[cfg(feature = "ahash")] +impl FromRedisValue for ahash::AHashSet { + fn from_redis_value(v: &Value) -> RedisResult> { + let v = get_inner_value(v); + let items = v + .as_sequence() + .ok_or_else(|| invalid_type_error_inner!(v, "Response type not hashset compatible"))?; + items.iter().map(|item| from_redis_value(item)).collect() + } + + fn from_owned_redis_value(v: Value) -> RedisResult> { + let v = get_owned_inner_value(v); + let items = v + .into_sequence() + .map_err(|v| invalid_type_error_inner!(v, "Response type not hashset compatible"))?; + items + .into_iter() + .map(|item| from_owned_redis_value(item)) + .collect() + } +} + +impl FromRedisValue for BTreeSet +where + T: Ord, +{ + fn from_redis_value(v: &Value) -> RedisResult> { + let v = get_inner_value(v); + let items = v.as_sequence().ok_or_else(|| { + invalid_type_error_inner!(v, "Response type not btree_set compatible") + })?; + items.iter().map(|item| from_redis_value(item)).collect() + } + fn from_owned_redis_value(v: Value) -> RedisResult> { + let v = get_owned_inner_value(v); + let items = v + .into_sequence() + .map_err(|v| invalid_type_error_inner!(v, "Response type not btree_set compatible"))?; + items + .into_iter() + .map(|item| from_owned_redis_value(item)) + .collect() + } +} + +impl FromRedisValue for Value { + fn from_redis_value(v: &Value) -> RedisResult { + Ok(v.clone()) + } + fn from_owned_redis_value(v: Value) -> RedisResult { + Ok(v) + } +} + +impl FromRedisValue for () { + fn from_redis_value(_v: &Value) -> RedisResult<()> { + Ok(()) + } +} + +macro_rules! from_redis_value_for_tuple { + () => (); + ($($name:ident,)+) => ( + #[doc(hidden)] + impl<$($name: FromRedisValue),*> FromRedisValue for ($($name,)*) { + // we have local variables named T1 as dummies and those + // variables are unused. + #[allow(non_snake_case, unused_variables)] + fn from_redis_value(v: &Value) -> RedisResult<($($name,)*)> { + let v = get_inner_value(v); + match *v { + Value::Array(ref items) => { + // hacky way to count the tuple size + let mut n = 0; + $(let $name = (); n += 1;)* + if items.len() != n { + invalid_type_error!(v, "Array response of wrong dimension") + } + + // this is pretty ugly too. The { i += 1; i - 1} is rust's + // postfix increment :) + let mut i = 0; + Ok(($({let $name = (); from_redis_value( + &items[{ i += 1; i - 1 }])?},)*)) + } + + Value::Map(ref items) => { + // hacky way to count the tuple size + let mut n = 0; + $(let $name = (); n += 1;)* + if n != 2 { + invalid_type_error!(v, "Map response of wrong dimension") + } + + let mut flatten_items = vec![]; + for (k,v) in items { + flatten_items.push(k); + flatten_items.push(v); + } + + // this is pretty ugly too. The { i += 1; i - 1} is rust's + // postfix increment :) + let mut i = 0; + Ok(($({let $name = (); from_redis_value( + &flatten_items[{ i += 1; i - 1 }])?},)*)) + } + + _ => invalid_type_error!(v, "Not a Array response") + } + } + + // we have local variables named T1 as dummies and those + // variables are unused. + #[allow(non_snake_case, unused_variables)] + fn from_owned_redis_value(v: Value) -> RedisResult<($($name,)*)> { + let v = get_owned_inner_value(v); + match v { + Value::Array(mut items) => { + // hacky way to count the tuple size + let mut n = 0; + $(let $name = (); n += 1;)* + if items.len() != n { + invalid_type_error!(Value::Array(items), "Array response of wrong dimension") + } + + // this is pretty ugly too. The { i += 1; i - 1} is rust's + // postfix increment :) + let mut i = 0; + Ok(($({let $name = (); from_owned_redis_value( + ::std::mem::replace(&mut items[{ i += 1; i - 1 }], Value::Nil) + )?},)*)) + } + + Value::Map(items) => { + // hacky way to count the tuple size + let mut n = 0; + $(let $name = (); n += 1;)* + if n != 2 { + invalid_type_error!(Value::Map(items), "Map response of wrong dimension") + } + + let mut flatten_items = vec![]; + for (k,v) in items { + flatten_items.push(k); + flatten_items.push(v); + } + + // this is pretty ugly too. The { i += 1; i - 1} is rust's + // postfix increment :) + let mut i = 0; + Ok(($({let $name = (); from_redis_value( + &flatten_items[{ i += 1; i - 1 }])?},)*)) + } + + _ => invalid_type_error!(v, "Not a Array response") + } + } + + #[allow(non_snake_case, unused_variables)] + fn from_redis_values(items: &[Value]) -> RedisResult> { + // hacky way to count the tuple size + let mut n = 0; + $(let $name = (); n += 1;)* + let mut rv = vec![]; + if items.len() == 0 { + return Ok(rv) + } + //It's uglier then before! + for item in items { + match item { + Value::Array(ch) => { + if let [$($name),*] = &ch[..] { + rv.push(($(from_redis_value(&$name)?),*),) + } else { + unreachable!() + }; + }, + _ => {}, + + } + } + if !rv.is_empty(){ + return Ok(rv); + } + + if let [$($name),*] = items{ + rv.push(($(from_redis_value($name)?),*),); + return Ok(rv); + } + for chunk in items.chunks_exact(n) { + match chunk { + [$($name),*] => rv.push(($(from_redis_value($name)?),*),), + _ => {}, + } + } + Ok(rv) + } + + #[allow(non_snake_case, unused_variables)] + fn from_owned_redis_values(mut items: Vec) -> RedisResult> { + // hacky way to count the tuple size + let mut n = 0; + $(let $name = (); n += 1;)* + + let mut rv = vec![]; + if items.len() == 0 { + return Ok(rv) + } + //It's uglier then before! + for item in items.iter() { + match item { + Value::Array(ch) => { + // TODO - this copies when we could've used the owned value. need to find out how to do this. + if let [$($name),*] = &ch[..] { + rv.push(($(from_redis_value($name)?),*),) + } else { + unreachable!() + }; + }, + _ => {}, + } + } + if !rv.is_empty(){ + return Ok(rv); + } + + let mut rv = Vec::with_capacity(items.len() / n); + if items.len() == 0 { + return Ok(rv) + } + for chunk in items.chunks_mut(n) { + match chunk { + // Take each element out of the chunk with `std::mem::replace`, leaving a `Value::Nil` + // in its place. This allows each `Value` to be parsed without being copied. + // Since `items` is consume by this function and not used later, this replacement + // is not observable to the rest of the code. + [$($name),*] => rv.push(($(from_owned_redis_value(std::mem::replace($name, Value::Nil))?),*),), + _ => unreachable!(), + } + } + Ok(rv) + } + } + from_redis_value_for_tuple_peel!($($name,)*); + ) +} + +/// This chips of the leading one and recurses for the rest. So if the first +/// iteration was T1, T2, T3 it will recurse to T2, T3. It stops for tuples +/// of size 1 (does not implement down to unit). +macro_rules! from_redis_value_for_tuple_peel { + ($name:ident, $($other:ident,)*) => (from_redis_value_for_tuple!($($other,)*);) +} + +from_redis_value_for_tuple! { T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, } + +impl FromRedisValue for InfoDict { + fn from_redis_value(v: &Value) -> RedisResult { + let v = get_inner_value(v); + let s: String = from_redis_value(v)?; + Ok(InfoDict::new(&s)) + } + fn from_owned_redis_value(v: Value) -> RedisResult { + let v = get_owned_inner_value(v); + let s: String = from_owned_redis_value(v)?; + Ok(InfoDict::new(&s)) + } +} + +impl FromRedisValue for Option { + fn from_redis_value(v: &Value) -> RedisResult> { + let v = get_inner_value(v); + if *v == Value::Nil { + return Ok(None); + } + Ok(Some(from_redis_value(v)?)) + } + fn from_owned_redis_value(v: Value) -> RedisResult> { + let v = get_owned_inner_value(v); + if v == Value::Nil { + return Ok(None); + } + Ok(Some(from_owned_redis_value(v)?)) + } +} + +#[cfg(feature = "bytes")] +impl FromRedisValue for bytes::Bytes { + fn from_redis_value(v: &Value) -> RedisResult { + let v = get_inner_value(v); + match v { + Value::BulkString(bytes_vec) => Ok(bytes::Bytes::copy_from_slice(bytes_vec.as_ref())), + _ => invalid_type_error!(v, "Not a bulk string"), + } + } + fn from_owned_redis_value(v: Value) -> RedisResult { + let v = get_owned_inner_value(v); + match v { + Value::BulkString(bytes_vec) => Ok(bytes_vec.into()), + _ => invalid_type_error!(v, "Not a bulk string"), + } + } +} + +#[cfg(feature = "uuid")] +impl FromRedisValue for uuid::Uuid { + fn from_redis_value(v: &Value) -> RedisResult { + match *v { + Value::BulkString(ref bytes) => Ok(uuid::Uuid::from_slice(bytes)?), + _ => invalid_type_error!(v, "Response type not uuid compatible."), + } + } +} + +#[cfg(feature = "uuid")] +impl ToRedisArgs for uuid::Uuid { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + out.write_arg(self.as_bytes()); + } +} + +/// A shortcut function to invoke `FromRedisValue::from_redis_value` +/// to make the API slightly nicer. +pub fn from_redis_value(v: &Value) -> RedisResult { + FromRedisValue::from_redis_value(v) +} + +/// A shortcut function to invoke `FromRedisValue::from_owned_redis_value` +/// to make the API slightly nicer. +pub fn from_owned_redis_value(v: Value) -> RedisResult { + FromRedisValue::from_owned_redis_value(v) +} + +/// Enum representing the communication protocol with the server. This enum represents the types +/// of data that the server can send to the client, and the capabilities that the client can use. +#[derive(Clone, Eq, PartialEq, Default, Debug, Copy)] +pub enum ProtocolVersion { + /// + #[default] + RESP2, + /// + RESP3, +} diff --git a/glide-core/redis-rs/redis/tests/auth.rs b/glide-core/redis-rs/redis/tests/auth.rs new file mode 100644 index 0000000000..e48e37940d --- /dev/null +++ b/glide-core/redis-rs/redis/tests/auth.rs @@ -0,0 +1,303 @@ +mod support; + +#[cfg(test)] +mod auth { + use crate::support::*; + use redis::{ + aio::MultiplexedConnection, + cluster::ClusterClientBuilder, + cluster_async::ClusterConnection, + cluster_routing::{MultipleNodeRoutingInfo, ResponsePolicy, RoutingInfo}, + cmd, ConnectionInfo, GlideConnectionOptions, RedisConnectionInfo, RedisResult, Value, + }; + + const ALL_SUCCESS_ROUTE: RoutingInfo = RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllNodes, + Some(ResponsePolicy::AllSucceeded), + )); + + const PASSWORD: &str = "password"; + const NEW_PASSWORD: &str = "new_password"; + + enum ConnectionType { + Cluster, + Standalone, + } + + enum Connection { + Cluster(ClusterConnection), + Standalone(MultiplexedConnection), + } + + async fn create_connection( + password: Option, + connection_type: ConnectionType, + cluster_context: Option<&TestClusterContext>, + standalone_context: Option<&TestContext>, + ) -> RedisResult { + match connection_type { + ConnectionType::Cluster => { + let cluster_context = + cluster_context.expect("ClusterContext is required for Cluster connection"); + let builder = get_builder(cluster_context, password); + let connection = builder.build()?.get_async_connection(None).await?; + Ok(Connection::Cluster(connection)) + } + ConnectionType::Standalone => { + let standalone_context = + standalone_context.expect("TestContext is required for Standalone connection"); + let info = get_connection_info(standalone_context, password); + let client = redis::Client::open(info)?; + let connection = client + .get_multiplexed_tokio_connection(GlideConnectionOptions::default()) + .await?; + Ok(Connection::Standalone(connection)) + } + } + } + + fn get_connection_info(cluster: &TestContext, password: Option) -> ConnectionInfo { + let addr = cluster.server.connection_info().addr.clone(); + ConnectionInfo { + addr, + redis: RedisConnectionInfo { + password, + ..Default::default() + }, + } + } + + fn get_builder(cluster: &TestClusterContext, password: Option) -> ClusterClientBuilder { + let mut builder = ClusterClientBuilder::new(cluster.nodes.clone()); + if let Some(password) = password { + builder = builder.password(password); + } + builder + } + + async fn set_password(password: &str, conn: &mut Connection) -> RedisResult<()> { + let mut set_auth_cmd = cmd("config"); + set_auth_cmd.arg("set").arg("requirepass").arg(password); + match conn { + Connection::Cluster(cluster_conn) => cluster_conn + .route_command(&set_auth_cmd, ALL_SUCCESS_ROUTE) + .await + .map(|_| ()), + Connection::Standalone(standalone_conn) => set_auth_cmd + .query_async::<_, ()>(standalone_conn) + .await + .map(|_| ()), + } + } + + async fn kill_non_management_connections(con: &mut Connection) { + let mut kill_cmd = cmd("client"); + kill_cmd.arg("kill").arg("type").arg("normal"); + match con { + Connection::Cluster(cluster_conn) => { + cluster_conn + .route_command(&kill_cmd, ALL_SUCCESS_ROUTE) + .await + .unwrap(); + } + Connection::Standalone(standalone_conn) => { + kill_cmd.arg("skipme").arg("no"); + kill_cmd + .query_async::<_, ()>(standalone_conn) + .await + .unwrap(); + } + } + } + + #[tokio::test] + #[serial_test::serial] + async fn test_replace_password_cluster() { + let cluster_context = TestClusterContext::new(3, 0); + + // Create a management connection to set the password + let management_connection = + match create_connection(None, ConnectionType::Cluster, Some(&cluster_context), None) + .await + .unwrap() + { + Connection::Cluster(conn) => conn, + _ => panic!("Expected ClusterConnection"), + }; + + // Set the password using the unified function + let mut management_conn = Connection::Cluster(management_connection.clone()); + set_password(PASSWORD, &mut management_conn).await.unwrap(); + + // Test that we can't connect without password + let connection_should_fail = + create_connection(None, ConnectionType::Cluster, Some(&cluster_context), None).await; + assert!(connection_should_fail.is_err()); + let err = connection_should_fail.err().unwrap(); + assert!(err.to_string().contains("Authentication required.")); + + // Test that we can connect with password + let mut connection_should_succeed = match create_connection( + Some(PASSWORD.to_string()), + ConnectionType::Cluster, + Some(&cluster_context), + None, + ) + .await + .unwrap() + { + Connection::Cluster(conn) => conn, + _ => panic!("Expected ClusterConnection"), + }; + + let res: RedisResult = cmd("set") + .arg("foo") + .arg("bar") + .query_async(&mut connection_should_succeed) + .await; + assert_eq!(res.unwrap(), Value::Okay); + + // Verify that we can retrieve the set value + let res: RedisResult = cmd("get") + .arg("foo") + .query_async(&mut connection_should_succeed) + .await; + assert_eq!(res.unwrap(), Value::BulkString(b"bar".to_vec())); + + // Kill the connection to force reconnection + kill_non_management_connections(&mut Connection::Cluster(management_connection.clone())) + .await; + + // Attempt to get the value again to ensure reconnection works + let should_be_ok: RedisResult = cmd("get") + .arg("foo") + .query_async(&mut connection_should_succeed) + .await; + assert_eq!(should_be_ok.unwrap(), Value::BulkString(b"bar".to_vec())); + + // Update the password in the connection + connection_should_succeed + .update_connection_password(Some(NEW_PASSWORD.to_string())) + .await + .unwrap(); + + // Update the password on the server + let mut management_conn = Connection::Cluster(management_connection.clone()); + set_password(NEW_PASSWORD, &mut management_conn) + .await + .unwrap(); + + // Test that we can't connect with the old password + let connection_should_fail = create_connection( + Some(PASSWORD.to_string()), + ConnectionType::Cluster, + Some(&cluster_context), + None, + ) + .await; + assert!(connection_should_fail.is_err()); + let err = connection_should_fail.err().unwrap(); + assert!(err + .to_string() + .contains("Password authentication failed- AuthenticationFailed")); + + // Kill the connection to force reconnection + let mut management_conn = Connection::Cluster(management_connection); + kill_non_management_connections(&mut management_conn).await; + + // Verify that the connection with new password still works + let result_should_succeed: RedisResult = cmd("get") + .arg("foo") + .query_async(&mut connection_should_succeed) + .await; + assert!(result_should_succeed.is_ok()); + assert_eq!( + result_should_succeed.unwrap(), + Value::BulkString(b"bar".to_vec()) + ); + } + + #[tokio::test] + #[serial_test::serial] + async fn test_replace_password_standalone() { + let standalone_context = TestContext::new(); + + // Create a management connection to set the password + let management_connection = match create_connection( + None, + ConnectionType::Standalone, + None, + Some(&standalone_context), + ) + .await + .unwrap() + { + Connection::Standalone(conn) => conn, + _ => panic!("Expected Standalone connection"), + }; + + // Set the password using the unified function + let mut management_conn = Connection::Standalone(management_connection.clone()); + set_password(PASSWORD, &mut management_conn).await.unwrap(); + + // Test that we can't send commands with new connection without password + let connection_should_fail = create_connection( + None, + ConnectionType::Standalone, + None, + Some(&standalone_context), + ) + .await; + let res_should_fail: RedisResult = match connection_should_fail.unwrap() { + Connection::Cluster(mut conn) => cmd("get").arg("foo").query_async(&mut conn).await, + Connection::Standalone(mut conn) => cmd("get").arg("foo").query_async(&mut conn).await, + }; + assert!(res_should_fail.is_err()); + + // Test that we can connect with password + let mut connection_should_succeed = match create_connection( + Some(PASSWORD.to_string()), + ConnectionType::Standalone, + None, + Some(&standalone_context), + ) + .await + .unwrap() + { + Connection::Standalone(conn) => conn, + _ => panic!("Expected Standalone connection"), + }; + + let res: RedisResult = cmd("set") + .arg("foo") + .arg("bar") + .query_async(&mut connection_should_succeed) + .await; + assert_eq!(res.unwrap(), Value::Okay); + + // Update the password in the connection + connection_should_succeed + .update_connection_password(Some(NEW_PASSWORD.to_string())) + .await + .unwrap(); + + // Update the password on the server + let mut management_conn = Connection::Standalone(management_connection.clone()); + set_password(NEW_PASSWORD, &mut management_conn) + .await + .unwrap(); + + // Reset the management connection + kill_non_management_connections(&mut management_conn).await; + + // Test that we can't connect with the old password + let connection_should_fail = create_connection( + Some(PASSWORD.to_string()), + ConnectionType::Standalone, + None, + Some(&standalone_context), + ) + .await; + assert!(connection_should_fail.is_err()); + } +} diff --git a/glide-core/redis-rs/redis/tests/parser.rs b/glide-core/redis-rs/redis/tests/parser.rs new file mode 100644 index 0000000000..c4083f44bd --- /dev/null +++ b/glide-core/redis-rs/redis/tests/parser.rs @@ -0,0 +1,195 @@ +use std::{io, pin::Pin}; + +use redis::Value; +use { + futures::{ + ready, + task::{self, Poll}, + }, + partial_io::{quickcheck_types::GenWouldBlock, quickcheck_types::PartialWithErrors, PartialOp}, + quickcheck::{quickcheck, Gen}, + tokio::io::{AsyncRead, ReadBuf}, +}; + +mod support; +use crate::support::{block_on_all, encode_value}; + +#[derive(Clone, Debug)] +struct ArbitraryValue(Value); + +impl ::quickcheck::Arbitrary for ArbitraryValue { + fn arbitrary(g: &mut Gen) -> Self { + let size = g.size(); + ArbitraryValue(arbitrary_value(g, size)) + } + + fn shrink(&self) -> Box> { + match self.0 { + Value::Nil | Value::Okay => Box::new(None.into_iter()), + Value::Int(i) => Box::new(i.shrink().map(Value::Int).map(ArbitraryValue)), + Value::BulkString(ref xs) => { + Box::new(xs.shrink().map(Value::BulkString).map(ArbitraryValue)) + } + Value::Array(ref xs) | Value::Set(ref xs) => { + let ys = xs + .iter() + .map(|x| ArbitraryValue(x.clone())) + .collect::>(); + Box::new( + ys.shrink() + .map(|xs| xs.into_iter().map(|x| x.0).collect()) + .map(Value::Array) + .map(ArbitraryValue), + ) + } + Value::Map(ref _xs) => Box::new(vec![ArbitraryValue(Value::Map(vec![]))].into_iter()), + Value::Attribute { + ref data, + ref attributes, + } => Box::new( + vec![ArbitraryValue(Value::Attribute { + data: data.clone(), + attributes: attributes.clone(), + })] + .into_iter(), + ), + Value::Push { ref kind, ref data } => { + let mut ys = data + .iter() + .map(|x| ArbitraryValue(x.clone())) + .collect::>(); + ys.insert(0, ArbitraryValue(Value::SimpleString(kind.to_string()))); + Box::new( + ys.shrink() + .map(|xs| xs.into_iter().map(|x| x.0).collect()) + .map(Value::Array) + .map(ArbitraryValue), + ) + } + Value::SimpleString(ref status) => { + Box::new(status.shrink().map(Value::SimpleString).map(ArbitraryValue)) + } + Value::Double(i) => Box::new(i.shrink().map(Value::Double).map(ArbitraryValue)), + Value::Boolean(i) => Box::new(i.shrink().map(Value::Boolean).map(ArbitraryValue)), + Value::BigNumber(ref i) => { + Box::new(vec![ArbitraryValue(Value::BigNumber(i.clone()))].into_iter()) + } + Value::VerbatimString { + ref format, + ref text, + } => Box::new( + vec![ArbitraryValue(Value::VerbatimString { + format: format.clone(), + text: text.clone(), + })] + .into_iter(), + ), + } + } +} + +fn arbitrary_value(g: &mut Gen, recursive_size: usize) -> Value { + use quickcheck::Arbitrary; + if recursive_size == 0 { + Value::Nil + } else { + match u8::arbitrary(g) % 6 { + 0 => Value::Nil, + 1 => Value::Int(Arbitrary::arbitrary(g)), + 2 => Value::BulkString(Arbitrary::arbitrary(g)), + 3 => { + let size = { + let s = g.size(); + usize::arbitrary(g) % s + }; + Value::Array( + (0..size) + .map(|_| arbitrary_value(g, recursive_size / size)) + .collect(), + ) + } + 4 => { + let size = { + let s = g.size(); + usize::arbitrary(g) % s + }; + + let mut string = String::with_capacity(size); + for _ in 0..size { + let c = char::arbitrary(g); + if c.is_ascii_alphabetic() { + string.push(c); + } + } + + if string == "OK" { + Value::Okay + } else { + Value::SimpleString(string) + } + } + 5 => Value::Okay, + _ => unreachable!(), + } + } +} + +struct PartialAsyncRead { + inner: R, + ops: Box + Send>, +} + +impl AsyncRead for PartialAsyncRead +where + R: AsyncRead + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.ops.next() { + Some(PartialOp::Limited(n)) => { + let len = std::cmp::min(n, buf.remaining()); + buf.initialize_unfilled(); + let mut sub_buf = buf.take(len); + ready!(Pin::new(&mut self.inner).poll_read(cx, &mut sub_buf))?; + let filled = sub_buf.filled().len(); + buf.advance(filled); + Poll::Ready(Ok(())) + } + Some(PartialOp::Err(err)) => { + if err == io::ErrorKind::WouldBlock { + cx.waker().wake_by_ref(); + Poll::Pending + } else { + Err(io::Error::new( + err, + "error during read, generated by partial-io", + )) + .into() + } + } + Some(PartialOp::Unlimited) | None => Pin::new(&mut self.inner).poll_read(cx, buf), + } + } +} + +quickcheck! { + fn partial_io_parse(input: ArbitraryValue, seq: PartialWithErrors) -> () { + + let mut encoded_input = Vec::new(); + encode_value(&input.0, &mut encoded_input).unwrap(); + + let mut reader = &encoded_input[..]; + let mut partial_reader = PartialAsyncRead { inner: &mut reader, ops: Box::new(seq.into_iter()) }; + let mut decoder = combine::stream::Decoder::new(); + + let result = block_on_all(redis::parse_redis_value_async(&mut decoder, &mut partial_reader)); + assert!(result.as_ref().is_ok(), "{}", result.unwrap_err()); + assert_eq!( + result.unwrap(), + input.0, + ); + } +} diff --git a/glide-core/redis-rs/redis/tests/support/cluster.rs b/glide-core/redis-rs/redis/tests/support/cluster.rs new file mode 100644 index 0000000000..991331cfca --- /dev/null +++ b/glide-core/redis-rs/redis/tests/support/cluster.rs @@ -0,0 +1,792 @@ +#![cfg(feature = "cluster")] +#![allow(dead_code)] + +use std::convert::identity; +use std::env; +use std::process; +use std::thread::sleep; +use std::time::Duration; + +use redis::cluster_routing::RoutingInfo; +use redis::cluster_routing::SingleNodeRoutingInfo; +use redis::from_redis_value; + +#[cfg(feature = "cluster-async")] +use redis::aio::ConnectionLike; +#[cfg(feature = "cluster-async")] +use redis::cluster_async::Connect; +use redis::ConnectionInfo; +use redis::ProtocolVersion; +use redis::PushInfo; +use redis::RedisResult; +use redis::Value; +use tempfile::TempDir; + +use crate::support::{build_keys_and_certs_for_tls, Module}; + +#[cfg(feature = "tls-rustls")] +use super::{build_single_client, load_certs_from_file}; + +use super::use_protocol; +use super::RedisServer; +use super::TlsFilePaths; +use tokio::sync::mpsc; + +const LOCALHOST: &str = "127.0.0.1"; + +enum ClusterType { + Tcp, + TcpTls, +} + +impl ClusterType { + fn get_intended() -> ClusterType { + match env::var("REDISRS_SERVER_TYPE") + .ok() + .as_ref() + .map(|x| &x[..]) + { + Some("tcp") => ClusterType::Tcp, + Some("tcp+tls") => ClusterType::TcpTls, + Some(val) => { + panic!("Unknown server type {val:?}"); + } + None => ClusterType::Tcp, + } + } + + fn build_addr(port: u16) -> redis::ConnectionAddr { + match ClusterType::get_intended() { + ClusterType::Tcp => redis::ConnectionAddr::Tcp("127.0.0.1".into(), port), + ClusterType::TcpTls => redis::ConnectionAddr::TcpTls { + host: "127.0.0.1".into(), + port, + insecure: true, + tls_params: None, + }, + } + } +} + +fn port_in_use(addr: &str) -> bool { + let socket_addr: std::net::SocketAddr = addr.parse().expect("Invalid address"); + let socket = socket2::Socket::new( + socket2::Domain::for_address(socket_addr), + socket2::Type::STREAM, + None, + ) + .expect("Failed to create socket"); + + socket.connect(&socket_addr.into()).is_ok() +} + +pub struct RedisCluster { + pub servers: Vec, + pub folders: Vec, + pub tls_paths: Option, +} + +impl RedisCluster { + pub fn username() -> &'static str { + "hello" + } + + pub fn password() -> &'static str { + "world" + } + + pub fn client_name() -> &'static str { + "test_cluster_client" + } + + pub fn new(nodes: u16, replicas: u16) -> RedisCluster { + RedisCluster::with_modules(nodes, replicas, &[], false) + } + + #[cfg(feature = "tls-rustls")] + pub fn new_with_mtls(nodes: u16, replicas: u16) -> RedisCluster { + RedisCluster::with_modules(nodes, replicas, &[], true) + } + + pub fn with_modules( + nodes: u16, + replicas: u16, + modules: &[Module], + mtls_enabled: bool, + ) -> RedisCluster { + let mut servers = vec![]; + let mut folders = vec![]; + let mut addrs = vec![]; + let start_port = 7000; + let mut tls_paths = None; + + let mut is_tls = false; + + if let ClusterType::TcpTls = ClusterType::get_intended() { + // Create a shared set of keys in cluster mode + let tempdir = tempfile::Builder::new() + .prefix("redis") + .tempdir() + .expect("failed to create tempdir"); + let files = build_keys_and_certs_for_tls(&tempdir); + folders.push(tempdir); + tls_paths = Some(files); + is_tls = true; + } + + let max_attempts = 5; + + for node in 0..nodes { + let port = start_port + node; + + servers.push(RedisServer::new_with_addr_tls_modules_and_spawner( + ClusterType::build_addr(port), + None, + tls_paths.clone(), + mtls_enabled, + modules, + |cmd| { + let tempdir = tempfile::Builder::new() + .prefix("redis") + .tempdir() + .expect("failed to create tempdir"); + let acl_path = tempdir.path().join("users.acl"); + let acl_content = format!( + "user {} on allcommands allkeys >{}", + Self::username(), + Self::password() + ); + std::fs::write(&acl_path, acl_content).expect("failed to write acl file"); + cmd.arg("--cluster-enabled") + .arg("yes") + .arg("--cluster-config-file") + .arg(tempdir.path().join("nodes.conf")) + .arg("--cluster-node-timeout") + .arg("5000") + .arg("--appendonly") + .arg("yes") + .arg("--aclfile") + .arg(&acl_path); + if is_tls { + cmd.arg("--tls-cluster").arg("yes"); + if replicas > 0 { + cmd.arg("--tls-replication").arg("yes"); + } + } + let addr = format!("127.0.0.1:{port}"); + cmd.current_dir(tempdir.path()); + folders.push(tempdir); + addrs.push(addr.clone()); + + let mut cur_attempts = 0; + loop { + let mut process = cmd.spawn().unwrap(); + sleep(Duration::from_millis(100)); + + match process.try_wait() { + Ok(Some(status)) => { + let err = + format!("redis server creation failed with status {status:?}"); + if cur_attempts == max_attempts { + panic!("{err}"); + } + eprintln!("Retrying: {err}"); + cur_attempts += 1; + } + Ok(None) => { + let max_attempts = 20; + let mut cur_attempts = 0; + loop { + if cur_attempts == max_attempts { + panic!("redis server creation failed: Port {port} closed") + } + if port_in_use(&addr) { + return process; + } + eprintln!("Waiting for redis process to initialize"); + sleep(Duration::from_millis(50)); + cur_attempts += 1; + } + } + Err(e) => { + panic!("Unexpected error in redis server creation {e}"); + } + } + } + }, + )); + } + + let mut cmd = process::Command::new("redis-cli"); + cmd.stdout(process::Stdio::null()) + .arg("--cluster") + .arg("create") + .args(&addrs); + if replicas > 0 { + cmd.arg("--cluster-replicas").arg(replicas.to_string()); + } + cmd.arg("--cluster-yes"); + + if is_tls { + if mtls_enabled { + if let Some(TlsFilePaths { + redis_crt, + redis_key, + ca_crt, + }) = &tls_paths + { + cmd.arg("--cert"); + cmd.arg(redis_crt); + cmd.arg("--key"); + cmd.arg(redis_key); + cmd.arg("--cacert"); + cmd.arg(ca_crt); + cmd.arg("--tls"); + } + } else { + cmd.arg("--tls").arg("--insecure"); + } + } + + let mut cur_attempts = 0; + loop { + let output = cmd.output().unwrap(); + if output.status.success() { + break; + } else { + let err = format!("Cluster creation failed: {output:?}"); + if cur_attempts == max_attempts { + panic!("{err}"); + } + eprintln!("Retrying: {err}"); + sleep(Duration::from_millis(50)); + cur_attempts += 1; + } + } + + let cluster = RedisCluster { + servers, + folders, + tls_paths, + }; + if replicas > 0 { + cluster.wait_for_replicas(replicas, mtls_enabled); + } + + wait_for_status_ok(&cluster); + cluster + } + + // parameter `_mtls_enabled` can only be used if `feature = tls-rustls` is active + #[allow(dead_code)] + fn wait_for_replicas(&self, replicas: u16, _mtls_enabled: bool) { + 'server: for server in &self.servers { + let conn_info = server.connection_info(); + eprintln!( + "waiting until {:?} knows required number of replicas", + conn_info.addr + ); + + #[cfg(feature = "tls-rustls")] + let client = + build_single_client(server.connection_info(), &self.tls_paths, _mtls_enabled) + .unwrap(); + #[cfg(not(feature = "tls-rustls"))] + let client = redis::Client::open(server.connection_info()).unwrap(); + + let mut con = client.get_connection(None).unwrap(); + + // retry 500 times + for _ in 1..500 { + let value = redis::cmd("CLUSTER").arg("SLOTS").query(&mut con).unwrap(); + let slots: Vec> = redis::from_owned_redis_value(value).unwrap(); + + // all slots should have following items: + // [start slot range, end slot range, master's IP, replica1's IP, replica2's IP,... ] + if slots.iter().all(|slot| slot.len() >= 3 + replicas as usize) { + continue 'server; + } + + sleep(Duration::from_millis(100)); + } + + panic!("failed to create enough replicas"); + } + } + + pub fn stop(&mut self) { + for server in &mut self.servers { + server.stop(); + } + } + + pub fn iter_servers(&self) -> impl Iterator { + self.servers.iter() + } +} + +fn wait_for_status_ok(cluster: &RedisCluster) { + 'server: for server in &cluster.servers { + let log_file = RedisServer::log_file(&server.tempdir); + + for _ in 1..500 { + let contents = + std::fs::read_to_string(&log_file).expect("Should have been able to read the file"); + + if contents.contains("Cluster state changed: ok") { + continue 'server; + } + sleep(Duration::from_millis(20)); + } + panic!("failed to reach state change: OK"); + } +} + +impl Drop for RedisCluster { + fn drop(&mut self) { + self.stop() + } +} + +pub struct TestClusterContext { + pub cluster: RedisCluster, + pub client: redis::cluster::ClusterClient, + pub mtls_enabled: bool, + pub nodes: Vec, + pub protocol: ProtocolVersion, +} + +impl TestClusterContext { + pub fn new(nodes: u16, replicas: u16) -> TestClusterContext { + Self::new_with_cluster_client_builder(nodes, replicas, identity, false) + } + + #[cfg(feature = "tls-rustls")] + pub fn new_with_mtls(nodes: u16, replicas: u16) -> TestClusterContext { + Self::new_with_cluster_client_builder(nodes, replicas, identity, true) + } + + pub fn new_with_cluster_client_builder( + nodes: u16, + replicas: u16, + initializer: F, + mtls_enabled: bool, + ) -> TestClusterContext + where + F: FnOnce(redis::cluster::ClusterClientBuilder) -> redis::cluster::ClusterClientBuilder, + { + let cluster = RedisCluster::new(nodes, replicas); + let initial_nodes: Vec = cluster + .iter_servers() + .map(RedisServer::connection_info) + .collect(); + let mut builder = redis::cluster::ClusterClientBuilder::new(initial_nodes.clone()) + .use_protocol(use_protocol()); + + #[cfg(feature = "tls-rustls")] + if mtls_enabled { + if let Some(tls_file_paths) = &cluster.tls_paths { + builder = builder.certs(load_certs_from_file(tls_file_paths)); + } + } + + builder = initializer(builder); + + let client = builder.build().unwrap(); + + TestClusterContext { + cluster, + client, + mtls_enabled, + nodes: initial_nodes, + protocol: use_protocol(), + } + } + + pub fn connection(&self) -> redis::cluster::ClusterConnection { + self.client.get_connection(None).unwrap() + } + + #[cfg(feature = "cluster-async")] + pub async fn async_connection( + &self, + push_sender: Option>, + ) -> redis::cluster_async::ClusterConnection { + self.client.get_async_connection(push_sender).await.unwrap() + } + + #[cfg(feature = "cluster-async")] + pub async fn async_generic_connection< + C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static, + >( + &self, + ) -> redis::cluster_async::ClusterConnection { + self.client + .get_async_generic_connection::() + .await + .unwrap() + } + + pub fn wait_for_cluster_up(&self) { + let mut con = self.connection(); + let mut c = redis::cmd("CLUSTER"); + c.arg("INFO"); + + for _ in 0..100 { + let r: String = c.query::(&mut con).unwrap(); + if r.starts_with("cluster_state:ok") { + return; + } + + sleep(Duration::from_millis(25)); + } + + panic!("failed waiting for cluster to be ready"); + } + + pub fn disable_default_user(&self) { + for server in &self.cluster.servers { + #[cfg(feature = "tls-rustls")] + let client = build_single_client( + server.connection_info(), + &self.cluster.tls_paths, + self.mtls_enabled, + ) + .unwrap(); + #[cfg(not(feature = "tls-rustls"))] + let client = redis::Client::open(server.connection_info()).unwrap(); + + let mut con = client.get_connection(None).unwrap(); + let _: () = redis::cmd("ACL") + .arg("SETUSER") + .arg("default") + .arg("off") + .query(&mut con) + .unwrap(); + + // subsequent unauthenticated command should fail: + if let Ok(mut con) = client.get_connection(None) { + assert!(redis::cmd("PING").query::<()>(&mut con).is_err()); + } + } + } + + pub fn get_version(&self) -> super::Version { + let mut conn = self.connection(); + super::get_version(&mut conn) + } + + pub fn get_node_ids(&self) -> Vec { + let mut conn = self.connection(); + let nodes: Vec = redis::cmd("CLUSTER") + .arg("NODES") + .query::(&mut conn) + .unwrap() + .split('\n') + .map(|s| s.to_string()) + .collect(); + let node_ids: Vec = nodes + .iter() + .map(|node| node.split(' ').next().unwrap().to_string()) + .collect(); + node_ids + .iter() + .filter(|id| !id.is_empty()) + .cloned() + .collect() + } + + // Migrate half the slots from one node to another + pub async fn migrate_slots_from_node_to_another( + &self, + slot_distribution: Vec<(String, String, String, Vec>)>, + ) { + let slots_ranges_of_node_id = slot_distribution[0].3.clone(); + + let mut conn = self.async_connection(None).await; + + let from = slot_distribution[0].clone(); + let target = slot_distribution[1].clone(); + + let from_node_id = from.0.clone(); + let target_node_id = target.0.clone(); + + let from_route = RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: from.1.clone(), + port: from.2.clone().parse::().unwrap(), + }); + let target_route = RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: target.1.clone(), + port: target.2.clone().parse::().unwrap(), + }); + + // Migrate the slots + for range in slots_ranges_of_node_id { + let mut slots_of_nodes: std::ops::Range = range[0]..range[1]; + let number_of_slots = range[1] - range[0] + 1; + // Migrate half the slots + for _i in 0..(number_of_slots as f64 / 2.0).floor() as usize { + let slot = slots_of_nodes.next().unwrap(); + // Set the nodes to MIGRATING and IMPORTING + let mut set_cmd = redis::cmd("CLUSTER"); + set_cmd + .arg("SETSLOT") + .arg(slot) + .arg("IMPORTING") + .arg(from_node_id.clone()); + let result: RedisResult = + conn.route_command(&set_cmd, target_route.clone()).await; + match result { + Ok(_) => {} + Err(err) => { + println!( + "Failed to set slot {} to IMPORTING with error {}", + slot, err + ); + } + } + let mut set_cmd = redis::cmd("CLUSTER"); + set_cmd + .arg("SETSLOT") + .arg(slot) + .arg("MIGRATING") + .arg(target_node_id.clone()); + let result: RedisResult = + conn.route_command(&set_cmd, from_route.clone()).await; + match result { + Ok(_) => {} + Err(err) => { + println!( + "Failed to set slot {} to MIGRATING with error {}", + slot, err + ); + } + } + // Get a key from the slot + let mut get_key_cmd = redis::cmd("CLUSTER"); + get_key_cmd.arg("GETKEYSINSLOT").arg(slot).arg(1); + let result: RedisResult = + conn.route_command(&get_key_cmd, from_route.clone()).await; + let vec_string_result: Vec = match result { + Ok(val) => { + let val: Vec = from_redis_value(&val).unwrap(); + val + } + Err(err) => { + println!("Failed to get keys in slot {}: {:?}", slot, err); + continue; + } + }; + if vec_string_result.is_empty() { + continue; + } + let key = vec_string_result[0].clone(); + // Migrate the key, which will make the whole slot to move + let mut migrate_cmd = redis::cmd("MIGRATE"); + migrate_cmd + .arg(target.1.clone()) + .arg(target.2.clone()) + .arg(key.clone()) + .arg(0) + .arg(5000); + let result: RedisResult = + conn.route_command(&migrate_cmd, from_route.clone()).await; + + match result { + Ok(Value::Okay) => {} + Ok(Value::SimpleString(str)) => { + if str != "NOKEY" { + println!( + "Failed to migrate key {} to target node with status {}", + key, str + ); + } else { + println!("Key {} does not exist", key); + } + } + Ok(_) => {} + Err(err) => { + println!( + "Failed to migrate key {} to target node with error {}", + key, err + ); + } + } + // Tell the source and target nodes to propagate the slot change to the cluster + let mut setslot_cmd = redis::cmd("CLUSTER"); + setslot_cmd + .arg("SETSLOT") + .arg(slot) + .arg("NODE") + .arg(target_node_id.clone()); + let result: RedisResult = + conn.route_command(&setslot_cmd, target_route.clone()).await; + match result { + Ok(_) => {} + Err(err) => { + println!( + "Failed to set slot {} to target NODE with error {}", + slot, err + ); + } + }; + self.wait_for_connection_is_ready(&from_route) + .await + .unwrap(); + self.wait_for_connection_is_ready(&target_route) + .await + .unwrap(); + self.wait_for_cluster_up(); + } + } + } + + // Return the slots distribution of the cluster as a vector of tuples + // where the first element is the node id, seconed is host, third is port and the last element is a vector of slots ranges + pub fn get_slots_ranges_distribution( + &self, + cluster_nodes: &str, + ) -> Vec<(String, String, String, Vec>)> { + let nodes_string: Vec = cluster_nodes + .split('\n') + .map(|s| s.to_string()) + .filter(|s| !s.is_empty()) + .collect(); + let mut nodes: Vec> = vec![]; + for node in nodes_string { + let node_vec: Vec = node.split(' ').map(|s| s.to_string()).collect(); + if node_vec.last().unwrap() == "connected" || node_vec.last().unwrap() == "disconnected" + { + continue; + } else { + nodes.push(node_vec); + } + } + let mut slot_distribution = vec![]; + for node in &nodes { + let mut slots_ranges: Vec> = vec![]; + let mut slots_ranges_vec: Vec = vec![]; + let node_id = node[0].clone(); + let host_and_port: Vec = node[1].split(':').map(|s| s.to_string()).collect(); + let host = host_and_port[0].clone(); + let port = host_and_port[1].split('@').next().unwrap().to_string(); + let slots = node[8..].to_vec(); + for slot in slots { + if slot.contains("->") || slot.contains("<-") { + continue; + } + if slot.contains('-') { + let range: Vec = + slot.split('-').map(|s| s.parse::().unwrap()).collect(); + slots_ranges_vec.push(range[0]); + slots_ranges_vec.push(range[1]); + slots_ranges.push(slots_ranges_vec.clone()); + slots_ranges_vec.clear(); + } else { + let slot: u16 = slot.parse::().unwrap(); + slots_ranges_vec.push(slot); + slots_ranges_vec.push(slot); + slots_ranges.push(slots_ranges_vec.clone()); + slots_ranges_vec.clear(); + } + } + let parsed_node: (String, String, String, Vec>) = + (node_id, host, port, slots_ranges); + slot_distribution.push(parsed_node); + } + slot_distribution + } + + pub async fn get_masters(&self, cluster_nodes: &str) -> Vec> { + let mut masters = vec![]; + for line in cluster_nodes.lines() { + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() < 3 { + continue; + } + if parts[2] == "master" || parts[2] == "myself,master" { + let id = parts[0]; + let host_and_port = parts[1].split(':'); + let host = host_and_port.clone().next().unwrap(); + let port = host_and_port + .clone() + .last() + .unwrap() + .split('@') + .next() + .unwrap(); + masters.push(vec![id.to_string(), host.to_string(), port.to_string()]); + } + } + masters + } + + pub async fn get_replicas(&self, cluster_nodes: &str) -> Vec> { + let mut replicas = vec![]; + for line in cluster_nodes.lines() { + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() < 3 { + continue; + } + if parts[2] == "slave" || parts[2] == "myself,slave" { + let id = parts[0]; + let host_and_port = parts[1].split(':'); + let host = host_and_port.clone().next().unwrap(); + let port = host_and_port + .clone() + .last() + .unwrap() + .split('@') + .next() + .unwrap(); + replicas.push(vec![id.to_string(), host.to_string(), port.to_string()]); + } + } + replicas + } + + pub async fn get_cluster_nodes(&self) -> String { + let mut conn = self.async_connection(None).await; + let mut cmd = redis::cmd("CLUSTER"); + cmd.arg("NODES"); + let res: RedisResult = conn + .route_command(&cmd, RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + .await; + let res: String = from_redis_value(&res.unwrap()).unwrap(); + res + } + + pub async fn wait_for_fail_to_finish(&self, route: &RoutingInfo) -> RedisResult<()> { + for _ in 0..500 { + let mut conn = self.async_connection(None).await; + let cmd = redis::cmd("PING"); + let res: RedisResult = conn.route_command(&cmd, route.clone()).await; + if res.is_err() { + return Ok(()); + } + sleep(Duration::from_millis(50)); + } + Err(redis::RedisError::from(( + redis::ErrorKind::IoError, + "Failed to get connection", + ))) + } + + pub async fn wait_for_connection_is_ready(&self, route: &RoutingInfo) -> RedisResult<()> { + let mut i = 1; + while i < 1000 { + let mut conn = self.async_connection(None).await; + let cmd = redis::cmd("PING"); + let res: RedisResult = conn.route_command(&cmd, route.clone()).await; + if res.is_ok() { + return Ok(()); + } + sleep(Duration::from_millis(i * 10)); + i += 10; + } + Err(redis::RedisError::from(( + redis::ErrorKind::IoError, + "Failed to get connection", + ))) + } +} diff --git a/glide-core/redis-rs/redis/tests/support/mock_cluster.rs b/glide-core/redis-rs/redis/tests/support/mock_cluster.rs new file mode 100644 index 0000000000..ce91988cef --- /dev/null +++ b/glide-core/redis-rs/redis/tests/support/mock_cluster.rs @@ -0,0 +1,487 @@ +use redis::{ + cluster::{self, ClusterClient, ClusterClientBuilder}, + ErrorKind, FromRedisValue, GlideConnectionOptions, RedisError, +}; + +use std::{ + collections::HashMap, + net::{IpAddr, Ipv4Addr, SocketAddr}, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, RwLock, + }, + time::Duration, +}; + +use { + once_cell::sync::Lazy, + redis::{IntoConnectionInfo, RedisResult, Value}, +}; + +#[cfg(feature = "cluster-async")] +use redis::{aio, cluster_async, RedisFuture}; + +#[cfg(feature = "cluster-async")] +use futures::future; + +#[cfg(feature = "cluster-async")] +use tokio::runtime::Runtime; + +type Handler = Arc Result<(), RedisResult> + Send + Sync>; + +pub struct MockConnectionBehavior { + pub id: String, + pub handler: Handler, + pub connection_id_provider: AtomicUsize, + pub returned_ip_type: ConnectionIPReturnType, + pub return_connection_err: ShouldReturnConnectionError, +} + +impl MockConnectionBehavior { + fn new(id: &str, handler: Handler) -> Self { + Self { + id: id.to_string(), + handler, + connection_id_provider: AtomicUsize::new(0), + returned_ip_type: ConnectionIPReturnType::default(), + return_connection_err: ShouldReturnConnectionError::default(), + } + } + + #[must_use] + pub fn register_new(id: &str, handler: Handler) -> RemoveHandler { + get_behaviors().insert(id.to_string(), Self::new(id, handler)); + RemoveHandler(vec![id.to_string()]) + } + + fn get_handler(&self) -> Handler { + self.handler.clone() + } +} + +pub fn modify_mock_connection_behavior(name: &str, func: impl FnOnce(&mut MockConnectionBehavior)) { + func( + get_behaviors() + .get_mut(name) + .expect("Handler `{name}` was not installed"), + ); +} + +pub fn get_mock_connection_handler(name: &str) -> Handler { + MOCK_CONN_BEHAVIORS + .read() + .unwrap() + .get(name) + .expect("Handler `{name}` was not installed") + .get_handler() +} + +pub fn get_mock_connection(name: &str, id: usize) -> MockConnection { + get_mock_connection_with_port(name, id, 6379) +} + +pub fn get_mock_connection_with_port(name: &str, id: usize, port: u16) -> MockConnection { + MockConnection { + id, + handler: get_mock_connection_handler(name), + port, + } +} + +static MOCK_CONN_BEHAVIORS: Lazy>> = + Lazy::new(Default::default); + +fn get_behaviors() -> std::sync::RwLockWriteGuard<'static, HashMap> +{ + MOCK_CONN_BEHAVIORS.write().unwrap() +} + +#[derive(Default)] +pub enum ConnectionIPReturnType { + /// New connections' IP will be returned as None + #[default] + None, + /// Creates connections with the specified IP + Specified(IpAddr), + /// Each new connection will be created with a different IP based on the passed atomic integer + Different(AtomicUsize), +} + +#[derive(Default)] +pub enum ShouldReturnConnectionError { + /// Don't return a connection error + #[default] + No, + /// Always return a connection error + Yes, + /// Return connection error when the internal index is an odd number + OnOddIdx(AtomicUsize), +} + +#[derive(Clone)] +pub struct MockConnection { + pub id: usize, + pub handler: Handler, + pub port: u16, +} + +#[cfg(feature = "cluster-async")] +impl cluster_async::Connect for MockConnection { + fn connect<'a, T>( + info: T, + _response_timeout: Duration, + _connection_timeout: Duration, + _socket_addr: Option, + _glide_connection_options: GlideConnectionOptions, + ) -> RedisFuture<'a, (Self, Option)> + where + T: IntoConnectionInfo + Send + 'a, + { + let info = info.into_connection_info().unwrap(); + + let (name, port) = match &info.addr { + redis::ConnectionAddr::Tcp(addr, port) => (addr, *port), + _ => unreachable!(), + }; + let binding = MOCK_CONN_BEHAVIORS.read().unwrap(); + let conn_utils = binding + .get(name) + .unwrap_or_else(|| panic!("MockConnectionUtils for `{name}` were not installed")); + let conn_err = Box::pin(future::err(RedisError::from(std::io::Error::new( + std::io::ErrorKind::ConnectionReset, + "mock-io-error", + )))); + match &conn_utils.return_connection_err { + ShouldReturnConnectionError::No => {} + ShouldReturnConnectionError::Yes => return conn_err, + ShouldReturnConnectionError::OnOddIdx(curr_idx) => { + if curr_idx.fetch_add(1, Ordering::SeqCst) % 2 != 0 { + // raise an error on each odd number + return conn_err; + } + } + } + + let ip = match &conn_utils.returned_ip_type { + ConnectionIPReturnType::Specified(ip) => Some(*ip), + ConnectionIPReturnType::Different(ip_getter) => { + let first_ip_num = ip_getter.fetch_add(1, Ordering::SeqCst) as u8; + Some(IpAddr::V4(Ipv4Addr::new(first_ip_num, 0, 0, 0))) + } + ConnectionIPReturnType::None => None, + }; + + Box::pin(future::ok(( + MockConnection { + id: conn_utils + .connection_id_provider + .fetch_add(1, Ordering::SeqCst), + handler: conn_utils.get_handler(), + port, + }, + ip, + ))) + } +} + +impl cluster::Connect for MockConnection { + fn connect<'a, T>(info: T, _timeout: Option) -> RedisResult + where + T: IntoConnectionInfo, + { + let info = info.into_connection_info().unwrap(); + + let (name, port) = match &info.addr { + redis::ConnectionAddr::Tcp(addr, port) => (addr, *port), + _ => unreachable!(), + }; + let binding = MOCK_CONN_BEHAVIORS.read().unwrap(); + let conn_utils = binding + .get(name) + .unwrap_or_else(|| panic!("MockConnectionUtils for `{name}` were not installed")); + Ok(MockConnection { + id: conn_utils + .connection_id_provider + .fetch_add(1, Ordering::SeqCst), + handler: conn_utils.get_handler(), + port, + }) + } + + fn send_packed_command(&mut self, _cmd: &[u8]) -> RedisResult<()> { + Ok(()) + } + + fn set_write_timeout(&self, _dur: Option) -> RedisResult<()> { + Ok(()) + } + + fn set_read_timeout(&self, _dur: Option) -> RedisResult<()> { + Ok(()) + } + + fn recv_response(&mut self) -> RedisResult { + Ok(Value::Nil) + } +} + +pub fn contains_slice(xs: &[u8], ys: &[u8]) -> bool { + for i in 0..xs.len() { + if xs[i..].starts_with(ys) { + return true; + } + } + false +} + +pub fn respond_startup(name: &str, cmd: &[u8]) -> Result<(), RedisResult> { + if contains_slice(cmd, b"PING") || contains_slice(cmd, b"SETNAME") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![Value::Array(vec![ + Value::Int(0), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6379), + ]), + ])]))) + } else if contains_slice(cmd, b"READONLY") { + Err(Ok(Value::SimpleString("OK".into()))) + } else { + Ok(()) + } +} + +#[derive(Clone, Debug)] +pub struct MockSlotRange { + pub primary_port: u16, + pub replica_ports: Vec, + pub slot_range: std::ops::Range, +} + +pub fn respond_startup_with_replica(name: &str, cmd: &[u8]) -> Result<(), RedisResult> { + respond_startup_with_replica_using_config(name, cmd, None) +} + +pub fn respond_startup_two_nodes(name: &str, cmd: &[u8]) -> Result<(), RedisResult> { + respond_startup_with_config(name, cmd, None, false) +} + +pub fn create_topology_from_config(name: &str, slots_config: Vec) -> Value { + let slots_vec = slots_config + .into_iter() + .map(|slot_config| { + let mut config = vec![ + Value::Int(slot_config.slot_range.start as i64), + Value::Int(slot_config.slot_range.end as i64), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(slot_config.primary_port as i64), + ]), + ]; + config.extend(slot_config.replica_ports.into_iter().map(|replica_port| { + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(replica_port as i64), + ]) + })); + Value::Array(config) + }) + .collect(); + Value::Array(slots_vec) +} + +pub fn respond_startup_with_replica_using_config( + name: &str, + cmd: &[u8], + slots_config: Option>, +) -> Result<(), RedisResult> { + respond_startup_with_config(name, cmd, slots_config, true) +} + +/// If the configuration isn't provided, a configuration with two primary nodes, with or without replicas, will be used. +pub fn respond_startup_with_config( + name: &str, + cmd: &[u8], + slots_config: Option>, + with_replicas: bool, +) -> Result<(), RedisResult> { + let slots_config = slots_config.unwrap_or(if with_replicas { + vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380], + slot_range: (0..8191), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: vec![6382], + slot_range: (8192..16383), + }, + ] + } else { + vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![], + slot_range: (0..8191), + }, + MockSlotRange { + primary_port: 6380, + replica_ports: vec![], + slot_range: (8192..16383), + }, + ] + }); + if contains_slice(cmd, b"PING") || contains_slice(cmd, b"SETNAME") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + let slots = create_topology_from_config(name, slots_config); + Err(Ok(slots)) + } else if contains_slice(cmd, b"READONLY") { + Err(Ok(Value::SimpleString("OK".into()))) + } else { + Ok(()) + } +} + +#[cfg(feature = "cluster-async")] +impl aio::ConnectionLike for MockConnection { + fn req_packed_command<'a>(&'a mut self, cmd: &'a redis::Cmd) -> RedisFuture<'a, Value> { + Box::pin(future::ready( + (self.handler)(&cmd.get_packed_command(), self.port) + .expect_err("Handler did not specify a response"), + )) + } + + fn req_packed_commands<'a>( + &'a mut self, + _pipeline: &'a redis::Pipeline, + _offset: usize, + _count: usize, + ) -> RedisFuture<'a, Vec> { + Box::pin(future::ok(vec![])) + } + + fn get_db(&self) -> i64 { + 0 + } + + fn is_closed(&self) -> bool { + false + } +} + +impl redis::ConnectionLike for MockConnection { + fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult { + (self.handler)(cmd, self.port).expect_err("Handler did not specify a response") + } + + fn req_packed_commands( + &mut self, + cmd: &[u8], + offset: usize, + _count: usize, + ) -> RedisResult> { + let res = (self.handler)(cmd, self.port).expect_err("Handler did not specify a response"); + match res { + Err(err) => Err(err), + Ok(res) => { + if let Value::Array(results) = res { + match results.into_iter().nth(offset) { + Some(Value::Array(res)) => Ok(res), + _ => Err((ErrorKind::ResponseError, "non-array response").into()), + } + } else { + Err(( + ErrorKind::ResponseError, + "non-array response", + String::from_owned_redis_value(res).unwrap(), + ) + .into()) + } + } + } + } + + fn get_db(&self) -> i64 { + 0 + } + + fn check_connection(&mut self) -> bool { + true + } + + fn is_open(&self) -> bool { + true + } +} + +pub struct MockEnv { + #[cfg(feature = "cluster-async")] + pub runtime: Runtime, + pub client: redis::cluster::ClusterClient, + pub connection: redis::cluster::ClusterConnection, + #[cfg(feature = "cluster-async")] + pub async_connection: redis::cluster_async::ClusterConnection, + #[allow(unused)] + pub handler: RemoveHandler, +} + +pub struct RemoveHandler(Vec); + +impl Drop for RemoveHandler { + fn drop(&mut self) { + for id in &self.0 { + get_behaviors().remove(id); + } + } +} + +impl MockEnv { + pub fn new( + id: &str, + handler: impl Fn(&[u8], u16) -> Result<(), RedisResult> + Send + Sync + 'static, + ) -> Self { + Self::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{id}")]), + id, + handler, + ) + } + + pub fn with_client_builder( + client_builder: ClusterClientBuilder, + id: &str, + handler: impl Fn(&[u8], u16) -> Result<(), RedisResult> + Send + Sync + 'static, + ) -> Self { + #[cfg(feature = "cluster-async")] + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .unwrap(); + + let id = id.to_string(); + let handler = MockConnectionBehavior::register_new( + &id, + Arc::new(move |cmd, port| handler(cmd, port)), + ); + let client = client_builder.build().unwrap(); + let connection = client.get_generic_connection(None).unwrap(); + #[cfg(feature = "cluster-async")] + let async_connection = runtime + .block_on(client.get_async_generic_connection()) + .unwrap(); + MockEnv { + #[cfg(feature = "cluster-async")] + runtime, + client, + connection, + #[cfg(feature = "cluster-async")] + async_connection, + handler, + } + } +} diff --git a/glide-core/redis-rs/redis/tests/support/mod.rs b/glide-core/redis-rs/redis/tests/support/mod.rs new file mode 100644 index 0000000000..72dc7c9a78 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/support/mod.rs @@ -0,0 +1,861 @@ +#![allow(dead_code)] + +use std::path::Path; +use std::{ + env, fs, io, net::SocketAddr, net::TcpListener, path::PathBuf, process, thread::sleep, + time::Duration, +}; +#[cfg(feature = "tls-rustls")] +use std::{ + fs::File, + io::{BufReader, Read}, +}; + +#[cfg(feature = "aio")] +use futures::Future; +use redis::{ConnectionAddr, InfoDict, Pipeline, ProtocolVersion, RedisConnectionInfo, Value}; + +#[cfg(feature = "tls-rustls")] +use redis::{ClientTlsConfig, TlsCertificates}; + +use socket2::{Domain, Socket, Type}; +use tempfile::TempDir; + +#[cfg(feature = "aio")] +use redis::GlideConnectionOptions; + +pub fn use_protocol() -> ProtocolVersion { + if env::var("PROTOCOL").unwrap_or_default() == "RESP3" { + ProtocolVersion::RESP3 + } else { + ProtocolVersion::RESP2 + } +} + +pub fn current_thread_runtime() -> tokio::runtime::Runtime { + let mut builder = tokio::runtime::Builder::new_current_thread(); + + #[cfg(feature = "aio")] + builder.enable_io(); + + builder.enable_time(); + + builder.build().unwrap() +} + +#[cfg(feature = "aio")] +pub fn block_on_all(f: F) -> F::Output +where + F: Future>, +{ + use std::panic; + use std::sync::atomic::{AtomicBool, Ordering}; + + static CHECK: AtomicBool = AtomicBool::new(false); + + // TODO - this solution is purely single threaded, and won't work on multiple threads at the same time. + // This is needed because Tokio's Runtime silently ignores panics - https://users.rust-lang.org/t/tokio-runtime-what-happens-when-a-thread-panics/95819 + // Once Tokio stabilizes the `unhandled_panic` field on the runtime builder, it should be used instead. + panic::set_hook(Box::new(|panic| { + println!("Panic: {panic}"); + CHECK.store(true, Ordering::Relaxed); + })); + + // This continuously query the flag, in order to abort ASAP after a panic. + let check_future = futures_util::FutureExt::fuse(async { + loop { + if CHECK.load(Ordering::Relaxed) { + return Err((redis::ErrorKind::IoError, "panic was caught").into()); + } + futures_time::task::sleep(futures_time::time::Duration::from_millis(1)).await; + } + }); + let f = futures_util::FutureExt::fuse(f); + futures::pin_mut!(f, check_future); + + let res = current_thread_runtime().block_on(async { + futures::select! {res = f => res, err = check_future => err} + }); + + let _ = panic::take_hook(); + if CHECK.swap(false, Ordering::Relaxed) { + panic!("Internal thread panicked"); + } + + res +} + +#[cfg(any(feature = "cluster", feature = "cluster-async"))] +mod cluster; + +#[cfg(any(feature = "cluster", feature = "cluster-async"))] +mod mock_cluster; + +mod util; +#[allow(unused_imports)] +pub use self::util::*; + +#[cfg(any(feature = "cluster", feature = "cluster-async"))] +#[allow(unused_imports)] +pub use self::cluster::*; + +#[cfg(any(feature = "cluster", feature = "cluster-async"))] +#[allow(unused_imports)] +pub use self::mock_cluster::*; + +#[cfg(feature = "sentinel")] +mod sentinel; + +#[cfg(feature = "sentinel")] +#[allow(unused_imports)] +pub use self::sentinel::*; + +#[derive(PartialEq)] +enum ServerType { + Tcp { tls: bool }, + Unix, +} + +pub enum Module { + Json, +} + +pub struct RedisServer { + pub process: process::Child, + pub(crate) tempdir: tempfile::TempDir, + pub(crate) addr: redis::ConnectionAddr, + pub(crate) tls_paths: Option, +} + +impl ServerType { + fn get_intended() -> ServerType { + match env::var("REDISRS_SERVER_TYPE") + .ok() + .as_ref() + .map(|x| &x[..]) + { + Some("tcp") => ServerType::Tcp { tls: false }, + Some("tcp+tls") => ServerType::Tcp { tls: true }, + Some("unix") => ServerType::Unix, + Some(val) => { + panic!("Unknown server type {val:?}"); + } + None => ServerType::Tcp { tls: false }, + } + } +} + +impl RedisServer { + pub fn new() -> RedisServer { + RedisServer::with_modules(&[], false) + } + + #[cfg(feature = "tls-rustls")] + pub fn new_with_mtls() -> RedisServer { + RedisServer::with_modules(&[], true) + } + + pub fn get_addr(port: u16) -> ConnectionAddr { + let server_type = ServerType::get_intended(); + match server_type { + ServerType::Tcp { tls } => { + if tls { + redis::ConnectionAddr::TcpTls { + host: "127.0.0.1".to_string(), + port, + insecure: true, + tls_params: None, + } + } else { + redis::ConnectionAddr::Tcp("127.0.0.1".to_string(), port) + } + } + ServerType::Unix => { + let (a, b) = rand::random::<(u64, u64)>(); + let path = format!("/tmp/redis-rs-test-{a}-{b}.sock"); + redis::ConnectionAddr::Unix(PathBuf::from(&path)) + } + } + } + + pub fn with_modules(modules: &[Module], mtls_enabled: bool) -> RedisServer { + // this is technically a race but we can't do better with + // the tools that redis gives us :( + let redis_port = get_random_available_port(); + let addr = RedisServer::get_addr(redis_port); + + RedisServer::new_with_addr_tls_modules_and_spawner( + addr, + None, + None, + mtls_enabled, + modules, + |cmd| { + cmd.spawn() + .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}")) + }, + ) + } + + pub fn new_with_addr_and_modules( + addr: redis::ConnectionAddr, + modules: &[Module], + mtls_enabled: bool, + ) -> RedisServer { + RedisServer::new_with_addr_tls_modules_and_spawner( + addr, + None, + None, + mtls_enabled, + modules, + |cmd| { + cmd.spawn() + .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}")) + }, + ) + } + + pub fn new_with_addr_tls_modules_and_spawner< + F: FnOnce(&mut process::Command) -> process::Child, + >( + addr: redis::ConnectionAddr, + config_file: Option<&Path>, + tls_paths: Option, + mtls_enabled: bool, + modules: &[Module], + spawner: F, + ) -> RedisServer { + let mut redis_cmd = process::Command::new("redis-server"); + + if let Some(config_path) = config_file { + redis_cmd.arg(config_path); + } + + // Load Redis Modules + for module in modules { + match module { + Module::Json => { + redis_cmd + .arg("--loadmodule") + .arg(env::var("REDIS_RS_REDIS_JSON_PATH").expect( + "Unable to find path to RedisJSON at REDIS_RS_REDIS_JSON_PATH, is it set?", + )); + } + }; + } + + redis_cmd + .stdout(process::Stdio::null()) + .stderr(process::Stdio::null()); + let tempdir = tempfile::Builder::new() + .prefix("redis") + .tempdir() + .expect("failed to create tempdir"); + redis_cmd.arg("--logfile").arg(Self::log_file(&tempdir)); + match addr { + redis::ConnectionAddr::Tcp(ref bind, server_port) => { + redis_cmd + .arg("--port") + .arg(server_port.to_string()) + .arg("--bind") + .arg(bind); + + RedisServer { + process: spawner(&mut redis_cmd), + tempdir, + addr, + tls_paths: None, + } + } + redis::ConnectionAddr::TcpTls { ref host, port, .. } => { + let tls_paths = tls_paths.unwrap_or_else(|| build_keys_and_certs_for_tls(&tempdir)); + + let auth_client = if mtls_enabled { "yes" } else { "no" }; + + // prepare redis with TLS + redis_cmd + .arg("--tls-port") + .arg(port.to_string()) + .arg("--port") + .arg("0") + .arg("--tls-cert-file") + .arg(&tls_paths.redis_crt) + .arg("--tls-key-file") + .arg(&tls_paths.redis_key) + .arg("--tls-ca-cert-file") + .arg(&tls_paths.ca_crt) + .arg("--tls-auth-clients") + .arg(auth_client) + .arg("--bind") + .arg(host); + + // Insecure only disabled if `mtls` is enabled + let insecure = !mtls_enabled; + + let addr = redis::ConnectionAddr::TcpTls { + host: host.clone(), + port, + insecure, + tls_params: None, + }; + + RedisServer { + process: spawner(&mut redis_cmd), + tempdir, + addr, + tls_paths: Some(tls_paths), + } + } + redis::ConnectionAddr::Unix(ref path) => { + redis_cmd + .arg("--port") + .arg("0") + .arg("--unixsocket") + .arg(path); + RedisServer { + process: spawner(&mut redis_cmd), + tempdir, + addr, + tls_paths: None, + } + } + } + } + + pub fn client_addr(&self) -> &redis::ConnectionAddr { + &self.addr + } + + pub fn connection_info(&self) -> redis::ConnectionInfo { + redis::ConnectionInfo { + addr: self.client_addr().clone(), + redis: RedisConnectionInfo { + protocol: use_protocol(), + ..Default::default() + }, + } + } + + pub fn stop(&mut self) { + let _ = self.process.kill(); + let _ = self.process.wait(); + if let redis::ConnectionAddr::Unix(ref path) = *self.client_addr() { + fs::remove_file(path).ok(); + } + } + + pub fn log_file(tempdir: &TempDir) -> PathBuf { + tempdir.path().join("redis.log") + } +} + +/// Finds a random open port available for listening at, by spawning a TCP server with +/// port "zero" (which prompts the OS to just use any available port). Between calling +/// this function and trying to bind to this port, the port may be given to another +/// process, so this must be used with care (since here we only use it for tests, it's +/// mostly okay). +pub fn get_random_available_port() -> u16 { + let addr = &"127.0.0.1:0".parse::().unwrap().into(); + let socket = Socket::new(Domain::IPV4, Type::STREAM, None).unwrap(); + socket.set_reuse_address(true).unwrap(); + socket.bind(addr).unwrap(); + socket.listen(1).unwrap(); + let listener = TcpListener::from(socket); + listener.local_addr().unwrap().port() +} + +impl Drop for RedisServer { + fn drop(&mut self) { + self.stop() + } +} + +pub struct TestContext { + pub server: RedisServer, + pub client: redis::Client, + pub protocol: ProtocolVersion, +} + +pub(crate) fn is_tls_enabled() -> bool { + cfg!(all(feature = "tls-rustls", not(feature = "tls-native-tls"))) +} + +impl TestContext { + pub fn new() -> TestContext { + TestContext::with_modules(&[], false) + } + + #[cfg(feature = "tls-rustls")] + pub fn new_with_mtls() -> TestContext { + Self::with_modules(&[], true) + } + + fn connect_with_retries(client: &redis::Client) { + let mut con; + + let millisecond = Duration::from_millis(1); + let mut retries = 0; + loop { + match client.get_connection(None) { + Err(err) => { + if err.is_connection_refusal() { + sleep(millisecond); + retries += 1; + if retries > 100000 { + panic!("Tried to connect too many times, last error: {err}"); + } + } else { + panic!("Could not connect: {err}"); + } + } + Ok(x) => { + con = x; + break; + } + } + } + redis::cmd("FLUSHDB").execute(&mut con); + } + + pub fn with_tls(tls_files: TlsFilePaths, mtls_enabled: bool) -> TestContext { + let redis_port = get_random_available_port(); + let addr: ConnectionAddr = RedisServer::get_addr(redis_port); + + let server = RedisServer::new_with_addr_tls_modules_and_spawner( + addr, + None, + Some(tls_files), + mtls_enabled, + &[], + |cmd| { + cmd.spawn() + .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}")) + }, + ); + + #[cfg(feature = "tls-rustls")] + let client = + build_single_client(server.connection_info(), &server.tls_paths, mtls_enabled).unwrap(); + #[cfg(not(feature = "tls-rustls"))] + let client = redis::Client::open(server.connection_info()).unwrap(); + + Self::connect_with_retries(&client); + + TestContext { + server, + client, + protocol: use_protocol(), + } + } + + pub fn with_modules(modules: &[Module], mtls_enabled: bool) -> TestContext { + let server = RedisServer::with_modules(modules, mtls_enabled); + + #[cfg(feature = "tls-rustls")] + let client = + build_single_client(server.connection_info(), &server.tls_paths, mtls_enabled).unwrap(); + #[cfg(not(feature = "tls-rustls"))] + let client = redis::Client::open(server.connection_info()).unwrap(); + + Self::connect_with_retries(&client); + + TestContext { + server, + client, + protocol: use_protocol(), + } + } + + pub fn with_client_name(clientname: &str) -> TestContext { + let server = RedisServer::with_modules(&[], false); + let con_info = redis::ConnectionInfo { + addr: server.client_addr().clone(), + redis: redis::RedisConnectionInfo { + client_name: Some(clientname.to_string()), + ..Default::default() + }, + }; + + #[cfg(feature = "tls-rustls")] + let client = build_single_client(con_info, &server.tls_paths, false).unwrap(); + #[cfg(not(feature = "tls-rustls"))] + let client = redis::Client::open(con_info).unwrap(); + + Self::connect_with_retries(&client); + + TestContext { + server, + client, + protocol: use_protocol(), + } + } + + pub fn connection(&self) -> redis::Connection { + self.client.get_connection(None).unwrap() + } + + #[cfg(feature = "aio")] + pub async fn async_connection(&self) -> redis::RedisResult { + self.client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await + } + + #[cfg(feature = "aio")] + pub async fn async_pubsub(&self) -> redis::RedisResult { + self.client.get_async_pubsub().await + } + + pub fn stop_server(&mut self) { + self.server.stop(); + } + + #[cfg(feature = "tokio-comp")] + pub async fn multiplexed_async_connection( + &self, + ) -> redis::RedisResult { + self.multiplexed_async_connection_tokio().await + } + + #[cfg(feature = "tokio-comp")] + pub async fn multiplexed_async_connection_tokio( + &self, + ) -> redis::RedisResult { + self.client + .get_multiplexed_tokio_connection(GlideConnectionOptions::default()) + .await + } + + pub fn get_version(&self) -> Version { + let mut conn = self.connection(); + get_version(&mut conn) + } +} + +fn encode_iter(values: &[Value], writer: &mut W, prefix: &str) -> io::Result<()> +where + W: io::Write, +{ + write!(writer, "{}{}\r\n", prefix, values.len())?; + for val in values.iter() { + encode_value(val, writer)?; + } + Ok(()) +} +fn encode_map(values: &[(Value, Value)], writer: &mut W, prefix: &str) -> io::Result<()> +where + W: io::Write, +{ + write!(writer, "{}{}\r\n", prefix, values.len())?; + for (k, v) in values.iter() { + encode_value(k, writer)?; + encode_value(v, writer)?; + } + Ok(()) +} +pub fn encode_value(value: &Value, writer: &mut W) -> io::Result<()> +where + W: io::Write, +{ + #![allow(clippy::write_with_newline)] + match *value { + Value::Nil => write!(writer, "$-1\r\n"), + Value::Int(val) => write!(writer, ":{val}\r\n"), + Value::BulkString(ref val) => { + write!(writer, "${}\r\n", val.len())?; + writer.write_all(val)?; + writer.write_all(b"\r\n") + } + Value::Array(ref values) => encode_iter(values, writer, "*"), + Value::Okay => write!(writer, "+OK\r\n"), + Value::SimpleString(ref s) => write!(writer, "+{s}\r\n"), + Value::Map(ref values) => encode_map(values, writer, "%"), + Value::Attribute { + ref data, + ref attributes, + } => { + encode_map(attributes, writer, "|")?; + encode_value(data, writer)?; + Ok(()) + } + Value::Set(ref values) => encode_iter(values, writer, "~"), + Value::Double(val) => write!(writer, ",{}\r\n", val), + Value::Boolean(v) => { + if v { + write!(writer, "#t\r\n") + } else { + write!(writer, "#f\r\n") + } + } + Value::VerbatimString { + ref format, + ref text, + } => { + // format is always 3 bytes + write!(writer, "={}\r\n{}:{}\r\n", 3 + text.len(), format, text) + } + Value::BigNumber(ref val) => write!(writer, "({}\r\n", val), + Value::Push { ref kind, ref data } => { + write!(writer, ">{}\r\n+{kind}\r\n", data.len() + 1)?; + for val in data.iter() { + encode_value(val, writer)?; + } + Ok(()) + } + } +} + +#[derive(Clone, Debug)] +pub struct TlsFilePaths { + pub(crate) redis_crt: PathBuf, + pub(crate) redis_key: PathBuf, + pub(crate) ca_crt: PathBuf, +} + +pub fn build_keys_and_certs_for_tls(tempdir: &TempDir) -> TlsFilePaths { + // Based on shell script in redis's server tests + // https://github.com/redis/redis/blob/8c291b97b95f2e011977b522acf77ead23e26f55/utils/gen-test-certs.sh + let ca_crt = tempdir.path().join("ca.crt"); + let ca_key = tempdir.path().join("ca.key"); + let ca_serial = tempdir.path().join("ca.txt"); + let redis_crt = tempdir.path().join("redis.crt"); + let redis_key = tempdir.path().join("redis.key"); + let ext_file = tempdir.path().join("openssl.cnf"); + + fn make_key>(name: S, size: usize) { + process::Command::new("openssl") + .arg("genrsa") + .arg("-out") + .arg(name) + .arg(format!("{size}")) + .stdout(process::Stdio::null()) + .stderr(process::Stdio::null()) + .spawn() + .expect("failed to spawn openssl") + .wait() + .expect("failed to create key"); + } + + // Build CA Key + make_key(&ca_key, 4096); + + // Build redis key + make_key(&redis_key, 2048); + + // Build CA Cert + process::Command::new("openssl") + .arg("req") + .arg("-x509") + .arg("-new") + .arg("-nodes") + .arg("-sha256") + .arg("-key") + .arg(&ca_key) + .arg("-days") + .arg("3650") + .arg("-subj") + .arg("/O=Redis Test/CN=Certificate Authority") + .arg("-out") + .arg(&ca_crt) + .stdout(process::Stdio::null()) + .stderr(process::Stdio::null()) + .spawn() + .expect("failed to spawn openssl") + .wait() + .expect("failed to create CA cert"); + + // Build x509v3 extensions file + fs::write( + &ext_file, + b"keyUsage = digitalSignature, keyEncipherment\n\ + subjectAltName = @alt_names\n\ + [alt_names]\n\ + IP.1 = 127.0.0.1\n", + ) + .expect("failed to create x509v3 extensions file"); + + // Read redis key + let mut key_cmd = process::Command::new("openssl") + .arg("req") + .arg("-new") + .arg("-sha256") + .arg("-subj") + .arg("/O=Redis Test/CN=Generic-cert") + .arg("-key") + .arg(&redis_key) + .stdout(process::Stdio::piped()) + .stderr(process::Stdio::null()) + .spawn() + .expect("failed to spawn openssl"); + + // build redis cert + process::Command::new("openssl") + .arg("x509") + .arg("-req") + .arg("-sha256") + .arg("-CA") + .arg(&ca_crt) + .arg("-CAkey") + .arg(&ca_key) + .arg("-CAserial") + .arg(&ca_serial) + .arg("-CAcreateserial") + .arg("-days") + .arg("365") + .arg("-extfile") + .arg(&ext_file) + .arg("-out") + .arg(&redis_crt) + .stdin(key_cmd.stdout.take().expect("should have stdout")) + .stdout(process::Stdio::null()) + .stderr(process::Stdio::null()) + .spawn() + .expect("failed to spawn openssl") + .wait() + .expect("failed to create redis cert"); + + key_cmd.wait().expect("failed to create redis key"); + + TlsFilePaths { + redis_crt, + redis_key, + ca_crt, + } +} + +pub type Version = (u16, u16, u16); + +fn get_version(conn: &mut impl redis::ConnectionLike) -> Version { + let info: InfoDict = redis::Cmd::new().arg("INFO").query(conn).unwrap(); + let version: String = info.get("redis_version").unwrap(); + let versions: Vec = version + .split('.') + .map(|version| version.parse::().unwrap()) + .collect(); + assert_eq!(versions.len(), 3); + (versions[0], versions[1], versions[2]) +} + +pub fn is_major_version(expected_version: u16, version: Version) -> bool { + expected_version <= version.0 +} + +pub fn is_version(expected_major_minor: (u16, u16), version: Version) -> bool { + expected_major_minor.0 < version.0 + || (expected_major_minor.0 == version.0 && expected_major_minor.1 <= version.1) +} + +#[cfg(feature = "tls-rustls")] +fn load_certs_from_file(tls_file_paths: &TlsFilePaths) -> TlsCertificates { + let ca_file = File::open(&tls_file_paths.ca_crt).expect("Cannot open CA cert file"); + let mut root_cert_vec = Vec::new(); + BufReader::new(ca_file) + .read_to_end(&mut root_cert_vec) + .expect("Unable to read CA cert file"); + + let cert_file = File::open(&tls_file_paths.redis_crt).expect("cannot open private cert file"); + let mut client_cert_vec = Vec::new(); + BufReader::new(cert_file) + .read_to_end(&mut client_cert_vec) + .expect("Unable to read client cert file"); + + let key_file = File::open(&tls_file_paths.redis_key).expect("Cannot open private key file"); + let mut client_key_vec = Vec::new(); + BufReader::new(key_file) + .read_to_end(&mut client_key_vec) + .expect("Unable to read client key file"); + + TlsCertificates { + client_tls: Some(ClientTlsConfig { + client_cert: client_cert_vec, + client_key: client_key_vec, + }), + root_cert: Some(root_cert_vec), + } +} + +#[cfg(feature = "tls-rustls")] +pub(crate) fn build_single_client( + connection_info: T, + tls_file_params: &Option, + mtls_enabled: bool, +) -> redis::RedisResult { + if mtls_enabled && tls_file_params.is_some() { + redis::Client::build_with_tls( + connection_info, + load_certs_from_file( + tls_file_params + .as_ref() + .expect("Expected certificates when `tls-rustls` feature is enabled"), + ), + ) + } else { + redis::Client::open(connection_info) + } +} + +#[cfg(feature = "tls-rustls")] +pub(crate) mod mtls_test { + use super::*; + use redis::{cluster::ClusterClient, ConnectionInfo, RedisError}; + + fn clean_node_info(nodes: &[ConnectionInfo]) -> Vec { + let nodes = nodes + .iter() + .map(|node| match node { + ConnectionInfo { + addr: redis::ConnectionAddr::TcpTls { host, port, .. }, + redis, + } => ConnectionInfo { + addr: redis::ConnectionAddr::TcpTls { + host: host.to_owned(), + port: *port, + insecure: false, + tls_params: None, + }, + redis: redis.clone(), + }, + _ => node.clone(), + }) + .collect(); + nodes + } + + pub(crate) fn create_cluster_client_from_cluster( + cluster: &TestClusterContext, + mtls_enabled: bool, + ) -> Result { + let server = cluster + .cluster + .servers + .first() + .expect("Expected at least 1 server"); + let tls_paths = server.tls_paths.as_ref(); + let nodes = clean_node_info(&cluster.nodes); + let builder = redis::cluster::ClusterClientBuilder::new(nodes); + if let Some(tls_paths) = tls_paths { + // server-side TLS available + if mtls_enabled { + builder.certs(load_certs_from_file(tls_paths)) + } else { + builder + } + } else { + // server-side TLS NOT available + builder + } + .build() + } +} + +pub fn build_simple_pipeline_for_invalidation() -> Pipeline { + let mut pipe = redis::pipe(); + pipe.cmd("GET") + .arg("key_1") + .ignore() + .cmd("SET") + .arg("key_1") + .arg(42) + .ignore(); + pipe +} diff --git a/glide-core/redis-rs/redis/tests/support/sentinel.rs b/glide-core/redis-rs/redis/tests/support/sentinel.rs new file mode 100644 index 0000000000..d34d3dc88b --- /dev/null +++ b/glide-core/redis-rs/redis/tests/support/sentinel.rs @@ -0,0 +1,404 @@ +use std::fs::File; +use std::io::Write; +use std::thread::sleep; +use std::time::Duration; + +use redis::sentinel::SentinelNodeConnectionInfo; +use redis::Client; +use redis::ConnectionAddr; +use redis::ConnectionInfo; +use redis::FromRedisValue; +use redis::RedisResult; +use redis::TlsMode; +use tempfile::TempDir; + +use crate::support::build_single_client; + +use super::build_keys_and_certs_for_tls; +use super::get_random_available_port; +use super::Module; +use super::RedisServer; +use super::TlsFilePaths; + +const LOCALHOST: &str = "127.0.0.1"; +const MTLS_NOT_ENABLED: bool = false; + +pub struct RedisSentinelCluster { + pub servers: Vec, + pub sentinel_servers: Vec, + pub folders: Vec, +} + +fn get_addr(port: u16) -> ConnectionAddr { + let addr = RedisServer::get_addr(port); + if let ConnectionAddr::Unix(_) = addr { + ConnectionAddr::Tcp(String::from("127.0.0.1"), port) + } else { + addr + } +} + +fn spawn_master_server( + port: u16, + dir: &TempDir, + tlspaths: &TlsFilePaths, + modules: &[Module], +) -> RedisServer { + RedisServer::new_with_addr_tls_modules_and_spawner( + get_addr(port), + None, + Some(tlspaths.clone()), + MTLS_NOT_ENABLED, + modules, + |cmd| { + // Minimize startup delay + cmd.arg("--repl-diskless-sync-delay").arg("0"); + cmd.arg("--appendonly").arg("yes"); + if let ConnectionAddr::TcpTls { .. } = get_addr(port) { + cmd.arg("--tls-replication").arg("yes"); + } + cmd.current_dir(dir.path()); + cmd.spawn().unwrap() + }, + ) +} + +fn spawn_replica_server( + port: u16, + master_port: u16, + dir: &TempDir, + tlspaths: &TlsFilePaths, + modules: &[Module], +) -> RedisServer { + let config_file_path = dir.path().join("redis_config.conf"); + File::create(&config_file_path).unwrap(); + + RedisServer::new_with_addr_tls_modules_and_spawner( + get_addr(port), + Some(&config_file_path), + Some(tlspaths.clone()), + MTLS_NOT_ENABLED, + modules, + |cmd| { + cmd.arg("--replicaof") + .arg("127.0.0.1") + .arg(master_port.to_string()); + if let ConnectionAddr::TcpTls { .. } = get_addr(port) { + cmd.arg("--tls-replication").arg("yes"); + } + cmd.arg("--appendonly").arg("yes"); + cmd.current_dir(dir.path()); + cmd.spawn().unwrap() + }, + ) +} + +fn spawn_sentinel_server( + port: u16, + master_ports: &[u16], + dir: &TempDir, + tlspaths: &TlsFilePaths, + modules: &[Module], +) -> RedisServer { + let config_file_path = dir.path().join("redis_config.conf"); + let mut file = File::create(&config_file_path).unwrap(); + for (i, master_port) in master_ports.iter().enumerate() { + file.write_all( + format!("sentinel monitor master{} 127.0.0.1 {} 1\n", i, master_port).as_bytes(), + ) + .unwrap(); + } + file.flush().unwrap(); + + RedisServer::new_with_addr_tls_modules_and_spawner( + get_addr(port), + Some(&config_file_path), + Some(tlspaths.clone()), + MTLS_NOT_ENABLED, + modules, + |cmd| { + cmd.arg("--sentinel"); + cmd.arg("--appendonly").arg("yes"); + if let ConnectionAddr::TcpTls { .. } = get_addr(port) { + cmd.arg("--tls-replication").arg("yes"); + } + cmd.current_dir(dir.path()); + cmd.spawn().unwrap() + }, + ) +} + +fn wait_for_master_server( + mut get_client_fn: impl FnMut() -> RedisResult, +) -> Result<(), ()> { + let rolecmd = redis::cmd("ROLE"); + for _ in 0..100 { + let master_client = get_client_fn(); + match master_client { + Ok(client) => match client.get_connection(None) { + Ok(mut conn) => { + let r: Vec = rolecmd.query(&mut conn).unwrap(); + let role = String::from_redis_value(r.first().unwrap()).unwrap(); + if role.starts_with("master") { + return Ok(()); + } else { + println!("failed check for master role - current role: {r:?}") + } + } + Err(err) => { + println!("failed to get master connection: {:?}", err) + } + }, + Err(err) => { + println!("failed to get master client: {:?}", err) + } + } + + sleep(Duration::from_millis(25)); + } + + Err(()) +} + +fn wait_for_replica(mut get_client_fn: impl FnMut() -> RedisResult) -> Result<(), ()> { + let rolecmd = redis::cmd("ROLE"); + for _ in 0..200 { + let replica_client = get_client_fn(); + match replica_client { + Ok(client) => match client.get_connection(None) { + Ok(mut conn) => { + let r: Vec = rolecmd.query(&mut conn).unwrap(); + let role = String::from_redis_value(r.first().unwrap()).unwrap(); + let state = String::from_redis_value(r.get(3).unwrap()).unwrap(); + if role.starts_with("slave") && state == "connected" { + return Ok(()); + } else { + println!("failed check for replica role - current role: {:?}", r) + } + } + Err(err) => { + println!("failed to get replica connection: {:?}", err) + } + }, + Err(err) => { + println!("failed to get replica client: {:?}", err) + } + } + + sleep(Duration::from_millis(25)); + } + + Err(()) +} + +fn wait_for_replicas_to_sync(servers: &[RedisServer], masters: u16) { + let cluster_size = servers.len() / (masters as usize); + let clusters = servers.len() / cluster_size; + let replicas = cluster_size - 1; + + for cluster_index in 0..clusters { + let master_addr = servers[cluster_index * cluster_size].connection_info(); + let tls_paths = &servers.first().unwrap().tls_paths; + let r = wait_for_master_server(|| { + Ok(build_single_client(master_addr.clone(), tls_paths, MTLS_NOT_ENABLED).unwrap()) + }); + if r.is_err() { + panic!("failed waiting for master to be ready"); + } + + for replica_index in 0..replicas { + let replica_addr = + servers[(cluster_index * cluster_size) + 1 + replica_index].connection_info(); + let r = wait_for_replica(|| { + Ok(build_single_client(replica_addr.clone(), tls_paths, MTLS_NOT_ENABLED).unwrap()) + }); + if r.is_err() { + panic!("failed waiting for replica to be ready and in sync"); + } + } + } +} + +impl RedisSentinelCluster { + pub fn new(masters: u16, replicas_per_master: u16, sentinels: u16) -> RedisSentinelCluster { + RedisSentinelCluster::with_modules(masters, replicas_per_master, sentinels, &[]) + } + + pub fn with_modules( + masters: u16, + replicas_per_master: u16, + sentinels: u16, + modules: &[Module], + ) -> RedisSentinelCluster { + let mut servers = vec![]; + let mut folders = vec![]; + let mut master_ports = vec![]; + + let tempdir = tempfile::Builder::new() + .prefix("redistls") + .tempdir() + .expect("failed to create tempdir"); + let tlspaths = build_keys_and_certs_for_tls(&tempdir); + folders.push(tempdir); + + let required_number_of_sockets = masters * (replicas_per_master + 1) + sentinels; + let mut available_ports = std::collections::HashSet::new(); + while available_ports.len() < required_number_of_sockets as usize { + available_ports.insert(get_random_available_port()); + } + let mut available_ports: Vec<_> = available_ports.into_iter().collect(); + + for _ in 0..masters { + let port = available_ports.pop().unwrap(); + let tempdir = tempfile::Builder::new() + .prefix("redis") + .tempdir() + .expect("failed to create tempdir"); + servers.push(spawn_master_server(port, &tempdir, &tlspaths, modules)); + folders.push(tempdir); + master_ports.push(port); + + for _ in 0..replicas_per_master { + let replica_port = available_ports.pop().unwrap(); + let tempdir = tempfile::Builder::new() + .prefix("redis") + .tempdir() + .expect("failed to create tempdir"); + servers.push(spawn_replica_server( + replica_port, + port, + &tempdir, + &tlspaths, + modules, + )); + folders.push(tempdir); + } + } + + // Wait for replicas to sync so that the sentinels discover them on the first try + wait_for_replicas_to_sync(&servers, masters); + + let mut sentinel_servers = vec![]; + for _ in 0..sentinels { + let port = available_ports.pop().unwrap(); + let tempdir = tempfile::Builder::new() + .prefix("redis") + .tempdir() + .expect("failed to create tempdir"); + + sentinel_servers.push(spawn_sentinel_server( + port, + &master_ports, + &tempdir, + &tlspaths, + modules, + )); + folders.push(tempdir); + } + + RedisSentinelCluster { + servers, + sentinel_servers, + folders, + } + } + + pub fn stop(&mut self) { + for server in &mut self.servers { + server.stop(); + } + for server in &mut self.sentinel_servers { + server.stop(); + } + } + + pub fn iter_sentinel_servers(&self) -> impl Iterator { + self.sentinel_servers.iter() + } +} + +impl Drop for RedisSentinelCluster { + fn drop(&mut self) { + self.stop() + } +} + +pub struct TestSentinelContext { + pub cluster: RedisSentinelCluster, + pub sentinel: redis::sentinel::Sentinel, + pub sentinels_connection_info: Vec, + mtls_enabled: bool, // for future tests +} + +impl TestSentinelContext { + pub fn new(nodes: u16, replicas: u16, sentinels: u16) -> TestSentinelContext { + Self::new_with_cluster_client_builder(nodes, replicas, sentinels) + } + + pub fn new_with_cluster_client_builder( + nodes: u16, + replicas: u16, + sentinels: u16, + ) -> TestSentinelContext { + let cluster = RedisSentinelCluster::new(nodes, replicas, sentinels); + let initial_nodes: Vec = cluster + .iter_sentinel_servers() + .map(RedisServer::connection_info) + .collect(); + let sentinel = redis::sentinel::Sentinel::build(initial_nodes.clone()); + let sentinel = sentinel.unwrap(); + + let mut context = TestSentinelContext { + cluster, + sentinel, + sentinels_connection_info: initial_nodes, + mtls_enabled: MTLS_NOT_ENABLED, + }; + context.wait_for_cluster_up(); + context + } + + pub fn sentinel(&self) -> &redis::sentinel::Sentinel { + &self.sentinel + } + + pub fn sentinel_mut(&mut self) -> &mut redis::sentinel::Sentinel { + &mut self.sentinel + } + + pub fn sentinels_connection_info(&self) -> &Vec { + &self.sentinels_connection_info + } + + pub fn sentinel_node_connection_info(&self) -> SentinelNodeConnectionInfo { + SentinelNodeConnectionInfo { + tls_mode: if let ConnectionAddr::TcpTls { insecure, .. } = + self.cluster.servers[0].client_addr() + { + if *insecure { + Some(TlsMode::Insecure) + } else { + Some(TlsMode::Secure) + } + } else { + None + }, + redis_connection_info: None, + } + } + + pub fn wait_for_cluster_up(&mut self) { + let node_conn_info = self.sentinel_node_connection_info(); + let con = self.sentinel_mut(); + + let r = wait_for_master_server(|| con.master_for("master1", Some(&node_conn_info))); + if r.is_err() { + panic!("failed waiting for sentinel master1 to be ready"); + } + + let r = wait_for_replica(|| con.replica_for("master1", Some(&node_conn_info))); + if r.is_err() { + panic!("failed waiting for sentinel master1 replica to be ready"); + } + } +} diff --git a/glide-core/redis-rs/redis/tests/support/util.rs b/glide-core/redis-rs/redis/tests/support/util.rs new file mode 100644 index 0000000000..4533146b67 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/support/util.rs @@ -0,0 +1,35 @@ +use std::collections::HashMap; +use versions::Versioning; + +use super::TestContext; + +#[macro_export] +macro_rules! assert_args { + ($value:expr, $($args:expr),+) => { + let args = $value.to_redis_args(); + let strings: Vec<_> = args.iter() + .map(|a| std::str::from_utf8(a.as_ref()).unwrap()) + .collect(); + assert_eq!(strings, vec![$($args),+]); + } +} + +pub fn parse_client_info(client_info: &str) -> HashMap { + let mut res = HashMap::new(); + + for line in client_info.split(' ') { + let this_attr: Vec<&str> = line.split('=').collect(); + res.insert(this_attr[0].to_string(), this_attr[1].to_string()); + } + + res +} + +pub fn version_greater_or_equal(ctx: &TestContext, version: &str) -> bool { + // Get the server version + let (major, minor, patch) = ctx.get_version(); + let server_version = Versioning::new(format!("{major}.{minor}.{patch}")).unwrap(); + let compared_version = Versioning::new(version).unwrap(); + // Compare server version with the specified version + server_version >= compared_version +} diff --git a/glide-core/redis-rs/redis/tests/test_acl.rs b/glide-core/redis-rs/redis/tests/test_acl.rs new file mode 100644 index 0000000000..093774f3bc --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_acl.rs @@ -0,0 +1,156 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +#![cfg(feature = "acl")] + +use std::collections::HashSet; + +use redis::acl::{AclInfo, Rule}; +use redis::{Commands, Value}; + +mod support; +use crate::support::*; + +#[test] +fn test_acl_whoami() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + assert_eq!(con.acl_whoami(), Ok("default".to_owned())); +} + +#[test] +fn test_acl_help() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let res: Vec = con.acl_help().expect("Got help manual"); + assert!(!res.is_empty()); +} + +//TODO: do we need this test? +#[test] +#[ignore] +fn test_acl_getsetdel_users() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + assert_eq!( + con.acl_list(), + Ok(vec!["user default on nopass ~* +@all".to_owned()]) + ); + assert_eq!(con.acl_users(), Ok(vec!["default".to_owned()])); + // bob + assert_eq!(con.acl_setuser("bob"), Ok(())); + assert_eq!( + con.acl_users(), + Ok(vec!["bob".to_owned(), "default".to_owned()]) + ); + + // ACL SETUSER bob on ~redis:* +set + assert_eq!( + con.acl_setuser_rules( + "bob", + &[ + Rule::On, + Rule::AddHashedPass( + "c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2".to_owned() + ), + Rule::Pattern("redis:*".to_owned()), + Rule::AddCommand("set".to_owned()) + ], + ), + Ok(()) + ); + let acl_info: AclInfo = con.acl_getuser("bob").expect("Got user"); + assert_eq!( + acl_info, + AclInfo { + flags: vec![Rule::On], + passwords: vec![Rule::AddHashedPass( + "c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2".to_owned() + )], + commands: vec![ + Rule::RemoveCategory("all".to_owned()), + Rule::AddCommand("set".to_owned()) + ], + keys: vec![Rule::Pattern("redis:*".to_owned())], + } + ); + assert_eq!( + con.acl_list(), + Ok(vec![ + "user bob on #c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2 ~redis:* -@all +set".to_owned(), + "user default on nopass ~* +@all".to_owned(), + ]) + ); + + // ACL SETUSER eve + assert_eq!(con.acl_setuser("eve"), Ok(())); + assert_eq!( + con.acl_users(), + Ok(vec![ + "bob".to_owned(), + "default".to_owned(), + "eve".to_owned() + ]) + ); + assert_eq!(con.acl_deluser(&["bob", "eve"]), Ok(2)); + assert_eq!(con.acl_users(), Ok(vec!["default".to_owned()])); +} + +#[test] +fn test_acl_cat() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let res: HashSet = con.acl_cat().expect("Got categories"); + let expects = vec![ + "keyspace", + "read", + "write", + "set", + "sortedset", + "list", + "hash", + "string", + "bitmap", + "hyperloglog", + "geo", + "stream", + "pubsub", + "admin", + "fast", + "slow", + "blocking", + "dangerous", + "connection", + "transaction", + "scripting", + ]; + for cat in expects.iter() { + assert!(res.contains(*cat), "Category `{cat}` does not exist"); + } + + let expects = ["pfmerge", "pfcount", "pfselftest", "pfadd"]; + let res: HashSet = con + .acl_cat_categoryname("hyperloglog") + .expect("Got commands of a category"); + for cmd in expects.iter() { + assert!(res.contains(*cmd), "Command `{cmd}` does not exist"); + } +} + +#[test] +fn test_acl_genpass() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let pass: String = con.acl_genpass().expect("Got password"); + assert_eq!(pass.len(), 64); + + let pass: String = con.acl_genpass_bits(1024).expect("Got password"); + assert_eq!(pass.len(), 256); +} + +#[test] +fn test_acl_log() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let logs: Vec = con.acl_log(1).expect("Got logs"); + assert_eq!(logs.len(), 0); + assert_eq!(con.acl_log_reset(), Ok(())); +} diff --git a/glide-core/redis-rs/redis/tests/test_async.rs b/glide-core/redis-rs/redis/tests/test_async.rs new file mode 100644 index 0000000000..73d14de022 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_async.rs @@ -0,0 +1,1132 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] + +mod support; + +#[cfg(test)] +mod basic_async { + use std::collections::HashMap; + + use futures::{prelude::*, StreamExt}; + use redis::{ + aio::{ConnectionLike, MultiplexedConnection}, + cmd, pipe, AsyncCommands, ErrorKind, GlideConnectionOptions, PushInfo, PushKind, + RedisResult, Value, + }; + use tokio::sync::mpsc::error::TryRecvError; + + use crate::support::*; + + #[test] + fn test_args() { + let ctx = TestContext::new(); + let connect = ctx.async_connection(); + + block_on_all(connect.and_then(|mut con| async move { + redis::cmd("SET") + .arg("key1") + .arg(b"foo") + .query_async(&mut con) + .await?; + redis::cmd("SET") + .arg(&["key2", "bar"]) + .query_async(&mut con) + .await?; + let result = redis::cmd("MGET") + .arg(&["key1", "key2"]) + .query_async(&mut con) + .await; + assert_eq!(result, Ok(("foo".to_string(), b"bar".to_vec()))); + result + })) + .unwrap(); + } + + #[test] + fn test_nice_hash_api() { + let ctx = TestContext::new(); + + block_on_all(async move { + let mut connection = ctx.async_connection().await.unwrap(); + + assert_eq!( + connection + .hset_multiple("my_hash", &[("f1", 1), ("f2", 2), ("f3", 4), ("f4", 8)]) + .await, + Ok(()) + ); + + let hm: HashMap = connection.hgetall("my_hash").await.unwrap(); + assert_eq!(hm.len(), 4); + assert_eq!(hm.get("f1"), Some(&1)); + assert_eq!(hm.get("f2"), Some(&2)); + assert_eq!(hm.get("f3"), Some(&4)); + assert_eq!(hm.get("f4"), Some(&8)); + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_nice_hash_api_in_pipe() { + let ctx = TestContext::new(); + + block_on_all(async move { + let mut connection = ctx.async_connection().await.unwrap(); + + assert_eq!( + connection + .hset_multiple("my_hash", &[("f1", 1), ("f2", 2), ("f3", 4), ("f4", 8)]) + .await, + Ok(()) + ); + + let mut pipe = redis::pipe(); + pipe.cmd("HGETALL").arg("my_hash"); + let mut vec: Vec> = + pipe.query_async(&mut connection).await.unwrap(); + assert_eq!(vec.len(), 1); + let hash = vec.pop().unwrap(); + assert_eq!(hash.len(), 4); + assert_eq!(hash.get("f1"), Some(&1)); + assert_eq!(hash.get("f2"), Some(&2)); + assert_eq!(hash.get("f3"), Some(&4)); + assert_eq!(hash.get("f4"), Some(&8)); + + Ok(()) + }) + .unwrap(); + } + + #[test] + fn dont_panic_on_closed_multiplexed_connection() { + let ctx = TestContext::new(); + let client = ctx.client.clone(); + let connect = client.get_multiplexed_async_connection(GlideConnectionOptions::default()); + drop(ctx); + + block_on_all(async move { + connect + .and_then(|con| async move { + let cmd = move || { + let mut con = con.clone(); + async move { + redis::cmd("SET") + .arg("key1") + .arg(b"foo") + .query_async(&mut con) + .await + } + }; + let result: RedisResult<()> = cmd().await; + assert_eq!( + result.as_ref().unwrap_err().kind(), + redis::ErrorKind::IoError, + "{}", + result.as_ref().unwrap_err() + ); + cmd().await + }) + .map(|result| { + assert_eq!( + result.as_ref().unwrap_err().kind(), + redis::ErrorKind::IoError, + "{}", + result.as_ref().unwrap_err() + ); + }) + .await; + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_pipeline_transaction() { + let ctx = TestContext::new(); + block_on_all(async move { + let mut con = ctx.async_connection().await?; + let mut pipe = redis::pipe(); + pipe.atomic() + .cmd("SET") + .arg("key_1") + .arg(42) + .ignore() + .cmd("SET") + .arg("key_2") + .arg(43) + .ignore() + .cmd("MGET") + .arg(&["key_1", "key_2"]); + pipe.query_async(&mut con) + .map_ok(|((k1, k2),): ((i32, i32),)| { + assert_eq!(k1, 42); + assert_eq!(k2, 43); + }) + .await + }) + .unwrap(); + } + + #[test] + fn test_client_tracking_doesnt_block_execution() { + //It checks if the library distinguish a push-type message from the others and continues its normal operation. + let ctx = TestContext::new(); + block_on_all(async move { + let mut con = ctx.async_connection().await.unwrap(); + let mut pipe = redis::pipe(); + pipe.cmd("CLIENT") + .arg("TRACKING") + .arg("ON") + .ignore() + .cmd("GET") + .arg("key_1") + .ignore() + .cmd("SET") + .arg("key_1") + .arg(42) + .ignore(); + let _: RedisResult<()> = pipe.query_async(&mut con).await; + let num: i32 = con.get("key_1").await.unwrap(); + assert_eq!(num, 42); + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_pipeline_transaction_with_errors() { + use redis::RedisError; + let ctx = TestContext::new(); + + block_on_all(async move { + let mut con = ctx.async_connection().await?; + con.set::<_, _, ()>("x", 42).await.unwrap(); + + // Make Redis a replica of a nonexistent master, thereby making it read-only. + redis::cmd("slaveof") + .arg("1.1.1.1") + .arg("1") + .query_async::<_, ()>(&mut con) + .await + .unwrap(); + + // Ensure that a write command fails with a READONLY error + let err: RedisResult<()> = redis::pipe() + .atomic() + .set("x", 142) + .ignore() + .get("x") + .query_async(&mut con) + .await; + + assert_eq!(err.unwrap_err().kind(), ErrorKind::ReadOnly); + + let x: i32 = con.get("x").await.unwrap(); + assert_eq!(x, 42); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + fn test_cmd( + con: &MultiplexedConnection, + i: i32, + ) -> impl Future> + Send { + let mut con = con.clone(); + async move { + let key = format!("key{i}"); + let key_2 = key.clone(); + let key2 = format!("key{i}_2"); + let key2_2 = key2.clone(); + + let foo_val = format!("foo{i}"); + + redis::cmd("SET") + .arg(&key[..]) + .arg(foo_val.as_bytes()) + .query_async(&mut con) + .await?; + redis::cmd("SET") + .arg(&[&key2, "bar"]) + .query_async(&mut con) + .await?; + redis::cmd("MGET") + .arg(&[&key_2, &key2_2]) + .query_async(&mut con) + .map(|result| { + assert_eq!(Ok((foo_val, b"bar".to_vec())), result); + Ok(()) + }) + .await + } + } + + fn test_error(con: &MultiplexedConnection) -> impl Future> { + let mut con = con.clone(); + async move { + redis::cmd("SET") + .query_async(&mut con) + .map(|result| match result { + Ok(()) => panic!("Expected redis to return an error"), + Err(_) => Ok(()), + }) + .await + } + } + + #[test] + fn test_pipe_over_multiplexed_connection() { + let ctx = TestContext::new(); + block_on_all(async move { + let mut con = ctx.multiplexed_async_connection().await?; + let mut pipe = pipe(); + pipe.zrange("zset", 0, 0); + pipe.zrange("zset", 0, 0); + let frames = con.send_packed_commands(&pipe, 0, 2).await?; + assert_eq!(frames.len(), 2); + assert!(matches!(frames[0], redis::Value::Array(_))); + assert!(matches!(frames[1], redis::Value::Array(_))); + RedisResult::Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_args_multiplexed_connection() { + let ctx = TestContext::new(); + block_on_all(async move { + ctx.multiplexed_async_connection() + .and_then(|con| { + let cmds = (0..100).map(move |i| test_cmd(&con, i)); + future::try_join_all(cmds).map_ok(|results| { + assert_eq!(results.len(), 100); + }) + }) + .map_err(|err| panic!("{}", err)) + .await + }) + .unwrap(); + } + + #[test] + fn test_args_with_errors_multiplexed_connection() { + let ctx = TestContext::new(); + block_on_all(async move { + ctx.multiplexed_async_connection() + .and_then(|con| { + let cmds = (0..100).map(move |i| { + let con = con.clone(); + async move { + if i % 2 == 0 { + test_cmd(&con, i).await + } else { + test_error(&con).await + } + } + }); + future::try_join_all(cmds).map_ok(|results| { + assert_eq!(results.len(), 100); + }) + }) + .map_err(|err| panic!("{}", err)) + .await + }) + .unwrap(); + } + + #[test] + fn test_transaction_multiplexed_connection() { + let ctx = TestContext::new(); + block_on_all(async move { + ctx.multiplexed_async_connection() + .and_then(|con| { + let cmds = (0..100).map(move |i| { + let mut con = con.clone(); + async move { + let foo_val = i; + let bar_val = format!("bar{i}"); + + let mut pipe = redis::pipe(); + pipe.atomic() + .cmd("SET") + .arg("key") + .arg(foo_val) + .ignore() + .cmd("SET") + .arg(&["key2", &bar_val[..]]) + .ignore() + .cmd("MGET") + .arg(&["key", "key2"]); + + pipe.query_async(&mut con) + .map(move |result| { + assert_eq!(Ok(((foo_val, bar_val.into_bytes()),)), result); + result + }) + .await + } + }); + future::try_join_all(cmds) + }) + .map_ok(|results| { + assert_eq!(results.len(), 100); + }) + .map_err(|err| panic!("{}", err)) + .await + }) + .unwrap(); + } + + fn test_async_scanning(batch_size: usize) { + let ctx = TestContext::new(); + block_on_all(async move { + ctx.multiplexed_async_connection() + .and_then(|mut con| { + async move { + let mut unseen = std::collections::HashSet::new(); + + for x in 0..batch_size { + redis::cmd("SADD") + .arg("foo") + .arg(x) + .query_async(&mut con) + .await?; + unseen.insert(x); + } + + let mut iter = redis::cmd("SSCAN") + .arg("foo") + .cursor_arg(0) + .clone() + .iter_async(&mut con) + .await + .unwrap(); + + while let Some(x) = iter.next_item().await { + // type inference limitations + let x: usize = x; + // if this assertion fails, too many items were returned by the iterator. + assert!(unseen.remove(&x)); + } + + assert_eq!(unseen.len(), 0); + Ok(()) + } + }) + .map_err(|err| panic!("{}", err)) + .await + }) + .unwrap(); + } + + #[test] + fn test_async_scanning_big_batch() { + test_async_scanning(1000) + } + + #[test] + fn test_async_scanning_small_batch() { + test_async_scanning(2) + } + + #[test] + fn test_response_timeout_multiplexed_connection() { + let ctx = TestContext::new(); + block_on_all(async move { + let mut connection = ctx.multiplexed_async_connection().await.unwrap(); + connection.set_response_timeout(std::time::Duration::from_millis(1)); + let mut cmd = redis::Cmd::new(); + cmd.arg("BLPOP").arg("foo").arg(0); // 0 timeout blocks indefinitely + let result = connection.req_packed_command(&cmd).await; + assert!(result.is_err()); + assert!(result.unwrap_err().is_timeout()); + Ok(()) + }) + .unwrap(); + } + + #[test] + #[cfg(feature = "script")] + fn test_script() { + use redis::RedisError; + + // Note this test runs both scripts twice to test when they have already been loaded + // into Redis and when they need to be loaded in + let script1 = redis::Script::new("return redis.call('SET', KEYS[1], ARGV[1])"); + let script2 = redis::Script::new("return redis.call('GET', KEYS[1])"); + let script3 = redis::Script::new("return redis.call('KEYS', '*')"); + + let ctx = TestContext::new(); + + block_on_all(async move { + let mut con = ctx.multiplexed_async_connection().await?; + script1 + .key("key1") + .arg("foo") + .invoke_async(&mut con) + .await?; + let val: String = script2.key("key1").invoke_async(&mut con).await?; + assert_eq!(val, "foo"); + let keys: Vec = script3.invoke_async(&mut con).await?; + assert_eq!(keys, ["key1"]); + script1 + .key("key1") + .arg("bar") + .invoke_async(&mut con) + .await?; + let val: String = script2.key("key1").invoke_async(&mut con).await?; + assert_eq!(val, "bar"); + let keys: Vec = script3.invoke_async(&mut con).await?; + assert_eq!(keys, ["key1"]); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[cfg(feature = "script")] + fn test_script_load() { + let ctx = TestContext::new(); + let script = redis::Script::new("return 'Hello World'"); + + block_on_all(async move { + let mut con = ctx.multiplexed_async_connection().await.unwrap(); + + let hash = script.prepare_invoke().load_async(&mut con).await.unwrap(); + assert_eq!(hash, script.get_hash().to_string()); + Ok(()) + }) + .unwrap(); + } + + #[test] + #[cfg(feature = "script")] + fn test_script_returning_complex_type() { + let ctx = TestContext::new(); + block_on_all(async { + let mut con = ctx.multiplexed_async_connection().await?; + redis::Script::new("return {1, ARGV[1], true}") + .arg("hello") + .invoke_async(&mut con) + .map_ok(|(i, s, b): (i32, String, bool)| { + assert_eq!(i, 1); + assert_eq!(s, "hello"); + assert!(b); + }) + .await + }) + .unwrap(); + } + + // Allowing `nth(0)` for similarity with the following `nth(1)`. + // Allowing `let ()` as `query_async` requries the type it converts the result to. + #[allow(clippy::let_unit_value, clippy::iter_nth_zero)] + #[tokio::test] + async fn io_error_on_kill_issue_320() { + let ctx = TestContext::new(); + + let mut conn_to_kill = ctx.async_connection().await.unwrap(); + cmd("CLIENT") + .arg("SETNAME") + .arg("to-kill") + .query_async::<_, ()>(&mut conn_to_kill) + .await + .unwrap(); + + let client_list: String = cmd("CLIENT") + .arg("LIST") + .query_async(&mut conn_to_kill) + .await + .unwrap(); + + eprintln!("{client_list}"); + let client_to_kill = client_list + .split('\n') + .find(|line| line.contains("to-kill")) + .expect("line") + .split(' ') + .nth(0) + .expect("id") + .split('=') + .nth(1) + .expect("id value"); + + let mut killer_conn = ctx.async_connection().await.unwrap(); + let () = cmd("CLIENT") + .arg("KILL") + .arg("ID") + .arg(client_to_kill) + .query_async(&mut killer_conn) + .await + .unwrap(); + let mut killed_client = conn_to_kill; + + let err = loop { + match killed_client.get::<_, Option>("a").await { + // We are racing against the server being shutdown so try until we a get an io error + Ok(_) => tokio::time::sleep(std::time::Duration::from_millis(50)).await, + Err(err) => break err, + } + }; + assert_eq!(err.kind(), ErrorKind::FatalSendError); + } + + #[tokio::test] + async fn invalid_password_issue_343() { + let ctx = TestContext::new(); + let coninfo = redis::ConnectionInfo { + addr: ctx.server.client_addr().clone(), + redis: redis::RedisConnectionInfo { + password: Some("asdcasc".to_string()), + ..Default::default() + }, + }; + let client = redis::Client::open(coninfo).unwrap(); + + let err = client + .get_multiplexed_tokio_connection(GlideConnectionOptions::default()) + .await + .err() + .unwrap(); + assert_eq!( + err.kind(), + ErrorKind::AuthenticationFailed, + "Unexpected error: {err}", + ); + } + + // Test issue of Stream trait blocking if we try to iterate more than 10 items + // https://github.com/mitsuhiko/redis-rs/issues/537 and https://github.com/mitsuhiko/redis-rs/issues/583 + #[tokio::test] + async fn test_issue_stream_blocks() { + let ctx = TestContext::new(); + let mut con = ctx.multiplexed_async_connection().await.unwrap(); + for i in 0..20usize { + let _: () = con.append(format!("test/{i}"), i).await.unwrap(); + } + let values = con.scan_match::<&str, String>("test/*").await.unwrap(); + tokio::time::timeout(std::time::Duration::from_millis(100), async move { + let values: Vec<_> = values.collect().await; + assert_eq!(values.len(), 20); + }) + .await + .unwrap(); + } + + // Test issue of AsyncCommands::scan returning the wrong number of keys + // https://github.com/redis-rs/redis-rs/issues/759 + #[tokio::test] + async fn test_issue_async_commands_scan_broken() { + let ctx = TestContext::new(); + let mut con = ctx.async_connection().await.unwrap(); + let mut keys: Vec = (0..100).map(|k| format!("async-key{k}")).collect(); + keys.sort(); + for key in &keys { + let _: () = con.set(key, b"foo").await.unwrap(); + } + + let iter: redis::AsyncIter = con.scan().await.unwrap(); + let mut keys_from_redis: Vec<_> = iter.collect().await; + keys_from_redis.sort(); + assert_eq!(keys, keys_from_redis); + assert_eq!(keys.len(), 100); + } + + mod pub_sub { + use std::time::Duration; + + use redis::ProtocolVersion; + + use super::*; + + #[test] + fn pub_sub_subscription() { + use redis::RedisError; + + let ctx = TestContext::new(); + block_on_all(async move { + let mut pubsub_conn = ctx.async_pubsub().await?; + pubsub_conn.subscribe("phonewave").await?; + let mut pubsub_stream = pubsub_conn.on_message(); + let mut publish_conn = ctx.async_connection().await?; + publish_conn.publish("phonewave", "banana").await?; + + let msg_payload: String = pubsub_stream.next().await.unwrap().get_payload()?; + assert_eq!("banana".to_string(), msg_payload); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn pub_sub_unsubscription() { + use redis::RedisError; + + const SUBSCRIPTION_KEY: &str = "phonewave-pub-sub-unsubscription"; + + let ctx = TestContext::new(); + block_on_all(async move { + let mut pubsub_conn = ctx.async_pubsub().await?; + pubsub_conn.subscribe(SUBSCRIPTION_KEY).await?; + pubsub_conn.unsubscribe(SUBSCRIPTION_KEY).await?; + + let mut conn = ctx.async_connection().await?; + let subscriptions_counts: HashMap = redis::cmd("PUBSUB") + .arg("NUMSUB") + .arg(SUBSCRIPTION_KEY) + .query_async(&mut conn) + .await?; + let subscription_count = *subscriptions_counts.get(SUBSCRIPTION_KEY).unwrap(); + assert_eq!(subscription_count, 0); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn automatic_unsubscription() { + use redis::RedisError; + + const SUBSCRIPTION_KEY: &str = "phonewave-automatic-unsubscription"; + + let ctx = TestContext::new(); + block_on_all(async move { + let mut pubsub_conn = ctx.async_pubsub().await?; + pubsub_conn.subscribe(SUBSCRIPTION_KEY).await?; + drop(pubsub_conn); + + let mut conn = ctx.async_connection().await?; + let mut subscription_count = 1; + // Allow for the unsubscription to occur within 5 seconds + for _ in 0..100 { + let subscriptions_counts: HashMap = redis::cmd("PUBSUB") + .arg("NUMSUB") + .arg(SUBSCRIPTION_KEY) + .query_async(&mut conn) + .await?; + subscription_count = *subscriptions_counts.get(SUBSCRIPTION_KEY).unwrap(); + if subscription_count == 0 { + break; + } + + std::thread::sleep(Duration::from_millis(50)); + } + assert_eq!(subscription_count, 0); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn pub_sub_conn_reuse() { + use redis::RedisError; + + let ctx = TestContext::new(); + block_on_all(async move { + let mut pubsub_conn = ctx.async_pubsub().await?; + pubsub_conn.subscribe("phonewave").await?; + pubsub_conn.psubscribe("*").await?; + + #[allow(deprecated)] + let mut conn = pubsub_conn.into_connection().await; + redis::cmd("SET") + .arg("foo") + .arg("bar") + .query_async(&mut conn) + .await?; + + let res: String = redis::cmd("GET").arg("foo").query_async(&mut conn).await?; + assert_eq!(&res, "bar"); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn pipe_errors_do_not_affect_subsequent_commands() { + use redis::RedisError; + + let ctx = TestContext::new(); + block_on_all(async move { + let mut conn = ctx.multiplexed_async_connection().await?; + + conn.lpush::<&str, &str, ()>("key", "value").await?; + + let res: Result<(String, usize), redis::RedisError> = redis::pipe() + .get("key") // WRONGTYPE + .llen("key") + .query_async(&mut conn) + .await; + + assert!(res.is_err()); + + let list: Vec = conn.lrange("key", 0, -1).await?; + + assert_eq!(list, vec!["value".to_owned()]); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn pub_sub_multiple() { + use redis::RedisError; + + let ctx = TestContext::new(); + if ctx.protocol == ProtocolVersion::RESP2 { + return; + } + block_on_all(async move { + let mut conn = ctx.multiplexed_async_connection().await?; + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + let pub_count = 10; + let channel_name = "phonewave".to_string(); + conn.get_push_manager().replace_sender(tx.clone()); + conn.subscribe(channel_name.clone()).await?; + rx.recv().await.unwrap(); //PASS SUBSCRIBE + + let mut publish_conn = ctx.async_connection().await?; + for i in 0..pub_count { + publish_conn + .publish(channel_name.clone(), format!("banana {i}")) + .await?; + } + for _ in 0..pub_count { + rx.recv().await.unwrap(); + } + assert!(rx.try_recv().is_err()); + + { + //Lets test if unsubscribing from individual channel subscription works + publish_conn + .publish(channel_name.clone(), "banana!") + .await?; + rx.recv().await.unwrap(); + } + { + //Giving none for channel id should unsubscribe all subscriptions from that channel and send unsubcribe command to server. + conn.unsubscribe(channel_name.clone()).await?; + rx.recv().await.unwrap(); //PASS UNSUBSCRIBE + publish_conn + .publish(channel_name.clone(), "banana!") + .await?; + //Let's wait for 100ms to make sure there is nothing in channel. + tokio::time::sleep(Duration::from_millis(100)).await; + assert!(rx.try_recv().is_err()); + } + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn push_manager_active_context() { + use redis::RedisError; + + let ctx = TestContext::new(); + if ctx.protocol == ProtocolVersion::RESP2 { + return; + } + block_on_all(async move { + let mut sub_conn = ctx.multiplexed_async_connection().await?; + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + let channel_name = "test_channel".to_string(); + sub_conn.get_push_manager().replace_sender(tx.clone()); + sub_conn.subscribe(channel_name.clone()).await?; + + let rcv_msg = rx.recv().await.unwrap(); + println!("Received PushInfo: {:?}", rcv_msg); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn push_manager_disconnection() { + use redis::RedisError; + + let ctx = TestContext::new(); + if ctx.protocol == ProtocolVersion::RESP2 { + return; + } + block_on_all(async move { + let mut conn = ctx.multiplexed_async_connection().await?; + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + conn.get_push_manager().replace_sender(tx.clone()); + + conn.set("A", "1").await?; + assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty); + drop(ctx); + let x: RedisResult<()> = conn.set("A", "1").await; + assert!(x.is_err()); + assert_eq!(rx.recv().await.unwrap().kind, PushKind::Disconnection); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + } + + #[test] + fn test_async_basic_pipe_with_parsing_error() { + // Tests a specific case involving repeated errors in transactions. + let ctx = TestContext::new(); + + block_on_all(async move { + let mut conn = ctx.multiplexed_async_connection().await?; + + // create a transaction where 2 errors are returned. + // we call EVALSHA twice with no loaded script, thus triggering 2 errors. + redis::pipe() + .atomic() + .cmd("EVALSHA") + .arg("foobar") + .arg(0) + .cmd("EVALSHA") + .arg("foobar") + .arg(0) + .query_async::<_, ((), ())>(&mut conn) + .await + .expect_err("should return an error"); + + assert!( + // Arbitrary Redis command that should not return an error. + redis::cmd("SMEMBERS") + .arg("nonexistent_key") + .query_async::<_, Vec>(&mut conn) + .await + .is_ok(), + "Failed transaction should not interfere with future calls." + ); + + Ok::<_, redis::RedisError>(()) + }) + .unwrap() + } + + #[cfg(feature = "connection-manager")] + async fn wait_for_server_to_become_ready(client: redis::Client) { + let millisecond = std::time::Duration::from_millis(1); + let mut retries = 0; + loop { + match client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await + { + Err(err) => { + if err.is_connection_refusal() { + tokio::time::sleep(millisecond).await; + retries += 1; + if retries > 100000 { + panic!("Tried to connect too many times, last error: {err}"); + } + } else { + panic!("Could not connect: {err}"); + } + } + Ok(mut con) => { + let _: RedisResult<()> = redis::cmd("FLUSHDB").query_async(&mut con).await; + break; + } + } + } + } + + #[test] + #[cfg(feature = "connection-manager")] + fn test_connection_manager_reconnect_after_delay() { + use redis::ProtocolVersion; + + let tempdir = tempfile::Builder::new() + .prefix("redis") + .tempdir() + .expect("failed to create tempdir"); + let tls_files = build_keys_and_certs_for_tls(&tempdir); + + let ctx = TestContext::with_tls(tls_files.clone(), false); + block_on_all(async move { + let mut manager = redis::aio::ConnectionManager::new(ctx.client.clone()) + .await + .unwrap(); + let server = ctx.server; + let addr = server.client_addr().clone(); + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + manager.get_push_manager().replace_sender(tx.clone()); + drop(server); + + let _result: RedisResult = manager.set("foo", "bar").await; // one call is ignored because it's required to trigger the connection manager's reconnect. + if ctx.protocol != ProtocolVersion::RESP2 { + assert_eq!(rx.recv().await.unwrap().kind, PushKind::Disconnection); + } + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let _new_server = RedisServer::new_with_addr_and_modules(addr.clone(), &[], false); + wait_for_server_to_become_ready(ctx.client.clone()).await; + + let result: redis::Value = manager.set("foo", "bar").await.unwrap(); + assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty); + assert_eq!(result, redis::Value::Okay); + Ok(()) + }) + .unwrap(); + } + + #[cfg(feature = "tls-rustls")] + mod mtls_test { + use super::*; + + #[test] + fn test_should_connect_mtls() { + let ctx = TestContext::new_with_mtls(); + + let client = + build_single_client(ctx.server.connection_info(), &ctx.server.tls_paths, true) + .unwrap(); + let connect = + client.get_multiplexed_async_connection(GlideConnectionOptions::default()); + block_on_all(connect.and_then(|mut con| async move { + redis::cmd("SET") + .arg("key1") + .arg(b"foo") + .query_async(&mut con) + .await?; + let result = redis::cmd("GET").arg(&["key1"]).query_async(&mut con).await; + assert_eq!(result, Ok("foo".to_string())); + result + })) + .unwrap(); + } + + #[test] + fn test_should_not_connect_if_tls_active() { + let ctx = TestContext::new_with_mtls(); + + let client = + build_single_client(ctx.server.connection_info(), &ctx.server.tls_paths, false) + .unwrap(); + let connect = + client.get_multiplexed_async_connection(GlideConnectionOptions::default()); + let result = block_on_all(connect.and_then(|mut con| async move { + redis::cmd("SET") + .arg("key1") + .arg(b"foo") + .query_async(&mut con) + .await?; + let result = redis::cmd("GET").arg(&["key1"]).query_async(&mut con).await; + assert_eq!(result, Ok("foo".to_string())); + result + })); + + // depends on server type set (REDISRS_SERVER_TYPE) + match ctx.server.connection_info() { + redis::ConnectionInfo { + addr: redis::ConnectionAddr::TcpTls { .. }, + .. + } => { + if result.is_ok() { + panic!("Must NOT be able to connect without client credentials if server accepts TLS"); + } + } + _ => { + if result.is_err() { + panic!("Must be able to connect without client credentials if server does NOT accept TLS"); + } + } + } + } + } + + #[test] + fn test_set_client_name_by_config() { + const CLIENT_NAME: &str = "TEST_CLIENT_NAME"; + use redis::RedisError; + let ctx = TestContext::with_client_name(CLIENT_NAME); + + block_on_all(async move { + let mut con = ctx.async_connection().await?; + + let client_info: String = redis::cmd("CLIENT") + .arg("INFO") + .query_async(&mut con) + .await + .unwrap(); + + let client_attrs = parse_client_info(&client_info); + + assert!( + client_attrs.contains_key("name"), + "Could not detect the 'name' attribute in CLIENT INFO output" + ); + + assert_eq!( + client_attrs["name"], CLIENT_NAME, + "Incorrect client name, expecting: {}, got {}", + CLIENT_NAME, client_attrs["name"] + ); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[cfg(feature = "connection-manager")] + fn test_push_manager_cm() { + use redis::ProtocolVersion; + + let ctx = TestContext::new(); + if ctx.protocol == ProtocolVersion::RESP2 { + return; + } + + block_on_all(async move { + let mut manager = redis::aio::ConnectionManager::new(ctx.client.clone()) + .await + .unwrap(); + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + manager.get_push_manager().replace_sender(tx.clone()); + manager + .send_packed_command(cmd("CLIENT").arg("TRACKING").arg("ON")) + .await + .unwrap(); + let pipe = build_simple_pipeline_for_invalidation(); + let _: RedisResult<()> = pipe.query_async(&mut manager).await; + let _: i32 = manager.get("key_1").await.unwrap(); + let PushInfo { kind, data } = rx.try_recv().unwrap(); + assert_eq!( + ( + PushKind::Invalidate, + vec![Value::Array(vec![Value::BulkString( + "key_1".as_bytes().to_vec() + )])] + ), + (kind, data) + ); + let (new_tx, mut new_rx) = tokio::sync::mpsc::unbounded_channel(); + manager.get_push_manager().replace_sender(new_tx); + drop(rx); + let _: RedisResult<()> = pipe.query_async(&mut manager).await; + let _: i32 = manager.get("key_1").await.unwrap(); + let PushInfo { kind, data } = new_rx.try_recv().unwrap(); + assert_eq!( + ( + PushKind::Invalidate, + vec![Value::Array(vec![Value::BulkString( + "key_1".as_bytes().to_vec() + )])] + ), + (kind, data) + ); + assert_eq!(TryRecvError::Empty, new_rx.try_recv().err().unwrap()); + Ok(()) + }) + .unwrap(); + } +} diff --git a/glide-core/redis-rs/redis/tests/test_async_cluster_connections_logic.rs b/glide-core/redis-rs/redis/tests/test_async_cluster_connections_logic.rs new file mode 100644 index 0000000000..356c5bfc8c --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_async_cluster_connections_logic.rs @@ -0,0 +1,563 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +#![cfg(feature = "cluster-async")] +mod support; + +use redis::{ + cluster_async::testing::{AsyncClusterNode, RefreshConnectionType}, + testing::ClusterParams, + ErrorKind, GlideConnectionOptions, +}; +use std::net::{IpAddr, Ipv4Addr}; +use std::sync::Arc; +use support::{ + get_mock_connection, get_mock_connection_with_port, modify_mock_connection_behavior, + respond_startup, ConnectionIPReturnType, MockConnection, MockConnectionBehavior, +}; + +mod test_connect_and_check { + use std::sync::atomic::AtomicUsize; + + use super::*; + use crate::support::{get_mock_connection_handler, ShouldReturnConnectionError}; + use redis::cluster_async::testing::{ + connect_and_check, ConnectAndCheckResult, ConnectionDetails, + }; + + fn assert_partial_result( + result: ConnectAndCheckResult, + ) -> (AsyncClusterNode, redis::RedisError) { + match result { + ConnectAndCheckResult::ManagementConnectionFailed { node, err } => (node, err), + ConnectAndCheckResult::Success(_) => { + panic!("Expected partial result, got full success") + } + ConnectAndCheckResult::Failed(_) => panic!("Expected partial result, got a failure"), + } + } + + fn assert_full_success( + result: ConnectAndCheckResult, + ) -> AsyncClusterNode { + match result { + ConnectAndCheckResult::Success(node) => node, + ConnectAndCheckResult::ManagementConnectionFailed { .. } => { + panic!("Expected full success, got partial success") + } + ConnectAndCheckResult::Failed(_) => panic!("Expected partial result, got a failure"), + } + } + + #[tokio::test] + async fn test_connect_and_check_connect_successfully() { + // Test that upon refreshing all connections, if both connections were successful, + // the returned node contains both user and management connection + let name = "test_connect_and_check_connect_successfully"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + + let ip = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)); + modify_mock_connection_behavior(name, |behavior| { + behavior.returned_ip_type = ConnectionIPReturnType::Specified(ip) + }); + + let result = connect_and_check::( + &format!("{name}:6379"), + ClusterParams::default(), + None, + RefreshConnectionType::AllConnections, + None, + GlideConnectionOptions::default(), + ) + .await; + let node = assert_full_success(result); + assert!(node.management_connection.is_some()); + assert_eq!(node.user_connection.ip, Some(ip)); + assert_eq!(node.management_connection.unwrap().ip, Some(ip)); + } + + #[tokio::test] + async fn test_connect_and_check_all_connections_one_connection_err_returns_only_user_conn() { + // Test that upon refreshing all connections, if only one of the new connections fail, + // the other successful connection will be used as the user connection, as a partial success. + let name = "all_connections_one_connection_err"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + modify_mock_connection_behavior(name, |behavior| { + // The second connection will fail + behavior.return_connection_err = + ShouldReturnConnectionError::OnOddIdx(AtomicUsize::new(0)) + }); + + let params = ClusterParams::default(); + + let result = connect_and_check::( + &format!("{name}:6379"), + params.clone(), + None, + RefreshConnectionType::AllConnections, + None, + GlideConnectionOptions::default(), + ) + .await; + let (node, _) = assert_partial_result(result); + assert!(node.management_connection.is_none()); + + modify_mock_connection_behavior(name, |behavior| { + // The first connection will fail + behavior.return_connection_err = + ShouldReturnConnectionError::OnOddIdx(AtomicUsize::new(1)); + }); + + let result = connect_and_check::( + &format!("{name}:6379"), + params, + None, + RefreshConnectionType::AllConnections, + None, + GlideConnectionOptions::default(), + ) + .await; + let (node, _) = assert_partial_result(result); + assert!(node.management_connection.is_none()); + } + + #[tokio::test] + async fn test_connect_and_check_all_connections_different_ip_returns_both_connections() { + // Test that node's connections (e.g. user and management) can have different IPs for the same DNS endpoint. + // It is relevant for cases where the DNS entry holds multiple IPs that routes to the same node, for example with load balancers. + // The test verifies that upon refreshing all connections, if the IPs of the new connections differ, + // the function uses all connections. + let name = "all_connections_different_ip"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + modify_mock_connection_behavior(name, |behavior| { + behavior.returned_ip_type = ConnectionIPReturnType::Different(AtomicUsize::new(0)); + }); + + // The first connection will have 0.0.0.0 IP, the second 1.0.0.0 + let result = connect_and_check::( + &format!("{name}:6379"), + ClusterParams::default(), + None, + RefreshConnectionType::AllConnections, + None, + GlideConnectionOptions::default(), + ) + .await; + let node = assert_full_success(result); + assert!(node.management_connection.is_some()); + assert_eq!( + node.user_connection.ip, + Some(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))) + ); + assert_eq!( + node.management_connection.unwrap().ip, + Some(IpAddr::V4(Ipv4Addr::new(1, 0, 0, 0))) + ); + } + + #[tokio::test] + async fn test_connect_and_check_all_connections_both_conn_error_returns_err() { + // Test that when trying to refresh all connections and both connections fail, the function returns with an error + let name = "both_conn_error_returns_err"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + modify_mock_connection_behavior(name, |behavior| { + behavior.return_connection_err = ShouldReturnConnectionError::Yes + }); + + let result = connect_and_check::( + &format!("{name}:6379"), + ClusterParams::default(), + None, + RefreshConnectionType::AllConnections, + None, + GlideConnectionOptions::default(), + ) + .await; + let err = result.get_error().unwrap(); + assert!( + err.to_string() + .contains("Failed to refresh both connections") + && err.kind() == ErrorKind::IoError + ); + } + + #[tokio::test] + async fn test_connect_and_check_only_management_same_ip() { + // Test that when we refresh only the management connection and the new connection returned with the same IP as the user's, + // the returned node contains a new management connection and the user connection remains unchanged + let name = "only_management_same_ip"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + + let ip = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)); + modify_mock_connection_behavior(name, |behavior| { + behavior.returned_ip_type = ConnectionIPReturnType::Specified(ip) + }); + + let user_conn_id: usize = 1000; + let user_conn = MockConnection { + id: user_conn_id, + handler: get_mock_connection_handler(name), + port: 6379, + }; + let node = AsyncClusterNode::new( + ConnectionDetails { + conn: user_conn, + ip: Some(ip), + } + .into_future(), + None, + ); + + let result = connect_and_check::( + &format!("{name}:6379"), + ClusterParams::default(), + None, + RefreshConnectionType::OnlyManagementConnection, + Some(node), + GlideConnectionOptions::default(), + ) + .await; + let node = assert_full_success(result); + assert!(node.management_connection.is_some()); + // Confirm that the user connection remains unchanged + assert_eq!(node.user_connection.conn.await.id, user_conn_id); + } + + #[tokio::test] + async fn test_connect_and_check_only_management_connection_err() { + // Test that when we try the refresh only the management connection and it fails, we receive a partial success with the same node. + let name = "only_management_connection_err"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + modify_mock_connection_behavior(name, |behavior| { + behavior.return_connection_err = ShouldReturnConnectionError::Yes; + }); + + let user_conn_id: usize = 1000; + let user_conn = MockConnection { + id: user_conn_id, + handler: get_mock_connection_handler(name), + port: 6379, + }; + let prev_ip = Some(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1))); + let node = AsyncClusterNode::new( + ConnectionDetails { + conn: user_conn, + ip: prev_ip, + } + .into_future(), + None, + ); + + let result = connect_and_check::( + &format!("{name}:6379"), + ClusterParams::default(), + None, + RefreshConnectionType::OnlyManagementConnection, + Some(node), + GlideConnectionOptions::default(), + ) + .await; + let (node, _) = assert_partial_result(result); + assert!(node.management_connection.is_none()); + // Confirm that the user connection was changed + assert_eq!(node.user_connection.conn.await.id, user_conn_id); + assert_eq!(node.user_connection.ip, prev_ip); + } + + #[tokio::test] + async fn test_connect_and_check_only_user_connection_same_ip() { + // Test that upon refreshing only the user connection, if the newly created connection share the same IP as the existing management connection, + // the managament connection remains unchanged + let name = "only_user_connection_same_ip"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + + let prev_ip = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)); + modify_mock_connection_behavior(name, |behavior| { + behavior.returned_ip_type = ConnectionIPReturnType::Specified(prev_ip); + }); + let old_user_conn_id: usize = 1000; + let management_conn_id: usize = 2000; + let old_user_conn = MockConnection { + id: old_user_conn_id, + handler: get_mock_connection_handler(name), + port: 6379, + }; + let management_conn = MockConnection { + id: management_conn_id, + handler: get_mock_connection_handler(name), + port: 6379, + }; + + let node = AsyncClusterNode::new( + ConnectionDetails { + conn: old_user_conn, + ip: Some(prev_ip), + } + .into_future(), + Some( + ConnectionDetails { + conn: management_conn, + ip: Some(prev_ip), + } + .into_future(), + ), + ); + + let result = connect_and_check::( + &format!("{name}:6379"), + ClusterParams::default(), + None, + RefreshConnectionType::OnlyUserConnection, + Some(node), + GlideConnectionOptions::default(), + ) + .await; + let node = assert_full_success(result); + // Confirm that a new user connection was created + assert_ne!(node.user_connection.conn.await.id, old_user_conn_id); + // Confirm that the management connection remains unchanged + assert_eq!( + node.management_connection.unwrap().conn.await.id, + management_conn_id + ); + } +} + +mod test_check_node_connections { + + use super::*; + use redis::cluster_async::testing::{check_node_connections, ConnectionDetails}; + fn create_node_with_all_connections(name: &str) -> AsyncClusterNode { + let ip = None; + AsyncClusterNode::new( + ConnectionDetails { + conn: get_mock_connection_with_port(name, 1, 6380), + ip, + } + .into_future(), + Some( + ConnectionDetails { + conn: get_mock_connection_with_port(name, 2, 6381), + ip, + } + .into_future(), + ), + ) + } + + #[tokio::test] + async fn test_check_node_connections_find_no_problem() { + // Test that upon when checking both connections, if both connections are healthy no issue is returned. + let name = "test_check_node_connections_find_no_problem"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + let node = create_node_with_all_connections(name); + let response = check_node_connections::( + &node, + &ClusterParams::default(), + RefreshConnectionType::AllConnections, + name, + ) + .await; + assert_eq!(response, None); + } + + #[tokio::test] + async fn test_check_node_connections_find_management_connection_issue() { + // Test that upon checking both connections, if management connection isn't responding to pings, `OnlyManagementConnection` will be returned. + let name = "test_check_node_connections_find_management_connection_issue"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, port| { + if port == 6381 { + return Err(Err((ErrorKind::ClientError, "some error").into())); + } + respond_startup(name, cmd)?; + Ok(()) + }), + ); + + let node = create_node_with_all_connections(name); + let response = check_node_connections::( + &node, + &ClusterParams::default(), + RefreshConnectionType::AllConnections, + name, + ) + .await; + assert_eq!( + response, + Some(RefreshConnectionType::OnlyManagementConnection) + ); + } + + #[tokio::test] + async fn test_check_node_connections_find_missing_management_connection() { + // Test that upon checking both connections, if management connection isn't present, `OnlyManagementConnection` will be returned. + let name = "test_check_node_connections_find_missing_management_connection"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + + let ip = None; + let node = AsyncClusterNode::new( + ConnectionDetails { + conn: get_mock_connection(name, 1), + ip, + } + .into_future(), + None, + ); + let response = check_node_connections::( + &node, + &ClusterParams::default(), + RefreshConnectionType::AllConnections, + name, + ) + .await; + assert_eq!( + response, + Some(RefreshConnectionType::OnlyManagementConnection) + ); + } + + #[tokio::test] + async fn test_check_node_connections_find_both_connections_issue() { + // Test that upon checking both connections, if management connection isn't responding to pings, `OnlyManagementConnection` will be returned. + let name = "test_check_node_connections_find_both_connections_issue"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|_, _| Err(Err((ErrorKind::ClientError, "some error").into()))), + ); + + let node = create_node_with_all_connections(name); + let response = check_node_connections::( + &node, + &ClusterParams::default(), + RefreshConnectionType::AllConnections, + name, + ) + .await; + assert_eq!(response, Some(RefreshConnectionType::AllConnections)); + } + + #[tokio::test] + async fn test_check_node_connections_find_user_connection_issue() { + // Test that upon checking both connections, if user connection isn't responding to pings, `OnlyUserConnection` will be returned. + let name = "test_check_node_connections_find_user_connection_issue"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, port| { + if port == 6380 { + return Err(Err((ErrorKind::ClientError, "some error").into())); + } + respond_startup(name, cmd)?; + Ok(()) + }), + ); + + let node = create_node_with_all_connections(name); + let response = check_node_connections::( + &node, + &ClusterParams::default(), + RefreshConnectionType::AllConnections, + name, + ) + .await; + assert_eq!(response, Some(RefreshConnectionType::OnlyUserConnection)); + } + + #[tokio::test] + async fn test_check_node_connections_ignore_missing_management_connection_when_refreshing_user() + { + // Test that upon checking only user connection, issues with management connection won't affect the result. + let name = + "test_check_node_connections_ignore_management_connection_issue_when_refreshing_user"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + + let node = AsyncClusterNode::new( + ConnectionDetails { + conn: get_mock_connection(name, 1), + ip: None, + } + .into_future(), + None, + ); + let response = check_node_connections::( + &node, + &ClusterParams::default(), + RefreshConnectionType::OnlyUserConnection, + name, + ) + .await; + assert_eq!(response, None); + } +} diff --git a/glide-core/redis-rs/redis/tests/test_basic.rs b/glide-core/redis-rs/redis/tests/test_basic.rs new file mode 100644 index 0000000000..fc359ff0ae --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_basic.rs @@ -0,0 +1,1652 @@ +#![allow(clippy::let_unit_value)] + +mod support; + +#[cfg(test)] +mod basic { + use redis::{cmd, ProtocolVersion, PushInfo}; + use redis::{ + Commands, ConnectionInfo, ConnectionLike, ControlFlow, ErrorKind, ExistenceCheck, Expiry, + PubSubCommands, PushKind, RedisResult, SetExpiry, SetOptions, ToRedisArgs, Value, + }; + use std::collections::{BTreeMap, BTreeSet}; + use std::collections::{HashMap, HashSet}; + use std::thread::{sleep, spawn}; + use std::time::Duration; + use std::vec; + use tokio::sync::mpsc::error::TryRecvError; + + use crate::{assert_args, support::*}; + + #[test] + #[serial_test::serial] + fn test_parse_redis_url() { + let redis_url = "redis://127.0.0.1:1234/0".to_string(); + redis::parse_redis_url(&redis_url).unwrap(); + redis::parse_redis_url("unix:/var/run/redis/redis.sock").unwrap(); + assert!(redis::parse_redis_url("127.0.0.1").is_none()); + } + + #[test] + #[serial_test::serial] + fn test_redis_url_fromstr() { + let _info: ConnectionInfo = "redis://127.0.0.1:1234/0".parse().unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_args() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("SET").arg("key1").arg(b"foo").execute(&mut con); + redis::cmd("SET").arg(&["key2", "bar"]).execute(&mut con); + + assert_eq!( + redis::cmd("MGET").arg(&["key1", "key2"]).query(&mut con), + Ok(("foo".to_string(), b"bar".to_vec())) + ); + } + + #[test] + #[serial_test::serial] + fn test_getset() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("SET").arg("foo").arg(42).execute(&mut con); + assert_eq!(redis::cmd("GET").arg("foo").query(&mut con), Ok(42)); + + redis::cmd("SET").arg("bar").arg("foo").execute(&mut con); + assert_eq!( + redis::cmd("GET").arg("bar").query(&mut con), + Ok(b"foo".to_vec()) + ); + } + + //unit test for key_type function + #[test] + #[serial_test::serial] + fn test_key_type() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + //The key is a simple value + redis::cmd("SET").arg("foo").arg(42).execute(&mut con); + let string_key_type: String = con.key_type("foo").unwrap(); + assert_eq!(string_key_type, "string"); + + //The key is a list + redis::cmd("LPUSH") + .arg("list_bar") + .arg("foo") + .execute(&mut con); + let list_key_type: String = con.key_type("list_bar").unwrap(); + assert_eq!(list_key_type, "list"); + + //The key is a set + redis::cmd("SADD") + .arg("set_bar") + .arg("foo") + .execute(&mut con); + let set_key_type: String = con.key_type("set_bar").unwrap(); + assert_eq!(set_key_type, "set"); + + //The key is a sorted set + redis::cmd("ZADD") + .arg("sorted_set_bar") + .arg("1") + .arg("foo") + .execute(&mut con); + let zset_key_type: String = con.key_type("sorted_set_bar").unwrap(); + assert_eq!(zset_key_type, "zset"); + + //The key is a hash + redis::cmd("HSET") + .arg("hset_bar") + .arg("hset_key_1") + .arg("foo") + .execute(&mut con); + let hash_key_type: String = con.key_type("hset_bar").unwrap(); + assert_eq!(hash_key_type, "hash"); + } + + #[test] + #[serial_test::serial] + fn test_client_tracking_doesnt_block_execution() { + //It checks if the library distinguish a push-type message from the others and continues its normal operation. + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let (k1, k2): (i32, i32) = redis::pipe() + .cmd("CLIENT") + .arg("TRACKING") + .arg("ON") + .ignore() + .cmd("GET") + .arg("key_1") + .ignore() + .cmd("SET") + .arg("key_1") + .arg(42) + .ignore() + .cmd("SET") + .arg("key_2") + .arg(43) + .ignore() + .cmd("GET") + .arg("key_1") + .cmd("GET") + .arg("key_2") + .cmd("SET") + .arg("key_1") + .arg(45) + .ignore() + .query(&mut con) + .unwrap(); + assert_eq!(k1, 42); + assert_eq!(k2, 43); + let num: i32 = con.get("key_1").unwrap(); + assert_eq!(num, 45); + } + + #[test] + #[serial_test::serial] + fn test_incr() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("SET").arg("foo").arg(42).execute(&mut con); + assert_eq!(redis::cmd("INCR").arg("foo").query(&mut con), Ok(43usize)); + } + + #[test] + #[serial_test::serial] + fn test_getdel() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("SET").arg("foo").arg(42).execute(&mut con); + + assert_eq!(con.get_del("foo"), Ok(42usize)); + + assert_eq!( + redis::cmd("GET").arg("foo").query(&mut con), + Ok(None::) + ); + } + + #[test] + #[serial_test::serial] + fn test_getex() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("SET").arg("foo").arg(42usize).execute(&mut con); + + // Return of get_ex must match set value + let ret_value = con.get_ex::<_, usize>("foo", Expiry::EX(1)).unwrap(); + assert_eq!(ret_value, 42usize); + + // Get before expiry time must also return value + sleep(Duration::from_millis(100)); + let delayed_get = con.get::<_, usize>("foo").unwrap(); + assert_eq!(delayed_get, 42usize); + + // Get after expiry time mustn't return value + sleep(Duration::from_secs(1)); + let after_expire_get = con.get::<_, Option>("foo").unwrap(); + assert_eq!(after_expire_get, None); + + // Persist option test prep + redis::cmd("SET").arg("foo").arg(420usize).execute(&mut con); + + // Return of get_ex with persist option must match set value + let ret_value = con.get_ex::<_, usize>("foo", Expiry::PERSIST).unwrap(); + assert_eq!(ret_value, 420usize); + + // Get after persist get_ex must return value + sleep(Duration::from_millis(200)); + let delayed_get = con.get::<_, usize>("foo").unwrap(); + assert_eq!(delayed_get, 420usize); + } + + #[test] + #[serial_test::serial] + fn test_info() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let info: redis::InfoDict = redis::cmd("INFO").query(&mut con).unwrap(); + assert_eq!( + info.find(&"role"), + Some(&redis::Value::SimpleString("master".to_string())) + ); + assert_eq!(info.get("role"), Some("master".to_string())); + assert_eq!(info.get("loading"), Some(false)); + assert!(!info.is_empty()); + assert!(info.contains_key(&"role")); + } + + #[test] + #[serial_test::serial] + fn test_hash_ops() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("HSET") + .arg("foo") + .arg("key_1") + .arg(1) + .execute(&mut con); + redis::cmd("HSET") + .arg("foo") + .arg("key_2") + .arg(2) + .execute(&mut con); + + let h: HashMap = redis::cmd("HGETALL").arg("foo").query(&mut con).unwrap(); + assert_eq!(h.len(), 2); + assert_eq!(h.get("key_1"), Some(&1i32)); + assert_eq!(h.get("key_2"), Some(&2i32)); + + let h: BTreeMap = redis::cmd("HGETALL").arg("foo").query(&mut con).unwrap(); + assert_eq!(h.len(), 2); + assert_eq!(h.get("key_1"), Some(&1i32)); + assert_eq!(h.get("key_2"), Some(&2i32)); + } + + // Requires redis-server >= 4.0.0. + // Not supported with the current appveyor/windows binary deployed. + #[cfg(not(target_os = "windows"))] + #[test] + #[serial_test::serial] + fn test_unlink() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("SET").arg("foo").arg(42).execute(&mut con); + assert_eq!(redis::cmd("GET").arg("foo").query(&mut con), Ok(42)); + assert_eq!(con.unlink("foo"), Ok(1)); + + redis::cmd("SET").arg("foo").arg(42).execute(&mut con); + redis::cmd("SET").arg("bar").arg(42).execute(&mut con); + assert_eq!(con.unlink(&["foo", "bar"]), Ok(2)); + } + + #[test] + #[serial_test::serial] + fn test_set_ops() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.sadd("foo", &[1, 2, 3]), Ok(3)); + + let mut s: Vec = con.smembers("foo").unwrap(); + s.sort_unstable(); + assert_eq!(s.len(), 3); + assert_eq!(&s, &[1, 2, 3]); + + let set: HashSet = con.smembers("foo").unwrap(); + assert_eq!(set.len(), 3); + assert!(set.contains(&1i32)); + assert!(set.contains(&2i32)); + assert!(set.contains(&3i32)); + + let set: BTreeSet = con.smembers("foo").unwrap(); + assert_eq!(set.len(), 3); + assert!(set.contains(&1i32)); + assert!(set.contains(&2i32)); + assert!(set.contains(&3i32)); + } + + #[test] + #[serial_test::serial] + fn test_scan() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.sadd("foo", &[1, 2, 3]), Ok(3)); + + let (cur, mut s): (i32, Vec) = redis::cmd("SSCAN") + .arg("foo") + .arg(0) + .query(&mut con) + .unwrap(); + s.sort_unstable(); + assert_eq!(cur, 0i32); + assert_eq!(s.len(), 3); + assert_eq!(&s, &[1, 2, 3]); + } + + #[test] + #[serial_test::serial] + fn test_optionals() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("SET").arg("foo").arg(1).execute(&mut con); + + let (a, b): (Option, Option) = redis::cmd("MGET") + .arg("foo") + .arg("missing") + .query(&mut con) + .unwrap(); + assert_eq!(a, Some(1i32)); + assert_eq!(b, None); + + let a = redis::cmd("GET") + .arg("missing") + .query(&mut con) + .unwrap_or(0i32); + assert_eq!(a, 0i32); + } + + #[test] + #[serial_test::serial] + fn test_scanning() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let mut unseen = HashSet::new(); + + for x in 0..1000 { + redis::cmd("SADD").arg("foo").arg(x).execute(&mut con); + unseen.insert(x); + } + + let iter = redis::cmd("SSCAN") + .arg("foo") + .cursor_arg(0) + .clone() + .iter(&mut con) + .unwrap(); + + for x in iter { + // type inference limitations + let x: usize = x; + unseen.remove(&x); + } + + assert_eq!(unseen.len(), 0); + } + + #[test] + #[serial_test::serial] + fn test_filtered_scanning() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let mut unseen = HashSet::new(); + + for x in 0..3000 { + let _: () = con + .hset("foo", format!("key_{}_{}", x % 100, x), x) + .unwrap(); + if x % 100 == 0 { + unseen.insert(x); + } + } + + let iter = con + .hscan_match::<&str, &str, (String, usize)>("foo", "key_0_*") + .unwrap(); + + for (_field, value) in iter { + unseen.remove(&value); + } + + assert_eq!(unseen.len(), 0); + } + + #[test] + #[serial_test::serial] + fn test_pipeline() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let ((k1, k2),): ((i32, i32),) = redis::pipe() + .cmd("SET") + .arg("key_1") + .arg(42) + .ignore() + .cmd("SET") + .arg("key_2") + .arg(43) + .ignore() + .cmd("MGET") + .arg(&["key_1", "key_2"]) + .query(&mut con) + .unwrap(); + + assert_eq!(k1, 42); + assert_eq!(k2, 43); + } + + #[test] + #[serial_test::serial] + fn test_pipeline_with_err() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = redis::cmd("SET") + .arg("x") + .arg("x-value") + .query(&mut con) + .unwrap(); + let _: () = redis::cmd("SET") + .arg("y") + .arg("y-value") + .query(&mut con) + .unwrap(); + + let _: () = redis::cmd("SLAVEOF") + .arg("1.1.1.1") + .arg("99") + .query(&mut con) + .unwrap(); + + let res = redis::pipe() + .set("x", "another-x-value") + .ignore() + .get("y") + .query::<()>(&mut con); + assert!(res.is_err() && res.unwrap_err().kind() == ErrorKind::ReadOnly); + + // Make sure we don't get leftover responses from the pipeline ("y-value"). See #436. + let res = redis::cmd("GET") + .arg("x") + .query::(&mut con) + .unwrap(); + assert_eq!(res, "x-value"); + } + + #[test] + #[serial_test::serial] + fn test_empty_pipeline() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = redis::pipe().cmd("PING").ignore().query(&mut con).unwrap(); + + let _: () = redis::pipe().query(&mut con).unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_pipeline_transaction() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let ((k1, k2),): ((i32, i32),) = redis::pipe() + .atomic() + .cmd("SET") + .arg("key_1") + .arg(42) + .ignore() + .cmd("SET") + .arg("key_2") + .arg(43) + .ignore() + .cmd("MGET") + .arg(&["key_1", "key_2"]) + .query(&mut con) + .unwrap(); + + assert_eq!(k1, 42); + assert_eq!(k2, 43); + } + + #[test] + #[serial_test::serial] + fn test_pipeline_transaction_with_errors() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = con.set("x", 42).unwrap(); + + // Make Redis a replica of a nonexistent master, thereby making it read-only. + let _: () = redis::cmd("slaveof") + .arg("1.1.1.1") + .arg("1") + .query(&mut con) + .unwrap(); + + // Ensure that a write command fails with a READONLY error + let err: RedisResult<()> = redis::pipe() + .atomic() + .set("x", 142) + .ignore() + .get("x") + .query(&mut con); + + assert_eq!(err.unwrap_err().kind(), ErrorKind::ReadOnly); + + let x: i32 = con.get("x").unwrap(); + assert_eq!(x, 42); + } + + #[test] + #[serial_test::serial] + fn test_pipeline_reuse_query() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let mut pl = redis::pipe(); + + let ((k1,),): ((i32,),) = pl + .cmd("SET") + .arg("pkey_1") + .arg(42) + .ignore() + .cmd("MGET") + .arg(&["pkey_1"]) + .query(&mut con) + .unwrap(); + + assert_eq!(k1, 42); + + redis::cmd("DEL").arg("pkey_1").execute(&mut con); + + // The internal commands vector of the pipeline still contains the previous commands. + let ((k1,), (k2, k3)): ((i32,), (i32, i32)) = pl + .cmd("SET") + .arg("pkey_2") + .arg(43) + .ignore() + .cmd("MGET") + .arg(&["pkey_1"]) + .arg(&["pkey_2"]) + .query(&mut con) + .unwrap(); + + assert_eq!(k1, 42); + assert_eq!(k2, 42); + assert_eq!(k3, 43); + } + + #[test] + #[serial_test::serial] + fn test_pipeline_reuse_query_clear() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let mut pl = redis::pipe(); + + let ((k1,),): ((i32,),) = pl + .cmd("SET") + .arg("pkey_1") + .arg(44) + .ignore() + .cmd("MGET") + .arg(&["pkey_1"]) + .query(&mut con) + .unwrap(); + pl.clear(); + + assert_eq!(k1, 44); + + redis::cmd("DEL").arg("pkey_1").execute(&mut con); + + let ((k1, k2),): ((bool, i32),) = pl + .cmd("SET") + .arg("pkey_2") + .arg(45) + .ignore() + .cmd("MGET") + .arg(&["pkey_1"]) + .arg(&["pkey_2"]) + .query(&mut con) + .unwrap(); + pl.clear(); + + assert!(!k1); + assert_eq!(k2, 45); + } + + #[test] + #[serial_test::serial] + fn test_real_transaction() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let key = "the_key"; + let _: () = redis::cmd("SET").arg(key).arg(42).query(&mut con).unwrap(); + + loop { + let _: () = redis::cmd("WATCH").arg(key).query(&mut con).unwrap(); + let val: isize = redis::cmd("GET").arg(key).query(&mut con).unwrap(); + let response: Option<(isize,)> = redis::pipe() + .atomic() + .cmd("SET") + .arg(key) + .arg(val + 1) + .ignore() + .cmd("GET") + .arg(key) + .query(&mut con) + .unwrap(); + + match response { + None => { + continue; + } + Some(response) => { + assert_eq!(response, (43,)); + break; + } + } + } + } + + #[test] + #[serial_test::serial] + fn test_real_transaction_highlevel() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let key = "the_key"; + let _: () = redis::cmd("SET").arg(key).arg(42).query(&mut con).unwrap(); + + let response: (isize,) = redis::transaction(&mut con, &[key], |con, pipe| { + let val: isize = redis::cmd("GET").arg(key).query(con)?; + pipe.cmd("SET") + .arg(key) + .arg(val + 1) + .ignore() + .cmd("GET") + .arg(key) + .query(con) + }) + .unwrap(); + + assert_eq!(response, (43,)); + } + + #[test] + #[serial_test::serial] + fn test_pubsub() { + use std::sync::{Arc, Barrier}; + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + // Connection for subscriber api + let mut pubsub_con = ctx.connection(); + + // Barrier is used to make test thread wait to publish + // until after the pubsub thread has subscribed. + let barrier = Arc::new(Barrier::new(2)); + let pubsub_barrier = barrier.clone(); + + let thread = spawn(move || { + let mut pubsub = pubsub_con.as_pubsub(); + pubsub.subscribe("foo").unwrap(); + + let _ = pubsub_barrier.wait(); + + let msg = pubsub.get_message().unwrap(); + assert_eq!(msg.get_channel(), Ok("foo".to_string())); + assert_eq!(msg.get_payload(), Ok(42)); + + let msg = pubsub.get_message().unwrap(); + assert_eq!(msg.get_channel(), Ok("foo".to_string())); + assert_eq!(msg.get_payload(), Ok(23)); + }); + + let _ = barrier.wait(); + redis::cmd("PUBLISH").arg("foo").arg(42).execute(&mut con); + // We can also call the command directly + assert_eq!(con.publish("foo", 23), Ok(1)); + + thread.join().expect("Something went wrong"); + } + + #[test] + #[serial_test::serial] + fn test_pubsub_unsubscribe() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + { + let mut pubsub = con.as_pubsub(); + pubsub.subscribe("foo").unwrap(); + pubsub.subscribe("bar").unwrap(); + pubsub.subscribe("baz").unwrap(); + pubsub.psubscribe("foo*").unwrap(); + pubsub.psubscribe("bar*").unwrap(); + pubsub.psubscribe("baz*").unwrap(); + } + + // Connection should be usable again for non-pubsub commands + let _: redis::Value = con.set("foo", "bar").unwrap(); + let value: String = con.get("foo").unwrap(); + assert_eq!(&value[..], "bar"); + } + + #[test] + #[serial_test::serial] + fn test_pubsub_subscribe_while_messages_are_sent() { + let ctx = TestContext::new(); + let mut conn_external = ctx.connection(); + let mut conn_internal = ctx.connection(); + let received = std::sync::Arc::new(std::sync::Mutex::new(Vec::new())); + let received_clone = received.clone(); + let (sender, receiver) = std::sync::mpsc::channel(); + // receive message from foo channel + let thread = std::thread::spawn(move || { + let mut pubsub = conn_internal.as_pubsub(); + pubsub.subscribe("foo").unwrap(); + sender.send(()).unwrap(); + loop { + let msg = pubsub.get_message().unwrap(); + let channel = msg.get_channel_name(); + let content: i32 = msg.get_payload().unwrap(); + received + .lock() + .unwrap() + .push(format!("{channel}:{content}")); + if content == -1 { + return; + } + if content == 5 { + // subscribe bar channel using the same pubsub + pubsub.subscribe("bar").unwrap(); + sender.send(()).unwrap(); + } + } + }); + receiver.recv().unwrap(); + + // send message to foo channel after channel is ready. + for index in 0..10 { + println!("publishing on foo {index}"); + redis::cmd("PUBLISH") + .arg("foo") + .arg(index) + .query::(&mut conn_external) + .unwrap(); + } + receiver.recv().unwrap(); + redis::cmd("PUBLISH") + .arg("bar") + .arg(-1) + .query::(&mut conn_external) + .unwrap(); + thread.join().unwrap(); + assert_eq!( + *received_clone.lock().unwrap(), + (0..10) + .map(|index| format!("foo:{}", index)) + .chain(std::iter::once("bar:-1".to_string())) + .collect::>() + ); + } + + #[test] + #[serial_test::serial] + fn test_pubsub_unsubscribe_no_subs() { + let ctx = TestContext::new(); + if version_greater_or_equal(&ctx, "7.2.4") { + // Skip for versions 7.2.4 and above + return; + } + let mut con = ctx.connection(); + + { + let _pubsub = con.as_pubsub(); + } + + // Connection should be usable again for non-pubsub commands + let _: redis::Value = con.set("foo", "bar").unwrap(); + let value: String = con.get("foo").unwrap(); + assert_eq!(&value[..], "bar"); + } + + #[test] + #[serial_test::serial] + fn test_pubsub_unsubscribe_one_sub() { + let ctx = TestContext::new(); + if version_greater_or_equal(&ctx, "7.2.4") { + // Skip for versions 7.2.4 and above + return; + } + let mut con = ctx.connection(); + + { + let mut pubsub = con.as_pubsub(); + pubsub.subscribe("foo").unwrap(); + } + + // Connection should be usable again for non-pubsub commands + let _: redis::Value = con.set("foo", "bar").unwrap(); + let value: String = con.get("foo").unwrap(); + assert_eq!(&value[..], "bar"); + } + + #[test] + #[serial_test::serial] + fn test_pubsub_unsubscribe_one_sub_one_psub() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + { + let mut pubsub = con.as_pubsub(); + pubsub.subscribe("foo").unwrap(); + pubsub.psubscribe("foo*").unwrap(); + } + + // Connection should be usable again for non-pubsub commands + let _: redis::Value = con.set("foo", "bar").unwrap(); + let value: String = con.get("foo").unwrap(); + assert_eq!(&value[..], "bar"); + } + + #[test] + #[serial_test::serial] + fn scoped_pubsub() { + let ctx = TestContext::new(); + if version_greater_or_equal(&ctx, "7.2.4") { + // Skip for versions 7.2.4 and above + return; + } + let mut con = ctx.connection(); + + // Connection for subscriber api + let mut pubsub_con = ctx.connection(); + + let thread = spawn(move || { + let mut count = 0; + pubsub_con + .subscribe(&["foo", "bar"], |msg| { + count += 1; + match count { + 1 => { + assert_eq!(msg.get_channel(), Ok("foo".to_string())); + assert_eq!(msg.get_payload(), Ok(42)); + ControlFlow::Continue + } + 2 => { + assert_eq!(msg.get_channel(), Ok("bar".to_string())); + assert_eq!(msg.get_payload(), Ok(23)); + ControlFlow::Break(()) + } + _ => ControlFlow::Break(()), + } + }) + .unwrap(); + + pubsub_con + }); + + // Can't use a barrier in this case since there's no opportunity to run code + // between channel subscription and blocking for messages. + sleep(Duration::from_millis(100)); + + redis::cmd("PUBLISH").arg("foo").arg(42).execute(&mut con); + assert_eq!(con.publish("bar", 23), Ok(1)); + + // Wait for thread + let mut pubsub_con = thread.join().expect("pubsub thread terminates ok"); + + // Connection should be usable again for non-pubsub commands + let _: redis::Value = pubsub_con.set("foo", "bar").unwrap(); + let value: String = pubsub_con.get("foo").unwrap(); + assert_eq!(&value[..], "bar"); + } + + #[test] + #[serial_test::serial] + #[cfg(feature = "script")] + fn test_script() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let script = redis::Script::new( + r" + return {redis.call('GET', KEYS[1]), ARGV[1]} + ", + ); + + let _: () = redis::cmd("SET") + .arg("my_key") + .arg("foo") + .query(&mut con) + .unwrap(); + let response = script.key("my_key").arg(42).invoke(&mut con); + + assert_eq!(response, Ok(("foo".to_string(), 42))); + } + + #[test] + #[serial_test::serial] + #[cfg(feature = "script")] + fn test_script_load() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let script = redis::Script::new("return 'Hello World'"); + + let hash = script.prepare_invoke().load(&mut con); + + assert_eq!(hash, Ok(script.get_hash().to_string())); + } + + #[test] + #[serial_test::serial] + fn test_tuple_args() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("HMSET") + .arg("my_key") + .arg(&[("field_1", 42), ("field_2", 23)]) + .execute(&mut con); + + assert_eq!( + redis::cmd("HGET") + .arg("my_key") + .arg("field_1") + .query(&mut con), + Ok(42) + ); + assert_eq!( + redis::cmd("HGET") + .arg("my_key") + .arg("field_2") + .query(&mut con), + Ok(23) + ); + } + + #[test] + #[serial_test::serial] + fn test_nice_api() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.set("my_key", 42), Ok(())); + assert_eq!(con.get("my_key"), Ok(42)); + + let (k1, k2): (i32, i32) = redis::pipe() + .atomic() + .set("key_1", 42) + .ignore() + .set("key_2", 43) + .ignore() + .get("key_1") + .get("key_2") + .query(&mut con) + .unwrap(); + + assert_eq!(k1, 42); + assert_eq!(k2, 43); + } + + #[test] + #[serial_test::serial] + fn test_auto_m_versions() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.mset(&[("key1", 1), ("key2", 2)]), Ok(())); + assert_eq!(con.get(&["key1", "key2"]), Ok((1, 2))); + assert_eq!(con.get(vec!["key1", "key2"]), Ok((1, 2))); + assert_eq!(con.get(vec!["key1", "key2"]), Ok((1, 2))); + } + + #[test] + #[serial_test::serial] + fn test_nice_hash_api() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!( + con.hset_multiple("my_hash", &[("f1", 1), ("f2", 2), ("f3", 4), ("f4", 8)]), + Ok(()) + ); + + let hm: HashMap = con.hgetall("my_hash").unwrap(); + assert_eq!(hm.get("f1"), Some(&1)); + assert_eq!(hm.get("f2"), Some(&2)); + assert_eq!(hm.get("f3"), Some(&4)); + assert_eq!(hm.get("f4"), Some(&8)); + assert_eq!(hm.len(), 4); + + let hm: BTreeMap = con.hgetall("my_hash").unwrap(); + assert_eq!(hm.get("f1"), Some(&1)); + assert_eq!(hm.get("f2"), Some(&2)); + assert_eq!(hm.get("f3"), Some(&4)); + assert_eq!(hm.get("f4"), Some(&8)); + assert_eq!(hm.len(), 4); + + let v: Vec<(String, isize)> = con.hgetall("my_hash").unwrap(); + assert_eq!( + v, + vec![ + ("f1".to_string(), 1), + ("f2".to_string(), 2), + ("f3".to_string(), 4), + ("f4".to_string(), 8), + ] + ); + + assert_eq!(con.hget("my_hash", &["f2", "f4"]), Ok((2, 8))); + assert_eq!(con.hincr("my_hash", "f1", 1), Ok(2)); + assert_eq!(con.hincr("my_hash", "f2", 1.5f32), Ok(3.5f32)); + assert_eq!(con.hexists("my_hash", "f2"), Ok(true)); + assert_eq!(con.hdel("my_hash", &["f1", "f2"]), Ok(())); + assert_eq!(con.hexists("my_hash", "f2"), Ok(false)); + + let iter: redis::Iter<'_, (String, isize)> = con.hscan("my_hash").unwrap(); + let mut found = HashSet::new(); + for item in iter { + found.insert(item); + } + + assert_eq!(found.len(), 2); + assert!(found.contains(&("f3".to_string(), 4))); + assert!(found.contains(&("f4".to_string(), 8))); + } + + #[test] + #[serial_test::serial] + fn test_nice_list_api() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.rpush("my_list", &[1, 2, 3, 4]), Ok(4)); + assert_eq!(con.rpush("my_list", &[5, 6, 7, 8]), Ok(8)); + assert_eq!(con.llen("my_list"), Ok(8)); + + assert_eq!(con.lpop("my_list", Default::default()), Ok(1)); + assert_eq!(con.llen("my_list"), Ok(7)); + + assert_eq!(con.lrange("my_list", 0, 2), Ok((2, 3, 4))); + + assert_eq!(con.lset("my_list", 0, 4), Ok(true)); + assert_eq!(con.lrange("my_list", 0, 2), Ok((4, 3, 4))); + + #[cfg(not(windows))] + //Windows version of redis is limited to v3.x + { + let my_list: Vec = con.lrange("my_list", 0, 10).expect("To get range"); + assert_eq!( + con.lpop("my_list", core::num::NonZeroUsize::new(10)), + Ok(my_list) + ); + } + } + + #[test] + #[serial_test::serial] + fn test_tuple_decoding_regression() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.del("my_zset"), Ok(())); + assert_eq!(con.zadd("my_zset", "one", 1), Ok(1)); + assert_eq!(con.zadd("my_zset", "two", 2), Ok(1)); + + let vec: Vec<(String, u32)> = con.zrangebyscore_withscores("my_zset", 0, 10).unwrap(); + assert_eq!(vec.len(), 2); + + assert_eq!(con.del("my_zset"), Ok(1)); + + let vec: Vec<(String, u32)> = con.zrangebyscore_withscores("my_zset", 0, 10).unwrap(); + assert_eq!(vec.len(), 0); + } + + #[test] + #[serial_test::serial] + fn test_bit_operations() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.setbit("bitvec", 10, true), Ok(false)); + assert_eq!(con.getbit("bitvec", 10), Ok(true)); + } + + #[test] + #[serial_test::serial] + fn test_redis_server_down() { + let mut ctx = TestContext::new(); + let mut con = ctx.connection(); + + let ping = redis::cmd("PING").query::(&mut con); + assert_eq!(ping, Ok("PONG".into())); + + ctx.stop_server(); + + let ping = redis::cmd("PING").query::(&mut con); + + assert!(ping.is_err()); + eprintln!("{}", ping.unwrap_err()); + assert!(!con.is_open()); + } + + #[test] + #[serial_test::serial] + fn test_zinterstore_weights() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = con + .zadd_multiple("zset1", &[(1, "one"), (2, "two"), (4, "four")]) + .unwrap(); + let _: () = con + .zadd_multiple("zset2", &[(1, "one"), (2, "two"), (3, "three")]) + .unwrap(); + + // zinterstore_weights + assert_eq!( + con.zinterstore_weights("out", &[("zset1", 2), ("zset2", 3)]), + Ok(2) + ); + + assert_eq!( + con.zrange_withscores("out", 0, -1), + Ok(vec![ + ("one".to_string(), "5".to_string()), + ("two".to_string(), "10".to_string()) + ]) + ); + + // zinterstore_min_weights + assert_eq!( + con.zinterstore_min_weights("out", &[("zset1", 2), ("zset2", 3)]), + Ok(2) + ); + + assert_eq!( + con.zrange_withscores("out", 0, -1), + Ok(vec![ + ("one".to_string(), "2".to_string()), + ("two".to_string(), "4".to_string()), + ]) + ); + + // zinterstore_max_weights + assert_eq!( + con.zinterstore_max_weights("out", &[("zset1", 2), ("zset2", 3)]), + Ok(2) + ); + + assert_eq!( + con.zrange_withscores("out", 0, -1), + Ok(vec![ + ("one".to_string(), "3".to_string()), + ("two".to_string(), "6".to_string()), + ]) + ); + } + + #[test] + #[serial_test::serial] + fn test_zunionstore_weights() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = con + .zadd_multiple("zset1", &[(1, "one"), (2, "two")]) + .unwrap(); + let _: () = con + .zadd_multiple("zset2", &[(1, "one"), (2, "two"), (3, "three")]) + .unwrap(); + + // zunionstore_weights + assert_eq!( + con.zunionstore_weights("out", &[("zset1", 2), ("zset2", 3)]), + Ok(3) + ); + + assert_eq!( + con.zrange_withscores("out", 0, -1), + Ok(vec![ + ("one".to_string(), "5".to_string()), + ("three".to_string(), "9".to_string()), + ("two".to_string(), "10".to_string()) + ]) + ); + // test converting to double + assert_eq!( + con.zrange_withscores("out", 0, -1), + Ok(vec![ + ("one".to_string(), 5.0), + ("three".to_string(), 9.0), + ("two".to_string(), 10.0) + ]) + ); + + // zunionstore_min_weights + assert_eq!( + con.zunionstore_min_weights("out", &[("zset1", 2), ("zset2", 3)]), + Ok(3) + ); + + assert_eq!( + con.zrange_withscores("out", 0, -1), + Ok(vec![ + ("one".to_string(), "2".to_string()), + ("two".to_string(), "4".to_string()), + ("three".to_string(), "9".to_string()) + ]) + ); + + // zunionstore_max_weights + assert_eq!( + con.zunionstore_max_weights("out", &[("zset1", 2), ("zset2", 3)]), + Ok(3) + ); + + assert_eq!( + con.zrange_withscores("out", 0, -1), + Ok(vec![ + ("one".to_string(), "3".to_string()), + ("two".to_string(), "6".to_string()), + ("three".to_string(), "9".to_string()) + ]) + ); + } + + #[test] + #[serial_test::serial] + fn test_zrembylex() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let setname = "myzset"; + assert_eq!( + con.zadd_multiple( + setname, + &[ + (0, "apple"), + (0, "banana"), + (0, "carrot"), + (0, "durian"), + (0, "eggplant"), + (0, "grapes"), + ], + ), + Ok(6) + ); + + // Will remove "banana", "carrot", "durian" and "eggplant" + let num_removed: u32 = con.zrembylex(setname, "[banana", "[eggplant").unwrap(); + assert_eq!(4, num_removed); + + let remaining: Vec = con.zrange(setname, 0, -1).unwrap(); + assert_eq!(remaining, vec!["apple".to_string(), "grapes".to_string()]); + } + + // Requires redis-server >= 6.2.0. + // Not supported with the current appveyor/windows binary deployed. + #[cfg(not(target_os = "windows"))] + #[test] + #[serial_test::serial] + fn test_zrandmember() { + use redis::ProtocolVersion; + + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let setname = "myzrandset"; + let () = con.zadd(setname, "one", 1).unwrap(); + + let result: String = con.zrandmember(setname, None).unwrap(); + assert_eq!(result, "one".to_string()); + + let result: Vec = con.zrandmember(setname, Some(1)).unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0], "one".to_string()); + + let result: Vec = con.zrandmember(setname, Some(2)).unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0], "one".to_string()); + + assert_eq!( + con.zadd_multiple( + setname, + &[(2, "two"), (3, "three"), (4, "four"), (5, "five")] + ), + Ok(4) + ); + + let results: Vec = con.zrandmember(setname, Some(5)).unwrap(); + assert_eq!(results.len(), 5); + + let results: Vec = con.zrandmember(setname, Some(-5)).unwrap(); + assert_eq!(results.len(), 5); + + if ctx.protocol == ProtocolVersion::RESP2 { + let results: Vec = con.zrandmember_withscores(setname, 5).unwrap(); + assert_eq!(results.len(), 10); + + let results: Vec = con.zrandmember_withscores(setname, -5).unwrap(); + assert_eq!(results.len(), 10); + } + + let results: Vec<(String, f64)> = con.zrandmember_withscores(setname, 5).unwrap(); + assert_eq!(results.len(), 5); + + let results: Vec<(String, f64)> = con.zrandmember_withscores(setname, -5).unwrap(); + assert_eq!(results.len(), 5); + } + + #[test] + #[serial_test::serial] + fn test_sismember() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let setname = "myset"; + assert_eq!(con.sadd(setname, &["a"]), Ok(1)); + + let result: bool = con.sismember(setname, &["a"]).unwrap(); + assert!(result); + + let result: bool = con.sismember(setname, &["b"]).unwrap(); + assert!(!result); + } + + // Requires redis-server >= 6.2.0. + // Not supported with the current appveyor/windows binary deployed. + #[cfg(not(target_os = "windows"))] + #[test] + #[serial_test::serial] + fn test_smismember() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let setname = "myset"; + assert_eq!(con.sadd(setname, &["a", "b", "c"]), Ok(3)); + let results: Vec = con.smismember(setname, &["0", "a", "b", "c", "x"]).unwrap(); + assert_eq!(results, vec![false, true, true, true, false]); + } + + #[test] + #[serial_test::serial] + fn test_object_commands() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = con.set("object_key_str", "object_value_str").unwrap(); + let _: () = con.set("object_key_int", 42).unwrap(); + + assert_eq!( + con.object_encoding::<_, String>("object_key_str").unwrap(), + "embstr" + ); + + assert_eq!( + con.object_encoding::<_, String>("object_key_int").unwrap(), + "int" + ); + + assert!(con.object_idletime::<_, i32>("object_key_str").unwrap() <= 1); + assert_eq!(con.object_refcount::<_, i32>("object_key_str").unwrap(), 1); + + // Needed for OBJECT FREQ and can't be set before object_idletime + // since that will break getting the idletime before idletime adjuts + redis::cmd("CONFIG") + .arg("SET") + .arg(b"maxmemory-policy") + .arg("allkeys-lfu") + .execute(&mut con); + + let _: () = con.get("object_key_str").unwrap(); + // since maxmemory-policy changed, freq should reset to 1 since we only called + // get after that + assert_eq!(con.object_freq::<_, i32>("object_key_str").unwrap(), 1); + } + + #[test] + #[serial_test::serial] + fn test_mget() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = con.set(1, "1").unwrap(); + let data: Vec = con.mget(&[1]).unwrap(); + assert_eq!(data, vec!["1"]); + + let _: () = con.set(2, "2").unwrap(); + let data: Vec = con.mget(&[1, 2]).unwrap(); + assert_eq!(data, vec!["1", "2"]); + + let data: Vec> = con.mget(&[4]).unwrap(); + assert_eq!(data, vec![None]); + + let data: Vec> = con.mget(&[2, 4]).unwrap(); + assert_eq!(data, vec![Some("2".to_string()), None]); + } + + #[test] + #[serial_test::serial] + fn test_variable_length_get() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = con.set(1, "1").unwrap(); + let keys = vec![1]; + assert_eq!(keys.len(), 1); + let data: Vec = con.get(&keys).unwrap(); + assert_eq!(data, vec!["1"]); + } + + #[test] + #[serial_test::serial] + fn test_multi_generics() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.sadd(b"set1", vec![5, 42]), Ok(2)); + assert_eq!(con.sadd(999_i64, vec![42, 123]), Ok(2)); + let _: () = con.rename(999_i64, b"set2").unwrap(); + assert_eq!(con.sunionstore("res", &[b"set1", b"set2"]), Ok(3)); + } + + #[test] + #[serial_test::serial] + fn test_set_options_with_get() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let opts = SetOptions::default().get(true); + let data: Option = con.set_options(1, "1", opts).unwrap(); + assert_eq!(data, None); + + let opts = SetOptions::default().get(true); + let data: Option = con.set_options(1, "1", opts).unwrap(); + assert_eq!(data, Some("1".to_string())); + } + + #[test] + #[serial_test::serial] + fn test_set_options_options() { + let empty = SetOptions::default(); + assert_eq!(ToRedisArgs::to_redis_args(&empty).len(), 0); + + let opts = SetOptions::default() + .conditional_set(ExistenceCheck::NX) + .get(true) + .with_expiration(SetExpiry::PX(1000)); + + assert_args!(&opts, "NX", "GET", "PX", "1000"); + + let opts = SetOptions::default() + .conditional_set(ExistenceCheck::XX) + .get(true) + .with_expiration(SetExpiry::PX(1000)); + + assert_args!(&opts, "XX", "GET", "PX", "1000"); + + let opts = SetOptions::default() + .conditional_set(ExistenceCheck::XX) + .with_expiration(SetExpiry::KEEPTTL); + + assert_args!(&opts, "XX", "KEEPTTL"); + + let opts = SetOptions::default() + .conditional_set(ExistenceCheck::XX) + .with_expiration(SetExpiry::EXAT(100)); + + assert_args!(&opts, "XX", "EXAT", "100"); + + let opts = SetOptions::default().with_expiration(SetExpiry::EX(1000)); + + assert_args!(&opts, "EX", "1000"); + } + + #[test] + #[serial_test::serial] + fn test_blocking_sorted_set_api() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + // setup version & input data followed by assertions that take into account Redis version + // BZPOPMIN & BZPOPMAX are available from Redis version 5.0.0 + // BZMPOP is available from Redis version 7.0.0 + + let redis_version = ctx.get_version(); + assert!(redis_version.0 >= 5); + + assert_eq!(con.zadd("a", "1a", 1), Ok(())); + assert_eq!(con.zadd("b", "2b", 2), Ok(())); + assert_eq!(con.zadd("c", "3c", 3), Ok(())); + assert_eq!(con.zadd("d", "4d", 4), Ok(())); + assert_eq!(con.zadd("a", "5a", 5), Ok(())); + assert_eq!(con.zadd("b", "6b", 6), Ok(())); + assert_eq!(con.zadd("c", "7c", 7), Ok(())); + assert_eq!(con.zadd("d", "8d", 8), Ok(())); + + let min = con.bzpopmin::<&str, (String, String, String)>("b", 0.0); + let max = con.bzpopmax::<&str, (String, String, String)>("b", 0.0); + + assert_eq!( + min.unwrap(), + (String::from("b"), String::from("2b"), String::from("2")) + ); + assert_eq!( + max.unwrap(), + (String::from("b"), String::from("6b"), String::from("6")) + ); + + if redis_version.0 >= 7 { + let min = con.bzmpop_min::<&str, (String, Vec>)>( + 0.0, + vec!["a", "b", "c", "d"].as_slice(), + 1, + ); + let max = con.bzmpop_max::<&str, (String, Vec>)>( + 0.0, + vec!["a", "b", "c", "d"].as_slice(), + 1, + ); + + assert_eq!( + min.unwrap().1[0][0], + (String::from("1a"), String::from("1")) + ); + assert_eq!( + max.unwrap().1[0][0], + (String::from("5a"), String::from("5")) + ); + } + } + + #[test] + #[serial_test::serial] + fn test_set_client_name_by_config() { + const CLIENT_NAME: &str = "TEST_CLIENT_NAME"; + + let ctx = TestContext::with_client_name(CLIENT_NAME); + let mut con = ctx.connection(); + + let client_info: String = redis::cmd("CLIENT").arg("INFO").query(&mut con).unwrap(); + + let client_attrs = parse_client_info(&client_info); + + assert!( + client_attrs.contains_key("name"), + "Could not detect the 'name' attribute in CLIENT INFO output" + ); + + assert_eq!( + client_attrs["name"], CLIENT_NAME, + "Incorrect client name, expecting: {}, got {}", + CLIENT_NAME, client_attrs["name"] + ); + } + + #[test] + #[serial_test::serial] + fn test_push_manager() { + let ctx = TestContext::new(); + if ctx.protocol == ProtocolVersion::RESP2 { + return; + } + let mut con = ctx.connection(); + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + con.get_push_manager().replace_sender(tx.clone()); + let _ = cmd("CLIENT") + .arg("TRACKING") + .arg("ON") + .query::<()>(&mut con) + .unwrap(); + let pipe = build_simple_pipeline_for_invalidation(); + for _ in 0..10 { + let _: RedisResult<()> = pipe.query(&mut con); + let _: i32 = con.get("key_1").unwrap(); + let PushInfo { kind, data } = rx.try_recv().unwrap(); + assert_eq!( + ( + PushKind::Invalidate, + vec![Value::Array(vec![Value::BulkString( + "key_1".as_bytes().to_vec() + )])] + ), + (kind, data) + ); + } + let (new_tx, mut new_rx) = tokio::sync::mpsc::unbounded_channel(); + con.get_push_manager().replace_sender(new_tx.clone()); + drop(rx); + let _: RedisResult<()> = pipe.query(&mut con); + let _: i32 = con.get("key_1").unwrap(); + let PushInfo { kind, data } = new_rx.try_recv().unwrap(); + assert_eq!( + ( + PushKind::Invalidate, + vec![Value::Array(vec![Value::BulkString( + "key_1".as_bytes().to_vec() + )])] + ), + (kind, data) + ); + + { + drop(new_rx); + for _ in 0..10 { + let _: RedisResult<()> = pipe.query(&mut con); + let v: i32 = con.get("key_1").unwrap(); + assert_eq!(v, 42); + } + } + } + + #[test] + #[serial_test::serial] + fn test_push_manager_disconnection() { + let ctx = TestContext::new(); + if ctx.protocol == ProtocolVersion::RESP2 { + return; + } + let mut con = ctx.connection(); + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + con.get_push_manager().replace_sender(tx.clone()); + + let _: () = con.set("A", "1").unwrap(); + assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty); + drop(ctx); + let x: RedisResult<()> = con.set("A", "1"); + assert!(x.is_err()); + assert_eq!(rx.try_recv().unwrap().kind, PushKind::Disconnection); + } +} diff --git a/glide-core/redis-rs/redis/tests/test_bignum.rs b/glide-core/redis-rs/redis/tests/test_bignum.rs new file mode 100644 index 0000000000..20beefbc66 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_bignum.rs @@ -0,0 +1,61 @@ +#![cfg(any( + feature = "rust_decimal", + feature = "bigdecimal", + feature = "num-bigint" +))] +use redis::{ErrorKind, FromRedisValue, RedisResult, ToRedisArgs, Value}; +use std::str::FromStr; + +fn test(content: &str) +where + T: FromRedisValue + + ToRedisArgs + + std::str::FromStr + + std::convert::From + + std::cmp::PartialEq + + std::fmt::Debug, + ::Err: std::fmt::Debug, +{ + let v: RedisResult = + FromRedisValue::from_redis_value(&Value::BulkString(Vec::from(content))); + assert_eq!(v, Ok(T::from_str(content).unwrap())); + + let arg = ToRedisArgs::to_redis_args(&v.unwrap()); + assert_eq!(arg[0], Vec::from(content)); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Int(0)); + assert_eq!(v.unwrap(), T::from(0u32)); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Int(42)); + assert_eq!(v.unwrap(), T::from(42u32)); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Okay); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Nil); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); +} + +#[test] +#[cfg(feature = "rust_decimal")] +fn test_rust_decimal() { + test::("-79228162514264.337593543950335"); +} + +#[test] +#[cfg(feature = "bigdecimal")] +fn test_bigdecimal() { + test::("-14272476927059598810582859.69449495136382746623"); +} + +#[test] +#[cfg(feature = "num-bigint")] +fn test_bigint() { + test::("-1427247692705959881058285969449495136382746623"); +} + +#[test] +#[cfg(feature = "num-bigint")] +fn test_biguint() { + test::("1427247692705959881058285969449495136382746623"); +} diff --git a/glide-core/redis-rs/redis/tests/test_cluster.rs b/glide-core/redis-rs/redis/tests/test_cluster.rs new file mode 100644 index 0000000000..38b3019edb --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_cluster.rs @@ -0,0 +1,1128 @@ +#![cfg(feature = "cluster")] +mod support; + +#[cfg(test)] +mod cluster { + use std::sync::{ + atomic::{self, AtomicI32, Ordering}, + Arc, + }; + + use crate::support::*; + use redis::{ + cluster::{cluster_pipe, ClusterClient}, + cmd, parse_redis_value, Commands, ConnectionLike, ErrorKind, ProtocolVersion, RedisError, + Value, + }; + + #[test] + #[serial_test::serial] + fn test_cluster_basics() { + let cluster = TestClusterContext::new(3, 0); + let mut con = cluster.connection(); + + redis::cmd("SET") + .arg("{x}key1") + .arg(b"foo") + .execute(&mut con); + redis::cmd("SET").arg(&["{x}key2", "bar"]).execute(&mut con); + + assert_eq!( + redis::cmd("MGET") + .arg(&["{x}key1", "{x}key2"]) + .query(&mut con), + Ok(("foo".to_string(), b"bar".to_vec())) + ); + } + + #[test] + #[serial_test::serial] + fn test_cluster_with_username_and_password() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .username(RedisCluster::username().to_string()) + .password(RedisCluster::password().to_string()) + }, + false, + ); + cluster.disable_default_user(); + + let mut con = cluster.connection(); + + redis::cmd("SET") + .arg("{x}key1") + .arg(b"foo") + .execute(&mut con); + redis::cmd("SET").arg(&["{x}key2", "bar"]).execute(&mut con); + + assert_eq!( + redis::cmd("MGET") + .arg(&["{x}key1", "{x}key2"]) + .query(&mut con), + Ok(("foo".to_string(), b"bar".to_vec())) + ); + } + + #[test] + #[serial_test::serial] + fn test_cluster_with_bad_password() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .username(RedisCluster::username().to_string()) + .password("not the right password".to_string()) + }, + false, + ); + assert!(cluster.client.get_connection(None).is_err()); + } + + #[test] + #[serial_test::serial] + fn test_cluster_read_from_replicas() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 6, + 1, + |builder| builder.read_from_replicas(), + false, + ); + let mut con = cluster.connection(); + + // Write commands would go to the primary nodes + redis::cmd("SET") + .arg("{x}key1") + .arg(b"foo") + .execute(&mut con); + redis::cmd("SET").arg(&["{x}key2", "bar"]).execute(&mut con); + + // Read commands would go to the replica nodes + assert_eq!( + redis::cmd("MGET") + .arg(&["{x}key1", "{x}key2"]) + .query(&mut con), + Ok(("foo".to_string(), b"bar".to_vec())) + ); + } + + #[test] + #[serial_test::serial] + fn test_cluster_eval() { + let cluster = TestClusterContext::new(3, 0); + let mut con = cluster.connection(); + + let rv = redis::cmd("EVAL") + .arg( + r#" + redis.call("SET", KEYS[1], "1"); + redis.call("SET", KEYS[2], "2"); + return redis.call("MGET", KEYS[1], KEYS[2]); + "#, + ) + .arg("2") + .arg("{x}a") + .arg("{x}b") + .query(&mut con); + + assert_eq!(rv, Ok(("1".to_string(), "2".to_string()))); + } + + #[test] + #[serial_test::serial] + fn test_cluster_resp3() { + if use_protocol() == ProtocolVersion::RESP2 { + return; + } + let cluster = TestClusterContext::new(3, 0); + + let mut connection = cluster.connection(); + + let hello: std::collections::HashMap = + redis::cmd("HELLO").query(&mut connection).unwrap(); + assert_eq!(hello.get("proto").unwrap(), &Value::Int(3)); + + let _: () = connection.hset("hash", "foo", "baz").unwrap(); + let _: () = connection.hset("hash", "bar", "foobar").unwrap(); + let result: Value = connection.hgetall("hash").unwrap(); + + assert_eq!( + result, + Value::Map(vec![ + ( + Value::BulkString("foo".as_bytes().to_vec()), + Value::BulkString("baz".as_bytes().to_vec()) + ), + ( + Value::BulkString("bar".as_bytes().to_vec()), + Value::BulkString("foobar".as_bytes().to_vec()) + ) + ]) + ); + } + + #[test] + #[serial_test::serial] + fn test_cluster_multi_shard_commands() { + let cluster = TestClusterContext::new(3, 0); + + let mut connection = cluster.connection(); + + let res: String = connection + .mset(&[("foo", "bar"), ("bar", "foo"), ("baz", "bazz")]) + .unwrap(); + assert_eq!(res, "OK"); + let res: Vec = connection.mget(&["baz", "foo", "bar"]).unwrap(); + assert_eq!(res, vec!["bazz", "bar", "foo"]); + } + + #[test] + #[serial_test::serial] + #[cfg(feature = "script")] + fn test_cluster_script() { + let cluster = TestClusterContext::new(3, 0); + let mut con = cluster.connection(); + + let script = redis::Script::new( + r#" + redis.call("SET", KEYS[1], "1"); + redis.call("SET", KEYS[2], "2"); + return redis.call("MGET", KEYS[1], KEYS[2]); + "#, + ); + + let rv = script.key("{x}a").key("{x}b").invoke(&mut con); + assert_eq!(rv, Ok(("1".to_string(), "2".to_string()))); + } + + #[test] + #[serial_test::serial] + fn test_cluster_pipeline() { + let cluster = TestClusterContext::new(3, 0); + cluster.wait_for_cluster_up(); + let mut con = cluster.connection(); + + let resp = cluster_pipe() + .cmd("SET") + .arg("key_1") + .arg(42) + .query::>(&mut con) + .unwrap(); + + assert_eq!(resp, vec!["OK".to_string()]); + } + + #[test] + #[serial_test::serial] + fn test_cluster_pipeline_multiple_keys() { + use redis::FromRedisValue; + let cluster = TestClusterContext::new(3, 0); + cluster.wait_for_cluster_up(); + let mut con = cluster.connection(); + + let resp = cluster_pipe() + .cmd("HSET") + .arg("hash_1") + .arg("key_1") + .arg("value_1") + .cmd("ZADD") + .arg("zset") + .arg(1) + .arg("zvalue_2") + .query::>(&mut con) + .unwrap(); + + assert_eq!(resp, vec![1i64, 1i64]); + + let resp = cluster_pipe() + .cmd("HGET") + .arg("hash_1") + .arg("key_1") + .cmd("ZCARD") + .arg("zset") + .query::>(&mut con) + .unwrap(); + + let resp_1: String = FromRedisValue::from_redis_value(&resp[0]).unwrap(); + assert_eq!(resp_1, "value_1".to_string()); + + let resp_2: usize = FromRedisValue::from_redis_value(&resp[1]).unwrap(); + assert_eq!(resp_2, 1); + } + + #[test] + #[serial_test::serial] + fn test_cluster_pipeline_invalid_command() { + let cluster = TestClusterContext::new(3, 0); + cluster.wait_for_cluster_up(); + let mut con = cluster.connection(); + + let err = cluster_pipe() + .cmd("SET") + .arg("foo") + .arg(42) + .ignore() + .cmd(" SCRIPT kill ") + .query::<()>(&mut con) + .unwrap_err(); + + assert_eq!( + err.to_string(), + "This command cannot be safely routed in cluster mode - ClientError: Command 'SCRIPT KILL' can't be executed in a cluster pipeline." + ); + + let err = cluster_pipe().keys("*").query::<()>(&mut con).unwrap_err(); + + assert_eq!( + err.to_string(), + "This command cannot be safely routed in cluster mode - ClientError: Command 'KEYS' can't be executed in a cluster pipeline." + ); + } + + #[test] + #[serial_test::serial] + fn test_cluster_can_connect_to_server_that_sends_cluster_slots_without_host_name() { + let name = "test_cluster_can_connect_to_server_that_sends_cluster_slots_without_host_name"; + + let MockEnv { mut connection, .. } = MockEnv::new(name, move |cmd: &[u8], _| { + if contains_slice(cmd, b"PING") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![Value::Array(vec![ + Value::Int(0), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString("".as_bytes().to_vec()), + Value::Int(6379), + ]), + ])]))) + } else { + Err(Ok(Value::Nil)) + } + }); + + let value = cmd("GET").arg("test").query::(&mut connection); + + assert_eq!(value, Ok(Value::Nil)); + } + + #[test] + #[serial_test::serial] + fn test_cluster_can_connect_to_server_that_sends_cluster_slots_with_null_host_name() { + let name = + "test_cluster_can_connect_to_server_that_sends_cluster_slots_with_null_host_name"; + + let MockEnv { mut connection, .. } = MockEnv::new(name, move |cmd: &[u8], _| { + if contains_slice(cmd, b"PING") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![Value::Array(vec![ + Value::Int(0), + Value::Int(16383), + Value::Array(vec![Value::Nil, Value::Int(6379)]), + ])]))) + } else { + Err(Ok(Value::Nil)) + } + }); + + let value = cmd("GET").arg("test").query::(&mut connection); + + assert_eq!(value, Ok(Value::Nil)); + } + + #[test] + #[serial_test::serial] + fn test_cluster_can_connect_to_server_that_sends_cluster_slots_with_partial_nodes_with_unknown_host_name( + ) { + let name = "test_cluster_can_connect_to_server_that_sends_cluster_slots_with_partial_nodes_with_unknown_host_name"; + + let MockEnv { mut connection, .. } = MockEnv::new(name, move |cmd: &[u8], _| { + if contains_slice(cmd, b"PING") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![ + Value::Array(vec![ + Value::Int(0), + Value::Int(7000), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6379), + ]), + ]), + Value::Array(vec![ + Value::Int(7001), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString("?".as_bytes().to_vec()), + Value::Int(6380), + ]), + ]), + ]))) + } else { + Err(Ok(Value::Nil)) + } + }); + + let value = cmd("GET").arg("test").query::(&mut connection); + assert_eq!(value, Ok(Value::Nil)); + } + + #[test] + #[serial_test::serial] + fn test_cluster_pipeline_command_ordering() { + let cluster = TestClusterContext::new(3, 0); + cluster.wait_for_cluster_up(); + let mut con = cluster.connection(); + let mut pipe = cluster_pipe(); + + let mut queries = Vec::new(); + let mut expected = Vec::new(); + for i in 0..100 { + queries.push(format!("foo{i}")); + expected.push(format!("bar{i}")); + pipe.set(&queries[i], &expected[i]).ignore(); + } + pipe.execute(&mut con); + + pipe.clear(); + for q in &queries { + pipe.get(q); + } + + let got = pipe.query::>(&mut con).unwrap(); + assert_eq!(got, expected); + } + + #[test] + #[serial_test::serial] + #[ignore] // Flaky + fn test_cluster_pipeline_ordering_with_improper_command() { + let cluster = TestClusterContext::new(3, 0); + cluster.wait_for_cluster_up(); + let mut con = cluster.connection(); + let mut pipe = cluster_pipe(); + + let mut queries = Vec::new(); + let mut expected = Vec::new(); + for i in 0..10 { + if i == 5 { + pipe.cmd("hset").arg("foo").ignore(); + } else { + let query = format!("foo{i}"); + let r = format!("bar{i}"); + pipe.set(&query, &r).ignore(); + queries.push(query); + expected.push(r); + } + } + pipe.query::<()>(&mut con).unwrap_err(); + + std::thread::sleep(std::time::Duration::from_secs(5)); + + pipe.clear(); + for q in &queries { + pipe.get(q); + } + + let got = pipe.query::>(&mut con).unwrap(); + assert_eq!(got, expected); + } + + #[test] + #[serial_test::serial] + fn test_cluster_retries() { + let name = "tryagain"; + + let requests = atomic::AtomicUsize::new(0); + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(5), + name, + move |cmd: &[u8], _| { + respond_startup(name, cmd)?; + + match requests.fetch_add(1, atomic::Ordering::SeqCst) { + 0..=4 => Err(parse_redis_value(b"-TRYAGAIN mock\r\n")), + _ => Err(Ok(Value::BulkString(b"123".to_vec()))), + } + }, + ); + + let value = cmd("GET").arg("test").query::>(&mut connection); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + #[serial_test::serial] + fn test_cluster_exhaust_retries() { + let name = "tryagain_exhaust_retries"; + + let requests = Arc::new(atomic::AtomicUsize::new(0)); + + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(2), + name, + { + let requests = requests.clone(); + move |cmd: &[u8], _| { + respond_startup(name, cmd)?; + requests.fetch_add(1, atomic::Ordering::SeqCst); + Err(parse_redis_value(b"-TRYAGAIN mock\r\n")) + } + }, + ); + + let result = cmd("GET").arg("test").query::>(&mut connection); + + match result { + Ok(_) => panic!("result should be an error"), + Err(e) => match e.kind() { + ErrorKind::TryAgain => {} + _ => panic!("Expected TryAgain but got {:?}", e.kind()), + }, + } + assert_eq!(requests.load(atomic::Ordering::SeqCst), 3); + } + + #[test] + #[serial_test::serial] + fn test_cluster_move_error_when_new_node_is_added() { + let name = "rebuild_with_extra_nodes"; + + let requests = atomic::AtomicUsize::new(0); + let started = atomic::AtomicBool::new(false); + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::new(name, move |cmd: &[u8], port| { + if !started.load(atomic::Ordering::SeqCst) { + respond_startup(name, cmd)?; + } + started.store(true, atomic::Ordering::SeqCst); + + if contains_slice(cmd, b"PING") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + + match i { + // Respond that the key exists on a node that does not yet have a connection: + 0 => Err(parse_redis_value(b"-MOVED 123\r\n")), + // Respond with the new masters + 1 => Err(Ok(Value::Array(vec![ + Value::Array(vec![ + Value::Int(0), + Value::Int(1), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6379), + ]), + ]), + Value::Array(vec![ + Value::Int(2), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6380), + ]), + ]), + ]))), + _ => { + // Check that the correct node receives the request after rebuilding + assert_eq!(port, 6380); + Err(Ok(Value::BulkString(b"123".to_vec()))) + } + } + }); + + let value = cmd("GET").arg("test").query::>(&mut connection); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + #[serial_test::serial] + fn test_cluster_ask_redirect() { + let name = "node"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]), + name, + { + move |cmd: &[u8], port| { + respond_startup_two_nodes(name, cmd)?; + // Error twice with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + let count = completed.fetch_add(1, Ordering::SeqCst); + match port { + 6379 => match count { + 0 => Err(parse_redis_value(b"-ASK 14000 node:6380\r\n")), + _ => panic!("Node should not be called now"), + }, + 6380 => match count { + 1 => { + assert!(contains_slice(cmd, b"ASKING")); + Err(Ok(Value::Okay)) + } + 2 => { + assert!(contains_slice(cmd, b"GET")); + Err(Ok(Value::BulkString(b"123".to_vec()))) + } + _ => panic!("Node should not be called now"), + }, + _ => panic!("Wrong node"), + } + } + }, + ); + + let value = cmd("GET").arg("test").query::>(&mut connection); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + #[serial_test::serial] + fn test_cluster_ask_error_when_new_node_is_added() { + let name = "ask_with_extra_nodes"; + + let requests = atomic::AtomicUsize::new(0); + let started = atomic::AtomicBool::new(false); + + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::new(name, move |cmd: &[u8], port| { + if !started.load(atomic::Ordering::SeqCst) { + respond_startup(name, cmd)?; + } + started.store(true, atomic::Ordering::SeqCst); + + if contains_slice(cmd, b"PING") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + + match i { + // Respond that the key exists on a node that does not yet have a connection: + 0 => Err(parse_redis_value( + format!("-ASK 123 {name}:6380\r\n").as_bytes(), + )), + 1 => { + assert_eq!(port, 6380); + assert!(contains_slice(cmd, b"ASKING")); + Err(Ok(Value::Okay)) + } + 2 => { + assert_eq!(port, 6380); + assert!(contains_slice(cmd, b"GET")); + Err(Ok(Value::BulkString(b"123".to_vec()))) + } + _ => { + panic!("Unexpected request: {:?}", cmd); + } + } + }); + + let value = cmd("GET").arg("test").query::>(&mut connection); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + #[serial_test::serial] + fn test_cluster_replica_read() { + let name = "node"; + + // requests should route to replica + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica(name, cmd)?; + + match port { + 6380 => Err(Ok(Value::BulkString(b"123".to_vec()))), + _ => panic!("Wrong node"), + } + }, + ); + + let value = cmd("GET").arg("test").query::>(&mut connection); + assert_eq!(value, Ok(Some(123))); + + // requests should route to primary + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica(name, cmd)?; + match port { + 6379 => Err(Ok(Value::SimpleString("OK".into()))), + _ => panic!("Wrong node"), + } + }, + ); + + let value = cmd("SET") + .arg("test") + .arg("123") + .query::>(&mut connection); + assert_eq!(value, Ok(Some(Value::SimpleString("OK".to_owned())))); + } + + #[test] + #[serial_test::serial] + fn test_cluster_io_error() { + let name = "node"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(2), + name, + move |cmd: &[u8], port| { + respond_startup_two_nodes(name, cmd)?; + // Error twice with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + match port { + 6380 => panic!("Node should not be called"), + _ => match completed.fetch_add(1, Ordering::SeqCst) { + 0..=1 => Err(Err(RedisError::from(std::io::Error::new( + std::io::ErrorKind::ConnectionReset, + "mock-io-error", + )))), + _ => Err(Ok(Value::BulkString(b"123".to_vec()))), + }, + } + }, + ); + + let value = cmd("GET").arg("test").query::>(&mut connection); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + #[serial_test::serial] + fn test_cluster_non_retryable_error_should_not_retry() { + let name = "node"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { mut connection, .. } = MockEnv::new(name, { + let completed = completed.clone(); + move |cmd: &[u8], _| { + respond_startup_two_nodes(name, cmd)?; + // Error twice with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + completed.fetch_add(1, Ordering::SeqCst); + Err(Err((ErrorKind::ReadOnly, "").into())) + } + }); + + let value = cmd("GET").arg("test").query::>(&mut connection); + + match value { + Ok(_) => panic!("result should be an error"), + Err(e) => match e.kind() { + ErrorKind::ReadOnly => {} + _ => panic!("Expected ReadOnly but got {:?}", e.kind()), + }, + } + assert_eq!(completed.load(Ordering::SeqCst), 1); + } + + fn test_cluster_fan_out( + command: &'static str, + expected_ports: Vec, + slots_config: Option>, + ) { + let name = "node"; + let found_ports = Arc::new(std::sync::Mutex::new(Vec::new())); + let ports_clone = found_ports.clone(); + let mut cmd = redis::Cmd::new(); + for arg in command.split_whitespace() { + cmd.arg(arg); + } + let packed_cmd = cmd.get_packed_command(); + // requests should route to replica + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config( + name, + received_cmd, + slots_config.clone(), + )?; + if received_cmd == packed_cmd { + ports_clone.lock().unwrap().push(port); + return Err(Ok(Value::SimpleString("OK".into()))); + } + Ok(()) + }, + ); + + let _ = cmd.query::>(&mut connection); + found_ports.lock().unwrap().sort(); + // MockEnv creates 2 mock connections. + assert_eq!(*found_ports.lock().unwrap(), expected_ports); + } + + #[test] + #[serial_test::serial] + fn test_cluster_fan_out_to_all_primaries() { + test_cluster_fan_out("FLUSHALL", vec![6379, 6381], None); + } + + #[test] + #[serial_test::serial] + fn test_cluster_fan_out_to_all_nodes() { + test_cluster_fan_out("CONFIG SET", vec![6379, 6380, 6381, 6382], None); + } + + #[test] + #[serial_test::serial] + fn test_cluster_fan_out_out_once_to_each_primary_when_no_replicas_are_available() { + test_cluster_fan_out( + "CONFIG SET", + vec![6379, 6381], + Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: Vec::new(), + slot_range: (0..8191), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: Vec::new(), + slot_range: (8192..16383), + }, + ]), + ); + } + + #[test] + #[serial_test::serial] + fn test_cluster_fan_out_out_once_even_if_primary_has_multiple_slot_ranges() { + test_cluster_fan_out( + "CONFIG SET", + vec![6379, 6380, 6381, 6382], + Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380], + slot_range: (0..4000), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: vec![6382], + slot_range: (4001..8191), + }, + MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380], + slot_range: (8192..8200), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: vec![6382], + slot_range: (8201..16383), + }, + ]), + ); + } + + #[test] + #[serial_test::serial] + fn test_cluster_split_multi_shard_command_and_combine_arrays_of_values() { + let name = "test_cluster_split_multi_shard_command_and_combine_arrays_of_values"; + let mut cmd = cmd("MGET"); + cmd.arg("foo").arg("bar").arg("baz"); + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + let cmd_str = std::str::from_utf8(received_cmd).unwrap(); + let results = ["foo", "bar", "baz"] + .iter() + .filter_map(|expected_key| { + if cmd_str.contains(expected_key) { + Some(Value::BulkString( + format!("{expected_key}-{port}").into_bytes(), + )) + } else { + None + } + }) + .collect(); + Err(Ok(Value::Array(results))) + }, + ); + + let result = cmd.query::>(&mut connection).unwrap(); + assert_eq!(result, vec!["foo-6382", "bar-6380", "baz-6380"]); + } + + #[test] + #[serial_test::serial] + fn test_cluster_route_correctly_on_packed_transaction_with_single_node_requests() { + let name = "test_cluster_route_correctly_on_packed_transaction_with_single_node_requests"; + let mut pipeline = redis::pipe(); + pipeline.atomic().set("foo", "bar").get("foo"); + let packed_pipeline = pipeline.get_packed_pipeline(); + + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + if port == 6381 { + let results = vec![ + Value::BulkString("OK".as_bytes().to_vec()), + Value::BulkString("QUEUED".as_bytes().to_vec()), + Value::BulkString("QUEUED".as_bytes().to_vec()), + Value::Array(vec![ + Value::BulkString("OK".as_bytes().to_vec()), + Value::BulkString("bar".as_bytes().to_vec()), + ]), + ]; + return Err(Ok(Value::Array(results))); + } + Err(Err(RedisError::from(std::io::Error::new( + std::io::ErrorKind::ConnectionReset, + format!("wrong port: {port}"), + )))) + }, + ); + + let result = connection + .req_packed_commands(&packed_pipeline, 3, 1) + .unwrap(); + assert_eq!( + result, + vec![ + Value::BulkString("OK".as_bytes().to_vec()), + Value::BulkString("bar".as_bytes().to_vec()), + ] + ); + } + + #[test] + #[serial_test::serial] + fn test_cluster_route_correctly_on_packed_transaction_with_single_node_requests2() { + let name = "test_cluster_route_correctly_on_packed_transaction_with_single_node_requests2"; + let mut pipeline = redis::pipe(); + pipeline.atomic().set("foo", "bar").get("foo"); + let packed_pipeline = pipeline.get_packed_pipeline(); + let results = vec![ + Value::BulkString("OK".as_bytes().to_vec()), + Value::BulkString("QUEUED".as_bytes().to_vec()), + Value::BulkString("QUEUED".as_bytes().to_vec()), + Value::Array(vec![ + Value::BulkString("OK".as_bytes().to_vec()), + Value::BulkString("bar".as_bytes().to_vec()), + ]), + ]; + let expected_result = Value::Array(results); + let cloned_result = expected_result.clone(); + + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + if port == 6381 { + return Err(Ok(cloned_result.clone())); + } + Err(Err(RedisError::from(std::io::Error::new( + std::io::ErrorKind::ConnectionReset, + format!("wrong port: {port}"), + )))) + }, + ); + + let result = connection.req_packed_command(&packed_pipeline).unwrap(); + assert_eq!(result, expected_result); + } + + #[test] + #[serial_test::serial] + fn test_cluster_with_client_name() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.client_name(RedisCluster::client_name().to_string()), + false, + ); + let mut con = cluster.connection(); + let client_info: String = redis::cmd("CLIENT").arg("INFO").query(&mut con).unwrap(); + + let client_attrs = parse_client_info(&client_info); + + assert!( + client_attrs.contains_key("name"), + "Could not detect the 'name' attribute in CLIENT INFO output" + ); + + assert_eq!( + client_attrs["name"], + RedisCluster::client_name(), + "Incorrect client name, expecting: {}, got {}", + RedisCluster::client_name(), + client_attrs["name"] + ); + } + + #[test] + #[serial_test::serial] + fn test_cluster_can_be_created_with_partial_slot_coverage() { + let name = "test_cluster_can_be_created_with_partial_slot_coverage"; + let slots_config = Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![], + slot_range: (0..8000), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: vec![], + slot_range: (8201..16380), + }, + ]); + + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], _| { + respond_startup_with_replica_using_config( + name, + received_cmd, + slots_config.clone(), + )?; + Err(Ok(Value::SimpleString("PONG".into()))) + }, + ); + + let res = connection.req_command(&redis::cmd("PING")); + assert!(res.is_ok()); + } + + #[cfg(feature = "tls-rustls")] + mod mtls_test { + use super::*; + use crate::support::mtls_test::create_cluster_client_from_cluster; + use redis::ConnectionInfo; + + #[test] + #[serial_test::serial] + fn test_cluster_basics_with_mtls() { + let cluster = TestClusterContext::new_with_mtls(3, 0); + + let client = create_cluster_client_from_cluster(&cluster, true).unwrap(); + let mut con = client.get_connection(None).unwrap(); + + redis::cmd("SET") + .arg("{x}key1") + .arg(b"foo") + .execute(&mut con); + redis::cmd("SET").arg(&["{x}key2", "bar"]).execute(&mut con); + + assert_eq!( + redis::cmd("MGET") + .arg(&["{x}key1", "{x}key2"]) + .query(&mut con), + Ok(("foo".to_string(), b"bar".to_vec())) + ); + } + + #[test] + #[serial_test::serial] + fn test_cluster_should_not_connect_without_mtls() { + let cluster = TestClusterContext::new_with_mtls(3, 0); + + let client = create_cluster_client_from_cluster(&cluster, false).unwrap(); + let connection = client.get_connection(None); + + match cluster.cluster.servers.first().unwrap().connection_info() { + ConnectionInfo { + addr: redis::ConnectionAddr::TcpTls { .. }, + .. + } => { + if connection.is_ok() { + panic!("Must NOT be able to connect without client credentials if server accepts TLS"); + } + } + _ => { + if let Err(e) = connection { + panic!("Must be able to connect without client credentials if server does NOT accept TLS: {e:?}"); + } + } + } + } + } +} diff --git a/glide-core/redis-rs/redis/tests/test_cluster_async.rs b/glide-core/redis-rs/redis/tests/test_cluster_async.rs new file mode 100644 index 0000000000..7273f98702 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_cluster_async.rs @@ -0,0 +1,5242 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +#![cfg(feature = "cluster-async")] +mod support; + +use std::cell::Cell; +use tokio::sync::Mutex; + +use lazy_static::lazy_static; + +lazy_static! { + static ref CLUSTER_VERSION: Mutex> = Mutex::>::default(); +} + +/// Check if the current cluster version is less than `min_version`. +/// At first, the func check for the Valkey version and if none exists, then the Redis version is checked. +async fn engine_version_less_than(min_version: &str) -> bool { + let test_version = crate::get_cluster_version().await; + let min_version_usize = crate::version_to_usize(min_version).unwrap(); + if test_version < min_version_usize { + println!( + "The engine version is {:?}, which is lower than {:?}", + test_version, min_version + ); + return true; + } + return false; +} + +/// Static function to get the engine version. When version looks like 8.0.0 -> 80000 and 12.0.1 -> 120001. +async fn get_cluster_version() -> usize { + let cluster_version = CLUSTER_VERSION.lock().await; + if cluster_version.get() == 0 { + let cluster = crate::support::TestClusterContext::new(3, 0); + + let mut connection = cluster.async_connection(None).await; + + let cmd = redis::cmd("INFO"); + let info = connection + .route_command( + &cmd, + redis::cluster_routing::RoutingInfo::SingleNode( + redis::cluster_routing::SingleNodeRoutingInfo::Random, + ), + ) + .await + .unwrap(); + + let info_result = redis::from_owned_redis_value::(info).unwrap(); + + cluster_version.set( + parse_version_from_info(info_result.clone()) + .expect(format!("Invalid version string in INFO : {info_result}").as_str()), + ); + } + return cluster_version.get(); +} + +fn parse_version_from_info(info: String) -> Option { + // check for valkey_version + if let Some(version) = info + .lines() + .find_map(|line| line.strip_prefix("valkey_version:")) + { + return version_to_usize(version); + } + + // check for redis_version if no valkey_version was found + if let Some(version) = info + .lines() + .find_map(|line| line.strip_prefix("redis_version:")) + { + return version_to_usize(version); + } + None +} + +/// Takes a version string (e.g., 8.2.1) and converts it to a usize (e.g., 80201) +/// version 12.10.0 will became 121000 +fn version_to_usize(version: &str) -> Option { + version + .split('.') + .enumerate() + .map(|(index, part)| { + part.parse::() + .ok() + .map(|num| num * 10_usize.pow(2 * (2 - index) as u32)) + }) + .sum() +} + +#[cfg(test)] +mod cluster_async { + use std::{ + collections::{HashMap, HashSet}, + net::{IpAddr, SocketAddr}, + str::from_utf8, + sync::{ + atomic::{self, AtomicBool, AtomicI32, AtomicU16, AtomicU32, Ordering}, + Arc, + }, + time::Duration, + }; + + use futures::prelude::*; + use futures_time::{future::FutureExt, task::sleep}; + use once_cell::sync::Lazy; + use std::ops::Add; + + use redis::{ + aio::{ConnectionLike, MultiplexedConnection}, + cluster::ClusterClient, + cluster_async::{testing::MANAGEMENT_CONN_NAME, ClusterConnection, Connect}, + cluster_routing::{ + MultipleNodeRoutingInfo, Route, RoutingInfo, SingleNodeRoutingInfo, SlotAddr, + }, + cluster_topology::{get_slot, DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES}, + cmd, from_owned_redis_value, parse_redis_value, AsyncCommands, Cmd, ErrorKind, + FromRedisValue, GlideConnectionOptions, InfoDict, IntoConnectionInfo, ProtocolVersion, + PubSubChannelOrPattern, PubSubSubscriptionInfo, PubSubSubscriptionKind, PushInfo, PushKind, + RedisError, RedisFuture, RedisResult, Script, Value, + }; + + use crate::support::*; + use tokio::sync::mpsc; + fn broken_pipe_error() -> RedisError { + RedisError::from(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "mock-io-error", + )) + } + + fn validate_subscriptions( + pubsub_subs: &PubSubSubscriptionInfo, + notifications_rx: &mut mpsc::UnboundedReceiver, + allow_disconnects: bool, + ) { + let mut subscribe_cnt = + if let Some(exact_subs) = pubsub_subs.get(&PubSubSubscriptionKind::Exact) { + exact_subs.len() + } else { + 0 + }; + + let mut psubscribe_cnt = + if let Some(pattern_subs) = pubsub_subs.get(&PubSubSubscriptionKind::Pattern) { + pattern_subs.len() + } else { + 0 + }; + + let mut ssubscribe_cnt = + if let Some(sharded_subs) = pubsub_subs.get(&PubSubSubscriptionKind::Sharded) { + sharded_subs.len() + } else { + 0 + }; + + for _ in 0..(subscribe_cnt + psubscribe_cnt + ssubscribe_cnt) { + let result = notifications_rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data: _ } = result.unwrap(); + assert!( + kind == PushKind::Subscribe + || kind == PushKind::PSubscribe + || kind == PushKind::SSubscribe + || if allow_disconnects { + kind == PushKind::Disconnection + } else { + false + } + ); + if kind == PushKind::Subscribe { + subscribe_cnt -= 1; + } else if kind == PushKind::PSubscribe { + psubscribe_cnt -= 1; + } else if kind == PushKind::SSubscribe { + ssubscribe_cnt -= 1; + } + } + + assert!(subscribe_cnt == 0); + assert!(psubscribe_cnt == 0); + assert!(ssubscribe_cnt == 0); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_basic_cmd() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + cmd("SET") + .arg("test") + .arg("test_data") + .query_async(&mut connection) + .await?; + let res: String = cmd("GET") + .arg("test") + .clone() + .query_async(&mut connection) + .await?; + assert_eq!(res, "test_data"); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[tokio::test] + async fn test_routing_by_slot_to_replica_with_az_affinity_strategy_to_half_replicas() { + // Skip test if version is less then Valkey 8.0 + if crate::engine_version_less_than("8.0").await { + return; + } + + let replica_num: u16 = 4; + let primaries_num: u16 = 3; + let replicas_num_in_client_az = replica_num / 2; + let cluster = + TestClusterContext::new((replica_num * primaries_num) + primaries_num, replica_num); + let az: String = "us-east-1a".to_string(); + + let mut connection = cluster.async_connection(None).await; + let cluster_addresses: Vec<_> = cluster + .cluster + .servers + .iter() + .map(|server| server.connection_info()) + .collect(); + + let mut cmd = redis::cmd("CONFIG"); + cmd.arg(&["SET", "availability-zone", &az.clone()]); + + for _ in 0..replicas_num_in_client_az { + connection + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + 12182, // foo key is mapping to 12182 slot + SlotAddr::ReplicaRequired, + ))), + ) + .await + .unwrap(); + } + + let mut client = ClusterClient::builder(cluster_addresses.clone()) + .read_from(redis::cluster_slotmap::ReadFromReplicaStrategy::AZAffinity( + az.clone(), + )) + .build() + .unwrap() + .get_async_connection(None) + .await + .unwrap(); + + // Each replica in the client az will return the value of foo n times + let n = 4; + for _ in 0..n * replicas_num_in_client_az { + let mut cmd = redis::cmd("GET"); + cmd.arg("foo"); + let _res: RedisResult = cmd.query_async(&mut client).await; + } + + let mut cmd = redis::cmd("INFO"); + cmd.arg("ALL"); + let info = connection + .route_command( + &cmd, + RoutingInfo::MultiNode((MultipleNodeRoutingInfo::AllNodes, None)), + ) + .await + .unwrap(); + + let info_result = redis::from_owned_redis_value::>(info).unwrap(); + let get_cmdstat = format!("cmdstat_get:calls="); + let n_get_cmdstat = format!("cmdstat_get:calls={}", n); + let client_az = format!("availability_zone:{}", az); + + let mut matching_entries_count: usize = 0; + + for value in info_result.values() { + if value.contains(&get_cmdstat) { + if value.contains(&client_az) && value.contains(&n_get_cmdstat) { + matching_entries_count += 1; + } else { + panic!( + "Invalid entry found: {}. Expected cmdstat_get:calls={} and availability_zone={}", + value, n, az); + } + } + } + + assert_eq!( + (matching_entries_count.try_into() as Result).unwrap(), + replicas_num_in_client_az, + "Test failed: expected exactly '{}' entries with '{}' and '{}', found {}", + replicas_num_in_client_az, + get_cmdstat, + client_az, + matching_entries_count + ); + } + + #[tokio::test] + async fn test_routing_by_slot_to_replica_with_az_affinity_strategy_to_all_replicas() { + // Skip test if version is less then Valkey 8.0 + if crate::engine_version_less_than("8.0").await { + return; + } + + let replica_num: u16 = 4; + let primaries_num: u16 = 3; + let cluster = + TestClusterContext::new((replica_num * primaries_num) + primaries_num, replica_num); + let az: String = "us-east-1a".to_string(); + + let mut connection = cluster.async_connection(None).await; + let cluster_addresses: Vec<_> = cluster + .cluster + .servers + .iter() + .map(|server| server.connection_info()) + .collect(); + + let mut cmd = redis::cmd("CONFIG"); + cmd.arg(&["SET", "availability-zone", &az.clone()]); + + connection + .route_command( + &cmd, + RoutingInfo::MultiNode((MultipleNodeRoutingInfo::AllNodes, None)), + ) + .await + .unwrap(); + + let mut client = ClusterClient::builder(cluster_addresses.clone()) + .read_from(redis::cluster_slotmap::ReadFromReplicaStrategy::AZAffinity( + az.clone(), + )) + .build() + .unwrap() + .get_async_connection(None) + .await + .unwrap(); + + // Each replica will return the value of foo n times + let n = 4; + for _ in 0..(n * replica_num) { + let mut cmd = redis::cmd("GET"); + cmd.arg("foo"); + let _res: RedisResult = cmd.query_async(&mut client).await; + } + + let mut cmd = redis::cmd("INFO"); + cmd.arg("ALL"); + let info = connection + .route_command( + &cmd, + RoutingInfo::MultiNode((MultipleNodeRoutingInfo::AllNodes, None)), + ) + .await + .unwrap(); + + let info_result = redis::from_owned_redis_value::>(info).unwrap(); + let get_cmdstat = format!("cmdstat_get:calls="); + let n_get_cmdstat = format!("cmdstat_get:calls={}", n); + let client_az = format!("availability_zone:{}", az); + + let mut matching_entries_count: usize = 0; + + for value in info_result.values() { + if value.contains(&get_cmdstat) { + if value.contains(&client_az) && value.contains(&n_get_cmdstat) { + matching_entries_count += 1; + } else { + panic!( + "Invalid entry found: {}. Expected cmdstat_get:calls={} and availability_zone={}", + value, n, az); + } + } + } + + assert_eq!( + (matching_entries_count.try_into() as Result).unwrap(), + replica_num, + "Test failed: expected exactly '{}' entries with '{}' and '{}', found {}", + replica_num.to_string(), + get_cmdstat, + client_az, + matching_entries_count + ); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_basic_eval() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let res: String = cmd("EVAL") + .arg(r#"redis.call("SET", KEYS[1], ARGV[1]); return redis.call("GET", KEYS[1])"#) + .arg(1) + .arg("key") + .arg("test") + .query_async(&mut connection) + .await?; + assert_eq!(res, "test"); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_basic_script() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let res: String = Script::new( + r#"redis.call("SET", KEYS[1], ARGV[1]); return redis.call("GET", KEYS[1])"#, + ) + .key("key") + .arg("test") + .invoke_async(&mut connection) + .await?; + assert_eq!(res, "test"); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_route_flush_to_specific_node() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let _: () = connection.set("foo", "bar").await.unwrap(); + let _: () = connection.set("bar", "foo").await.unwrap(); + + let res: String = connection.get("foo").await.unwrap(); + assert_eq!(res, "bar".to_string()); + let res2: Option = connection.get("bar").await.unwrap(); + assert_eq!(res2, Some("foo".to_string())); + + let route = + redis::cluster_routing::Route::new(1, redis::cluster_routing::SlotAddr::Master); + let single_node_route = + redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(route); + let routing = RoutingInfo::SingleNode(single_node_route); + assert_eq!( + connection + .route_command(&redis::cmd("FLUSHALL"), routing) + .await + .unwrap(), + Value::Okay + ); + let res: String = connection.get("foo").await.unwrap(); + assert_eq!(res, "bar".to_string()); + let res2: Option = connection.get("bar").await.unwrap(); + assert_eq!(res2, None); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_route_flush_to_node_by_address() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let mut cmd = redis::cmd("INFO"); + // The other sections change with time. + // TODO - after we remove support of redis 6, we can add more than a single section - .arg("Persistence").arg("Memory").arg("Replication") + cmd.arg("Clients"); + let value = connection + .route_command( + &cmd, + RoutingInfo::MultiNode((MultipleNodeRoutingInfo::AllNodes, None)), + ) + .await + .unwrap(); + + let info_by_address = from_owned_redis_value::>(value).unwrap(); + // find the info of the first returned node + let (address, info) = info_by_address.into_iter().next().unwrap(); + let mut split_address = address.split(':'); + let host = split_address.next().unwrap().to_string(); + let port = split_address.next().unwrap().parse().unwrap(); + + let value = connection + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { host, port }), + ) + .await + .unwrap(); + let new_info = from_owned_redis_value::(value).unwrap(); + + assert_eq!(new_info, info); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_route_info_to_nodes() { + let cluster = TestClusterContext::new(12, 1); + + let split_to_addresses_and_info = |res| -> (Vec, Vec) { + if let Value::Map(values) = res { + let mut pairs: Vec<_> = values + .into_iter() + .map(|(key, value)| { + ( + redis::from_redis_value::(&key).unwrap(), + redis::from_redis_value::(&value).unwrap(), + ) + }) + .collect(); + pairs.sort_by(|(address1, _), (address2, _)| address1.cmp(address2)); + pairs.into_iter().unzip() + } else { + unreachable!("{:?}", res); + } + }; + + block_on_all(async move { + let cluster_addresses: Vec<_> = cluster + .cluster + .servers + .iter() + .map(|server| server.connection_info()) + .collect(); + let client = ClusterClient::builder(cluster_addresses.clone()) + .read_from_replicas() + .build()?; + let mut connection = client.get_async_connection(None).await?; + + let route_to_all_nodes = redis::cluster_routing::MultipleNodeRoutingInfo::AllNodes; + let routing = RoutingInfo::MultiNode((route_to_all_nodes, None)); + let res = connection + .route_command(&redis::cmd("INFO"), routing) + .await + .unwrap(); + let (addresses, infos) = split_to_addresses_and_info(res); + + let mut cluster_addresses: Vec<_> = cluster_addresses + .into_iter() + .map(|info| info.addr.to_string()) + .collect(); + cluster_addresses.sort(); + + assert_eq!(addresses.len(), 12); + assert_eq!(addresses, cluster_addresses); + assert_eq!(infos.len(), 12); + for i in 0..12 { + let split: Vec<_> = addresses[i].split(':').collect(); + assert!(infos[i].contains(&format!("tcp_port:{}", split[1]))); + } + + let route_to_all_primaries = + redis::cluster_routing::MultipleNodeRoutingInfo::AllMasters; + let routing = RoutingInfo::MultiNode((route_to_all_primaries, None)); + let res = connection + .route_command(&redis::cmd("INFO"), routing) + .await + .unwrap(); + let (addresses, infos) = split_to_addresses_and_info(res); + assert_eq!(addresses.len(), 6); + assert_eq!(infos.len(), 6); + // verify that all primaries have the correct port & host, and are marked as primaries. + for i in 0..6 { + assert!(cluster_addresses.contains(&addresses[i])); + let split: Vec<_> = addresses[i].split(':').collect(); + assert!(infos[i].contains(&format!("tcp_port:{}", split[1]))); + assert!(infos[i].contains("role:primary") || infos[i].contains("role:master")); + } + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_resp3() { + if use_protocol() == ProtocolVersion::RESP2 { + return; + } + block_on_all(async move { + let cluster = TestClusterContext::new(3, 0); + + let mut connection = cluster.async_connection(None).await; + + let hello: HashMap = redis::cmd("HELLO") + .query_async(&mut connection) + .await + .unwrap(); + assert_eq!(hello.get("proto").unwrap(), &Value::Int(3)); + + let _: () = connection.hset("hash", "foo", "baz").await.unwrap(); + let _: () = connection.hset("hash", "bar", "foobar").await.unwrap(); + let result: Value = connection.hgetall("hash").await.unwrap(); + + assert_eq!( + result, + Value::Map(vec![ + ( + Value::BulkString("foo".as_bytes().to_vec()), + Value::BulkString("baz".as_bytes().to_vec()) + ), + ( + Value::BulkString("bar".as_bytes().to_vec()), + Value::BulkString("foobar".as_bytes().to_vec()) + ) + ]) + ); + + Ok(()) + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_basic_pipe() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let mut pipe = redis::pipe(); + pipe.add_command(cmd("SET").arg("test").arg("test_data").clone()); + pipe.add_command(cmd("SET").arg("{test}3").arg("test_data3").clone()); + pipe.query_async(&mut connection).await?; + let res: String = connection.get("test").await?; + assert_eq!(res, "test_data"); + let res: String = connection.get("{test}3").await?; + assert_eq!(res, "test_data3"); + Ok::<_, RedisError>(()) + }) + .unwrap() + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_multi_shard_commands() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + + let res: String = connection + .mset(&[("foo", "bar"), ("bar", "foo"), ("baz", "bazz")]) + .await?; + assert_eq!(res, "OK"); + let res: Vec = connection.mget(&["baz", "foo", "bar"]).await?; + assert_eq!(res, vec!["bazz", "bar", "foo"]); + Ok::<_, RedisError>(()) + }) + .unwrap() + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_basic_failover() { + block_on_all(async move { + test_failover(&TestClusterContext::new(6, 1), 10, 123, false).await; + Ok::<_, RedisError>(()) + }) + .unwrap() + } + + async fn do_failover( + redis: &mut redis::aio::MultiplexedConnection, + ) -> Result<(), anyhow::Error> { + cmd("CLUSTER").arg("FAILOVER").query_async(redis).await?; + Ok(()) + } + + // parameter `_mtls_enabled` can only be used if `feature = tls-rustls` is active + #[allow(dead_code)] + async fn test_failover( + env: &TestClusterContext, + requests: i32, + value: i32, + _mtls_enabled: bool, + ) { + let completed = Arc::new(AtomicI32::new(0)); + + let connection = env.async_connection(None).await; + let mut node_conns: Vec = Vec::new(); + + 'outer: loop { + node_conns.clear(); + let cleared_nodes = async { + for server in env.cluster.iter_servers() { + let addr = server.client_addr(); + + #[cfg(feature = "tls-rustls")] + let client = build_single_client( + server.connection_info(), + &server.tls_paths, + _mtls_enabled, + ) + .unwrap_or_else(|e| panic!("Failed to connect to '{addr}': {e}")); + + #[cfg(not(feature = "tls-rustls"))] + let client = redis::Client::open(server.connection_info()) + .unwrap_or_else(|e| panic!("Failed to connect to '{addr}': {e}")); + + let mut conn = client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await + .unwrap_or_else(|e| panic!("Failed to get connection: {e}")); + + let info: InfoDict = redis::Cmd::new() + .arg("INFO") + .query_async(&mut conn) + .await + .expect("INFO"); + let role: String = info.get("role").expect("cluster role"); + + if role == "master" { + tokio::time::timeout(std::time::Duration::from_secs(3), async { + Ok(redis::Cmd::new() + .arg("FLUSHALL") + .query_async(&mut conn) + .await?) + }) + .await + .unwrap_or_else(|err| Err(anyhow::Error::from(err)))?; + } + + node_conns.push(conn); + } + Ok::<_, anyhow::Error>(()) + } + .await; + match cleared_nodes { + Ok(()) => break 'outer, + Err(err) => { + // Failed to clear the databases, retry + tracing::warn!("{}", err); + } + } + } + + (0..requests + 1) + .map(|i| { + let mut connection = connection.clone(); + let mut node_conns = node_conns.clone(); + let completed = completed.clone(); + async move { + if i == requests / 2 { + // Failover all the nodes, error only if all the failover requests error + let mut results = future::join_all( + node_conns + .iter_mut() + .map(|conn| Box::pin(do_failover(conn))), + ) + .await; + if results.iter().all(|res| res.is_err()) { + results.pop().unwrap() + } else { + Ok::<_, anyhow::Error>(()) + } + } else { + let key = format!("test-{value}-{i}"); + cmd("SET") + .arg(&key) + .arg(i) + .clone() + .query_async(&mut connection) + .await?; + let res: i32 = cmd("GET") + .arg(key) + .clone() + .query_async(&mut connection) + .await?; + assert_eq!(res, i); + completed.fetch_add(1, Ordering::SeqCst); + Ok::<_, anyhow::Error>(()) + } + } + }) + .collect::>() + .try_collect() + .await + .unwrap_or_else(|e| panic!("{e}")); + + assert_eq!( + completed.load(Ordering::SeqCst), + requests, + "Some requests never completed!" + ); + } + + static ERROR: Lazy = Lazy::new(Default::default); + + #[derive(Clone)] + struct ErrorConnection { + inner: MultiplexedConnection, + } + + impl Connect for ErrorConnection { + fn connect<'a, T>( + info: T, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + socket_addr: Option, + glide_connection_options: GlideConnectionOptions, + ) -> RedisFuture<'a, (Self, Option)> + where + T: IntoConnectionInfo + Send + 'a, + { + Box::pin(async move { + let (inner, _ip) = MultiplexedConnection::connect( + info, + response_timeout, + connection_timeout, + socket_addr, + glide_connection_options, + ) + .await?; + Ok((ErrorConnection { inner }, None)) + }) + } + } + + impl ConnectionLike for ErrorConnection { + fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> { + if ERROR.load(Ordering::SeqCst) { + Box::pin(async move { Err(RedisError::from((redis::ErrorKind::Moved, "ERROR"))) }) + } else { + self.inner.req_packed_command(cmd) + } + } + + fn req_packed_commands<'a>( + &'a mut self, + pipeline: &'a redis::Pipeline, + offset: usize, + count: usize, + ) -> RedisFuture<'a, Vec> { + self.inner.req_packed_commands(pipeline, offset, count) + } + + fn get_db(&self) -> i64 { + self.inner.get_db() + } + + fn is_closed(&self) -> bool { + true + } + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_error_in_inner_connection() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut con = cluster.async_generic_connection::().await; + + ERROR.store(false, Ordering::SeqCst); + let r: Option = con.get("test").await?; + assert_eq!(r, None::); + + ERROR.store(true, Ordering::SeqCst); + + let result: RedisResult<()> = con.get("test").await; + assert_eq!( + result, + Err(RedisError::from((redis::ErrorKind::Moved, "ERROR"))) + ); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_can_connect_to_server_that_sends_cluster_slots_without_host_name() { + let name = + "test_async_cluster_can_connect_to_server_that_sends_cluster_slots_without_host_name"; + + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::new(name, move |cmd: &[u8], _| { + if contains_slice(cmd, b"PING") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![Value::Array(vec![ + Value::Int(0), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString("".as_bytes().to_vec()), + Value::Int(6379), + ]), + ])]))) + } else { + Err(Ok(Value::Nil)) + } + }); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Value>(&mut connection), + ); + + assert_eq!(value, Ok(Value::Nil)); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_can_connect_to_server_that_sends_cluster_slots_with_null_host_name() { + let name = + "test_async_cluster_can_connect_to_server_that_sends_cluster_slots_with_null_host_name"; + + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::new(name, move |cmd: &[u8], _| { + if contains_slice(cmd, b"PING") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![Value::Array(vec![ + Value::Int(0), + Value::Int(16383), + Value::Array(vec![Value::Nil, Value::Int(6379)]), + ])]))) + } else { + Err(Ok(Value::Nil)) + } + }); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Value>(&mut connection), + ); + + assert_eq!(value, Ok(Value::Nil)); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_cannot_connect_to_server_with_unknown_host_name() { + let name = "test_async_cluster_cannot_connect_to_server_with_unknown_host_name"; + let handler = move |cmd: &[u8], _| { + if contains_slice(cmd, b"PING") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![Value::Array(vec![ + Value::Int(0), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString("?".as_bytes().to_vec()), + Value::Int(6379), + ]), + ])]))) + } else { + Err(Ok(Value::Nil)) + } + }; + let client_builder = ClusterClient::builder(vec![&*format!("redis://{name}")]); + let client: ClusterClient = client_builder.build().unwrap(); + let _handler = MockConnectionBehavior::register_new(name, Arc::new(handler)); + let connection = client.get_generic_connection::(None); + assert!(connection.is_err()); + let err = connection.err().unwrap(); + assert!(err + .to_string() + .contains("Error parsing slots: No healthy node found")) + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_can_connect_to_server_that_sends_cluster_slots_with_partial_nodes_with_unknown_host_name( + ) { + let name = "test_async_cluster_can_connect_to_server_that_sends_cluster_slots_with_partial_nodes_with_unknown_host_name"; + + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::new(name, move |cmd: &[u8], _| { + if contains_slice(cmd, b"PING") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![ + Value::Array(vec![ + Value::Int(0), + Value::Int(7000), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6379), + ]), + ]), + Value::Array(vec![ + Value::Int(7001), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString("?".as_bytes().to_vec()), + Value::Int(6380), + ]), + ]), + ]))) + } else { + Err(Ok(Value::Nil)) + } + }); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Value>(&mut connection), + ); + + assert_eq!(value, Ok(Value::Nil)); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_retries() { + let name = "tryagain"; + + let requests = atomic::AtomicUsize::new(0); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(5), + name, + move |cmd: &[u8], _| { + respond_startup(name, cmd)?; + + match requests.fetch_add(1, atomic::Ordering::SeqCst) { + 0..=4 => Err(parse_redis_value(b"-TRYAGAIN mock\r\n")), + _ => Err(Ok(Value::BulkString(b"123".to_vec()))), + } + }, + ); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_tryagain_exhaust_retries() { + let name = "tryagain_exhaust_retries"; + + let requests = Arc::new(atomic::AtomicUsize::new(0)); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(2), + name, + { + let requests = requests.clone(); + move |cmd: &[u8], _| { + respond_startup(name, cmd)?; + requests.fetch_add(1, atomic::Ordering::SeqCst); + Err(parse_redis_value(b"-TRYAGAIN mock\r\n")) + } + }, + ); + + let result = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + match result { + Ok(_) => panic!("result should be an error"), + Err(e) => match e.kind() { + ErrorKind::TryAgain => {} + _ => panic!("Expected TryAgain but got {:?}", e.kind()), + }, + } + assert_eq!(requests.load(atomic::Ordering::SeqCst), 3); + } + + // Obtain the view index associated with the node with [called_port] port + fn get_node_view_index(num_of_views: usize, ports: &Vec, called_port: u16) -> usize { + let port_index = ports + .iter() + .position(|&p| p == called_port) + .unwrap_or_else(|| { + panic!( + "CLUSTER SLOTS was called with unknown port: {called_port}; Known ports: {:?}", + ports + ) + }); + // If we have less views than nodes, use the last view + if port_index < num_of_views { + port_index + } else { + num_of_views - 1 + } + } + #[test] + #[serial_test::serial] + fn test_async_cluster_move_error_when_new_node_is_added() { + let name = "rebuild_with_extra_nodes"; + + let requests = atomic::AtomicUsize::new(0); + let started = atomic::AtomicBool::new(false); + let refreshed_map = HashMap::from([ + (6379, atomic::AtomicBool::new(false)), + (6380, atomic::AtomicBool::new(false)), + ]); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::new(name, move |cmd: &[u8], port| { + if !started.load(atomic::Ordering::SeqCst) { + respond_startup(name, cmd)?; + } + started.store(true, atomic::Ordering::SeqCst); + + if contains_slice(cmd, b"PING") || contains_slice(cmd, b"SETNAME") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + + let is_get_cmd = contains_slice(cmd, b"GET"); + let get_response = Err(Ok(Value::BulkString(b"123".to_vec()))); + match i { + // Respond that the key exists on a node that does not yet have a connection: + 0 => Err(parse_redis_value( + format!("-MOVED 123 {name}:6380\r\n").as_bytes(), + )), + _ => { + if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + // Should not attempt to refresh slots more than once, + // so we expect a single CLUSTER NODES request for each node + assert!(!refreshed_map + .get(&port) + .unwrap() + .swap(true, Ordering::SeqCst)); + Err(Ok(Value::Array(vec![ + Value::Array(vec![ + Value::Int(0), + Value::Int(1), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6379), + ]), + ]), + Value::Array(vec![ + Value::Int(2), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6380), + ]), + ]), + ]))) + } else { + assert_eq!(port, 6380); + assert!(is_get_cmd, "{:?}", std::str::from_utf8(cmd)); + get_response + } + } + } + }); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + fn test_async_cluster_refresh_topology_after_moved_assert_get_succeed_and_expected_retries( + slots_config_vec: Vec>, + ports: Vec, + has_a_majority: bool, + ) { + assert!(!ports.is_empty() && !slots_config_vec.is_empty()); + let name = "refresh_topology_moved"; + let num_of_nodes = ports.len(); + let requests = atomic::AtomicUsize::new(0); + let started = atomic::AtomicBool::new(false); + let refresh_calls = Arc::new(atomic::AtomicUsize::new(0)); + let refresh_calls_cloned = refresh_calls.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + // Disable the rate limiter to refresh slots immediately on all MOVED errors. + .slots_refresh_rate_limit(Duration::from_secs(0), 0), + name, + move |cmd: &[u8], port| { + if !started.load(atomic::Ordering::SeqCst) { + respond_startup_with_replica_using_config( + name, + cmd, + Some(slots_config_vec[0].clone()), + )?; + } + started.store(true, atomic::Ordering::SeqCst); + + if contains_slice(cmd, b"PING") || contains_slice(cmd, b"SETNAME") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + let is_get_cmd = contains_slice(cmd, b"GET"); + let get_response = Err(Ok(Value::BulkString(b"123".to_vec()))); + let moved_node = ports[0]; + match i { + // Respond that the key exists on a node that does not yet have a connection: + 0 => Err(parse_redis_value( + format!("-MOVED 123 {name}:{moved_node}\r\n").as_bytes(), + )), + _ => { + if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + refresh_calls_cloned.fetch_add(1, atomic::Ordering::SeqCst); + let view_index = + get_node_view_index(slots_config_vec.len(), &ports, port); + Err(Ok(create_topology_from_config( + name, + slots_config_vec[view_index].clone(), + ))) + } else { + assert_eq!(port, moved_node); + assert!(is_get_cmd, "{:?}", std::str::from_utf8(cmd)); + get_response + } + } + } + }, + ); + runtime.block_on(async move { + let res = cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection) + .await; + assert_eq!(res, Ok(Some(123))); + // If there is a majority in the topology views, or if it's a 2-nodes cluster, we shall be able to calculate the topology on the first try, + // so each node will be queried only once with CLUSTER SLOTS. + // Otherwise, if we don't have a majority, we expect to see the refresh_slots function being called with the maximum retry number. + let expected_calls = if has_a_majority || num_of_nodes == 2 {num_of_nodes} else {DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES * num_of_nodes}; + let mut refreshed_calls = 0; + for _ in 0..100 { + refreshed_calls = refresh_calls.load(atomic::Ordering::Relaxed); + if refreshed_calls == expected_calls { + return; + } else { + let sleep_duration = core::time::Duration::from_millis(100); + #[cfg(feature = "tokio-comp")] + tokio::time::sleep(sleep_duration).await; + } + } + panic!("Failed to reach to the expected topology refresh retries. Found={refreshed_calls}, Expected={expected_calls}") + }); + } + + fn test_async_cluster_refresh_slots_rate_limiter_helper( + slots_config_vec: Vec>, + ports: Vec, + should_skip: bool, + ) { + // This test queries GET, which returns a MOVED error. If `should_skip` is true, + // it indicates that we should skip refreshing slots because the specified time + // duration since the last refresh slots call has not yet passed. In this case, + // we expect CLUSTER SLOTS not to be called on the nodes after receiving the + // MOVED error. + + // If `should_skip` is false, we verify that if the MOVED error occurs after the + // time duration of the rate limiter has passed, the refresh slots operation + // should not be skipped. We assert this by expecting calls to CLUSTER SLOTS on + // all nodes. + let test_name = format!( + "test_async_cluster_refresh_slots_rate_limiter_helper_{}", + if should_skip { + "should_skip" + } else { + "not_skipping_waiting_time_passed" + } + ); + + let requests = atomic::AtomicUsize::new(0); + let started = atomic::AtomicBool::new(false); + let refresh_calls = Arc::new(atomic::AtomicUsize::new(0)); + let refresh_calls_cloned = Arc::clone(&refresh_calls); + let wait_duration = Duration::from_millis(10); + let num_of_nodes = ports.len(); + + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{test_name}")]) + .slots_refresh_rate_limit(wait_duration, 0), + test_name.clone().as_str(), + move |cmd: &[u8], port| { + if !started.load(atomic::Ordering::SeqCst) { + respond_startup_with_replica_using_config( + test_name.as_str(), + cmd, + Some(slots_config_vec[0].clone()), + )?; + started.store(true, atomic::Ordering::SeqCst); + } + + if contains_slice(cmd, b"PING") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + let is_get_cmd = contains_slice(cmd, b"GET"); + let get_response = Err(Ok(Value::BulkString(b"123".to_vec()))); + let moved_node = ports[0]; + match i { + // The first request calls are the starting calls for each GET command where we want to respond with MOVED error + 0 => { + if !should_skip { + // Wait for the wait duration to pass + std::thread::sleep(wait_duration.add(Duration::from_millis(10))); + } + Err(parse_redis_value( + format!("-MOVED 123 {test_name}:{moved_node}\r\n").as_bytes(), + )) + } + _ => { + if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + refresh_calls_cloned.fetch_add(1, atomic::Ordering::SeqCst); + let view_index = + get_node_view_index(slots_config_vec.len(), &ports, port); + Err(Ok(create_topology_from_config( + test_name.as_str(), + slots_config_vec[view_index].clone(), + ))) + } else { + // Even if the slots weren't refreshed we still expect the command to be + // routed by the redirect host and port it received in the moved error + assert_eq!(port, moved_node); + assert!(is_get_cmd, "{:?}", std::str::from_utf8(cmd)); + get_response + } + } + } + }, + ); + + runtime.block_on(async move { + // First GET request should raise MOVED error and then refresh slots + let res = cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection) + .await; + assert_eq!(res, Ok(Some(123))); + + // We should skip is false, we should call CLUSTER SLOTS once per node + let expected_calls = if should_skip { + 0 + } else { + num_of_nodes + }; + for _ in 0..4 { + if refresh_calls.load(atomic::Ordering::Relaxed) == expected_calls { + return Ok::<_, RedisError>(()); + } + let _ = sleep(Duration::from_millis(50).into()).await; + } + panic!("Refresh slots wasn't called as expected!\nExpected CLUSTER SLOTS calls: {}, actual calls: {:?}", expected_calls, refresh_calls.load(atomic::Ordering::Relaxed)); + }).unwrap() + } + + fn test_async_cluster_refresh_topology_in_client_init_get_succeed( + slots_config_vec: Vec>, + ports: Vec, + ) { + assert!(!ports.is_empty() && !slots_config_vec.is_empty()); + let name = "refresh_topology_client_init"; + let started = atomic::AtomicBool::new(false); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder::( + ports + .iter() + .map(|port| format!("redis://{name}:{port}")) + .collect::>(), + ), + name, + move |cmd: &[u8], port| { + let is_started = started.load(atomic::Ordering::SeqCst); + if !is_started { + if contains_slice(cmd, b"PING") || contains_slice(cmd, b"SETNAME") { + return Err(Ok(Value::SimpleString("OK".into()))); + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + let view_index = get_node_view_index(slots_config_vec.len(), &ports, port); + return Err(Ok(create_topology_from_config( + name, + slots_config_vec[view_index].clone(), + ))); + } else if contains_slice(cmd, b"READONLY") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + } + started.store(true, atomic::Ordering::SeqCst); + if contains_slice(cmd, b"PING") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + let is_get_cmd = contains_slice(cmd, b"GET"); + let get_response = Err(Ok(Value::BulkString(b"123".to_vec()))); + { + assert!(is_get_cmd, "{:?}", std::str::from_utf8(cmd)); + get_response + } + }, + ); + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + fn generate_topology_view( + ports: &[u16], + interval: usize, + full_slot_coverage: bool, + ) -> Vec { + let mut slots_res = vec![]; + let mut start_pos: usize = 0; + for (idx, port) in ports.iter().enumerate() { + let end_pos: usize = if idx == ports.len() - 1 && full_slot_coverage { + 16383 + } else { + start_pos + interval + }; + let mock_slot = MockSlotRange { + primary_port: *port, + replica_ports: vec![], + slot_range: (start_pos as u16..end_pos as u16), + }; + slots_res.push(mock_slot); + start_pos = end_pos + 1; + } + slots_res + } + + fn get_ports(num_of_nodes: usize) -> Vec { + (6379_u16..6379 + num_of_nodes as u16).collect() + } + + fn get_no_majority_topology_view(ports: &[u16]) -> Vec> { + let mut result = vec![]; + let mut full_coverage = true; + for i in 0..ports.len() { + result.push(generate_topology_view(ports, i + 1, full_coverage)); + full_coverage = !full_coverage; + } + result + } + + fn get_topology_with_majority(ports: &[u16]) -> Vec> { + let view: Vec = generate_topology_view(ports, 10, true); + let result: Vec<_> = ports.iter().map(|_| view.clone()).collect(); + result + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_refresh_topology_after_moved_error_all_nodes_agree_get_succeed() { + let ports = get_ports(3); + test_async_cluster_refresh_topology_after_moved_assert_get_succeed_and_expected_retries( + get_topology_with_majority(&ports), + ports, + true, + ); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_refresh_topology_in_client_init_all_nodes_agree_get_succeed() { + let ports = get_ports(3); + test_async_cluster_refresh_topology_in_client_init_get_succeed( + get_topology_with_majority(&ports), + ports, + ); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_refresh_topology_after_moved_error_with_no_majority_get_succeed() { + for num_of_nodes in 2..4 { + let ports = get_ports(num_of_nodes); + test_async_cluster_refresh_topology_after_moved_assert_get_succeed_and_expected_retries( + get_no_majority_topology_view(&ports), + ports, + false, + ); + } + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_refresh_topology_in_client_init_with_no_majority_get_succeed() { + for num_of_nodes in 2..4 { + let ports = get_ports(num_of_nodes); + test_async_cluster_refresh_topology_in_client_init_get_succeed( + get_no_majority_topology_view(&ports), + ports, + ); + } + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_refresh_topology_even_with_zero_retries() { + let name = "test_async_cluster_refresh_topology_even_with_zero_retries"; + + let should_refresh = atomic::AtomicBool::new(false); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(0) + // Disable the rate limiter to refresh slots immediately on the MOVED error. + .slots_refresh_rate_limit(Duration::from_secs(0), 0), + name, + move |cmd: &[u8], port| { + if !should_refresh.load(atomic::Ordering::SeqCst) { + respond_startup(name, cmd)?; + } + + if contains_slice(cmd, b"PING") || contains_slice(cmd, b"SETNAME") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + return Err(Ok(Value::Array(vec![ + Value::Array(vec![ + Value::Int(0), + Value::Int(1), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6379), + ]), + ]), + Value::Array(vec![ + Value::Int(2), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6380), + ]), + ]), + ]))); + } + + if contains_slice(cmd, b"GET") { + let get_response = Err(Ok(Value::BulkString(b"123".to_vec()))); + match port { + 6380 => get_response, + // Respond that the key exists on a node that does not yet have a connection: + _ => { + // Should not attempt to refresh slots more than once: + assert!(!should_refresh.swap(true, Ordering::SeqCst)); + Err(parse_redis_value( + format!("-MOVED 123 {name}:6380\r\n").as_bytes(), + )) + } + } + } else { + panic!("unexpected command {cmd:?}") + } + }, + ); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + // The user should receive an initial error, because there are no retries and the first request failed. + assert_eq!( + value, + Err(RedisError::from(( + ErrorKind::Moved, + "An error was signalled by the server", + "test_async_cluster_refresh_topology_even_with_zero_retries:6380".to_string() + ))) + ); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + fn test_async_cluster_update_slots_based_on_moved_error_indicates_slot_migration() { + // This test simulates the scenario where the client receives a MOVED error indicating that a key is now + // stored on the primary node of another shard. + // It ensures that the new slot now owned by the primary and its associated replicas. + let name = "test_async_cluster_update_slots_based_on_moved_error_indicates_slot_migration"; + let slots_config = vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![7000], + slot_range: (0..8000), + }, + MockSlotRange { + primary_port: 6380, + replica_ports: vec![7001], + slot_range: (8001..16380), + }, + ]; + + let moved_from_port = 6379; + let moved_to_port = 6380; + let new_shard_replica_port = 7001; + + // Tracking moved and replica requests for validation + let moved_requests = Arc::new(atomic::AtomicUsize::new(0)); + let cloned_moved_requests = moved_requests.clone(); + let replica_requests = Arc::new(atomic::AtomicUsize::new(0)); + let cloned_replica_requests = moved_requests.clone(); + + // Test key and slot + let key = "test"; + let key_slot = 6918; + + // Mock environment setup + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .slots_refresh_rate_limit(Duration::from_secs(1000000), 0) // Rate limiter to disable slot refresh + .read_from_replicas(), // Allow reads from replicas + name, + move |cmd: &[u8], port| { + if contains_slice(cmd, b"PING") + || contains_slice(cmd, b"SETNAME") + || contains_slice(cmd, b"READONLY") + { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + let slots = create_topology_from_config(name, slots_config.clone()); + return Err(Ok(slots)); + } + + if contains_slice(cmd, b"SET") { + if port == moved_to_port { + // Simulate primary OK response + Err(Ok(Value::SimpleString("OK".into()))) + } else if port == moved_from_port { + // Simulate MOVED error for other port + moved_requests.fetch_add(1, Ordering::Relaxed); + Err(parse_redis_value( + format!("-MOVED {key_slot} {name}:{moved_to_port}\r\n").as_bytes(), + )) + } else { + panic!("unexpected port for SET command: {port:?}.\n + Expected one of: moved_to_port={moved_to_port}, moved_from_port={moved_from_port}"); + } + } else if contains_slice(cmd, b"GET") { + if new_shard_replica_port == port { + // Simulate replica response for GET after slot migration + replica_requests.fetch_add(1, Ordering::Relaxed); + Err(Ok(Value::BulkString(b"123".to_vec()))) + } else { + panic!("unexpected port for GET command: {port:?}, Expected: {new_shard_replica_port:?}"); + } + } else { + panic!("unexpected command {cmd:?}") + } + }, + ); + + // First request: Trigger MOVED error and reroute + let value = runtime.block_on( + cmd("SET") + .arg(key) + .arg("bar") + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(Value::SimpleString("OK".to_owned())))); + + // Second request: Should be routed directly to the new primary node if the slots map is updated + let value = runtime.block_on( + cmd("SET") + .arg(key) + .arg("bar") + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(Value::SimpleString("OK".to_owned())))); + + // Handle slot migration scenario: Ensure the new shard's replicas are accessible + let value = runtime.block_on( + cmd("GET") + .arg(key) + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(123))); + assert_eq!(cloned_replica_requests.load(Ordering::Relaxed), 1); + + // Assert there was only a single MOVED error + assert_eq!(cloned_moved_requests.load(Ordering::Relaxed), 1); + } + + #[test] + fn test_async_cluster_update_slots_based_on_moved_error_indicates_failover() { + // This test simulates a failover scenario, where the client receives a MOVED error and the replica becomes the new primary. + // The test verifies that the client updates the slot mapping to promote the replica to the primary and routes future requests + // to the new primary, ensuring other slots in the shard are also handled by the new primary. + let name = "test_async_cluster_update_slots_based_on_moved_error_indicates_failover"; + let slots_config = vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![7001], + slot_range: (0..8000), + }, + MockSlotRange { + primary_port: 6380, + replica_ports: vec![7002], + slot_range: (8001..16380), + }, + ]; + + let moved_from_port = 6379; + let moved_to_port = 7001; + + // Tracking moved for validation + let moved_requests = Arc::new(atomic::AtomicUsize::new(0)); + let cloned_moved_requests = moved_requests.clone(); + + // Test key and slot + let key = "test"; + let key_slot = 6918; + + // Mock environment setup + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .slots_refresh_rate_limit(Duration::from_secs(1000000), 0), // Rate limiter to disable slot refresh + name, + move |cmd: &[u8], port| { + if contains_slice(cmd, b"PING") + || contains_slice(cmd, b"SETNAME") + || contains_slice(cmd, b"READONLY") + { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + let slots = create_topology_from_config(name, slots_config.clone()); + return Err(Ok(slots)); + } + + if contains_slice(cmd, b"SET") { + if port == moved_to_port { + // Simulate primary OK response + Err(Ok(Value::SimpleString("OK".into()))) + } else if port == moved_from_port { + // Simulate MOVED error for other port + moved_requests.fetch_add(1, Ordering::Relaxed); + Err(parse_redis_value( + format!("-MOVED {key_slot} {name}:{moved_to_port}\r\n").as_bytes(), + )) + } else { + panic!("unexpected port for SET command: {port:?}.\n + Expected one of: moved_to_port={moved_to_port}, moved_from_port={moved_from_port}"); + } + } else { + panic!("unexpected command {cmd:?}") + } + }, + ); + + // First request: Trigger MOVED error and reroute + let value = runtime.block_on( + cmd("SET") + .arg(key) + .arg("bar") + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(Value::SimpleString("OK".to_owned())))); + + // Second request: Should be routed directly to the new primary node if the slots map is updated + let value = runtime.block_on( + cmd("SET") + .arg(key) + .arg("bar") + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(Value::SimpleString("OK".to_owned())))); + + // Handle failover scenario: Ensure other slots in the same shard are updated to the new primary + let key_slot_1044 = "foo2"; + let value = runtime.block_on( + cmd("SET") + .arg(key_slot_1044) + .arg("bar2") + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(Value::SimpleString("OK".to_owned())))); + + // Assert there was only a single MOVED error + assert_eq!(cloned_moved_requests.load(Ordering::Relaxed), 1); + } + + #[test] + fn test_async_cluster_update_slots_based_on_moved_error_indicates_new_primary() { + // This test simulates the scenario where the client receives a MOVED error indicating that the key now belongs to + // an entirely new primary node that wasn't previously known. The test verifies that the client correctly adds the new + // primary node to its slot map and routes future requests to the new node. + let name = "test_async_cluster_update_slots_based_on_moved_error_indicates_new_primary"; + let slots_config = vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![], + slot_range: (0..8000), + }, + MockSlotRange { + primary_port: 6380, + replica_ports: vec![], + slot_range: (8001..16380), + }, + ]; + + let moved_from_port = 6379; + let moved_to_port = 6381; + + // Tracking moved for validation + let moved_requests = Arc::new(atomic::AtomicUsize::new(0)); + let cloned_moved_requests = moved_requests.clone(); + + // Test key and slot + let key = "test"; + let key_slot = 6918; + + // Mock environment setup + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .slots_refresh_rate_limit(Duration::from_secs(1000000), 0) // Rate limiter to disable slot refresh + .read_from_replicas(), // Allow reads from replicas + name, + move |cmd: &[u8], port| { + if contains_slice(cmd, b"PING") + || contains_slice(cmd, b"SETNAME") + || contains_slice(cmd, b"READONLY") + { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + let slots = create_topology_from_config(name, slots_config.clone()); + return Err(Ok(slots)); + } + + if contains_slice(cmd, b"SET") { + if port == moved_to_port { + // Simulate primary OK response + Err(Ok(Value::SimpleString("OK".into()))) + } else if port == moved_from_port { + // Simulate MOVED error for other port + moved_requests.fetch_add(1, Ordering::Relaxed); + Err(parse_redis_value( + format!("-MOVED {key_slot} {name}:{moved_to_port}\r\n").as_bytes(), + )) + } else { + panic!("unexpected port for SET command: {port:?}.\n + Expected one of: moved_to_port={moved_to_port}, moved_from_port={moved_from_port}"); + } + } else if contains_slice(cmd, b"GET") { + if moved_to_port == port { + // Simulate primary response for GET + Err(Ok(Value::BulkString(b"123".to_vec()))) + } else { + panic!( + "unexpected port for GET command: {port:?}, Expected: {moved_to_port}" + ); + } + } else { + panic!("unexpected command {cmd:?}") + } + }, + ); + + // First request: Trigger MOVED error and reroute + let value = runtime.block_on( + cmd("SET") + .arg(key) + .arg("bar") + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(Value::SimpleString("OK".to_owned())))); + + // Second request: Should be routed directly to the new primary node if the slots map is updated + let value = runtime.block_on( + cmd("SET") + .arg(key) + .arg("bar") + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(Value::SimpleString("OK".to_owned())))); + + // Third request: The new primary should have no replicas so it should be directed to it + let value = runtime.block_on( + cmd("GET") + .arg(key) + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(123))); + + // Assert there was only a single MOVED error + assert_eq!(cloned_moved_requests.load(Ordering::Relaxed), 1); + } + + #[test] + fn test_async_cluster_update_slots_based_on_moved_error_indicates_replica_of_different_shard() { + // This test simulates a scenario where the client receives a MOVED error indicating that a key + // has been moved to a replica in a different shard. The replica is then promoted to primary and + // no longer exists in the shard’s replica set. + // The test validates that the key gets correctly routed to the new primary and ensures that the + // shard updates its mapping accordingly, with only one MOVED error encountered during the process. + + let name = "test_async_cluster_update_slots_based_on_moved_error_indicates_replica_of_different_shard"; + let slots_config = vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![7000], + slot_range: (0..8000), + }, + MockSlotRange { + primary_port: 6380, + replica_ports: vec![7001], + slot_range: (8001..16380), + }, + ]; + + let moved_from_port = 6379; + let moved_to_port = 7001; + let primary_shard2 = 6380; + + // Tracking moved for validation + let moved_requests = Arc::new(atomic::AtomicUsize::new(0)); + let cloned_moved_requests = moved_requests.clone(); + + // Test key and slot of the first shard + let key = "test"; + let key_slot = 6918; + + // Test key of the second shard + let key_shard2 = "foo"; // slot 12182 + + // Mock environment setup + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .slots_refresh_rate_limit(Duration::from_secs(1000000), 0) // Rate limiter to disable slot refresh + .read_from_replicas(), // Allow reads from replicas + name, + move |cmd: &[u8], port| { + if contains_slice(cmd, b"PING") + || contains_slice(cmd, b"SETNAME") + || contains_slice(cmd, b"READONLY") + { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + let slots = create_topology_from_config(name, slots_config.clone()); + return Err(Ok(slots)); + } + + if contains_slice(cmd, b"SET") { + if port == moved_to_port { + // Simulate primary OK response + Err(Ok(Value::SimpleString("OK".into()))) + } else if port == moved_from_port { + // Simulate MOVED error for other port + moved_requests.fetch_add(1, Ordering::Relaxed); + Err(parse_redis_value( + format!("-MOVED {key_slot} {name}:{moved_to_port}\r\n").as_bytes(), + )) + } else { + panic!("unexpected port for SET command: {port:?}.\n + Expected one of: moved_to_port={moved_to_port}, moved_from_port={moved_from_port}"); + } + } else if contains_slice(cmd, b"GET") { + if port == primary_shard2 { + // Simulate second shard primary response for GET + Err(Ok(Value::BulkString(b"123".to_vec()))) + } else { + panic!("unexpected port for GET command: {port:?}, Expected: {primary_shard2:?}"); + } + } else { + panic!("unexpected command {cmd:?}") + } + }, + ); + + // First request: Trigger MOVED error and reroute + let value = runtime.block_on( + cmd("SET") + .arg(key) + .arg("bar") + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(Value::SimpleString("OK".to_owned())))); + + // Second request: Should be routed directly to the new primary node if the slots map is updated + let value = runtime.block_on( + cmd("SET") + .arg(key) + .arg("bar") + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(Value::SimpleString("OK".to_owned())))); + + // Third request: Verify that the promoted replica is no longer part of the second shard replicas by + // ensuring the response is received from the shard's primary + let value = runtime.block_on( + cmd("GET") + .arg(key_shard2) + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(123))); + + // Assert there was only a single MOVED error + assert_eq!(cloned_moved_requests.load(Ordering::Relaxed), 1); + } + + #[test] + fn test_async_cluster_update_slots_based_on_moved_error_no_change() { + // This test simulates a scenario where the client receives a MOVED error, but the new primary is the + // same as the old primary (no actual change). It ensures that no additional slot map + // updates are required and that the subsequent requests are still routed to the same primary node, with + // only one MOVED error encountered. + let name = "test_async_cluster_update_slots_based_on_moved_error_no_change"; + let slots_config = vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![7000], + slot_range: (0..8000), + }, + MockSlotRange { + primary_port: 6380, + replica_ports: vec![7001], + slot_range: (8001..16380), + }, + ]; + + let moved_from_port = 6379; + let moved_to_port = 6379; + + // Tracking moved for validation + let moved_requests = Arc::new(atomic::AtomicUsize::new(0)); + let cloned_moved_requests = moved_requests.clone(); + + // Test key and slot of the first shard + let key = "test"; + let key_slot = 6918; + + // Mock environment setup + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .slots_refresh_rate_limit(Duration::from_secs(1000000), 0), // Rate limiter to disable slot refresh + name, + move |cmd: &[u8], port| { + if contains_slice(cmd, b"PING") + || contains_slice(cmd, b"SETNAME") + || contains_slice(cmd, b"READONLY") + { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + let slots = create_topology_from_config(name, slots_config.clone()); + return Err(Ok(slots)); + } + + if contains_slice(cmd, b"SET") { + if port == moved_to_port { + if moved_requests.load(Ordering::Relaxed) == 0 { + moved_requests.fetch_add(1, Ordering::Relaxed); + Err(parse_redis_value( + format!("-MOVED {key_slot} {name}:{moved_to_port}\r\n").as_bytes(), + )) + } else { + Err(Ok(Value::SimpleString("OK".into()))) + } + } else { + panic!("unexpected port for SET command: {port:?}.\n + Expected one of: moved_to_port={moved_to_port}, moved_from_port={moved_from_port}"); + } + } else { + panic!("unexpected command {cmd:?}") + } + }, + ); + + // First request: Trigger MOVED error and reroute + let value = runtime.block_on( + cmd("SET") + .arg(key) + .arg("bar") + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(Value::SimpleString("OK".to_owned())))); + + // Second request: Should be still routed to the same primary node + let value = runtime.block_on( + cmd("SET") + .arg(key) + .arg("bar") + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(Value::SimpleString("OK".to_owned())))); + + // Assert there was only a single MOVED error + assert_eq!(cloned_moved_requests.load(Ordering::Relaxed), 1); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_reconnect_even_with_zero_retries() { + let name = "test_async_cluster_reconnect_even_with_zero_retries"; + + let should_reconnect = atomic::AtomicBool::new(true); + let connection_count = Arc::new(atomic::AtomicU16::new(0)); + let connection_count_clone = connection_count.clone(); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(0), + name, + move |cmd: &[u8], port| { + match respond_startup(name, cmd) { + Ok(_) => {} + Err(err) => { + connection_count.fetch_add(1, Ordering::Relaxed); + return Err(err); + } + } + + if contains_slice(cmd, b"ECHO") && port == 6379 { + // Should not attempt to refresh slots more than once: + if should_reconnect.swap(false, Ordering::SeqCst) { + Err(Err(broken_pipe_error())) + } else { + Err(Ok(Value::BulkString(b"PONG".to_vec()))) + } + } else { + panic!("unexpected command {cmd:?}") + } + }, + ); + + // We expect 6 calls in total. MockEnv creates both synchronous and asynchronous connections, which make the following calls: + // - 1 call by the sync connection to `CLUSTER SLOTS` for initializing the client's topology map. + // - 3 calls by the async connection to `PING`: one for the user connection when creating the node from initial addresses, + // and two more for checking the user and management connections during client initialization in `refresh_slots`. + // - 1 call by the async connection to `CLIENT SETNAME` for setting up the management connection name. + // - 1 call by the async connection to `CLUSTER SLOTS` for initializing the client's topology map. + // Note: If additional nodes or setup calls are added, this number should increase. + let expected_init_calls = 6; + assert_eq!( + connection_count_clone.load(Ordering::Relaxed), + expected_init_calls + ); + + let value = runtime.block_on(connection.route_command( + &cmd("ECHO"), + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: name.to_string(), + port: 6379, + }), + )); + + // The user should receive an initial error, because there are no retries and the first request failed. + assert_eq!( + value.unwrap_err().to_string(), + broken_pipe_error().to_string() + ); + + let value = runtime.block_on(connection.route_command( + &cmd("ECHO"), + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: name.to_string(), + port: 6379, + }), + )); + + assert_eq!(value, Ok(Value::BulkString(b"PONG".to_vec()))); + // `expected_init_calls` plus another PING for a new user connection created from refresh_connections + assert_eq!( + connection_count_clone.load(Ordering::Relaxed), + expected_init_calls + 1 + ); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_refresh_slots_rate_limiter_skips_refresh() { + let ports = get_ports(3); + test_async_cluster_refresh_slots_rate_limiter_helper( + get_topology_with_majority(&ports), + ports, + true, + ); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_refresh_slots_rate_limiter_does_refresh_when_wait_duration_passed() { + let ports = get_ports(3); + test_async_cluster_refresh_slots_rate_limiter_helper( + get_topology_with_majority(&ports), + ports, + false, + ); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_ask_redirect() { + let name = "node"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { + async_connection: mut connection, + handler: _handler, + runtime, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]), + name, + { + move |cmd: &[u8], port| { + respond_startup_two_nodes(name, cmd)?; + // Error twice with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + let count = completed.fetch_add(1, Ordering::SeqCst); + match port { + 6379 => match count { + 0 => Err(parse_redis_value(b"-ASK 14000 node:6380\r\n")), + _ => panic!("Node should not be called now"), + }, + 6380 => match count { + 1 => { + assert!(contains_slice(cmd, b"ASKING")); + Err(Ok(Value::Okay)) + } + 2 => { + assert!(contains_slice(cmd, b"GET")); + Err(Ok(Value::BulkString(b"123".to_vec()))) + } + _ => panic!("Node should not be called now"), + }, + _ => panic!("Wrong node"), + } + } + }, + ); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_ask_save_new_connection() { + let name = "node"; + let ping_attempts = Arc::new(AtomicI32::new(0)); + let ping_attempts_clone = ping_attempts.clone(); + let MockEnv { + async_connection: mut connection, + handler: _handler, + runtime, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]), + name, + { + move |cmd: &[u8], port| { + if port != 6391 { + respond_startup_two_nodes(name, cmd)?; + return Err(parse_redis_value(b"-ASK 14000 node:6391\r\n")); + } + + if contains_slice(cmd, b"PING") { + ping_attempts_clone.fetch_add(1, Ordering::Relaxed); + } + respond_startup_two_nodes(name, cmd)?; + Err(Ok(Value::Okay)) + } + }, + ); + + for _ in 0..4 { + runtime + .block_on( + cmd("GET") + .arg("test") + .query_async::<_, Value>(&mut connection), + ) + .unwrap(); + } + + assert_eq!(ping_attempts.load(Ordering::Relaxed), 1); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_reset_routing_if_redirect_fails() { + let name = "test_async_cluster_reset_routing_if_redirect_fails"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { + async_connection: mut connection, + handler: _handler, + runtime, + .. + } = MockEnv::new(name, move |cmd: &[u8], port| { + if port != 6379 && port != 6380 { + return Err(Err(broken_pipe_error())); + } + respond_startup_two_nodes(name, cmd)?; + let count = completed.fetch_add(1, Ordering::SeqCst); + match (port, count) { + // redirect once to non-existing node + (6379, 0) => Err(parse_redis_value( + format!("-ASK 14000 {name}:9999\r\n").as_bytes(), + )), + // accept the next request + (6379, 1) => { + assert!(contains_slice(cmd, b"GET")); + Err(Ok(Value::BulkString(b"123".to_vec()))) + } + _ => panic!("Wrong node. port: {port}, received count: {count}"), + } + }); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_ask_redirect_even_if_original_call_had_no_route() { + let name = "node"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { + async_connection: mut connection, + handler: _handler, + runtime, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]), + name, + { + move |cmd: &[u8], port| { + respond_startup_two_nodes(name, cmd)?; + // Error twice with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + let count = completed.fetch_add(1, Ordering::SeqCst); + if count == 0 { + return Err(parse_redis_value(b"-ASK 14000 node:6380\r\n")); + } + match port { + 6380 => match count { + 1 => { + assert!( + contains_slice(cmd, b"ASKING"), + "{:?}", + std::str::from_utf8(cmd) + ); + Err(Ok(Value::Okay)) + } + 2 => { + assert!(contains_slice(cmd, b"EVAL")); + Err(Ok(Value::Okay)) + } + _ => panic!("Node should not be called now"), + }, + _ => panic!("Wrong node"), + } + } + }, + ); + + let value = runtime.block_on( + cmd("EVAL") // Eval command has no directed, and so is redirected randomly + .query_async::<_, Value>(&mut connection), + ); + + assert_eq!(value, Ok(Value::Okay)); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_ask_error_when_new_node_is_added() { + let name = "ask_with_extra_nodes"; + + let requests = atomic::AtomicUsize::new(0); + let started = atomic::AtomicBool::new(false); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::new(name, move |cmd: &[u8], port| { + if !started.load(atomic::Ordering::SeqCst) { + respond_startup(name, cmd)?; + } + started.store(true, atomic::Ordering::SeqCst); + + if contains_slice(cmd, b"PING") || contains_slice(cmd, b"SETNAME") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + + match i { + // Respond that the key exists on a node that does not yet have a connection: + 0 => Err(parse_redis_value( + format!("-ASK 123 {name}:6380\r\n").as_bytes(), + )), + 1 => { + assert_eq!(port, 6380); + assert!(contains_slice(cmd, b"ASKING")); + Err(Ok(Value::Okay)) + } + 2 => { + assert_eq!(port, 6380); + assert!(contains_slice(cmd, b"GET")); + Err(Ok(Value::BulkString(b"123".to_vec()))) + } + _ => { + panic!("Unexpected request: {:?}", cmd); + } + } + }); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_replica_read() { + let name = "node"; + + // requests should route to replica + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica(name, cmd)?; + match port { + 6380 => Err(Ok(Value::BulkString(b"123".to_vec()))), + _ => panic!("Wrong node"), + } + }, + ); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(123))); + + // requests should route to primary + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica(name, cmd)?; + match port { + 6379 => Err(Ok(Value::SimpleString("OK".into()))), + _ => panic!("Wrong node"), + } + }, + ); + + let value = runtime.block_on( + cmd("SET") + .arg("test") + .arg("123") + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(Value::SimpleString("OK".to_owned())))); + } + + fn test_async_cluster_fan_out( + command: &'static str, + expected_ports: Vec, + slots_config: Option>, + ) { + let name = "node"; + let found_ports = Arc::new(std::sync::Mutex::new(Vec::new())); + let ports_clone = found_ports.clone(); + let mut cmd = Cmd::new(); + for arg in command.split_whitespace() { + cmd.arg(arg); + } + let packed_cmd = cmd.get_packed_command(); + // requests should route to replica + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config( + name, + received_cmd, + slots_config.clone(), + )?; + if received_cmd == packed_cmd { + ports_clone.lock().unwrap().push(port); + return Err(Ok(Value::SimpleString("OK".into()))); + } + Ok(()) + }, + ); + + let _ = runtime.block_on(cmd.query_async::<_, Option<()>>(&mut connection)); + found_ports.lock().unwrap().sort(); + // MockEnv creates 2 mock connections. + assert_eq!(*found_ports.lock().unwrap(), expected_ports); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_fan_out_to_all_primaries() { + test_async_cluster_fan_out("FLUSHALL", vec![6379, 6381], None); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_fan_out_to_all_nodes() { + test_async_cluster_fan_out("CONFIG SET", vec![6379, 6380, 6381, 6382], None); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_fan_out_once_to_each_primary_when_no_replicas_are_available() { + test_async_cluster_fan_out( + "CONFIG SET", + vec![6379, 6381], + Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: Vec::new(), + slot_range: (0..8191), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: Vec::new(), + slot_range: (8192..16383), + }, + ]), + ); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_fan_out_once_even_if_primary_has_multiple_slot_ranges() { + test_async_cluster_fan_out( + "CONFIG SET", + vec![6379, 6380, 6381, 6382], + Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380], + slot_range: (0..4000), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: vec![6382], + slot_range: (4001..8191), + }, + MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380], + slot_range: (8192..8200), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: vec![6382], + slot_range: (8201..16383), + }, + ]), + ); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_route_according_to_passed_argument() { + let name = "test_async_cluster_route_according_to_passed_argument"; + + let touched_ports = Arc::new(std::sync::Mutex::new(Vec::new())); + let cloned_ports = touched_ports.clone(); + + // requests should route to replica + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica(name, cmd)?; + cloned_ports.lock().unwrap().push(port); + Err(Ok(Value::Nil)) + }, + ); + + let mut cmd = cmd("GET"); + cmd.arg("test"); + let _ = runtime.block_on(connection.route_command( + &cmd, + RoutingInfo::MultiNode((MultipleNodeRoutingInfo::AllMasters, None)), + )); + { + let mut touched_ports = touched_ports.lock().unwrap(); + touched_ports.sort(); + assert_eq!(*touched_ports, vec![6379, 6381]); + touched_ports.clear(); + } + + let _ = runtime.block_on(connection.route_command( + &cmd, + RoutingInfo::MultiNode((MultipleNodeRoutingInfo::AllNodes, None)), + )); + { + let mut touched_ports = touched_ports.lock().unwrap(); + touched_ports.sort(); + assert_eq!(*touched_ports, vec![6379, 6380, 6381, 6382]); + touched_ports.clear(); + } + + let _ = runtime.block_on(connection.route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: name.to_string(), + port: 6382, + }), + )); + { + let mut touched_ports = touched_ports.lock().unwrap(); + touched_ports.sort(); + assert_eq!(*touched_ports, vec![6382]); + touched_ports.clear(); + } + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_fan_out_and_aggregate_numeric_response_with_min() { + let name = "test_async_cluster_fan_out_and_aggregate_numeric_response"; + let mut cmd = Cmd::new(); + cmd.arg("SLOWLOG").arg("LEN"); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + + let res = 6383 - port as i64; + Err(Ok(Value::Int(res))) // this results in 1,2,3,4 + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, i64>(&mut connection)) + .unwrap(); + assert_eq!(result, 10, "{result}"); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_fan_out_and_aggregate_logical_array_response() { + let name = "test_async_cluster_fan_out_and_aggregate_logical_array_response"; + let mut cmd = Cmd::new(); + cmd.arg("SCRIPT") + .arg("EXISTS") + .arg("foo") + .arg("bar") + .arg("baz") + .arg("barvaz"); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + + if port == 6381 { + return Err(Ok(Value::Array(vec![ + Value::Int(0), + Value::Int(0), + Value::Int(1), + Value::Int(1), + ]))); + } else if port == 6379 { + return Err(Ok(Value::Array(vec![ + Value::Int(0), + Value::Int(1), + Value::Int(0), + Value::Int(1), + ]))); + } + + panic!("unexpected port {port}"); + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Vec>(&mut connection)) + .unwrap(); + assert_eq!(result, vec![0, 0, 0, 1], "{result:?}"); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_fan_out_and_return_one_succeeded_response() { + let name = "test_async_cluster_fan_out_and_return_one_succeeded_response"; + let mut cmd = Cmd::new(); + cmd.arg("SCRIPT").arg("KILL"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + if port == 6381 { + return Err(Ok(Value::Okay)); + } + Err(Err(( + ErrorKind::NotBusy, + "No scripts in execution right now", + ) + .into())) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Value>(&mut connection)) + .unwrap(); + assert_eq!(result, Value::Okay, "{result:?}"); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_fan_out_and_fail_one_succeeded_if_there_are_no_successes() { + let name = "test_async_cluster_fan_out_and_fail_one_succeeded_if_there_are_no_successes"; + let mut cmd = Cmd::new(); + cmd.arg("SCRIPT").arg("KILL"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], _port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + + Err(Err(( + ErrorKind::NotBusy, + "No scripts in execution right now", + ) + .into())) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Value>(&mut connection)) + .unwrap_err(); + assert_eq!(result.kind(), ErrorKind::NotBusy, "{:?}", result.kind()); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_fan_out_and_return_all_succeeded_response() { + let name = "test_async_cluster_fan_out_and_return_all_succeeded_response"; + let cmd = cmd("FLUSHALL"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], _port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + Err(Ok(Value::Okay)) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Value>(&mut connection)) + .unwrap(); + assert_eq!(result, Value::Okay, "{result:?}"); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_fan_out_and_fail_all_succeeded_if_there_is_a_single_failure() { + let name = "test_async_cluster_fan_out_and_fail_all_succeeded_if_there_is_a_single_failure"; + let cmd = cmd("FLUSHALL"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + if port == 6381 { + return Err(Err(( + ErrorKind::NotBusy, + "No scripts in execution right now", + ) + .into())); + } + Err(Ok(Value::Okay)) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Value>(&mut connection)) + .unwrap_err(); + assert_eq!(result.kind(), ErrorKind::NotBusy, "{:?}", result.kind()); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_first_succeeded_non_empty_or_all_empty_return_value_ignoring_nil_and_err_resps( + ) { + let name = + "test_async_cluster_first_succeeded_non_empty_or_all_empty_return_value_ignoring_nil_and_err_resps"; + let cmd = cmd("RANDOMKEY"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + let ports = vec![6379, 6380, 6381]; + let slots_config_vec = generate_topology_view(&ports, 1000, true); + respond_startup_with_config(name, received_cmd, Some(slots_config_vec), false)?; + if port == 6380 { + return Err(Ok(Value::BulkString("foo".as_bytes().to_vec()))); + } else if port == 6381 { + return Err(Err(RedisError::from(( + redis::ErrorKind::ResponseError, + "ERROR", + )))); + } + Err(Ok(Value::Nil)) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, String>(&mut connection)) + .unwrap(); + assert_eq!(result, "foo", "{result:?}"); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_first_succeeded_non_empty_or_all_empty_return_err_if_all_resps_are_nil_and_errors( + ) { + let name = + "test_async_cluster_first_succeeded_non_empty_or_all_empty_return_err_if_all_resps_are_nil_and_errors"; + let cmd = cmd("RANDOMKEY"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_config(name, received_cmd, None, false)?; + if port == 6380 { + return Err(Ok(Value::Nil)); + } + Err(Err(RedisError::from(( + redis::ErrorKind::ResponseError, + "ERROR", + )))) + }, + ); + let result = runtime + .block_on(cmd.query_async::<_, Value>(&mut connection)) + .unwrap_err(); + assert_eq!(result.kind(), ErrorKind::ResponseError); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_first_succeeded_non_empty_or_all_empty_return_nil_if_all_resp_nil() { + let name = + "test_async_cluster_first_succeeded_non_empty_or_all_empty_return_nil_if_all_resp_nil"; + let cmd = cmd("RANDOMKEY"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], _port| { + respond_startup_with_config(name, received_cmd, None, false)?; + Err(Ok(Value::Nil)) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Value>(&mut connection)) + .unwrap(); + assert_eq!(result, Value::Nil, "{result:?}"); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_fan_out_and_return_map_of_results_for_special_response_policy() { + let name = "foo"; + let mut cmd = Cmd::new(); + cmd.arg("LATENCY").arg("LATEST"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + Err(Ok(Value::BulkString( + format!("latency: {port}").into_bytes(), + ))) + }, + ); + + // TODO once RESP3 is in, return this as a map + let mut result = runtime + .block_on(cmd.query_async::<_, Vec<(String, String)>>(&mut connection)) + .unwrap(); + result.sort(); + assert_eq!( + result, + vec![ + (format!("{name}:6379"), "latency: 6379".to_string()), + (format!("{name}:6380"), "latency: 6380".to_string()), + (format!("{name}:6381"), "latency: 6381".to_string()), + (format!("{name}:6382"), "latency: 6382".to_string()) + ], + "{result:?}" + ); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_fan_out_and_combine_arrays_of_values() { + let name = "foo"; + let cmd = cmd("KEYS"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + Err(Ok(Value::Array(vec![Value::BulkString( + format!("key:{port}").into_bytes(), + )]))) + }, + ); + + let mut result = runtime + .block_on(cmd.query_async::<_, Vec>(&mut connection)) + .unwrap(); + result.sort(); + assert_eq!( + result, + vec!["key:6379".to_string(), "key:6381".to_string(),], + "{result:?}" + ); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_split_multi_shard_command_and_combine_arrays_of_values() { + let name = "test_async_cluster_split_multi_shard_command_and_combine_arrays_of_values"; + let mut cmd = cmd("MGET"); + cmd.arg("foo").arg("bar").arg("baz"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + let cmd_str = std::str::from_utf8(received_cmd).unwrap(); + let results = ["foo", "bar", "baz"] + .iter() + .filter_map(|expected_key| { + if cmd_str.contains(expected_key) { + Some(Value::BulkString( + format!("{expected_key}-{port}").into_bytes(), + )) + } else { + None + } + }) + .collect(); + Err(Ok(Value::Array(results))) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Vec>(&mut connection)) + .unwrap(); + assert_eq!(result, vec!["foo-6382", "bar-6380", "baz-6380"]); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_handle_asking_error_in_split_multi_shard_command() { + let name = "test_async_cluster_handle_asking_error_in_split_multi_shard_command"; + let mut cmd = cmd("MGET"); + cmd.arg("foo").arg("bar").arg("baz"); + let asking_called = Arc::new(AtomicU16::new(0)); + let asking_called_cloned = asking_called.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + let cmd_str = std::str::from_utf8(received_cmd).unwrap(); + if cmd_str.contains("ASKING") && port == 6382 { + asking_called_cloned.fetch_add(1, Ordering::Relaxed); + } + if port == 6380 && cmd_str.contains("baz") { + return Err(parse_redis_value( + format!("-ASK 14000 {name}:6382\r\n").as_bytes(), + )); + } + let results = ["foo", "bar", "baz"] + .iter() + .filter_map(|expected_key| { + if cmd_str.contains(expected_key) { + Some(Value::BulkString( + format!("{expected_key}-{port}").into_bytes(), + )) + } else { + None + } + }) + .collect(); + Err(Ok(Value::Array(results))) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Vec>(&mut connection)) + .unwrap(); + assert_eq!(result, vec!["foo-6382", "bar-6380", "baz-6382"]); + assert_eq!(asking_called.load(Ordering::Relaxed), 1); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_pass_errors_from_split_multi_shard_command() { + let name = "test_async_cluster_pass_errors_from_split_multi_shard_command"; + let mut cmd = cmd("MGET"); + cmd.arg("foo").arg("bar").arg("baz"); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::new(name, move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + let cmd_str = std::str::from_utf8(received_cmd).unwrap(); + if cmd_str.contains("foo") || cmd_str.contains("baz") { + Err(Err((ErrorKind::IoError, "error").into())) + } else { + Err(Ok(Value::Array(vec![Value::BulkString( + format!("{port}").into_bytes(), + )]))) + } + }); + + let result = runtime + .block_on(cmd.query_async::<_, Vec>(&mut connection)) + .unwrap_err(); + assert_eq!(result.kind(), ErrorKind::IoError); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_handle_missing_slots_in_split_multi_shard_command() { + let name = "test_async_cluster_handle_missing_slots_in_split_multi_shard_command"; + let mut cmd = cmd("MGET"); + cmd.arg("foo").arg("bar").arg("baz"); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::new(name, move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config( + name, + received_cmd, + Some(vec![MockSlotRange { + primary_port: 6381, + replica_ports: vec![6382], + slot_range: (8192..16383), + }]), + )?; + Err(Ok(Value::Array(vec![Value::BulkString( + format!("{port}").into_bytes(), + )]))) + }); + + let result = runtime + .block_on(cmd.query_async::<_, Vec>(&mut connection)) + .unwrap_err(); + assert!( + matches!(result.kind(), ErrorKind::ConnectionNotFoundForRoute) + || result.is_connection_dropped() + ); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_with_username_and_password() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .username(RedisCluster::username().to_string()) + .password(RedisCluster::password().to_string()) + }, + false, + ); + cluster.disable_default_user(); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + cmd("SET") + .arg("test") + .arg("test_data") + .query_async(&mut connection) + .await?; + let res: String = cmd("GET") + .arg("test") + .clone() + .query_async(&mut connection) + .await?; + assert_eq!(res, "test_data"); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_io_error() { + let name = "node"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(2), + name, + move |cmd: &[u8], port| { + respond_startup_two_nodes(name, cmd)?; + // Error twice with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + match port { + 6380 => panic!("Node should not be called"), + _ => match completed.fetch_add(1, Ordering::SeqCst) { + 0..=1 => Err(Err(RedisError::from(( + ErrorKind::FatalSendError, + "mock-io-error", + )))), + _ => Err(Ok(Value::BulkString(b"123".to_vec()))), + }, + } + }, + ); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_non_retryable_error_should_not_retry() { + let name = "node"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { + async_connection: mut connection, + handler: _handler, + runtime, + .. + } = MockEnv::new(name, { + let completed = completed.clone(); + move |cmd: &[u8], _| { + respond_startup_two_nodes(name, cmd)?; + // Error twice with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + completed.fetch_add(1, Ordering::SeqCst); + Err(Err((ErrorKind::ReadOnly, "").into())) + } + }); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + match value { + Ok(_) => panic!("result should be an error"), + Err(e) => match e.kind() { + ErrorKind::ReadOnly => {} + _ => panic!("Expected ReadOnly but got {:?}", e.kind()), + }, + } + assert_eq!(completed.load(Ordering::SeqCst), 1); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_non_retryable_io_error_should_not_retry() { + let name = "test_async_cluster_non_retryable_io_error_should_not_retry"; + let requests = atomic::AtomicUsize::new(0); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(3), + name, + move |cmd: &[u8], _port| { + respond_startup_two_nodes(name, cmd)?; + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + match i { + 0 => Err(Err(RedisError::from((ErrorKind::IoError, "io-error")))), + _ => { + panic!("Expected not to be retried!") + } + } + }, + ); + runtime + .block_on(async move { + let res = cmd("INCR") + .arg("foo") + .query_async::<_, Option>(&mut connection) + .await; + assert!(res.is_err()); + let err = res.unwrap_err(); + assert!(err.is_io_error()); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_retry_safe_io_error_should_be_retried() { + let name = "test_async_cluster_retry_safe_io_error_should_be_retried"; + let requests = atomic::AtomicUsize::new(0); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(3), + name, + move |cmd: &[u8], _port| { + respond_startup_two_nodes(name, cmd)?; + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + match i { + 0 => Err(Err(RedisError::from(( + ErrorKind::FatalSendError, + "server didn't receive the request, safe to retry", + )))), + _ => Err(Ok(Value::Int(1))), + } + }, + ); + runtime + .block_on(async move { + let res = cmd("INCR") + .arg("foo") + .query_async::<_, i32>(&mut connection) + .await; + assert!(res.is_ok()); + let value = res.unwrap(); + assert_eq!(value, 1); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_read_from_primary() { + let name = "node"; + let found_ports = Arc::new(std::sync::Mutex::new(Vec::new())); + let ports_clone = found_ports.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::new(name, move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config( + name, + received_cmd, + Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380, 6381], + slot_range: (0..8191), + }, + MockSlotRange { + primary_port: 6382, + replica_ports: vec![6383, 6384], + slot_range: (8192..16383), + }, + ]), + )?; + ports_clone.lock().unwrap().push(port); + Err(Ok(Value::Nil)) + }); + + runtime.block_on(async { + cmd("GET") + .arg("foo") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("bar") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("foo") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("bar") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + }); + + found_ports.lock().unwrap().sort(); + assert_eq!(*found_ports.lock().unwrap(), vec![6379, 6379, 6382, 6382]); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_round_robin_read_from_replica() { + let name = "node"; + let found_ports = Arc::new(std::sync::Mutex::new(Vec::new())); + let ports_clone = found_ports.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config( + name, + received_cmd, + Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380, 6381], + slot_range: (0..8191), + }, + MockSlotRange { + primary_port: 6382, + replica_ports: vec![6383, 6384], + slot_range: (8192..16383), + }, + ]), + )?; + ports_clone.lock().unwrap().push(port); + Err(Ok(Value::Nil)) + }, + ); + + runtime.block_on(async { + cmd("GET") + .arg("foo") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("bar") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("foo") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("bar") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + }); + + found_ports.lock().unwrap().sort(); + assert_eq!(*found_ports.lock().unwrap(), vec![6380, 6381, 6383, 6384]); + } + + fn get_queried_node_id_if_master(cluster_nodes_output: Value) -> Option { + // Returns the node ID of the connection that was queried for CLUSTER NODES (using the 'myself' flag), if it's a master. + // Otherwise, returns None. + let get_node_id = |str: &str| { + let parts: Vec<&str> = str.split('\n').collect(); + for node_entry in parts { + if node_entry.contains("myself") && node_entry.contains("master") { + let node_entry_parts: Vec<&str> = node_entry.split(' ').collect(); + let node_id = node_entry_parts[0]; + return Some(node_id.to_string()); + } + } + None + }; + + match cluster_nodes_output { + Value::BulkString(val) => match from_utf8(&val) { + Ok(str_res) => get_node_id(str_res), + Err(e) => panic!("failed to decode INFO response: {:?}", e), + }, + Value::VerbatimString { format: _, text } => get_node_id(&text), + _ => panic!("Recieved unexpected response: {:?}", cluster_nodes_output), + } + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_handle_complete_server_disconnect_without_panicking() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.retries(2), + false, + ); + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + drop(cluster); + for _ in 0..5 { + let cmd = cmd("PING"); + let result = connection + .route_command(&cmd, RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + .await; + // TODO - this should be a NoConnectionError, but ATM we get the errors from the failing + assert!(result.is_err()); + // This will route to all nodes - different path through the code. + let result = connection.req_packed_command(&cmd).await; + // TODO - this should be a NoConnectionError, but ATM we get the errors from the failing + assert!(result.is_err()); + } + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_test_fast_reconnect() { + // Note the 3 seconds connection check to differentiate between notifications and periodic + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .retries(0) + .periodic_connections_checks(Duration::from_secs(3)) + }, + false, + ); + + // For tokio-comp, do 3 consequtive disconnects and ensure reconnects succeeds in less than 100ms, + // which is more than enough for local connections even with TLS. + // More than 1 run is done to ensure it is the fast reconnect notification that trigger the reconnect + // and not the periodic interval. + // For other async implementation, only periodic connection check is available, hence, + // do 1 run sleeping for periodic connection check interval, allowing it to reestablish connections + block_on_all(async move { + let mut disconnecting_con = cluster.async_connection(None).await; + let mut monitoring_con = cluster.async_connection(None).await; + + #[cfg(feature = "tokio-comp")] + let tries = 0..3; + #[cfg(not(feature = "tokio-comp"))] + let tries = 0..1; + + for _ in tries { + // get connection id + let mut cmd = redis::cmd("CLIENT"); + cmd.arg("ID"); + let res = disconnecting_con + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + 0, + SlotAddr::Master, + ))), + ) + .await; + assert!(res.is_ok()); + let res = res.unwrap(); + let id = { + match res { + Value::Int(id) => id, + _ => { + panic!("Wrong return value for CLIENT ID command: {:?}", res); + } + } + }; + + // ask server to kill the connection + let mut cmd = redis::cmd("CLIENT"); + cmd.arg("KILL").arg("ID").arg(id).arg("SKIPME").arg("NO"); + let res = disconnecting_con + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + 0, + SlotAddr::Master, + ))), + ) + .await; + // assert server has closed connection + assert_eq!(res, Ok(Value::Int(1))); + + #[cfg(feature = "tokio-comp")] + // ensure reconnect happened in less than 100ms + sleep(futures_time::time::Duration::from_millis(100)).await; + + #[cfg(not(feature = "tokio-comp"))] + // no fast notification is available, wait for 1 periodic check + overhead + sleep(futures_time::time::Duration::from_secs(3 + 1)).await; + + let mut cmd = redis::cmd("CLIENT"); + cmd.arg("LIST").arg("TYPE").arg("NORMAL"); + let res = monitoring_con + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + 0, + SlotAddr::Master, + ))), + ) + .await; + assert!(res.is_ok()); + let res = res.unwrap(); + let client_list: String = { + match res { + // RESP2 + Value::BulkString(client_info) => { + // ensure 4 connections - 2 for each client, its save to unwrap here + String::from_utf8(client_info).unwrap() + } + // RESP3 + Value::VerbatimString { format: _, text } => text, + _ => { + panic!("Wrong return type for CLIENT LIST command: {:?}", res); + } + } + }; + assert_eq!(client_list.chars().filter(|&x| x == '\n').count(), 4); + } + Ok(()) + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_restore_resp3_pubsub_state_passive_disconnect() { + let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); + let use_sharded = redis_ver.starts_with("7."); + + let mut client_subscriptions = PubSubSubscriptionInfo::from([( + PubSubSubscriptionKind::Exact, + HashSet::from([PubSubChannelOrPattern::from("test_channel".as_bytes())]), + )]); + + if use_sharded { + client_subscriptions.insert( + PubSubSubscriptionKind::Sharded, + HashSet::from([PubSubChannelOrPattern::from("test_channel_?".as_bytes())]), + ); + } + + // note topology change detection is not activated since no topology change is expected + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .retries(3) + .use_protocol(ProtocolVersion::RESP3) + .pubsub_subscriptions(client_subscriptions.clone()) + .periodic_connections_checks(Duration::from_secs(1)) + }, + false, + ); + + block_on_all(async move { + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + let mut _listening_con = cluster.async_connection(Some(tx.clone())).await; + // Note, publishing connection has the same pubsub config + let mut publishing_con = cluster.async_connection(None).await; + + // short sleep to allow the server to push subscription notification + sleep(futures_time::time::Duration::from_secs(1)).await; + + // validate subscriptions + validate_subscriptions(&client_subscriptions, &mut rx, false); + + // validate PUBLISH + let result = cmd("PUBLISH") + .arg("test_channel") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::Message, + vec![ + Value::BulkString("test_channel".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + + if use_sharded { + // validate SPUBLISH + let result = cmd("SPUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::SMessage, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + } + + // simulate passive disconnect + drop(cluster); + + // recreate the cluster, the assumtion is that the cluster is built with exactly the same params (ports, slots map...) + let _cluster = + TestClusterContext::new_with_cluster_client_builder(3, 0, |builder| builder, false); + + // sleep for 1 periodic_connections_checks + overhead + sleep(futures_time::time::Duration::from_secs(1 + 1)).await; + + // new subscription notifications due to resubscriptions + validate_subscriptions(&client_subscriptions, &mut rx, true); + + // validate PUBLISH + let result = cmd("PUBLISH") + .arg("test_channel") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::Message, + vec![ + Value::BulkString("test_channel".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + + if use_sharded { + // validate SPUBLISH + let result = cmd("SPUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::SMessage, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + } + + Ok(()) + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_restore_resp3_pubsub_state_after_scale_out() { + let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); + let use_sharded = redis_ver.starts_with("7."); + + let mut client_subscriptions = PubSubSubscriptionInfo::from([ + // test_channel_? is used as it maps to 14212 slot, which is the last node in both 3 and 6 node config + // (assuming slots allocation is monotonicaly increasing starting from node 0) + ( + PubSubSubscriptionKind::Exact, + HashSet::from([PubSubChannelOrPattern::from("test_channel_?".as_bytes())]), + ), + ]); + + if use_sharded { + client_subscriptions.insert( + PubSubSubscriptionKind::Sharded, + HashSet::from([PubSubChannelOrPattern::from("test_channel_?".as_bytes())]), + ); + } + + let slot_14212 = get_slot(b"test_channel_?"); + assert_eq!(slot_14212, 14212); + + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .retries(3) + .use_protocol(ProtocolVersion::RESP3) + .pubsub_subscriptions(client_subscriptions.clone()) + // periodic connection check is required to detect the disconnect from the last node + .periodic_connections_checks(Duration::from_secs(1)) + // periodic topology check is required to detect topology change + .periodic_topology_checks(Duration::from_secs(1)) + .slots_refresh_rate_limit(Duration::from_secs(0), 0) + }, + false, + ); + + block_on_all(async move { + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + let mut _listening_con = cluster.async_connection(Some(tx.clone())).await; + // Note, publishing connection has the same pubsub config + let mut publishing_con = cluster.async_connection(None).await; + + // short sleep to allow the server to push subscription notification + sleep(futures_time::time::Duration::from_secs(1)).await; + + // validate subscriptions + validate_subscriptions(&client_subscriptions, &mut rx, false); + + // validate PUBLISH + let result = cmd("PUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::Message, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + + if use_sharded { + // validate SPUBLISH + let result = cmd("SPUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::SMessage, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + } + + // drop and recreate a cluster with more nodes + drop(cluster); + + // recreate the cluster, the assumtion is that the cluster is built with exactly the same params (ports, slots map...) + let cluster = + TestClusterContext::new_with_cluster_client_builder(6, 0, |builder| builder, false); + + // assume slot 14212 will reside in the last node + let last_server_port = { + let addr = cluster.cluster.servers.last().unwrap().addr.clone(); + match addr { + redis::ConnectionAddr::TcpTls { + host: _, + port, + insecure: _, + tls_params: _, + } => port, + redis::ConnectionAddr::Tcp(_, port) => port, + _ => { + panic!("Wrong server address type: {:?}", addr); + } + } + }; + + // wait for new topology discovery + let max_requests = 5; + let mut i = 0; + let mut cmd = redis::cmd("INFO"); + cmd.arg("SERVER"); + loop { + if i == max_requests { + panic!("Failed to recover and discover new topology"); + } + i += 1; + + if let Ok(res) = publishing_con + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + slot_14212, + SlotAddr::Master, + ))), + ) + .await + { + match res { + Value::VerbatimString { format: _, text } => { + if text.contains(format!("tcp_port:{}", last_server_port).as_str()) { + // new topology rediscovered + break; + } + } + _ => { + panic!("Wrong return type for INFO SERVER command: {:?}", res); + } + } + sleep(futures_time::time::Duration::from_secs(1)).await; + } + } + + // sleep for one one cycle of topology refresh + sleep(futures_time::time::Duration::from_secs(1)).await; + + // validate PUBLISH + let result = redis::cmd("PUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + // allow message to propagate + sleep(futures_time::time::Duration::from_secs(1)).await; + + loop { + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + // ignore disconnection and subscription notifications due to resubscriptions + if kind == PushKind::Message { + assert_eq!( + data, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ); + break; + } + } + + if use_sharded { + // validate SPUBLISH + let result = redis::cmd("SPUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + // allow message to propagate + sleep(futures_time::time::Duration::from_secs(1)).await; + + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::SMessage, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + } + + drop(publishing_con); + drop(_listening_con); + + Ok(()) + }) + .unwrap(); + + block_on_all(async move { + sleep(futures_time::time::Duration::from_secs(10)).await; + Ok(()) + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_resp3_pubsub() { + let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); + let use_sharded = redis_ver.starts_with("7."); + + let mut client_subscriptions = PubSubSubscriptionInfo::from([ + ( + PubSubSubscriptionKind::Exact, + HashSet::from([PubSubChannelOrPattern::from("test_channel_?".as_bytes())]), + ), + ( + PubSubSubscriptionKind::Pattern, + HashSet::from([ + PubSubChannelOrPattern::from("test_*".as_bytes()), + PubSubChannelOrPattern::from("*".as_bytes()), + ]), + ), + ]); + + if use_sharded { + client_subscriptions.insert( + PubSubSubscriptionKind::Sharded, + HashSet::from([PubSubChannelOrPattern::from("test_channel_?".as_bytes())]), + ); + } + + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .retries(3) + .use_protocol(ProtocolVersion::RESP3) + .pubsub_subscriptions(client_subscriptions.clone()) + }, + false, + ); + + block_on_all(async move { + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + let mut connection = cluster.async_connection(Some(tx.clone())).await; + + // short sleep to allow the server to push subscription notification + sleep(futures_time::time::Duration::from_secs(1)).await; + + validate_subscriptions(&client_subscriptions, &mut rx, false); + + let slot_14212 = get_slot(b"test_channel_?"); + assert_eq!(slot_14212, 14212); + + let slot_0_route = + redis::cluster_routing::Route::new(0, redis::cluster_routing::SlotAddr::Master); + let node_0_route = + redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(slot_0_route); + + // node 0 route is used to ensure that the publish is propagated correctly + let result = connection + .route_command( + redis::Cmd::new() + .arg("PUBLISH") + .arg("test_channel_?") + .arg("test_message"), + RoutingInfo::SingleNode(node_0_route.clone()), + ) + .await; + assert!(result.is_ok()); + + sleep(futures_time::time::Duration::from_secs(1)).await; + + let mut pmsg_cnt = 0; + let mut msg_cnt = 0; + for _ in 0..3 { + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data: _ } = result.unwrap(); + assert!(kind == PushKind::Message || kind == PushKind::PMessage); + if kind == PushKind::Message { + msg_cnt += 1; + } else { + pmsg_cnt += 1; + } + } + assert_eq!(msg_cnt, 1); + assert_eq!(pmsg_cnt, 2); + + if use_sharded { + let result = cmd("SPUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut connection) + .await; + assert_eq!(result, Ok(Value::Int(1))); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::SMessage, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + } + + Ok(()) + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_periodic_checks_update_topology_after_failover() { + // This test aims to validate the functionality of periodic topology checks by detecting and updating topology changes. + // We will repeatedly execute CLUSTER NODES commands against the primary node responsible for slot 0, recording its node ID. + // Once we've successfully completed commands with the current primary, we will initiate a failover within the same shard. + // Since we are not executing key-based commands, we won't encounter MOVED errors that trigger a slot refresh. + // Consequently, we anticipate that only the periodic topology check will detect this change and trigger topology refresh. + // If successful, the node to which we route the CLUSTER NODES command should be the newly promoted node with a different node ID. + let cluster = TestClusterContext::new_with_cluster_client_builder( + 6, + 1, + |builder| { + builder + .periodic_topology_checks(Duration::from_millis(10)) + // Disable the rate limiter to refresh slots immediately on all MOVED errors + .slots_refresh_rate_limit(Duration::from_secs(0), 0) + }, + false, + ); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let mut prev_master_id = "".to_string(); + let max_requests = 5000; + let mut i = 0; + loop { + if i == 10 { + let mut cmd = redis::cmd("CLUSTER"); + cmd.arg("FAILOVER"); + cmd.arg("TAKEOVER"); + let res = connection + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode( + Route::new(0, SlotAddr::ReplicaRequired), + )), + ) + .await; + assert!(res.is_ok()); + } else if i == max_requests { + break; + } else { + let mut cmd = redis::cmd("CLUSTER"); + cmd.arg("NODES"); + let res = connection + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode( + Route::new(0, SlotAddr::Master), + )), + ) + .await + .expect("Failed executing CLUSTER NODES"); + let node_id = get_queried_node_id_if_master(res); + if let Some(current_master_id) = node_id { + if prev_master_id.is_empty() { + prev_master_id = current_master_id; + } else if prev_master_id != current_master_id { + return Ok::<_, RedisError>(()); + } + } + } + i += 1; + let _ = sleep(futures_time::time::Duration::from_millis(10)).await; + } + panic!("Topology change wasn't found!"); + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_recover_disconnected_management_connections() { + // This test aims to verify that the management connections used for periodic checks are reconnected, in case that they get killed. + // In order to test this, we choose a single node, kill all connections to it which aren't user connections, and then wait until new + // connections are created. + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder.periodic_topology_checks(Duration::from_millis(10)) + // Disable the rate limiter to refresh slots immediately + .slots_refresh_rate_limit(Duration::from_secs(0), 0) + }, + false, + ); + + block_on_all(async move { + let routing = RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + 1, + SlotAddr::Master, + ))); + + let mut connection = cluster.async_connection(None).await; + let max_requests = 5000; + + let connections = + get_clients_names_to_ids(&mut connection, routing.clone().into()).await; + assert!(connections.contains_key(MANAGEMENT_CONN_NAME)); + let management_conn_id = connections.get(MANAGEMENT_CONN_NAME).unwrap(); + + // Get the connection ID of the management connection + kill_connection(&mut connection, management_conn_id).await; + + let connections = + get_clients_names_to_ids(&mut connection, routing.clone().into()).await; + assert!(!connections.contains_key(MANAGEMENT_CONN_NAME)); + + for _ in 0..max_requests { + let _ = sleep(futures_time::time::Duration::from_millis(10)).await; + + let connections = + get_clients_names_to_ids(&mut connection, routing.clone().into()).await; + if connections.contains_key(MANAGEMENT_CONN_NAME) { + return Ok(()); + } + } + + panic!("Topology connection didn't reconnect!"); + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_with_client_name() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.client_name(RedisCluster::client_name().to_string()), + false, + ); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let client_info: String = cmd("CLIENT") + .arg("INFO") + .query_async(&mut connection) + .await + .unwrap(); + + let client_attrs = parse_client_info(&client_info); + + assert!( + client_attrs.contains_key("name"), + "Could not detect the 'name' attribute in CLIENT INFO output" + ); + + assert_eq!( + client_attrs["name"], + RedisCluster::client_name(), + "Incorrect client name, expecting: {}, got {}", + RedisCluster::client_name(), + client_attrs["name"] + ); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_reroute_from_replica_if_in_loading_state() { + /* Test replica in loading state. The expected behaviour is that the request will be directed to a different replica or the primary. + depends on the read from replica policy. */ + let name = "test_async_cluster_reroute_from_replica_if_in_loading_state"; + + let load_errors: Arc<_> = Arc::new(std::sync::Mutex::new(vec![])); + let load_errors_clone = load_errors.clone(); + + // requests should route to replica + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).read_from_replicas(), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica_using_config( + name, + cmd, + Some(vec![MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380, 6381], + slot_range: (0..16383), + }]), + )?; + match port { + 6380 | 6381 => { + load_errors_clone.lock().unwrap().push(port); + Err(parse_redis_value(b"-LOADING\r\n")) + } + 6379 => Err(Ok(Value::BulkString(b"123".to_vec()))), + _ => panic!("Wrong node"), + } + }, + ); + for _n in 0..3 { + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(123))); + } + + let mut load_errors_guard = load_errors.lock().unwrap(); + load_errors_guard.sort(); + + // We expected to get only 2 loading error since the 2 replicas are in loading state. + // The third iteration will be directed to the primary since the connections of the replicas were removed. + assert_eq!(*load_errors_guard, vec![6380, 6381]); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_read_from_primary_when_primary_loading() { + // Test primary in loading state. The expected behaviour is that the request will be retried until the primary is no longer in loading state. + let name = "test_async_cluster_read_from_primary_when_primary_loading"; + + const RETRIES: u32 = 3; + const ITERATIONS: u32 = 2; + let load_errors = Arc::new(AtomicU32::new(0)); + let load_errors_clone = load_errors.clone(); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica_using_config( + name, + cmd, + Some(vec![MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380, 6381], + slot_range: (0..16383), + }]), + )?; + match port { + 6379 => { + let attempts = load_errors_clone.fetch_add(1, Ordering::Relaxed) + 1; + if attempts % RETRIES == 0 { + Err(Ok(Value::BulkString(b"123".to_vec()))) + } else { + Err(parse_redis_value(b"-LOADING\r\n")) + } + } + _ => panic!("Wrong node"), + } + }, + ); + for _n in 0..ITERATIONS { + runtime + .block_on( + cmd("GET") + .arg("test") + .query_async::<_, Value>(&mut connection), + ) + .unwrap(); + } + + assert_eq!(load_errors.load(Ordering::Relaxed), ITERATIONS * RETRIES); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_can_be_created_with_partial_slot_coverage() { + let name = "test_async_cluster_can_be_created_with_partial_slot_coverage"; + let slots_config = Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![], + slot_range: (0..8000), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: vec![], + slot_range: (8201..16380), + }, + ]); + + let MockEnv { + async_connection: mut connection, + handler: _handler, + runtime, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], _| { + respond_startup_with_replica_using_config( + name, + received_cmd, + slots_config.clone(), + )?; + Err(Ok(Value::SimpleString("PONG".into()))) + }, + ); + + let res = runtime.block_on(connection.req_packed_command(&redis::cmd("PING"))); + assert!(res.is_ok()); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_reconnect_after_complete_server_disconnect() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder.retries(2) + // Disable the rate limiter to refresh slots immediately + .slots_refresh_rate_limit(Duration::from_secs(0), 0) + }, + false, + ); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + drop(cluster); + let cmd = cmd("PING"); + + let result = connection + .route_command(&cmd, RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + .await; + // TODO - this should be a NoConnectionError, but ATM we get the errors from the failing + assert!(result.is_err()); + + // This will route to all nodes - different path through the code. + let result = connection.req_packed_command(&cmd).await; + // TODO - this should be a NoConnectionError, but ATM we get the errors from the failing + assert!(result.is_err()); + + let _cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.retries(2), + false, + ); + + let max_requests = 5; + let mut i = 0; + let mut last_err = None; + loop { + if i == max_requests { + break; + } + i += 1; + match connection.req_packed_command(&cmd).await { + Ok(result) => { + assert_eq!(result, Value::SimpleString("PONG".to_string())); + return Ok::<_, RedisError>(()); + } + Err(err) => { + last_err = Some(err); + let _ = sleep(futures_time::time::Duration::from_secs(1)).await; + } + } + } + panic!("Failed to recover after all nodes went down. Last error: {last_err:?}"); + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_reconnect_after_complete_server_disconnect_route_to_many() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.retries(3), + false, + ); + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + drop(cluster); + + // recreate cluster + let _cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.retries(2), + false, + ); + + let cmd = cmd("PING"); + + let max_requests = 5; + let mut i = 0; + let mut last_err = None; + loop { + if i == max_requests { + break; + } + i += 1; + // explicitly route to all primaries and request all succeeded + match connection + .route_command( + &cmd, + RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllMasters, + Some(redis::cluster_routing::ResponsePolicy::AllSucceeded), + )), + ) + .await + { + Ok(result) => { + assert_eq!(result, Value::SimpleString("PONG".to_string())); + return Ok::<_, RedisError>(()); + } + Err(err) => { + last_err = Some(err); + let _ = sleep(futures_time::time::Duration::from_secs(1)).await; + } + } + } + panic!("Failed to recover after all nodes went down. Last error: {last_err:?}"); + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_blocking_command_when_cluster_drops() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.retries(3), + false, + ); + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + futures::future::join( + async { + let res = connection.blpop::<&str, f64>("foo", 0.0).await; + assert!(res.is_err()); + println!("blpop returned error {:?}", res.map_err(|e| e.to_string())); + }, + async { + let _ = sleep(futures_time::time::Duration::from_secs(3)).await; + drop(cluster); + }, + ) + .await; + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_saves_reconnected_connection() { + let name = "test_async_cluster_saves_reconnected_connection"; + let ping_attempts = Arc::new(AtomicI32::new(0)); + let ping_attempts_clone = ping_attempts.clone(); + let get_attempts = AtomicI32::new(0); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(1), + name, + move |cmd: &[u8], port| { + if port == 6380 { + respond_startup_two_nodes(name, cmd)?; + return Err(parse_redis_value( + format!("-MOVED 123 {name}:6379\r\n").as_bytes(), + )); + } + + if contains_slice(cmd, b"PING") { + let connect_attempt = ping_attempts_clone.fetch_add(1, Ordering::Relaxed); + let past_get_attempts = get_attempts.load(Ordering::Relaxed); + // We want connection checks to fail after the first GET attempt, until it retries. Hence, we wait for 5 PINGs - + // 1. initial connection, + // 2. refresh slots on client creation, + // 3. refresh_connections `check_connection` after first GET failed, + // 4. refresh_connections `connect_and_check` after first GET failed, + // 5. reconnect on 2nd GET attempt. + // more than 5 attempts mean that the server reconnects more than once, which is the behavior we're testing against. + if past_get_attempts != 1 || connect_attempt > 3 { + respond_startup_two_nodes(name, cmd)?; + } + if connect_attempt > 5 { + panic!("Too many pings!"); + } + Err(Err(RedisError::from(( + ErrorKind::FatalSendError, + "mock-io-error", + )))) + } else { + respond_startup_two_nodes(name, cmd)?; + let past_get_attempts = get_attempts.fetch_add(1, Ordering::Relaxed); + // we fail the initial GET request, and after that we'll fail the first reconnect attempt, in the `refresh_connections` attempt. + if past_get_attempts == 0 { + // Error once with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + Err(Err(RedisError::from(( + ErrorKind::FatalSendError, + "mock-io-error", + )))) + } else { + Err(Ok(Value::BulkString(b"123".to_vec()))) + } + } + }, + ); + + for _ in 0..4 { + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + // If you need to change the number here due to a change in the cluster, you probably also need to adjust the test. + // See the PING counts above to explain why 5 is the target number. + assert_eq!(ping_attempts.load(Ordering::Acquire), 5); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_periodic_checks_use_management_connection() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder.periodic_topology_checks(Duration::from_millis(10)) + // Disable the rate limiter to refresh slots immediately on the periodic checks + .slots_refresh_rate_limit(Duration::from_secs(0), 0) + }, + false, + ); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let mut client_list = "".to_string(); + let max_requests = 1000; + let mut i = 0; + loop { + if i == max_requests { + break; + } else { + client_list = cmd("CLIENT") + .arg("LIST") + .query_async::<_, String>(&mut connection) + .await + .expect("Failed executing CLIENT LIST"); + let mut client_list_parts = client_list.split('\n'); + if client_list_parts + .any(|line| line.contains(MANAGEMENT_CONN_NAME) && line.contains("cmd=cluster")) + && client_list.matches(MANAGEMENT_CONN_NAME).count() == 1 { + return Ok::<_, RedisError>(()); + } + } + i += 1; + let _ = sleep(futures_time::time::Duration::from_millis(10)).await; + } + panic!("Couldn't find a management connection or the connection wasn't used to execute CLUSTER SLOTS {:?}", client_list); + }) + .unwrap(); + } + + async fn get_clients_names_to_ids( + connection: &mut ClusterConnection, + routing: Option, + ) -> HashMap { + let mut client_list_cmd = redis::cmd("CLIENT"); + client_list_cmd.arg("LIST"); + let value = match routing { + Some(routing) => connection.route_command(&client_list_cmd, routing).await, + None => connection.req_packed_command(&client_list_cmd).await, + } + .unwrap(); + let string = String::from_owned_redis_value(value).unwrap(); + string + .split('\n') + .filter_map(|line| { + if line.is_empty() { + return None; + } + let key_values = line + .split(' ') + .filter_map(|value| { + let mut split = value.split('='); + match (split.next(), split.next()) { + (Some(key), Some(val)) => Some((key, val)), + _ => None, + } + }) + .collect::>(); + match (key_values.get("name"), key_values.get("id")) { + (Some(key), Some(val)) if !val.is_empty() => { + Some((key.to_string(), val.to_string())) + } + _ => None, + } + }) + .collect() + } + + async fn kill_connection(killer_connection: &mut ClusterConnection, connection_to_kill: &str) { + let default_routing = RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode( + Route::new(0, SlotAddr::Master), + )); + kill_connection_with_routing(killer_connection, connection_to_kill, default_routing).await; + } + + async fn kill_connection_with_routing( + killer_connection: &mut ClusterConnection, + connection_to_kill: &str, + routing: RoutingInfo, + ) { + let mut cmd = redis::cmd("CLIENT"); + cmd.arg("KILL"); + cmd.arg("ID"); + cmd.arg(connection_to_kill); + // Kill the management connection for the routing node + assert!(killer_connection.route_command(&cmd, routing).await.is_ok()); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_only_management_connection_is_reconnected_after_connection_failure() { + // This test will check two aspects: + // 1. Ensuring that after a disconnection in the management connection, a new management connection is established. + // 2. Confirming that a failure in the management connection does not impact the user connection, which should remain intact. + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.periodic_topology_checks(Duration::from_millis(10)), + false, + ); + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let _client_list = "".to_string(); + let max_requests = 500; + let mut i = 0; + // Set the name of the client connection to 'user-connection', so we'll be able to identify it later on + assert!(cmd("CLIENT") + .arg("SETNAME") + .arg("user-connection") + .query_async::<_, Value>(&mut connection) + .await + .is_ok()); + // Get the client list + let names_to_ids = get_clients_names_to_ids(&mut connection, Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new(0, SlotAddr::Master))))).await; + + // Get the connection ID of 'user-connection' + let user_conn_id = names_to_ids.get("user-connection").unwrap(); + // Get the connection ID of the management connection + let management_conn_id = names_to_ids.get(MANAGEMENT_CONN_NAME).unwrap(); + // Get another connection that will be used to kill the management connection + let mut killer_connection = cluster.async_connection(None).await; + kill_connection(&mut killer_connection, management_conn_id).await; + loop { + // In this loop we'll wait for the new management connection to be established + if i == max_requests { + break; + } else { + let names_to_ids = get_clients_names_to_ids(&mut connection, Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new(0, SlotAddr::Master))))).await; + if names_to_ids.contains_key(MANAGEMENT_CONN_NAME) { + // A management connection is found + let curr_management_conn_id = + names_to_ids.get(MANAGEMENT_CONN_NAME).unwrap(); + let curr_user_conn_id = + names_to_ids.get("user-connection").unwrap(); + // Confirm that the management connection has a new connection ID, and verify that the user connection remains unaffected. + if (curr_management_conn_id != management_conn_id) + && (curr_user_conn_id == user_conn_id) + { + return Ok::<_, RedisError>(()); + } + } else { + i += 1; + let _ = sleep(futures_time::time::Duration::from_millis(50)).await; + continue; + } + } + } + panic!( + "No reconnection of the management connection found, or there was an unwantedly reconnection of the user connections. + \nprev_management_conn_id={:?},prev_user_conn_id={:?}\nclient list={:?}", + management_conn_id, user_conn_id, names_to_ids + ); + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_dont_route_to_a_random_on_non_key_based_cmd() { + // This test verifies that non-key-based commands do not get routed to a random node + // when no connection is found for the given route. Instead, the appropriate error + // should be raised. + let name = "test_async_cluster_dont_route_to_a_random_on_non_key_based_cmd"; + let request_counter = Arc::new(AtomicU32::new(0)); + let cloned_req_counter = request_counter.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(1), + name, + move |received_cmd: &[u8], _| { + let slots_config_vec = vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![], + slot_range: (0_u16..8000_u16), + }, + MockSlotRange { + primary_port: 6380, + replica_ports: vec![], + // Don't cover all slots + slot_range: (8001_u16..12000_u16), + }, + ]; + respond_startup_with_config(name, received_cmd, Some(slots_config_vec), false)?; + // If requests are sent to random nodes, they will be caught and counted here. + request_counter.fetch_add(1, Ordering::Relaxed); + Err(Ok(Value::Nil)) + }, + ); + + runtime + .block_on(async move { + let uncovered_slot = 16000; + let route = redis::cluster_routing::Route::new( + uncovered_slot, + redis::cluster_routing::SlotAddr::Master, + ); + let single_node_route = + redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(route); + let routing = RoutingInfo::SingleNode(single_node_route); + let res = connection + .route_command(&redis::cmd("FLUSHALL"), routing) + .await; + assert!(res.is_err()); + let res_err = res.unwrap_err(); + assert_eq!( + res_err.kind(), + ErrorKind::ConnectionNotFoundForRoute, + "{:?}", + res_err + ); + assert_eq!(cloned_req_counter.load(Ordering::Relaxed), 0); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_route_to_random_on_key_based_cmd() { + // This test verifies that key-based commands get routed to a random node + // when no connection is found for the given route. The command should + // then be redirected correctly by the server's MOVED error. + let name = "test_async_cluster_route_to_random_on_key_based_cmd"; + let request_counter = Arc::new(AtomicU32::new(0)); + let cloned_req_counter = request_counter.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]), + name, + move |received_cmd: &[u8], _| { + let slots_config_vec = vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![], + slot_range: (0_u16..8000_u16), + }, + MockSlotRange { + primary_port: 6380, + replica_ports: vec![], + // Don't cover all slots + slot_range: (8001_u16..12000_u16), + }, + ]; + respond_startup_with_config(name, received_cmd, Some(slots_config_vec), false)?; + if contains_slice(received_cmd, b"GET") { + if request_counter.fetch_add(1, Ordering::Relaxed) == 0 { + return Err(parse_redis_value( + format!("-MOVED 12182 {name}:6380\r\n").as_bytes(), + )); + } else { + return Err(Ok(Value::SimpleString("bar".into()))); + } + } + panic!("unexpected command {:?}", received_cmd); + }, + ); + + runtime + .block_on(async move { + // The keyslot of "foo" is 12182 and it isn't covered by any node, so we expect the + // request to be routed to a random node and then to be redirected to the MOVED node (2 requests in total) + let res: String = connection.get("foo").await.unwrap(); + assert_eq!(res, "bar".to_string()); + assert_eq!(cloned_req_counter.load(Ordering::Relaxed), 2); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_do_not_retry_when_receiver_was_dropped() { + let name = "test_async_cluster_do_not_retry_when_receiver_was_dropped"; + let cmd = cmd("FAKE_COMMAND"); + let packed_cmd = cmd.get_packed_command(); + let request_counter = Arc::new(AtomicU32::new(0)); + let cloned_req_counter = request_counter.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(5) + .max_retry_wait(2) + .min_retry_wait(2), + name, + move |received_cmd: &[u8], _| { + respond_startup(name, received_cmd)?; + + if received_cmd == packed_cmd { + cloned_req_counter.fetch_add(1, Ordering::Relaxed); + return Err(Err((ErrorKind::TryAgain, "seriously, try again").into())); + } + + Err(Ok(Value::Okay)) + }, + ); + + runtime.block_on(async move { + let err = cmd + .query_async::<_, Value>(&mut connection) + .timeout(futures_time::time::Duration::from_millis(1)) + .await + .unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::TimedOut); + + // we sleep here, to allow the cluster connection time to retry. We expect it won't, but without this + // sleep the test will complete before the the runtime gave the connection time to retry, which would've made the + // test pass regardless of whether the connection tries retrying or not. + sleep(Duration::from_millis(10).into()).await; + }); + + assert_eq!(request_counter.load(Ordering::Relaxed), 1); + } + + #[cfg(feature = "tls-rustls")] + mod mtls_test { + use crate::support::mtls_test::create_cluster_client_from_cluster; + use redis::ConnectionInfo; + + use super::*; + + #[test] + #[serial_test::serial] + fn test_async_cluster_basic_cmd_with_mtls() { + let cluster = TestClusterContext::new_with_mtls(3, 0); + block_on_all(async move { + let client = create_cluster_client_from_cluster(&cluster, true).unwrap(); + let mut connection = client.get_async_connection(None).await.unwrap(); + cmd("SET") + .arg("test") + .arg("test_data") + .query_async(&mut connection) + .await?; + let res: String = cmd("GET") + .arg("test") + .clone() + .query_async(&mut connection) + .await?; + assert_eq!(res, "test_data"); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[serial_test::serial] + fn test_async_cluster_should_not_connect_without_mtls_enabled() { + let cluster = TestClusterContext::new_with_mtls(3, 0); + block_on_all(async move { + let client = create_cluster_client_from_cluster(&cluster, false).unwrap(); + let connection = client.get_async_connection(None).await; + match cluster.cluster.servers.first().unwrap().connection_info() { + ConnectionInfo { + addr: redis::ConnectionAddr::TcpTls { .. }, + .. + } => { + if connection.is_ok() { + panic!("Must NOT be able to connect without client credentials if server accepts TLS"); + } + } + _ => { + if let Err(e) = connection { + panic!("Must be able to connect without client credentials if server does NOT accept TLS: {e:?}"); + } + } + } + Ok::<_, RedisError>(()) + }).unwrap(); + } + } +} diff --git a/glide-core/redis-rs/redis/tests/test_cluster_scan.rs b/glide-core/redis-rs/redis/tests/test_cluster_scan.rs new file mode 100644 index 0000000000..cfc4bae594 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_cluster_scan.rs @@ -0,0 +1,860 @@ +#![cfg(feature = "cluster-async")] +mod support; + +#[cfg(test)] +mod test_cluster_scan_async { + use crate::support::*; + use rand::Rng; + use redis::cluster_routing::{RoutingInfo, SingleNodeRoutingInfo}; + use redis::{cmd, from_redis_value, ObjectType, RedisResult, ScanStateRC, Value}; + use std::time::Duration; + + async fn kill_one_node( + cluster: &TestClusterContext, + slot_distribution: Vec<(String, String, String, Vec>)>, + ) -> RoutingInfo { + let mut cluster_conn = cluster.async_connection(None).await; + let distribution_clone = slot_distribution.clone(); + let index_of_random_node = rand::thread_rng().gen_range(0..slot_distribution.len()); + let random_node = distribution_clone.get(index_of_random_node).unwrap(); + let random_node_route_info = RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: random_node.1.clone(), + port: random_node.2.parse::().unwrap(), + }); + let random_node_id = &random_node.0; + // Create connections to all nodes + for node in &distribution_clone { + if random_node_id == &node.0 { + continue; + } + let node_route = RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: node.1.clone(), + port: node.2.parse::().unwrap(), + }); + + let mut forget_cmd = cmd("CLUSTER"); + forget_cmd.arg("FORGET").arg(random_node_id); + let _: RedisResult = cluster_conn + .route_command(&forget_cmd, node_route.clone()) + .await; + } + let mut shutdown_cmd = cmd("SHUTDOWN"); + shutdown_cmd.arg("NOSAVE"); + let _: RedisResult = cluster_conn + .route_command(&shutdown_cmd, random_node_route_info.clone()) + .await; + random_node_route_info + } + + #[tokio::test] + #[serial_test::serial] + async fn test_async_cluster_scan() { + let cluster = TestClusterContext::new(3, 0); + let mut connection = cluster.async_connection(None).await; + + // Set some keys + for i in 0..10 { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + } + + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = vec![]; + loop { + let (next_cursor, scan_keys): (ScanStateRC, Vec) = connection + .cluster_scan(scan_state_rc, None, None) + .await + .unwrap(); + scan_state_rc = next_cursor; + let mut scan_keys = scan_keys + .into_iter() + .map(|v| from_redis_value(&v).unwrap()) + .collect::>(); // Change the type of `keys` to `Vec` + keys.append(&mut scan_keys); + if scan_state_rc.is_finished() { + break; + } + } + // Check if all keys were scanned + keys.sort(); + keys.dedup(); + for (i, key) in keys.iter().enumerate() { + assert_eq!(key.to_owned(), format!("key{}", i)); + } + } + + #[tokio::test] + #[serial_test::serial] // test cluster scan with slot migration in the middle + async fn test_async_cluster_scan_with_migration() { + let cluster = TestClusterContext::new(3, 0); + + let mut connection = cluster.async_connection(None).await; + // Set some keys + let mut expected_keys: Vec = Vec::new(); + + for i in 0..1000 { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + } + + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = Vec::new(); + let mut count = 0; + loop { + count += 1; + let (next_cursor, scan_keys): (ScanStateRC, Vec) = connection + .cluster_scan(scan_state_rc, None, None) + .await + .unwrap(); + scan_state_rc = next_cursor; + let scan_keys = scan_keys + .into_iter() + .map(|v| from_redis_value(&v).unwrap()) + .collect::>(); + keys.extend(scan_keys); + if scan_state_rc.is_finished() { + break; + } + if count == 5 { + let mut cluster_nodes = cluster.get_cluster_nodes().await; + let slot_distribution = cluster.get_slots_ranges_distribution(&cluster_nodes); + cluster + .migrate_slots_from_node_to_another(slot_distribution.clone()) + .await; + for node in &slot_distribution { + let ready = cluster + .wait_for_connection_is_ready(&RoutingInfo::SingleNode( + SingleNodeRoutingInfo::ByAddress { + host: node.1.clone(), + port: node.2.parse::().unwrap(), + }, + )) + .await; + match ready { + Ok(_) => {} + Err(e) => { + println!("error: {:?}", e); + break; + } + } + } + + cluster_nodes = cluster.get_cluster_nodes().await; + // Compare slot distribution before and after migration + let new_slot_distribution = cluster.get_slots_ranges_distribution(&cluster_nodes); + assert_ne!(slot_distribution, new_slot_distribution); + } + } + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + assert_eq!(keys, expected_keys); + } + + #[tokio::test] + #[serial_test::serial] // test cluster scan with node fail in the middle + async fn test_async_cluster_scan_with_fail() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.retries(1), + false, + ); + let mut connection = cluster.async_connection(None).await; + // Set some keys + for i in 0..1000 { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + } + + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = Vec::new(); + let mut count = 0; + let mut result: RedisResult = Ok(Value::Nil); + loop { + count += 1; + let scan_response: RedisResult<(ScanStateRC, Vec)> = + connection.cluster_scan(scan_state_rc, None, None).await; + let (next_cursor, scan_keys) = match scan_response { + Ok((cursor, keys)) => (cursor, keys), + Err(e) => { + result = Err(e); + break; + } + }; + scan_state_rc = next_cursor; + keys.extend(scan_keys.into_iter().map(|v| from_redis_value(&v).unwrap())); + if scan_state_rc.is_finished() { + break; + } + if count == 5 { + let cluster_nodes = cluster.get_cluster_nodes().await; + let slot_distribution = cluster.get_slots_ranges_distribution(&cluster_nodes); + // simulate node failure + let killed_node_routing = kill_one_node(&cluster, slot_distribution.clone()).await; + let ready = cluster.wait_for_fail_to_finish(&killed_node_routing).await; + match ready { + Ok(_) => {} + Err(e) => { + println!("error: {:?}", e); + break; + } + } + let cluster_nodes = cluster.get_cluster_nodes().await; + let new_slot_distribution = cluster.get_slots_ranges_distribution(&cluster_nodes); + assert_ne!(slot_distribution, new_slot_distribution); + } + } + // We expect an error of finding address + assert!(result.is_err()); + } + + #[tokio::test] + #[serial_test::serial] // Test cluster scan with killing all masters during scan + async fn test_async_cluster_scan_with_all_masters_down() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 6, + 1, + |builder| { + builder + .slots_refresh_rate_limit(Duration::from_secs(0), 0) + .retries(1) + }, + false, + ); + + let mut connection = cluster.async_connection(None).await; + + let mut expected_keys: Vec = Vec::new(); + + cluster.wait_for_cluster_up(); + + let mut cluster_nodes = cluster.get_cluster_nodes().await; + + let slot_distribution = cluster.get_slots_ranges_distribution(&cluster_nodes); + let masters = cluster.get_masters(&cluster_nodes).await; + let replicas = cluster.get_replicas(&cluster_nodes).await; + + for i in 0..1000 { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + } + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = Vec::new(); + let mut count = 0; + loop { + count += 1; + let scan_response: RedisResult<(ScanStateRC, Vec)> = + connection.cluster_scan(scan_state_rc, None, None).await; + if scan_response.is_err() { + println!("error: {:?}", scan_response); + } + let (next_cursor, scan_keys) = scan_response.unwrap(); + scan_state_rc = next_cursor; + keys.extend(scan_keys.into_iter().map(|v| from_redis_value(&v).unwrap())); + if scan_state_rc.is_finished() { + break; + } + if count == 5 { + for replica in replicas.iter() { + let mut failover_cmd = cmd("CLUSTER"); + let _: RedisResult = connection + .route_command( + failover_cmd.arg("FAILOVER").arg("TAKEOVER"), + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: replica[1].clone(), + port: replica[2].parse::().unwrap(), + }), + ) + .await; + let ready = cluster + .wait_for_connection_is_ready(&RoutingInfo::SingleNode( + SingleNodeRoutingInfo::ByAddress { + host: replica[1].clone(), + port: replica[2].parse::().unwrap(), + }, + )) + .await; + match ready { + Ok(_) => {} + Err(e) => { + println!("error: {:?}", e); + break; + } + } + } + + for master in masters.iter() { + for replica in replicas.clone() { + let mut forget_cmd = cmd("CLUSTER"); + forget_cmd.arg("FORGET").arg(master[0].clone()); + let _: RedisResult = connection + .route_command( + &forget_cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: replica[1].clone(), + port: replica[2].parse::().unwrap(), + }), + ) + .await; + } + } + for master in masters.iter() { + let mut shut_cmd = cmd("SHUTDOWN"); + shut_cmd.arg("NOSAVE"); + let _ = connection + .route_command( + &shut_cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: master[1].clone(), + port: master[2].parse::().unwrap(), + }), + ) + .await; + let ready = cluster + .wait_for_fail_to_finish(&RoutingInfo::SingleNode( + SingleNodeRoutingInfo::ByAddress { + host: master[1].clone(), + port: master[2].parse::().unwrap(), + }, + )) + .await; + match ready { + Ok(_) => {} + Err(e) => { + println!("error: {:?}", e); + break; + } + } + } + for replica in replicas.iter() { + let ready = cluster + .wait_for_connection_is_ready(&RoutingInfo::SingleNode( + SingleNodeRoutingInfo::ByAddress { + host: replica[1].clone(), + port: replica[2].parse::().unwrap(), + }, + )) + .await; + match ready { + Ok(_) => {} + Err(e) => { + println!("error: {:?}", e); + break; + } + } + } + cluster_nodes = cluster.get_cluster_nodes().await; + let new_slot_distribution = cluster.get_slots_ranges_distribution(&cluster_nodes); + assert_ne!(slot_distribution, new_slot_distribution); + } + } + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + assert_eq!(keys, expected_keys); + } + + #[tokio::test] + #[serial_test::serial] + // Test cluster scan with killing all replicas during scan + async fn test_async_cluster_scan_with_all_replicas_down() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 6, + 1, + |builder| { + builder + .slots_refresh_rate_limit(Duration::from_secs(0), 0) + .retries(1) + }, + false, + ); + + let mut connection = cluster.async_connection(None).await; + + let mut expected_keys: Vec = Vec::new(); + + for server in cluster.cluster.servers.iter() { + let address = server.addr.clone().to_string(); + let host_and_port = address.split(':'); + let host = host_and_port.clone().next().unwrap().to_string(); + let port = host_and_port + .clone() + .last() + .unwrap() + .parse::() + .unwrap(); + let ready = cluster + .wait_for_connection_is_ready(&RoutingInfo::SingleNode( + SingleNodeRoutingInfo::ByAddress { host, port }, + )) + .await; + match ready { + Ok(_) => {} + Err(e) => { + println!("error: {:?}", e); + break; + } + } + } + + let cluster_nodes = cluster.get_cluster_nodes().await; + + let replicas = cluster.get_replicas(&cluster_nodes).await; + + for i in 0..1000 { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + } + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = Vec::new(); + let mut count = 0; + loop { + count += 1; + let scan_response: RedisResult<(ScanStateRC, Vec)> = + connection.cluster_scan(scan_state_rc, None, None).await; + if scan_response.is_err() { + println!("error: {:?}", scan_response); + } + let (next_cursor, scan_keys) = scan_response.unwrap(); + scan_state_rc = next_cursor; + keys.extend(scan_keys.into_iter().map(|v| from_redis_value(&v).unwrap())); + if scan_state_rc.is_finished() { + break; + } + if count == 5 { + for replica in replicas.iter() { + let mut shut_cmd = cmd("SHUTDOWN"); + shut_cmd.arg("NOSAVE"); + let ready: RedisResult = connection + .route_command( + &shut_cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: replica[1].clone(), + port: replica[2].parse::().unwrap(), + }), + ) + .await; + match ready { + Ok(_) => {} + Err(e) => { + println!("error: {:?}", e); + break; + } + } + } + let new_cluster_nodes = cluster.get_cluster_nodes().await; + assert_ne!(cluster_nodes, new_cluster_nodes); + } + } + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + assert_eq!(keys, expected_keys); + } + #[tokio::test] + #[serial_test::serial] + // Test cluster scan with setting keys for each iteration + async fn test_async_cluster_scan_set_in_the_middle() { + let cluster = TestClusterContext::new(3, 0); + let mut connection = cluster.async_connection(None).await; + let mut expected_keys: Vec = Vec::new(); + let mut i = 0; + // Set some keys + loop { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + i += 1; + if i == 1000 { + break; + } + } + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = vec![]; + loop { + let (next_cursor, scan_keys): (ScanStateRC, Vec) = connection + .cluster_scan(scan_state_rc, None, None) + .await + .unwrap(); + scan_state_rc = next_cursor; + let mut scan_keys = scan_keys + .into_iter() + .map(|v| from_redis_value(&v).unwrap()) + .collect::>(); // Change the type of `keys` to `Vec` + keys.append(&mut scan_keys); + if scan_state_rc.is_finished() { + break; + } + let key = format!("key{}", i); + i += 1; + let res: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + assert!(res.is_ok()); + } + // Check if all keys were scanned + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + for key in expected_keys.iter() { + assert!(keys.contains(key)); + } + assert!(keys.len() >= expected_keys.len()); + } + + #[tokio::test] + #[serial_test::serial] + // Test cluster scan with deleting keys for each iteration + async fn test_async_cluster_scan_dell_in_the_middle() { + let cluster = TestClusterContext::new(3, 0); + + let mut connection = cluster.async_connection(None).await; + let mut expected_keys: Vec = Vec::new(); + let mut i = 0; + // Set some keys + loop { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + i += 1; + if i == 1000 { + break; + } + } + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = vec![]; + loop { + let (next_cursor, scan_keys): (ScanStateRC, Vec) = connection + .cluster_scan(scan_state_rc, None, None) + .await + .unwrap(); + scan_state_rc = next_cursor; + let mut scan_keys = scan_keys + .into_iter() + .map(|v| from_redis_value(&v).unwrap()) + .collect::>(); // Change the type of `keys` to `Vec` + keys.append(&mut scan_keys); + if scan_state_rc.is_finished() { + break; + } + i -= 1; + let key = format!("key{}", i); + + let res: Result<(), redis::RedisError> = redis::cmd("del") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + assert!(res.is_ok()); + expected_keys.remove(i as usize); + } + // Check if all keys were scanned + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + for key in expected_keys.iter() { + assert!(keys.contains(key)); + } + assert!(keys.len() >= expected_keys.len()); + } + + #[tokio::test] + #[serial_test::serial] + // Testing cluster scan with Pattern option + async fn test_async_cluster_scan_with_pattern() { + let cluster = TestClusterContext::new(3, 0); + let mut connection = cluster.async_connection(None).await; + let mut expected_keys: Vec = Vec::new(); + let mut i = 0; + // Set some keys + loop { + let key = format!("key:pattern:{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + let non_relevant_key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&non_relevant_key) + .arg("value") + .query_async(&mut connection) + .await; + i += 1; + if i == 500 { + break; + } + } + + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = vec![]; + loop { + let (next_cursor, scan_keys): (ScanStateRC, Vec) = connection + .cluster_scan_with_pattern(scan_state_rc, "key:pattern:*", None, None) + .await + .unwrap(); + scan_state_rc = next_cursor; + let mut scan_keys = scan_keys + .into_iter() + .map(|v| from_redis_value(&v).unwrap()) + .collect::>(); // Change the type of `keys` to `Vec` + keys.append(&mut scan_keys); + if scan_state_rc.is_finished() { + break; + } + } + // Check if all keys were scanned + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + for key in expected_keys.iter() { + assert!(keys.contains(key)); + } + assert!(keys.len() == expected_keys.len()); + } + + #[tokio::test] + #[serial_test::serial] + // Testing cluster scan with TYPE option + async fn test_async_cluster_scan_with_type() { + let cluster = TestClusterContext::new(3, 0); + let mut connection = cluster.async_connection(None).await; + let mut expected_keys: Vec = Vec::new(); + let mut i = 0; + // Set some keys + loop { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SADD") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + let key = format!("key-that-is-not-set{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + i += 1; + if i == 500 { + break; + } + } + + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = vec![]; + loop { + let (next_cursor, scan_keys): (ScanStateRC, Vec) = connection + .cluster_scan(scan_state_rc, None, Some(ObjectType::Set)) + .await + .unwrap(); + scan_state_rc = next_cursor; + let mut scan_keys = scan_keys + .into_iter() + .map(|v| from_redis_value(&v).unwrap()) + .collect::>(); // Change the type of `keys` to `Vec` + keys.append(&mut scan_keys); + if scan_state_rc.is_finished() { + break; + } + } + // Check if all keys were scanned + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + for key in expected_keys.iter() { + assert!(keys.contains(key)); + } + assert!(keys.len() == expected_keys.len()); + } + + #[tokio::test] + #[serial_test::serial] + // Testing cluster scan with COUNT option + async fn test_async_cluster_scan_with_count() { + let cluster = TestClusterContext::new(3, 0); + let mut connection = cluster.async_connection(None).await; + let mut expected_keys: Vec = Vec::new(); + let mut i = 0; + // Set some keys + loop { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + i += 1; + if i == 1000 { + break; + } + } + + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = vec![]; + let mut comparing_times = 0; + loop { + let (next_cursor, scan_keys): (ScanStateRC, Vec) = connection + .cluster_scan(scan_state_rc.clone(), Some(100), None) + .await + .unwrap(); + let (_, scan_without_count_keys): (ScanStateRC, Vec) = connection + .cluster_scan(scan_state_rc, Some(100), None) + .await + .unwrap(); + if !scan_keys.is_empty() && !scan_without_count_keys.is_empty() { + assert!(scan_keys.len() >= scan_without_count_keys.len()); + + comparing_times += 1; + } + scan_state_rc = next_cursor; + let mut scan_keys = scan_keys + .into_iter() + .map(|v| from_redis_value(&v).unwrap()) + .collect::>(); // Change the type of `keys` to `Vec` + keys.append(&mut scan_keys); + if scan_state_rc.is_finished() { + break; + } + } + assert!(comparing_times > 0); + // Check if all keys were scanned + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + for key in expected_keys.iter() { + assert!(keys.contains(key)); + } + assert!(keys.len() == expected_keys.len()); + } + + #[tokio::test] + #[serial_test::serial] + // Testing cluster scan when connection fails in the middle and we get an error + // then cluster up again and scanning can continue without any problem + async fn test_async_cluster_scan_failover() { + let mut cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.retries(1), + false, + ); + let mut connection = cluster.async_connection(None).await; + let mut i = 0; + loop { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + i += 1; + if i == 1000 { + break; + } + } + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = Vec::new(); + let mut count = 0; + loop { + count += 1; + let scan_response: RedisResult<(ScanStateRC, Vec)> = + connection.cluster_scan(scan_state_rc, None, None).await; + if scan_response.is_err() { + println!("error: {:?}", scan_response); + } + let (next_cursor, scan_keys) = scan_response.unwrap(); + scan_state_rc = next_cursor; + keys.extend(scan_keys.into_iter().map(|v| from_redis_value(&v).unwrap())); + if scan_state_rc.is_finished() { + break; + } + if count == 5 { + drop(cluster); + let scan_response: RedisResult<(ScanStateRC, Vec)> = connection + .cluster_scan(scan_state_rc.clone(), None, None) + .await; + assert!(scan_response.is_err()); + break; + }; + } + cluster = TestClusterContext::new(3, 0); + connection = cluster.async_connection(None).await; + loop { + let scan_response: RedisResult<(ScanStateRC, Vec)> = + connection.cluster_scan(scan_state_rc, None, None).await; + if scan_response.is_err() { + println!("error: {:?}", scan_response); + } + let (next_cursor, scan_keys) = scan_response.unwrap(); + scan_state_rc = next_cursor; + keys.extend(scan_keys.into_iter().map(|v| from_redis_value(&v).unwrap())); + if scan_state_rc.is_finished() { + break; + } + } + } +} diff --git a/glide-core/redis-rs/redis/tests/test_geospatial.rs b/glide-core/redis-rs/redis/tests/test_geospatial.rs new file mode 100644 index 0000000000..8bec9a1d73 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_geospatial.rs @@ -0,0 +1,197 @@ +#![cfg(feature = "geospatial")] + +use assert_approx_eq::assert_approx_eq; + +use redis::geo::{Coord, RadiusOptions, RadiusOrder, RadiusSearchResult, Unit}; +use redis::{Commands, RedisResult}; + +mod support; +use crate::support::*; + +const PALERMO: (&str, &str, &str) = ("13.361389", "38.115556", "Palermo"); +const CATANIA: (&str, &str, &str) = ("15.087269", "37.502669", "Catania"); +const AGRIGENTO: (&str, &str, &str) = ("13.5833332", "37.316667", "Agrigento"); + +#[test] +fn test_geoadd_single_tuple() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", PALERMO), Ok(1)); +} + +#[test] +fn test_geoadd_multiple_tuples() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", &[PALERMO, CATANIA]), Ok(2)); +} + +#[test] +fn test_geodist_existing_members() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", &[PALERMO, CATANIA]), Ok(2)); + + let dist: f64 = con + .geo_dist("my_gis", PALERMO.2, CATANIA.2, Unit::Kilometers) + .unwrap(); + assert_approx_eq!(dist, 166.2742, 0.001); +} + +#[test] +fn test_geodist_support_option() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", &[PALERMO, CATANIA]), Ok(2)); + + // We should be able to extract the value as an Option<_>, so we can detect + // if a member is missing + + let result: RedisResult> = con.geo_dist("my_gis", PALERMO.2, "none", Unit::Meters); + assert_eq!(result, Ok(None)); + + let result: RedisResult> = + con.geo_dist("my_gis", PALERMO.2, CATANIA.2, Unit::Meters); + assert_ne!(result, Ok(None)); + + let dist = result.unwrap().unwrap(); + assert_approx_eq!(dist, 166_274.151_6, 0.01); +} + +#[test] +fn test_geohash() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", &[PALERMO, CATANIA]), Ok(2)); + let result: RedisResult> = con.geo_hash("my_gis", PALERMO.2); + assert_eq!(result, Ok(vec![String::from("sqc8b49rny0")])); + + let result: RedisResult> = con.geo_hash("my_gis", &[PALERMO.2, CATANIA.2]); + assert_eq!( + result, + Ok(vec![ + String::from("sqc8b49rny0"), + String::from("sqdtr74hyu0"), + ]) + ); +} + +#[test] +fn test_geopos() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", &[PALERMO, CATANIA]), Ok(2)); + + let result: Vec> = con.geo_pos("my_gis", &[PALERMO.2]).unwrap(); + assert_eq!(result.len(), 1); + + assert_approx_eq!(result[0][0], 13.36138, 0.0001); + assert_approx_eq!(result[0][1], 38.11555, 0.0001); + + // Using the Coord struct + let result: Vec> = con.geo_pos("my_gis", &[PALERMO.2, CATANIA.2]).unwrap(); + assert_eq!(result.len(), 2); + + assert_approx_eq!(result[0].longitude, 13.36138, 0.0001); + assert_approx_eq!(result[0].latitude, 38.11555, 0.0001); + + assert_approx_eq!(result[1].longitude, 15.08726, 0.0001); + assert_approx_eq!(result[1].latitude, 37.50266, 0.0001); +} + +#[test] +fn test_use_coord_struct() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!( + con.geo_add( + "my_gis", + (Coord::lon_lat(13.361_389, 38.115_556), "Palermo") + ), + Ok(1) + ); + + let result: Vec> = con.geo_pos("my_gis", "Palermo").unwrap(); + assert_eq!(result.len(), 1); + + assert_approx_eq!(result[0].longitude, 13.36138, 0.0001); + assert_approx_eq!(result[0].latitude, 38.11555, 0.0001); +} + +#[test] +fn test_georadius() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", &[PALERMO, CATANIA]), Ok(2)); + + let mut geo_radius = |opts: RadiusOptions| -> Vec { + con.geo_radius("my_gis", 15.0, 37.0, 200.0, Unit::Kilometers, opts) + .unwrap() + }; + + // Simple request, without extra data + let mut result = geo_radius(RadiusOptions::default()); + result.sort_by(|a, b| Ord::cmp(&a.name, &b.name)); + + assert_eq!(result.len(), 2); + + assert_eq!(result[0].name.as_str(), "Catania"); + assert_eq!(result[0].coord, None); + assert_eq!(result[0].dist, None); + + assert_eq!(result[1].name.as_str(), "Palermo"); + assert_eq!(result[1].coord, None); + assert_eq!(result[1].dist, None); + + // Get data with multiple fields + let result = geo_radius(RadiusOptions::default().with_dist().order(RadiusOrder::Asc)); + + assert_eq!(result.len(), 2); + + assert_eq!(result[0].name.as_str(), "Catania"); + assert_eq!(result[0].coord, None); + assert_approx_eq!(result[0].dist.unwrap(), 56.4413, 0.001); + + assert_eq!(result[1].name.as_str(), "Palermo"); + assert_eq!(result[1].coord, None); + assert_approx_eq!(result[1].dist.unwrap(), 190.4424, 0.001); + + let result = geo_radius( + RadiusOptions::default() + .with_coord() + .order(RadiusOrder::Desc) + .limit(1), + ); + + assert_eq!(result.len(), 1); + + assert_eq!(result[0].name.as_str(), "Palermo"); + assert_approx_eq!(result[0].coord.as_ref().unwrap().longitude, 13.361_389); + assert_approx_eq!(result[0].coord.as_ref().unwrap().latitude, 38.115_556); + assert_eq!(result[0].dist, None); +} + +#[test] +fn test_georadius_by_member() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", &[PALERMO, CATANIA, AGRIGENTO]), Ok(3)); + + // Simple request, without extra data + let opts = RadiusOptions::default().order(RadiusOrder::Asc); + let result: Vec = con + .geo_radius_by_member("my_gis", AGRIGENTO.2, 100.0, Unit::Kilometers, opts) + .unwrap(); + let names: Vec<_> = result.iter().map(|c| c.name.as_str()).collect(); + + assert_eq!(names, vec!["Agrigento", "Palermo"]); +} diff --git a/glide-core/redis-rs/redis/tests/test_module_json.rs b/glide-core/redis-rs/redis/tests/test_module_json.rs new file mode 100644 index 0000000000..08fed23930 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_module_json.rs @@ -0,0 +1,540 @@ +#![cfg(feature = "json")] + +use std::assert_eq; +use std::collections::HashMap; + +use redis::{JsonCommands, ProtocolVersion}; + +use redis::{ + ErrorKind, RedisError, RedisResult, + Value::{self, *}, +}; + +use crate::support::*; +mod support; + +use serde::Serialize; +// adds json! macro for quick json generation on the fly. +use serde_json::json; + +const TEST_KEY: &str = "my_json"; + +const MTLS_NOT_ENABLED: bool = false; + +#[test] +fn test_module_json_serialize_error() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + #[derive(Debug, Serialize)] + struct InvalidSerializedStruct { + // Maps in serde_json must have string-like keys + // so numbers and strings, anything else will cause the serialization to fail + // this is basically the only way to make a serialization fail at runtime + // since rust doesnt provide the necessary ability to enforce this + pub invalid_json: HashMap, i64>, + } + + let mut test_invalid_value: InvalidSerializedStruct = InvalidSerializedStruct { + invalid_json: HashMap::new(), + }; + + test_invalid_value.invalid_json.insert(None, 2i64); + + let set_invalid: RedisResult = con.json_set(TEST_KEY, "$", &test_invalid_value); + + assert_eq!( + set_invalid, + Err(RedisError::from(( + ErrorKind::Serialize, + "Serialization Error", + String::from("key must be string") + ))) + ); +} + +#[test] +fn test_module_json_arr_append() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":[1i64], "nested": {"a": [1i64, 2i64]}, "nested2": {"a": 42i64}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_append: RedisResult = con.json_arr_append(TEST_KEY, "$..a", &3i64); + + assert_eq!(json_append, Ok(Array(vec![Int(2i64), Int(3i64), Nil]))); +} + +#[test] +fn test_module_json_arr_index() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":[1i64, 2i64, 3i64, 2i64], "nested": {"a": [3i64, 4i64]}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_arrindex: RedisResult = con.json_arr_index(TEST_KEY, "$..a", &2i64); + + assert_eq!(json_arrindex, Ok(Array(vec![Int(1i64), Int(-1i64)]))); + + let update_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":[1i64, 2i64, 3i64, 2i64], "nested": {"a": false}}), + ); + + assert_eq!(update_initial, Ok(true)); + + let json_arrindex_2: RedisResult = + con.json_arr_index_ss(TEST_KEY, "$..a", &2i64, &0, &0); + + assert_eq!(json_arrindex_2, Ok(Array(vec![Int(1i64), Nil]))); +} + +#[test] +fn test_module_json_arr_insert() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":[3i64], "nested": {"a": [3i64 ,4i64]}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_arrinsert: RedisResult = con.json_arr_insert(TEST_KEY, "$..a", 0, &1i64); + + assert_eq!(json_arrinsert, Ok(Array(vec![Int(2), Int(3)]))); + + let update_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":[1i64 ,2i64 ,3i64 ,2i64], "nested": {"a": false}}), + ); + + assert_eq!(update_initial, Ok(true)); + + let json_arrinsert_2: RedisResult = con.json_arr_insert(TEST_KEY, "$..a", 0, &1i64); + + assert_eq!(json_arrinsert_2, Ok(Array(vec![Int(5), Nil]))); +} + +#[test] +fn test_module_json_arr_len() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a": [3i64], "nested": {"a": [3i64, 4i64]}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_arrlen: RedisResult = con.json_arr_len(TEST_KEY, "$..a"); + + assert_eq!(json_arrlen, Ok(Array(vec![Int(1), Int(2)]))); + + let update_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a": [1i64, 2i64, 3i64, 2i64], "nested": {"a": false}}), + ); + + assert_eq!(update_initial, Ok(true)); + + let json_arrlen_2: RedisResult = con.json_arr_len(TEST_KEY, "$..a"); + + assert_eq!(json_arrlen_2, Ok(Array(vec![Int(4), Nil]))); +} + +#[test] +fn test_module_json_arr_pop() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a": [3i64], "nested": {"a": [3i64, 4i64]}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_arrpop: RedisResult = con.json_arr_pop(TEST_KEY, "$..a", -1); + + assert_eq!( + json_arrpop, + Ok(Array(vec![ + // convert string 3 to its ascii value as bytes + BulkString(Vec::from("3".as_bytes())), + BulkString(Vec::from("4".as_bytes())) + ])) + ); + + let update_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":["foo", "bar"], "nested": {"a": false}, "nested2": {"a":[]}}), + ); + + assert_eq!(update_initial, Ok(true)); + + let json_arrpop_2: RedisResult = con.json_arr_pop(TEST_KEY, "$..a", -1); + + assert_eq!( + json_arrpop_2, + Ok(Array(vec![ + BulkString(Vec::from("\"bar\"".as_bytes())), + Nil, + Nil + ])) + ); +} + +#[test] +fn test_module_json_arr_trim() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a": [], "nested": {"a": [1i64, 4u64]}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_arrtrim: RedisResult = con.json_arr_trim(TEST_KEY, "$..a", 1, 1); + + assert_eq!(json_arrtrim, Ok(Array(vec![Int(0), Int(1)]))); + + let update_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a": [1i64, 2i64, 3i64, 4i64], "nested": {"a": false}}), + ); + + assert_eq!(update_initial, Ok(true)); + + let json_arrtrim_2: RedisResult = con.json_arr_trim(TEST_KEY, "$..a", 1, 1); + + assert_eq!(json_arrtrim_2, Ok(Array(vec![Int(1), Nil]))); +} + +#[test] +fn test_module_json_clear() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set(TEST_KEY, "$", &json!({"obj": {"a": 1i64, "b": 2i64}, "arr": [1i64, 2i64, 3i64], "str": "foo", "bool": true, "int": 42i64, "float": std::f64::consts::PI})); + + assert_eq!(set_initial, Ok(true)); + + let json_clear: RedisResult = con.json_clear(TEST_KEY, "$.*"); + + assert_eq!(json_clear, Ok(4)); + + let checking_value: RedisResult = con.json_get(TEST_KEY, "$"); + + // float is set to 0 and serde_json serializes 0f64 to 0.0, which is a different string + assert_eq!( + checking_value, + // i found it changes the order? + // its not reallt a problem if you're just deserializing it anyway but still + // kinda weird + Ok("[{\"arr\":[],\"bool\":true,\"float\":0,\"int\":0,\"obj\":{},\"str\":\"foo\"}]".into()) + ); +} + +#[test] +fn test_module_json_del() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a": 1i64, "nested": {"a": 2i64, "b": 3i64}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_del: RedisResult = con.json_del(TEST_KEY, "$..a"); + + assert_eq!(json_del, Ok(2)); +} + +#[test] +fn test_module_json_get() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":2i64, "b": 3i64, "nested": {"a": 4i64, "b": null}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_get: RedisResult = con.json_get(TEST_KEY, "$..b"); + + assert_eq!(json_get, Ok("[3,null]".into())); + + let json_get_multi: RedisResult = con.json_get(TEST_KEY, vec!["..a", "$..b"]); + + if json_get_multi != Ok("{\"$..b\":[3,null],\"..a\":[2,4]}".into()) + && json_get_multi != Ok("{\"..a\":[2,4],\"$..b\":[3,null]}".into()) + { + panic!("test_error: incorrect response from json_get_multi"); + } +} + +#[test] +fn test_module_json_mget() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial_a: RedisResult = con.json_set( + format!("{TEST_KEY}-a"), + "$", + &json!({"a":1i64, "b": 2i64, "nested": {"a": 3i64, "b": null}}), + ); + let set_initial_b: RedisResult = con.json_set( + format!("{TEST_KEY}-b"), + "$", + &json!({"a":4i64, "b": 5i64, "nested": {"a": 6i64, "b": null}}), + ); + + assert_eq!(set_initial_a, Ok(true)); + assert_eq!(set_initial_b, Ok(true)); + + let json_mget: RedisResult = con.json_get( + vec![format!("{TEST_KEY}-a"), format!("{TEST_KEY}-b")], + "$..a", + ); + + assert_eq!( + json_mget, + Ok(Array(vec![ + BulkString(Vec::from("[1,3]".as_bytes())), + BulkString(Vec::from("[4,6]".as_bytes())) + ])) + ); +} + +#[test] +fn test_module_json_num_incr_by() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":"b","b":[{"a":2i64}, {"a":5i64}, {"a":"c"}]}), + ); + + assert_eq!(set_initial, Ok(true)); + + let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); + if ctx.protocol != ProtocolVersion::RESP2 && redis_ver.starts_with("7.") { + // cannot increment a string + let json_numincrby_a: RedisResult> = con.json_num_incr_by(TEST_KEY, "$.a", 2); + assert_eq!(json_numincrby_a, Ok(vec![Nil])); + + let json_numincrby_b: RedisResult> = con.json_num_incr_by(TEST_KEY, "$..a", 2); + + // however numbers can be incremented + assert_eq!(json_numincrby_b, Ok(vec![Nil, Int(4), Int(7), Nil])); + } else { + // cannot increment a string + let json_numincrby_a: RedisResult = con.json_num_incr_by(TEST_KEY, "$.a", 2); + assert_eq!(json_numincrby_a, Ok("[null]".into())); + + let json_numincrby_b: RedisResult = con.json_num_incr_by(TEST_KEY, "$..a", 2); + + // however numbers can be incremented + assert_eq!(json_numincrby_b, Ok("[null,4,7,null]".into())); + } +} + +#[test] +fn test_module_json_obj_keys() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":[3i64], "nested": {"a": {"b":2i64, "c": 1i64}}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_objkeys: RedisResult = con.json_obj_keys(TEST_KEY, "$..a"); + + assert_eq!( + json_objkeys, + Ok(Array(vec![ + Nil, + Array(vec![ + BulkString(Vec::from("b".as_bytes())), + BulkString(Vec::from("c".as_bytes())) + ]) + ])) + ); +} + +#[test] +fn test_module_json_obj_len() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":[3i64], "nested": {"a": {"b":2i64, "c": 1i64}}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_objlen: RedisResult = con.json_obj_len(TEST_KEY, "$..a"); + + assert_eq!(json_objlen, Ok(Array(vec![Nil, Int(2)]))); +} + +#[test] +fn test_module_json_set() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set: RedisResult = con.json_set(TEST_KEY, "$", &json!({"key": "value"})); + + assert_eq!(set, Ok(true)); +} + +#[test] +fn test_module_json_str_append() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":"foo", "nested": {"a": "hello"}, "nested2": {"a": 31i64}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_strappend: RedisResult = con.json_str_append(TEST_KEY, "$..a", "\"baz\""); + + assert_eq!(json_strappend, Ok(Array(vec![Int(6), Int(8), Nil]))); + + let json_get_check: RedisResult = con.json_get(TEST_KEY, "$"); + + assert_eq!( + json_get_check, + Ok("[{\"a\":\"foobaz\",\"nested\":{\"a\":\"hellobaz\"},\"nested2\":{\"a\":31}}]".into()) + ); +} + +#[test] +fn test_module_json_str_len() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":"foo", "nested": {"a": "hello"}, "nested2": {"a": 31i32}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_strlen: RedisResult = con.json_str_len(TEST_KEY, "$..a"); + + assert_eq!(json_strlen, Ok(Array(vec![Int(3), Int(5), Nil]))); +} + +#[test] +fn test_module_json_toggle() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set(TEST_KEY, "$", &json!({"bool": true})); + + assert_eq!(set_initial, Ok(true)); + + let json_toggle_a: RedisResult = con.json_toggle(TEST_KEY, "$.bool"); + assert_eq!(json_toggle_a, Ok(Array(vec![Int(0)]))); + + let json_toggle_b: RedisResult = con.json_toggle(TEST_KEY, "$.bool"); + assert_eq!(json_toggle_b, Ok(Array(vec![Int(1)]))); +} + +#[test] +fn test_module_json_type() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":2i64, "nested": {"a": true}, "foo": "bar"}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_type_a: RedisResult = con.json_type(TEST_KEY, "$..foo"); + let json_type_b: RedisResult = con.json_type(TEST_KEY, "$..a"); + let json_type_c: RedisResult = con.json_type(TEST_KEY, "$..dummy"); + + let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); + if ctx.protocol != ProtocolVersion::RESP2 && redis_ver.starts_with("7.") { + // In RESP3 current RedisJSON always gives response in an array. + assert_eq!( + json_type_a, + Ok(Array(vec![Array(vec![BulkString(Vec::from( + "string".as_bytes() + ))])])) + ); + + assert_eq!( + json_type_b, + Ok(Array(vec![Array(vec![ + BulkString(Vec::from("integer".as_bytes())), + BulkString(Vec::from("boolean".as_bytes())) + ])])) + ); + assert_eq!(json_type_c, Ok(Array(vec![Array(vec![])]))); + } else { + assert_eq!( + json_type_a, + Ok(Array(vec![BulkString(Vec::from("string".as_bytes()))])) + ); + + assert_eq!( + json_type_b, + Ok(Array(vec![ + BulkString(Vec::from("integer".as_bytes())), + BulkString(Vec::from("boolean".as_bytes())) + ])) + ); + assert_eq!(json_type_c, Ok(Array(vec![]))); + } +} diff --git a/glide-core/redis-rs/redis/tests/test_sentinel.rs b/glide-core/redis-rs/redis/tests/test_sentinel.rs new file mode 100644 index 0000000000..24cd13bd67 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_sentinel.rs @@ -0,0 +1,496 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +#![cfg(feature = "sentinel")] +mod support; + +use std::collections::HashMap; + +use redis::{ + sentinel::{Sentinel, SentinelClient, SentinelNodeConnectionInfo}, + Client, Connection, ConnectionAddr, ConnectionInfo, +}; + +use crate::support::*; + +fn parse_replication_info(value: &str) -> HashMap<&str, &str> { + let info_map: std::collections::HashMap<&str, &str> = value + .split("\r\n") + .filter(|line| !line.trim_start().starts_with('#')) + .filter_map(|line| line.split_once(':')) + .collect(); + info_map +} + +fn assert_is_master_role(replication_info: String) { + let info_map = parse_replication_info(&replication_info); + assert_eq!(info_map.get("role"), Some(&"master")); +} + +fn assert_replica_role_and_master_addr(replication_info: String, expected_master: &ConnectionInfo) { + let info_map = parse_replication_info(&replication_info); + + assert_eq!(info_map.get("role"), Some(&"slave")); + + let (master_host, master_port) = match &expected_master.addr { + ConnectionAddr::Tcp(host, port) => (host, port), + ConnectionAddr::TcpTls { + host, + port, + insecure: _, + tls_params: _, + } => (host, port), + ConnectionAddr::Unix(..) => panic!("Unexpected master connection type"), + }; + + assert_eq!(info_map.get("master_host"), Some(&master_host.as_str())); + assert_eq!( + info_map.get("master_port"), + Some(&master_port.to_string().as_str()) + ); +} + +fn assert_is_connection_to_master(conn: &mut Connection) { + let info: String = redis::cmd("INFO").arg("REPLICATION").query(conn).unwrap(); + assert_is_master_role(info); +} + +fn assert_connection_is_replica_of_correct_master(conn: &mut Connection, master_client: &Client) { + let info: String = redis::cmd("INFO").arg("REPLICATION").query(conn).unwrap(); + assert_replica_role_and_master_addr(info, master_client.get_connection_info()); +} + +/// Get replica clients from the sentinel in a rotating fashion, asserting that they are +/// indeed replicas of the given master, and returning a list of their addresses. +fn connect_to_all_replicas( + sentinel: &mut Sentinel, + master_name: &str, + master_client: &Client, + node_conn_info: &SentinelNodeConnectionInfo, + number_of_replicas: u16, +) -> Vec { + let mut replica_conn_infos = vec![]; + + for _ in 0..number_of_replicas { + let replica_client = sentinel + .replica_rotate_for(master_name, Some(node_conn_info)) + .unwrap(); + let mut replica_con = replica_client.get_connection(None).unwrap(); + + assert!(!replica_conn_infos.contains(&replica_client.get_connection_info().addr)); + replica_conn_infos.push(replica_client.get_connection_info().addr.clone()); + + assert_connection_is_replica_of_correct_master(&mut replica_con, master_client); + } + + replica_conn_infos +} + +fn assert_connect_to_known_replicas( + sentinel: &mut Sentinel, + replica_conn_infos: &[ConnectionAddr], + master_name: &str, + master_client: &Client, + node_conn_info: &SentinelNodeConnectionInfo, + number_of_connections: u32, +) { + for _ in 0..number_of_connections { + let replica_client = sentinel + .replica_rotate_for(master_name, Some(node_conn_info)) + .unwrap(); + let mut replica_con = replica_client.get_connection(None).unwrap(); + + assert!(replica_conn_infos.contains(&replica_client.get_connection_info().addr)); + + assert_connection_is_replica_of_correct_master(&mut replica_con, master_client); + } +} + +#[test] +fn test_sentinel_connect_to_random_replica() { + let master_name = "master1"; + let mut context = TestSentinelContext::new(2, 3, 3); + let node_conn_info: SentinelNodeConnectionInfo = context.sentinel_node_connection_info(); + let sentinel = context.sentinel_mut(); + + let master_client = sentinel + .master_for(master_name, Some(&node_conn_info)) + .unwrap(); + let mut master_con = master_client.get_connection(None).unwrap(); + + let mut replica_con = sentinel + .replica_for(master_name, Some(&node_conn_info)) + .unwrap() + .get_connection(None) + .unwrap(); + + assert_is_connection_to_master(&mut master_con); + assert_connection_is_replica_of_correct_master(&mut replica_con, &master_client); +} + +#[test] +fn test_sentinel_connect_to_multiple_replicas() { + let number_of_replicas = 3; + let master_name = "master1"; + let mut cluster = TestSentinelContext::new(2, number_of_replicas, 3); + let node_conn_info = cluster.sentinel_node_connection_info(); + let sentinel = cluster.sentinel_mut(); + + let master_client = sentinel + .master_for(master_name, Some(&node_conn_info)) + .unwrap(); + let mut master_con = master_client.get_connection(None).unwrap(); + + assert_is_connection_to_master(&mut master_con); + + let replica_conn_infos = connect_to_all_replicas( + sentinel, + master_name, + &master_client, + &node_conn_info, + number_of_replicas, + ); + + assert_connect_to_known_replicas( + sentinel, + &replica_conn_infos, + master_name, + &master_client, + &node_conn_info, + 10, + ); +} + +#[test] +fn test_sentinel_server_down() { + let number_of_replicas = 3; + let master_name = "master1"; + let mut context = TestSentinelContext::new(2, number_of_replicas, 3); + let node_conn_info = context.sentinel_node_connection_info(); + let sentinel = context.sentinel_mut(); + + let master_client = sentinel + .master_for(master_name, Some(&node_conn_info)) + .unwrap(); + let mut master_con = master_client.get_connection(None).unwrap(); + + assert_is_connection_to_master(&mut master_con); + + context.cluster.sentinel_servers[0].stop(); + std::thread::sleep(std::time::Duration::from_millis(25)); + + let sentinel = context.sentinel_mut(); + + let replica_conn_infos = connect_to_all_replicas( + sentinel, + master_name, + &master_client, + &node_conn_info, + number_of_replicas, + ); + + assert_connect_to_known_replicas( + sentinel, + &replica_conn_infos, + master_name, + &master_client, + &node_conn_info, + 10, + ); +} + +#[test] +fn test_sentinel_client() { + let master_name = "master1"; + let mut context = TestSentinelContext::new(2, 3, 3); + let mut master_client = SentinelClient::build( + context.sentinels_connection_info().clone(), + String::from(master_name), + Some(context.sentinel_node_connection_info()), + redis::sentinel::SentinelServerType::Master, + ) + .unwrap(); + + let mut replica_client = SentinelClient::build( + context.sentinels_connection_info().clone(), + String::from(master_name), + Some(context.sentinel_node_connection_info()), + redis::sentinel::SentinelServerType::Replica, + ) + .unwrap(); + + let mut master_con = master_client.get_connection().unwrap(); + + assert_is_connection_to_master(&mut master_con); + + let node_conn_info = context.sentinel_node_connection_info(); + let sentinel = context.sentinel_mut(); + let master_client = sentinel + .master_for(master_name, Some(&node_conn_info)) + .unwrap(); + + for _ in 0..20 { + let mut replica_con = replica_client.get_connection().unwrap(); + + assert_connection_is_replica_of_correct_master(&mut replica_con, &master_client); + } +} + +#[cfg(feature = "aio")] +pub mod async_tests { + use redis::{ + aio::MultiplexedConnection, + sentinel::{Sentinel, SentinelClient, SentinelNodeConnectionInfo}, + Client, ConnectionAddr, GlideConnectionOptions, RedisError, + }; + + use crate::{assert_is_master_role, assert_replica_role_and_master_addr, support::*}; + + async fn async_assert_is_connection_to_master(conn: &mut MultiplexedConnection) { + let info: String = redis::cmd("INFO") + .arg("REPLICATION") + .query_async(conn) + .await + .unwrap(); + + assert_is_master_role(info); + } + + async fn async_assert_connection_is_replica_of_correct_master( + conn: &mut MultiplexedConnection, + master_client: &Client, + ) { + let info: String = redis::cmd("INFO") + .arg("REPLICATION") + .query_async(conn) + .await + .unwrap(); + + assert_replica_role_and_master_addr(info, master_client.get_connection_info()); + } + + /// Async version of connect_to_all_replicas + async fn async_connect_to_all_replicas( + sentinel: &mut Sentinel, + master_name: &str, + master_client: &Client, + node_conn_info: &SentinelNodeConnectionInfo, + number_of_replicas: u16, + ) -> Vec { + let mut replica_conn_infos = vec![]; + + for _ in 0..number_of_replicas { + let replica_client = sentinel + .async_replica_rotate_for(master_name, Some(node_conn_info)) + .await + .unwrap(); + let mut replica_con = replica_client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await + .unwrap(); + + assert!( + !replica_conn_infos.contains(&replica_client.get_connection_info().addr), + "pushing {:?} into {:?}", + replica_client.get_connection_info().addr, + replica_conn_infos + ); + replica_conn_infos.push(replica_client.get_connection_info().addr.clone()); + + async_assert_connection_is_replica_of_correct_master(&mut replica_con, master_client) + .await; + } + + replica_conn_infos + } + + async fn async_assert_connect_to_known_replicas( + sentinel: &mut Sentinel, + replica_conn_infos: &[ConnectionAddr], + master_name: &str, + master_client: &Client, + node_conn_info: &SentinelNodeConnectionInfo, + number_of_connections: u32, + ) { + for _ in 0..number_of_connections { + let replica_client = sentinel + .async_replica_rotate_for(master_name, Some(node_conn_info)) + .await + .unwrap(); + let mut replica_con = replica_client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await + .unwrap(); + + assert!(replica_conn_infos.contains(&replica_client.get_connection_info().addr)); + + async_assert_connection_is_replica_of_correct_master(&mut replica_con, master_client) + .await; + } + } + + #[test] + fn test_sentinel_connect_to_random_replica_async() { + let master_name = "master1"; + let mut context = TestSentinelContext::new(2, 3, 3); + let node_conn_info = context.sentinel_node_connection_info(); + let sentinel = context.sentinel_mut(); + + block_on_all(async move { + let master_client = sentinel + .async_master_for(master_name, Some(&node_conn_info)) + .await?; + let mut master_con = master_client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + + let mut replica_con = sentinel + .async_replica_for(master_name, Some(&node_conn_info)) + .await? + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + + async_assert_is_connection_to_master(&mut master_con).await; + async_assert_connection_is_replica_of_correct_master(&mut replica_con, &master_client) + .await; + + Ok::<(), RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_sentinel_connect_to_multiple_replicas_async() { + let number_of_replicas = 3; + let master_name = "master1"; + let mut cluster = TestSentinelContext::new(2, number_of_replicas, 3); + let node_conn_info = cluster.sentinel_node_connection_info(); + let sentinel = cluster.sentinel_mut(); + + block_on_all(async move { + let master_client = sentinel + .async_master_for(master_name, Some(&node_conn_info)) + .await?; + let mut master_con = master_client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + + async_assert_is_connection_to_master(&mut master_con).await; + + let replica_conn_infos = async_connect_to_all_replicas( + sentinel, + master_name, + &master_client, + &node_conn_info, + number_of_replicas, + ) + .await; + + async_assert_connect_to_known_replicas( + sentinel, + &replica_conn_infos, + master_name, + &master_client, + &node_conn_info, + 10, + ) + .await; + + Ok::<(), RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_sentinel_server_down_async() { + let number_of_replicas = 3; + let master_name = "master1"; + let mut context = TestSentinelContext::new(2, number_of_replicas, 3); + let node_conn_info = context.sentinel_node_connection_info(); + + block_on_all(async move { + let sentinel = context.sentinel_mut(); + + let master_client = sentinel + .async_master_for(master_name, Some(&node_conn_info)) + .await?; + let mut master_con = master_client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + + async_assert_is_connection_to_master(&mut master_con).await; + + context.cluster.sentinel_servers[0].stop(); + std::thread::sleep(std::time::Duration::from_millis(25)); + + let sentinel = context.sentinel_mut(); + + let replica_conn_infos = async_connect_to_all_replicas( + sentinel, + master_name, + &master_client, + &node_conn_info, + number_of_replicas, + ) + .await; + + async_assert_connect_to_known_replicas( + sentinel, + &replica_conn_infos, + master_name, + &master_client, + &node_conn_info, + 10, + ) + .await; + + Ok::<(), RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_sentinel_client_async() { + let master_name = "master1"; + let mut context = TestSentinelContext::new(2, 3, 3); + let mut master_client = SentinelClient::build( + context.sentinels_connection_info().clone(), + String::from(master_name), + Some(context.sentinel_node_connection_info()), + redis::sentinel::SentinelServerType::Master, + ) + .unwrap(); + + let mut replica_client = SentinelClient::build( + context.sentinels_connection_info().clone(), + String::from(master_name), + Some(context.sentinel_node_connection_info()), + redis::sentinel::SentinelServerType::Replica, + ) + .unwrap(); + + block_on_all(async move { + let mut master_con = master_client.get_async_connection().await?; + + async_assert_is_connection_to_master(&mut master_con).await; + + let node_conn_info = context.sentinel_node_connection_info(); + let sentinel = context.sentinel_mut(); + let master_client = sentinel + .async_master_for(master_name, Some(&node_conn_info)) + .await?; + + // Read commands to the replica node + for _ in 0..20 { + let mut replica_con = replica_client.get_async_connection().await?; + + async_assert_connection_is_replica_of_correct_master( + &mut replica_con, + &master_client, + ) + .await; + } + + Ok::<(), RedisError>(()) + }) + .unwrap(); + } +} diff --git a/glide-core/redis-rs/redis/tests/test_streams.rs b/glide-core/redis-rs/redis/tests/test_streams.rs new file mode 100644 index 0000000000..bf06028b95 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_streams.rs @@ -0,0 +1,627 @@ +#![cfg(feature = "streams")] + +use redis::streams::*; +use redis::{Commands, Connection, RedisResult, ToRedisArgs}; + +mod support; +use crate::support::*; + +use std::collections::BTreeMap; +use std::str; +use std::thread::sleep; +use std::time::Duration; + +fn xadd(con: &mut Connection) { + let _: RedisResult = + con.xadd("k1", "1000-0", &[("hello", "world"), ("redis", "streams")]); + let _: RedisResult = con.xadd("k1", "1000-1", &[("hello", "world2")]); + let _: RedisResult = con.xadd("k2", "2000-0", &[("hello", "world")]); + let _: RedisResult = con.xadd("k2", "2000-1", &[("hello", "world2")]); +} + +fn xadd_keyrange(con: &mut Connection, key: &str, start: i32, end: i32) { + for _i in start..end { + let _: RedisResult = con.xadd(key, "*", &[("h", "w")]); + } +} + +#[test] +fn test_cmd_options() { + // Tests the following command option builders.... + // xclaim_options + // xread_options + // maxlen enum + + // test read options + + let empty = StreamClaimOptions::default(); + assert_eq!(ToRedisArgs::to_redis_args(&empty).len(), 0); + + let empty = StreamReadOptions::default(); + assert_eq!(ToRedisArgs::to_redis_args(&empty).len(), 0); + + let opts = StreamClaimOptions::default() + .idle(50) + .time(500) + .retry(3) + .with_force() + .with_justid(); + + assert_args!( + &opts, + "IDLE", + "50", + "TIME", + "500", + "RETRYCOUNT", + "3", + "FORCE", + "JUSTID" + ); + + // test maxlen options + + assert_args!(StreamMaxlen::Approx(10), "MAXLEN", "~", "10"); + assert_args!(StreamMaxlen::Equals(10), "MAXLEN", "=", "10"); + + // test read options + + let opts = StreamReadOptions::default() + .noack() + .block(100) + .count(200) + .group("group-name", "consumer-name"); + + assert_args!( + &opts, + "GROUP", + "group-name", + "consumer-name", + "BLOCK", + "100", + "COUNT", + "200", + "NOACK" + ); + + // should skip noack because of missing group(,) + let opts = StreamReadOptions::default().noack().block(100).count(200); + + assert_args!(&opts, "BLOCK", "100", "COUNT", "200"); +} + +#[test] +fn test_assorted_1() { + // Tests the following commands.... + // xadd + // xadd_map (skip this for now) + // xadd_maxlen + // xread + // xlen + + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + xadd(&mut con); + + // smoke test that we get the same id back + let result: RedisResult = con.xadd("k0", "1000-0", &[("x", "y")]); + assert_eq!(result.unwrap(), "1000-0"); + + // xread reply + let reply: StreamReadReply = con.xread(&["k1", "k2", "k3"], &["0", "0", "0"]).unwrap(); + + // verify reply contains 2 keys even though we asked for 3 + assert_eq!(&reply.keys.len(), &2usize); + + // verify first key & first id exist + assert_eq!(&reply.keys[0].key, "k1"); + assert_eq!(&reply.keys[0].ids.len(), &2usize); + assert_eq!(&reply.keys[0].ids[0].id, "1000-0"); + + // lookup the key in StreamId map + let hello: Option = reply.keys[0].ids[0].get("hello"); + assert_eq!(hello, Some("world".to_string())); + + // verify the second key was written + assert_eq!(&reply.keys[1].key, "k2"); + assert_eq!(&reply.keys[1].ids.len(), &2usize); + assert_eq!(&reply.keys[1].ids[0].id, "2000-0"); + + // test xadd_map + let mut map: BTreeMap<&str, &str> = BTreeMap::new(); + map.insert("ab", "cd"); + map.insert("ef", "gh"); + map.insert("ij", "kl"); + let _: RedisResult = con.xadd_map("k3", "3000-0", map); + + let reply: StreamRangeReply = con.xrange_all("k3").unwrap(); + assert!(reply.ids[0].contains_key("ab")); + assert!(reply.ids[0].contains_key("ef")); + assert!(reply.ids[0].contains_key("ij")); + + // test xadd w/ maxlength below... + + // add 100 things to k4 + xadd_keyrange(&mut con, "k4", 0, 100); + + // test xlen.. should have 100 items + let result: RedisResult = con.xlen("k4"); + assert_eq!(result, Ok(100)); + + // test xadd_maxlen + let _: RedisResult = + con.xadd_maxlen("k4", StreamMaxlen::Equals(10), "*", &[("h", "w")]); + let result: RedisResult = con.xlen("k4"); + assert_eq!(result, Ok(10)); +} + +#[test] +fn test_xgroup_create() { + // Tests the following commands.... + // xadd + // xinfo_stream + // xgroup_create + // xinfo_groups + + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + xadd(&mut con); + + // no key exists... this call breaks the connection pipe for some reason + let reply: RedisResult = con.xinfo_stream("k10"); + assert!(reply.is_err()); + + // redo the connection because the above error + con = ctx.connection(); + + // key should exist + let reply: StreamInfoStreamReply = con.xinfo_stream("k1").unwrap(); + assert_eq!(&reply.first_entry.id, "1000-0"); + assert_eq!(&reply.last_entry.id, "1000-1"); + assert_eq!(&reply.last_generated_id, "1000-1"); + + // xgroup create (existing stream) + let result: RedisResult = con.xgroup_create("k1", "g1", "$"); + assert!(result.is_ok()); + + // xinfo groups (existing stream) + let result: RedisResult = con.xinfo_groups("k1"); + assert!(result.is_ok()); + let reply = result.unwrap(); + assert_eq!(&reply.groups.len(), &1); + assert_eq!(&reply.groups[0].name, &"g1"); +} + +#[test] +fn test_assorted_2() { + // Tests the following commands.... + // xadd + // xinfo_stream + // xinfo_groups + // xinfo_consumer + // xgroup_create_mkstream + // xread_options + // xack + // xpending + // xpending_count + // xpending_consumer_count + + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + xadd(&mut con); + + // test xgroup create w/ mkstream @ 0 + let result: RedisResult = con.xgroup_create_mkstream("k99", "g99", "0"); + assert!(result.is_ok()); + + // Since nothing exists on this stream yet, + // it should have the defaults returned by the client + let result: RedisResult = con.xinfo_groups("k99"); + assert!(result.is_ok()); + let reply = result.unwrap(); + assert_eq!(&reply.groups.len(), &1); + assert_eq!(&reply.groups[0].name, &"g99"); + assert_eq!(&reply.groups[0].last_delivered_id, &"0-0"); + + // call xadd on k99 just so we can read from it + // using consumer g99 and test xinfo_consumers + let _: RedisResult = con.xadd("k99", "1000-0", &[("a", "b"), ("c", "d")]); + let _: RedisResult = con.xadd("k99", "1000-1", &[("e", "f"), ("g", "h")]); + + // test empty PEL + let empty_reply: StreamPendingReply = con.xpending("k99", "g99").unwrap(); + + assert_eq!(empty_reply.count(), 0); + if let StreamPendingReply::Empty = empty_reply { + // looks good + } else { + panic!("Expected StreamPendingReply::Empty but got Data"); + } + + // passing options w/ group triggers XREADGROUP + // using ID=">" means all undelivered ids + // otherwise, ID="0 | ms-num" means all pending already + // sent to this client + let reply: StreamReadReply = con + .xread_options( + &["k99"], + &[">"], + &StreamReadOptions::default().group("g99", "c99"), + ) + .unwrap(); + assert_eq!(reply.keys[0].ids.len(), 2); + + // read xinfo consumers again, should have 2 messages for the c99 consumer + let reply: StreamInfoConsumersReply = con.xinfo_consumers("k99", "g99").unwrap(); + assert_eq!(reply.consumers[0].pending, 2); + + // ack one of these messages + let result: RedisResult = con.xack("k99", "g99", &["1000-0"]); + assert_eq!(result, Ok(1)); + + // get pending messages already seen by this client + // we should only have one now.. + let reply: StreamReadReply = con + .xread_options( + &["k99"], + &["0"], + &StreamReadOptions::default().group("g99", "c99"), + ) + .unwrap(); + assert_eq!(reply.keys.len(), 1); + + // we should also have one pending here... + let reply: StreamInfoConsumersReply = con.xinfo_consumers("k99", "g99").unwrap(); + assert_eq!(reply.consumers[0].pending, 1); + + // add more and read so we can test xpending + let _: RedisResult = con.xadd("k99", "1001-0", &[("i", "j"), ("k", "l")]); + let _: RedisResult = con.xadd("k99", "1001-1", &[("m", "n"), ("o", "p")]); + let _: StreamReadReply = con + .xread_options( + &["k99"], + &[">"], + &StreamReadOptions::default().group("g99", "c99"), + ) + .unwrap(); + + // call xpending here... + // this has a different reply from what the count variations return + let data_reply: StreamPendingReply = con.xpending("k99", "g99").unwrap(); + + assert_eq!(data_reply.count(), 3); + + if let StreamPendingReply::Data(data) = data_reply { + assert_stream_pending_data(data) + } else { + panic!("Expected StreamPendingReply::Data but got Empty"); + } + + // both count variations have the same reply types + let reply: StreamPendingCountReply = con.xpending_count("k99", "g99", "-", "+", 10).unwrap(); + assert_eq!(reply.ids.len(), 3); + + let reply: StreamPendingCountReply = con + .xpending_consumer_count("k99", "g99", "-", "+", 10, "c99") + .unwrap(); + assert_eq!(reply.ids.len(), 3); + + for StreamPendingId { + id, + consumer, + times_delivered, + last_delivered_ms: _, + } in reply.ids + { + assert!(!id.is_empty()); + assert!(!consumer.is_empty()); + assert!(times_delivered > 0); + } +} + +fn assert_stream_pending_data(data: StreamPendingData) { + assert_eq!(data.start_id, "1000-1"); + assert_eq!(data.end_id, "1001-1"); + assert_eq!(data.consumers.len(), 1); + assert_eq!(data.consumers[0].name, "c99"); +} + +#[test] +fn test_xadd_maxlen_map() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + for i in 0..10 { + let mut map: BTreeMap<&str, &str> = BTreeMap::new(); + let idx = i.to_string(); + map.insert("idx", &idx); + let _: RedisResult = + con.xadd_maxlen_map("maxlen_map", StreamMaxlen::Equals(3), "*", map); + } + + let result: RedisResult = con.xlen("maxlen_map"); + assert_eq!(result, Ok(3)); + let reply: StreamRangeReply = con.xrange_all("maxlen_map").unwrap(); + + assert_eq!(reply.ids[0].get("idx"), Some("7".to_string())); + assert_eq!(reply.ids[1].get("idx"), Some("8".to_string())); + assert_eq!(reply.ids[2].get("idx"), Some("9".to_string())); +} + +#[test] +fn test_xread_options_deleted_pel_entry() { + // Test xread_options behaviour with deleted entry + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let result: RedisResult = con.xgroup_create_mkstream("k1", "g1", "$"); + assert!(result.is_ok()); + let _: RedisResult = + con.xadd_maxlen("k1", StreamMaxlen::Equals(1), "*", &[("h1", "w1")]); + // read the pending items for this key & group + let result: StreamReadReply = con + .xread_options( + &["k1"], + &[">"], + &StreamReadOptions::default().group("g1", "c1"), + ) + .unwrap(); + + let _: RedisResult = + con.xadd_maxlen("k1", StreamMaxlen::Equals(1), "*", &[("h2", "w2")]); + let result_deleted_entry: StreamReadReply = con + .xread_options( + &["k1"], + &["0"], + &StreamReadOptions::default().group("g1", "c1"), + ) + .unwrap(); + assert_eq!( + result.keys[0].ids.len(), + result_deleted_entry.keys[0].ids.len() + ); + assert_eq!( + result.keys[0].ids[0].id, + result_deleted_entry.keys[0].ids[0].id + ); +} +#[test] +fn test_xclaim() { + // Tests the following commands.... + // xclaim + // xclaim_options + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + // xclaim test basic idea: + // 1. we need to test adding messages to a group + // 2. then xreadgroup needs to define a consumer and read pending + // messages without acking them + // 3. then we need to sleep 5ms and call xpending + // 4. from here we should be able to claim message + // past the idle time and read them from a different consumer + + // create the group + let result: RedisResult = con.xgroup_create_mkstream("k1", "g1", "$"); + assert!(result.is_ok()); + + // add some keys + xadd_keyrange(&mut con, "k1", 0, 10); + + // read the pending items for this key & group + let reply: StreamReadReply = con + .xread_options( + &["k1"], + &[">"], + &StreamReadOptions::default().group("g1", "c1"), + ) + .unwrap(); + // verify we have 10 ids + assert_eq!(reply.keys[0].ids.len(), 10); + + // save this StreamId for later + let claim = &reply.keys[0].ids[0]; + let _claim_1 = &reply.keys[0].ids[1]; + let claim_justids = &reply.keys[0] + .ids + .iter() + .map(|msg| &msg.id) + .collect::>(); + + // sleep for 5ms + sleep(Duration::from_millis(5)); + + // grab this id if > 4ms + let reply: StreamClaimReply = con + .xclaim("k1", "g1", "c2", 4, &[claim.id.clone()]) + .unwrap(); + assert_eq!(reply.ids.len(), 1); + assert_eq!(reply.ids[0].id, claim.id); + + // grab all pending ids for this key... + // we should 9 in c1 and 1 in c2 + let reply: StreamPendingReply = con.xpending("k1", "g1").unwrap(); + if let StreamPendingReply::Data(data) = reply { + assert_eq!(data.consumers[0].name, "c1"); + assert_eq!(data.consumers[0].pending, 9); + assert_eq!(data.consumers[1].name, "c2"); + assert_eq!(data.consumers[1].pending, 1); + } + + // sleep for 5ms + sleep(Duration::from_millis(5)); + + // lets test some of the xclaim_options + // call force on the same claim.id + let _: StreamClaimReply = con + .xclaim_options( + "k1", + "g1", + "c3", + 4, + &[claim.id.clone()], + StreamClaimOptions::default().with_force(), + ) + .unwrap(); + + let reply: StreamPendingReply = con.xpending("k1", "g1").unwrap(); + // we should have 9 w/ c1 and 1 w/ c3 now + if let StreamPendingReply::Data(data) = reply { + assert_eq!(data.consumers[1].name, "c3"); + assert_eq!(data.consumers[1].pending, 1); + } + + // sleep for 5ms + sleep(Duration::from_millis(5)); + + // claim and only return JUSTID + let claimed: Vec = con + .xclaim_options( + "k1", + "g1", + "c5", + 4, + claim_justids, + StreamClaimOptions::default().with_force().with_justid(), + ) + .unwrap(); + // we just claimed the original 10 ids + // and only returned the ids + assert_eq!(claimed.len(), 10); +} + +#[test] +fn test_xdel() { + // Tests the following commands.... + // xdel + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + // add some keys + xadd(&mut con); + + // delete the first stream item for this key + let result: RedisResult = con.xdel("k1", &["1000-0"]); + // returns the number of items deleted + assert_eq!(result, Ok(1)); + + let result: RedisResult = con.xdel("k2", &["2000-0", "2000-1", "2000-2"]); + // should equal 2 since the last id doesn't exist + assert_eq!(result, Ok(2)); +} + +#[test] +fn test_xtrim() { + // Tests the following commands.... + // xtrim + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + // add some keys + xadd_keyrange(&mut con, "k1", 0, 100); + + // trim key to 50 + // returns the number of items remaining in the stream + let result: RedisResult = con.xtrim("k1", StreamMaxlen::Equals(50)); + assert_eq!(result, Ok(50)); + // we should end up with 40 after this call + let result: RedisResult = con.xtrim("k1", StreamMaxlen::Equals(10)); + assert_eq!(result, Ok(40)); +} + +#[test] +fn test_xgroup() { + // Tests the following commands.... + // xgroup_create_mkstream + // xgroup_destroy + // xgroup_delconsumer + + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + // test xgroup create w/ mkstream @ 0 + let result: RedisResult = con.xgroup_create_mkstream("k1", "g1", "0"); + assert!(result.is_ok()); + + // destroy this new stream group + let result: RedisResult = con.xgroup_destroy("k1", "g1"); + assert_eq!(result, Ok(1)); + + // add some keys + xadd(&mut con); + + // create the group again using an existing stream + let result: RedisResult = con.xgroup_create("k1", "g1", "0"); + assert!(result.is_ok()); + + // read from the group so we can register the consumer + let reply: StreamReadReply = con + .xread_options( + &["k1"], + &[">"], + &StreamReadOptions::default().group("g1", "c1"), + ) + .unwrap(); + assert_eq!(reply.keys[0].ids.len(), 2); + + let result: RedisResult = con.xgroup_delconsumer("k1", "g1", "c1"); + // returns the number of pending message this client had open + assert_eq!(result, Ok(2)); + + let result: RedisResult = con.xgroup_destroy("k1", "g1"); + assert_eq!(result, Ok(1)); +} + +#[test] +fn test_xrange() { + // Tests the following commands.... + // xrange (-/+ variations) + // xrange_all + // xrange_count + + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + xadd(&mut con); + + // xrange replies + let reply: StreamRangeReply = con.xrange_all("k1").unwrap(); + assert_eq!(reply.ids.len(), 2); + + let reply: StreamRangeReply = con.xrange("k1", "1000-1", "+").unwrap(); + assert_eq!(reply.ids.len(), 1); + + let reply: StreamRangeReply = con.xrange("k1", "-", "1000-0").unwrap(); + assert_eq!(reply.ids.len(), 1); + + let reply: StreamRangeReply = con.xrange_count("k1", "-", "+", 1).unwrap(); + assert_eq!(reply.ids.len(), 1); +} + +#[test] +fn test_xrevrange() { + // Tests the following commands.... + // xrevrange (+/- variations) + // xrevrange_all + // xrevrange_count + + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + xadd(&mut con); + + // xrange replies + let reply: StreamRangeReply = con.xrevrange_all("k1").unwrap(); + assert_eq!(reply.ids.len(), 2); + + let reply: StreamRangeReply = con.xrevrange("k1", "1000-1", "-").unwrap(); + assert_eq!(reply.ids.len(), 2); + + let reply: StreamRangeReply = con.xrevrange("k1", "+", "1000-1").unwrap(); + assert_eq!(reply.ids.len(), 1); + + let reply: StreamRangeReply = con.xrevrange_count("k1", "+", "-", 1).unwrap(); + assert_eq!(reply.ids.len(), 1); +} diff --git a/glide-core/redis-rs/redis/tests/test_types.rs b/glide-core/redis-rs/redis/tests/test_types.rs new file mode 100644 index 0000000000..d5df513efb --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_types.rs @@ -0,0 +1,606 @@ +mod support; + +#[cfg(test)] +mod types { + use redis::{FromRedisValue, ToRedisArgs, Value}; + #[test] + fn test_is_single_arg() { + let sslice: &[_] = &["foo"][..]; + let nestslice: &[_] = &[sslice][..]; + let nestvec = vec![nestslice]; + let bytes = b"Hello World!"; + let twobytesslice: &[_] = &[bytes, bytes][..]; + let twobytesvec = vec![bytes, bytes]; + + assert!("foo".is_single_arg()); + assert!(sslice.is_single_arg()); + assert!(nestslice.is_single_arg()); + assert!(nestvec.is_single_arg()); + assert!(bytes.is_single_arg()); + + assert!(!twobytesslice.is_single_arg()); + assert!(!twobytesvec.is_single_arg()); + } + + /// The `FromRedisValue` trait provides two methods for parsing: + /// - `fn from_redis_value(&Value) -> Result` + /// - `fn from_owned_redis_value(Value) -> Result` + /// The `RedisParseMode` below allows choosing between the two + /// so that test logic does not need to be duplicated for each. + enum RedisParseMode { + Owned, + Ref, + } + + impl RedisParseMode { + /// Calls either `FromRedisValue::from_owned_redis_value` or + /// `FromRedisValue::from_redis_value`. + fn parse_redis_value( + &self, + value: redis::Value, + ) -> Result { + match self { + Self::Owned => redis::FromRedisValue::from_owned_redis_value(value), + Self::Ref => redis::FromRedisValue::from_redis_value(&value), + } + } + } + + #[test] + fn test_info_dict() { + use redis::{InfoDict, Value}; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let d: InfoDict = parse_mode + .parse_redis_value(Value::SimpleString( + "# this is a comment\nkey1:foo\nkey2:42\n".into(), + )) + .unwrap(); + + assert_eq!(d.get("key1"), Some("foo".to_string())); + assert_eq!(d.get("key2"), Some(42i64)); + assert_eq!(d.get::("key3"), None); + } + } + + #[test] + fn test_i32() { + use redis::{ErrorKind, Value}; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let i = parse_mode.parse_redis_value(Value::SimpleString("42".into())); + assert_eq!(i, Ok(42i32)); + + let i = parse_mode.parse_redis_value(Value::Int(42)); + assert_eq!(i, Ok(42i32)); + + let i = parse_mode.parse_redis_value(Value::BulkString("42".into())); + assert_eq!(i, Ok(42i32)); + + let bad_i: Result = + parse_mode.parse_redis_value(Value::SimpleString("42x".into())); + assert_eq!(bad_i.unwrap_err().kind(), ErrorKind::TypeError); + } + } + + #[test] + fn test_u32() { + use redis::{ErrorKind, Value}; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let i = parse_mode.parse_redis_value(Value::SimpleString("42".into())); + assert_eq!(i, Ok(42u32)); + + let bad_i: Result = + parse_mode.parse_redis_value(Value::SimpleString("-1".into())); + assert_eq!(bad_i.unwrap_err().kind(), ErrorKind::TypeError); + } + } + + #[test] + fn test_vec() { + use redis::Value; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::Array(vec![ + Value::BulkString("1".into()), + Value::BulkString("2".into()), + Value::BulkString("3".into()), + ])); + assert_eq!(v, Ok(vec![1i32, 2, 3])); + + let content: &[u8] = b"\x01\x02\x03\x04"; + let content_vec: Vec = Vec::from(content); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec.clone())); + assert_eq!(v, Ok(content_vec)); + + let content: &[u8] = b"1"; + let content_vec: Vec = Vec::from(content); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec.clone())); + assert_eq!(v, Ok(vec![b'1'])); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec)); + assert_eq!(v, Ok(vec![1_u16])); + } + } + + #[test] + fn test_box_slice() { + use redis::{FromRedisValue, Value}; + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::Array(vec![ + Value::BulkString("1".into()), + Value::BulkString("2".into()), + Value::BulkString("3".into()), + ])); + assert_eq!(v, Ok(vec![1i32, 2, 3].into_boxed_slice())); + + let content: &[u8] = b"\x01\x02\x03\x04"; + let content_vec: Vec = Vec::from(content); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec.clone())); + assert_eq!(v, Ok(content_vec.into_boxed_slice())); + + let content: &[u8] = b"1"; + let content_vec: Vec = Vec::from(content); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec.clone())); + assert_eq!(v, Ok(vec![b'1'].into_boxed_slice())); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec)); + assert_eq!(v, Ok(vec![1_u16].into_boxed_slice())); + + assert_eq!( + Box::<[i32]>::from_redis_value( + &Value::BulkString("just a string".into()) + ).unwrap_err().to_string(), + "Response was of incompatible type - TypeError: \"Conversion to alloc::boxed::Box<[i32]> failed.\" (response was bulk-string('\"just a string\"'))", + ); + } + } + + #[test] + fn test_arc_slice() { + use redis::{FromRedisValue, Value}; + use std::sync::Arc; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::Array(vec![ + Value::BulkString("1".into()), + Value::BulkString("2".into()), + Value::BulkString("3".into()), + ])); + assert_eq!(v, Ok(Arc::from(vec![1i32, 2, 3]))); + + let content: &[u8] = b"\x01\x02\x03\x04"; + let content_vec: Vec = Vec::from(content); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec.clone())); + assert_eq!(v, Ok(Arc::from(content_vec))); + + let content: &[u8] = b"1"; + let content_vec: Vec = Vec::from(content); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec.clone())); + assert_eq!(v, Ok(Arc::from(vec![b'1']))); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec)); + assert_eq!(v, Ok(Arc::from(vec![1_u16]))); + + assert_eq!( + Arc::<[i32]>::from_redis_value( + &Value::BulkString("just a string".into()) + ).unwrap_err().to_string(), + "Response was of incompatible type - TypeError: \"Conversion to alloc::sync::Arc<[i32]> failed.\" (response was bulk-string('\"just a string\"'))", + ); + } + } + + #[test] + fn test_single_bool_vec() { + use redis::Value; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::BulkString("1".into())); + + assert_eq!(v, Ok(vec![true])); + } + } + + #[test] + fn test_single_i32_vec() { + use redis::Value; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::BulkString("1".into())); + + assert_eq!(v, Ok(vec![1i32])); + } + } + + #[test] + fn test_single_u32_vec() { + use redis::Value; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::BulkString("42".into())); + + assert_eq!(v, Ok(vec![42u32])); + } + } + + #[test] + fn test_single_string_vec() { + use redis::Value; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::BulkString("1".into())); + assert_eq!(v, Ok(vec!["1".to_string()])); + } + } + + #[test] + fn test_tuple() { + use redis::Value; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::Array(vec![Value::Array(vec![ + Value::BulkString("1".into()), + Value::BulkString("2".into()), + Value::BulkString("3".into()), + ])])); + + assert_eq!(v, Ok(((1i32, 2, 3,),))); + } + } + + #[test] + fn test_hashmap() { + use fnv::FnvHasher; + use redis::{ErrorKind, Value}; + use std::collections::HashMap; + use std::hash::BuildHasherDefault; + + type Hm = HashMap; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v: Result = parse_mode.parse_redis_value(Value::Array(vec![ + Value::BulkString("a".into()), + Value::BulkString("1".into()), + Value::BulkString("b".into()), + Value::BulkString("2".into()), + Value::BulkString("c".into()), + Value::BulkString("3".into()), + ])); + let mut e: Hm = HashMap::new(); + e.insert("a".into(), 1); + e.insert("b".into(), 2); + e.insert("c".into(), 3); + assert_eq!(v, Ok(e)); + + type Hasher = BuildHasherDefault; + type HmHasher = HashMap; + let v: Result = parse_mode.parse_redis_value(Value::Array(vec![ + Value::BulkString("a".into()), + Value::BulkString("1".into()), + Value::BulkString("b".into()), + Value::BulkString("2".into()), + Value::BulkString("c".into()), + Value::BulkString("3".into()), + ])); + + let fnv = Hasher::default(); + let mut e: HmHasher = HashMap::with_hasher(fnv); + e.insert("a".into(), 1); + e.insert("b".into(), 2); + e.insert("c".into(), 3); + assert_eq!(v, Ok(e)); + + let v: Result = + parse_mode.parse_redis_value(Value::Array(vec![Value::BulkString("a".into())])); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + } + } + + #[test] + fn test_bool() { + use redis::{ErrorKind, Value}; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::BulkString("1".into())); + assert_eq!(v, Ok(true)); + + let v = parse_mode.parse_redis_value(Value::BulkString("0".into())); + assert_eq!(v, Ok(false)); + + let v: Result = + parse_mode.parse_redis_value(Value::BulkString("garbage".into())); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v = parse_mode.parse_redis_value(Value::SimpleString("1".into())); + assert_eq!(v, Ok(true)); + + let v = parse_mode.parse_redis_value(Value::SimpleString("0".into())); + assert_eq!(v, Ok(false)); + + let v: Result = + parse_mode.parse_redis_value(Value::SimpleString("garbage".into())); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v = parse_mode.parse_redis_value(Value::Okay); + assert_eq!(v, Ok(true)); + + let v = parse_mode.parse_redis_value(Value::Nil); + assert_eq!(v, Ok(false)); + + let v = parse_mode.parse_redis_value(Value::Int(0)); + assert_eq!(v, Ok(false)); + + let v = parse_mode.parse_redis_value(Value::Int(42)); + assert_eq!(v, Ok(true)); + } + } + + #[cfg(feature = "bytes")] + #[test] + fn test_bytes() { + use bytes::Bytes; + use redis::{ErrorKind, RedisResult, Value}; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let content: &[u8] = b"\x01\x02\x03\x04"; + let content_vec: Vec = Vec::from(content); + let content_bytes = Bytes::from_static(content); + + let v: RedisResult = + parse_mode.parse_redis_value(Value::BulkString(content_vec)); + assert_eq!(v, Ok(content_bytes)); + + let v: RedisResult = + parse_mode.parse_redis_value(Value::SimpleString("garbage".into())); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Okay); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Nil); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Int(0)); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Int(42)); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + } + } + + #[cfg(feature = "uuid")] + #[test] + fn test_uuid() { + use std::str::FromStr; + + use redis::{ErrorKind, FromRedisValue, RedisResult, Value}; + use uuid::Uuid; + + let uuid = Uuid::from_str("abab64b7-e265-4052-a41b-23e1e28674bf").unwrap(); + let bytes = uuid.as_bytes().to_vec(); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::BulkString(bytes)); + assert_eq!(v, Ok(uuid)); + + let v: RedisResult = + FromRedisValue::from_redis_value(&Value::SimpleString("garbage".into())); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Okay); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Nil); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Int(0)); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Int(42)); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + } + + #[test] + fn test_cstring() { + use redis::{ErrorKind, RedisResult, Value}; + use std::ffi::CString; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let content: &[u8] = b"\x01\x02\x03\x04"; + let content_vec: Vec = Vec::from(content); + + let v: RedisResult = + parse_mode.parse_redis_value(Value::BulkString(content_vec)); + assert_eq!(v, Ok(CString::new(content).unwrap())); + + let v: RedisResult = + parse_mode.parse_redis_value(Value::SimpleString("garbage".into())); + assert_eq!(v, Ok(CString::new("garbage").unwrap())); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Okay); + assert_eq!(v, Ok(CString::new("OK").unwrap())); + + let v: RedisResult = + parse_mode.parse_redis_value(Value::SimpleString("gar\0bage".into())); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Nil); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Int(0)); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Int(42)); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + } + } + + #[test] + fn test_types_to_redis_args() { + use redis::ToRedisArgs; + use std::collections::BTreeMap; + use std::collections::BTreeSet; + use std::collections::HashMap; + use std::collections::HashSet; + + assert!(!5i32.to_redis_args().is_empty()); + assert!(!"abc".to_redis_args().is_empty()); + assert!(!"abc".to_redis_args().is_empty()); + assert!(!String::from("x").to_redis_args().is_empty()); + + assert!(![5, 4] + .iter() + .cloned() + .collect::>() + .to_redis_args() + .is_empty()); + + assert!(![5, 4] + .iter() + .cloned() + .collect::>() + .to_redis_args() + .is_empty()); + + // this can be used on something HMSET + assert!(![("a", 5), ("b", 6), ("C", 7)] + .iter() + .cloned() + .collect::>() + .to_redis_args() + .is_empty()); + + // this can also be used on something HMSET + assert!(![("d", 8), ("e", 9), ("f", 10)] + .iter() + .cloned() + .collect::>() + .to_redis_args() + .is_empty()); + } + + #[test] + fn test_large_usize_array_to_redis_args_and_back() { + use crate::support::encode_value; + use redis::ToRedisArgs; + + let mut array = [0; 1000]; + for (i, item) in array.iter_mut().enumerate() { + *item = i; + } + + let vec = (&array).to_redis_args(); + assert_eq!(array.len(), vec.len()); + + let value = Value::Array( + vec.iter() + .map(|val| Value::BulkString(val.clone())) + .collect(), + ); + let mut encoded_input = Vec::new(); + encode_value(&value, &mut encoded_input).unwrap(); + + let new_array: [usize; 1000] = FromRedisValue::from_redis_value(&value).unwrap(); + assert_eq!(new_array, array); + } + + #[test] + fn test_large_u8_array_to_redis_args_and_back() { + use crate::support::encode_value; + use redis::ToRedisArgs; + + let mut array: [u8; 1000] = [0; 1000]; + for (i, item) in array.iter_mut().enumerate() { + *item = (i % 256) as u8; + } + + let vec = (&array).to_redis_args(); + assert_eq!(vec.len(), 1); + assert_eq!(array.len(), vec[0].len()); + + let value = Value::Array(vec[0].iter().map(|val| Value::Int(*val as i64)).collect()); + let mut encoded_input = Vec::new(); + encode_value(&value, &mut encoded_input).unwrap(); + + let new_array: [u8; 1000] = FromRedisValue::from_redis_value(&value).unwrap(); + assert_eq!(new_array, array); + } + + #[test] + fn test_large_string_array_to_redis_args_and_back() { + use crate::support::encode_value; + use redis::ToRedisArgs; + + let mut array: [String; 1000] = [(); 1000].map(|_| String::new()); + for (i, item) in array.iter_mut().enumerate() { + *item = format!("{i}"); + } + + let vec = (&array).to_redis_args(); + assert_eq!(array.len(), vec.len()); + + let value = Value::Array( + vec.iter() + .map(|val| Value::BulkString(val.clone())) + .collect(), + ); + let mut encoded_input = Vec::new(); + encode_value(&value, &mut encoded_input).unwrap(); + + let new_array: [String; 1000] = FromRedisValue::from_redis_value(&value).unwrap(); + assert_eq!(new_array, array); + } + + #[test] + fn test_0_length_usize_array_to_redis_args_and_back() { + use crate::support::encode_value; + use redis::ToRedisArgs; + + let array: [usize; 0] = [0; 0]; + + let vec = (&array).to_redis_args(); + assert_eq!(array.len(), vec.len()); + + let value = Value::Array( + vec.iter() + .map(|val| Value::BulkString(val.clone())) + .collect(), + ); + let mut encoded_input = Vec::new(); + encode_value(&value, &mut encoded_input).unwrap(); + + let new_array: [usize; 0] = FromRedisValue::from_redis_value(&value).unwrap(); + assert_eq!(new_array, array); + + let new_array: [usize; 0] = FromRedisValue::from_redis_value(&Value::Nil).unwrap(); + assert_eq!(new_array, array); + } + + #[test] + fn test_attributes() { + use redis::{parse_redis_value, FromRedisValue, Value}; + let bytes: &[u8] = b"*3\r\n:1\r\n:2\r\n|1\r\n+ttl\r\n:3600\r\n:3\r\n"; + let val = parse_redis_value(bytes).unwrap(); + { + // The case user doesn't expect attributes from server + let x: Vec = redis::FromRedisValue::from_redis_value(&val).unwrap(); + assert_eq!(x, vec![1, 2, 3]); + } + { + // The case user wants raw value from server + let x: Value = FromRedisValue::from_redis_value(&val).unwrap(); + assert_eq!( + x, + Value::Array(vec![ + Value::Int(1), + Value::Int(2), + Value::Attribute { + data: Box::new(Value::Int(3)), + attributes: vec![( + Value::SimpleString("ttl".to_string()), + Value::Int(3600) + )] + } + ]) + ) + } + } +} diff --git a/glide-core/redis-rs/rustfmt.toml b/glide-core/redis-rs/rustfmt.toml new file mode 100644 index 0000000000..0d564415cb --- /dev/null +++ b/glide-core/redis-rs/rustfmt.toml @@ -0,0 +1,2 @@ +use_try_shorthand = true +edition = "2018" diff --git a/glide-core/redis-rs/src/main.rs b/glide-core/redis-rs/src/main.rs new file mode 100644 index 0000000000..a5610f8be9 --- /dev/null +++ b/glide-core/redis-rs/src/main.rs @@ -0,0 +1,3 @@ +fn main() { + println!("Dummy source to bypass ORT OSS Tool virtual workspace restrictions."); +} diff --git a/glide-core/src/client/mod.rs b/glide-core/src/client/mod.rs index 6a8ebe278a..cfe8d6dc05 100644 --- a/glide-core/src/client/mod.rs +++ b/glide-core/src/client/mod.rs @@ -9,7 +9,10 @@ use futures::FutureExt; use logger_core::{log_info, log_warn}; use redis::aio::ConnectionLike; use redis::cluster_async::ClusterConnection; -use redis::cluster_routing::{Routable, RoutingInfo, SingleNodeRoutingInfo}; +use redis::cluster_routing::{ + MultipleNodeRoutingInfo, ResponsePolicy, Routable, RoutingInfo, SingleNodeRoutingInfo, +}; +use redis::cluster_slotmap::ReadFromReplicaStrategy; use redis::{Cmd, ErrorKind, ObjectType, PushInfo, RedisError, RedisResult, ScanStateRC, Value}; pub use standalone_client::StandaloneClient; use std::io; @@ -31,21 +34,22 @@ pub const DEFAULT_CONNECTION_ATTEMPT_TIMEOUT: Duration = Duration::from_millis(2 pub const DEFAULT_PERIODIC_TOPOLOGY_CHECKS_INTERVAL: Duration = Duration::from_secs(60); pub const INTERNAL_CONNECTION_TIMEOUT: Duration = Duration::from_millis(250); pub const FINISHED_SCAN_CURSOR: &str = "finished"; -// The value of 1000 for the maximum number of inflight requests is determined based on Little's Law in queuing theory: -// -// Expected maximum request rate: 50,000 requests/second -// Expected response time: 1 millisecond -// -// According to Little's Law, the maximum number of inflight requests required to fully utilize the maximum request rate is: -// (50,000 requests/second) × (1 millisecond / 1000 milliseconds) = 50 requests -// -// The value of 1000 provides a buffer for bursts while still allowing full utilization of the maximum request rate. + +/// The value of 1000 for the maximum number of inflight requests is determined based on Little's Law in queuing theory: +/// +/// Expected maximum request rate: 50,000 requests/second +/// Expected response time: 1 millisecond +/// +/// According to Little's Law, the maximum number of inflight requests required to fully utilize the maximum request rate is: +/// (50,000 requests/second) × (1 millisecond / 1000 milliseconds) = 50 requests +/// +/// The value of 1000 provides a buffer for bursts while still allowing full utilization of the maximum request rate. pub const DEFAULT_MAX_INFLIGHT_REQUESTS: u32 = 1000; -// The connection check interval is currently not exposed to the user via ConnectionRequest, -// as improper configuration could negatively impact performance or pub/sub resiliency. -// A 3-second interval provides a reasonable balance between connection validation -// and performance overhead. +/// The connection check interval is currently not exposed to the user via ConnectionRequest, +/// as improper configuration could negatively impact performance or pub/sub resiliency. +/// A 3-second interval provides a reasonable balance between connection validation +/// and performance overhead. pub const CONNECTION_CHECKS_INTERVAL: Duration = Duration::from_secs(3); pub(super) fn get_port(address: &NodeAddress) -> u16 { @@ -258,9 +262,9 @@ impl Client { if let Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) = routing { - let cmdname = cmd.command().unwrap_or_default(); - let cmdname = String::from_utf8_lossy(&cmdname); - if redis::cluster_routing::is_readonly_cmd(cmdname.as_bytes()) { + let cmd_name = cmd.command().unwrap_or_default(); + let cmd_name = String::from_utf8_lossy(&cmd_name); + if redis::cluster_routing::is_readonly_cmd(cmd_name.as_bytes()) { // A read-only command, go ahead and send it to a random node RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random) } else { @@ -269,7 +273,7 @@ impl Client { log_warn( "send_command", format!( - "User provided 'Random' routing which is not suitable for the writeable command '{cmdname}'. Changing it to 'RandomPrimary'" + "User provided 'Random' routing which is not suitable for the writeable command '{cmd_name}'. Changing it to 'RandomPrimary'" ), ); RoutingInfo::SingleNode(SingleNodeRoutingInfo::RandomPrimary) @@ -473,6 +477,69 @@ impl Client { self.inflight_requests_allowed .fetch_add(1, Ordering::SeqCst) } + + /// Update the password used to authenticate with the servers. + /// If None is passed, the password will be removed. + /// If `immediate_auth` is true, the password will be used to authenticate with the servers immediately using the `AUTH` command. + /// The default behavior is to update the password without authenticating immediately. + /// If the password is empty or None, and `immediate_auth` is true, the password will be updated and an error will be returned. + pub async fn update_connection_password( + &mut self, + password: Option, + immediate_auth: bool, + ) -> RedisResult { + let timeout = self.request_timeout; + // The password update operation is wrapped in a timeout to prevent it from blocking indefinitely. + // If the operation times out, an error is returned. + // Since the password update operation is not a command that go through the regular command pipeline, + // it is not have the regular timeout handling, as such we need to handle it separately. + match tokio::time::timeout(timeout, async { + match self.internal_client { + ClientWrapper::Standalone(ref mut client) => { + client.update_connection_password(password.clone()).await + } + ClientWrapper::Cluster { ref mut client } => { + client.update_connection_password(password.clone()).await + } + } + }) + .await + { + Ok(result) => { + if immediate_auth { + self.send_immediate_auth(password).await + } else { + result + } + } + Err(_elapsed) => Err(RedisError::from(( + ErrorKind::IoError, + "Password update operation timed out, please check the connection", + ))), + } + } + + async fn send_immediate_auth(&mut self, password: Option) -> RedisResult { + match &password { + Some(pw) if pw.is_empty() => Err(RedisError::from(( + ErrorKind::UserOperationError, + "Empty password provided for authentication", + ))), + None => Err(RedisError::from(( + ErrorKind::UserOperationError, + "No password provided for authentication", + ))), + Some(password) => { + let routing = RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllNodes, + Some(ResponsePolicy::AllSucceeded), + )); + let mut cmd = redis::cmd("AUTH"); + cmd.arg(password); + self.send_command(&cmd, Some(routing)).await + } + } + } } fn load_cmd(code: &[u8]) -> Cmd { @@ -511,8 +578,6 @@ async fn create_cluster_client( .into_iter() .map(|address| get_connection_info(&address, tls_mode, redis_connection_info.clone())) .collect(); - let read_from = request.read_from.unwrap_or_default(); - let read_from_replicas = !matches!(read_from, ReadFrom::Primary); // TODO - implement different read from replica strategies. let periodic_topology_checks = match request.periodic_checks { Some(PeriodicCheck::Disabled) => None, Some(PeriodicCheck::Enabled) => Some(DEFAULT_PERIODIC_TOPOLOGY_CHECKS_INTERVAL), @@ -522,9 +587,12 @@ async fn create_cluster_client( let mut builder = redis::cluster::ClusterClientBuilder::new(initial_nodes) .connection_timeout(INTERNAL_CONNECTION_TIMEOUT) .retries(DEFAULT_RETRIES); - if read_from_replicas { - builder = builder.read_from_replicas(); - } + let read_from_strategy = request.read_from.unwrap_or_default(); + builder = builder.read_from(match read_from_strategy { + ReadFrom::AZAffinity(az) => ReadFromReplicaStrategy::AZAffinity(az), + ReadFrom::PreferReplica => ReadFromReplicaStrategy::RoundRobin, + ReadFrom::Primary => ReadFromReplicaStrategy::AlwaysFromPrimary, + }); if let Some(interval_duration) = periodic_topology_checks { builder = builder.periodic_topology_checks(interval_duration); } @@ -618,12 +686,14 @@ fn sanitized_request_string(request: &ConnectionRequest) -> String { let database_id = format!("\ndatabase ID: {}", request.database_id); let rfr_strategy = request .read_from + .clone() .map(|rfr| { format!( "\nRead from Replica mode: {}", match rfr { ReadFrom::Primary => "Only primary", ReadFrom::PreferReplica => "Prefer replica", + ReadFrom::AZAffinity(_) => "Prefer replica in user's availability zone", } ) }) @@ -712,8 +782,7 @@ impl Client { }) }) .await - .map_err(|_| ConnectionError::Timeout) - .and_then(|res| res) + .map_err(|_| ConnectionError::Timeout)? } } diff --git a/glide-core/src/client/reconnecting_connection.rs b/glide-core/src/client/reconnecting_connection.rs index 14311173ee..39a4c1db62 100644 --- a/glide-core/src/client/reconnecting_connection.rs +++ b/glide-core/src/client/reconnecting_connection.rs @@ -13,13 +13,23 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::sync::Mutex; use std::time::Duration; +use telemetrylib::Telemetry; use tokio::sync::{mpsc, Notify}; use tokio::task; use tokio::time::timeout; -use tokio_retry::Retry; +use tokio_retry2::{Retry, RetryError}; use super::{run_with_timeout, DEFAULT_CONNECTION_ATTEMPT_TIMEOUT}; +/// The reason behind the call to `reconnect()` +#[derive(PartialEq, Eq, Debug, Clone)] +pub enum ReconnectReason { + /// A connection was dropped (for any reason) + ConnectionDropped, + /// Connection creation error + CreateError, +} + /// The object that is used in order to recreate a connection after a disconnect. struct ConnectionBackend { /// This signal is reset when a connection disconnects, and set when a new `ConnectionState` has been set with a `Connected` state. @@ -103,6 +113,7 @@ async fn create_connection( connection_backend: ConnectionBackend, retry_strategy: RetryStrategy, push_sender: Option>, + discover_az: bool, ) -> Result { let client = &connection_backend.connection_info; let connection_options = GlideConnectionOptions { @@ -110,8 +121,13 @@ async fn create_connection( disconnect_notifier: Some::>(Box::new( TokioDisconnectNotifier::new(), )), + discover_az, + }; + let action = || async { + get_multiplexed_connection(client, &connection_options) + .await + .map_err(RetryError::transient) }; - let action = || get_multiplexed_connection(client, &connection_options); match Retry::spawn(retry_strategy.get_iterator(), action).await { Ok(connection) => { @@ -125,6 +141,7 @@ async fn create_connection( .addr ), ); + Telemetry::incr_total_connections(1); Ok(ReconnectingConnection { inner: Arc::new(InnerReconnectingConnection { state: Mutex::new(ConnectionState::Connected(connection)), @@ -151,7 +168,7 @@ async fn create_connection( }), connection_options, }; - connection.reconnect(); + connection.reconnect(ReconnectReason::CreateError); Err((connection, err)) } } @@ -189,6 +206,7 @@ impl ReconnectingConnection { redis_connection_info: RedisConnectionInfo, tls_mode: TlsMode, push_sender: Option>, + discover_az: bool, ) -> Result { log_debug( "connection creation", @@ -201,7 +219,7 @@ impl ReconnectingConnection { connection_available_signal: ManualResetEvent::new(true), client_dropped_flagged: AtomicBool::new(false), }; - create_connection(backend, connection_retry_strategy, push_sender).await + create_connection(backend, connection_retry_strategy, push_sender, discover_az).await } pub(crate) fn node_address(&self) -> String { @@ -221,6 +239,9 @@ impl ReconnectingConnection { } pub(super) fn mark_as_dropped(&self) { + // Update the telemetry for each connection that is dropped. A dropped connection + // will not be re-connected, so update the telemetry here + Telemetry::decr_total_connections(1); self.inner .backend .client_dropped_flagged @@ -245,7 +266,10 @@ impl ReconnectingConnection { } } - pub(super) fn reconnect(&self) { + /// Attempt to re-connect the connection. + /// + /// This function spawns a task to perform the reconnection in the background + pub(super) fn reconnect(&self, reason: ReconnectReason) { { let mut guard = self.inner.state.lock().unwrap(); if matches!(*guard, ConnectionState::Reconnecting) { @@ -259,6 +283,13 @@ impl ReconnectingConnection { log_debug("reconnect", "starting"); let connection_clone = self.clone(); + + if reason.eq(&ReconnectReason::ConnectionDropped) { + // Attempting to reconnect a connection that was dropped (for any reason) - update the telemetry by reducing + // the number of opened connections by 1, it will be incremented by 1 after a successful re-connect + Telemetry::decr_total_connections(1); + } + // The reconnect task is spawned instead of awaited here, so that the reconnect attempt will continue in the // background, regardless of whether the calling task is dropped or not. task::spawn(async move { @@ -293,6 +324,7 @@ impl ReconnectingConnection { .set(); *guard = ConnectionState::Connected(connection); } + Telemetry::incr_total_connections(1); return; } Err(_) => tokio::time::sleep(sleep_duration).await, diff --git a/glide-core/src/client/standalone_client.rs b/glide-core/src/client/standalone_client.rs index b1a699a546..c5e69fd6dd 100644 --- a/glide-core/src/client/standalone_client.rs +++ b/glide-core/src/client/standalone_client.rs @@ -2,8 +2,9 @@ * Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ use super::get_redis_connection_info; -use super::reconnecting_connection::ReconnectingConnection; +use super::reconnecting_connection::{ReconnectReason, ReconnectingConnection}; use super::{ConnectionRequest, NodeAddress, TlsMode}; +use crate::client::types::ReadFrom as ClientReadFrom; use crate::retry_strategies::RetryStrategy; use futures::{future, stream, StreamExt}; use logger_core::log_debug; @@ -13,7 +14,9 @@ use redis::aio::ConnectionLike; use redis::cluster_routing::{self, is_readonly_cmd, ResponsePolicy, Routable, RoutingInfo}; use redis::{PushInfo, RedisError, RedisResult, Value}; use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; use std::sync::Arc; +use telemetrylib::Telemetry; use tokio::sync::mpsc; use tokio::task; @@ -21,7 +24,11 @@ use tokio::task; enum ReadFrom { Primary, PreferReplica { - latest_read_replica_index: Arc, + latest_read_replica_index: Arc, + }, + AZAffinity { + client_az: String, + last_read_replica_index: Arc, }, } @@ -46,6 +53,13 @@ pub struct StandaloneClient { inner: Arc, } +impl Drop for StandaloneClient { + fn drop(&mut self) { + // Client was dropped, reduce the number of clients + Telemetry::decr_total_clients(1); + } +} + pub enum StandaloneClientConnectionError { NoAddressesProvided, FailedConnection(Vec<(Option, RedisError)>), @@ -112,6 +126,11 @@ impl StandaloneClient { // randomize pubsub nodes, maybe a batter option is to always use the primary let pubsub_node_index = rand::thread_rng().gen_range(0..node_count); let pubsub_addr = &connection_request.addresses[pubsub_node_index]; + let discover_az = matches!( + connection_request.read_from, + Some(ClientReadFrom::AZAffinity(_)) + ); + let mut stream = stream::iter(connection_request.addresses.iter()) .map(|address| async { get_connection_and_replication_info( @@ -124,6 +143,7 @@ impl StandaloneClient { }, tls_mode.unwrap_or(TlsMode::NoTls), &push_sender, + discover_az, ) .await .map_err(|err| (format!("{}:{}", address.host, address.port), err)) @@ -193,6 +213,9 @@ impl StandaloneClient { Self::start_periodic_connection_check(node.clone()); } + // Successfully created new client. Update the telemetry + Telemetry::incr_total_clients(1); + Ok(Self { inner: Arc::new(DropWrapper { primary_index, @@ -210,7 +233,7 @@ impl StandaloneClient { &self, latest_read_replica_index: &Arc, ) -> &ReconnectingConnection { - let initial_index = latest_read_replica_index.load(std::sync::atomic::Ordering::Relaxed); + let initial_index = latest_read_replica_index.load(Ordering::Relaxed); let mut check_count = 0; loop { check_count += 1; @@ -230,15 +253,53 @@ impl StandaloneClient { let _ = latest_read_replica_index.compare_exchange_weak( initial_index, index, - std::sync::atomic::Ordering::Relaxed, - std::sync::atomic::Ordering::Relaxed, + Ordering::Relaxed, + Ordering::Relaxed, ); return connection; } } } - fn get_connection(&self, readonly: bool) -> &ReconnectingConnection { + async fn round_robin_read_from_replica_az_awareness( + &self, + latest_read_replica_index: &Arc, + client_az: String, + ) -> &ReconnectingConnection { + let initial_index = latest_read_replica_index.load(Ordering::Relaxed); + let mut retries = 0usize; + + loop { + retries = retries.saturating_add(1); + // Looped through all replicas; no connected replica found in the same AZ. + if retries > self.inner.nodes.len() { + // Attempt a fallback to any available replica in other AZs or primary. + return self.round_robin_read_from_replica(latest_read_replica_index); + } + + // Calculate index based on initial index and check count. + let index = (initial_index + retries) % self.inner.nodes.len(); + let replica = &self.inner.nodes[index]; + + // Attempt to get a connection and retrieve the replica's AZ. + if let Ok(connection) = replica.get_connection().await { + if let Some(replica_az) = connection.get_az().as_deref() { + if replica_az == client_az { + // Update `latest_used_replica` with the index of this replica. + let _ = latest_read_replica_index.compare_exchange_weak( + initial_index, + index, + Ordering::Relaxed, + Ordering::Relaxed, + ); + return replica; + } + } + } + } + } + + async fn get_connection(&self, readonly: bool) -> &ReconnectingConnection { if self.inner.nodes.len() == 1 || !readonly { return self.get_primary_connection(); } @@ -248,6 +309,16 @@ impl StandaloneClient { ReadFrom::PreferReplica { latest_read_replica_index, } => self.round_robin_read_from_replica(latest_read_replica_index), + ReadFrom::AZAffinity { + client_az, + last_read_replica_index, + } => { + self.round_robin_read_from_replica_az_awareness( + last_read_replica_index, + client_az.to_string(), + ) + .await + } } } @@ -260,7 +331,7 @@ impl StandaloneClient { match result { Err(err) if err.is_unrecoverable_error() => { log_warn("send request", format!("received disconnect error `{err}`")); - reconnecting_connection.reconnect(); + reconnecting_connection.reconnect(ReconnectReason::ConnectionDropped); Err(err) } _ => result, @@ -343,7 +414,7 @@ impl StandaloneClient { cmd: &redis::Cmd, readonly: bool, ) -> RedisResult { - let reconnecting_connection = self.get_connection(readonly); + let reconnecting_connection = self.get_connection(readonly).await; Self::send_request(cmd, reconnecting_connection).await } @@ -377,7 +448,7 @@ impl StandaloneClient { "pipeline request", format!("received disconnect error `{err}`"), ); - reconnecting_connection.reconnect(); + reconnecting_connection.reconnect(ReconnectReason::ConnectionDropped); Err(err) } _ => result, @@ -414,7 +485,7 @@ impl StandaloneClient { .is_err_and(|err| err.is_connection_dropped() || err.is_connection_refusal()) { log_debug("StandaloneClient", "heartbeat triggered reconnect"); - reconnecting_connection.reconnect(); + reconnecting_connection.reconnect(ReconnectReason::ConnectionDropped); } } }); @@ -435,6 +506,7 @@ impl StandaloneClient { "StandaloneClient", "connection checker stopped after connection was dropped", ); + // Client was dropped, checker can stop. return; } @@ -453,11 +525,25 @@ impl StandaloneClient { "StandaloneClient", "connection checker has triggered reconnect", ); - reconnecting_connection.reconnect(); + reconnecting_connection.reconnect(ReconnectReason::ConnectionDropped); } } }); } + + /// Update the password used to authenticate with the servers. + /// If the password is `None`, the password will be removed. + pub async fn update_connection_password( + &mut self, + password: Option, + ) -> RedisResult { + self.get_connection(false) + .await + .get_connection() + .await? + .update_connection_password(password.clone()) + .await + } } async fn get_connection_and_replication_info( @@ -466,6 +552,7 @@ async fn get_connection_and_replication_info( connection_info: &redis::RedisConnectionInfo, tls_mode: TlsMode, push_sender: &Option>, + discover_az: bool, ) -> Result<(ReconnectingConnection, Value), (ReconnectingConnection, RedisError)> { let result = ReconnectingConnection::new( address, @@ -473,6 +560,7 @@ async fn get_connection_and_replication_info( connection_info.clone(), tls_mode, push_sender.clone(), + discover_az, ) .await; let reconnecting_connection = match result { @@ -483,7 +571,8 @@ async fn get_connection_and_replication_info( let mut multiplexed_connection = match reconnecting_connection.get_connection().await { Ok(multiplexed_connection) => multiplexed_connection, Err(err) => { - reconnecting_connection.reconnect(); + // NOTE: this block is never reached + reconnecting_connection.reconnect(ReconnectReason::ConnectionDropped); return Err((reconnecting_connection, err)); } }; @@ -492,7 +581,10 @@ async fn get_connection_and_replication_info( .send_packed_command(redis::cmd("INFO").arg("REPLICATION")) .await { - Ok(replication_status) => Ok((reconnecting_connection, replication_status)), + Ok(replication_status) => { + // Connection established + we got the INFO output + Ok((reconnecting_connection, replication_status)) + } Err(err) => Err((reconnecting_connection, err)), } } @@ -503,6 +595,10 @@ fn get_read_from(read_from: Option) -> ReadFrom { Some(super::ReadFrom::PreferReplica) => ReadFrom::PreferReplica { latest_read_replica_index: Default::default(), }, + Some(super::ReadFrom::AZAffinity(az)) => ReadFrom::AZAffinity { + client_az: az, + last_read_replica_index: Default::default(), + }, None => ReadFrom::Primary, } } diff --git a/glide-core/src/client/types.rs b/glide-core/src/client/types.rs index ef4be661e6..0c7680b3a6 100644 --- a/glide-core/src/client/types.rs +++ b/glide-core/src/client/types.rs @@ -53,11 +53,12 @@ impl ::std::fmt::Display for NodeAddress { } } -#[derive(PartialEq, Eq, Clone, Copy, Default)] +#[derive(PartialEq, Eq, Clone, Default)] pub enum ReadFrom { #[default] Primary, PreferReplica, + AZAffinity(String), } #[derive(PartialEq, Eq, Clone, Copy, Default)] @@ -99,7 +100,20 @@ impl From for ConnectionRequest { protobuf::ReadFrom::Primary => ReadFrom::Primary, protobuf::ReadFrom::PreferReplica => ReadFrom::PreferReplica, protobuf::ReadFrom::LowestLatency => todo!(), - protobuf::ReadFrom::AZAffinity => todo!(), + protobuf::ReadFrom::AZAffinity => { + if let Some(client_az) = chars_to_string_option(&value.client_az) { + ReadFrom::AZAffinity(client_az) + } else { + log_warn( + "types", + format!( + "Failed to convert availability zone string: '{:?}'. Falling back to `ReadFrom::PreferReplica`", + value.client_az + ), + ); + ReadFrom::PreferReplica + } + } }); let client_name = chars_to_string_option(&value.client_name); diff --git a/glide-core/src/client/value_conversion.rs b/glide-core/src/client/value_conversion.rs index 4a43da7da7..6ba9dc757c 100644 --- a/glide-core/src/client/value_conversion.rs +++ b/glide-core/src/client/value_conversion.rs @@ -22,6 +22,10 @@ pub(crate) enum ExpectedReturnType<'a> { ArrayOfStrings, ArrayOfBools, ArrayOfDoubleOrNull, + FTAggregateReturnType, + FTSearchReturnType, + FTProfileReturnType(&'a Option>), + FTInfoReturnType, Lolwut, ArrayOfStringAndArrays, ArrayOfArraysOfDoubleOrNull, @@ -891,6 +895,255 @@ pub(crate) fn convert_to_expected_type( format!("(response was {:?})", get_value_type(&value)), ) .into()), + }, + ExpectedReturnType::FTAggregateReturnType => match value { + /* + Example of the response + 1) "3" + 2) 1) "condition" + 2) "refurbished" + 3) "bicylces" + 4) 1) "bicycle:9" + 3) 1) "condition" + 2) "used" + 3) "bicylces" + 4) 1) "bicycle:1" + 2) "bicycle:2" + 3) "bicycle:3" + 4) "bicycle:4" + 4) 1) "condition" + 2) "new" + 3) "bicylces" + 4) 1) "bicycle:5" + 2) "bicycle:6" + + Converting response to (array of maps) + 1) 1# "condition" => "refurbished" + 2# "bicylces" => + 1) "bicycle:9" + 2) 1# "condition" => "used" + 2# "bicylces" => + 1) "bicycle:1" + 2) "bicycle:2" + 3) "bicycle:3" + 4) "bicycle:4" + 3) 1# "condition" => "new" + 2# "bicylces" => + 1) "bicycle:5" + 2) "bicycle:6" + + Very first element in the response is meaningless and should be ignored. + */ + Value::Array(array) => { + let mut res = Vec::with_capacity(array.len() - 1); + for aggregation in array.into_iter().skip(1) { + let Value::Array(fields) = aggregation else { + return Err(( + ErrorKind::TypeError, + "Response couldn't be converted for FT.AGGREGATE", + format!("(`fields` was {:?})", get_value_type(&aggregation)), + ) + .into()); + }; + res.push(convert_array_to_map_by_type( + fields, + None, + None, + )?); + } + Ok(Value::Array(res)) + } + _ => Err(( + ErrorKind::TypeError, + "Response couldn't be converted for FT.AGGREGATE", + format!("(response was {:?})", get_value_type(&value)), + ) + .into()), + }, + ExpectedReturnType::FTSearchReturnType => match value { + /* + Example of the response + 1) (integer) 2 + 2) "json:2" + 3) 1) "__VEC_score" + 2) "11.1100006104" + 3) "$" + 4) "{\"vec\":[1.1,1.2,1.3,1.4,1.5,1.6]}" + 4) "json:0" + 5) 1) "__VEC_score" + 2) "91" + 3) "$" + 4) "{\"vec\":[1,2,3,4,5,6]}" + + Converting response to + 1) (integer) 2 + 2) 1# "json:2" => + 1# "__VEC_score" => "11.1100006104" + 2# "$" => "{\"vec\":[1.1,1.2,1.3,1.4,1.5,1.6]}" + 2# "json:0" => + 1# "__VEC_score" => "91" + 2# "$" => "{\"vec\":[1,2,3,4,5,6]}" + + Response may contain only 1 element, no conversion in that case. + */ + Value::Array(ref array) if array.len() == 1 => Ok(value), + Value::Array(mut array) => { + Ok(Value::Array(vec![ + array.remove(0), + convert_to_expected_type(Value::Array(array), Some(ExpectedReturnType::Map { + key_type: &Some(ExpectedReturnType::BulkString), + value_type: &Some(ExpectedReturnType::Map { + key_type: &Some(ExpectedReturnType::BulkString), + value_type: &Some(ExpectedReturnType::BulkString), + }), + }))? + ])) + }, + _ => Err(( + ErrorKind::TypeError, + "Response couldn't be converted for FT.SEARCH", + format!("(response was {:?})", get_value_type(&value)), + ) + .into()) + }, + ExpectedReturnType::FTInfoReturnType => match value { + /* + Example of the response + 1) index_name + 2) "957fa3ca-2280-467d-873f-8763a36fbd5a" + 3) creation_timestamp + 4) (integer) 1728348101740745 + 5) key_type + 6) HASH + 7) key_prefixes + 8) 1) "blog:post:" + 9) fields + 10) 1) 1) identifier + 2) category + 3) field_name + 4) category + 5) type + 6) TAG + 7) option + 8) + 2) 1) identifier + 2) vec + 3) field_name + 4) VEC + 5) type + 6) VECTOR + 7) option + 8) + 9) vector_params + 10) 1) algorithm + 2) HNSW + 3) data_type + 4) FLOAT32 + 5) dimension + 6) (integer) 2 + ... + + Converting response to + 1# "index_name" => "957fa3ca-2280-467d-873f-8763a36fbd5a" + 2# "creation_timestamp" => 1728348101740745 + 3# "key_type" => "HASH" + 4# "key_prefixes" => + 1) "blog:post:" + 5# "fields" => + 1) 1# "identifier" => "category" + 2# "field_name" => "category" + 3# "type" => "TAG" + 4# "option" => "" + 2) 1# "identifier" => "vec" + 2# "field_name" => "VEC" + 3# "type" => "TAVECTORG" + 4# "option" => "" + 5# "vector_params" => + 1# "algorithm" => "HNSW" + 2# "data_type" => "FLOAT32" + 3# "dimension" => 2 + ... + + Map keys (odd array elements) are simple strings, not bulk strings. + */ + Value::Array(_) => { + let Value::Map(mut map) = convert_to_expected_type(value, Some(ExpectedReturnType::Map { + key_type: &None, + value_type: &None, + }))? else { unreachable!() }; + let Some(fields_pair) = map.iter_mut().find(|(key, _)| { + *key == Value::SimpleString("fields".into()) + }) else { return Ok(Value::Map(map)) }; + let (fields_key, fields_value) = std::mem::replace(fields_pair, (Value::Nil, Value::Nil)); + let Value::Array(fields) = fields_value else { + return Err(( + ErrorKind::TypeError, + "Response couldn't be converted for FT.INFO", + format!("(`fields` was {:?})", get_value_type(&fields_value)), + ).into()); + }; + let fields = fields.into_iter().map(|field| { + let Value::Map(mut field_params) = convert_to_expected_type(field, Some(ExpectedReturnType::Map { + key_type: &None, + value_type: &None, + }))? else { unreachable!() }; + let Some(vector_params_pair) = field_params.iter_mut().find(|(key, _)| { + *key == Value::SimpleString("vector_params".into()) + }) else { return Ok(Value::Map(field_params)) }; + let (vector_params_key, vector_params_value) = std::mem::replace(vector_params_pair, (Value::Nil, Value::Nil)); + let _ = std::mem::replace(vector_params_pair, (vector_params_key, convert_to_expected_type(vector_params_value, Some(ExpectedReturnType::Map { + key_type: &None, + value_type: &None, + }))?)); + Ok(Value::Map(field_params)) + }).collect::>>()?; + let _ = std::mem::replace(fields_pair, (fields_key, Value::Array(fields))); + Ok(Value::Map(map)) + }, + _ => Err(( + ErrorKind::TypeError, + "Response couldn't be converted for FT.INFO", + format!("(response was {:?})", get_value_type(&value)), + ) + .into()) + }, + ExpectedReturnType::FTProfileReturnType(type_of_query) => match value { + /* + Example of the response + 1) + 2) 1) 1) "parse.time" + 2) 119 + 2) 1) "all.count" + 2) 4 + 3) 1) "sync.time" + 2) 0 + ... + + Converting response to + 1) + 2) 1# "parse.time" => 119 + 2# "all.count" => 4 + 3# "sync.time" => 0 + ... + + Converting first array element as it is needed for the inner query and second element to a map. + */ + Value::Array(mut array) if array.len() == 2 => { + let res = vec![ + convert_to_expected_type(array.remove(0), *type_of_query)?, + convert_to_expected_type(array.remove(0), Some(ExpectedReturnType::Map { + key_type: &Some(ExpectedReturnType::SimpleString), + value_type: &Some(ExpectedReturnType::Double), + }))?]; + + Ok(Value::Array(res)) + }, + _ => Err(( + ErrorKind::TypeError, + "Response couldn't be converted for FT.PROFILE", + format!("(response was {:?})", get_value_type(&value)), + ) + .into()) } } } @@ -1140,10 +1393,12 @@ pub(crate) fn expected_type_for_cmd(cmd: &Cmd) -> Option { // TODO use enum to avoid mistakes match command.as_slice() { - b"HGETALL" | b"CONFIG GET" | b"FT.CONFIG GET" | b"HELLO" => Some(ExpectedReturnType::Map { - key_type: &None, - value_type: &None, - }), + b"HGETALL" | b"CONFIG GET" | b"FT.CONFIG GET" | b"FT._ALIASLIST" | b"HELLO" => { + Some(ExpectedReturnType::Map { + key_type: &None, + value_type: &None, + }) + } b"XCLAIM" => { if cmd.position(b"JUSTID").is_some() { Some(ExpectedReturnType::ArrayOfStrings) @@ -1256,6 +1511,17 @@ pub(crate) fn expected_type_for_cmd(cmd: &Cmd) -> Option { key_type: &None, value_type: &None, }), + b"FT.AGGREGATE" => Some(ExpectedReturnType::FTAggregateReturnType), + b"FT.SEARCH" => Some(ExpectedReturnType::FTSearchReturnType), + // TODO replace with tuple + b"FT.PROFILE" => Some(ExpectedReturnType::FTProfileReturnType( + if cmd.arg_idx(2).is_some_and(|a| a == b"SEARCH") { + &Some(ExpectedReturnType::FTSearchReturnType) + } else { + &Some(ExpectedReturnType::FTAggregateReturnType) + }, + )), + b"FT.INFO" => Some(ExpectedReturnType::FTInfoReturnType), _ => None, } } diff --git a/glide-core/src/lib.rs b/glide-core/src/lib.rs index 8da08e99f9..9b22a8bb55 100644 --- a/glide-core/src/lib.rs +++ b/glide-core/src/lib.rs @@ -17,3 +17,4 @@ pub mod scripts_container; pub use client::ConnectionRequest; pub mod cluster_scan_container; pub mod request_type; +pub use telemetrylib::Telemetry; diff --git a/glide-core/src/protobuf/command_request.proto b/glide-core/src/protobuf/command_request.proto index 5b2b826acc..30b33362af 100644 --- a/glide-core/src/protobuf/command_request.proto +++ b/glide-core/src/protobuf/command_request.proto @@ -508,6 +508,11 @@ message ClusterScan { optional string object_type = 4; } +message UpdateConnectionPassword { + optional string password = 1; + bool immediate_auth = 2; +} + message CommandRequest { uint32 callback_idx = 1; @@ -517,6 +522,7 @@ message CommandRequest { ScriptInvocation script_invocation = 4; ScriptInvocationPointers script_invocation_pointers = 5; ClusterScan cluster_scan = 6; + UpdateConnectionPassword update_connection_password = 7; } - Routes route = 7; + Routes route = 8; } diff --git a/glide-core/src/protobuf/connection_request.proto b/glide-core/src/protobuf/connection_request.proto index e8f54c042a..5f4db44b00 100644 --- a/glide-core/src/protobuf/connection_request.proto +++ b/glide-core/src/protobuf/connection_request.proto @@ -70,6 +70,7 @@ message ConnectionRequest { } PubSubSubscriptions pubsub_subscriptions = 13; uint32 inflight_requests_limit = 14; + string client_az = 15; } message ConnectionRetryStrategy { diff --git a/glide-core/src/retry_strategies.rs b/glide-core/src/retry_strategies.rs index dbe5683347..1a5157d225 100644 --- a/glide-core/src/retry_strategies.rs +++ b/glide-core/src/retry_strategies.rs @@ -3,7 +3,7 @@ */ use crate::client::ConnectionRetryStrategy; use std::time::Duration; -use tokio_retry::strategy::{jitter, ExponentialBackoff}; +use tokio_retry2::strategy::{jitter_range, ExponentialBackoff}; #[derive(Clone, Debug)] pub(super) struct RetryStrategy { @@ -27,7 +27,7 @@ impl RetryStrategy { pub(super) fn get_iterator(&self) -> impl Iterator { ExponentialBackoff::from_millis(self.exponent_base as u64) .factor(self.factor as u64) - .map(jitter) + .map(jitter_range(0.8, 1.2)) .take(self.number_of_retries as usize) } } @@ -56,6 +56,7 @@ pub(crate) fn get_exponential_backoff( } #[cfg(feature = "socket-layer")] +#[allow(dead_code)] pub(crate) fn get_fixed_interval_backoff( fixed_interval: u32, number_of_retries: u32, @@ -77,23 +78,39 @@ mod tests { let mut counter = 0; for duration in intervals { counter += 1; - assert!(duration.as_millis() <= interval_duration as u128); + let upper_limit = (interval_duration as f32 * 1.2) as u128; + let lower_limit = (interval_duration as f32 * 0.8) as u128; + assert!( + lower_limit <= duration.as_millis() || duration.as_millis() <= upper_limit, + "{:?}ms <= {:?}ms <= {:?}ms", + lower_limit, + duration.as_millis(), + upper_limit + ); } assert_eq!(counter, retries); } #[test] fn test_exponential_backoff_with_jitter() { - let retries = 3; - let base = 10; - let factor = 5; + let retries = 5; + let base = 2; + let factor = 100; let intervals = get_exponential_backoff(base, factor, retries).get_iterator(); let mut counter = 0; for duration in intervals { counter += 1; let unjittered_duration = factor * (base.pow(counter)); - assert!(duration.as_millis() <= unjittered_duration as u128); + let upper_limit = (unjittered_duration as f32 * 1.2) as u128; + let lower_limit = (unjittered_duration as f32 * 0.8) as u128; + assert!( + lower_limit <= duration.as_millis() || duration.as_millis() <= upper_limit, + "{:?}ms <= {:?}ms <= {:?}ms", + lower_limit, + duration.as_millis(), + upper_limit + ); } assert_eq!(counter, retries); diff --git a/glide-core/src/socket_listener.rs b/glide-core/src/socket_listener.rs index 50445c881d..b9db4e6d99 100644 --- a/glide-core/src/socket_listener.rs +++ b/glide-core/src/socket_listener.rs @@ -11,11 +11,10 @@ use crate::connection_request::ConnectionRequest; use crate::errors::{error_message, error_type, RequestErrorType}; use crate::response; use crate::response::Response; -use crate::retry_strategies::get_fixed_interval_backoff; use bytes::Bytes; use directories::BaseDirs; -use dispose::{Disposable, Dispose}; use logger_core::{log_debug, log_error, log_info, log_trace, log_warn}; +use once_cell::sync::Lazy; use protobuf::{Chars, Message}; use redis::cluster_routing::{ MultipleNodeRoutingInfo, Route, RoutingInfo, SingleNodeRoutingInfo, SlotAddr, @@ -23,18 +22,18 @@ use redis::cluster_routing::{ use redis::cluster_routing::{ResponsePolicy, Routable}; use redis::{Cmd, PushInfo, RedisError, ScanStateRC, Value}; use std::cell::Cell; +use std::collections::HashSet; use std::rc::Rc; +use std::sync::RwLock; use std::{env, str}; use std::{io, thread}; use thiserror::Error; -use tokio::io::ErrorKind::AddrInUse; use tokio::net::{UnixListener, UnixStream}; use tokio::runtime::Builder; use tokio::sync::mpsc; use tokio::sync::mpsc::{channel, Sender}; use tokio::sync::Mutex; use tokio::task; -use tokio_retry::Retry; use tokio_util::task::LocalPoolHandle; use ClosingReason::*; use PipeListeningResult::*; @@ -53,20 +52,6 @@ pub const ZSET: &str = "zset"; pub const HASH: &str = "hash"; pub const STREAM: &str = "stream"; -/// struct containing all objects needed to bind to a socket and clean it. -struct SocketListener { - socket_path: String, - cleanup_socket: bool, -} - -impl Dispose for SocketListener { - fn dispose(self) { - if self.cleanup_socket { - close_socket(&self.socket_path); - } - } -} - /// struct containing all objects needed to read from a unix stream. struct UnixStreamListener { read_socket: Rc, @@ -390,7 +375,7 @@ async fn invoke_script( async fn send_transaction( request: Transaction, - mut client: Client, + client: &mut Client, routing: Option, ) -> ClientUsageResult { let mut pipeline = redis::Pipeline::with_capacity(request.commands.capacity()); @@ -476,7 +461,7 @@ fn get_route( } } -fn handle_request(request: CommandRequest, client: Client, writer: Rc) { +fn handle_request(request: CommandRequest, mut client: Client, writer: Rc) { task::spawn_local(async move { let mut updated_inflight_counter = true; let client_clone = client.clone(); @@ -504,7 +489,7 @@ fn handle_request(request: CommandRequest, client: Client, writer: Rc) { } command_request::Command::Transaction(transaction) => { match get_route(request.route.0, None) { - Ok(routes) => send_transaction(transaction, client, routes).await, + Ok(routes) => send_transaction(transaction, &mut client, routes).await, Err(e) => Err(e), } } @@ -537,6 +522,17 @@ fn handle_request(request: CommandRequest, client: Client, writer: Rc) { Err(e) => Err(e), } } + command_request::Command::UpdateConnectionPassword( + update_connection_password_command, + ) => client + .update_connection_password( + update_connection_password_command + .password + .map(|chars| chars.to_string()), + update_connection_password_command.immediate_auth, + ) + .await + .map_err(|err| err.into()), }, None => { log_debug( @@ -734,109 +730,6 @@ async fn listen_on_client_stream(socket: UnixStream) { log_trace("client closing", "closing connection"); } -enum SocketCreationResult { - // Socket creation was successful, returned a socket listener. - Created(UnixListener), - // There's an existing a socket listener. - PreExisting, - // Socket creation failed with an error. - Err(io::Error), -} - -impl SocketListener { - fn new(socket_path: String) -> Self { - SocketListener { - socket_path, - // Don't cleanup the socket resources unless we know that the socket is in use, and owned by this listener. - cleanup_socket: false, - } - } - - /// Return true if it's possible to connect to socket. - async fn socket_is_available(&self) -> bool { - if UnixStream::connect(&self.socket_path).await.is_ok() { - return true; - } - - let retry_strategy = get_fixed_interval_backoff(10, 3); - - let action = || async { - UnixStream::connect(&self.socket_path) - .await - .map(|_| ()) - .map_err(|_| ()) - }; - let result = Retry::spawn(retry_strategy.get_iterator(), action).await; - result.is_ok() - } - - async fn get_socket_listener(&self) -> SocketCreationResult { - const RETRY_COUNT: u8 = 3; - let mut retries = RETRY_COUNT; - while retries > 0 { - match UnixListener::bind(self.socket_path.clone()) { - Ok(listener) => { - return SocketCreationResult::Created(listener); - } - Err(err) if err.kind() == AddrInUse => { - if self.socket_is_available().await { - return SocketCreationResult::PreExisting; - } else { - // socket file might still exist, even if nothing is listening on it. - close_socket(&self.socket_path); - retries -= 1; - continue; - } - } - Err(err) => { - return SocketCreationResult::Err(err); - } - } - } - SocketCreationResult::Err(io::Error::new( - io::ErrorKind::Other, - "Failed to connect to socket", - )) - } - - pub(crate) async fn listen_on_socket(&mut self, init_callback: InitCallback) - where - InitCallback: FnOnce(Result) + Send + 'static, - { - // Bind to socket - let listener = match self.get_socket_listener().await { - SocketCreationResult::Created(listener) => listener, - SocketCreationResult::Err(err) => { - log_info("listen_on_socket", format!("failed with error: {err}")); - init_callback(Err(err.to_string())); - return; - } - SocketCreationResult::PreExisting => { - init_callback(Ok(self.socket_path.clone())); - return; - } - }; - - self.cleanup_socket = true; - init_callback(Ok(self.socket_path.clone())); - let local_set_pool = LocalPoolHandle::new(num_cpus::get()); - loop { - match listener.accept().await { - Ok((stream, _addr)) => { - local_set_pool.spawn_pinned(move || listen_on_client_stream(stream)); - } - Err(err) => { - log_debug( - "listen_on_socket", - format!("Socket closed with error: `{err}`"), - ); - return; - } - } - } - } -} - #[derive(Debug)] /// Enum describing the reason that a socket listener stopped listening on a socket. pub enum ClosingReason { @@ -924,23 +817,114 @@ pub fn start_socket_listener_internal( init_callback: InitCallback, socket_path: Option, ) where - InitCallback: FnOnce(Result) + Send + 'static, + InitCallback: FnOnce(Result) + Send + Clone + 'static, { + static INITIALIZED_SOCKETS: Lazy>> = + Lazy::new(|| RwLock::new(HashSet::new())); + + let socket_path = socket_path.unwrap_or_else(get_socket_path); + + { + // Optimize for already initialized + let initialized_sockets = INITIALIZED_SOCKETS + .read() + .expect("Failed to acquire sockets db read guard"); + if initialized_sockets.contains(&socket_path) { + init_callback(Ok(socket_path.clone())); + return; + } + } + + // Retry with write lock, will be dropped upon the function completion + let mut sockets_write_guard = INITIALIZED_SOCKETS + .write() + .expect("Failed to acquire sockets db write guard"); + if sockets_write_guard.contains(&socket_path) { + init_callback(Ok(socket_path.clone())); + return; + } + + let (tx, rx) = std::sync::mpsc::channel(); + let socket_path_cloned = socket_path.clone(); + let init_callback_cloned = init_callback.clone(); + let tx_cloned = tx.clone(); thread::Builder::new() .name("socket_listener_thread".to_string()) .spawn(move || { - let runtime = Builder::new_current_thread().enable_all().build(); - match runtime { - Ok(runtime) => { - let mut listener = Disposable::new(SocketListener::new( - socket_path.unwrap_or_else(get_socket_path), - )); - runtime.block_on(listener.listen_on_socket(init_callback)); - } - Err(err) => init_callback(Err(err.to_string())), + let init_result = { + let runtime = match Builder::new_current_thread().enable_all().build() { + Err(err) => { + log_error( + "listen_on_socket", + format!("Error failed to create a new tokio thread: {err}"), + ); + return Err(err); + } + Ok(runtime) => runtime, + }; + + runtime.block_on(async move { + let listener_socket = match UnixListener::bind(socket_path_cloned.clone()) { + Err(err) => { + log_error( + "listen_on_socket", + format!("Error failed to bind listening socket: {err}"), + ); + return Err(err); + } + Ok(listener_socket) => listener_socket, + }; + + // Signal initialization is successful. + // IMPORTANT: + // tx.send() must be called before init_callback_cloned() to ensure runtimes, such as Python, can properly complete the main function + let _ = tx.send(true); + init_callback_cloned(Ok(socket_path_cloned.clone())); + + let local_set_pool = LocalPoolHandle::new(num_cpus::get()); + loop { + match listener_socket.accept().await { + Ok((stream, _addr)) => { + local_set_pool + .spawn_pinned(move || listen_on_client_stream(stream)); + } + Err(err) => { + log_error( + "listen_on_socket", + format!("Error accepting connection: {err}"), + ); + break; + } + } + } + + // ensure socket file removal + drop(listener_socket); + let _ = std::fs::remove_file(socket_path_cloned.clone()); + + // no more listening on socket - update the sockets db + let mut sockets_write_guard = INITIALIZED_SOCKETS + .write() + .expect("Failed to acquire sockets db write guard"); + sockets_write_guard.remove(&socket_path_cloned); + Ok(()) + }) }; + + if let Err(err) = init_result { + init_callback(Err(err.to_string())); + let _ = tx_cloned.send(false); + } + Ok(()) }) .expect("Thread spawn failed. Cannot report error because callback was moved."); + + // wait for thread initialization signaling, callback invocation is done in the thread + let _ = rx.recv().map(|res| { + if res { + sockets_write_guard.insert(socket_path); + } + }); } /// Creates a new thread with a main loop task listening on the socket for new connections. @@ -950,7 +934,7 @@ pub fn start_socket_listener_internal( /// * `init_callback` - called when the socket listener fails to initialize, with the reason for the failure. pub fn start_socket_listener(init_callback: InitCallback) where - InitCallback: FnOnce(Result) + Send + 'static, + InitCallback: FnOnce(Result) + Send + Clone + 'static, { start_socket_listener_internal(init_callback, None); } diff --git a/glide-core/telemetry/Cargo.toml b/glide-core/telemetry/Cargo.toml new file mode 100644 index 0000000000..73b9cb25ea --- /dev/null +++ b/glide-core/telemetry/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "telemetrylib" +version = "0.1.0" +edition = "2021" +license = "Apache-2.0" +authors = ["Valkey GLIDE Maintainers"] + +[dependencies] +lazy_static = "1" +serde = { version = "1", features = ["derive"] } +serde_json = "1" diff --git a/glide-core/telemetry/src/lib.rs b/glide-core/telemetry/src/lib.rs new file mode 100644 index 0000000000..886e43a2c8 --- /dev/null +++ b/glide-core/telemetry/src/lib.rs @@ -0,0 +1,68 @@ +use lazy_static::lazy_static; +use serde::Serialize; +use std::sync::RwLock as StdRwLock; + +#[derive(Default, Serialize)] +#[allow(dead_code)] +pub struct Telemetry { + /// Total number of connections opened to Valkey + total_connections: usize, + /// Total number of GLIDE clients + total_clients: usize, +} + +lazy_static! { + static ref TELEMETRY: StdRwLock = StdRwLock::::default(); +} + +const MUTEX_WRITE_ERR: &str = "Failed to obtain write lock for mutex. Poisoned mutex"; +const MUTEX_READ_ERR: &str = "Failed to obtain read lock for mutex. Poisoned mutex"; + +impl Telemetry { + /// Increment the total number of connections by `incr_by` + /// Return the number of total connections after the increment + pub fn incr_total_connections(incr_by: usize) -> usize { + let mut t = TELEMETRY.write().expect(MUTEX_WRITE_ERR); + t.total_connections = t.total_connections.saturating_add(incr_by); + t.total_connections + } + + /// Decrease the total number of connections by `decr_by` + /// Return the number of total connections after the decrease + pub fn decr_total_connections(decr_by: usize) -> usize { + let mut t = TELEMETRY.write().expect(MUTEX_WRITE_ERR); + t.total_connections = t.total_connections.saturating_sub(decr_by); + t.total_connections + } + + /// Increment the total number of clients by `incr_by` + /// Return the number of total clients after the increment + pub fn incr_total_clients(incr_by: usize) -> usize { + let mut t = TELEMETRY.write().expect(MUTEX_WRITE_ERR); + t.total_clients = t.total_clients.saturating_add(incr_by); + t.total_clients + } + + /// Decrease the total number of clients by `decr_by` + /// Return the number of total clients after the decrease + pub fn decr_total_clients(decr_by: usize) -> usize { + let mut t = TELEMETRY.write().expect(MUTEX_WRITE_ERR); + t.total_clients = t.total_clients.saturating_sub(decr_by); + t.total_clients + } + + /// Return the number of active connections + pub fn total_connections() -> usize { + TELEMETRY.read().expect(MUTEX_READ_ERR).total_connections + } + + /// Return the number of active clients + pub fn total_clients() -> usize { + TELEMETRY.read().expect(MUTEX_READ_ERR).total_clients + } + + /// Reset the telemetry collected thus far + pub fn reset() { + *TELEMETRY.write().expect(MUTEX_WRITE_ERR) = Telemetry::default(); + } +} diff --git a/glide-core/tests/test_client.rs b/glide-core/tests/test_client.rs index 024c0d74bc..ffc672fee6 100644 --- a/glide-core/tests/test_client.rs +++ b/glide-core/tests/test_client.rs @@ -3,8 +3,27 @@ */ mod utilities; +#[macro_export] +/// Compare `$expected` with `$actual`. This macro, will exit the test process +/// if the assertion fails. Unlike `assert_eq!` - this also works in tasks +macro_rules! async_assert_eq { + ($expected:expr, $actual:expr) => {{ + if $actual != $expected { + println!( + "{}:{}: Expected: {:?} != Actual: {:?}", + file!(), + line!(), + $actual, + $expected + ); + std::process::exit(1); + } + }}; +} + #[cfg(test)] pub(crate) mod shared_client_tests { + use glide_core::Telemetry; use std::collections::HashMap; use super::*; @@ -44,7 +63,9 @@ pub(crate) mod shared_client_tests { }) .await } - BackingServer::Cluster(cluster) => create_cluster_client(cluster, configuration).await, + BackingServer::Cluster(cluster) => { + create_cluster_client(cluster.as_ref(), configuration).await + } } } @@ -540,6 +561,98 @@ pub(crate) mod shared_client_tests { }); } + #[test] + #[serial_test::serial] + fn test_client_telemetry_standalone() { + Telemetry::reset(); + block_on_all(async move { + // create a server with 2 clients + let server_config = TestConfiguration { + use_tls: false, + ..Default::default() + }; + + let test_basics = utilities::setup_test_basics_internal(&server_config).await; + let server = BackingServer::Standalone(test_basics.server); + + // setup_test_basics_internal internally, starts a single client connection + assert_eq!(Telemetry::total_connections(), 1); + assert_eq!(Telemetry::total_clients(), 1); + + { + // Create 2 more clients, confirm that they are tracked + let _client1 = create_client(&server, server_config.clone()).await; + let _client2 = create_client(&server, server_config).await; + + // Each client maintains a single connection + assert_eq!(Telemetry::total_connections(), 3); + assert_eq!(Telemetry::total_clients(), 3); + + // Connections are dropped here + } + + // Confirm 1 connection & client remain + assert_eq!(Telemetry::total_connections(), 1); + assert_eq!(Telemetry::total_clients(), 1); + }); + } + + #[test] + #[serial_test::serial] + fn test_client_telemetry_cluster() { + Telemetry::reset(); + block_on_all(async { + let local_set = tokio::task::LocalSet::default(); + let (tx, mut rx) = tokio::sync::mpsc::channel(1); + // We use 2 tasks to let "dispose" be called. In addition, the task that checks for the cleanup + // does not start until the cluster is up and running. We use a channel to communicate this between + // the tasks + local_set.spawn_local(async move { + let cluster = cluster::setup_default_cluster().await; + async_assert_eq!(Telemetry::total_connections(), 0); + async_assert_eq!(Telemetry::total_clients(), 0); + + // Each client opens 12 connections + println!("Creating 1st cluster client..."); + let _c1 = cluster::setup_default_client(&cluster).await; + async_assert_eq!(Telemetry::total_connections(), 12); + async_assert_eq!(Telemetry::total_clients(), 1); + + println!("Creating 2nd cluster client..."); + let _c2 = cluster::setup_default_client(&cluster).await; + async_assert_eq!(Telemetry::total_connections(), 24); + async_assert_eq!(Telemetry::total_clients(), 2); + + let _ = tx.send(1).await; + // client is dropped and eventually disposed here + }); + + local_set.spawn_local(async move { + let _ = rx.recv().await; + println!("Cluster terminated. Wait for the telemetry to clear"); + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + assert_eq!(Telemetry::total_connections(), 0); + assert_eq!(Telemetry::total_clients(), 0); + }); + local_set.await; + }); + } + + #[test] + #[serial_test::serial] + fn test_multi_key_no_args_in_cluster() { + block_on_all(async { + let cluster = cluster::setup_default_cluster().await; + println!("Creating 1st cluster client..."); + let mut c1 = cluster::setup_default_client(&cluster).await; + let result = c1.send_command(&redis::cmd("MSET"), None).await; + assert!(result.is_err()); + let e = result.unwrap_err(); + assert!(e.kind().clone().eq(&redis::ErrorKind::ResponseError)); + assert!(e.to_string().contains("wrong number of arguments")); + }); + } + #[rstest] #[serial_test::serial] #[timeout(SHORT_CLUSTER_TEST_TIMEOUT)] diff --git a/glide-core/tests/test_socket_listener.rs b/glide-core/tests/test_socket_listener.rs index a242eb80d1..6f2aa566b9 100644 --- a/glide-core/tests/test_socket_listener.rs +++ b/glide-core/tests/test_socket_listener.rs @@ -172,7 +172,7 @@ mod socket_listener { } fn read_from_socket(buffer: &mut Vec, socket: &mut UnixStream) -> usize { - buffer.resize(100, 0_u8); + buffer.resize(300, 0_u8); socket.read(buffer).unwrap() } @@ -518,8 +518,10 @@ mod socket_listener { #[rstest] #[timeout(SHORT_STANDALONE_TEST_TIMEOUT)] fn test_working_after_socket_listener_was_dropped() { - let socket_path = - get_socket_path_from_name("test_working_after_socket_listener_was_dropped".to_string()); + let socket_path = get_socket_path_from_name(format!( + "{}_test_working_after_socket_listener_was_dropped", + std::process::id() + )); close_socket(&socket_path); // create a socket listener and drop it, to simulate a panic in a previous iteration. Builder::new_current_thread() @@ -528,6 +530,8 @@ mod socket_listener { .unwrap() .block_on(async { let _ = UnixListener::bind(socket_path.clone()).unwrap(); + // UDS sockets require explicit removal of the socket file + close_socket(&socket_path); }); const CALLBACK_INDEX: u32 = 99; @@ -554,9 +558,10 @@ mod socket_listener { #[rstest] #[timeout(SHORT_STANDALONE_TEST_TIMEOUT)] fn test_multiple_listeners_competing_for_the_socket() { - let socket_path = get_socket_path_from_name( - "test_multiple_listeners_competing_for_the_socket".to_string(), - ); + let socket_path = get_socket_path_from_name(format!( + "{}_test_multiple_listeners_competing_for_the_socket", + std::process::id() + )); close_socket(&socket_path); let server = Arc::new(RedisServer::new(ServerType::Tcp { tls: false })); diff --git a/glide-core/tests/test_standalone_client.rs b/glide-core/tests/test_standalone_client.rs index c007f7a6e0..c118d6d28f 100644 --- a/glide-core/tests/test_standalone_client.rs +++ b/glide-core/tests/test_standalone_client.rs @@ -193,20 +193,19 @@ mod standalone_client_tests { } fn test_read_from_replica(config: ReadFromReplicaTestConfig) { - let mut mocks = create_primary_mock_with_replicas( + let mut servers = create_primary_mock_with_replicas( config.number_of_initial_replicas - config.number_of_missing_replicas, ); let mut cmd = redis::cmd("GET"); cmd.arg("foo"); - for mock in mocks.iter() { + for server in servers.iter() { for _ in 0..3 { - mock.add_response(&cmd, "$-1\r\n".to_string()); + server.add_response(&cmd, "$-1\r\n".to_string()); } } - let mut addresses = get_mock_addresses(&mocks); - + let mut addresses = get_mock_addresses(&servers); for i in 4 - config.number_of_missing_replicas..4 { addresses.push(redis::ConnectionAddr::Tcp( "foo".to_string(), @@ -221,19 +220,32 @@ mod standalone_client_tests { let mut client = StandaloneClient::create_client(connection_request.into(), None) .await .unwrap(); - for mock in mocks.drain(1..config.number_of_replicas_dropped_after_connection + 1) { - mock.close().await; + logger_core::log_info( + "Test", + format!( + "Closing {} servers after connection established", + config.number_of_replicas_dropped_after_connection + ), + ); + for server in servers.drain(1..config.number_of_replicas_dropped_after_connection + 1) { + server.close().await; } + logger_core::log_info( + "Test", + format!("sending {} messages", config.number_of_requests_sent), + ); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; for _ in 0..config.number_of_requests_sent { let _ = client.send_command(&cmd).await; } }); assert_eq!( - mocks[0].get_number_of_received_commands(), + servers[0].get_number_of_received_commands(), config.expected_primary_reads ); - let mut replica_reads: Vec<_> = mocks + let mut replica_reads: Vec<_> = servers .iter() .skip(1) .map(|mock| mock.get_number_of_received_commands()) @@ -261,6 +273,18 @@ mod standalone_client_tests { }); } + #[rstest] + #[serial_test::serial] + #[timeout(SHORT_STANDALONE_TEST_TIMEOUT)] + fn test_read_from_replica_az_affinity() { + test_read_from_replica(ReadFromReplicaTestConfig { + read_from: ReadFrom::AZAffinity, + expected_primary_reads: 0, + expected_replica_reads: vec![1, 1, 1], + ..Default::default() + }); + } + #[rstest] #[serial_test::serial] #[timeout(SHORT_STANDALONE_TEST_TIMEOUT)] @@ -294,7 +318,9 @@ mod standalone_client_tests { test_read_from_replica(ReadFromReplicaTestConfig { read_from: ReadFrom::PreferReplica, expected_primary_reads: 0, - expected_replica_reads: vec![2, 3], + // Since we drop 1 replica after connection establishment + // we expect all reads to be handled by the remaining replicas + expected_replica_reads: vec![3, 3], number_of_replicas_dropped_after_connection: 1, number_of_requests_sent: 6, ..Default::default() diff --git a/glide-core/tests/utilities/cluster.rs b/glide-core/tests/utilities/cluster.rs index 9e7c356f4e..8f7ed6aca0 100644 --- a/glide-core/tests/utilities/cluster.rs +++ b/glide-core/tests/utilities/cluster.rs @@ -5,9 +5,9 @@ use super::{create_connection_request, ClusterMode, TestConfiguration}; use futures::future::{join_all, BoxFuture}; use futures::FutureExt; use glide_core::client::Client; -use glide_core::connection_request::NodeAddress; use once_cell::sync::Lazy; use redis::{ConnectionAddr, RedisConnectionInfo}; +use serde::Deserialize; use std::process::Command; use std::sync::Mutex; use std::time::Duration; @@ -23,6 +23,14 @@ enum ClusterType { TcpTls, } +#[derive(Deserialize, Clone, Debug)] +struct ValkeyServerInfo { + host: String, + port: u32, + pid: u32, + is_primary: bool, +} + impl ClusterType { fn build_addr(use_tls: bool, host: &str, port: u16) -> redis::ConnectionAddr { if use_tls { @@ -40,15 +48,27 @@ impl ClusterType { pub struct RedisCluster { cluster_folder: String, - addresses: Vec, use_tls: bool, password: Option, + servers: Vec, } impl Drop for RedisCluster { fn drop(&mut self) { + let pids: Vec = self + .servers + .iter() + .map(|server| format!("{}", server.pid)) + .collect(); + let pids = pids.join(","); Self::execute_cluster_script( - vec!["stop", "--cluster-folder", &self.cluster_folder], + vec![ + "stop", + "--cluster-folder", + &self.cluster_folder, + "--pids", + &pids, + ], self.use_tls, self.password.clone(), ); @@ -119,44 +139,46 @@ impl RedisCluster { script_args.push(&replicas_num); } let (stdout, stderr) = Self::execute_cluster_script(script_args, use_tls, None); - let (cluster_folder, addresses) = Self::parse_start_script_output(&stdout, &stderr); + let (cluster_folder, servers) = Self::parse_start_script_output(&stdout, &stderr); let mut password: Option = None; if let Some(info) = conn_info { password.clone_from(&info.password); }; RedisCluster { cluster_folder, - addresses, use_tls, password, + servers, } } - fn parse_start_script_output(output: &str, errors: &str) -> (String, Vec) { - let cluster_folder = output.split("CLUSTER_FOLDER=").collect::>(); - assert!( - !cluster_folder.is_empty() && cluster_folder.len() >= 2, - "Received output: {output}, stderr: {errors}" - ); - let cluster_folder = cluster_folder.get(1).unwrap().lines(); - let cluster_folder = cluster_folder.collect::>(); - let cluster_folder = cluster_folder.first().unwrap().to_string(); - - let output_parts = output.split("CLUSTER_NODES=").collect::>(); - assert!( - !output_parts.is_empty() && output_parts.len() >= 2, - "Received output: {output}, stderr: {errors}" - ); - let nodes = output_parts.get(1).unwrap().split(','); - let mut address_vec: Vec = Vec::new(); - for node in nodes { - let node_parts = node.split(':').collect::>(); - let mut address_info = NodeAddress::new(); - address_info.host = node_parts.first().unwrap().to_string().into(); - address_info.port = node_parts.get(1).unwrap().parse::().unwrap(); - address_vec.push(address_info); + fn value_after_prefix(prefix: &str, line: &str) -> Option { + if !line.starts_with(prefix) { + return None; + } + Some(line[prefix.len()..].to_string()) + } + + fn parse_start_script_output(output: &str, _errors: &str) -> (String, Vec) { + let prefixes = vec!["CLUSTER_FOLDER", "SERVERS_JSON"]; + let mut values = std::collections::HashMap::::new(); + let lines: Vec<&str> = output.split('\n').map(|line| line.trim()).collect(); + for line in lines { + for prefix in &prefixes { + let prefix_with_shave = format!("{prefix}="); + if line.starts_with(&prefix_with_shave) { + values.insert( + prefix.to_string(), + Self::value_after_prefix(&prefix_with_shave, line).unwrap_or_default(), + ); + } + } } - (cluster_folder, address_vec) + + let cluster_folder = values.get("CLUSTER_FOLDER").unwrap(); + let cluster_nodes_json = values.get("SERVERS_JSON").unwrap(); + let servers: Vec = serde_json::from_str(cluster_nodes_json).unwrap(); + (cluster_folder.clone(), servers) } fn execute_cluster_script( @@ -180,6 +202,7 @@ impl RedisCluster { }, args.join(" ") ); + let output = if cfg!(target_os = "windows") { Command::new("cmd") .args(["/C", &cmd]) @@ -204,11 +227,9 @@ impl RedisCluster { } pub fn get_server_addresses(&self) -> Vec { - self.addresses + self.servers .iter() - .map(|address| { - ClusterType::build_addr(self.use_tls, &address.host, address.port as u16) - }) + .map(|server| ClusterType::build_addr(self.use_tls, &server.host, server.port as u16)) .collect() } } @@ -230,11 +251,11 @@ async fn setup_acl_for_cluster( } pub async fn create_cluster_client( - cluster: &Option, + cluster: Option<&RedisCluster>, mut configuration: TestConfiguration, ) -> Client { let addresses = if !configuration.shared_server { - cluster.as_ref().unwrap().get_server_addresses() + cluster.unwrap().get_server_addresses() } else { get_shared_cluster_addresses(configuration.use_tls) }; @@ -263,7 +284,36 @@ pub async fn setup_test_basics_internal(configuration: TestConfiguration) -> Clu } else { None }; - let client = create_cluster_client(&cluster, configuration).await; + let client = create_cluster_client(cluster.as_ref(), configuration).await; + ClusterTestBasics { cluster, client } +} + +pub async fn setup_default_cluster() -> RedisCluster { + let test_config = TestConfiguration::default(); + RedisCluster::new(false, &test_config.connection_info, None, None) +} + +pub async fn setup_default_client(cluster: &RedisCluster) -> Client { + let test_config = TestConfiguration::default(); + create_cluster_client(Some(cluster), test_config).await +} + +pub async fn setup_cluster_with_replicas( + configuration: TestConfiguration, + replicas_num: u16, + primaries_num: u16, +) -> ClusterTestBasics { + let cluster = if !configuration.shared_server { + Some(RedisCluster::new( + configuration.use_tls, + &configuration.connection_info, + Some(primaries_num), + Some(replicas_num), + )) + } else { + None + }; + let client = create_cluster_client(cluster.as_ref(), configuration).await; ClusterTestBasics { cluster, client } } @@ -274,3 +324,34 @@ pub async fn setup_test_basics(use_tls: bool) -> ClusterTestBasics { }) .await } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_start_script_output() { + let script_output = r#" +INFO:root:## Executing cluster_manager.py with the following args: + Namespace(host='127.0.0.1', tls=False, auth=None, log='info', logfile=None, action='start', cluster_mode=True, folder_path='/Users/user/glide-for-redis/utils/clusters', ports=None, shard_count=3, replica_count=2, prefix='redis-cluster', load_module=None) +INFO:root:2024-11-05 16:05:44.024796+00:00 Starting script for cluster /Users/user/glide-for-redis/utils/clusters/redis-cluster-2024-11-05T16-05-44Z-2bz4YS +LOG_FILE=/Users/user/glide-for-redis/utils/clusters/redis-cluster-2024-11-05T16-05-44Z-2bz4YS/cluster_manager.log +SERVERS_JSON=[{"host": "127.0.0.1", "port": 39163, "pid": 59428, "is_primary": true}, {"host": "127.0.0.1", "port": 23178, "pid": 59436, "is_primary": true}, {"host": "127.0.0.1", "port": 25186, "pid": 59453, "is_primary": true}, {"host": "127.0.0.1", "port": 52500, "pid": 59432, "is_primary": false}, {"host": "127.0.0.1", "port": 48252, "pid": 59461, "is_primary": false}, {"host": "127.0.0.1", "port": 19544, "pid": 59444, "is_primary": false}, {"host": "127.0.0.1", "port": 37455, "pid": 59440, "is_primary": false}, {"host": "127.0.0.1", "port": 9282, "pid": 59449, "is_primary": false}, {"host": "127.0.0.1", "port": 19843, "pid": 59457, "is_primary": false}] +INFO:root:Created Cluster Redis in 24.8926 seconds +CLUSTER_FOLDER=/Users/user/glide-for-redis/utils/clusters/redis-cluster-2024-11-05T16-05-44Z-2bz4YS +CLUSTER_NODES=127.0.0.1:39163,127.0.0.1:23178,127.0.0.1:25186,127.0.0.1:52500,127.0.0.1:48252,127.0.0.1:19544,127.0.0.1:37455,127.0.0.1:9282,127.0.0.1:19843 + "#; + let (folder, servers) = RedisCluster::parse_start_script_output(script_output, ""); + assert_eq!(servers.len(), 9); + assert_eq!( + folder, + "/Users/user/glide-for-redis/utils/clusters/redis-cluster-2024-11-05T16-05-44Z-2bz4YS" + ); + + let server_0 = servers.first().unwrap(); + assert_eq!(server_0.pid, 59428); + assert_eq!(server_0.port, 39163); + assert_eq!(server_0.host, "127.0.0.1"); + assert!(server_0.is_primary); + } +} diff --git a/glide-core/tests/utilities/mocks.rs b/glide-core/tests/utilities/mocks.rs index 160e8a3189..33b8ae4121 100644 --- a/glide-core/tests/utilities/mocks.rs +++ b/glide-core/tests/utilities/mocks.rs @@ -5,14 +5,15 @@ use futures_intrusive::sync::ManualResetEvent; use redis::{Cmd, ConnectionAddr, Value}; use std::collections::HashMap; use std::io; +use std::io::Read; +use std::io::Write; use std::net::TcpListener; +use std::net::TcpStream as StdTcpStream; use std::str::from_utf8; use std::sync::{ atomic::{AtomicU16, Ordering}, Arc, }; -use tokio::io::AsyncWriteExt; -use tokio::net::TcpStream; use tokio::sync::mpsc::UnboundedSender; pub struct MockedRequest { @@ -29,20 +30,24 @@ pub struct ServerMock { closing_completed_signal: Arc, } -async fn read_from_socket(buffer: &mut Vec, socket: &mut TcpStream) -> Option { - let _ = socket.readable().await; - - loop { - match socket.try_read_buf(buffer) { +fn read_from_socket( + buffer: &mut [u8], + socket: &mut StdTcpStream, + closing_signal: &Arc, +) -> Option { + while !closing_signal.is_set() { + let read_res = socket.read(buffer); // read() is using timeout + match read_res { Ok(0) => { return None; } - Ok(size) => return Some(size), + Ok(size) => { + return Some(size); + } Err(ref e) if e.kind() == io::ErrorKind::WouldBlock || e.kind() == io::ErrorKind::Interrupted => { - tokio::task::yield_now().await; continue; } Err(_) => { @@ -50,43 +55,53 @@ async fn read_from_socket(buffer: &mut Vec, socket: &mut TcpStream) -> Optio } } } + // If we reached here, it means we got a signal to terminate + None +} + +/// Escape and print a RESP message +fn log_resp_message(msg: &str) { + logger_core::log_info( + "Test", + format!( + "{:?} {}", + std::thread::current().id(), + msg.replace('\r', "\\r").replace('\n', "\\n") + ), + ); } -async fn receive_and_respond_to_next_message( +fn receive_and_respond_to_next_message( receiver: &mut tokio::sync::mpsc::UnboundedReceiver, - socket: &mut TcpStream, + socket: &mut StdTcpStream, received_commands: &Arc, constant_responses: &HashMap, closing_signal: &Arc, ) -> bool { - let mut buffer = Vec::with_capacity(1024); - let size = tokio::select! { - size = read_from_socket(&mut buffer, socket) => { - let Some(size) = size else { - return false; - }; - size - }, - _ = closing_signal.wait() => { + let mut buffer = vec![0; 1024]; + let size = match read_from_socket(&mut buffer, socket, closing_signal) { + Some(size) => size, + None => { return false; } }; - let message = from_utf8(&buffer[..size]).unwrap().to_string(); + log_resp_message(&message); + let setinfo_count = message.matches("SETINFO").count(); if setinfo_count > 0 { let mut buffer = Vec::new(); for _ in 0..setinfo_count { super::encode_value(&Value::Okay, &mut buffer).unwrap(); } - socket.write_all(&buffer).await.unwrap(); + socket.write_all(&buffer).unwrap(); return true; } if let Some(response) = constant_responses.get(&message) { let mut buffer = Vec::new(); super::encode_value(response, &mut buffer).unwrap(); - socket.write_all(&buffer).await.unwrap(); + socket.write_all(&buffer).unwrap(); return true; } let Ok(request) = receiver.try_recv() else { @@ -94,7 +109,7 @@ async fn receive_and_respond_to_next_message( }; received_commands.fetch_add(1, Ordering::AcqRel); assert_eq!(message, request.expected_message); - socket.write_all(request.response.as_bytes()).await.unwrap(); + socket.write_all(request.response.as_bytes()).unwrap(); true } @@ -127,15 +142,11 @@ impl ServerMock { let closing_signal_clone = closing_signal.clone(); let closing_completed_signal = Arc::new(ManualResetEvent::new(false)); let closing_completed_signal_clone = closing_completed_signal.clone(); - let runtime = tokio::runtime::Builder::new_multi_thread() - .worker_threads(1) - .thread_name(format!("ServerMock - {address}")) - .enable_all() - .build() - .unwrap(); - runtime.spawn(async move { - let listener = tokio::net::TcpListener::from_std(listener).unwrap(); - let mut socket = listener.accept().await.unwrap().0; + let address_clone = address.clone(); + std::thread::spawn(move || { + logger_core::log_info("Test", format!("ServerMock started on: {}", address_clone)); + let mut socket: StdTcpStream = listener.accept().unwrap().0; + let _ = socket.set_read_timeout(Some(std::time::Duration::from_millis(10))); while receive_and_respond_to_next_message( &mut receiver, @@ -143,17 +154,25 @@ impl ServerMock { &received_commands_clone, &constant_responses, &closing_signal_clone, - ) - .await - {} + ) {} + + // Terminate the connection + let _ = socket.shutdown(std::net::Shutdown::Both); + // Now notify exit completed closing_completed_signal_clone.set(); + + logger_core::log_info( + "Test", + format!("{:?} ServerMock exited", std::thread::current().id()), + ); }); + Self { request_sender, address, received_commands, - runtime: Some(runtime), + runtime: None, closing_signal, closing_completed_signal, } @@ -186,6 +205,5 @@ impl Mock for ServerMock { impl Drop for ServerMock { fn drop(&mut self) { self.closing_signal.set(); - self.runtime.take().unwrap().shutdown_background(); } } diff --git a/glide-core/tests/utilities/mod.rs b/glide-core/tests/utilities/mod.rs index 765c1cffb1..7318d6e640 100644 --- a/glide-core/tests/utilities/mod.rs +++ b/glide-core/tests/utilities/mod.rs @@ -659,6 +659,10 @@ pub fn create_connection_request( connection_request.client_name = client_name.deref().into(); } + if let Some(client_az) = &configuration.client_az { + connection_request.client_az = client_az.deref().into(); + } + connection_request } @@ -673,6 +677,7 @@ pub struct TestConfiguration { pub read_from: Option, pub database_id: u32, pub client_name: Option, + pub client_az: Option, pub protocol: ProtocolVersion, } diff --git a/go/Cargo.toml b/go/Cargo.toml index 6d6c4ecb15..48556820fd 100644 --- a/go/Cargo.toml +++ b/go/Cargo.toml @@ -9,7 +9,7 @@ authors = ["Valkey GLIDE Maintainers"] crate-type = ["cdylib"] [dependencies] -redis = { path = "../submodules/redis-rs/redis", features = ["aio", "tokio-comp", "tls", "tokio-native-tls-comp", "tls-rustls-insecure"] } +redis = { path = "../glide-core/redis-rs/redis", features = ["aio", "tokio-comp", "tls", "tokio-native-tls-comp", "tls-rustls-insecure"] } glide-core = { path = "../glide-core", features = ["socket-layer"] } tokio = { version = "^1", features = ["rt", "macros", "rt-multi-thread", "time"] } protobuf = { version = "3.3.0", features = [] } diff --git a/go/DEVELOPER.md b/go/DEVELOPER.md index a1d81f79e6..35879e9ed3 100644 --- a/go/DEVELOPER.md +++ b/go/DEVELOPER.md @@ -105,7 +105,7 @@ Before starting this step, make sure you've installed all software requirements. git clone --branch ${VERSION} https://github.com/valkey-io/valkey-glide.git cd valkey-glide ``` -2. Initialize git submodule: +2. Initialize git submodules: ```bash git submodule update --init --recursive ``` @@ -163,7 +163,7 @@ go test -race ./... -run TestConnectionRequestProtobufGeneration_allFieldsSet -v After pulling new changes, ensure that you update the submodules by running the following command: ```bash -git submodule update +git submodule update --init --recursive ``` ### Generate protobuf files diff --git a/go/api/commands.go b/go/api/commands.go index 1d00ff8270..5778ec9069 100644 --- a/go/api/commands.go +++ b/go/api/commands.go @@ -132,9 +132,12 @@ type StringCommands interface { // Sets multiple keys to multiple values in a single operation. // // Note: - // When in cluster mode, the command may route to multiple nodes when keys in keyValueMap map to different hash slots. - // - // See [valkey.io] for details. + // In cluster mode, if keys in `keyValueMap` map to different hash slots, the command + // will be split across these slots and executed separately for each. This means the command + // is atomic only at the slot level. If one or more slot-specific requests fail, the entire + // call will return the first encountered error, even though some requests may have succeeded + // while others did not. If this behavior impacts your application logic, consider splitting + // the request into sub-requests per slot to ensure atomicity. // // Parameters: // keyValueMap - A key-value map consisting of keys and their respective values to set. @@ -153,9 +156,12 @@ type StringCommands interface { // Retrieves the values of multiple keys. // // Note: - // When in cluster mode, the command may route to multiple nodes when keys map to different hash slots. - // - // See [valkey.io] for details. + // In cluster mode, if keys in `keys` map to different hash slots, the command + // will be split across these slots and executed separately for each. This means the command + // is atomic only at the slot level. If one or more slot-specific requests fail, the entire + // call will return the first encountered error, even though some requests may have succeeded + // while others did not. If this behavior impacts your application logic, consider splitting + // the request into sub-requests per slot to ensure atomicity. // // Parameters: // keys - A list of keys to retrieve values for. diff --git a/go/src/lib.rs b/go/src/lib.rs index 55b3c9515c..193d0d1bea 100644 --- a/go/src/lib.rs +++ b/go/src/lib.rs @@ -16,6 +16,7 @@ use std::{ ffi::{c_void, CString}, mem, os::raw::{c_char, c_double, c_long, c_ulong}, + ptr, }; use tokio::runtime::Builder; use tokio::runtime::Runtime; @@ -78,6 +79,23 @@ impl Default for CommandResponse { } } +impl Default for CommandResponse { + fn default() -> Self { + CommandResponse { + response_type: ResponseType::default(), + int_value: 0, + float_value: 0.0, + bool_value: false, + string_value: ptr::null_mut(), + string_value_len: 0, + array_value: ptr::null_mut(), + array_value_len: 0, + map_key: ptr::null_mut(), + map_value: ptr::null_mut(), + } + } +} + #[repr(C)] #[derive(Debug, Default)] pub enum ResponseType { diff --git a/java/Cargo.toml b/java/Cargo.toml index 6428f67fa6..c8fa49fe3f 100644 --- a/java/Cargo.toml +++ b/java/Cargo.toml @@ -10,7 +10,7 @@ authors = ["Valkey GLIDE Maintainers"] crate-type = ["cdylib"] [dependencies] -redis = { path = "../submodules/redis-rs/redis", features = ["aio", "tokio-comp", "connection-manager", "tokio-rustls-comp"] } +redis = { path = "../glide-core/redis-rs/redis", features = ["aio", "tokio-comp", "connection-manager", "tokio-rustls-comp"] } glide-core = { path = "../glide-core", features = ["socket-layer"] } tokio = { version = "^1", features = ["rt", "macros", "rt-multi-thread", "time"] } logger_core = {path = "../logger_core"} diff --git a/java/README.md b/java/README.md index 9b14dab87d..4264d4c838 100644 --- a/java/README.md +++ b/java/README.md @@ -9,11 +9,13 @@ Valkey General Language Independent Driver for the Enterprise (GLIDE), is an ope The release of Valkey GLIDE was tested on the following platforms: Linux: -- Ubuntu 22.04.1 (x86_64) -- Amazon Linux 2023 (AL2023) (x86_64) + +- Ubuntu 22.04.1 (x86_64 and aarch64) +- Amazon Linux 2023 (AL2023) (x86_64) macOS: -- macOS 12.7 (Apple silicon/aarch_64 and Intel/x86_64) + +- macOS 14.7 (Apple silicon/aarch_64) ## Layout of Java code The Java client contains the following parts: @@ -55,7 +57,6 @@ Additionally, consider installing the Gradle plugin, [OS Detector](https://githu There are 4 types of classifiers for Valkey GLIDE which are ``` osx-aarch_64 -osx-x86_64 linux-aarch_64 linux-x86_64 ``` @@ -69,11 +70,6 @@ dependencies { implementation group: 'io.valkey', name: 'valkey-glide', version: '1.+', classifier: 'osx-aarch_64' } -// osx-x86_64 -dependencies { - implementation group: 'io.valkey', name: 'valkey-glide', version: '1.+', classifier: 'osx-x86_64' -} - // linux-aarch_64 dependencies { implementation group: 'io.valkey', name: 'valkey-glide', version: '1.+', classifier: 'linux-aarch_64' @@ -105,14 +101,6 @@ Maven: [1.0.0,2.0.0) - - - io.valkey - valkey-glide - osx-x86_64 - [1.0.0,2.0.0) - - io.valkey @@ -136,9 +124,6 @@ SBT: // osx-aarch_64 libraryDependencies += "io.valkey" % "valkey-glide" % "1.+" classifier "osx-aarch_64" -// osx-x86_64 -libraryDependencies += "io.valkey" % "valkey-glide" % "1.+" classifier "osx-x86_64" - // linux-aarch_64 libraryDependencies += "io.valkey" % "valkey-glide" % "1.+" classifier "linux-aarch_64" diff --git a/java/client/build.gradle b/java/client/build.gradle index 46fa8f4cee..cc2446251d 100644 --- a/java/client/build.gradle +++ b/java/client/build.gradle @@ -22,7 +22,6 @@ dependencies { // At the moment, Windows is not supported implementation group: 'io.netty', name: 'netty-transport-native-epoll', version: '4.1.100.Final', classifier: 'linux-x86_64' implementation group: 'io.netty', name: 'netty-transport-native-epoll', version: '4.1.100.Final', classifier: 'linux-aarch_64' - implementation group: 'io.netty', name: 'netty-transport-native-kqueue', version: '4.1.100.Final', classifier: 'osx-x86_64' implementation group: 'io.netty', name: 'netty-transport-native-kqueue', version: '4.1.100.Final', classifier: 'osx-aarch_64' // junit @@ -165,8 +164,8 @@ jar.dependsOn('copyNativeLib') javadoc.dependsOn('copyNativeLib') copyNativeLib.dependsOn('buildRustRelease') compileTestJava.dependsOn('copyNativeLib') -test.dependsOn('buildRust') -testFfi.dependsOn('buildRust') +test.dependsOn('buildRustRelease') +testFfi.dependsOn('buildRustRelease') test { exclude "glide/ffi/FfiTest.class" diff --git a/java/client/src/main/java/glide/api/BaseClient.java b/java/client/src/main/java/glide/api/BaseClient.java index 0283113975..fd8015cc2e 100644 --- a/java/client/src/main/java/glide/api/BaseClient.java +++ b/java/client/src/main/java/glide/api/BaseClient.java @@ -275,6 +275,7 @@ import glide.connectors.resources.ThreadPoolResource; import glide.connectors.resources.ThreadPoolResourceAllocator; import glide.ffi.resolvers.GlideValueResolver; +import glide.ffi.resolvers.StatisticsResolver; import glide.managers.BaseResponseResolver; import glide.managers.CommandManager; import glide.managers.ConnectionManager; @@ -385,6 +386,15 @@ protected static CompletableFuture createClient( } } + /** + * Return a statistics + * + * @return Return a {@link Map} that contains the statistics collected internally by GLIDE core + */ + public Map getStatistics() { + return (HashMap) StatisticsResolver.getStatistics(); + } + /** * Return a next pubsub message if it is present. * @@ -767,6 +777,66 @@ protected Map handleLcsIdxResponse(Map response) return response; } + /** + * Update the current connection with a new password. + * + *

This method is useful in scenarios where the server password has changed or when utilizing + * short-lived passwords for enhanced security. It allows the client to update its password to + * reconnect upon disconnection without the need to recreate the client instance. This ensures + * that the internal reconnection mechanism can handle reconnection seamlessly, preventing the + * loss of in-flight commands. + * + * @param immediateAuth A boolean flag. If true, the client will + * authenticate immediately with the new password against all connections, Using AUTH + * command.
+ * If password supplied is an empty string, the client will not perform auth and a warning + * will be returned.
+ * The default is `false`. + * @apiNote This method updates the client's internal password configuration and does not perform + * password rotation on the server side. + * @param password A new password to set. + * @return "OK". + * @example + *

{@code
+     * String response = client.resetConnectionPassword("new_password", RE_AUTHENTICATE).get();
+     * assert response.equals("OK");
+     * }
+ */ + public CompletableFuture updateConnectionPassword( + @NonNull String password, boolean immediateAuth) { + return commandManager.submitPasswordUpdate( + Optional.of(password), immediateAuth, this::handleStringResponse); + } + + /** + * Update the current connection by removing the password. + * + *

This method is useful in scenarios where the server password has changed or when utilizing + * short-lived passwords for enhanced security. It allows the client to update its password to + * reconnect upon disconnection without the need to recreate the client instance. This ensures + * that the internal reconnection mechanism can handle reconnection seamlessly, preventing the + * loss of in-flight commands. + * + * @apiNote This method updates the client's internal password configuration and does not perform + * password rotation on the server side. + * @param immediateAuth A boolean flag. If true, the client will + * authenticate immediately with the new password against all connections, Using AUTH + * command.
+ * If password supplied is an empty string, the client will not perform auth and a warning + * will be returned.
+ * The default is `false`. + * @return "OK". + * @example + *

{@code
+     * String response = client.resetConnectionPassword(true).get();
+     * assert response.equals("OK");
+     * }
+ */ + public CompletableFuture updateConnectionPassword(boolean immediateAuth) { + return commandManager.submitPasswordUpdate( + Optional.empty(), immediateAuth, this::handleStringResponse); + } + @Override public CompletableFuture del(@NonNull String[] keys) { return commandManager.submitNewCommand(Del, keys, this::handleLongResponse); diff --git a/java/client/src/main/java/glide/api/commands/GenericBaseCommands.java b/java/client/src/main/java/glide/api/commands/GenericBaseCommands.java index a55c1ef1a8..6234672899 100644 --- a/java/client/src/main/java/glide/api/commands/GenericBaseCommands.java +++ b/java/client/src/main/java/glide/api/commands/GenericBaseCommands.java @@ -23,8 +23,12 @@ public interface GenericBaseCommands { * Removes the specified keys from the database. A key is ignored if it does not * exist. * - * @apiNote When in cluster mode, the command may route to multiple nodes when keys - * map to different hash slots. + * @apiNote In cluster mode, if keys in keys map to different hash slots, the command + * will be split across these slots and executed separately for each. This means the command + * is atomic only at the slot level. If one or more slot-specific requests fail, the entire + * call will return the first encountered error, even though some requests may have succeeded + * while others did not. If this behavior impacts your application logic, consider splitting + * the request into sub-requests per slot to ensure atomicity. * @see valkey.io for details. * @param keys The keys we wanted to remove. * @return The number of keys that were removed. @@ -40,8 +44,12 @@ public interface GenericBaseCommands { * Removes the specified keys from the database. A key is ignored if it does not * exist. * - * @apiNote When in cluster mode, the command may route to multiple nodes when keys - * map to different hash slots. + * @apiNote In cluster mode, if keys in keys map to different hash slots, the command + * will be split across these slots and executed separately for each. This means the command + * is atomic only at the slot level. If one or more slot-specific requests fail, the entire + * call will return the first encountered error, even though some requests may have succeeded + * while others did not. If this behavior impacts your application logic, consider splitting + * the request into sub-requests per slot to ensure atomicity. * @see valkey.io for details. * @param keys The keys we wanted to remove. * @return The number of keys that were removed. @@ -56,8 +64,12 @@ public interface GenericBaseCommands { /** * Returns the number of keys in keys that exist in the database. * - * @apiNote When in cluster mode, the command may route to multiple nodes when keys - * map to different hash slots. + * @apiNote In cluster mode, if keys in keys map to different hash slots, the command + * will be split across these slots and executed separately for each. This means the command + * is atomic only at the slot level. If one or more slot-specific requests fail, the entire + * call will return the first encountered error, even though some requests may have succeeded + * while others did not. If this behavior impacts your application logic, consider splitting + * the request into sub-requests per slot to ensure atomicity. * @see valkey.io for details. * @param keys The keys list to check. * @return The number of keys that exist. If the same existing key is mentioned in keys @@ -73,8 +85,12 @@ public interface GenericBaseCommands { /** * Returns the number of keys in keys that exist in the database. * - * @apiNote When in cluster mode, the command may route to multiple nodes when keys - * map to different hash slots. + * @apiNote In cluster mode, if keys in keys map to different hash slots, the command + * will be split across these slots and executed separately for each. This means the command + * is atomic only at the slot level. If one or more slot-specific requests fail, the entire + * call will return the first encountered error, even though some requests may have succeeded + * while others did not. If this behavior impacts your application logic, consider splitting + * the request into sub-requests per slot to ensure atomicity. * @see valkey.io for details. * @param keys The keys list to check. * @return The number of keys that exist. If the same existing key is mentioned in keys @@ -93,8 +109,12 @@ public interface GenericBaseCommands { * specified keys and ignores non-existent ones. However, this command does not block the server, * while DEL does. * - * @apiNote When in cluster mode, the command may route to multiple nodes when keys - * map to different hash slots. + * @apiNote In cluster mode, if keys in keys map to different hash slots, the command + * will be split across these slots and executed separately for each. This means the command + * is atomic only at the slot level. If one or more slot-specific requests fail, the entire + * call will return the first encountered error, even though some requests may have succeeded + * while others did not. If this behavior impacts your application logic, consider splitting + * the request into sub-requests per slot to ensure atomicity. * @see valkey.io for details. * @param keys The list of keys to unlink. * @return The number of keys that were unlinked. @@ -112,8 +132,12 @@ public interface GenericBaseCommands { * specified keys and ignores non-existent ones. However, this command does not block the server, * while DEL does. * - * @apiNote When in cluster mode, the command may route to multiple nodes when keys - * map to different hash slots. + * @apiNote In cluster mode, if keys in keys map to different hash slots, the command + * will be split across these slots and executed separately for each. This means the command + * is atomic only at the slot level. If one or more slot-specific requests fail, the entire + * call will return the first encountered error, even though some requests may have succeeded + * while others did not. If this behavior impacts your application logic, consider splitting + * the request into sub-requests per slot to ensure atomicity. * @see valkey.io for details. * @param keys The list of keys to unlink. * @return The number of keys that were unlinked. @@ -952,8 +976,12 @@ CompletableFuture pexpireAt( /** * Updates the last access time of specified keys. * - * @apiNote When in cluster mode, the command may route to multiple nodes when keys - * map to different hash slots. + * @apiNote In cluster mode, if keys in keys map to different hash slots, the command + * will be split across these slots and executed separately for each. This means the command + * is atomic only at the slot level. If one or more slot-specific requests fail, the entire + * call will return the first encountered error, even though some requests may have succeeded + * while others did not. If this behavior impacts your application logic, consider splitting + * the request into sub-requests per slot to ensure atomicity. * @see valkey.io for details. * @param keys The keys to update last access time. * @return The number of keys that were updated. @@ -968,8 +996,12 @@ CompletableFuture pexpireAt( /** * Updates the last access time of specified keys. * - * @apiNote When in cluster mode, the command may route to multiple nodes when keys - * map to different hash slots. + * @apiNote In cluster mode, if keys in keys map to different hash slots, the command + * will be split across these slots and executed separately for each. This means the command + * is atomic only at the slot level. If one or more slot-specific requests fail, the entire + * call will return the first encountered error, even though some requests may have succeeded + * while others did not. If this behavior impacts your application logic, consider splitting + * the request into sub-requests per slot to ensure atomicity. * @see valkey.io for details. * @param keys The keys to update last access time. * @return The number of keys that were updated. diff --git a/java/client/src/main/java/glide/api/commands/StringBaseCommands.java b/java/client/src/main/java/glide/api/commands/StringBaseCommands.java index 3f46f6a2cb..20f13c30f2 100644 --- a/java/client/src/main/java/glide/api/commands/StringBaseCommands.java +++ b/java/client/src/main/java/glide/api/commands/StringBaseCommands.java @@ -249,8 +249,12 @@ public interface StringBaseCommands { /** * Retrieves the values of multiple keys. * - * @apiNote When in cluster mode, the command may route to multiple nodes when keys - * map to different hash slots. + * @apiNote In cluster mode, if keys in keys map to different hash slots, the command + * will be split across these slots and executed separately for each. This means the command + * is atomic only at the slot level. If one or more slot-specific requests fail, the entire + * call will return the first encountered error, even though some requests may have succeeded + * while others did not. If this behavior impacts your application logic, consider splitting + * the request into sub-requests per slot to ensure atomicity. * @see valkey.io for details. * @param keys A list of keys to retrieve values for. * @return An array of values corresponding to the provided keys.
@@ -267,8 +271,12 @@ public interface StringBaseCommands { /** * Retrieves the values of multiple keys. * - * @apiNote When in cluster mode, the command may route to multiple nodes when keys - * map to different hash slots. + * @apiNote In cluster mode, if keys in keys map to different hash slots, the command + * will be split across these slots and executed separately for each. This means the command + * is atomic only at the slot level. If one or more slot-specific requests fail, the entire + * call will return the first encountered error, even though some requests may have succeeded + * while others did not. If this behavior impacts your application logic, consider splitting + * the request into sub-requests per slot to ensure atomicity. * @see valkey.io for details. * @param keys A list of keys to retrieve values for. * @return An array of values corresponding to the provided keys.
@@ -285,11 +293,15 @@ public interface StringBaseCommands { /** * Sets multiple keys to multiple values in a single operation. * - * @apiNote When in cluster mode, the command may route to multiple nodes when keys in - * keyValueMap map to different hash slots. + * @apiNote In cluster mode, if keys in keyValueMap map to different hash slots, the + * command will be split across these slots and executed separately for each. This means the + * command is atomic only at the slot level. If one or more slot-specific requests fail, the + * entire call will return the first encountered error, even though some requests may have + * succeeded while others did not. If this behavior impacts your application logic, consider + * splitting the request into sub-requests per slot to ensure atomicity. * @see valkey.io for details. * @param keyValueMap A key-value map consisting of keys and their respective values to set. - * @return Always OK. + * @return A simple OK response. * @example *
{@code
      * String result = client.mset(Map.of("key1", "value1", "key2", "value2"}).get();
@@ -301,11 +313,15 @@ public interface StringBaseCommands {
     /**
      * Sets multiple keys to multiple values in a single operation.
      *
-     * @apiNote When in cluster mode, the command may route to multiple nodes when keys in 
-     *     keyValueMap map to different hash slots.
+     * @apiNote In cluster mode, if keys in keyValueMap map to different hash slots, the
+     *     command will be split across these slots and executed separately for each. This means the
+     *     command is atomic only at the slot level. If one or more slot-specific requests fail, the
+     *     entire call will return the first encountered error, even though some requests may have
+     *     succeeded while others did not. If this behavior impacts your application logic, consider
+     *     splitting the request into sub-requests per slot to ensure atomicity.
      * @see valkey.io for details.
      * @param keyValueMap A key-value map consisting of keys and their respective values to set.
-     * @return Always OK.
+     * @return A simple OK response.
      * @example
      *     
{@code
      * String result = client.msetBinary(Map.of(gs("key1"), gs("value1"), gs("key2"), gs("value2")}).get();
diff --git a/java/client/src/main/java/glide/api/commands/TransactionsBaseCommands.java b/java/client/src/main/java/glide/api/commands/TransactionsBaseCommands.java
index 199357fdfe..d4db27785e 100644
--- a/java/client/src/main/java/glide/api/commands/TransactionsBaseCommands.java
+++ b/java/client/src/main/java/glide/api/commands/TransactionsBaseCommands.java
@@ -15,8 +15,12 @@ public interface TransactionsBaseCommands {
      * will only execute commands if the watched keys are not modified before execution of the
      * transaction.
      *
-     * @apiNote When in cluster mode, the command may route to multiple nodes when keys
-     *     map to different hash slots.
+     * @apiNote In cluster mode, if keys in keys map to different hash slots, the command
+     *     will be split across these slots and executed separately for each. This means the command
+     *     is atomic only at the slot level. If one or more slot-specific requests fail, the entire
+     *     call will return the first encountered error, even though some requests may have succeeded
+     *     while others did not. If this behavior impacts your application logic, consider splitting
+     *     the request into sub-requests per slot to ensure atomicity.
      * @see valkey.io for details.
      * @param keys The keys to watch.
      * @return OK.
@@ -41,8 +45,12 @@ public interface TransactionsBaseCommands {
      * will only execute commands if the watched keys are not modified before execution of the
      * transaction.
      *
-     * @apiNote When in cluster mode, the command may route to multiple nodes when keys
-     *     map to different hash slots.
+     * @apiNote In cluster mode, if keys in keys map to different hash slots, the command
+     *     will be split across these slots and executed separately for each. This means the command
+     *     is atomic only at the slot level. If one or more slot-specific requests fail, the entire
+     *     call will return the first encountered error, even though some requests may have succeeded
+     *     while others did not. If this behavior impacts your application logic, consider splitting
+     *     the request into sub-requests per slot to ensure atomicity.
      * @see valkey.io for details.
      * @param keys The keys to watch.
      * @return OK.
diff --git a/java/client/src/main/java/glide/api/commands/servermodules/FT.java b/java/client/src/main/java/glide/api/commands/servermodules/FT.java
new file mode 100644
index 0000000000..1f27772e1c
--- /dev/null
+++ b/java/client/src/main/java/glide/api/commands/servermodules/FT.java
@@ -0,0 +1,931 @@
+/** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */
+package glide.api.commands.servermodules;
+
+import static glide.api.models.GlideString.gs;
+import static glide.utils.ArrayTransformUtils.castArray;
+import static glide.utils.ArrayTransformUtils.concatenateArrays;
+
+import glide.api.BaseClient;
+import glide.api.GlideClient;
+import glide.api.GlideClusterClient;
+import glide.api.models.ClusterValue;
+import glide.api.models.GlideString;
+import glide.api.models.commands.FT.FTAggregateOptions;
+import glide.api.models.commands.FT.FTCreateOptions;
+import glide.api.models.commands.FT.FTCreateOptions.FieldInfo;
+import glide.api.models.commands.FT.FTProfileOptions;
+import glide.api.models.commands.FT.FTSearchOptions;
+import java.util.Arrays;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import lombok.NonNull;
+
+/** Module for vector search commands. */
+public class FT {
+    /**
+     * Creates an index and initiates a backfill of that index.
+     *
+     * @param client The client to execute the command.
+     * @param indexName The index name.
+     * @param schema Fields to populate into the index. Equivalent to `SCHEMA` block in the module
+     *     API.
+     * @return "OK".
+     * @example
+     *     
{@code
+     * // Create an index for vectors of size 2:
+     * FT.create(client, "my_idx1", new FieldInfo[] {
+     *     new FieldInfo("vec", VectorFieldFlat.builder(DistanceMetric.L2, 2).build())
+     * }).get();
+     *
+     * // Create a 6-dimensional JSON index using the HNSW algorithm:
+     * FT.create(client, "my_idx2",
+     *     new FieldInfo[] { new FieldInfo("$.vec", "VEC",
+     *         VectorFieldHnsw.builder(DistanceMetric.L2, 6).numberOfEdges(32).build())
+     * }).get();
+     * }
+ */ + public static CompletableFuture create( + @NonNull BaseClient client, @NonNull String indexName, @NonNull FieldInfo[] schema) { + // Node: bug in meme DB - command fails if cmd is too short even though all mandatory args are + // present + // TODO confirm is it fixed or not and update docs if needed + return create(client, indexName, schema, FTCreateOptions.builder().build()); + } + + /** + * Creates an index and initiates a backfill of that index. + * + * @param client The client to execute the command. + * @param indexName The index name. + * @param schema Fields to populate into the index. Equivalent to `SCHEMA` block in the module + * API. + * @param options Additional parameters for the command - see {@link FTCreateOptions}. + * @return "OK". + * @example + *
{@code
+     * // Create a 6-dimensional JSON index using the HNSW algorithm:
+     * FT.create(client, "json_idx1",
+     *     new FieldInfo[] { new FieldInfo("$.vec", "VEC",
+     *         VectorFieldHnsw.builder(DistanceMetric.L2, 6).numberOfEdges(32).build())
+     *     },
+     *     FTCreateOptions.builder().dataType(JSON).prefixes(new String[] {"json:"}).build(),
+     * ).get();
+     * }
+ */ + public static CompletableFuture create( + @NonNull BaseClient client, + @NonNull String indexName, + @NonNull FieldInfo[] schema, + @NonNull FTCreateOptions options) { + return create(client, gs(indexName), schema, options); + } + + /** + * Creates an index and initiates a backfill of that index. + * + * @param client The client to execute the command. + * @param indexName The index name. + * @param schema Fields to populate into the index. Equivalent to `SCHEMA` block in the module + * API. + * @return "OK". + * @example + *
{@code
+     * // Create an index for vectors of size 2:
+     * FT.create(client, gs("my_idx1"), new FieldInfo[] {
+     *     new FieldInfo("vec", VectorFieldFlat.builder(DistanceMetric.L2, 2).build())
+     * }).get();
+     *
+     * // Create a 6-dimensional JSON index using the HNSW algorithm:
+     * FT.create(client, gs("my_idx2"),
+     *     new FieldInfo[] { new FieldInfo(gs("$.vec"), gs("VEC"),
+     *         VectorFieldHnsw.builder(DistanceMetric.L2, 6).numberOfEdges(32).build())
+     * }).get();
+     * }
+ */ + public static CompletableFuture create( + @NonNull BaseClient client, @NonNull GlideString indexName, @NonNull FieldInfo[] schema) { + // Node: bug in meme DB - command fails if cmd is too short even though all mandatory args are + // present + // TODO confirm is it fixed or not and update docs if needed + return create(client, indexName, schema, FTCreateOptions.builder().build()); + } + + /** + * Creates an index and initiates a backfill of that index. + * + * @param client The client to execute the command. + * @param indexName The index name. + * @param schema Fields to populate into the index. Equivalent to `SCHEMA` block in the module + * API. + * @param options Additional parameters for the command - see {@link FTCreateOptions}. + * @return OK. + * @example + *
{@code
+     * // Create a 6-dimensional JSON index using the HNSW algorithm:
+     * FT.create(client, gs("json_idx1"),
+     *     new FieldInfo[] { new FieldInfo(gs("$.vec"), gs("VEC"),
+     *         VectorFieldHnsw.builder(DistanceMetric.L2, 6).numberOfEdges(32).build())
+     *     },
+     *     FTCreateOptions.builder().dataType(JSON).prefixes(new String[] {"json:"}).build(),
+     * ).get();
+     * }
+ */ + public static CompletableFuture create( + @NonNull BaseClient client, + @NonNull GlideString indexName, + @NonNull FieldInfo[] schema, + @NonNull FTCreateOptions options) { + var args = + Stream.of( + new GlideString[] {gs("FT.CREATE"), indexName}, + options.toArgs(), + new GlideString[] {gs("SCHEMA")}, + Arrays.stream(schema) + .map(FieldInfo::toArgs) + .flatMap(Arrays::stream) + .toArray(GlideString[]::new)) + .flatMap(Arrays::stream) + .toArray(GlideString[]::new); + return executeCommand(client, args, false); + } + + /** + * Uses the provided query expression to locate keys within an index. Once located, the count + * and/or content of indexed fields within those keys can be returned. + * + * @param client The client to execute the command. + * @param indexName The index name to search into. + * @param query The text query to search. + * @param options The search options - see {@link FTSearchOptions}. + * @return A two element array, where first element is count of documents in result set, and the + * second element, which has format + * {@literal Map>} - a mapping between + * document names and map of their attributes.
+ * If {@link FTSearchOptions.FTSearchOptionsBuilder#count()} or {@link + * FTSearchOptions.FTSearchOptionsBuilder#limit(int, int)} with values 0, 0 is + * set, the command returns array with only one element - the count of the documents. + * @example + *
{@code
+     * byte[] vector = new byte[24];
+     * Arrays.fill(vector, (byte) 0);
+     * var result = FT.search(client, "json_idx1", "*=>[KNN 2 @VEC $query_vec]",
+     *         FTSearchOptions.builder().params(Map.of(gs("query_vec"), gs(vector))).build())
+     *     .get();
+     * assertArrayEquals(result, new Object[] { 2L, Map.of(
+     *     gs("json:2"), Map.of(gs("__VEC_score"), gs("11.1100006104"), gs("$"), gs("{\"vec\":[1.1,1.2,1.3,1.4,1.5,1.6]}")),
+     *     gs("json:0"), Map.of(gs("__VEC_score"), gs("91"), gs("$"), gs("{\"vec\":[1,2,3,4,5,6]}")))
+     * });
+     * }
+ */ + public static CompletableFuture search( + @NonNull BaseClient client, + @NonNull String indexName, + @NonNull String query, + @NonNull FTSearchOptions options) { + var args = + concatenateArrays( + new GlideString[] {gs("FT.SEARCH"), gs(indexName), gs(query)}, options.toArgs()); + return executeCommand(client, args, false); + } + + /** + * Uses the provided query expression to locate keys within an index. Once located, the count + * and/or content of indexed fields within those keys can be returned. + * + * @param client The client to execute the command. + * @param indexName The index name to search into. + * @param query The text query to search. + * @param options The search options - see {@link FTSearchOptions}. + * @return A two element array, where first element is count of documents in result set, and the + * second element, which has format + * {@literal Map>} - a mapping between + * document names and map of their attributes.
+ * If {@link FTSearchOptions.FTSearchOptionsBuilder#count()} or {@link + * FTSearchOptions.FTSearchOptionsBuilder#limit(int, int)} with values 0, 0 is + * set, the command returns array with only one element - the count of the documents. + * @example + *
{@code
+     * byte[] vector = new byte[24];
+     * Arrays.fill(vector, (byte) 0);
+     * var result = FT.search(client, gs("json_idx1"), gs("*=>[KNN 2 @VEC $query_vec]"),
+     *         FTSearchOptions.builder().params(Map.of(gs("query_vec"), gs(vector))).build())
+     *     .get();
+     * assertArrayEquals(result, new Object[] { 2L, Map.of(
+     *     gs("json:2"), Map.of(gs("__VEC_score"), gs("11.1100006104"), gs("$"), gs("{\"vec\":[1.1,1.2,1.3,1.4,1.5,1.6]}")),
+     *     gs("json:0"), Map.of(gs("__VEC_score"), gs("91"), gs("$"), gs("{\"vec\":[1,2,3,4,5,6]}")))
+     * });
+     * }
+ */ + public static CompletableFuture search( + @NonNull BaseClient client, + @NonNull GlideString indexName, + @NonNull GlideString query, + @NonNull FTSearchOptions options) { + var args = + concatenateArrays(new GlideString[] {gs("FT.SEARCH"), indexName, query}, options.toArgs()); + return executeCommand(client, args, false); + } + + /** + * Uses the provided query expression to locate keys within an index. Once located, the count + * and/or content of indexed fields within those keys can be returned. + * + * @param client The client to execute the command. + * @param indexName The index name to search into. + * @param query The text query to search. + * @return A two element array, where first element is count of documents in result set, and the + * second element, which has format + * {@literal Map>} - a mapping between + * document names and map of their attributes. + * @example + *
{@code
+     * byte[] vector = new byte[24];
+     * Arrays.fill(vector, (byte) 0);
+     * var result = FT.search(client, "json_idx1", "*").get();
+     * assertArrayEquals(result, new Object[] { 2L, Map.of(
+     *     gs("json:2"), Map.of(gs("$"), gs("{\"vec\":[1.1,1.2,1.3,1.4,1.5,1.6]}")),
+     *     gs("json:0"), Map.of(gs("$"), gs("{\"vec\":[1,2,3,4,5,6]}")))
+     * });
+     * }
+ */ + public static CompletableFuture search( + @NonNull BaseClient client, @NonNull String indexName, @NonNull String query) { + var args = new GlideString[] {gs("FT.SEARCH"), gs(indexName), gs(query)}; + return executeCommand(client, args, false); + } + + /** + * Uses the provided query expression to locate keys within an index. Once located, the count + * and/or content of indexed fields within those keys can be returned. + * + * @param client The client to execute the command. + * @param indexName The index name to search into. + * @param query The text query to search. + * @return A two element array, where first element is count of documents in result set, and the + * second element, which has format + * {@literal Map>} - a mapping between + * document names and map of their attributes. + * @example + *
{@code
+     * byte[] vector = new byte[24];
+     * Arrays.fill(vector, (byte) 0);
+     * var result = FT.search(client, gs("json_idx1"), gs("*")).get();
+     * assertArrayEquals(result, new Object[] { 2L, Map.of(
+     *     gs("json:2"), Map.of(gs("$"), gs("{\"vec\":[1.1,1.2,1.3,1.4,1.5,1.6]}")),
+     *     gs("json:0"), Map.of(gs("$"), gs("{\"vec\":[1,2,3,4,5,6]}")))
+     * });
+     * }
+ */ + public static CompletableFuture search( + @NonNull BaseClient client, @NonNull GlideString indexName, @NonNull GlideString query) { + var args = new GlideString[] {gs("FT.SEARCH"), indexName, query}; + return executeCommand(client, args, false); + } + + /** + * Deletes an index and associated content. Indexed document keys are unaffected. + * + * @param client The client to execute the command. + * @param indexName The index name. + * @return "OK". + * @example + *
{@code
+     * FT.dropindex(client, "hash_idx1").get();
+     * }
+ */ + public static CompletableFuture dropindex( + @NonNull BaseClient client, @NonNull String indexName) { + return executeCommand(client, new GlideString[] {gs("FT.DROPINDEX"), gs(indexName)}, false); + } + + /** + * Deletes an index and associated content. Indexed document keys are unaffected. + * + * @param client The client to execute the command. + * @param indexName The index name. + * @return "OK". + * @example + *
{@code
+     * FT.dropindex(client, gs("hash_idx1")).get();
+     * }
+ */ + public static CompletableFuture dropindex( + @NonNull BaseClient client, @NonNull GlideString indexName) { + return executeCommand(client, new GlideString[] {gs("FT.DROPINDEX"), indexName}, false); + } + + /** + * Runs a search query on an index, and perform aggregate transformations on the results. + * + * @param client The client to execute the command. + * @param indexName The index name. + * @param query The text query to search. + * @return Results of the last stage of the pipeline. + * @example + *
{@code
+     * // example of using the API:
+     * FT.aggregate(client, "myIndex", "*").get();
+     * // the response contains data in the following format:
+     * Map[] response = new Map[] {
+     *     Map.of(
+     *         gs("condition"), gs("refurbished"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:9") }
+     *     ),
+     *     Map.of(
+     *         gs("condition"), gs("used"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:1"), gs("bicycle:2"), gs("bicycle:3") }
+     *     ),
+     *     Map.of(
+     *         gs("condition"), gs("new"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:0"), gs("bicycle:5") }
+     *     )
+     * };
+     * }
+ */ + public static CompletableFuture[]> aggregate( + @NonNull BaseClient client, @NonNull String indexName, @NonNull String query) { + return aggregate(client, gs(indexName), gs(query)); + } + + /** + * Runs a search query on an index, and perform aggregate transformations on the results. + * + * @param client The client to execute the command. + * @param indexName The index name. + * @param query The text query to search. + * @param options Additional parameters for the command - see {@link FTAggregateOptions}. + * @return Results of the last stage of the pipeline. + * @example + *
{@code
+     * // example of using the API:
+     * FTAggregateOptions options = FTAggregateOptions.builder()
+     *     .loadFields(new String[] {"__key"})
+     *     .addClause(
+     *             new FTAggregateOptions.GroupBy(
+     *                     new String[] {"@condition"},
+     *                     new Reducer[] {
+     *                         new Reducer("TOLIST", new String[] {"__key"}, "bicycles")
+     *                     }))
+     *     .build();
+     * FT.aggregate(client, "myIndex", "*", options).get();
+     * // the response contains data in the following format:
+     * Map[] response = new Map[] {
+     *     Map.of(
+     *         gs("condition"), gs("refurbished"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:9") }
+     *     ),
+     *     Map.of(
+     *         gs("condition"), gs("used"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:1"), gs("bicycle:2"), gs("bicycle:3") }
+     *     ),
+     *     Map.of(
+     *         gs("condition"), gs("new"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:0"), gs("bicycle:5") }
+     *     )
+     * };
+     * }
+ */ + public static CompletableFuture[]> aggregate( + @NonNull BaseClient client, + @NonNull String indexName, + @NonNull String query, + @NonNull FTAggregateOptions options) { + return aggregate(client, gs(indexName), gs(query), options); + } + + /** + * Runs a search query on an index, and perform aggregate transformations on the results. + * + * @param client The client to execute the command. + * @param indexName The index name. + * @param query The text query to search. + * @return Results of the last stage of the pipeline. + * @example + *
{@code
+     * // example of using the API:
+     * FT.aggregate(client, gs("myIndex"), gs("*")).get();
+     * // the response contains data in the following format:
+     * Map[] response = new Map[] {
+     *     Map.of(
+     *         gs("condition"), gs("refurbished"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:9") }
+     *     ),
+     *     Map.of(
+     *         gs("condition"), gs("used"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:1"), gs("bicycle:2"), gs("bicycle:3") }
+     *     ),
+     *     Map.of(
+     *         gs("condition"), gs("new"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:0"), gs("bicycle:5") }
+     *     )
+     * };
+     * }
+ */ + @SuppressWarnings("unchecked") + public static CompletableFuture[]> aggregate( + @NonNull BaseClient client, @NonNull GlideString indexName, @NonNull GlideString query) { + var args = new GlideString[] {gs("FT.AGGREGATE"), indexName, query}; + return FT.executeCommand(client, args, false) + .thenApply(res -> castArray(res, Map.class)); + } + + /** + * Runs a search query on an index, and perform aggregate transformations on the results. + * + * @param client The client to execute the command. + * @param indexName The index name. + * @param query The text query to search. + * @param options Additional parameters for the command - see {@link FTAggregateOptions}. + * @return Results of the last stage of the pipeline. + * @example + *
{@code
+     * // example of using the API:
+     * FTAggregateOptions options = FTAggregateOptions.builder()
+     *     .loadFields(new String[] {"__key"})
+     *     .addClause(
+     *             new FTAggregateOptions.GroupBy(
+     *                     new String[] {"@condition"},
+     *                     new Reducer[] {
+     *                         new Reducer("TOLIST", new String[] {"__key"}, "bicycles")
+     *                     }))
+     *     .build();
+     * FT.aggregate(client, gs("myIndex"), gs("*"), options).get();
+     * // the response contains data in the following format:
+     * Map[] response = new Map[] {
+     *     Map.of(
+     *         gs("condition"), gs("refurbished"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:9") }
+     *     ),
+     *     Map.of(
+     *         gs("condition"), gs("used"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:1"), gs("bicycle:2"), gs("bicycle:3") }
+     *     ),
+     *     Map.of(
+     *         gs("condition"), gs("new"),
+     *         gs("bicycles"), new Object[] { gs("bicycle:0"), gs("bicycle:5") }
+     *     )
+     * };
+     * }
+ */ + @SuppressWarnings("unchecked") + public static CompletableFuture[]> aggregate( + @NonNull BaseClient client, + @NonNull GlideString indexName, + @NonNull GlideString query, + @NonNull FTAggregateOptions options) { + var args = + concatenateArrays( + new GlideString[] {gs("FT.AGGREGATE"), indexName, query}, options.toArgs()); + return FT.executeCommand(client, args, false) + .thenApply(res -> castArray(res, Map.class)); + } + + /** + * Runs a search or aggregation query and collects performance profiling information. + * + * @param client The client to execute the command. + * @param indexName The index name. + * @param options Querying and profiling parameters - see {@link FTProfileOptions}. + * @return A two-element array. The first element contains results of query being profiled, the + * second element stores profiling information. + * @example + *
{@code
+     * var options = FTSearchOptions.builder().params(Map.of(
+     *         gs("query_vec"),
+     *         gs(new byte[] { (byte) 0, (byte) 0, (byte) 0, (byte) 0 })))
+     *     .build();
+     * var result = FT.profile(client, "myIndex", new FTProfileOptions("*=>[KNN 2 @VEC $query_vec]", options)).get();
+     * // result[0] contains `FT.SEARCH` response with the given options and query
+     * // result[1] contains profiling data as a `Map`
+     * }
+ */ + public static CompletableFuture profile( + @NonNull BaseClient client, @NonNull String indexName, @NonNull FTProfileOptions options) { + return profile(client, gs(indexName), options); + } + + /** + * Runs a search or aggregation query and collects performance profiling information. + * + * @param client The client to execute the command. + * @param indexName The index name. + * @param options Querying and profiling parameters - see {@link FTProfileOptions}. + * @return A two-element array. The first element contains results of query being profiled, the + * second element stores profiling information. + * @example + *
{@code
+     * var commandLine = new String[] { "*", "LOAD", "1", "__key", "GROUPBY", "1", "@condition", "REDUCE", "COUNT", "0", "AS", "bicylces" };
+     * var result = FT.profile(client, gs("myIndex"), new FTProfileOptions(QueryType.AGGREGATE, commandLine)).get();
+     * // result[0] contains `FT.AGGREGATE` response with the given command line
+     * // result[1] contains profiling data as a `Map`
+     * }
+ */ + public static CompletableFuture profile( + @NonNull BaseClient client, + @NonNull GlideString indexName, + @NonNull FTProfileOptions options) { + var args = concatenateArrays(new GlideString[] {gs("FT.PROFILE"), indexName}, options.toArgs()); + return executeCommand(client, args, false); + } + + /** + * Returns information about a given index. + * + * @param client The client to execute the command. + * @param indexName The index name. + * @return Nested maps with info about the index. See example for more details. + * @example + *
{@code
+     * // example of using the API:
+     * Map response = FT.info(client, "myIndex").get();
+     * // the response contains data in the following format:
+     * Map data = Map.of(
+     *     "index_name", gs("myIndex"),
+     *     "index_status", gs("AVAILABLE"),
+     *     "key_type", gs("JSON"),
+     *     "creation_timestamp", 1728348101728771L,
+     *     "key_prefixes", new Object[] { gs("json:") },
+     *     "num_indexed_vectors", 0L,
+     *     "space_usage", 653471L,
+     *     "num_docs", 0L,
+     *     "vector_space_usage", 653471L,
+     *     "index_degradation_percentage", 0L,
+     *     "fulltext_space_usage", 0L,
+     *     "current_lag", 0L,
+     *     "fields", new Object [] {
+     *         Map.of(
+     *             gs("identifier"), gs("$.vec"),
+     *             gs("type"), gs("VECTOR"),
+     *             gs("field_name"), gs("VEC"),
+     *             gs("option"), gs(""),
+     *             gs("vector_params", Map.of(
+     *                 gs("data_type", gs("FLOAT32"),
+     *                 gs("initial_capacity", 1000L,
+     *                 gs("current_capacity", 1000L,
+     *                 gs("distance_metric", gs("L2"),
+     *                 gs("dimension", 6L,
+     *                 gs("block_size", 1024L,
+     *                 gs("algorithm", gs("FLAT")
+     *             )
+     *         ),
+     *         Map.of(
+     *             gs("identifier"), gs("name"),
+     *             gs("type"), gs("TEXT"),
+     *             gs("field_name"), gs("name"),
+     *             gs("option"), gs("")
+     *         ),
+     *     }
+     * );
+     * }
+ */ + public static CompletableFuture> info( + @NonNull BaseClient client, @NonNull String indexName) { + return info(client, gs(indexName)); + } + + /** + * Returns information about a given index. + * + * @param client The client to execute the command. + * @param indexName The index name. + * @return Nested maps with info about the index. See example for more details. + * @example + *
{@code
+     * // example of using the API:
+     * Map response = FT.info(client, gs("myIndex")).get();
+     * // the response contains data in the following format:
+     * Map data = Map.of(
+     *     "index_name", gs("myIndex"),
+     *     "index_status", gs("AVAILABLE"),
+     *     "key_type", gs("JSON"),
+     *     "creation_timestamp", 1728348101728771L,
+     *     "key_prefixes", new Object[] { gs("json:") },
+     *     "num_indexed_vectors", 0L,
+     *     "space_usage", 653471L,
+     *     "num_docs", 0L,
+     *     "vector_space_usage", 653471L,
+     *     "index_degradation_percentage", 0L,
+     *     "fulltext_space_usage", 0L,
+     *     "current_lag", 0L,
+     *     "fields", new Object [] {
+     *         Map.of(
+     *             gs("identifier"), gs("$.vec"),
+     *             gs("type"), gs("VECTOR"),
+     *             gs("field_name"), gs("VEC"),
+     *             gs("option"), gs(""),
+     *             gs("vector_params", Map.of(
+     *                 gs("data_type", gs("FLOAT32"),
+     *                 gs("initial_capacity", 1000L,
+     *                 gs("current_capacity", 1000L,
+     *                 gs("distance_metric", gs("L2"),
+     *                 gs("dimension", 6L,
+     *                 gs("block_size", 1024L,
+     *                 gs("algorithm", gs("FLAT")
+     *             )
+     *         ),
+     *         Map.of(
+     *             gs("identifier"), gs("name"),
+     *             gs("type"), gs("TEXT"),
+     *             gs("field_name"), gs("name"),
+     *             gs("option"), gs("")
+     *         ),
+     *     }
+     * );
+     * }
+ */ + public static CompletableFuture> info( + @NonNull BaseClient client, @NonNull GlideString indexName) { + // TODO inconsistency on cluster client: the outer map is `Map`, + // while inner maps are `Map` + // The outer map converted from `Map` in ClusterValue::ofMultiValueBinary + // TODO server returns all map keys as `SimpleString`, we're safe to convert all to + // `GlideString`s to `String` + + // standalone client returns `Map`, but cluster `Map` + if (client instanceof GlideClusterClient) + return executeCommand(client, new GlideString[] {gs("FT.INFO"), indexName}, true); + return FT.>executeCommand( + client, new GlideString[] {gs("FT.INFO"), indexName}, true) + .thenApply( + map -> + map.entrySet().stream() + .collect(Collectors.toMap(e -> e.getKey().toString(), Map.Entry::getValue))); + } + + /** + * Lists all indexes. + * + * @param client The client to execute the command. + * @return An array of index names. + * @example + *
{@code
+     * GlideString[] indices = FT.list(client).get();
+     * }
+ */ + public static CompletableFuture list(@NonNull BaseClient client) { + return FT.executeCommand(client, new GlideString[] {gs("FT._LIST")}, false) + .thenApply(arr -> castArray(arr, GlideString.class)); + } + + /** + * Adds an alias for an index. The new alias name can be used anywhere that an index name is + * required. + * + * @param client The client to execute the command. + * @param aliasName The alias to be added to an index. + * @param indexName The index name for which the alias has to be added. + * @return "OK". + * @example + *
{@code
+     * FT.aliasadd(client, "myalias", "myindex").get(); // "OK"
+     * }
+ */ + public static CompletableFuture aliasadd( + @NonNull BaseClient client, @NonNull String aliasName, @NonNull String indexName) { + return aliasadd(client, gs(aliasName), gs(indexName)); + } + + /** + * Adds an alias for an index. The new alias name can be used anywhere that an index name is + * required. + * + * @param client The client to execute the command. + * @param aliasName The alias to be added to an index. + * @param indexName The index name for which the alias has to be added. + * @return "OK". + * @example + *
{@code
+     * FT.aliasadd(client, gs("myalias"), gs("myindex")).get(); // "OK"
+     * }
+ */ + public static CompletableFuture aliasadd( + @NonNull BaseClient client, @NonNull GlideString aliasName, @NonNull GlideString indexName) { + var args = new GlideString[] {gs("FT.ALIASADD"), aliasName, indexName}; + + return executeCommand(client, args, false); + } + + /** + * Deletes an existing alias for an index. + * + * @param client The client to execute the command. + * @param aliasName The existing alias to be deleted for an index. + * @return "OK". + * @example + *
{@code
+     * FT.aliasdel(client, "myalias").get(); // "OK"
+     * }
+ */ + public static CompletableFuture aliasdel( + @NonNull BaseClient client, @NonNull String aliasName) { + return aliasdel(client, gs(aliasName)); + } + + /** + * Deletes an existing alias for an index. + * + * @param client The client to execute the command. + * @param aliasName The existing alias to be deleted for an index. + * @return "OK". + * @example + *
{@code
+     * FT.aliasdel(client, gs("myalias")).get(); // "OK"
+     * }
+ */ + public static CompletableFuture aliasdel( + @NonNull BaseClient client, @NonNull GlideString aliasName) { + var args = new GlideString[] {gs("FT.ALIASDEL"), aliasName}; + + return executeCommand(client, args, false); + } + + /** + * Updates an existing alias to point to a different physical index. This command only affects + * future references to the alias. + * + * @param client The client to execute the command. + * @param aliasName The alias name. This alias will now be pointed to a different index. + * @param indexName The index name for which an existing alias has to be updated. + * @return "OK". + * @example + *
{@code
+     * FT.aliasupdate(client, "myalias", "myindex").get(); // "OK"
+     * }
+ */ + public static CompletableFuture aliasupdate( + @NonNull BaseClient client, @NonNull String aliasName, @NonNull String indexName) { + return aliasupdate(client, gs(aliasName), gs(indexName)); + } + + /** + * Update an existing alias to point to a different physical index. This command only affects + * future references to the alias. + * + * @param client The client to execute the command. + * @param aliasName The alias name. This alias will now be pointed to a different index. + * @param indexName The index name for which an existing alias has to be updated. + * @return "OK". + * @example + *
{@code
+     * FT.aliasupdate(client, gs("myalias"), gs("myindex")).get(); // "OK"
+     * }
+ */ + public static CompletableFuture aliasupdate( + @NonNull BaseClient client, @NonNull GlideString aliasName, @NonNull GlideString indexName) { + var args = new GlideString[] {gs("FT.ALIASUPDATE"), aliasName, indexName}; + return executeCommand(client, args, false); + } + + /** + * Lists all index aliases. + * + * @param client The client to execute the command. + * @return A map of index aliases to indices being aliased. + * @example + *
{@code
+     * var aliases = FT.aliaslist(client).get();
+     * // the response contains data in the following format:
+     * Map aliases = Map.of(
+     *     gs("alias"), gs("myIndex"),
+     * );
+     * }
+ */ + public static CompletableFuture> aliaslist( + @NonNull BaseClient client) { + // standalone client returns `Map`, but cluster `Map` + // The map converted from `Map` in ClusterValue::ofMultiValueBinary + // TODO this will fail once an alias name will be non-utf8-compatible + if (client instanceof GlideClient) + return executeCommand(client, new GlideString[] {gs("FT._ALIASLIST")}, true); + return FT.>executeCommand( + client, new GlideString[] {gs("FT._ALIASLIST")}, true) + .thenApply( + map -> + map.entrySet().stream() + .collect(Collectors.toMap(e -> gs(e.getKey()), Map.Entry::getValue))); + } + + /** + * Parse a query and return information about how that query was parsed. + * + * @param client The client to execute the command. + * @param indexName The index name to search into. + * @param query The text query to search. It is the same as the query passed as an argument to + * {@link FT#search(BaseClient, String, String)} and {@link FT#aggregate(BaseClient, String, + * String)}. + * @return A String representing the execution plan. + * @example + *
{@code
+     * String result = FT.explain(client, "myIndex", "@price:[0 10]").get();
+     * assert result.equals("Field {\n\tprice\n\t0\n\t10\n}");
+     * }
+ */ + public static CompletableFuture explain( + @NonNull BaseClient client, @NonNull String indexName, @NonNull String query) { + GlideString[] args = {gs("FT.EXPLAIN"), gs(indexName), gs(query)}; + return FT.executeCommand(client, args, false).thenApply(GlideString::toString); + } + + /** + * Parse a query and return information about how that query was parsed. + * + * @param client The client to execute the command. + * @param indexName The index name to search into. + * @param query The text query to search. It is the same as the query passed as an argument to + * {@link FT#search(BaseClient, GlideString, GlideString)} and {@link FT#aggregate(BaseClient, + * GlideString, GlideString)}. + * @return A GlideString representing the execution plan. + * @example + *
{@code
+     * GlideString result = FT.explain(client, gs("myIndex"), gs("@price:[0 10]")).get();
+     * assert result.equals("Field {\n\tprice\n\t0\n\t10\n}");
+     * }
+ */ + public static CompletableFuture explain( + @NonNull BaseClient client, @NonNull GlideString indexName, @NonNull GlideString query) { + GlideString[] args = {gs("FT.EXPLAIN"), indexName, query}; + return executeCommand(client, args, false); + } + + /** + * Same as the {@link FT#explain(BaseClient, String, String)} except that the results are + * displayed in a different format. + * + * @param client The client to execute the command. + * @param indexName The index name to search into. + * @param query The text query to search. It is the same as the query passed as an argument to + * {@link FT#search(BaseClient, String, String)} and {@link FT#aggregate(BaseClient, String, + * String)}. + * @return A String[] representing the execution plan. + * @example + *
{@code
+     * String[] result = FT.explaincli(client, "myIndex",  "@price:[0 10]").get();
+     * assert Arrays.equals(result, new String[]{
+     *   "Field {",
+     *   "  price",
+     *   "  0",
+     *   "  10",
+     *   "}"
+     * });
+     * }
+ */ + public static CompletableFuture explaincli( + @NonNull BaseClient client, @NonNull String indexName, @NonNull String query) { + CompletableFuture result = explaincli(client, gs(indexName), gs(query)); + return result.thenApply( + ret -> Arrays.stream(ret).map(GlideString::toString).toArray(String[]::new)); + } + + /** + * Same as the {@link FT#explain(BaseClient, String, String)} except that the results are + * displayed in a different format. + * + * @param client The client to execute the command. + * @param indexName The index name to search into. + * @param query The text query to search. It is the same as the query passed as an argument to + * {@link FT#search(BaseClient, GlideString, GlideString)} and {@link FT#aggregate(BaseClient, + * GlideString, GlideString)}. + * @return A GlideString[] representing the execution plan. + * @example + *
{@code
+     * GlideString[] result = FT.explaincli(client, gs("myIndex"),  gs("@price:[0 10]")).get();
+     * assert Arrays.equals(result, new GlideString[]{
+     *   gs("Field {"),
+     *   gs("  price"),
+     *   gs("  0"),
+     *   gs("  10"),
+     *   gs("}")
+     * });
+     * }
+ */ + public static CompletableFuture explaincli( + @NonNull BaseClient client, @NonNull GlideString indexName, @NonNull GlideString query) { + GlideString[] args = new GlideString[] {gs("FT.EXPLAINCLI"), indexName, query}; + return FT.executeCommand(client, args, false) + .thenApply(ret -> castArray(ret, GlideString.class)); + } + + /** + * A wrapper for custom command API. + * + * @param client The client to execute the command. + * @param args The command line. + * @param returnsMap - true if command returns a map + */ + @SuppressWarnings("unchecked") + private static CompletableFuture executeCommand( + BaseClient client, GlideString[] args, boolean returnsMap) { + if (client instanceof GlideClient) { + return ((GlideClient) client).customCommand(args).thenApply(r -> (T) r); + } else if (client instanceof GlideClusterClient) { + return ((GlideClusterClient) client) + .customCommand(args) + .thenApply(returnsMap ? ClusterValue::getMultiValue : ClusterValue::getSingleValue) + .thenApply(r -> (T) r); + } + throw new IllegalArgumentException( + "Unknown type of client, should be either `GlideClient` or `GlideClusterClient`"); + } +} diff --git a/java/client/src/main/java/glide/api/commands/servermodules/Json.java b/java/client/src/main/java/glide/api/commands/servermodules/Json.java new file mode 100644 index 0000000000..2b94564791 --- /dev/null +++ b/java/client/src/main/java/glide/api/commands/servermodules/Json.java @@ -0,0 +1,2915 @@ +/** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ +package glide.api.commands.servermodules; + +import static glide.api.models.GlideString.gs; +import static glide.utils.ArrayTransformUtils.castArray; +import static glide.utils.ArrayTransformUtils.concatenateArrays; + +import glide.api.BaseClient; +import glide.api.GlideClient; +import glide.api.GlideClusterClient; +import glide.api.models.ClusterValue; +import glide.api.models.GlideString; +import glide.api.models.commands.ConditionalChange; +import glide.api.models.commands.json.JsonArrindexOptions; +import glide.api.models.commands.json.JsonGetOptions; +import glide.api.models.commands.json.JsonGetOptionsBinary; +import glide.utils.ArgsBuilder; +import java.util.concurrent.CompletableFuture; +import lombok.NonNull; + +/** Module for JSON commands. */ +public class Json { + + private static final String JSON_PREFIX = "JSON."; + private static final String JSON_SET = JSON_PREFIX + "SET"; + private static final String JSON_GET = JSON_PREFIX + "GET"; + private static final String JSON_MGET = JSON_PREFIX + "MGET"; + private static final String JSON_NUMINCRBY = JSON_PREFIX + "NUMINCRBY"; + private static final String JSON_NUMMULTBY = JSON_PREFIX + "NUMMULTBY"; + private static final String JSON_ARRAPPEND = JSON_PREFIX + "ARRAPPEND"; + private static final String JSON_ARRINSERT = JSON_PREFIX + "ARRINSERT"; + private static final String JSON_ARRINDEX = JSON_PREFIX + "ARRINDEX"; + private static final String JSON_ARRLEN = JSON_PREFIX + "ARRLEN"; + private static final String[] JSON_DEBUG_MEMORY = new String[] {JSON_PREFIX + "DEBUG", "MEMORY"}; + private static final String[] JSON_DEBUG_FIELDS = new String[] {JSON_PREFIX + "DEBUG", "FIELDS"}; + private static final String JSON_ARRPOP = JSON_PREFIX + "ARRPOP"; + private static final String JSON_ARRTRIM = JSON_PREFIX + "ARRTRIM"; + private static final String JSON_OBJLEN = JSON_PREFIX + "OBJLEN"; + private static final String JSON_OBJKEYS = JSON_PREFIX + "OBJKEYS"; + private static final String JSON_DEL = JSON_PREFIX + "DEL"; + private static final String JSON_FORGET = JSON_PREFIX + "FORGET"; + private static final String JSON_TOGGLE = JSON_PREFIX + "TOGGLE"; + private static final String JSON_STRAPPEND = JSON_PREFIX + "STRAPPEND"; + private static final String JSON_STRLEN = JSON_PREFIX + "STRLEN"; + private static final String JSON_CLEAR = JSON_PREFIX + "CLEAR"; + private static final String JSON_RESP = JSON_PREFIX + "RESP"; + private static final String JSON_TYPE = JSON_PREFIX + "TYPE"; + + private Json() {} + + /** + * Sets the JSON value at the specified path stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the value will be set. The key + * will be modified only if value is added as the last child in the specified + * path, or if the specified path acts as the parent of a new child + * being added. + * @param value The value to set at the specific path, in JSON formatted string. + * @return A simple "OK" response if the value is successfully set. + * @example + *
{@code
+     * String value = Json.set(client, "doc", ".", "{\"a\": 1.0, \"b\": 2}").get();
+     * assert value.equals("OK");
+     * }
+ */ + public static CompletableFuture set( + @NonNull BaseClient client, + @NonNull String key, + @NonNull String path, + @NonNull String value) { + return executeCommand(client, new String[] {JSON_SET, key, path, value}); + } + + /** + * Sets the JSON value at the specified path stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the value will be set. The key + * will be modified only if value is added as the last child in the specified + * path, or if the specified path acts as the parent of a new child + * being added. + * @param value The value to set at the specific path, in JSON formatted GlideString. + * @return A simple "OK" response if the value is successfully set. + * @example + *
{@code
+     * String value = Json.set(client, gs("doc"), gs("."), gs("{\"a\": 1.0, \"b\": 2}")).get();
+     * assert value.equals("OK");
+     * }
+ */ + public static CompletableFuture set( + @NonNull BaseClient client, + @NonNull GlideString key, + @NonNull GlideString path, + @NonNull GlideString value) { + return executeCommand(client, new GlideString[] {gs(JSON_SET), key, path, value}); + } + + /** + * Sets the JSON value at the specified path stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the value will be set. The key + * will be modified only if value is added as the last child in the specified + * path, or if the specified path acts as the parent of a new child + * being added. + * @param value The value to set at the specific path, in JSON formatted string. + * @param setCondition Set the value only if the given condition is met (within the key or path). + * @return A simple "OK" response if the value is successfully set. If value isn't + * set because of setCondition, returns null. + * @example + *
{@code
+     * String value = Json.set(client, "doc", ".", "{\"a\": 1.0, \"b\": 2}", ConditionalChange.ONLY_IF_DOES_NOT_EXIST).get();
+     * assert value.equals("OK");
+     * }
+ */ + public static CompletableFuture set( + @NonNull BaseClient client, + @NonNull String key, + @NonNull String path, + @NonNull String value, + @NonNull ConditionalChange setCondition) { + return executeCommand( + client, new String[] {JSON_SET, key, path, value, setCondition.getValkeyApi()}); + } + + /** + * Sets the JSON value at the specified path stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the value will be set. The key + * will be modified only if value is added as the last child in the specified + * path, or if the specified path acts as the parent of a new child + * being added. + * @param value The value to set at the specific path, in JSON formatted GlideString. + * @param setCondition Set the value only if the given condition is met (within the key or path). + * @return A simple "OK" response if the value is successfully set. If value isn't + * set because of setCondition, returns null. + * @example + *
{@code
+     * String value = Json.set(client, gs("doc"), gs("."), gs("{\"a\": 1.0, \"b\": 2}"), ConditionalChange.ONLY_IF_DOES_NOT_EXIST).get();
+     * assert value.equals("OK");
+     * }
+ */ + public static CompletableFuture set( + @NonNull BaseClient client, + @NonNull GlideString key, + @NonNull GlideString path, + @NonNull GlideString value, + @NonNull ConditionalChange setCondition) { + return executeCommand( + client, + new GlideString[] {gs(JSON_SET), key, path, value, gs(setCondition.getValkeyApi())}); + } + + /** + * Retrieves the JSON value at the specified path stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return Returns a string representation of the JSON document. If key doesn't + * exist, returns null. + * @example + *
{@code
+     * String value = Json.get(client, "doc").get();
+     * assert value.equals("{\"a\": 1.0, \"b\": 2}");
+     * }
+ */ + public static CompletableFuture get(@NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new String[] {JSON_GET, key}); + } + + /** + * Retrieves the JSON value at the specified path stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return Returns a string representation of the JSON document. If key doesn't + * exist, returns null. + * @example + *
{@code
+     * GlideString value = Json.get(client, gs("doc")).get();
+     * assert value.equals(gs("{\"a\": 1.0, \"b\": 2}"));
+     * }
+ */ + public static CompletableFuture get( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new GlideString[] {gs(JSON_GET), key}); + } + + /** + * Retrieves the JSON value at the specified paths stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param paths List of paths within the JSON document. + * @return + *
    + *
  • If one path is given: + *
      + *
    • For JSONPath (path starts with $): Returns a stringified JSON list + * replies for every possible path, or a string representation of an empty array, + * if path doesn't exist. If key doesn't exist, returns null + * . + *
    • For legacy path (path doesn't start with $): Returns a string + * representation of the value in paths. If paths + * doesn't exist, an error is raised. If key doesn't exist, returns + * null. + *
    + *
  • If multiple paths are given: Returns a stringified JSON, in which each path is a key, + * and it's corresponding value, is the value as if the path was executed in the command + * as a single path. + *
+ * In case of multiple paths, and paths are a mix of both JSONPath and legacy + * path, the command behaves as if all are JSONPath paths. + * @example + *
{@code
+     * String value = Json.get(client, "doc", new String[] {"$"}).get();
+     * assert value.equals("{\"a\": 1.0, \"b\": 2}");
+     * String value = Json.get(client, "doc", new String[] {"$.a", "$.b"}).get();
+     * assert value.equals("{\"$.a\": [1.0], \"$.b\": [2]}");
+     * }
+ */ + public static CompletableFuture get( + @NonNull BaseClient client, @NonNull String key, @NonNull String[] paths) { + return executeCommand(client, concatenateArrays(new String[] {JSON_GET, key}, paths)); + } + + /** + * Retrieves the JSON value at the specified paths stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param paths List of paths within the JSON document. + * @return + *
    + *
  • If one path is given: + *
      + *
    • For JSONPath (path starts with $): Returns a stringified JSON list + * replies for every possible path, or a string representation of an empty array, + * if path doesn't exist. If key doesn't exist, returns null + * . + *
    • For legacy path (path doesn't start with $): Returns a string + * representation of the value in paths. If paths + * doesn't exist, an error is raised. If key doesn't exist, returns + * null. + *
    + *
  • If multiple paths are given: Returns a stringified JSON, in which each path is a key, + * and it's corresponding value, is the value as if the path was executed in the command + * as a single path. + *
+ * In case of multiple paths, and paths are a mix of both JSONPath and legacy + * path, the command behaves as if all are JSONPath paths. + * @example + *
{@code
+     * GlideString value = Json.get(client, gs("doc"), new GlideString[] {gs("$")}).get();
+     * assert value.equals(gs("{\"a\": 1.0, \"b\": 2}"));
+     * GlideString value = Json.get(client, gs("doc"), new GlideString[] {gs("$.a"), gs("$.b")}).get();
+     * assert value.equals(gs("{\"$.a\": [1.0], \"$.b\": [2]}"));
+     * }
+ */ + public static CompletableFuture get( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString[] paths) { + return executeCommand(client, concatenateArrays(new GlideString[] {gs(JSON_GET), key}, paths)); + } + + /** + * Retrieves the JSON value at the specified path stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param options Options for formatting the byte representation of the JSON data. See + * JsonGetOptions. + * @return Returns a string representation of the JSON document. If key doesn't + * exist, returns null. + * @example + *
{@code
+     * JsonGetOptions options = JsonGetOptions.builder()
+     *                              .indent("  ")
+     *                              .space(" ")
+     *                              .newline("\n")
+     *                              .build();
+     * String value = Json.get(client, "doc", "$", options).get();
+     * assert value.equals("{\n \"a\": \n  1.0\n ,\n \"b\": \n  2\n }");
+     * }
+ */ + public static CompletableFuture get( + @NonNull BaseClient client, @NonNull String key, @NonNull JsonGetOptions options) { + return executeCommand( + client, concatenateArrays(new String[] {JSON_GET, key}, options.toArgs())); + } + + /** + * Retrieves the JSON value at the specified path stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param options Options for formatting the byte representation of the JSON data. See + * JsonGetOptions. + * @return Returns a string representation of the JSON document. If key doesn't + * exist, returns null. + * @example + *
{@code
+     * JsonGetOptions options = JsonGetOptions.builder()
+     *                              .indent("  ")
+     *                              .space(" ")
+     *                              .newline("\n")
+     *                              .build();
+     * GlideString value = Json.get(client, gs("doc"), gs("$"), options).get();
+     * assert value.equals(gs("{\n \"a\": \n  1.0\n ,\n \"b\": \n  2\n }"));
+     * }
+ */ + public static CompletableFuture get( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull JsonGetOptionsBinary options) { + return executeCommand( + client, new ArgsBuilder().add(gs(JSON_GET)).add(key).add(options.toArgs()).toArray()); + } + + /** + * Retrieves the JSON value at the specified path stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param paths List of paths within the JSON document. + * @param options Options for formatting the byte representation of the JSON data. See + * JsonGetOptions. + * @return + *
    + *
  • If one path is given: + *
      + *
    • For JSONPath (path starts with $): Returns a stringified JSON list + * replies for every possible path, or a string representation of an empty array, + * if path doesn't exist. If key doesn't exist, returns null + * . + *
    • For legacy path (path doesn't start with $): Returns a string + * representation of the value in paths. If paths + * doesn't exist, an error is raised. If key doesn't exist, returns + * null. + *
    + *
  • If multiple paths are given: Returns a stringified JSON, in which each path is a key, + * and it's corresponding value, is the value as if the path was executed in the command + * as a single path. + *
+ * In case of multiple paths, and paths are a mix of both JSONPath and legacy + * path, the command behaves as if all are JSONPath paths. + * @example + *
{@code
+     * JsonGetOptions options = JsonGetOptions.builder()
+     *                              .indent("  ")
+     *                              .space(" ")
+     *                              .newline("\n")
+     *                              .build();
+     * String value = Json.get(client, "doc", new String[] {"$.a", "$.b"}, options).get();
+     * assert value.equals("{\n \"$.a\": [\n  1.0\n ],\n \"$.b\": [\n  2\n ]\n}");
+     * }
+ */ + public static CompletableFuture get( + @NonNull BaseClient client, + @NonNull String key, + @NonNull String[] paths, + @NonNull JsonGetOptions options) { + return executeCommand( + client, concatenateArrays(new String[] {JSON_GET, key}, options.toArgs(), paths)); + } + + /** + * Retrieves the JSON value at the specified path stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param paths List of paths within the JSON document. + * @param options Options for formatting the byte representation of the JSON data. See + * JsonGetOptions. + * @return + *
    + *
  • If one path is given: + *
      + *
    • For JSONPath (path starts with $): Returns a stringified JSON list + * replies for every possible path, or a string representation of an empty array, + * if path doesn't exist. If key doesn't exist, returns null + * . + *
    • For legacy path (path doesn't start with $): Returns a string + * representation of the value in paths. If paths + * doesn't exist, an error is raised. If key doesn't exist, returns + * null. + *
    + *
  • If multiple paths are given: Returns a stringified JSON, in which each path is a key, + * and it's corresponding value, is the value as if the path was executed in the command + * as a single path. + *
+ * In case of multiple paths, and paths are a mix of both JSONPath and legacy + * path, the command behaves as if all are JSONPath paths. + * @example + *
{@code
+     * JsonGetOptions options = JsonGetOptions.builder()
+     *                              .indent("  ")
+     *                              .space(" ")
+     *                              .newline("\n")
+     *                              .build();
+     * GlideString value = Json.get(client, gs("doc"), new GlideString[] {gs("$.a"), gs("$.b")}, options).get();
+     * assert value.equals(gs("{\n \"$.a\": [\n  1.0\n ],\n \"$.b\": [\n  2\n ]\n}"));
+     * }
+ */ + public static CompletableFuture get( + @NonNull BaseClient client, + @NonNull GlideString key, + @NonNull GlideString[] paths, + @NonNull JsonGetOptionsBinary options) { + return executeCommand( + client, + new ArgsBuilder().add(gs(JSON_GET)).add(key).add(options.toArgs()).add(paths).toArray()); + } + + /** + * Retrieves the JSON values at the specified path stored at multiple keys + * . + * + * @apiNote When in cluster mode, if keys in keys map to different hash slots, the + * command will be split across these slots and executed separately for each. This means the + * command is atomic only at the slot level. If one or more slot-specific requests fail, the + * entire call will return the first encountered error, even though some requests may have + * succeeded while others did not. If this behavior impacts your application logic, consider + * splitting the request into sub-requests per slot to ensure atomicity. + * @param client The client to execute the command. + * @param keys The keys of the JSON documents. + * @param path The path within the JSON documents. + * @return An array with requested values for each key. + *
    + *
  • For JSONPath (path starts with $): Returns a stringified JSON list + * replies for every possible path, or a string representation of an empty array, if + * path doesn't exist. + *
  • For legacy path (path doesn't start with $): Returns a string + * representation of the value in path. If path doesn't exist, + * the corresponding array element will be null. + *
+ * If a key doesn't exist, the corresponding array element will be null + * . + * @example + *
{@code
+     * Json.set(client, "doc1", "$", "{\"a\": 1, \"b\": [\"one\", \"two\"]}").get();
+     * Json.set(client, "doc2", "$", "{\"a\": 1, \"c\": false}").get();
+     * var res = Json.mget(client, new String[] { "doc1", "doc2", "non_existing" }, "$.c").get();
+     * assert Arrays.equals(res, new String[] { "[]", "[false]", null });
+     * }
+ */ + public static CompletableFuture mget( + @NonNull BaseClient client, @NonNull String[] keys, @NonNull String path) { + return Json.executeCommand( + client, concatenateArrays(new String[] {JSON_MGET}, keys, new String[] {path})) + .thenApply(res -> castArray(res, String.class)); + } + + /** + * Retrieves the JSON values at the specified path stored at multiple keys + * . + * + * @apiNote When in cluster mode, if keys in keys map to different hash slots, the + * command will be split across these slots and executed separately for each. This means the + * command is atomic only at the slot level. If one or more slot-specific requests fail, the + * entire call will return the first encountered error, even though some requests may have + * succeeded while others did not. If this behavior impacts your application logic, consider + * splitting the request into sub-requests per slot to ensure atomicity. + * @param client The client to execute the command. + * @param keys The keys of the JSON documents. + * @param path The path within the JSON documents. + * @return An array with requested values for each key. + *
    + *
  • For JSONPath (path starts with $): Returns a stringified JSON list + * replies for every possible path, or a string representation of an empty array, if + * path doesn't exist. + *
  • For legacy path (path doesn't start with $): Returns a string + * representation of the value in path. If path doesn't exist, + * the corresponding array element will be null. + *
+ * If a key doesn't exist, the corresponding array element will be null + * . + * @example + *
{@code
+     * Json.set(client, "doc1", "$", "{\"a\": 1, \"b\": [\"one\", \"two\"]}").get();
+     * Json.set(client, "doc2", "$", "{\"a\": 1, \"c\": false}").get();
+     * var res = Json.mget(client, new GlideString[] { gs("doc1"), gs("doc2"), gs("doc3") }, gs("$.c")).get();
+     * assert Arrays.equals(res, new GlideString[] { gs("[]"), gs("[false]"), null });
+     * }
+ */ + public static CompletableFuture mget( + @NonNull BaseClient client, @NonNull GlideString[] keys, @NonNull GlideString path) { + return Json.executeCommand( + client, + concatenateArrays(new GlideString[] {gs(JSON_MGET)}, keys, new GlideString[] {path})) + .thenApply(res -> castArray(res, GlideString.class)); + } + + /** + * Appends one or more values to the JSON array at the specified path + * within the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the values + * will be appended. + * @param values The JSON values to be appended to the array.
+ * JSON string values must be wrapped with quotes. For example, to append "foo", + * pass "\"foo\"". + * @return + *
    + *
  • For JSONPath (path starts with $):
    + * Returns a list of integers for every possible path, indicating the new length of the + * array after appending values, or null for JSON values + * matching the path that are not an array. If path does not exist, an + * empty array will be returned. + *
  • For legacy path (path doesn't start with $):
    + * Returns the new length of the array after appending values to the array + * at path. If multiple paths are matched, returns the last updated array. + * If the JSON value at path is not a array or if path doesn't + * exist, an error is raised. If key doesn't exist, an error is raised. + * @example + *
    {@code
    +     * Json.set(client, "doc", "$", "{\"a\": 1, \"b\": [\"one\", \"two\"]}").get();
    +     * var res = Json.arrappend(client, "doc", "$.b", new String[] {"\"three\""}).get();
    +     * assert Arrays.equals((Object[]) res, new int[] {3}); // New length of the array after appending
    +     * res = Json.arrappend(client, "doc", ".b", new String[] {"\"four\""}).get();
    +     * assert res.equals(4); // New length of the array after appending
    +     * }
    + */ + public static CompletableFuture arrappend( + @NonNull BaseClient client, + @NonNull String key, + @NonNull String path, + @NonNull String[] values) { + return executeCommand( + client, concatenateArrays(new String[] {JSON_ARRAPPEND, key, path}, values)); + } + + /** + * Appends one or more values to the JSON array at the specified path + * within the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the values + * will be appended. + * @param values The JSON values to be appended to the array.
    + * JSON string values must be wrapped with quotes. For example, to append "foo", + * pass "\"foo\"". + * @return + *
      + *
    • For JSONPath (path starts with $):
      + * Returns a list of integers for every possible path, indicating the new length of the + * new array after appending values, or null for JSON values + * matching the path that are not an array. If path does not exist, an + * empty array will be returned. + *
    • For legacy path (path doesn't start with $):
      + * Returns the length of the new array after appending values to the array + * at path. If multiple paths are matched, returns the last updated array. + * If the JSON value at path is not a array or if path doesn't + * exist, an error is raised. If key doesn't exist, an error is raised. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\": 1, \"b\": [\"one\", \"two\"]}").get();
      +     * var res = Json.arrappend(client, gs("doc"), gs("$.b"), new GlideString[] {gs("\"three\"")}).get();
      +     * assert Arrays.equals((Object[]) res, new int[] {3}); // New length of the array after appending
      +     * res = Json.arrappend(client, gs("doc"), gs(".b"), new GlideString[] {gs("\"four\"")}).get();
      +     * assert res.equals(4); // New length of the array after appending
      +     * }
      + */ + public static CompletableFuture arrappend( + @NonNull BaseClient client, + @NonNull GlideString key, + @NonNull GlideString path, + @NonNull GlideString[] values) { + return executeCommand( + client, new ArgsBuilder().add(gs(JSON_ARRAPPEND)).add(key).add(path).add(values).toArray()); + } + + /** + * Inserts one or more values into the array at the specified path within the JSON + * document stored at key, before the given index. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param index The array index before which values are inserted. + * @param values The JSON values to be inserted into the array.
      + * JSON string values must be wrapped with quotes. For example, to insert "foo", + * pass "\"foo\"". + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an Object[] with a list of integers for every possible path, + * indicating the new length of the array, or null for JSON values matching + * the path that are not an array. If path does not exist, an empty array + * will be returned. + *
      • For legacy path (path doesn't start with $):
        + * Returns an integer representing the new length of the array. If multiple paths are + * matched, returns the length of the first modified array. If path doesn't + * exist or the value at path is not an array, an error is raised. + *
      + * If the index is out of bounds or key doesn't exist, an error is raised. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "[[], [\"a\"], [\"a\", \"b\"]]").get();
      +     * var newValues = new String[] { "\"c\"", "{\"key\": \"value\"}", "true", "null", "[\"bar\"]" };
      +     * var res = Json.arrinsert(client, "doc", "$[*]", 0, newValues).get();
      +     * assert Arrays.equals((Object[]) res, new int[] { 5, 6, 7 }); // New lengths of arrays after insertion
      +     * var doc = Json.get(client, "doc").get();
      +     * assert doc.equals("[[\"c\", {\"key\": \"value\"}, true, null, [\"bar\"]], [\"c\", {\"key\": \"value\"}, "
      +     *     + "true, null, [\"bar\"], \"a\"], [\"c\", {\"key\": \"value\"}, true, null, [\"bar\"], \"a\", \"b\"]]");
      +     *
      +     * Json.set(client, "doc", "$", "[[], [\"a\"], [\"a\", \"b\"]]").get();
      +     * res = Json.arrinsert(client, "doc", ".", 0, new String[] { "\"c\"" }).get();
      +     * assert res == 4 // New length of the root array after insertion
      +     * doc = Json.get(client, "doc").get();
      +     * assert doc.equals("[\"c\", [], [\"a\"], [\"a\", \"b\"]]");
      +     * }
      + */ + public static CompletableFuture arrinsert( + @NonNull BaseClient client, + @NonNull String key, + @NonNull String path, + int index, + @NonNull String[] values) { + return executeCommand( + client, + concatenateArrays( + new String[] {JSON_ARRINSERT, key, path, Integer.toString(index)}, values)); + } + + /** + * Inserts one or more values into the array at the specified path within the JSON + * document stored at key, before the given index. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param index The array index before which values are inserted. + * @param values The JSON values to be inserted into the array.
      + * JSON string values must be wrapped with quotes. For example, to insert "foo", + * pass "\"foo\"". + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an Object[] with a list of integers for every possible path, + * indicating the new length of the array, or null for JSON values matching + * the path that are not an array. If path does not exist, an empty array + * will be returned. + *
      • For legacy path (path doesn't start with $):
        + * Returns an integer representing the new length of the array. If multiple paths are + * matched, returns the length of the first modified array. If path doesn't + * exist or the value at path is not an array, an error is raised. + *
      + * If the index is out of bounds or key doesn't exist, an error is raised. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "[[], [\"a\"], [\"a\", \"b\"]]").get();
      +     * var newValues = new GlideString[] { gs("\"c\""), gs("{\"key\": \"value\"}"), gs("true"), gs("null"), gs("[\"bar\"]") };
      +     * var res = Json.arrinsert(client, gs("doc"), gs("$[*]"), 0, newValues).get();
      +     * assert Arrays.equals((Object[]) res, new int[] { 5, 6, 7 }); // New lengths of arrays after insertion
      +     * var doc = Json.get(client, "doc").get();
      +     * assert doc.equals("[[\"c\", {\"key\": \"value\"}, true, null, [\"bar\"]], [\"c\", {\"key\": \"value\"}, "
      +     *     + "true, null, [\"bar\"], \"a\"], [\"c\", {\"key\": \"value\"}, true, null, [\"bar\"], \"a\", \"b\"]]");
      +     *
      +     * Json.set(client, "doc", "$", "[[], [\"a\"], [\"a\", \"b\"]]").get();
      +     * res = Json.arrinsert(client, gs("doc"), gs("."), 0, new GlideString[] { gs("\"c\"") }).get();
      +     * assert res == 4 // New length of the root array after insertion
      +     * doc = Json.get(client, "doc").get();
      +     * assert doc.equals("[\"c\", [], [\"a\"], [\"a\", \"b\"]]");
      +     * }
      + */ + public static CompletableFuture arrinsert( + @NonNull BaseClient client, + @NonNull GlideString key, + @NonNull GlideString path, + int index, + @NonNull GlideString[] values) { + return executeCommand( + client, + new ArgsBuilder() + .add(gs(JSON_ARRINSERT)) + .add(key) + .add(path) + .add(Integer.toString(index)) + .add(values) + .toArray()); + } + + /** + * Searches for the first occurrence of a scalar JSON value in the arrays at the + * path. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param scalar The scalar value to search for. + * @return + *
        + *
      • For JSONPath (path starts with $): Returns an array with a + * list of integers for every possible path, indicating the index of the matching + * element. The value is -1 if not found. If a value is not an array, its + * corresponding return value is null. + *
      • For legacy path (path doesn't start with $): Returns an integer + * representing the index of matching element, or -1 if not found. If the + * value at the path is not an array, an error is raised. + *
      + * + * @example + *
      {@code
      +     * Json.set(client, key, "$", "{\"a\": [\"value\", 3], \"b\": {\"a\": [3, [\"value\", false], 5]}}").get();
      +     * var result = Json.arrindex(client, key, "$..a", "3").get();
      +     * assert Arrays.equals((Object[]) result, new Object[] {1L, 0L});
      +     *
      +     * result = Json.arrindex(client, key, "$..a", "\"value\"").get();
      +     * assert Arrays.equals((Object[]) result, new Object[] {0L, -1L});
      +     * }
      + */ + public static CompletableFuture arrindex( + @NonNull BaseClient client, + @NonNull String key, + @NonNull String path, + @NonNull String scalar) { + return arrindex(client, gs(key), gs(path), gs(scalar)); + } + + /** + * Searches for the first occurrence of a scalar JSON value in the arrays at the + * path. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param scalar The scalar value to search for. + * @return + *
        + *
      • For JSONPath (path starts with $): Returns an array with a + * list of integers for every possible path, indicating the index of the matching + * element. The value is -1 if not found. If a value is not an array, its + * corresponding return value is null. + *
      • For legacy path (path doesn't start with $): Returns an integer + * representing the index of matching element, or -1 if not found. If the + * value at the path is not an array, an error is raised. + *
      + * + * @example + *
      {@code
      +     * Json.set(client, key, "$", "{\"a\": [\"value\", 3], \"b\": {\"a\": [3, [\"value\", false], 5]}}").get();
      +     * var result = Json.arrindex(client, gs(key), gs("$..a"), gs("3")).get();
      +     * assert Arrays.equals((Object[]) result, new Object[] {1L, 0L});
      +     *
      +     * // Searches for the first occurrence of null in the arrays
      +     * result = Json.arrindex(client, gs(key), gs("$..a"), gs("null")).get();
      +     * assert Arrays.equals((Object[]) result, new Object[] {-1L, -1L});
      +     * }
      + */ + public static CompletableFuture arrindex( + @NonNull BaseClient client, + @NonNull GlideString key, + @NonNull GlideString path, + @NonNull GlideString scalar) { + return executeCommand(client, new GlideString[] {gs(JSON_ARRINDEX), key, path, scalar}); + } + + /** + * Searches for the first occurrence of a scalar JSON value in the arrays at the + * path. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param scalar The scalar value to search for. + * @param options The additional options for the command. See JsonArrindexOptions. + * @return + *
        + *
      • For JSONPath (path starts with $): Returns an array with a + * list of integers for every possible path, indicating the index of the matching + * element. The value is -1 if not found. If a value is not an array, its + * corresponding return value is null. + *
      • For legacy path (path doesn't start with $): Returns an integer + * representing the index of matching element, or -1 if not found. If the + * value at the path is not an array, an error is raised. + *
      + * + * @example + *
      {@code
      +     * Json.set(client, key, "$", "{\"a\": [\"value\", 3], \"b\": {\"a\": [3, [\"value\", false], 5]}}").get();
      +     * var result = Json.arrindex(client, key, ".a", "3", new JsonArrindexOptions(0L)).get();
      +     * assert Arrays.equals(1L, result);
      +     * }
      + */ + public static CompletableFuture arrindex( + @NonNull BaseClient client, + @NonNull String key, + @NonNull String path, + @NonNull String scalar, + @NonNull JsonArrindexOptions options) { + + return executeCommand( + client, + new ArgsBuilder() + .add(JSON_ARRINDEX) + .add(key) + .add(path) + .add(scalar) + .add(options.toArgs()) + .toArray()); + } + + /** + * Searches for the first occurrence of a scalar JSON value in the arrays at the + * path. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param scalar The scalar value to search for. + * @param options The additional options for the command. See JsonArrindexOptions. + * @return + *
        + *
      • For JSONPath (path starts with $): Returns an array with a + * list of integers for every possible path, indicating the index of the matching + * element. The value is -1 if not found. If a value is not an array, its + * corresponding return value is null.. + *
      • For legacy path (path doesn't start with $): Returns an integer + * representing the index of matching element, or -1 if not found. If the + * value at the path is not an array, an error is raised. + *
      + * + * @example + *
      {@code
      +     * Json.set(client, key, "$", "{\"a\": [\"value\", 3], \"b\": {\"a\": [3, [\"value\", false], 5]}}").get();
      +     * var result = Json.arrindex(client, gs(key), gs(".a"), gs("3"), new JsonArrindexOptions(0L)).get();
      +     * assert Arrays.equals(1L, result);
      +     * }
      + */ + public static CompletableFuture arrindex( + @NonNull BaseClient client, + @NonNull GlideString key, + @NonNull GlideString path, + @NonNull GlideString scalar, + @NonNull JsonArrindexOptions options) { + + return executeCommand( + client, + new ArgsBuilder() + .add(JSON_ARRINDEX) + .add(key) + .add(path) + .add(scalar) + .add(options.toArgs()) + .toArray()); + } + + /** + * Retrieves the length of the array at the specified path within the JSON document + * stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an Object[] with a list of integers for every possible path, + * indicating the length of the array, or null for JSON values matching the + * path that are not an array. If path does not exist, an empty array will + * be returned. + *
      • For legacy path (path doesn't start with $):
        + * Returns an integer representing the length of the array. If multiple paths are + * matched, returns the length of the first matching array. If path doesn't + * exist or the value at path is not an array, an error is raised. + *
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\": [1, 2, 3], \"b\": {\"a\": [1, 2], \"c\": {\"a\": 42}}}").get();
      +     * var res = Json.arrlen(client, "doc", "$").get();
      +     * assert Arrays.equals((Object[]) res, new Object[] { null }); // No array at the root path.
      +     * res = Json.arrlen(client, "doc", "$.a").get();
      +     * assert Arrays.equals((Object[]) res, new Object[] { 3 }); // Retrieves the length of the array at path $.a.
      +     * res = Json.arrlen(client, "doc", "$..a").get();
      +     * assert Arrays.equals((Object[]) res, new Object[] { 3, 2, null }); // Retrieves lengths of arrays found at all levels of the path `..a`.
      +     * res = Json.arrlen(client, "doc", "..a").get();
      +     * assert res == 3; // Legacy path retrieves the first array match at path `..a`.
      +     * }
      + */ + public static CompletableFuture arrlen( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, new String[] {JSON_ARRLEN, key, path}); + } + + /** + * Retrieves the length of the array at the specified path within the JSON document + * stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an Object[] with a list of integers for every possible path, + * indicating the length of the array, or null for JSON values matching the + * path that are not an array. If path does not exist, an empty array will + * be returned. + *
      • For legacy path (path doesn't start with $):
        + * Returns an integer representing the length of the array. If multiple paths are + * matched, returns the length of the first matching array. If path doesn't + * exist or the value at path is not an array, an error is raised. + *
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\": [1, 2, 3], \"b\": {\"a\": [1, 2], \"c\": {\"a\": 42}}}").get();
      +     * var res = Json.arrlen(client, gs("doc"), gs("$")).get();
      +     * assert Arrays.equals((Object[]) res, new Object[] { null }); // No array at the root path.
      +     * res = Json.arrlen(client, gs("doc"), gs("$.a")).get();
      +     * assert Arrays.equals((Object[]) res, new Object[] { 3 }); // Retrieves the length of the array at path $.a.
      +     * res = Json.arrlen(client, gs("doc"), gs("$..a")).get();
      +     * assert Arrays.equals((Object[]) res, new Object[] { 3, 2, null }); // Retrieves lengths of arrays found at all levels of the path `..a`.
      +     * res = Json.arrlen(client, gs("doc"), gs("..a")).get();
      +     * assert res == 3; // Legacy path retrieves the first array match at path `..a`.
      +     * }
      + */ + public static CompletableFuture arrlen( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand(client, new GlideString[] {gs(JSON_ARRLEN), key, path}); + } + + /** + * Retrieves the length of the array at the root of the JSON document stored at key. + *
      + * Equivalent to {@link #arrlen(BaseClient, String, String)} with path set to + * ".". + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return The array length stored at the root of the document. If document root is not an array, + * an error is raised.
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "[1, 2, true, null, \"tree\"]").get();
      +     * var res = Json.arrlen(client, "doc").get();
      +     * assert res == 5;
      +     * }
      + */ + public static CompletableFuture arrlen(@NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new String[] {JSON_ARRLEN, key}); + } + + /** + * Retrieves the length of the array at the root of the JSON document stored at key. + * Equivalent to {@link #arrlen(BaseClient, GlideString, GlideString)} with path set + * to gs("."). + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return The array length stored at the root of the document. If document root is not an array, + * an error is raised.
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "[1, 2, true, null, \"tree\"]").get();
      +     * var res = Json.arrlen(client, gs("doc")).get();
      +     * assert res == 5;
      +     * }
      + */ + public static CompletableFuture arrlen( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new GlideString[] {gs(JSON_ARRLEN), key}); + } + + /** + * Reports memory usage in bytes of a JSON object at the specified path within the + * JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an Object[] with a list of numbers for every possible path, + * indicating the memory usage. If path does not exist, an empty array will + * be returned. + *
      • For legacy path (path doesn't start with $):
        + * Returns an integer representing the memory usage. If multiple paths are matched, + * returns the data of the first matching object. If path doesn't exist, an + * error is raised. + *
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "[1, 2.3, \"foo\", true, null, {}, [], {\"a\":1, \"b\":2}, [1, 2, 3]]").get();
      +     * var res = Json.debugMemory(client, "doc", "..").get();
      +     * assert res == 258L;
      +     * }
      + */ + public static CompletableFuture debugMemory( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, concatenateArrays(JSON_DEBUG_MEMORY, new String[] {key, path})); + } + + /** + * Reports the number of fields at the specified path within the JSON document stored + * at key.
      + * Each non-container JSON value counts as one field. Objects and arrays recursively count one + * field for each of their containing JSON values. Each container value, except the root + * container, counts as one additional field. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an Object[] with a list of numbers for every possible path, + * indicating the number of fields. If path does not exist, an empty array + * will be returned. + *
      • For legacy path (path doesn't start with $):
        + * Returns an integer representing the number of fields. If multiple paths are matched, + * returns the data of the first matching object. If path doesn't exist, an + * error is raised. + *
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "[1, 2.3, \"foo\", true, null, {}, [], {\"a\":1, \"b\":2}, [1, 2, 3]]").get();
      +     * var res = Json.debugFields(client, "doc", "$[*]").get();
      +     * assert Arrays.equals((Object[]) res, new Object[] {1, 1, 1, 1, 1, 0, 0, 2, 3});
      +     * }
      + */ + public static CompletableFuture debugFields( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, concatenateArrays(JSON_DEBUG_FIELDS, new String[] {key, path})); + } + + /** + * Reports memory usage in bytes of a JSON object at the specified path within the + * JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an Object[] with a list of numbers for every possible path, + * indicating the memory usage. If path does not exist, an empty array will + * be returned. + *
      • For legacy path (path doesn't start with $):
        + * Returns an integer representing the memory usage. If multiple paths are matched, + * returns the data of the first matching object. If path doesn't exist, an + * error is raised. + *
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "[1, 2.3, \"foo\", true, null, {}, [], {\"a\":1, \"b\":2}, [1, 2, 3]]").get();
      +     * var res = Json.debugMemory(client, gs("doc"), gs("..")).get();
      +     * assert res == 258L;
      +     * }
      + */ + public static CompletableFuture debugMemory( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand( + client, new ArgsBuilder().add(JSON_DEBUG_MEMORY).add(key).add(path).toArray()); + } + + /** + * Reports the number of fields at the specified path within the JSON document stored + * at key.
      + * Each non-container JSON value counts as one field. Objects and arrays recursively count one + * field for each of their containing JSON values. Each container value, except the root + * container, counts as one additional field. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an Object[] with a list of numbers for every possible path, + * indicating the number of fields. If path does not exist, an empty array + * will be returned. + *
      • For legacy path (path doesn't start with $):
        + * Returns an integer representing the number of fields. If multiple paths are matched, + * returns the data of the first matching object. If path doesn't exist, an + * error is raised. + *
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "[1, 2.3, \"foo\", true, null, {}, [], {\"a\":1, \"b\":2}, [1, 2, 3]]").get();
      +     * var res = Json.debugFields(client, gs("doc"), gs("$[*]")).get();
      +     * assert Arrays.equals((Object[]) res, new Object[] {1, 1, 1, 1, 1, 0, 0, 2, 3});
      +     * }
      + */ + public static CompletableFuture debugFields( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand( + client, new ArgsBuilder().add(JSON_DEBUG_FIELDS).add(key).add(path).toArray()); + } + + /** + * Reports memory usage in bytes of a JSON object at the specified path within the + * JSON document stored at key.
      + * Equivalent to {@link #debugMemory(BaseClient, String, String)} with path set to + * "..". + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return The total memory usage in bytes of the entire JSON document.
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "[1, 2.3, \"foo\", true, null, {}, [], {\"a\":1, \"b\":2}, [1, 2, 3]]").get();
      +     * var res = Json.debugMemory(client, "doc").get();
      +     * assert res == 258L;
      +     * }
      + */ + public static CompletableFuture debugMemory( + @NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, concatenateArrays(JSON_DEBUG_MEMORY, new String[] {key})); + } + + /** + * Reports the number of fields at the specified path within the JSON document stored + * at key.
      + * Each non-container JSON value counts as one field. Objects and arrays recursively count one + * field for each of their containing JSON values. Each container value, except the root + * container, counts as one additional field.
      + * Equivalent to {@link #debugFields(BaseClient, String, String)} with path set to + * "..". + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return The total number of fields in the entire JSON document.
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "[1, 2.3, \"foo\", true, null, {}, [], {\"a\":1, \"b\":2}, [1, 2, 3]]").get();
      +     * var res = Json.debugFields(client, "doc").get();
      +     * assert res == 14L;
      +     * }
      + */ + public static CompletableFuture debugFields( + @NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, concatenateArrays(JSON_DEBUG_FIELDS, new String[] {key})); + } + + /** + * Reports memory usage in bytes of a JSON object at the specified path within the + * JSON document stored at key.
      + * Equivalent to {@link #debugMemory(BaseClient, GlideString, GlideString)} with path + * set to gs(".."). + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return The total memory usage in bytes of the entire JSON document.
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "[1, 2.3, \"foo\", true, null, {}, [], {\"a\":1, \"b\":2}, [1, 2, 3]]").get();
      +     * var res = Json.debugMemory(client, gs("doc")).get();
      +     * assert res == 258L;
      +     * }
      + */ + public static CompletableFuture debugMemory( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new ArgsBuilder().add(JSON_DEBUG_MEMORY).add(key).toArray()); + } + + /** + * Reports the number of fields at the specified path within the JSON document stored + * at key.
      + * Each non-container JSON value counts as one field. Objects and arrays recursively count one + * field for each of their containing JSON values. Each container value, except the root + * container, counts as one additional field.
      + * Equivalent to {@link #debugFields(BaseClient, GlideString, GlideString)} with path + * set to gs(".."). + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return The total number of fields in the entire JSON document.
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "[1, 2.3, \"foo\", true, null, {}, [], {\"a\":1, \"b\":2}, [1, 2, 3]]").get();
      +     * var res = Json.debugFields(client, gs("doc")).get();
      +     * assert res == 14L;
      +     * }
      + */ + public static CompletableFuture debugFields( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new ArgsBuilder().add(JSON_DEBUG_FIELDS).add(key).toArray()); + } + + /** + * Pops the last element from the array stored in the root of the JSON document stored at + * key. Equivalent to {@link #arrpop(BaseClient, String, String)} with + * path set to ".". + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return Returns a string representing the popped JSON value, or null if the array + * at document root is empty.
      + * If the JSON value at document root is not an array or if key doesn't exist, an + * error is raised. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "[1, 2, true, {\"a\": 42, \"b\": 33}, \"tree\"]").get();
      +     * var res = Json.arrpop(client, "doc").get();
      +     * assert res.equals("\"tree\"");
      +     * res = Json.arrpop(client, "doc").get();
      +     * assert res.equals("{\"a\": 42, \"b\": 33}");
      +     * }
      + */ + public static CompletableFuture arrpop(@NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new String[] {JSON_ARRPOP, key}); + } + + /** + * Pops the last element from the array located in the root of the JSON document stored at + * key. Equivalent to {@link #arrpop(BaseClient, GlideString, GlideString)} with + * path set to gs("."). + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return Returns a string representing the popped JSON value, or null if the array + * at document root is empty.
      + * If the JSON value at document root is not an array or if key doesn't exist, an + * error is raised. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "[1, 2, true, {\"a\": 42, \"b\": 33}, \"tree\"]").get();
      +     * var res = Json.arrpop(client, gs("doc")).get();
      +     * assert res.equals(gs("\"tree\""));
      +     * res = Json.arrpop(client, gs("doc")).get();
      +     * assert res.equals(gs("{\"a\": 42, \"b\": 33}"));
      +     * }
      + */ + public static CompletableFuture arrpop( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new GlideString[] {gs(JSON_ARRPOP), key}); + } + + /** + * Pops the last element from the array located at path in the JSON document stored + * at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an array with a strings for every possible path, representing the popped JSON + * values, or null for JSON values matching the path that are not an array + * or an empty array. + *
      • For legacy path (path doesn't start with $):
        + * Returns a string representing the popped JSON value, or null if the + * array at path is empty. If multiple paths are matched, the value from + * the first matching array that is not empty is returned. If path doesn't + * exist or the value at path is not an array, an error is raised. + *
      + * If key doesn't exist, an error is raised. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "[1, 2, true, {\"a\": 42, \"b\": 33}, \"tree\"]").get();
      +     * var res = Json.arrpop(client, "doc", "$").get();
      +     * assert Arrays.equals((Object[]) res, new Object[] { "\"tree\"" });
      +     * res = Json.arrpop(client, "doc", ".").get();
      +     * assert res.equals("{\"a\": 42, \"b\": 33}");
      +     * }
      + */ + public static CompletableFuture arrpop( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, new String[] {JSON_ARRPOP, key, path}); + } + + /** + * Pops the last element from the array located at path in the JSON document stored + * at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an array with a strings for every possible path, representing the popped JSON + * values, or null for JSON values matching the path that are not an array + * or an empty array. + *
      • For legacy path (path doesn't start with $):
        + * Returns a string representing the popped JSON value, or null if the + * array at path is empty. If multiple paths are matched, the value from + * the first matching array that is not empty is returned. If path doesn't + * exist or the value at path is not an array, an error is raised. + *
      + * If key doesn't exist, an error is raised. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "[1, 2, true, {\"a\": 42, \"b\": 33}, \"tree\"]").get();
      +     * var res = Json.arrpop(client, gs("doc"), gs("$")).get();
      +     * assert Arrays.equals((Object[]) res, new Object[] { gs("\"tree\"") });
      +     * res = Json.arrpop(client, gs("doc"), gs(".")).get();
      +     * assert res.equals(gs("{\"a\": 42, \"b\": 33}"));
      +     * }
      + */ + public static CompletableFuture arrpop( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand(client, new GlideString[] {gs(JSON_ARRPOP), key, path}); + } + + /** + * Pops an element from the array located at path in the JSON document stored at + * key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param index The index of the element to pop. Out of boundary indexes are rounded to their + * respective array boundaries. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an array with a strings for every possible path, representing the popped JSON + * values, or null for JSON values matching the path that are not an array + * or an empty array. + *
      • For legacy path (path doesn't start with $):
        + * Returns a string representing the popped JSON value, or null if the + * array at path is empty. If multiple paths are matched, the value from + * the first matching array that is not empty is returned. If path doesn't + * exist or the value at path is not an array, an error is raised. + *
      + * If key doesn't exist, an error is raised. + * @example + *
      {@code
      +     * String doc = "{\"a\": [1, 2, true], \"b\": {\"a\": [3, 4, [\"value\", 3, false], 5], \"c\": {\"a\": 42}}}";
      +     * Json.set(client, "doc", "$", doc).get();
      +     * var res = Json.arrpop(client, "doc", "$.a", 1).get();
      +     * assert res.equals("2"); // Pop second element from array at path `$.a`
      +     *
      +     * Json.set(client, "doc", "$", "[[], [\"a\"], [\"a\", \"b\", \"c\"]]").get();
      +     * res = Json.arrpop(client, "doc", ".", -1).get());
      +     * assert res.equals("[\"a\", \"b\", \"c\"]"); // Pop last elements at path `.`
      +     * }
      + */ + public static CompletableFuture arrpop( + @NonNull BaseClient client, @NonNull String key, @NonNull String path, long index) { + return executeCommand(client, new String[] {JSON_ARRPOP, key, path, Long.toString(index)}); + } + + /** + * Pops an element from the array located at path in the JSON document stored at + * key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param index The index of the element to pop. Out of boundary indexes are rounded to their + * respective array boundaries. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an array with a strings for every possible path, representing the popped JSON + * values, or null for JSON values matching the path that are not an array + * or an empty array. + *
      • For legacy path (path doesn't start with $):
        + * Returns a string representing the popped JSON value, or null if the + * array at path is empty. If multiple paths are matched, the value from + * the first matching array that is not empty is returned. If path doesn't + * exist or the value at path is not an array, an error is raised. + *
      + * If key doesn't exist, an error is raised. + * @example + *
      {@code
      +     * String doc = "{\"a\": [1, 2, true], \"b\": {\"a\": [3, 4, [\"value\", 3, false], 5], \"c\": {\"a\": 42}}}";
      +     * Json.set(client, "doc", "$", doc).get();
      +     * var res = Json.arrpop(client, gs("doc"), gs("$.a"), 1).get();
      +     * assert res.equals("2"); // Pop second element from array at path `$.a`
      +     *
      +     * Json.set(client, "doc", "$", "[[], [\"a\"], [\"a\", \"b\", \"c\"]]").get();
      +     * res = Json.arrpop(client, gs("doc"), gs("."), -1).get());
      +     * assert res.equals(gs("[\"a\", \"b\", \"c\"]")); // Pop last elements at path `.`
      +     * }
      + */ + public static CompletableFuture arrpop( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path, long index) { + return executeCommand( + client, new GlideString[] {gs(JSON_ARRPOP), key, path, gs(Long.toString(index))}); + } + + /** + * Trims an array at the specified path within the JSON document stored at key + * so that it becomes a subarray [start, end], both inclusive. + *
      + * If start < 0, it is treated as 0.
      + * If end >= size (size of the array), it is treated as size -1.
      + * If start >= size or start > end, the array is emptied + * and 0 is return.
      + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param start The index of the first element to keep, inclusive. + * @param end The index of the last element to keep, inclusive. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an Object[] with a list of integers for every possible path, + * indicating the new length of the array, or null for JSON values matching + * the path that are not an array. If the array is empty, its corresponding return value + * is 0. If path doesn't exist, an empty array will be return. If an index + * argument is out of bounds, an error is raised. + *
      • For legacy path (path doesn't start with $):
        + * Returns an integer representing the new length of the array. If the array is empty, + * its corresponding return value is 0. If multiple paths match, the length of the first + * trimmed array match is returned. If path doesn't exist, or the value at + * path is not an array, an error is raised. If an index argument is out of + * bounds, an error is raised. + *
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{[], [\"a\"], [\"a\", \"b\"], [\"a\", \"b\", \"c\"]}").get();
      +     * var res = Json.arrtrim(client, "doc", "$[*]", 0, 1).get();
      +     * assert Arrays.equals((Object[]) res, new Object[] { 0, 1, 2, 2 }); // New lengths of arrays after trimming
      +     *
      +     * Json.set(client, "doc", "$", "{\"children\": [\"John\", \"Jack\", \"Tom\", \"Bob\", \"Mike\"]}").get();
      +     * res = Json.arrtrim(client, "doc", ".children", 0, 1).get();
      +     * assert res == 2; // new length after trimming
      +     * }
      + */ + public static CompletableFuture arrtrim( + @NonNull BaseClient client, @NonNull String key, @NonNull String path, int start, int end) { + return executeCommand( + client, + new String[] {JSON_ARRTRIM, key, path, Integer.toString(start), Integer.toString(end)}); + } + + /** + * Trims an array at the specified path within the JSON document stored at key + * so that it becomes a subarray [start, end], both inclusive. + *
      + * If start < 0, it is treated as 0.
      + * If end >= size (size of the array), it is treated as size -1.
      + * If start >= size or start > end, the array is emptied + * and 0 is return.
      + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param start The index of the first element to keep, inclusive. + * @param end The index of the last element to keep, inclusive. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an Object[] with a list of integers for every possible path, + * indicating the new length of the array, or null for JSON values matching + * the path that are not an array. If the array is empty, its corresponding return value + * is 0. If path doesn't exist, an empty array will be return. If an index + * argument is out of bounds, an error is raised. + *
      • For legacy path (path doesn't start with $):
        + * Returns an integer representing the new length of the array. If the array is empty, + * its corresponding return value is 0. If multiple paths match, the length of the first + * trimmed array match is returned. If path doesn't exist, or the value at + * path is not an array, an error is raised. If an index argument is out of + * bounds, an error is raised. + *
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{[], [\"a\"], [\"a\", \"b\"], [\"a\", \"b\", \"c\"]}").get();
      +     * var res = Json.arrtrim(client, gs("doc"), gs("$[*]"), 0, 1).get();
      +     * assert Arrays.equals((Object[]) res, new Object[] { 0, 1, 2, 2 }); // New lengths of arrays after trimming
      +     *
      +     * Json.set(client, "doc", "$", "{\"children\": [\"John\", \"Jack\", \"Tom\", \"Bob\", \"Mike\"]}").get();
      +     * res = Json.arrtrim(client, gs("doc"), gs(".children"), 0, 1).get();
      +     * assert res == 2; // new length after trimming
      +     * }
      + */ + public static CompletableFuture arrtrim( + @NonNull BaseClient client, + @NonNull GlideString key, + @NonNull GlideString path, + int start, + int end) { + return executeCommand( + client, + new ArgsBuilder() + .add(gs(JSON_ARRTRIM)) + .add(key) + .add(path) + .add(Integer.toString(start)) + .add(Integer.toString(end)) + .toArray()); + } + + /** + * Increments or decrements the JSON value(s) at the specified path by number + * within the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param number The number to increment or decrement by. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns a string representation of an array of strings, indicating the new values + * after incrementing for each matched path.
        + * If a value is not a number, its corresponding return value will be null. + *
        + * If path doesn't exist, a byte string representation of an empty array + * will be returned. + *
      • For legacy path (path doesn't start with $):
        + * Returns a string representation of the resulting value after the increment or + * decrement.
        + * If multiple paths match, the result of the last updated value is returned.
        + * If the value at the path is not a number or path doesn't + * exist, an error is raised. + *
      + * If key does not exist, an error is raised.
      + * If the result is out of the range of 64-bit IEEE double, an error is raised. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"c\": [1, 2], \"d\": [1, 2, 3]}").get();
      +     * var res = Json.numincrby(client, "doc", "$.d[*]", 10.0).get();
      +     * assert res.equals("[11,12,13]"); // Increment each element in `d` array by 10.
      +     *
      +     * res = Json.numincrby(client, "doc", ".c[1]", 10.0).get();
      +     * assert res.equals("12"); // Increment the second element in the `c` array by 10.
      +     * }
      + */ + public static CompletableFuture numincrby( + @NonNull BaseClient client, @NonNull String key, @NonNull String path, Number number) { + return executeCommand(client, new String[] {JSON_NUMINCRBY, key, path, number.toString()}); + } + + /** + * Increments or decrements the JSON value(s) at the specified path by number + * within the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param number The number to increment or decrement by. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns a GlideString representation of an array of strings, indicating + * the new values after incrementing for each matched path.
        + * If a value is not a number, its corresponding return value will be null. + *
        + * If path doesn't exist, a byte string representation of an empty array + * will be returned. + *
      • For legacy path (path doesn't start with $):
        + * Returns a GlideString representation of the resulting value after the + * increment or decrement.
        + * If multiple paths match, the result of the last updated value is returned.
        + * If the value at the path is not a number or path doesn't + * exist, an error is raised. + *
      + * If key does not exist, an error is raised.
      + * If the result is out of the range of 64-bit IEEE double, an error is raised. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"c\": [1, 2], \"d\": [1, 2, 3]}").get();
      +     * var res = Json.numincrby(client, gs("doc"), gs("$.d[*]"), 10.0).get();
      +     * assert res.equals(gs("[11,12,13]")); // Increment each element in `d` array by 10.
      +     *
      +     * res = Json.numincrby(client, gs("doc"), gs(".c[1]"), 10.0).get();
      +     * assert res.equals(gs("12")); // Increment the second element in the `c` array by 10.
      +     * }
      + */ + public static CompletableFuture numincrby( + @NonNull BaseClient client, + @NonNull GlideString key, + @NonNull GlideString path, + Number number) { + return executeCommand( + client, new GlideString[] {gs(JSON_NUMINCRBY), key, path, gs(number.toString())}); + } + + /** + * Multiplies the JSON value(s) at the specified path by number within + * the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param number The number to multiply by. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns a string representation of an array of strings, indicating the new values + * after multiplication for each matched path.
        + * If a value is not a number, its corresponding return value will be null. + *
        + * If path doesn't exist, a byte string representation of an empty array + * will be returned. + *
      • For legacy path (path doesn't start with $):
        + * Returns a string representation of the resulting value after multiplication.
        + * If multiple paths match, the result of the last updated value is returned.
        + * If the value at the path is not a number or path doesn't + * exist, an error is raised. + *
      + * If key does not exist, an error is raised.
      + * If the result is out of the range of 64-bit IEEE double, an error is raised. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"c\": [1, 2], \"d\": [1, 2, 3]}").get();
      +     * var res = Json.nummultby(client, "doc", "$.d[*]", 2.0).get();
      +     * assert res.equals("[2,4,6]"); // Multiplies each element in the `d` array by 2.
      +     *
      +     * res = Json.nummultby(client, "doc", ".c[1]", 2.0).get();
      +     * assert res.equals("12"); // Multiplies the second element in the `c` array by 2.
      +     * }
      + */ + public static CompletableFuture nummultby( + @NonNull BaseClient client, @NonNull String key, @NonNull String path, Number number) { + return executeCommand(client, new String[] {JSON_NUMMULTBY, key, path, number.toString()}); + } + + /** + * Multiplies the JSON value(s) at the specified path by number within + * the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @param number The number to multiply by. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns a GlideString representation of an array of strings, indicating + * the new values after multiplication for each matched path.
        + * If a value is not a number, its corresponding return value will be null. + *
        + * If path doesn't exist, a byte string representation of an empty array + * will be returned. + *
      • For legacy path (path doesn't start with $):
        + * Returns a GlideString representation of the resulting value after + * multiplication.
        + * If multiple paths match, the result of the last updated value is returned.
        + * If the value at the path is not a number or path doesn't + * exist, an error is raised. + *
      + * If key does not exist, an error is raised.
      + * If the result is out of the range of 64-bit IEEE double, an error is raised. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"c\": [1, 2], \"d\": [1, 2, 3]}").get();
      +     * var res = Json.nummultby(client, gs("doc"), gs("$.d[*]"), 2.0).get();
      +     * assert res.equals(gs("[2,4,6]")); // Multiplies each element in the `d` array by 2.
      +     *
      +     * res = Json.nummultby(client, gs("doc"), gs(".c[1]"), 2.0).get();
      +     * assert res.equals(gs("12")); // Multiplies the second element in the `c` array by 2.
      +     * }
      + */ + public static CompletableFuture nummultby( + @NonNull BaseClient client, + @NonNull GlideString key, + @NonNull GlideString path, + Number number) { + return executeCommand( + client, new GlideString[] {gs(JSON_NUMMULTBY), key, path, gs(number.toString())}); + } + + /** + * Retrieves the number of key-value pairs in the object values at the specified path + * within the JSON document stored at key.
      + * Equivalent to {@link #objlen(BaseClient, String, String)} with path set to + * ".". + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return The object length stored at the root of the document. If document root is not an + * object, an error is raised.
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\": 1.0, \"b\": {\"a\": {\"x\": 1, \"y\": 2}, \"b\": 2.5, \"c\": true}}").get();
      +     * var res = Json.objlen(client, "doc").get();
      +     * assert res == 2; // the size of object matching the path `.`, which has 2 keys: 'a' and 'b'.
      +     * }
      + */ + public static CompletableFuture objlen(@NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new String[] {JSON_OBJLEN, key}); + } + + /** + * Retrieves the number of key-value pairs in the object values at the specified path + * within the JSON document stored at key.
      + * Equivalent to {@link #objlen(BaseClient, GlideString, GlideString)} with path set + * to gs("."). + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return The object length stored at the root of the document. If document root is not an + * object, an error is raised.
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\": 1.0, \"b\": {\"a\": {\"x\": 1, \"y\": 2}, \"b\": 2.5, \"c\": true}}").get();
      +     * var res = Json.objlen(client, gs("doc"), gs(".")).get();
      +     * assert res == 2; // the size of object matching the path `.`, which has 2 keys: 'a' and 'b'.
      +     * }
      + */ + public static CompletableFuture objlen( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new GlideString[] {gs(JSON_OBJLEN), key}); + } + + /** + * Retrieves the number of key-value pairs in the object values at the specified path + * within the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an Object[] with a list of long integers for every possible + * path, indicating the number of key-value pairs for each matching object, or + * null + * for JSON values matching the path that are not an object. If path + * does not exist, an empty array will be returned. + *
      • For legacy path (path doesn't start with $):
        + * Returns the number of key-value pairs for the object value matching the path. If + * multiple paths are matched, returns the length of the first matching object. If + * path doesn't exist or the value at path is not an array, an + * error is raised. + *
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\": 1.0, \"b\": {\"a\": {\"x\": 1, \"y\": 2}, \"b\": 2.5, \"c\": true}}").get();
      +     * var res = Json.objlen(client, "doc", ".").get(); // legacy path - command returns first value as `Long`
      +     * assert res == 2L; // the size of object matching the path `.`, which has 2 keys: 'a' and 'b'.
      +     *
      +     * res = Json.objlen(client, "doc", "$.b").get(); // JSONPath - command returns an array
      +     * assert Arrays.equals((Object[]) res, new Object[] { 3L }); // the length of the objects at path `$.b`
      +     * }
      + */ + public static CompletableFuture objlen( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, new String[] {JSON_OBJLEN, key, path}); + } + + /** + * Retrieves the number of key-value pairs in the object values at the specified path + * within the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an Object[] with a list of long integers for every possible + * path, indicating the number of key-value pairs for each matching object, or + * null + * for JSON values matching the path that are not an object. If path + * does not exist, an empty array will be returned. + *
      • For legacy path (path doesn't start with $):
        + * Returns the number of key-value pairs for the object value matching the path. If + * multiple paths are matched, returns the length of the first matching object. If + * path doesn't exist or the value at path is not an array, an + * error is raised. + *
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\": 1.0, \"b\": {\"a\": {\"x\": 1, \"y\": 2}, \"b\": 2.5, \"c\": true}}").get();
      +     * var res = Json.objlen(client, gs("doc"), gs(".")).get(); // legacy path - command returns first value as `Long`
      +     * assert res == 2L; // the size of object matching the path `.`, which has 2 keys: 'a' and 'b'.
      +     *
      +     * res = Json.objlen(client, gs("doc"), gs("$.b")).get(); // JSONPath - command returns an array
      +     * assert Arrays.equals((Object[]) res, new Object[] { 3L }); // the length of the objects at path `$.b`
      +     * }
      + */ + public static CompletableFuture objlen( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand(client, new GlideString[] {gs(JSON_OBJLEN), key, path}); + } + + /** + * Retrieves the key names in the object values at the specified path within the JSON + * document stored at key.
      + * Equivalent to {@link #objkeys(BaseClient, String, String)} with path set to + * ".". + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return The object length stored at the root of the document. If document root is not an + * object, an error is raised.
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\": 1.0, \"b\": {\"a\": {\"x\": 1, \"y\": 2}, \"b\": 2.5, \"c\": true}}").get();
      +     * var res = Json.objkeys(client, "doc").get();
      +     * assert Arrays.equals((Object[]) res, new Object[] { "a", "b" }); // the keys of the object matching the path `.`, which has 2 keys: 'a' and 'b'.
      +     * }
      + */ + public static CompletableFuture objkeys( + @NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new String[] {JSON_OBJKEYS, key}); + } + + /** + * Retrieves the key names in the object values at the specified path within the JSON + * document stored at key.
      + * Equivalent to {@link #objkeys(BaseClient, GlideString, GlideString)} with path set + * to gs("."). + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return The object length stored at the root of the document. If document root is not an + * object, an error is raised.
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\": 1.0, \"b\": {\"a\": {\"x\": 1, \"y\": 2}, \"b\": 2.5, \"c\": true}}").get();
      +     * var res = Json.objkeys(client, gs("doc"), gs(".")).get();
      +     * assert Arrays.equals((Object[]) res, new Object[] { gs("a"), gs("b") }); // the keys of the object matching the path `.`, which has 2 keys: 'a' and 'b'.
      +     * }
      + */ + public static CompletableFuture objkeys( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new GlideString[] {gs(JSON_OBJKEYS), key}); + } + + /** + * Retrieves the key names in the object values at the specified path within the JSON + * document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an Object[][] with each nested array containing key names for + * each matching object for every possible path, indicating the list of object keys for + * each matching object, or null for JSON values matching the path that are + * not an object. If path does not exist, an empty sub-array will be + * returned. + *
      • For legacy path (path doesn't start with $):
        + * Returns an array of object keys for the object value matching the path. If multiple + * paths are matched, returns the length of the first matching object. If path + * doesn't exist or the value at path is not an array, an error is + * raised. + *
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\": 1.0, \"b\": {\"a\": {\"x\": 1, \"y\": 2}, \"b\": 2.5, \"c\": true}}").get();
      +     * var res = Json.objkeys(client, "doc", ".").get(); // legacy path - command returns array for first matched object
      +     * assert Arrays.equals((Object[]) res, new Object[] { "a", "b" }); // key names for the object matching the path `.` as it is the only match.
      +     *
      +     * res = Json.objkeys(client, "doc", "$.b").get(); // JSONPath - command returns an array for each matched object
      +     * assert Arrays.equals((Object[]) res, new Object[][] { { "a", "b", "c" } }); // key names as a nested list for objects matching the JSONPath `$.b`.
      +     * }
      + */ + public static CompletableFuture objkeys( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, new String[] {JSON_OBJKEYS, key, path}); + } + + /** + * Retrieves the key names in the object values at the specified path within the JSON + * document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns an Object[][] with each nested array containing key names for + * each matching object for every possible path, indicating the list of object keys for + * each matching object, or null for JSON values matching the path that are + * not an object. If path does not exist, an empty sub-array will be + * returned. + *
      • For legacy path (path doesn't start with $):
        + * Returns an array of object keys for the object value matching the path. If multiple + * paths are matched, returns the length of the first matching object. If path + * doesn't exist or the value at path is not an array, an error is + * raised. + *
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\": 1.0, \"b\": {\"a\": {\"x\": 1, \"y\": 2}, \"b\": 2.5, \"c\": true}}").get();
      +     * var res = Json.objkeys(client, gs("doc"), gs(".")).get(); // legacy path - command returns array for first matched object
      +     * assert Arrays.equals((Object[]) res, new Object[] { "a", "b" }); // key names for the object matching the path `.` as it is the only match.
      +     *
      +     * res = Json.objkeys(client, gs("doc"), gs("$.b")).get(); // JSONPath - command returns an array for each matched object
      +     * assert Arrays.equals((Object[]) res, new Object[][] { { "a", "b", "c" } }); // key names as a nested list for objects matching the JSONPath `$.b`.
      +     * }
      + */ + public static CompletableFuture objkeys( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand(client, new GlideString[] {gs(JSON_OBJKEYS), key, path}); + } + + /** + * Deletes the JSON document stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return The number of elements deleted. 0 if the key does not exist. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}}");
      +     * Long result = Json.del(client, "doc").get();
      +     * assert result == 1L;
      +     * }
      + */ + public static CompletableFuture del(@NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new String[] {JSON_DEL, key}); + } + + /** + * Deletes the JSON document stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return The number of elements deleted. 0 if the key does not exist. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}}");
      +     * Long result = Json.del(client, gs("doc")).get();
      +     * assert result == 1L;
      +     * }
      + */ + public static CompletableFuture del(@NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new GlideString[] {gs(JSON_DEL), key}); + } + + /** + * Deletes the JSON value at the specified path within the JSON document stored at + * key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the value will be deleted. + * @return The number of elements deleted. 0 if the key does not exist, or if the JSON path is + * invalid or does not exist. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}}");
      +     * Long result = Json.del(client, "doc", "$..a").get();
      +     * assert result == 2L;
      +     * }
      + */ + public static CompletableFuture del( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, new String[] {JSON_DEL, key, path}); + } + + /** + * Deletes the JSON value at the specified path within the JSON document stored at + * key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the value will be deleted. + * @return The number of elements deleted. 0 if the key does not exist, or if the JSON path is + * invalid or does not exist. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}}");
      +     * Long result = Json.del(client, gs("doc"), gs("$..a")).get();
      +     * assert result == 2L;
      +     * }
      + */ + public static CompletableFuture del( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand(client, new GlideString[] {gs(JSON_DEL), key, path}); + } + + /** + * Deletes the JSON document stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return The number of elements deleted. 0 if the key does not exist. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}}");
      +     * Long result = Json.forget(client, "doc").get();
      +     * assert result == 1L;
      +     * }
      + */ + public static CompletableFuture forget(@NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new String[] {JSON_FORGET, key}); + } + + /** + * Deletes the JSON document stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return The number of elements deleted. 0 if the key does not exist. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}}");
      +     * Long result = Json.forget(client, gs("doc")).get();
      +     * assert result == 1L;
      +     * }
      + */ + public static CompletableFuture forget( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new GlideString[] {gs(JSON_FORGET), key}); + } + + /** + * Deletes the JSON value at the specified path within the JSON document stored at + * key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the value will be deleted. + * @return The number of elements deleted. 0 if the key does not exist, or if the JSON path is + * invalid or does not exist. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}}");
      +     * Long result = Json.forget(client, "doc", "$..a").get();
      +     * assert result == 2L;
      +     * }
      + */ + public static CompletableFuture forget( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, new String[] {JSON_FORGET, key, path}); + } + + /** + * Deletes the JSON value at the specified path within the JSON document stored at + * key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the value will be deleted. + * @return The number of elements deleted. 0 if the key does not exist, or if the JSON path is + * invalid or does not exist. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}}");
      +     * Long result = Json.forget(client, gs("doc"), gs("$..a")).get();
      +     * assert result == 2L;
      +     * }
      + */ + public static CompletableFuture forget( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand(client, new GlideString[] {gs(JSON_FORGET), key, path}); + } + + /** + * Toggles a Boolean value stored at the root within the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return Returns the toggled boolean value at the root of the document, or null for + * JSON values matching the root that are not boolean. If key doesn't exist, + * returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", true).get();
      +     * var res = Json.toggle(client, "doc").get();
      +     * assert res.equals(false);
      +     * res = Json.toggle(client, "doc").get();
      +     * assert res.equals(true);
      +     * }
      + */ + public static CompletableFuture toggle(@NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new String[] {JSON_TOGGLE, key}); + } + + /** + * Toggles a Boolean value stored at the root within the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return Returns the toggled boolean value at the root of the document, or null for + * JSON values matching the root that are not boolean. If key doesn't exist, + * returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", true).get();
      +     * var res = Json.toggle(client, gs("doc")).get();
      +     * assert res.equals(false);
      +     * res = Json.toggle(client, gs("doc")).get();
      +     * assert res.equals(true);
      +     * }
      + */ + public static CompletableFuture toggle( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new ArgsBuilder().add(gs(JSON_TOGGLE)).add(key).toArray()); + } + + /** + * Toggles a Boolean value stored at the specified path within the JSON document + * stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns a Boolean[] with the toggled boolean value for every possible + * path, or null for JSON values matching the path that are not boolean. + *
      • For legacy path (path doesn't start with $):
        + * Returns the value of the toggled boolean in path. If path + * doesn't exist or the value at path isn't a boolean, an error is raised. + *
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"bool\": true, \"nested\": {\"bool\": false, \"nested\": {\"bool\": 10}}}").get();
      +     * var res = Json.toggle(client, "doc", "$..bool").get();
      +     * assert Arrays.equals((Boolean[]) res, new Boolean[] {false, true, null});
      +     * res = Json.toggle(client, "doc", "bool").get();
      +     * assert res.equals(true);
      +     * var getResult = Json.get(client, "doc", "$").get();
      +     * assert getResult.equals("{\"bool\": true, \"nested\": {\"bool\": true, \"nested\": {\"bool\": 10}}}");
      +     * }
      + */ + public static CompletableFuture toggle( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, new String[] {JSON_TOGGLE, key, path}); + } + + /** + * Toggles a Boolean value stored at the specified path within the JSON document + * stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns a Boolean[] with the toggled boolean value for every possible + * path, or null for JSON values matching the path that are not boolean. + *
      • For legacy path (path doesn't start with $):
        + * Returns the value of the toggled boolean in path. If path + * doesn't exist or the value at path isn't a boolean, an error is raised. + *
      + * If key doesn't exist, returns null. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"bool\": true, \"nested\": {\"bool\": false, \"nested\": {\"bool\": 10}}}").get();
      +     * var res = Json.toggle(client, gs("doc"), gs("$..bool")).get();
      +     * assert Arrays.equals((Boolean[]) res, new Boolean[] {false, true, null});
      +     * res = Json.toggle(client, gs("doc"), gs("bool")).get();
      +     * assert res.equals(true);
      +     * var getResult = Json.get(client, "doc", "$").get();
      +     * assert getResult.equals("{\"bool\": true, \"nested\": {\"bool\": true, \"nested\": {\"bool\": 10}}}");
      +     * }
      + */ + public static CompletableFuture toggle( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand( + client, new ArgsBuilder().add(gs(JSON_TOGGLE)).add(key).add(path).toArray()); + } + + /** + * Appends the specified value to the string stored at the specified path + * within the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param value The value to append to the string. Must be wrapped with single quotes. For + * example, to append "foo", pass '"foo"'. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns a list of integer replies for every possible path, indicating the length of + * the resulting string after appending value, or null for + * JSON values matching the path that are not string.
        + * If key doesn't exist, an error is raised. + *
      • For legacy path (path doesn't start with $):
        + * Returns the length of the resulting string after appending value to the + * string at path.
        + * If multiple paths match, the length of the last updated string is returned.
        + * If the JSON value at path is not a string of if path + * doesn't exist, an error is raised.
        + * If key doesn't exist, an error is raised. + *
      + * + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\":\"foo\", \"nested\": {\"a\": \"hello\"}, \"nested2\": {\"a\": 31}}").get();
      +     * var res = Json.strappend(client, "doc", "baz", "$..a").get();
      +     * assert Arrays.equals((Object[]) res, new Object[] {6L, 8L, null}); // The new length of the string values at path '$..a' in the key stored at `doc` after the append operation.
      +     *
      +     * res = Json.strappend(client, "doc", '"foo"', "nested.a").get();
      +     * assert (Long) res == 11L; // The length of the string value after appending "foo" to the string at path 'nested.array' in the key stored at `doc`.
      +     *
      +     * var getResult = Json.get(client, "doc", "$").get();
      +     * assert getResult.equals("[{\"a\":\"foobaz\", \"nested\": {\"a\": \"hellobazfoo\"}, \"nested2\": {\"a\": 31}}]"); // The updated JSON value in the key stored at `doc`.
      +     * }
      + */ + public static CompletableFuture strappend( + @NonNull BaseClient client, + @NonNull String key, + @NonNull String value, + @NonNull String path) { + return executeCommand( + client, new ArgsBuilder().add(JSON_STRAPPEND).add(key).add(path).add(value).toArray()); + } + + /** + * Appends the specified value to the string stored at the specified path + * within the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param value The value to append to the string. Must be wrapped with single quotes. For + * example, to append "foo", pass '"foo"'. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns a list of integer replies for every possible path, indicating the length of + * the resulting string after appending value, or null for + * JSON values matching the path that are not string.
        + * If key doesn't exist, an error is raised. + *
      • For legacy path (path doesn't start with $):
        + * Returns the length of the resulting string after appending value to the + * string at path.
        + * If multiple paths match, the length of the last updated string is returned.
        + * If the JSON value at path is not a string of if path + * doesn't exist, an error is raised.
        + * If key doesn't exist, an error is raised. + *
      + * + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\":\"foo\", \"nested\": {\"a\": \"hello\"}, \"nested2\": {\"a\": 31}}").get();
      +     * var res = Json.strappend(client, gs("doc"), gs("baz"), gs("$..a")).get();
      +     * assert Arrays.equals((Object[]) res, new Object[] {6L, 8L, null}); // The new length of the string values at path '$..a' in the key stored at `doc` after the append operation.
      +     *
      +     * res = Json.strappend(client, gs("doc"), gs("'\"foo\"'"), gs("nested.a")).get();
      +     * assert (Long) res == 11L; // The length of the string value after appending "foo" to the string at path 'nested.array' in the key stored at `doc`.
      +     *
      +     * var getResult = Json.get(client, gs("doc"), gs("$")).get();
      +     * assert getResult.equals("[{\"a\":\"foobaz\", \"nested\": {\"a\": \"hellobazfoo\"}, \"nested2\": {\"a\": 31}}]"); // The updated JSON value in the key stored at `doc`.
      +     * }
      + */ + public static CompletableFuture strappend( + @NonNull BaseClient client, + @NonNull GlideString key, + @NonNull GlideString value, + @NonNull GlideString path) { + return executeCommand( + client, new ArgsBuilder().add(gs(JSON_STRAPPEND)).add(key).add(path).add(value).toArray()); + } + + /** + * Appends the specified value to the string stored at the root within the JSON + * document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param value The value to append to the string. Must be wrapped with single quotes. For + * example, to append "foo", pass '"foo"'. + * @return Returns the length of the resulting string after appending value to the + * string at the root.
      + * If the JSON value at root is not a string, an error is raised.
      + * If key doesn't exist, an error is raised. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "'\"foo\"'").get();
      +     * var res = Json.strappend(client, "doc", "'\"baz\"'").get();
      +     * assert res == 6L; // The length of the string value after appending "foo" to the string at root in the key stored at `doc`.
      +     *
      +     * var getResult = Json.get(client, "doc").get();
      +     * assert getResult.equals("\"foobaz\""); // The updated JSON value in the key stored at `doc`.
      +     * }
      + */ + public static CompletableFuture strappend( + @NonNull BaseClient client, @NonNull String key, @NonNull String value) { + return executeCommand( + client, new ArgsBuilder().add(JSON_STRAPPEND).add(key).add(value).toArray()); + } + + /** + * Appends the specified value to the string stored at the root within the JSON + * document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param value The value to append to the string. Must be wrapped with single quotes. For + * example, to append "foo", pass '"foo"'. + * @return Returns the length of the resulting string after appending value to the + * string at the root.
      + * If the JSON value at root is not a string, an error is raised.
      + * If key doesn't exist, an error is raised. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "'\"foo\"'").get();
      +     * var res = Json.strappend(client, gs("doc"), gs("'\"baz\"'")).get();
      +     * assert res == 6L; // The length of the string value after appending "foo" to the string at root in the key stored at `doc`.
      +     *
      +     * var getResult = Json.get(client, gs("$"), gs("doc")).get();
      +     * assert getResult.equals("\"foobaz\""); // The updated JSON value in the key stored at `doc`.
      +     * }
      + */ + public static CompletableFuture strappend( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString value) { + return executeCommand( + client, new ArgsBuilder().add(gs(JSON_STRAPPEND)).add(key).add(value).toArray()); + } + + /** + * Returns the length of the JSON string value stored at the specified path within + * the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns a list of integer replies for every possible path, indicating the length of + * the JSON string value, or null for JSON values matching the path that + * are not string. + *
      • For legacy path (path doesn't start with $):
        + * Returns the length of the JSON value at path or null if + * key doesn't exist.
        + * If multiple paths match, the length of the first matched string is returned.
        + * If the JSON value at path is not a string of if path + * doesn't exist, an error is raised. If key doesn't exist, null + * is returned. + *
      + * + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\":\"foo\", \"nested\": {\"a\": \"hello\"}, \"nested2\": {\"a\": 31}}").get();
      +     * var res = Json.strlen(client, "doc", "$..a").get();
      +     * assert Arrays.equals((Object[]) res, new Object[] {3L, 5L, null}); // The length of the string values at path '$..a' in the key stored at `doc`.
      +     *
      +     * res = Json.strlen(client, "doc", "nested.a").get();
      +     * assert (Long) res == 5L; // The length of the JSON value at path 'nested.a' in the key stored at `doc`.
      +     *
      +     * res = Json.strlen(client, "doc", "$").get();
      +     * assert Arrays.equals((Object[]) res, new Object[] {null}); // Returns an array with null since the value at root path does in the JSON document stored at `doc` is not a string.
      +     *
      +     * res = Json.strlen(client, "non_existing_key", ".").get();
      +     * assert res == null; // `key` doesn't exist.
      +     * }
      + */ + public static CompletableFuture strlen( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, new ArgsBuilder().add(JSON_STRLEN).add(key).add(path).toArray()); + } + + /** + * Returns the length of the JSON string value stored at the specified path within + * the JSON document stored at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $):
        + * Returns a list of integer replies for every possible path, indicating the length of + * the JSON string value, or null for JSON values matching the path that + * are not string. + *
      • For legacy path (path doesn't start with $):
        + * Returns the length of the JSON value at path or null if + * key doesn't exist.
        + * If multiple paths match, the length of the first matched string is returned.
        + * If the JSON value at path is not a string of if path + * doesn't exist, an error is raised. If key doesn't exist, null + * is returned. + *
      + * + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\":\"foo\", \"nested\": {\"a\": \"hello\"}, \"nested2\": {\"a\": 31}}").get();
      +     * var res = Json.strlen(client, gs("doc"), gs("$..a")).get();
      +     * assert Arrays.equals((Object[]) res, new Object[] {3L, 5L, null}); // The length of the string values at path '$..a' in the key stored at `doc`.
      +     *
      +     * res = Json.strlen(client, gs("doc"), gs("nested.a")).get();
      +     * assert (Long) res == 5L; // The length of the JSON value at path 'nested.a' in the key stored at `doc`.
      +     *
      +     * res = Json.strlen(client, gs("doc"), gs("$")).get();
      +     * assert Arrays.equals((Object[]) res, new Object[] {null}); // Returns an array with null since the value at root path does in the JSON document stored at `doc` is not a string.
      +     *
      +     * res = Json.strlen(client, gs("non_existing_key"), gs(".")).get();
      +     * assert res == null; // `key` doesn't exist.
      +     * }
      + */ + public static CompletableFuture strlen( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand( + client, new ArgsBuilder().add(gs(JSON_STRLEN)).add(key).add(path).toArray()); + } + + /** + * Returns the length of the JSON string value stored at the root within the JSON document stored + * at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return Returns the length of the JSON value at the root.
      + * If the JSON value is not a string, an error is raised.
      + * If key doesn't exist, null is returned. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "\"Hello\"").get();
      +     * var res = Json.strlen(client, "doc").get();
      +     * assert res == 5L; // The length of the JSON value at the root in the key stored at `doc`.
      +     *
      +     * res = Json.strlen(client, "non_existing_key").get();
      +     * assert res == null; // `key` doesn't exist.
      +     * }
      + */ + public static CompletableFuture strlen(@NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new ArgsBuilder().add(JSON_STRLEN).add(key).toArray()); + } + + /** + * Returns the length of the JSON string value stored at the root within the JSON document stored + * at key. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return Returns the length of the JSON value at the root.
      + * If the JSON value is not a string, an error is raised.
      + * If key doesn't exist, null is returned. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "\"Hello\"").get();
      +     * var res = Json.strlen(client, gs("doc")).get();
      +     * assert res == 5L; // The length of the JSON value at the root in the key stored at `doc`.
      +     *
      +     * res = Json.strlen(client, gs("non_existing_key")).get();
      +     * assert res == null; // `key` doesn't exist.
      +     * }
      + */ + public static CompletableFuture strlen( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new ArgsBuilder().add(gs(JSON_STRLEN)).add(key).toArray()); + } + + /** + * Clears an array and an object at the root of the JSON document stored at key.
      + * Equivalent to {@link #clear(BaseClient, String, String)} with path set to + * ".". + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return 1 if the document wasn't empty or 0 if it was.
      + * If key doesn't exist, an error is raised. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\":1, \"b\":2}").get();
      +     * long res = Json.clear(client, "doc").get();
      +     * assert res == 1;
      +     *
      +     * var doc = Json.get(client, "doc", "$").get();
      +     * assert doc.equals("[{}]");
      +     *
      +     * res = Json.clear(client, "doc").get();
      +     * assert res == 0; // the doc is already empty
      +     * }
      + */ + public static CompletableFuture clear(@NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new String[] {JSON_CLEAR, key}); + } + + /** + * Clears an array and an object at the root of the JSON document stored at key.
      + * Equivalent to {@link #clear(BaseClient, GlideString, GlideString)} with path set + * to ".". + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @return 1 if the document wasn't empty or 0 if it was.
      + * If key doesn't exist, an error is raised. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\":1, \"b\":2}").get();
      +     * long res = Json.clear(client, gs("doc")).get();
      +     * assert res == 1;
      +     *
      +     * var doc = Json.get(client, "doc", "$").get();
      +     * assert doc.equals("[{}]");
      +     *
      +     * res = Json.clear(client, gs("doc")).get();
      +     * assert res == 0; // the doc is already empty
      +     * }
      + */ + public static CompletableFuture clear( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new GlideString[] {gs(JSON_CLEAR), key}); + } + + /** + * Clears arrays and objects at the specified path within the JSON document stored at + * key.
      + * Numeric values are set to 0, boolean values are set to false, and + * string values are converted to empty strings. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return The number of containers cleared.
      + * If path doesn't exist, or the value at path is already cleared + * (e.g., an empty array, object, or string), 0 is returned. If key doesn't + * exist, an error is raised. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"obj\": {\"a\":1, \"b\":2}, \"arr\":[1, 2, 3], \"str\": \"foo\", \"bool\": true,
      +     *     \"int\": 42, \"float\": 3.14, \"nullVal\": null}").get();
      +     * long res = Json.clear(client, "doc", "$.*").get();
      +     * assert res == 6; // 6 values are cleared: "obj", "arr", "str", "bool", "int", and "float"; "nullVal" is not clearable.
      +     *
      +     * var doc = Json.get(client, "doc", "$").get();
      +     * assert doc.equals("[{\"obj\":{},\"arr\":[],\"str\":\"\",\"bool\":false,\"int\":0,\"float\":0.0,\"nullVal\":null}]");
      +     *
      +     * res = Json.clear(client, "doc", "$.*").get();
      +     * assert res == 0; // containers are already empty and nothing is cleared
      +     * }
      + */ + public static CompletableFuture clear( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, new String[] {JSON_CLEAR, key, path}); + } + + /** + * Clears arrays and objects at the specified path within the JSON document stored at + * key.
      + * Numeric values are set to 0, boolean values are set to false, and + * string values are converted to empty strings. + * + * @param client The client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return The number of containers cleared.
      + * If path doesn't exist, or the value at path is already cleared + * (e.g., an empty array, object, or string), 0 is returned. If key doesn't + * exist, an error is raised. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"obj\": {\"a\":1, \"b\":2}, \"arr\":[1, 2, 3], \"str\": \"foo\", \"bool\": true,
      +     *     \"int\": 42, \"float\": 3.14, \"nullVal\": null}").get();
      +     * long res = Json.clear(client, gs("doc"), gs("$.*")).get();
      +     * assert res == 6; // 6 values are cleared: "obj", "arr", "str", "bool", "int", and "float"; "nullVal" is not clearable.
      +     *
      +     * var doc = Json.get(client, "doc", "$").get();
      +     * assert doc.equals("[{\"obj\":{},\"arr\":[],\"str\":\"\",\"bool\":false,\"int\":0,\"float\":0.0,\"nullVal\":null}]");
      +     *
      +     * res = Json.clear(client, gs("doc"), gs("$.*")).get();
      +     * assert res == 0; // containers are already empty and nothing is cleared
      +     * }
      + */ + public static CompletableFuture clear( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand(client, new GlideString[] {gs(JSON_CLEAR), key, path}); + } + + /** + * Retrieves the JSON document stored at key. The returning result is in the Valkey or Redis OSS Serialization Protocol (RESP). + *
        + *
      • JSON null is mapped to the RESP Null Bulk String.
      • + *
      • JSON Booleans are mapped to RESP Simple string.
      • + *
      • JSON integers are mapped to RESP Integers.
      • + *
      • JSON doubles are mapped to RESP Bulk Strings.
      • + *
      • JSON strings are mapped to RESP Bulk Strings.
      • + *
      • JSON arrays are represented as RESP arrays, where the first element is the simple string [, followed by the array's elements.
      • + *
      • JSON objects are represented as RESP object, where the first element is the simple string {, followed by key-value pairs, each of which is a RESP bulk string.
      • + *
      + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return Returns the JSON document in its RESP form. + * If key doesn't exist, null is returned. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": [1, 2, 3], \"b\": {\"b1\": 1}, \"c\": 42}");
      +     * Object actualResult = Json.resp(client, "doc").get();
      +     * Object[] expectedResult = new Object[] {
      +     *     "{",
      +     *     new Object[] {"a", new Object[] {"[", 1L, 2L, 3L}},
      +     *     new Object[] {"b", new Object[] {"{", new Object[] {"b1", 1L}}},
      +     *     new Object[] {"c", 42L}
      +     * };
      +     * assertInstanceOf(Object[].class, actualResult);
      +     * assertArrayEquals(expectedResult, (Object[]) actualResult);
      +     * }
      + */ + public static CompletableFuture resp(@NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new String[] {JSON_RESP, key}); + } + + /** + * Retrieves the JSON document stored at key. The returning result is in the Valkey or Redis OSS Serialization Protocol (RESP). + *
        + *
      • JSON null is mapped to the RESP Null Bulk String.
      • + *
      • JSON Booleans are mapped to RESP Simple string.
      • + *
      • JSON integers are mapped to RESP Integers.
      • + *
      • JSON doubles are mapped to RESP Bulk Strings.
      • + *
      • JSON strings are mapped to RESP Bulk Strings.
      • + *
      • JSON arrays are represented as RESP arrays, where the first element is the simple string [, followed by the array's elements.
      • + *
      • JSON objects are represented as RESP object, where the first element is the simple string {, followed by key-value pairs, each of which is a RESP bulk string.
      • + *
      + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return Returns the JSON document in its RESP form. + * If key doesn't exist, null is returned. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": [1, 2, 3], \"b\": {\"b1\": 1}, \"c\": 42}");
      +     * Object actualResultBinary = Json.resp(client, gs("doc")).get();
      +     * Object[] expectedResultBinary = new Object[] {
      +     *     "{",
      +     *     new Object[] {gs("a"), new Object[] {gs("["), 1L, 2L, 3L}},
      +     *     new Object[] {gs("b"), new Object[] {gs("{"), new Object[] {gs("b1"), 1L}}},
      +     *     new Object[] {gs("c"), 42L}
      +     * };
      +     * assertInstanceOf(Object[].class, actualResultBinary);
      +     * assertArrayEquals(expectedResultBinary, (Object[]) actualResultBinary);
      +     * }
      + */ + public static CompletableFuture resp( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new GlideString[] {gs(JSON_RESP), key}); + } + + /** + * Retrieve the JSON value at the specified path within the JSON document stored at + * key. The returning result is in the Valkey or Redis OSS Serialization Protocol + * (RESP). + * + *
        + *
      • JSON null is mapped to the RESP Null Bulk String. + *
      • JSON Booleans are mapped to RESP Simple string. + *
      • JSON integers are mapped to RESP Integers. + *
      • JSON doubles are mapped to RESP Bulk Strings. + *
      • JSON strings are mapped to RESP Bulk Strings. + *
      • JSON arrays are represented as RESP arrays, where the first element is the simple string + * [, followed by the array's elements. + *
      • JSON objects are represented as RESP object, where the first element is the simple string + * {, followed by key-value pairs, each of which is a RESP bulk string. + *
      + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $): Returns a list of + * replies for every possible path, indicating the RESP form of the JSON value. If + * path doesn't exist, returns an empty list. + *
      • For legacy path (path doesn't starts with $): Returns a + * single reply for the JSON value at the specified path, in its RESP form. If multiple + * paths match, the value of the first JSON value match is returned. If path + * doesn't exist, an error is raised. + *
      + * If key doesn't exist, null is returned. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": [1, 2, 3], \"b\": {\"a\": [1, 2], \"c\": {\"a\": 42}}}");
      +     * Object actualResult = Json.resp(client, "doc", "$..a").get(); // JSONPath returns all possible paths
      +     * Object[] expectedResult = new Object[] {
      +     *                 new Object[] {"[", 1L, 2L, 3L},
      +     *                 new Object[] {"[", 1L, 2L},
      +     *                 42L};
      +     * assertArrayEquals(expectedResult, (Object[]) actualResult);
      +     * // legacy path only returns the first JSON value match
      +     * assertArrayEquals(new Object[] {"[", 1L, 2L, 3L}, (Object[]) Json.resp(client, key, "..a").get());
      +     * }
      + */ + public static CompletableFuture resp( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + return executeCommand(client, new String[] {JSON_RESP, key, path}); + } + + /** + * Retrieve the JSON value at the specified path within the JSON document stored at + * key. The returning result is in the Valkey or Redis OSS Serialization Protocol + * (RESP). + * + *
        + *
      • JSON null is mapped to the RESP Null Bulk String. + *
      • JSON Booleans are mapped to RESP Simple string. + *
      • JSON integers are mapped to RESP Integers. + *
      • JSON doubles are mapped to RESP Bulk Strings. + *
      • JSON strings are mapped to RESP Bulk Strings. + *
      • JSON arrays are represented as RESP arrays, where the first element is the simple string + * [, followed by the array's elements. + *
      • JSON objects are represented as RESP object, where the first element is the simple string + * {, followed by key-value pairs, each of which is a RESP bulk string. + *
      + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path The path within the JSON document. + * @return + *
        + *
      • For JSONPath (path starts with $): Returns a list of + * replies for every possible path, indicating the RESP form of the JSON value. If + * path doesn't exist, returns an empty list. + *
      • For legacy path (path doesn't starts with $): Returns a + * single reply for the JSON value at the specified path, in its RESP form. If multiple + * paths match, the value of the first JSON value match is returned. If path + * doesn't exist, an error is raised. + *
      + * If key doesn't exist, null is returned. + * @example + *
      {@code
      +     * Json.set(client, "doc", ".", "{\"a\": [1, 2, 3], \"b\": {\"a\": [1, 2], \"c\": {\"a\": 42}}}");
      +     * Object actualResult = Json.resp(client, gs("doc"), gs("$..a")).get(); // JSONPath returns all possible paths
      +     * Object[] expectedResult = new Object[] {
      +     *                 new Object[] {gs("["), 1L, 2L, 3L},
      +     *                 new Object[] {gs("["), 1L, 2L},
      +     *                 42L};
      +     * assertArrayEquals(expectedResult, (Object[]) actualResult);
      +     * // legacy path only returns the first JSON value match
      +     * assertArrayEquals(new Object[] {gs("["), 1L, 2L, 3L}, (Object[]) Json.resp(client, gs(key), gs("..a")).get());
      +     * }
      + */ + public static CompletableFuture resp( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand(client, new GlideString[] {gs(JSON_RESP), key, path}); + } + + /** + * Retrieves the type of the JSON value at the root of the JSON document stored at key + * . + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return Returns the type of the JSON value at root. If key doesn't exist, + * null is returned. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "[1, 2, 3]");
      +     * assertEquals("array", Json.type(client, "doc").get());
      +     *
      +     * Json.set(client, "doc", "$", "{\"a\": 1}");
      +     * assertEquals("object", Json.type(client, "doc").get());
      +     *
      +     * assertNull(Json.type(client, "non_existing_key").get());
      +     * }
      + */ + public static CompletableFuture type(@NonNull BaseClient client, @NonNull String key) { + return executeCommand(client, new String[] {JSON_TYPE, key}); + } + + /** + * Retrieves the type of the JSON value at the root of the JSON document stored at key + * . + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @return Returns the type of the JSON value at root. If key doesn't exist, + * null is returned. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "[1, 2, 3]");
      +     * assertEquals(gs("array"), Json.type(client, gs("doc")).get());
      +     *
      +     * Json.set(client, "doc", "$", "{\"a\": 1}");
      +     * assertEquals(gs("object"), Json.type(client, gs("doc")).get());
      +     *
      +     * assertNull(Json.type(client, gs("non_existing_key")).get());
      +     * }
      + */ + public static CompletableFuture type( + @NonNull BaseClient client, @NonNull GlideString key) { + return executeCommand(client, new GlideString[] {gs(JSON_TYPE), key}); + } + + /** + * Retrieves the type of the JSON value at the specified path within the JSON + * document stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the type will be retrieved. + * @return + *
        + *
      • For JSONPath (path starts with $): Returns a list of string + * replies for every possible path, indicating the type of the JSON value. If `path` + * doesn't exist, an empty array will be returned. + *
      • For legacy path (path doesn't starts with $): Returns the + * type of the JSON value at `path`. If multiple paths match, the type of the first JSON + * value match is returned. If `path` doesn't exist, null will be returned. + *
      + * If key doesn't exist, null is returned. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}}");
      +     * assertArrayEquals(new Object[]{"object"}, (Object[]) Json.type(client, key, "$.nested").get());
      +     * assertArrayEquals(new Object[]{"integer"}, (Object[]) Json.type(client, key, "$.nested.a").get());
      +     * assertArrayEquals(new Object[]{"integer", "object"}, (Object[]) Json.type(client, key, "$[*]").get());
      +     * }
      + */ + public static CompletableFuture type( + @NonNull BaseClient client, @NonNull String key, @NonNull String path) { + + return executeCommand(client, new String[] {JSON_TYPE, key, path}); + } + + /** + * Retrieves the type of the JSON value at the specified path within the JSON + * document stored at key. + * + * @param client The Valkey GLIDE client to execute the command. + * @param key The key of the JSON document. + * @param path Represents the path within the JSON document where the type will be retrieved. + * @return + *
        + *
      • For JSONPath (path starts with $): Returns a list of string + * replies for every possible path, indicating the type of the JSON value. If `path` + * doesn't exist, an empty array will be returned. + *
      • For legacy path (path doesn't starts with $): Returns the + * type of the JSON value at `path`. If multiple paths match, the type of the first JSON + * value match is returned. If `path` doesn't exist, null will be returned. + *
      + * If key doesn't exist, null is returned. + * @example + *
      {@code
      +     * Json.set(client, "doc", "$", "{\"a\": 1, \"nested\": {\"a\": 2, \"b\": 3}}");
      +     * assertArrayEquals(new Object[]{gs("object")}, (Object[]) Json.type(client, gs(key), gs("$.nested")).get());
      +     * assertArrayEquals(new Object[]{gs("integer")}, (Object[]) Json.type(client, gs(key), gs("$.nested.a")).get());
      +     * assertArrayEquals(new Object[]{gs("integer"), gs("object")}, (Object[]) Json.type(client, gs(key), gs("$[*]")).get());
      +     * }
      + */ + public static CompletableFuture type( + @NonNull BaseClient client, @NonNull GlideString key, @NonNull GlideString path) { + return executeCommand(client, new GlideString[] {gs(JSON_TYPE), key, path}); + } + + /** + * A wrapper for custom command API. + * + * @param client The client to execute the command. + * @param args The command line. + */ + private static CompletableFuture executeCommand(BaseClient client, String[] args) { + return executeCommand(client, args, false); + } + + /** + * A wrapper for custom command API. + * + * @param client The client to execute the command. + * @param args The command line. + * @param returnsMap - true if command returns a map + */ + @SuppressWarnings({"unchecked", "SameParameterValue"}) + private static CompletableFuture executeCommand( + BaseClient client, String[] args, boolean returnsMap) { + if (client instanceof GlideClient) { + return ((GlideClient) client).customCommand(args).thenApply(r -> (T) r); + } else if (client instanceof GlideClusterClient) { + return ((GlideClusterClient) client) + .customCommand(args) + .thenApply(returnsMap ? ClusterValue::getMultiValue : ClusterValue::getSingleValue) + .thenApply(r -> (T) r); + } + throw new IllegalArgumentException( + "Unknown type of client, should be either `GlideClient` or `GlideClusterClient`"); + } + + /** + * A wrapper for custom command API. + * + * @param client The client to execute the command. + * @param args The command line. + */ + private static CompletableFuture executeCommand(BaseClient client, GlideString[] args) { + return executeCommand(client, args, false); + } + + /** + * A wrapper for custom command API. + * + * @param client The client to execute the command. + * @param args The command line. + * @param returnsMap - true if command returns a map + */ + @SuppressWarnings({"unchecked", "SameParameterValue"}) + private static CompletableFuture executeCommand( + BaseClient client, GlideString[] args, boolean returnsMap) { + if (client instanceof GlideClient) { + return ((GlideClient) client).customCommand(args).thenApply(r -> (T) r); + } else if (client instanceof GlideClusterClient) { + return ((GlideClusterClient) client) + .customCommand(args) + .thenApply(returnsMap ? ClusterValue::getMultiValue : ClusterValue::getSingleValue) + .thenApply(r -> (T) r); + } + throw new IllegalArgumentException( + "Unknown type of client, should be either `GlideClient` or `GlideClusterClient`"); + } +} diff --git a/java/client/src/main/java/glide/api/models/BaseTransaction.java b/java/client/src/main/java/glide/api/models/BaseTransaction.java index 062150c3d2..3914b05049 100644 --- a/java/client/src/main/java/glide/api/models/BaseTransaction.java +++ b/java/client/src/main/java/glide/api/models/BaseTransaction.java @@ -583,7 +583,7 @@ public T mget(@NonNull ArgType[] keys) { * * @see valkey.io for details. * @param keyValueMap A key-value map consisting of keys and their respective values to set. - * @return Command Response - Always OK. + * @return Command Response - A simple OK response. */ public T mset(@NonNull Map keyValueMap) { GlideString[] args = flattenMapToGlideStringArray(keyValueMap); diff --git a/java/client/src/main/java/glide/api/models/commands/FT/FTAggregateOptions.java b/java/client/src/main/java/glide/api/models/commands/FT/FTAggregateOptions.java new file mode 100644 index 0000000000..d4335099a0 --- /dev/null +++ b/java/client/src/main/java/glide/api/models/commands/FT/FTAggregateOptions.java @@ -0,0 +1,414 @@ +/** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ +package glide.api.models.commands.FT; + +import static glide.api.models.GlideString.gs; +import static glide.utils.ArrayTransformUtils.concatenateArrays; +import static glide.utils.ArrayTransformUtils.toGlideStringArray; + +import glide.api.BaseClient; +import glide.api.commands.servermodules.FT; +import glide.api.models.GlideString; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; +import lombok.Builder; +import lombok.NonNull; + +/** + * Additional arguments for {@link FT#aggregate(BaseClient, String, String, FTAggregateOptions)} + * command. + */ +@Builder +public class FTAggregateOptions { + /** Query timeout in milliseconds. */ + private final Integer timeout; + + private final boolean loadAll; + + private final GlideString[] loadFields; + + private final List clauses; + + /** Convert to module API. */ + public GlideString[] toArgs() { + var args = new ArrayList(); + if (loadAll) { + args.add(gs("LOAD")); + args.add(gs("*")); + } else if (loadFields != null) { + args.add(gs("LOAD")); + args.add(gs(Integer.toString(loadFields.length))); + args.addAll(List.of(loadFields)); + } + if (timeout != null) { + args.add(gs("TIMEOUT")); + args.add(gs(timeout.toString())); + } + if (!params.isEmpty()) { + args.add(gs("PARAMS")); + args.add(gs(Integer.toString(params.size() * 2))); + params.forEach( + (name, value) -> { + args.add(gs(name)); + args.add(value); + }); + } + if (clauses != null) { + for (var expression : clauses) { + args.addAll(List.of(expression.toArgs())); + } + } + return args.toArray(GlideString[]::new); + } + + /** + * Query parameters, which could be referenced in the query by $ sign, followed by + * the parameter name. + */ + @Builder.Default private final Map params = new HashMap<>(); + + public static class FTAggregateOptionsBuilder { + // private - hiding this API from user + void loadAll(boolean loadAll) {} + + void expressions(List expressions) {} + + /** Load all fields declared in the index. */ + public FTAggregateOptionsBuilder loadAll() { + loadAll = true; + return this; + } + + /** Load specified fields from the index. */ + public FTAggregateOptionsBuilder loadFields(@NonNull String[] fields) { + loadFields = toGlideStringArray(fields); + loadAll = false; + return this; + } + + /** Load specified fields from the index. */ + public FTAggregateOptionsBuilder loadFields(@NonNull GlideString[] fields) { + loadFields = fields; + loadAll = false; + return this; + } + + /** + * Add {@link Filter}, {@link Limit}, {@link GroupBy}, {@link SortBy} or {@link Apply} clause to + * the pipeline, that can be repeated multiple times in any order and be freely intermixed. They + * are applied in the order specified, with the output of one clause feeding the input of the + * next clause. + */ + public FTAggregateOptionsBuilder addClause(@NonNull FTAggregateClause clause) { + if (clauses == null) clauses = new ArrayList<>(); + clauses.add(clause); + return this; + } + } + + /** + * A superclass for clauses which could be added to FT.AGGREGATE pipeline.
      + * A clause could be either: + * + *
        + *
      • {@link Filter} + *
      • {@link Limit} + *
      • {@link GroupBy} + *
      • {@link SortBy} + *
      • {@link Apply} + *
      + */ + public abstract static class FTAggregateClause { + abstract GlideString[] toArgs(); + } + + enum ClauseType { + LIMIT, + FILTER, + GROUPBY, + SORTBY, + REDUCE, + APPLY + } + + /** A clause for limiting the number of retained records. */ + public static class Limit extends FTAggregateClause { + private final int offset; + private final int count; + + /** + * Initialize a new instance. + * + * @param offset Starting point from which the records have to be retained. + * @param count The total number of records to be retained. + */ + public Limit(int offset, int count) { + this.offset = offset; + this.count = count; + } + + @Override + GlideString[] toArgs() { + return new GlideString[] { + gs(ClauseType.LIMIT.toString()), gs(Integer.toString(offset)), gs(Integer.toString(count)) + }; + } + } + + /** + * Filter the results using predicate expression relating to values in each result. It is applied + * post query and relate to the current state of the pipeline. + */ + public static class Filter extends FTAggregateClause { + private final GlideString expression; + + /** + * Initialize a new instance. + * + * @param expression The expression to filter the results. + */ + public Filter(@NonNull GlideString expression) { + this.expression = expression; + } + + /** + * Initialize a new instance. + * + * @param expression The expression to filter the results. + */ + public Filter(@NonNull String expression) { + this.expression = gs(expression); + } + + @Override + GlideString[] toArgs() { + return new GlideString[] {gs(ClauseType.FILTER.toString()), expression}; + } + } + + /** A clause for grouping the results in the pipeline based on one or more properties. */ + public static class GroupBy extends FTAggregateClause { + private final GlideString[] properties; + private final Reducer[] reducers; + + /** + * Initialize a new instance. + * + * @param properties The list of properties to be used for grouping the results in the pipeline. + * @param reducers The list of functions that handles the group entries by performing multiple + * aggregate operations. + */ + public GroupBy(@NonNull GlideString[] properties, @NonNull Reducer[] reducers) { + this.properties = properties; + this.reducers = reducers; + } + + /** + * Initialize a new instance. + * + * @param properties The list of properties to be used for grouping the results in the pipeline. + * @param reducers The list of functions that handles the group entries by performing multiple + * aggregate operations. + */ + public GroupBy(@NonNull String[] properties, @NonNull Reducer[] reducers) { + this.properties = toGlideStringArray(properties); + this.reducers = reducers; + } + + @Override + GlideString[] toArgs() { + return concatenateArrays( + new GlideString[] { + gs(ClauseType.GROUPBY.toString()), gs(Integer.toString(properties.length)) + }, + properties, + Stream.of(reducers).map(Reducer::toArgs).flatMap(Stream::of).toArray(GlideString[]::new)); + } + + /** + * A function that handles the group entries, either counting them, or performing multiple + * aggregate operations. + */ + public static class Reducer { + private final String function; + private final GlideString[] args; + private final String name; + + /** + * Initialize a new instance. + * + * @param function The reduction function names for the respective group. + * @param args The list of arguments for the reducer. + * @param name User defined property name for the reducer. + */ + public Reducer(@NonNull String function, @NonNull GlideString[] args, @NonNull String name) { + this.function = function; + this.args = args; + this.name = name; + } + + /** + * Initialize a new instance. + * + * @param function The reduction function names for the respective group. + * @param args The list of arguments for the reducer. + */ + public Reducer(@NonNull String function, @NonNull GlideString[] args) { + this.function = function; + this.args = args; + this.name = null; + } + + /** + * Initialize a new instance. + * + * @param function The reduction function names for the respective group. + * @param args The list of arguments for the reducer. + * @param name User defined property name for the reducer. + */ + public Reducer(@NonNull String function, @NonNull String[] args, @NonNull String name) { + this.function = function; + this.args = toGlideStringArray(args); + this.name = name; + } + + /** + * Initialize a new instance. + * + * @param function The reduction function names for the respective group. + * @param args The list of arguments for the reducer. + */ + public Reducer(@NonNull String function, @NonNull String[] args) { + this.function = function; + this.args = toGlideStringArray(args); + this.name = null; + } + + GlideString[] toArgs() { + return concatenateArrays( + new GlideString[] { + gs(ClauseType.REDUCE.toString()), gs(function), gs(Integer.toString(args.length)) + }, + args, + name == null ? new GlideString[0] : new GlideString[] {gs("AS"), gs(name)}); + } + } + } + + /** Sort the pipeline using a list of properties. */ + public static class SortBy extends FTAggregateClause { + + private final SortProperty[] properties; + private final Integer max; + + /** + * Initialize a new instance. + * + * @param properties A list of sorting parameters for the sort operation. + */ + public SortBy(@NonNull SortProperty[] properties) { + this.properties = properties; + this.max = null; + } + + /** + * Initialize a new instance. + * + * @param properties A list of sorting parameters for the sort operation. + * @param max The MAX value for optimizing the sorting, by sorting only for the n-largest + * elements. + */ + public SortBy(@NonNull SortProperty[] properties, int max) { + this.properties = properties; + this.max = max; + } + + @Override + GlideString[] toArgs() { + return concatenateArrays( + new GlideString[] { + gs(ClauseType.SORTBY.toString()), gs(Integer.toString(properties.length * 2)), + }, + Stream.of(properties) + .map(SortProperty::toArgs) + .flatMap(Stream::of) + .toArray(GlideString[]::new), + max == null ? new GlideString[0] : new GlideString[] {gs("MAX"), gs(max.toString())}); + } + + public enum SortOrder { + ASC, + DESC + } + + /** A sorting parameter. */ + public static class SortProperty { + private final GlideString property; + private final SortOrder order; + + /** + * Initialize a new instance. + * + * @param property The sorting parameter name. + * @param order The order for the sorting. + */ + public SortProperty(@NonNull GlideString property, @NonNull SortOrder order) { + this.property = property; + this.order = order; + } + + /** + * Initialize a new instance. + * + * @param property The sorting parameter name. + * @param order The order for the sorting. + */ + public SortProperty(@NonNull String property, @NonNull SortOrder order) { + this.property = gs(property); + this.order = order; + } + + GlideString[] toArgs() { + return new GlideString[] {property, gs(order.toString())}; + } + } + } + + /** + * Apply a 1-to-1 transformation on one or more properties and either stores the result as a new + * property down the pipeline or replace any property using this transformation. + */ + public static class Apply extends FTAggregateClause { + private final GlideString expression; + private final GlideString name; + + /** + * Initialize a new instance. + * + * @param expression The transformation expression. + * @param name The new property name to store the result of apply. + */ + public Apply(@NonNull GlideString expression, @NonNull GlideString name) { + this.expression = expression; + this.name = name; + } + + /** + * Initialize a new instance. + * + * @param expression The transformation expression. + * @param name The new property name to store the result of apply. + */ + public Apply(@NonNull String expression, @NonNull String name) { + this.expression = gs(expression); + this.name = gs(name); + } + + @Override + GlideString[] toArgs() { + return new GlideString[] {gs(ClauseType.APPLY.toString()), expression, gs("AS"), name}; + } + } +} diff --git a/java/client/src/main/java/glide/api/models/commands/FT/FTCreateOptions.java b/java/client/src/main/java/glide/api/models/commands/FT/FTCreateOptions.java new file mode 100644 index 0000000000..81bb9e1dce --- /dev/null +++ b/java/client/src/main/java/glide/api/models/commands/FT/FTCreateOptions.java @@ -0,0 +1,422 @@ +/** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ +package glide.api.models.commands.FT; + +import static glide.api.models.GlideString.gs; + +import glide.api.BaseClient; +import glide.api.commands.servermodules.FT; +import glide.api.models.GlideString; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.NonNull; + +/** + * Additional parameters for {@link FT#create(BaseClient, String, FieldInfo[], FTCreateOptions)} + * command. + */ +@Builder +public class FTCreateOptions { + /** The index data type. If not defined a {@link DataType#HASH} index is created. */ + private final DataType dataType; + + /** A list of prefixes of index definitions. */ + private final GlideString[] prefixes; + + FTCreateOptions(DataType dataType, GlideString[] prefixes) { + this.dataType = dataType; + this.prefixes = prefixes; + } + + public static FTCreateOptionsBuilder builder() { + return new FTCreateOptionsBuilder(); + } + + public GlideString[] toArgs() { + var args = new ArrayList(); + if (dataType != null) { + args.add(gs("ON")); + args.add(gs(dataType.toString())); + } + if (prefixes != null && prefixes.length > 0) { + args.add(gs("PREFIX")); + args.add(gs(Integer.toString(prefixes.length))); + args.addAll(List.of(prefixes)); + } + return args.toArray(GlideString[]::new); + } + + public static class FTCreateOptionsBuilder { + public FTCreateOptionsBuilder prefixes(@NonNull String[] prefixes) { + this.prefixes = Stream.of(prefixes).map(GlideString::gs).toArray(GlideString[]::new); + return this; + } + } + + /** Type of the index dataset. */ + public enum DataType { + /** Data stored in hashes. Field identifiers are field names within the hashes. */ + HASH, + /** Data stored as a JSON document. Field identifiers are JSON Path expressions. */ + JSON + } + + /** + * A vector search field. Could be one of the following: + * + *
        + *
      • {@link NumericField} + *
      • {@link TextField} + *
      • {@link TagField} + *
      • {@link VectorFieldHnsw} + *
      • {@link VectorFieldFlat} + *
      + */ + public interface Field { + /** Convert to module API. */ + String[] toArgs(); + } + + private enum FieldType { + NUMERIC, + TEXT, + TAG, + VECTOR + } + + /** Field contains a number. */ + public static class NumericField implements Field { + @Override + public String[] toArgs() { + return new String[] {FieldType.NUMERIC.toString()}; + } + } + + /** Field contains any blob of data. */ + public static class TextField implements Field { + @Override + public String[] toArgs() { + return new String[] {FieldType.TEXT.toString()}; + } + } + + /** + * Tag fields are similar to full-text fields, but they interpret the text as a simple list of + * tags delimited by a separator character.
      + * For {@link DataType#HASH} fields, separator default is a comma (,). For {@link + * DataType#JSON} fields, there is no default separator; you must declare one explicitly if + * needed. + */ + public static class TagField implements Field { + private Optional separator; + private final boolean caseSensitive; + + /** Create a TAG field. */ + public TagField() { + this.separator = Optional.empty(); + this.caseSensitive = false; + } + + /** + * Create a TAG field. + * + * @param separator Specify how text in the attribute is split into individual tags. Must be a + * single character. + */ + public TagField(char separator) { + this.separator = Optional.of(separator); + this.caseSensitive = false; + } + + /** + * Create a TAG field. + * + * @param separator Specify how text in the attribute is split into individual tags. Must be a + * single character. + * @param caseSensitive Preserve the original letter cases of tags. If set to False, characters + * are converted to lowercase by default. + */ + public TagField(char separator, boolean caseSensitive) { + this.separator = Optional.of(separator); + this.caseSensitive = caseSensitive; + } + + /** + * Create a TAG field. + * + * @param caseSensitive Preserve the original letter cases of tags. If set to False, characters + * are converted to lowercase by default. + */ + public TagField(boolean caseSensitive) { + this.caseSensitive = caseSensitive; + } + + @Override + public String[] toArgs() { + var args = new ArrayList(); + args.add(FieldType.TAG.toString()); + if (separator.isPresent()) { + args.add("SEPARATOR"); + args.add(separator.get().toString()); + } + if (caseSensitive) { + args.add("CASESENSITIVE"); + } + return args.toArray(String[]::new); + } + } + + /** + * Distance metrics to measure the degree of similarity between two vectors.
      + * The above metrics calculate distance between two vectors, where the smaller the value is, the + * closer the two vectors are in the vector space. + */ + public enum DistanceMetric { + /** Euclidean distance between two vectors. */ + L2, + /** Inner product of two vectors. */ + IP, + /** Cosine distance of two vectors. */ + COSINE + } + + /** Superclass for vector field implementations, contains common logic. */ + @AllArgsConstructor(access = AccessLevel.PROTECTED) + abstract static class VectorField implements Field { + private final Map params; + private final VectorAlgorithm algorithm; + + @Override + public String[] toArgs() { + var args = new ArrayList(); + args.add(FieldType.VECTOR.toString()); + args.add(algorithm.toString()); + args.add(Integer.toString(params.size() * 2)); + params.forEach( + (name, value) -> { + args.add(name.toString()); + args.add(value); + }); + return args.toArray(String[]::new); + } + } + + /** Algorithm for vector type fields used for vector similarity search. */ + private enum VectorAlgorithm { + HNSW, + FLAT + } + + private enum VectorAlgorithmParam { + M, + EF_CONSTRUCTION, + EF_RUNTIME, + TYPE, + DIM, + DISTANCE_METRIC, + INITIAL_CAP + } + + /** + * Vector field that supports vector search by HNSM (Hierarchical Navigable Small + * World) algorithm.
      + * The algorithm provides an approximation of the correct answer in exchange for substantially + * lower execution times. + */ + public static class VectorFieldHnsw extends VectorField { + private VectorFieldHnsw(Map params) { + super(params, VectorAlgorithm.HNSW); + } + + /** + * Init a builder. + * + * @param distanceMetric {@link DistanceMetric} to measure the degree of similarity between two + * vectors. Equivalent to DISTANCE_METRIC on the module API. + * @param dimensions Vector dimension, specified as a positive integer. Maximum: 32768. + * Equivalent to DIM on the module API. + */ + public static VectorFieldHnswBuilder builder( + @NonNull DistanceMetric distanceMetric, int dimensions) { + return new VectorFieldHnswBuilder(distanceMetric, dimensions); + } + } + + public static class VectorFieldHnswBuilder extends VectorFieldBuilder { + VectorFieldHnswBuilder(DistanceMetric distanceMetric, int dimensions) { + super(distanceMetric, dimensions); + } + + @Override + public VectorFieldHnsw build() { + return new VectorFieldHnsw(params); + } + + /** + * Number of maximum allowed outgoing edges for each node in the graph in each layer. On layer + * zero the maximal number of outgoing edges is doubled. Default is 16 Maximum is 512. + * Equivalent to M on the module API. + */ + public VectorFieldHnswBuilder numberOfEdges(int numberOfEdges) { + params.put(VectorAlgorithmParam.M, Integer.toString(numberOfEdges)); + return this; + } + + /** + * (Optional) The number of vectors examined during index construction. Higher values for this + * parameter will improve recall ratio at the expense of longer index creation times. Default + * value is 200. Maximum value is 4096. Equivalent to EF_CONSTRUCTION on the module + * API. + */ + public VectorFieldHnswBuilder vectorsExaminedOnConstruction(int vectorsExaminedOnConstruction) { + params.put( + VectorAlgorithmParam.EF_CONSTRUCTION, Integer.toString(vectorsExaminedOnConstruction)); + return this; + } + + /** + * (Optional) The number of vectors examined during query operations. Higher values for this + * parameter can yield improved recall at the expense of longer query times. The value of this + * parameter can be overriden on a per-query basis. Default value is 10. Maximum value is 4096. + * Equivalent to EF_RUNTIME on the module API. + */ + public VectorFieldHnswBuilder vectorsExaminedOnRuntime(int vectorsExaminedOnRuntime) { + params.put(VectorAlgorithmParam.EF_RUNTIME, Integer.toString(vectorsExaminedOnRuntime)); + return this; + } + } + + /** + * Vector field that supports vector search by FLAT (brute force) algorithm.
      + * The algorithm is a brute force linear processing of each vector in the index, yielding exact + * answers within the bounds of the precision of the distance computations. + */ + public static class VectorFieldFlat extends VectorField { + + private VectorFieldFlat(Map params) { + super(params, VectorAlgorithm.FLAT); + } + + /** + * Init a builder. + * + * @param distanceMetric {@link DistanceMetric} to measure the degree of similarity between two + * vectors. Equivalent to DISTANCE_METRIC on the module API. + * @param dimensions Vector dimension, specified as a positive integer. Maximum: 32768. + * Equivalent to DIM on the module API. + */ + public static VectorFieldFlatBuilder builder( + @NonNull DistanceMetric distanceMetric, int dimensions) { + return new VectorFieldFlatBuilder(distanceMetric, dimensions); + } + } + + public static class VectorFieldFlatBuilder extends VectorFieldBuilder { + VectorFieldFlatBuilder(DistanceMetric distanceMetric, int dimensions) { + super(distanceMetric, dimensions); + } + + @Override + public VectorFieldFlat build() { + return new VectorFieldFlat(params); + } + } + + abstract static class VectorFieldBuilder> { + final Map params = new HashMap<>(); + + VectorFieldBuilder(DistanceMetric distanceMetric, int dimensions) { + params.put(VectorAlgorithmParam.TYPE, "FLOAT32"); + params.put(VectorAlgorithmParam.DIM, Integer.toString(dimensions)); + params.put(VectorAlgorithmParam.DISTANCE_METRIC, distanceMetric.toString()); + } + + /** + * Initial vector capacity in the index affecting memory allocation size of the index. Defaults + * to 1024. Equivalent to INITIAL_CAP on the module API. + */ + @SuppressWarnings("unchecked") + public T initialCapacity(int initialCapacity) { + params.put(VectorAlgorithmParam.INITIAL_CAP, Integer.toString(initialCapacity)); + return (T) this; + } + + public abstract VectorField build(); + } + + /** Field definition to be added into index schema. */ + public static class FieldInfo { + private final GlideString name; + private final GlideString alias; + private final Field field; + + /** + * Field definition to be added into index schema. + * + * @param name Field name. + * @param field The {@link Field} itself. + */ + public FieldInfo(@NonNull String name, @NonNull Field field) { + this.name = gs(name); + this.field = field; + this.alias = null; + } + + /** + * Field definition to be added into index schema. + * + * @param name Field name. + * @param alias Field alias. + * @param field The {@link Field} itself. + */ + public FieldInfo(@NonNull String name, @NonNull String alias, @NonNull Field field) { + this.name = gs(name); + this.alias = gs(alias); + this.field = field; + } + + /** + * Field definition to be added into index schema. + * + * @param name Field name. + * @param field The {@link Field} itself. + */ + public FieldInfo(@NonNull GlideString name, @NonNull Field field) { + this.name = name; + this.field = field; + this.alias = null; + } + + /** + * Field definition to be added into index schema. + * + * @param name Field name. + * @param alias Field alias. + * @param field The {@link Field} itself. + */ + public FieldInfo(@NonNull GlideString name, @NonNull GlideString alias, @NonNull Field field) { + this.name = name; + this.alias = alias; + this.field = field; + } + + /** Convert to module API. */ + public GlideString[] toArgs() { + var args = new ArrayList(); + args.add(name); + if (alias != null) { + args.add(gs("AS")); + args.add(alias); + } + args.addAll(Stream.of(field.toArgs()).map(GlideString::gs).collect(Collectors.toList())); + return args.toArray(GlideString[]::new); + } + } +} diff --git a/java/client/src/main/java/glide/api/models/commands/FT/FTProfileOptions.java b/java/client/src/main/java/glide/api/models/commands/FT/FTProfileOptions.java new file mode 100644 index 0000000000..5d9b7e892d --- /dev/null +++ b/java/client/src/main/java/glide/api/models/commands/FT/FTProfileOptions.java @@ -0,0 +1,126 @@ +/** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ +package glide.api.models.commands.FT; + +import static glide.api.models.GlideString.gs; +import static glide.utils.ArrayTransformUtils.concatenateArrays; + +import glide.api.commands.servermodules.FT; +import glide.api.models.GlideString; +import java.util.ArrayList; +import java.util.List; +import lombok.NonNull; + +/** Mandatory parameters for {@link FT#profile} command. */ +public class FTProfileOptions { + private final QueryType queryType; + private final boolean limited; + private final GlideString[] commandLine; + + /** Query type being profiled. */ + enum QueryType { + SEARCH, + AGGREGATE + } + + /** + * Profile an aggregation query with given parameters. + * + * @param query The query itself. + * @param options {@link FT#aggregate} options. + */ + public FTProfileOptions(@NonNull String query, @NonNull FTAggregateOptions options) { + this(gs(query), options); + } + + /** + * Profile an aggregation query with given parameters. + * + * @param query The query itself. + * @param options {@link FT#aggregate} options. + */ + public FTProfileOptions(@NonNull GlideString query, @NonNull FTAggregateOptions options) { + this(query, options, false); + } + + /** + * Profile a search query with given parameters. + * + * @param query The query itself. + * @param options {@link FT#search} options. + */ + public FTProfileOptions(@NonNull String query, @NonNull FTSearchOptions options) { + this(gs(query), options); + } + + /** + * Profile a search query with given parameters. + * + * @param query The query itself. + * @param options {@link FT#search} options. + */ + public FTProfileOptions(@NonNull GlideString query, @NonNull FTSearchOptions options) { + this(query, options, false); + } + + /** + * Profile an aggregation query with given parameters. + * + * @param query The query itself. + * @param options {@link FT#aggregate} options. + * @param limited Either provide a full verbose output or some brief version (limited). + */ + public FTProfileOptions( + @NonNull String query, @NonNull FTAggregateOptions options, boolean limited) { + this(gs(query), options, limited); + } + + /** + * Profile a search query with given parameters. + * + * @param query The query itself. + * @param options {@link FT#search} options. + * @param limited Either provide a full verbose output or some brief version (limited). + */ + public FTProfileOptions( + @NonNull GlideString query, @NonNull FTAggregateOptions options, boolean limited) { + queryType = QueryType.AGGREGATE; + commandLine = concatenateArrays(new GlideString[] {query}, options.toArgs()); + this.limited = limited; + } + + /** + * Profile an aggregation query with given parameters. + * + * @param query The query itself. + * @param options {@link FT#aggregate} options. + * @param limited Either provide a full verbose output or some brief version (limited). + */ + public FTProfileOptions( + @NonNull String query, @NonNull FTSearchOptions options, boolean limited) { + this(gs(query), options, limited); + } + + /** + * Profile a search query with given parameters. + * + * @param query The query itself. + * @param options {@link FT#search} options. + * @param limited Either provide a full verbose output or some brief version (limited). + */ + public FTProfileOptions( + @NonNull GlideString query, @NonNull FTSearchOptions options, boolean limited) { + queryType = QueryType.SEARCH; + commandLine = concatenateArrays(new GlideString[] {query}, options.toArgs()); + this.limited = limited; + } + + /** Convert to module API. */ + public GlideString[] toArgs() { + var args = new ArrayList(); + args.add(gs(queryType.toString())); + if (limited) args.add(gs("LIMITED")); + args.add(gs("QUERY")); + args.addAll(List.of(commandLine)); + return args.toArray(GlideString[]::new); + } +} diff --git a/java/client/src/main/java/glide/api/models/commands/FT/FTSearchOptions.java b/java/client/src/main/java/glide/api/models/commands/FT/FTSearchOptions.java new file mode 100644 index 0000000000..74407c64c0 --- /dev/null +++ b/java/client/src/main/java/glide/api/models/commands/FT/FTSearchOptions.java @@ -0,0 +1,133 @@ +/** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ +package glide.api.models.commands.FT; + +import static glide.api.models.GlideString.gs; + +import glide.api.commands.servermodules.FT; +import glide.api.models.GlideString; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; +import lombok.Builder; +import lombok.NonNull; +import org.apache.commons.lang3.tuple.Pair; + +/** Mandatory parameters for {@link FT#search}. */ +@Builder +public class FTSearchOptions { + + @Builder.Default private final Map identifiers = new HashMap<>(); + + /** Query timeout in milliseconds. */ + private final Integer timeout; + + private final Pair limit; + + @Builder.Default private final boolean count = false; + + /** + * Query parameters, which could be referenced in the query by $ sign, followed by + * the parameter name. + */ + @Builder.Default private final Map params = new HashMap<>(); + + // TODO maxstale? + // dialect is no-op + + /** Convert to module API. */ + public GlideString[] toArgs() { + var args = new ArrayList(); + if (!identifiers.isEmpty()) { + args.add(gs("RETURN")); + int tokenCount = 0; + for (var pair : identifiers.entrySet()) { + tokenCount++; + args.add(pair.getKey()); + if (pair.getValue() != null) { + tokenCount += 2; + args.add(gs("AS")); + args.add(pair.getValue()); + } + } + args.add(1, gs(Integer.toString(tokenCount))); + } + if (timeout != null) { + args.add(gs("TIMEOUT")); + args.add(gs(timeout.toString())); + } + if (!params.isEmpty()) { + args.add(gs("PARAMS")); + args.add(gs(Integer.toString(params.size() * 2))); + params.forEach( + (name, value) -> { + args.add(name); + args.add(value); + }); + } + if (limit != null) { + args.add(gs("LIMIT")); + args.add(gs(Integer.toString(limit.getLeft()))); + args.add(gs(Integer.toString(limit.getRight()))); + } + if (count) { + args.add(gs("COUNT")); + } + return args.toArray(GlideString[]::new); + } + + public static class FTSearchOptionsBuilder { + + // private - hiding this API from user + void limit(Pair limit) {} + + void count(boolean count) {} + + void identifiers(Map identifiers) {} + + /** Add a field to be returned. */ + public FTSearchOptionsBuilder addReturnField(@NonNull String field) { + this.identifiers$value.put(gs(field), null); + return this; + } + + /** Add a field with an alias to be returned. */ + public FTSearchOptionsBuilder addReturnField(@NonNull String field, @NonNull String alias) { + this.identifiers$value.put(gs(field), gs(alias)); + return this; + } + + /** Add a field to be returned. */ + public FTSearchOptionsBuilder addReturnField(@NonNull GlideString field) { + this.identifiers$value.put(field, null); + return this; + } + + /** Add a field with an alias to be returned. */ + public FTSearchOptionsBuilder addReturnField( + @NonNull GlideString field, @NonNull GlideString alias) { + this.identifiers$value.put(field, alias); + return this; + } + + /** + * Configure query pagination. By default only first 10 documents are returned. + * + * @param offset Zero-based offset. + * @param count Number of elements to return. + */ + public FTSearchOptionsBuilder limit(int offset, int count) { + this.limit = Pair.of(offset, count); + return this; + } + + /** + * Once set, the query will return only number of documents in the result set without actually + * returning them. + */ + public FTSearchOptionsBuilder count() { + this.count$value = true; + this.count$set = true; + return this; + } + } +} diff --git a/java/client/src/main/java/glide/api/models/commands/json/JsonArrindexOptions.java b/java/client/src/main/java/glide/api/models/commands/json/JsonArrindexOptions.java new file mode 100644 index 0000000000..9e430fef47 --- /dev/null +++ b/java/client/src/main/java/glide/api/models/commands/json/JsonArrindexOptions.java @@ -0,0 +1,54 @@ +/** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ +package glide.api.models.commands.json; + +import glide.api.commands.servermodules.Json; +import java.util.ArrayList; +import java.util.List; + +/** Additional parameters for {@link Json#arrindex} command. */ +public final class JsonArrindexOptions { + + /** The start index, inclusive. Default to 0. */ + private Long start; + + /** The end index, exclusive. */ + private Long end; + + /** + * Search using a start index (is inclusive). Defaults to 0 if not provided. Indices + * that exceed the array bounds are automatically adjusted to the nearest valid position. + */ + public JsonArrindexOptions(Long start) { + this.start = start; + } + + /** + * Search using a start index (is inclusive) and end index (is exclusive). If start + * is greater than end, the command returns -1 to indicate that the + * value was not found. Indices that exceed the array bounds are automatically adjusted to the + * nearest valid position. + */ + public JsonArrindexOptions(Long start, Long end) { + this.start = start; + this.end = end; + } + + /** + * Converts JsonArrindexOptions into a String[]. + * + * @return String[] + */ + public String[] toArgs() { + List args = new ArrayList<>(); + + if (start != null) { + args.add(start.toString()); + + if (end != null) { + args.add(end.toString()); + } + } + + return args.toArray(new String[0]); + } +} diff --git a/java/client/src/main/java/glide/api/models/commands/json/JsonGetOptions.java b/java/client/src/main/java/glide/api/models/commands/json/JsonGetOptions.java new file mode 100644 index 0000000000..5273e9c8c1 --- /dev/null +++ b/java/client/src/main/java/glide/api/models/commands/json/JsonGetOptions.java @@ -0,0 +1,64 @@ +/** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ +package glide.api.models.commands.json; + +import glide.api.commands.servermodules.Json; +import java.util.ArrayList; +import java.util.List; +import lombok.Builder; + +/** Additional parameters for {@link Json#get} command. */ +@Builder +public final class JsonGetOptions { + /** ValKey API string to designate INDENT */ + public static final String INDENT_VALKEY_API = "INDENT"; + + /** ValKey API string to designate NEWLINE */ + public static final String NEWLINE_VALKEY_API = "NEWLINE"; + + /** ValKey API string to designate SPACE */ + public static final String SPACE_VALKEY_API = "SPACE"; + + /** ValKey API string to designate SPACE */ + public static final String NOESCAPE_VALKEY_API = "NOESCAPE"; + + /** Sets an indentation string for nested levels. */ + private String indent; + + /** Sets a string that's printed at the end of each line. */ + private String newline; + + /** Sets a string that's put between a key and a value. */ + private String space; + + /** Allowed to be present for legacy compatibility and has no other effect. */ + private boolean noescape; + + /** + * Converts JsonGetOptions into a String[]. + * + * @return String[] + */ + public String[] toArgs() { + List args = new ArrayList<>(); + if (indent != null) { + args.add(INDENT_VALKEY_API); + args.add(indent); + } + + if (newline != null) { + args.add(NEWLINE_VALKEY_API); + args.add(newline); + } + + if (space != null) { + args.add(SPACE_VALKEY_API); + args.add(space); + } + + if (noescape) { + args.add(NOESCAPE_VALKEY_API); + } + + return args.toArray(new String[0]); + } +} diff --git a/java/client/src/main/java/glide/api/models/commands/json/JsonGetOptionsBinary.java b/java/client/src/main/java/glide/api/models/commands/json/JsonGetOptionsBinary.java new file mode 100644 index 0000000000..634b4d298e --- /dev/null +++ b/java/client/src/main/java/glide/api/models/commands/json/JsonGetOptionsBinary.java @@ -0,0 +1,67 @@ +/** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ +package glide.api.models.commands.json; + +import static glide.api.models.GlideString.gs; + +import glide.api.commands.servermodules.Json; +import glide.api.models.GlideString; +import java.util.ArrayList; +import java.util.List; +import lombok.Builder; + +/** GlideString version of additional parameters for {@link Json#get} command. */ +@Builder +public final class JsonGetOptionsBinary { + /** ValKey API string to designate INDENT */ + public static final GlideString INDENT_VALKEY_API = gs("INDENT"); + + /** ValKey API string to designate NEWLINE */ + public static final GlideString NEWLINE_VALKEY_API = gs("NEWLINE"); + + /** ValKey API string to designate SPACE */ + public static final GlideString SPACE_VALKEY_API = gs("SPACE"); + + /** ValKey API string to designate SPACE */ + public static final GlideString NOESCAPE_VALKEY_API = gs("NOESCAPE"); + + /** Sets an indentation string for nested levels. */ + private GlideString indent; + + /** Sets a string that's printed at the end of each line. */ + private GlideString newline; + + /** Sets a string that's put between a key and a value. */ + private GlideString space; + + /** Allowed to be present for legacy compatibility and has no other effect. */ + private boolean noescape; + + /** + * Converts JsonGetOptions into a GlideString[]. + * + * @return GlideString[] + */ + public GlideString[] toArgs() { + List args = new ArrayList<>(); + if (indent != null) { + args.add(INDENT_VALKEY_API); + args.add(indent); + } + + if (newline != null) { + args.add(NEWLINE_VALKEY_API); + args.add(newline); + } + + if (space != null) { + args.add(SPACE_VALKEY_API); + args.add(space); + } + + if (noescape) { + args.add(NOESCAPE_VALKEY_API); + } + + return args.toArray(new GlideString[0]); + } +} diff --git a/java/client/src/main/java/glide/api/models/configuration/BaseClientConfiguration.java b/java/client/src/main/java/glide/api/models/configuration/BaseClientConfiguration.java index e0a4ed5500..7d9d5d5b68 100644 --- a/java/client/src/main/java/glide/api/models/configuration/BaseClientConfiguration.java +++ b/java/client/src/main/java/glide/api/models/configuration/BaseClientConfiguration.java @@ -74,4 +74,10 @@ public abstract class BaseClientConfiguration { * used. */ private final Integer inflightRequestsLimit; + + /** + * Availability Zone of the client. If ReadFrom strategy is AZAffinity, this setting ensures that + * readonly commands are directed to replicas within the specified AZ if exits. + */ + private final String clientAZ; } diff --git a/java/client/src/main/java/glide/api/models/configuration/ReadFrom.java b/java/client/src/main/java/glide/api/models/configuration/ReadFrom.java index 2d80ae7b60..29a212d8c7 100644 --- a/java/client/src/main/java/glide/api/models/configuration/ReadFrom.java +++ b/java/client/src/main/java/glide/api/models/configuration/ReadFrom.java @@ -9,5 +9,10 @@ public enum ReadFrom { * Spread the requests between all replicas in a round-robin manner. If no replica is available, * route the requests to the primary. */ - PREFER_REPLICA + PREFER_REPLICA, + /** + * Spread the read requests between replicas in the same client's AZ (Aviliablity zone) in a + * round-robin manner, falling back to other replicas or the primary if needed. + */ + AZ_AFFINITY, } diff --git a/java/client/src/main/java/glide/ffi/resolvers/StatisticsResolver.java b/java/client/src/main/java/glide/ffi/resolvers/StatisticsResolver.java new file mode 100644 index 0000000000..0bb3731a71 --- /dev/null +++ b/java/client/src/main/java/glide/ffi/resolvers/StatisticsResolver.java @@ -0,0 +1,12 @@ +/** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ +package glide.ffi.resolvers; + +public class StatisticsResolver { + // TODO: consider lazy loading the glide_rs library + static { + NativeUtils.loadGlideLib(); + } + + /** Return the internal statistics Map object */ + public static native Object getStatistics(); +} diff --git a/java/client/src/main/java/glide/managers/CommandManager.java b/java/client/src/main/java/glide/managers/CommandManager.java index 639e258d81..d069c6bd72 100644 --- a/java/client/src/main/java/glide/managers/CommandManager.java +++ b/java/client/src/main/java/glide/managers/CommandManager.java @@ -12,6 +12,7 @@ import command_request.CommandRequestOuterClass.ScriptInvocationPointers; import command_request.CommandRequestOuterClass.SimpleRoutes; import command_request.CommandRequestOuterClass.SlotTypes; +import command_request.CommandRequestOuterClass.UpdateConnectionPassword; import glide.api.models.ClusterTransaction; import glide.api.models.GlideString; import glide.api.models.Script; @@ -218,6 +219,26 @@ public CompletableFuture submitClusterScan( return submitCommandToChannel(command, responseHandler); } + /** + * Submit a password update request to GLIDE core. + * + * @param password A new password to set or empty value to remove the password. + * @param immediateAuth immediately perform auth. + * @param responseHandler A response handler. + * @return A request promise. + * @param Type of the response. + */ + public CompletableFuture submitPasswordUpdate( + Optional password, + boolean immediateAuth, + GlideExceptionCheckedFunction responseHandler) { + var builder = UpdateConnectionPassword.newBuilder().setImmediateAuth(immediateAuth); + password.ifPresent(builder::setPassword); + + var command = CommandRequest.newBuilder().setUpdateConnectionPassword(builder.build()); + return submitCommandToChannel(command, responseHandler); + } + /** * Take a command request and send to channel. * diff --git a/java/client/src/main/java/glide/managers/ConnectionManager.java b/java/client/src/main/java/glide/managers/ConnectionManager.java index a5a8b9c5c3..99b383a9ed 100644 --- a/java/client/src/main/java/glide/managers/ConnectionManager.java +++ b/java/client/src/main/java/glide/managers/ConnectionManager.java @@ -14,6 +14,7 @@ import glide.api.models.configuration.NodeAddress; import glide.api.models.configuration.ReadFrom; import glide.api.models.exceptions.ClosingException; +import glide.api.models.exceptions.ConfigurationError; import glide.connectors.handlers.ChannelHandler; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; @@ -122,6 +123,14 @@ private ConnectionRequest.Builder setupConnectionRequestBuilderBaseConfiguration connectionRequestBuilder.setInflightRequestsLimit(configuration.getInflightRequestsLimit()); } + if (configuration.getReadFrom() == ReadFrom.AZ_AFFINITY) { + if (configuration.getClientAZ() == null) { + throw new ConfigurationError( + "`clientAZ` must be set when read_from is set to `AZ_AFFINITY`"); + } + connectionRequestBuilder.setClientAz(configuration.getClientAZ()); + } + return connectionRequestBuilder; } @@ -200,11 +209,14 @@ private ConnectionRequest.Builder setupConnectionRequestBuilderGlideClusterClien * @return Protobuf defined ReadFrom enum */ private ConnectionRequestOuterClass.ReadFrom mapReadFromEnum(ReadFrom readFrom) { - if (readFrom == ReadFrom.PREFER_REPLICA) { - return ConnectionRequestOuterClass.ReadFrom.PreferReplica; + switch (readFrom) { + case PREFER_REPLICA: + return ConnectionRequestOuterClass.ReadFrom.PreferReplica; + case AZ_AFFINITY: + return ConnectionRequestOuterClass.ReadFrom.AZAffinity; + default: + return ConnectionRequestOuterClass.ReadFrom.Primary; } - - return ConnectionRequestOuterClass.ReadFrom.Primary; } /** Check a response received from Glide. */ diff --git a/java/client/src/main/java/module-info.java b/java/client/src/main/java/module-info.java index 99c4655082..fc280da076 100644 --- a/java/client/src/main/java/module-info.java +++ b/java/client/src/main/java/module-info.java @@ -9,8 +9,11 @@ exports glide.api.models.commands.function; exports glide.api.models.commands.scan; exports glide.api.models.commands.stream; + exports glide.api.models.commands.FT; + exports glide.api.models.commands.json; exports glide.api.models.configuration; exports glide.api.models.exceptions; + exports glide.api.commands.servermodules; requires com.google.protobuf; requires io.netty.codec; diff --git a/java/client/src/test/java/glide/api/commands/servermodules/JsonTest.java b/java/client/src/test/java/glide/api/commands/servermodules/JsonTest.java new file mode 100644 index 0000000000..884f1bed27 --- /dev/null +++ b/java/client/src/test/java/glide/api/commands/servermodules/JsonTest.java @@ -0,0 +1,685 @@ +/** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ +package glide.api.commands.servermodules; + +import static glide.api.models.GlideString.gs; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.RETURNS_DEEP_STUBS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import glide.api.GlideClient; +import glide.api.models.GlideString; +import glide.api.models.commands.ConditionalChange; +import glide.api.models.commands.json.JsonGetOptions; +import glide.api.models.commands.json.JsonGetOptionsBinary; +import glide.utils.ArgsBuilder; +import glide.utils.ArrayTransformUtils; +import java.util.ArrayList; +import java.util.Collections; +import java.util.concurrent.CompletableFuture; +import lombok.SneakyThrows; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class JsonTest { + + private GlideClient glideClient; + + @BeforeEach + void setUp() { + glideClient = mock(GlideClient.class, RETURNS_DEEP_STUBS); + } + + @Test + @SneakyThrows + void set_returns_success() { + // setup + String key = "testKey"; + String path = "$"; + String jsonValue = "{\"a\": 1.0, \"b\": 2}"; + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "OK"; + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new String[] {"JSON.SET", key, path, jsonValue})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.set(glideClient, key, path, jsonValue); + String actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void set_binary_returns_success() { + // setup + GlideString key = gs("testKey"); + GlideString path = gs("$"); + GlideString jsonValue = gs("{\"a\": 1.0, \"b\": 2}"); + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "OK"; + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new GlideString[] {gs("JSON.SET"), key, path, jsonValue})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.set(glideClient, key, path, jsonValue); + String actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void set_with_condition_returns_success() { + // setup + String key = "testKey"; + String path = "$"; + String jsonValue = "{\"a\": 1.0, \"b\": 2}"; + ConditionalChange setCondition = ConditionalChange.ONLY_IF_DOES_NOT_EXIST; + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "OK"; + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand( + eq(new String[] {"JSON.SET", key, path, jsonValue, setCondition.getValkeyApi()})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = + Json.set(glideClient, key, path, jsonValue, setCondition); + String actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void set_binary_with_condition_returns_success() { + // setup + GlideString key = gs("testKey"); + GlideString path = gs("$"); + GlideString jsonValue = gs("{\"a\": 1.0, \"b\": 2}"); + ConditionalChange setCondition = ConditionalChange.ONLY_IF_DOES_NOT_EXIST; + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "OK"; + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand( + eq( + new GlideString[] { + gs("JSON.SET"), key, path, jsonValue, gs(setCondition.getValkeyApi()) + })) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = + Json.set(glideClient, key, path, jsonValue, setCondition); + String actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void get_with_no_path_returns_success() { + // setup + String key = "testKey"; + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "{\"a\": 1.0, \"b\": 2}"; + expectedResponse.complete(expectedResponseValue); + when(glideClient.customCommand(eq(new String[] {"JSON.GET", key})).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.get(glideClient, key); + String actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void get_binary_with_no_path_returns_success() { + // setup + GlideString key = gs("testKey"); + CompletableFuture expectedResponse = new CompletableFuture<>(); + GlideString expectedResponseValue = gs("{\"a\": 1.0, \"b\": 2}"); + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new GlideString[] {gs("JSON.GET"), key})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.get(glideClient, key); + GlideString actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void get_with_multiple_paths_returns_success() { + // setup + String key = "testKey"; + String path1 = ".firstName"; + String path2 = ".lastName"; + String[] paths = new String[] {path1, path2}; + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "{\"a\": 1.0, \"b\": 2}"; + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new String[] {"JSON.GET", key, path1, path2})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.get(glideClient, key, paths); + String actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void get_binary_with_multiple_paths_returns_success() { + // setup + GlideString key = gs("testKey"); + GlideString path1 = gs(".firstName"); + GlideString path2 = gs(".lastName"); + GlideString[] paths = new GlideString[] {path1, path2}; + CompletableFuture expectedResponse = new CompletableFuture<>(); + GlideString expectedResponseValue = gs("{\"a\": 1.0, \"b\": 2}"); + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new GlideString[] {gs("JSON.GET"), key, path1, path2})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.get(glideClient, key, paths); + GlideString actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void get_with_no_path_and_options_returns_success() { + // setup + String key = "testKey"; + JsonGetOptions options = JsonGetOptions.builder().indent("\t").space(" ").newline("\n").build(); + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "{\"a\": 1.0, \"b\": 2}"; + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand( + eq( + ArrayTransformUtils.concatenateArrays( + new String[] {"JSON.GET", key}, options.toArgs()))) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.get(glideClient, key, options); + String actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void get_binary_with_no_path_and_options_returns_success() { + // setup + GlideString key = gs("testKey"); + JsonGetOptionsBinary options = + JsonGetOptionsBinary.builder().indent(gs("\t")).space(gs(" ")).newline(gs("\n")).build(); + CompletableFuture expectedResponse = new CompletableFuture<>(); + GlideString expectedResponseValue = gs("{\"a\": 1.0, \"b\": 2}"); + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand( + eq( + new ArgsBuilder() + .add(new GlideString[] {gs("JSON.GET"), key}) + .add(options.toArgs()) + .toArray())) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.get(glideClient, key, options); + GlideString actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void get_with_multiple_paths_and_options_returns_success() { + // setup + String key = "testKey"; + String path1 = ".firstName"; + String path2 = ".lastName"; + JsonGetOptions options = JsonGetOptions.builder().indent("\t").newline("\n").space(" ").build(); + String[] paths = new String[] {path1, path2}; + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "{\"a\": 1.0, \"b\": 2}"; + expectedResponse.complete(expectedResponseValue); + ArrayList argsList = new ArrayList<>(); + argsList.add("JSON.GET"); + argsList.add(key); + Collections.addAll(argsList, options.toArgs()); + Collections.addAll(argsList, paths); + when(glideClient.customCommand(eq(argsList.toArray(new String[0]))).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.get(glideClient, key, paths, options); + String actualResponseValue = actualResponse.get(); + + // verify + assertArrayEquals( + new String[] {"INDENT", "\t", "NEWLINE", "\n", "SPACE", " "}, options.toArgs()); + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void get_binary_with_multiple_paths_and_options_returns_success() { + // setup + GlideString key = gs("testKey"); + GlideString path1 = gs(".firstName"); + GlideString path2 = gs(".lastName"); + JsonGetOptionsBinary options = + JsonGetOptionsBinary.builder().indent(gs("\t")).newline(gs("\n")).space(gs(" ")).build(); + GlideString[] paths = new GlideString[] {path1, path2}; + CompletableFuture expectedResponse = new CompletableFuture<>(); + GlideString expectedResponseValue = gs("{\"a\": 1.0, \"b\": 2}"); + expectedResponse.complete(expectedResponseValue); + GlideString[] args = + new ArgsBuilder().add("JSON.GET").add(key).add(options.toArgs()).add(paths).toArray(); + when(glideClient.customCommand(eq(args)).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.get(glideClient, key, paths, options); + GlideString actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void del_with_no_path_returns_success() { + // setup + String key = "testKey"; + CompletableFuture expectedResponse = new CompletableFuture<>(); + Long expectedResponseValue = 1L; + expectedResponse.complete(expectedResponseValue); + when(glideClient.customCommand(eq(new String[] {"JSON.DEL", key})).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.del(glideClient, key); + Long actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void del_binary_with_no_path_returns_success() { + // setup + GlideString key = gs("testKey"); + CompletableFuture expectedResponse = new CompletableFuture<>(); + Long expectedResponseValue = 1L; + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new GlideString[] {gs("JSON.DEL"), key})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.del(glideClient, key); + Long actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void del_with_path_returns_success() { + // setup + String key = "testKey"; + String path = "$"; + CompletableFuture expectedResponse = new CompletableFuture<>(); + Long expectedResponseValue = 2L; + expectedResponse.complete(expectedResponseValue); + when(glideClient.customCommand(eq(new String[] {"JSON.DEL", key, path})).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.del(glideClient, key, path); + Long actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void del_binary_with_path_returns_success() { + // setup + GlideString key = gs("testKey"); + GlideString path = gs("$"); + CompletableFuture expectedResponse = new CompletableFuture<>(); + Long expectedResponseValue = 2L; + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new GlideString[] {gs("JSON.DEL"), key, path})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.del(glideClient, key, path); + Long actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void forget_with_no_path_returns_success() { + // setup + String key = "testKey"; + CompletableFuture expectedResponse = new CompletableFuture<>(); + Long expectedResponseValue = 1L; + expectedResponse.complete(expectedResponseValue); + when(glideClient.customCommand(eq(new String[] {"JSON.FORGET", key})).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.forget(glideClient, key); + Long actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void forget_binary_with_no_path_returns_success() { + // setup + GlideString key = gs("testKey"); + CompletableFuture expectedResponse = new CompletableFuture<>(); + Long expectedResponseValue = 1L; + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new GlideString[] {gs("JSON.FORGET"), key})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.forget(glideClient, key); + Long actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void forget_with_path_returns_success() { + // setup + String key = "testKey"; + String path = "$"; + CompletableFuture expectedResponse = new CompletableFuture<>(); + Long expectedResponseValue = 2L; + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new String[] {"JSON.FORGET", key, path})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.forget(glideClient, key, path); + Long actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void forget_binary_with_path_returns_success() { + // setup + GlideString key = gs("testKey"); + GlideString path = gs("$"); + CompletableFuture expectedResponse = new CompletableFuture<>(); + Long expectedResponseValue = 2L; + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new GlideString[] {gs("JSON.FORGET"), key, path})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.forget(glideClient, key, path); + Long actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void resp_without_path_returns_success() { + // setup + String key = "testKey"; + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "foo"; + expectedResponse.complete(expectedResponseValue); + when(glideClient.customCommand(eq(new String[] {"JSON.RESP", key})).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.resp(glideClient, key); + Object actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void resp_binary_without_path_returns_success() { + // setup + GlideString key = gs("testKey"); + CompletableFuture expectedResponse = new CompletableFuture<>(); + GlideString expectedResponseValue = gs("foo"); + expectedResponse.complete(expectedResponseValue); + when(glideClient.customCommand(eq(new GlideString[] {gs("JSON.RESP"), key})).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.resp(glideClient, key); + Object actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void resp_with_path_returns_success() { + // setup + String key = "testKey"; + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "foo"; + expectedResponse.complete(expectedResponseValue); + when(glideClient.customCommand(eq(new String[] {"JSON.RESP", key, "$"})).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.resp(glideClient, key, "$"); + Object actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void resp_binary_with_path_returns_success() { + // setup + GlideString key = gs("testKey"); + CompletableFuture expectedResponse = new CompletableFuture<>(); + GlideString expectedResponseValue = gs("foo"); + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new GlideString[] {gs("JSON.RESP"), key, gs("$")})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.resp(glideClient, key, gs("$")); + Object actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void type_without_path_returns_success() { + // setup + String key = "testKey"; + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "foo"; + expectedResponse.complete(expectedResponseValue); + when(glideClient.customCommand(eq(new String[] {"JSON.TYPE", key})).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.type(glideClient, key); + Object actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void type_binary_without_path_returns_success() { + // setup + GlideString key = gs("testKey"); + CompletableFuture expectedResponse = new CompletableFuture<>(); + GlideString expectedResponseValue = gs("foo"); + expectedResponse.complete(expectedResponseValue); + when(glideClient.customCommand(eq(new GlideString[] {gs("JSON.TYPE"), key})).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.type(glideClient, key); + Object actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void type_with_path_returns_success() { + // setup + String key = "testKey"; + String path = "$"; + CompletableFuture expectedResponse = new CompletableFuture<>(); + String expectedResponseValue = "foo"; + expectedResponse.complete(expectedResponseValue); + when(glideClient.customCommand(eq(new String[] {"JSON.TYPE", key, path})).thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.type(glideClient, key, path); + Object actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } + + @Test + @SneakyThrows + void type_binary_with_path_returns_success() { + // setup + GlideString key = gs("testKey"); + GlideString path = gs("$"); + CompletableFuture expectedResponse = new CompletableFuture<>(); + GlideString expectedResponseValue = gs("foo"); + expectedResponse.complete(expectedResponseValue); + when(glideClient + .customCommand(eq(new GlideString[] {gs("JSON.TYPE"), key, path})) + .thenApply(any())) + .thenReturn(expectedResponse); + + // exercise + CompletableFuture actualResponse = Json.type(glideClient, key, path); + Object actualResponseValue = actualResponse.get(); + + // verify + assertEquals(expectedResponse, actualResponse); + assertEquals(expectedResponseValue, actualResponseValue); + } +} diff --git a/java/client/src/test/java/glide/managers/ConnectionManagerTest.java b/java/client/src/test/java/glide/managers/ConnectionManagerTest.java index 7a8f1a0d44..80e9783c88 100644 --- a/java/client/src/test/java/glide/managers/ConnectionManagerTest.java +++ b/java/client/src/test/java/glide/managers/ConnectionManagerTest.java @@ -32,6 +32,7 @@ import glide.api.models.configuration.ServerCredentials; import glide.api.models.configuration.StandaloneSubscriptionConfiguration; import glide.api.models.exceptions.ClosingException; +import glide.api.models.exceptions.ConfigurationError; import glide.connectors.handlers.ChannelHandler; import io.netty.channel.ChannelFuture; import java.util.Map; @@ -268,4 +269,58 @@ public void connection_on_resp_pointer_throws_ClosingException() { assertEquals("Unexpected data in response", executionException.getCause().getMessage()); verify(channel).close(); } + + @SneakyThrows + @Test + public void test_convert_config_with_azaffinity_to_protobuf() { + // setup + String az = "us-east-1a"; + GlideClientConfiguration config = + GlideClientConfiguration.builder() + .address(NodeAddress.builder().host(DEFAULT_HOST).port(DEFAULT_PORT).build()) + .useTLS(true) + .readFrom(ReadFrom.AZ_AFFINITY) + .clientAZ(az) + .build(); + + ConnectionRequest request = + ConnectionRequest.newBuilder() + .addAddresses( + ConnectionRequestOuterClass.NodeAddress.newBuilder() + .setHost(DEFAULT_HOST) + .setPort(DEFAULT_PORT) + .build()) + .setTlsMode(TlsMode.SecureTls) + .setReadFrom(ConnectionRequestOuterClass.ReadFrom.AZAffinity) + .setClientAz(az) + .build(); + + CompletableFuture completedFuture = new CompletableFuture<>(); + Response response = Response.newBuilder().setConstantResponse(ConstantResponse.OK).build(); + completedFuture.complete(response); + + // execute + when(channel.connect(eq(request))).thenReturn(completedFuture); + CompletableFuture result = connectionManager.connectToValkey(config); + + // verify + assertNull(result.get()); + verify(channel).connect(eq(request)); + } + + @SneakyThrows + @Test + public void test_az_affinity_without_client_az_throws_ConfigurationError() { + // setup + String az = "us-east-1a"; + GlideClientConfiguration config = + GlideClientConfiguration.builder() + .address(NodeAddress.builder().host(DEFAULT_HOST).port(DEFAULT_PORT).build()) + .useTLS(true) + .readFrom(ReadFrom.AZ_AFFINITY) + .build(); + + // verify + assertThrows(ConfigurationError.class, () -> connectionManager.connectToValkey(config)); + } } diff --git a/java/integTest/build.gradle b/java/integTest/build.gradle index d467b4ebbb..53b690aa49 100644 --- a/java/integTest/build.gradle +++ b/java/integTest/build.gradle @@ -11,11 +11,12 @@ dependencies { implementation project(':client') implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.13.0' + implementation 'com.google.code.gson:gson:2.10.1' // https://github.com/netty/netty/wiki/Native-transports // At the moment, Windows is not supported implementation group: 'io.netty', name: 'netty-transport-native-epoll', version: '4.1.100.Final', classifier: 'linux-x86_64' - implementation group: 'io.netty', name: 'netty-transport-native-kqueue', version: '4.1.100.Final', classifier: 'osx-x86_64' + implementation group: 'io.netty', name: 'netty-transport-native-epoll', version: '4.1.100.Final', classifier: 'linux-aarch_64' implementation group: 'io.netty', name: 'netty-transport-native-kqueue', version: '4.1.100.Final', classifier: 'osx-aarch_64' // junit @@ -32,6 +33,7 @@ dependencies { def standaloneHosts = '' def clusterHosts = '' +def azClusterHosts = '' ext { extractAddressesFromClusterManagerOutput = { String output -> @@ -83,6 +85,25 @@ tasks.register('startCluster') { } } +tasks.register('startClusterForAz') { + doLast { + if (System.getProperty("cluster-endpoints") == null) { + new ByteArrayOutputStream().withStream { os -> + exec { + workingDir "${project.rootDir}/../utils" + def args = ['python3', 'cluster_manager.py', 'start', '--cluster-mode', '-r', '4'] + if (System.getProperty("tls") == 'true') args.add(2, '--tls') + commandLine args + standardOutput = os + } + azClusterHosts = extractAddressesFromClusterManagerOutput(os.toString()) + } + } else { + azClusterHosts = System.getProperty("cluster-endpoints") + } + } +} + tasks.register('startStandalone') { doLast { if (System.getProperty("standalone-endpoints") == null) { @@ -102,20 +123,19 @@ tasks.register('startStandalone') { } } - test.dependsOn 'stopAllBeforeTests' stopAllBeforeTests.finalizedBy 'clearDirs' clearDirs.finalizedBy 'startStandalone' clearDirs.finalizedBy 'startCluster' +clearDirs.finalizedBy 'startClusterForAz' test.finalizedBy 'stopAllAfterTests' test.dependsOn ':client:buildRustRelease' tasks.withType(Test) { doFirst { - println "Cluster hosts = ${clusterHosts}" - println "Standalone hosts = ${standaloneHosts}" systemProperty 'test.server.standalone', standaloneHosts systemProperty 'test.server.cluster', clusterHosts + systemProperty 'test.server.azcluster', azClusterHosts systemProperty 'test.server.tls', System.getProperty("tls") } diff --git a/java/integTest/src/test/java/glide/ConnectionTests.java b/java/integTest/src/test/java/glide/ConnectionTests.java index de17f54e1c..2aec2e4e6b 100644 --- a/java/integTest/src/test/java/glide/ConnectionTests.java +++ b/java/integTest/src/test/java/glide/ConnectionTests.java @@ -1,11 +1,25 @@ /** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ package glide; +import static glide.TestConfiguration.SERVER_VERSION; +import static glide.TestUtilities.azClusterClientConfig; import static glide.TestUtilities.commonClientConfig; import static glide.TestUtilities.commonClusterClientConfig; +import static glide.api.BaseClient.OK; +import static glide.api.models.configuration.RequestRoutingConfiguration.SimpleMultiNodeRoute.ALL_NODES; +import static glide.api.models.configuration.RequestRoutingConfiguration.SlotType.PRIMARY; +import static glide.api.models.configuration.RequestRoutingConfiguration.SlotType.REPLICA; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; import glide.api.GlideClient; import glide.api.GlideClusterClient; +import glide.api.models.ClusterValue; +import glide.api.models.commands.InfoOptions; +import glide.api.models.configuration.ReadFrom; +import glide.api.models.configuration.RequestRoutingConfiguration; +import java.util.Map; +import java.util.stream.Stream; import lombok.SneakyThrows; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; @@ -26,4 +40,166 @@ public void cluster_client() { var clusterClient = GlideClusterClient.createClient(commonClusterClientConfig().build()).get(); clusterClient.close(); } + + @SneakyThrows + public GlideClusterClient createAzTestClient(String az) { + return GlideClusterClient.createClient( + azClusterClientConfig() + .readFrom(ReadFrom.AZ_AFFINITY) + .clientAZ(az) + .requestTimeout(2000) + .build()) + .get(); + } + + /** + * Test that the client with AZ affinity strategy routes in a round-robin manner to all replicas + * within the specified AZ. + */ + @SneakyThrows + @Test + public void test_routing_by_slot_to_replica_with_az_affinity_strategy_to_all_replicas() { + assumeTrue(SERVER_VERSION.isGreaterThanOrEqualTo("8.0.0"), "Skip for versions below 8"); + + String az = "us-east-1a"; + + // Create client for setting the configs + GlideClusterClient configSetClient = + GlideClusterClient.createClient(azClusterClientConfig().requestTimeout(2000).build()).get(); + assertEquals(configSetClient.configResetStat().get(), OK); + + // Get Replica Count for current cluster + var clusterInfo = + configSetClient + .customCommand( + new String[] {"INFO", "REPLICATION"}, + new RequestRoutingConfiguration.SlotKeyRoute("key", PRIMARY)) + .get(); + long nReplicas = + Long.parseLong( + Stream.of(((String) clusterInfo.getSingleValue()).split("\\R")) + .map(line -> line.split(":", 2)) + .filter(parts -> parts.length == 2 && parts[0].trim().equals("connected_slaves")) + .map(parts -> parts[1].trim()) + .findFirst() + .get()); + long nGetCalls = 3 * nReplicas; + String getCmdstat = String.format("cmdstat_get:calls=%d", 3); + + // Setting AZ for all Nodes + configSetClient.configSet(Map.of("availability-zone", az), ALL_NODES).get(); + configSetClient.close(); + + // Creating Client with AZ configuration for testing + GlideClusterClient azTestClient = createAzTestClient(az); + ClusterValue> azGetResult = + azTestClient.configGet(new String[] {"availability-zone"}, ALL_NODES).get(); + Map> azData = azGetResult.getMultiValue(); + + // Check that all replicas have the availability zone set to the az + for (var entry : azData.entrySet()) { + assertEquals(az, entry.getValue().get("availability-zone")); + } + + // execute GET commands + for (int i = 0; i < nGetCalls; i++) { + azTestClient.get("foo").get(); + } + + ClusterValue infoResult = + azTestClient.info(new InfoOptions.Section[] {InfoOptions.Section.ALL}, ALL_NODES).get(); + Map infoData = infoResult.getMultiValue(); + + // Check that all replicas have the same number of GET calls + long matchingEntries = + infoData.values().stream() + .filter(value -> value.contains(getCmdstat) && value.contains(az)) + .count(); + assertEquals(nReplicas, matchingEntries); + azTestClient.close(); + } + + /** + * Test that the client with az affinity strategy will only route to the 1 replica with the same + * az. + */ + @SneakyThrows + @Test + public void test_routing_with_az_affinity_strategy_to_1_replica() { + assumeTrue(SERVER_VERSION.isGreaterThanOrEqualTo("8.0.0"), "Skip for versions below 8"); + + String az = "us-east-1a"; + int nGetCalls = 3; + String getCmdstat = String.format("cmdstat_get:calls=%d", nGetCalls); + + GlideClusterClient configSetClient = + GlideClusterClient.createClient(azClusterClientConfig().requestTimeout(2000).build()).get(); + + // reset availability zone for all nodes + configSetClient.configSet(Map.of("availability-zone", ""), ALL_NODES).get(); + assertEquals(configSetClient.configResetStat().get(), OK); + + Long fooSlotKey = + (Long) + configSetClient + .customCommand(new String[] {"CLUSTER", "KEYSLOT", "foo"}) + .get() + .getSingleValue(); + int convertedKey = Integer.parseInt(fooSlotKey.toString()); + configSetClient + .configSet( + Map.of("availability-zone", az), + new RequestRoutingConfiguration.SlotIdRoute(convertedKey, REPLICA)) + .get(); + configSetClient.close(); + + GlideClusterClient azTestClient = createAzTestClient(az); + + // execute GET commands + for (int i = 0; i < nGetCalls; i++) { + azTestClient.get("foo").get(); + } + + ClusterValue infoResult = + azTestClient.info(new InfoOptions.Section[] {InfoOptions.Section.ALL}, ALL_NODES).get(); + Map infoData = infoResult.getMultiValue(); + + // Check that all replicas have the same number of GET calls + long matchingEntries = + infoData.values().stream() + .filter(value -> value.contains(getCmdstat) && value.contains(az)) + .count(); + assertEquals(1, matchingEntries); + azTestClient.close(); + } + + @SneakyThrows + @Test + public void test_az_affinity_non_existing_az() { + assumeTrue(SERVER_VERSION.isGreaterThanOrEqualTo("8.0.0"), "Skip for versions below 8"); + + int nGetCalls = 4; + int nReplicaCalls = 1; + String getCmdstat = String.format("cmdstat_get:calls=%d", nReplicaCalls); + + GlideClusterClient azTestClient = createAzTestClient("non-existing-az"); + assertEquals(azTestClient.configResetStat(ALL_NODES).get(), OK); + + // execute GET commands + for (int i = 0; i < nGetCalls; i++) { + azTestClient.get("foo").get(); + } + + ClusterValue infoResult = + azTestClient + .info(new InfoOptions.Section[] {InfoOptions.Section.COMMANDSTATS}, ALL_NODES) + .get(); + Map infoData = infoResult.getMultiValue(); + + // We expect the calls to be distributed evenly among the replicas + long matchingEntries = + infoData.values().stream().filter(value -> value.contains(getCmdstat)).count(); + assertEquals(4, matchingEntries); + azTestClient.close(); + } } diff --git a/java/integTest/src/test/java/glide/PubSubTests.java b/java/integTest/src/test/java/glide/PubSubTests.java index e4eeab6cad..7b4e835b80 100644 --- a/java/integTest/src/test/java/glide/PubSubTests.java +++ b/java/integTest/src/test/java/glide/PubSubTests.java @@ -307,8 +307,8 @@ public void exact_happy_path(boolean standalone, MessageReadMethod method) { @MethodSource("getTestScenarios") public void exact_happy_path_many_channels(boolean standalone, MessageReadMethod method) { skipTestsOnMac(); - int numChannels = 256; - int messagesPerChannel = 256; + int numChannels = 16; + int messagesPerChannel = 16; var messages = new ArrayList(numChannels * messagesPerChannel); ChannelMode mode = exact(standalone); Map> subscriptions = Map.of(mode, new HashSet<>()); @@ -366,8 +366,8 @@ public void sharded_pubsub_many_channels(MessageReadMethod method) { assumeTrue(SERVER_VERSION.isGreaterThanOrEqualTo("7.0.0"), "This feature added in version 7"); skipTestsOnMac(); - int numChannels = 256; - int pubsubMessagesPerChannel = 256; + int numChannels = 16; + int pubsubMessagesPerChannel = 16; var pubsubMessages = new ArrayList(numChannels * pubsubMessagesPerChannel); PubSubClusterChannelMode mode = PubSubClusterChannelMode.SHARDED; Map> subscriptions = Map.of(mode, new HashSet<>()); @@ -444,8 +444,8 @@ public void pattern_many_channels(boolean standalone, MessageReadMethod method) skipTestsOnMac(); String prefix = "channel."; GlideString pattern = gs(prefix + "*"); - int numChannels = 256; - int messagesPerChannel = 256; + int numChannels = 16; + int messagesPerChannel = 16; ChannelMode mode = standalone ? PubSubChannelMode.PATTERN : PubSubClusterChannelMode.PATTERN; var messages = new ArrayList(numChannels * messagesPerChannel); var subscriptions = Map.of(mode, Set.of(pattern)); @@ -482,8 +482,8 @@ public void combined_exact_and_pattern_one_client(boolean standalone, MessageRea skipTestsOnMac(); String prefix = "channel."; GlideString pattern = gs(prefix + "*"); - int numChannels = 256; - int messagesPerChannel = 256; + int numChannels = 16; + int messagesPerChannel = 16; var messages = new ArrayList(numChannels * messagesPerChannel); ChannelMode mode = standalone ? PubSubChannelMode.EXACT : PubSubClusterChannelMode.EXACT; Map> subscriptions = @@ -533,7 +533,7 @@ public void combined_exact_and_pattern_multiple_clients( skipTestsOnMac(); String prefix = "channel."; GlideString pattern = gs(prefix + "*"); - int numChannels = 256; + int numChannels = 16; var messages = new ArrayList(numChannels * 2); ChannelMode mode = exact(standalone); Map> subscriptions = Map.of(mode, new HashSet<>()); @@ -604,7 +604,7 @@ public void combined_exact_pattern_and_sharded_one_client(MessageReadMethod meth String prefix = "channel."; GlideString pattern = gs(prefix + "*"); String shardPrefix = "{shard}"; - int numChannels = 256; + int numChannels = 16; var messages = new ArrayList(numChannels * 2); var shardedMessages = new ArrayList(numChannels); Map> subscriptions = @@ -660,7 +660,7 @@ public void coexistense_of_sync_and_async_read() { String prefix = "channel."; String pattern = prefix + "*"; String shardPrefix = "{shard}"; - int numChannels = 256; + int numChannels = 16; var messages = new ArrayList(numChannels * 2); var shardedMessages = new ArrayList(numChannels); Map> subscriptions = @@ -742,7 +742,7 @@ public void combined_exact_pattern_and_sharded_multi_client(MessageReadMethod me String prefix = "channel."; GlideString pattern = gs(prefix + "*"); String shardPrefix = "{shard}"; - int numChannels = 256; + int numChannels = 16; var exactMessages = new ArrayList(numChannels); var patternMessages = new ArrayList(numChannels); var shardedMessages = new ArrayList(numChannels); diff --git a/java/integTest/src/test/java/glide/SharedClientTests.java b/java/integTest/src/test/java/glide/SharedClientTests.java index 3650c079e3..26a4144e96 100644 --- a/java/integTest/src/test/java/glide/SharedClientTests.java +++ b/java/integTest/src/test/java/glide/SharedClientTests.java @@ -32,7 +32,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -@Timeout(25) // seconds +@Timeout(35) // seconds public class SharedClientTests { private static GlideClient standaloneClient = null; @@ -47,10 +47,18 @@ public static void init() { clusterClient = GlideClusterClient.createClient(commonClusterClientConfig().requestTimeout(10000).build()) .get(); - clients = List.of(Arguments.of(standaloneClient), Arguments.of(clusterClient)); } + @SneakyThrows + @ParameterizedTest(autoCloseArguments = false) + @MethodSource("getClients") + public void validate_statistics(BaseClient client) { + assertFalse(client.getStatistics().isEmpty()); + // we expect 2 items in the statistics map + assertEquals(2, client.getStatistics().size()); + } + @AfterAll @SneakyThrows public static void teardown() { diff --git a/java/integTest/src/test/java/glide/TestConfiguration.java b/java/integTest/src/test/java/glide/TestConfiguration.java index 864e384e1d..812c06c301 100644 --- a/java/integTest/src/test/java/glide/TestConfiguration.java +++ b/java/integTest/src/test/java/glide/TestConfiguration.java @@ -16,6 +16,8 @@ public final class TestConfiguration { System.getProperty("test.server.standalone", "").split(","); public static final String[] CLUSTER_HOSTS = System.getProperty("test.server.cluster", "").split(","); + public static final String[] AZ_CLUSTER_HOSTS = + System.getProperty("test.server.azcluster", "").split(","); public static final Semver SERVER_VERSION; public static final boolean TLS = Boolean.parseBoolean(System.getProperty("test.server.tls", "")); diff --git a/java/integTest/src/test/java/glide/TestUtilities.java b/java/integTest/src/test/java/glide/TestUtilities.java index c6f8ef7201..ea03632c39 100644 --- a/java/integTest/src/test/java/glide/TestUtilities.java +++ b/java/integTest/src/test/java/glide/TestUtilities.java @@ -1,6 +1,7 @@ /** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ package glide; +import static glide.TestConfiguration.AZ_CLUSTER_HOSTS; import static glide.TestConfiguration.CLUSTER_HOSTS; import static glide.TestConfiguration.STANDALONE_HOSTS; import static glide.TestConfiguration.TLS; @@ -111,6 +112,17 @@ public static Map parseInfoResponseToMap(String serverInfo) { return builder.useTLS(TLS); } + public static GlideClusterClientConfiguration.GlideClusterClientConfigurationBuilder + azClusterClientConfig() { + var builder = GlideClusterClientConfiguration.builder(); + for (var host : AZ_CLUSTER_HOSTS) { + var parts = host.split(":"); + builder.address( + NodeAddress.builder().host(parts[0]).port(Integer.parseInt(parts[1])).build()); + } + return builder.useTLS(TLS); + } + /** * Deep traverse and compare two objects, including comparing content of all nested collections * recursively. Floating point numbers comparison performed with 1e-6 delta. diff --git a/java/integTest/src/test/java/glide/cluster/ClusterClientTests.java b/java/integTest/src/test/java/glide/cluster/ClusterClientTests.java index 288b31f5da..b777a0c9fe 100644 --- a/java/integTest/src/test/java/glide/cluster/ClusterClientTests.java +++ b/java/integTest/src/test/java/glide/cluster/ClusterClientTests.java @@ -6,6 +6,8 @@ import static glide.TestUtilities.getRandomString; import static glide.api.BaseClient.OK; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assumptions.assumeTrue; @@ -14,8 +16,11 @@ import glide.api.models.configuration.ServerCredentials; import glide.api.models.exceptions.ClosingException; import glide.api.models.exceptions.RequestException; +import java.util.Map; +import java.util.UUID; import java.util.concurrent.ExecutionException; import lombok.SneakyThrows; +import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; @@ -158,4 +163,117 @@ public void closed_client_throws_ExecutionException_with_ClosingException_as_cau assertThrows(ExecutionException.class, () -> client.set("foo", "bar").get()); assertTrue(executionException.getCause() instanceof ClosingException); } + + @SneakyThrows + @Test + public void test_update_connection_password() { + GlideClusterClient adminClient = + GlideClusterClient.createClient(commonClusterClientConfig().build()).get(); + String pwd = UUID.randomUUID().toString(); + + try (GlideClusterClient testClient = + GlideClusterClient.createClient(commonClusterClientConfig().build()).get()) { + // validate that we can use the client + assertNotNull(testClient.info().get()); + + // Update password without re-authentication + assertEquals(OK, testClient.updateConnectionPassword(pwd, false).get()); + + // Verify client still works with old auth + assertNotNull(testClient.info().get()); + + // Update server password + // Kill all other clients to force reconnection + assertEquals("OK", adminClient.configSet(Map.of("requirepass", pwd)).get()); + adminClient.customCommand(new String[] {"CLIENT", "KILL", "TYPE", "NORMAL"}).get(); + + // Verify client auto-reconnects with new password + assertNotNull(testClient.info().get()); + } finally { + adminClient.configSet(Map.of("requirepass", "")).get(); + adminClient.close(); + } + } + + @SneakyThrows + @Test + public void test_update_connection_password_auth_non_valid_pass() { + // Test Client fails on call to updateConnectionPassword with invalid parameters + try (GlideClusterClient testClient = + GlideClusterClient.createClient(commonClusterClientConfig().build()).get()) { + var emptyPasswordException = + assertThrows( + ExecutionException.class, () -> testClient.updateConnectionPassword("", true).get()); + assertInstanceOf(RequestException.class, emptyPasswordException.getCause()); + + var noPasswordException = + assertThrows( + ExecutionException.class, () -> testClient.updateConnectionPassword(true).get()); + assertInstanceOf(RequestException.class, noPasswordException.getCause()); + } + } + + @SneakyThrows + @Test + public void test_update_connection_password_no_server_auth() { + var pwd = UUID.randomUUID().toString(); + + try (GlideClusterClient testClient = + GlideClusterClient.createClient(commonClusterClientConfig().build()).get()) { + // validate that we can use the client + assertNotNull(testClient.info().get()); + + // Test that immediate re-authentication fails when no server password is set. + var exception = + assertThrows( + ExecutionException.class, () -> testClient.updateConnectionPassword(pwd, true).get()); + assertInstanceOf(RequestException.class, exception.getCause()); + } + } + + @SneakyThrows + @Test + public void test_update_connection_password_long() { + var pwd = RandomStringUtils.randomAlphabetic(1000); + + try (GlideClusterClient testClient = + GlideClusterClient.createClient(commonClusterClientConfig().build()).get()) { + // validate that we can use the client + assertNotNull(testClient.info().get()); + + // Test replacing connection password with a long password string. + assertEquals(OK, testClient.updateConnectionPassword(pwd, false).get()); + } + } + + @Timeout(50) + @SneakyThrows + @Test + public void test_replace_password_immediateAuth_wrong_password() { + var pwd = UUID.randomUUID().toString(); + var notThePwd = UUID.randomUUID().toString(); + + GlideClusterClient adminClient = + GlideClusterClient.createClient(commonClusterClientConfig().build()).get(); + try (GlideClusterClient testClient = + GlideClusterClient.createClient(commonClusterClientConfig().build()).get()) { + // validate that we can use the client + assertNotNull(testClient.info().get()); + + // set the password to something else + adminClient.configSet(Map.of("requirepass", notThePwd)).get(); + + // Test that re-authentication fails when using wrong password. + var exception = + assertThrows( + ExecutionException.class, () -> testClient.updateConnectionPassword(pwd, true).get()); + assertInstanceOf(RequestException.class, exception.getCause()); + + // But using something else password returns OK + assertEquals(OK, testClient.updateConnectionPassword(notThePwd, true).get()); + } finally { + adminClient.configSet(Map.of("requirepass", "")).get(); + adminClient.close(); + } + } } diff --git a/java/integTest/src/test/java/glide/cluster/CommandTests.java b/java/integTest/src/test/java/glide/cluster/CommandTests.java index ee2682a259..3a1150770a 100644 --- a/java/integTest/src/test/java/glide/cluster/CommandTests.java +++ b/java/integTest/src/test/java/glide/cluster/CommandTests.java @@ -107,7 +107,7 @@ import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; -@Timeout(10) // seconds +@Timeout(30) // seconds public class CommandTests { private static GlideClusterClient clusterClient = null; @@ -1660,7 +1660,7 @@ public void fcall_readonly_function() { assertEquals(libName, clusterClient.functionLoad(code, false).get()); // let replica sync with the primary node - assertEquals(1L, clusterClient.wait(1L, 3000L).get()); + assertEquals(1L, clusterClient.wait(1L, 4000L).get()); // fcall on a replica node should fail, because a function isn't guaranteed to be RO var executionException = diff --git a/java/integTest/src/test/java/glide/modules/JsonTests.java b/java/integTest/src/test/java/glide/modules/JsonTests.java index 5d6880ae2e..747a6078b6 100644 --- a/java/integTest/src/test/java/glide/modules/JsonTests.java +++ b/java/integTest/src/test/java/glide/modules/JsonTests.java @@ -2,22 +2,1227 @@ package glide.modules; import static glide.TestUtilities.commonClusterClientConfig; +import static glide.api.BaseClient.OK; +import static glide.api.models.GlideString.gs; +import static glide.api.models.configuration.RequestRoutingConfiguration.SimpleMultiNodeRoute.ALL_PRIMARIES; import static glide.api.models.configuration.RequestRoutingConfiguration.SimpleSingleNodeRoute.RANDOM; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import com.google.gson.JsonParser; import glide.api.GlideClusterClient; +import glide.api.commands.servermodules.Json; +import glide.api.models.GlideString; +import glide.api.models.commands.ConditionalChange; +import glide.api.models.commands.FlushMode; import glide.api.models.commands.InfoOptions.Section; +import glide.api.models.commands.json.JsonArrindexOptions; +import glide.api.models.commands.json.JsonGetOptions; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ExecutionException; import lombok.SneakyThrows; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; public class JsonTests { - @Test + + private static GlideClusterClient client; + + @BeforeAll @SneakyThrows - public void check_module_loaded() { - var client = + public static void init() { + client = GlideClusterClient.createClient(commonClusterClientConfig().requestTimeout(5000).build()) .get(); + client.flushall(FlushMode.SYNC, ALL_PRIMARIES).get(); + } + + @AfterAll + @SneakyThrows + public static void teardown() { + client.close(); + } + + @Test + @SneakyThrows + public void check_module_loaded() { var info = client.info(new Section[] {Section.MODULES}, RANDOM).get().getSingleValue(); assertTrue(info.contains("# json_core_metrics")); } + + @Test + @SneakyThrows + public void json_set_get() { + String key = UUID.randomUUID().toString(); + String jsonValue = "{\"a\": 1.0,\"b\": 2}"; + + assertEquals(OK, Json.set(client, key, "$", jsonValue).get()); + + String getResult = Json.get(client, key).get(); + + assertEquals(JsonParser.parseString(jsonValue), JsonParser.parseString(getResult)); + + String getResultWithMultiPaths = Json.get(client, key, new String[] {"$.a", "$.b"}).get(); + + assertEquals( + JsonParser.parseString("{\"$.a\":[1.0],\"$.b\":[2]}"), + JsonParser.parseString(getResultWithMultiPaths)); + + assertNull(Json.get(client, "non_existing_key").get()); + assertEquals("[]", Json.get(client, key, new String[] {"$.d"}).get()); + } + + @Test + @SneakyThrows + public void json_set_get_multiple_values() { + String key = UUID.randomUUID().toString(); + String jsonValue = "{\"a\": {\"c\": 1, \"d\": 4}, \"b\": {\"c\": 2}, \"c\": true}"; + + assertEquals(OK, Json.set(client, gs(key), gs("$"), gs(jsonValue)).get()); + + GlideString getResult = Json.get(client, gs(key), new GlideString[] {gs("$..c")}).get(); + + assertEquals( + JsonParser.parseString("[true, 1, 2]"), JsonParser.parseString(getResult.getString())); + + String getResultWithMultiPaths = Json.get(client, key, new String[] {"$..c", "$.c"}).get(); + + assertEquals( + JsonParser.parseString("{\"$..c\": [True, 1, 2], \"$.c\": [True]}"), + JsonParser.parseString(getResultWithMultiPaths)); + + assertEquals(OK, Json.set(client, key, "$..c", "\"new_value\"").get()); + String getResultAfterSetNewValue = Json.get(client, key, new String[] {"$..c"}).get(); + assertEquals( + JsonParser.parseString("[\"new_value\", \"new_value\", \"new_value\"]"), + JsonParser.parseString(getResultAfterSetNewValue)); + } + + @Test + @SneakyThrows + public void json_set_get_conditional_set() { + String key = UUID.randomUUID().toString(); + String jsonValue = "{\"a\": 1.0, \"b\": 2}"; + + assertNull(Json.set(client, key, "$", jsonValue, ConditionalChange.ONLY_IF_EXISTS).get()); + assertEquals( + OK, Json.set(client, key, "$", jsonValue, ConditionalChange.ONLY_IF_DOES_NOT_EXIST).get()); + assertNull(Json.set(client, key, "$.a", "4.5", ConditionalChange.ONLY_IF_DOES_NOT_EXIST).get()); + assertEquals("1.0", Json.get(client, key, new String[] {".a"}).get()); + assertEquals(OK, Json.set(client, key, "$.a", "4.5", ConditionalChange.ONLY_IF_EXISTS).get()); + assertEquals("4.5", Json.get(client, key, new String[] {".a"}).get()); + } + + @Test + @SneakyThrows + public void json_set_get_formatting() { + String key = UUID.randomUUID().toString(); + + assertEquals( + OK, + Json.set(client, key, "$", "{\"a\": 1.0, \"b\": 2, \"c\": {\"d\": 3, \"e\": 4}}").get()); + + String expectedGetResult = + "[\n" + + " {\n" + + " \"a\": 1.0,\n" + + " \"b\": 2,\n" + + " \"c\": {\n" + + " \"d\": 3,\n" + + " \"e\": 4\n" + + " }\n" + + " }\n" + + "]"; + String actualGetResult = + Json.get( + client, + key, + new String[] {"$"}, + JsonGetOptions.builder().indent(" ").newline("\n").space(" ").build()) + .get(); + assertEquals(expectedGetResult, actualGetResult); + + String expectedGetResult2 = + "[\n~{\n~~\"a\":*1.0,\n~~\"b\":*2,\n~~\"c\":*{\n~~~\"d\":*3,\n~~~\"e\":*4\n~~}\n~}\n]"; + String actualGetResult2 = + Json.get( + client, + key, + new String[] {"$"}, + JsonGetOptions.builder().indent("~").newline("\n").space("*").build()) + .get(); + assertEquals(expectedGetResult2, actualGetResult2); + } + + @Test + @SneakyThrows + public void arrappend() { + String key = UUID.randomUUID().toString(); + String doc = "{\"a\": 1, \"b\": [\"one\", \"two\"]}"; + + assertEquals(OK, Json.set(client, key, "$", doc).get()); + + assertArrayEquals( + new Object[] {3L}, + (Object[]) Json.arrappend(client, key, "$.b", new String[] {"\"three\""}).get()); + assertEquals( + 5L, Json.arrappend(client, key, ".b", new String[] {"\"four\"", "\"five\""}).get()); + + String getResult = Json.get(client, key, new String[] {"$"}).get(); + String expectedGetResult = + "[{\"a\": 1, \"b\": [\"one\", \"two\", \"three\", \"four\", \"five\"]}]"; + assertEquals(JsonParser.parseString(expectedGetResult), JsonParser.parseString(getResult)); + + assertArrayEquals( + new Object[] {null}, + (Object[]) Json.arrappend(client, key, "$.a", new String[] {"\"value\""}).get()); + + // JSONPath, path doesn't exist + assertArrayEquals( + new Object[] {}, + (Object[]) + Json.arrappend(client, gs(key), gs("$.c"), new GlideString[] {gs("\"value\"")}).get()); + + // Legacy path, path doesn't exist + assertThrows( + ExecutionException.class, + () -> Json.arrappend(client, key, ".c", new String[] {"\"value\""}).get()); + + // Legacy path, the JSON value at path is not a array + assertThrows( + ExecutionException.class, + () -> Json.arrappend(client, key, ".a", new String[] {"\"value\""}).get()); + + assertThrows( + ExecutionException.class, + () -> Json.arrappend(client, "non_existing_key", "$.b", new String[] {"\"six\""}).get()); + + assertThrows( + ExecutionException.class, + () -> Json.arrappend(client, "non_existing_key", ".b", new String[] {"\"six\""}).get()); + } + + @Test + @SneakyThrows + public void arrindex() { + String key1 = UUID.randomUUID().toString(); + String key2 = UUID.randomUUID().toString(); + String key3 = UUID.randomUUID().toString(); + + String doc1 = + "{\"a\": [1, 3, true], \"b\": {\"a\": [3, 4, [\"value\", 3, false], 5], \"c\": {\"a\":" + + " 42}}}"; + + String doc2 = + "{\"a\": [1, 3, true, \"foo\", \"meow\", \"m\", \"foo\", \"lol\", false], \"b\": {\"a\":" + + " [3, 4, [\"value\", 3, false], 5], \"c\": {\"a\": 42}, \"empty\": []}}"; + + String doc3 = "{\"a\": 123123}"; + + assertEquals("OK", Json.set(client, key1, "$", doc1).get()); + assertArrayEquals( + new Object[] {2L, -1L, null}, (Object[]) Json.arrindex(client, key1, "$..a", "true").get()); + + assertArrayEquals( + new Object[] {1L, 0L, null}, + (Object[]) Json.arrindex(client, gs(key1), gs("$..a"), gs("3")).get()); + + assertEquals("OK", Json.set(client, key2, "$", doc2).get()); + + assertArrayEquals( + new Object[] {6L, -1L, null}, + (Object[]) + Json.arrindex(client, key2, "$..a", "\"foo\"", new JsonArrindexOptions(6L, 8L)).get()); + + assertArrayEquals( + new Object[] {-1L, -1L, null}, + (Object[]) + Json.arrindex(client, key2, "$..a", "null", new JsonArrindexOptions(6L, 8L)).get()); + assertArrayEquals( + new Object[] {-1L, -1L, null}, + (Object[]) + Json.arrindex(client, gs(key2), gs("$..a"), gs("null"), new JsonArrindexOptions(6L, 8L)) + .get()); + + assertArrayEquals( + new Object[] {6L, -1L, null}, + (Object[]) + Json.arrindex( + client, gs(key2), gs("$..a"), gs("\"foo\""), new JsonArrindexOptions(6L, 8L)) + .get()); + + assertArrayEquals( + new Object[] {6L, -1L, null}, + (Object[]) + Json.arrindex(client, key2, "$..a", "\"foo\"", new JsonArrindexOptions(6L)).get()); + + // value doesn't exist + assertArrayEquals( + new Object[] {null}, + (Object[]) + Json.arrindex(client, key1, "$..b", "true", new JsonArrindexOptions(1L, 3L)).get()); + + // with legacy path + assertEquals(2L, Json.arrindex(client, key1, ".a", "true").get()); + + // element doesn't exist + assertEquals(-1L, Json.arrindex(client, key1, ".a", "\"nonexistent-element\"").get()); + + // empty array + assertThrows( + ExecutionException.class, + () -> Json.arrindex(client, key1, ".empty", "\"nonexistent-element\"").get()); + + assertEquals("OK", Json.set(client, key3, "$", doc3).get()); + + // wrong type error + assertThrows(ExecutionException.class, () -> Json.arrindex(client, key3, ".a", "42").get()); + + // JsonScalar is null + assertThrows(ExecutionException.class, () -> Json.arrindex(client, key3, ".a", "null").get()); + + // start index is larger than the end index + assertEquals( + -1L, Json.arrindex(client, key2, ".a", "false", new JsonArrindexOptions(4L, 2L)).get()); + + // end index is larger than the length of the array + assertEquals( + 8L, + Json.arrindex(client, key2, ".a", "false", new JsonArrindexOptions(0L, 12378798798721L)) + .get()); + } + + @Test + @SneakyThrows + public void arrinsert() { + String key = UUID.randomUUID().toString(); + + String doc = + "{" + + "\"a\": []," + + "\"b\": { \"a\": [1, 2, 3, 4] }," + + "\"c\": { \"a\": \"not an array\" }," + + "\"d\": [{ \"a\": [\"x\", \"y\"] }, { \"a\": [[\"foo\"]] }]," + + "\"e\": [{ \"a\": 42 }, { \"a\": {} }]," + + "\"f\": { \"a\": [true, false, null] }" + + "}"; + assertEquals("OK", Json.set(client, key, "$", doc).get()); + + String[] values = + new String[] { + "\"string_value\"", "123", "{\"key\": \"value\"}", "true", "null", "[\"bar\"]" + }; + var res = Json.arrinsert(client, key, "$..a", 0, values).get(); + + doc = Json.get(client, key).get(); + var expected = + "{" + + " \"a\": [\"string_value\", 123, {\"key\": \"value\"}, true, null, [\"bar\"]]," + + " \"b\": {" + + " \"a\": [" + + " \"string_value\"," + + " 123," + + " {\"key\": \"value\"}," + + " true," + + " null," + + " [\"bar\"]," + + " 1," + + " 2," + + " 3," + + " 4" + + " ]" + + " }," + + " \"c\": {\"a\": \"not an array\"}," + + " \"d\": [" + + " {" + + " \"a\": [" + + " \"string_value\"," + + " 123," + + " {\"key\": \"value\"}," + + " true," + + " null," + + " [\"bar\"]," + + " \"x\"," + + " \"y\"" + + " ]" + + " }," + + " {" + + " \"a\": [" + + " \"string_value\"," + + " 123," + + " {\"key\": \"value\"}," + + " true," + + " null," + + " [\"bar\"]," + + " [\"foo\"]" + + " ]" + + " }" + + " ]," + + " \"e\": [{\"a\": 42}, {\"a\": {}}]," + + " \"f\": {" + + " \"a\": [" + + " \"string_value\"," + + " 123," + + " {\"key\": \"value\"}," + + " true," + + " null," + + " [\"bar\"]," + + " true," + + " false," + + " null" + + " ]" + + " }" + + "}"; + + assertEquals(JsonParser.parseString(expected), JsonParser.parseString(doc)); + } + + @Test + @SneakyThrows + public void debug() { + String key = UUID.randomUUID().toString(); + + var doc = + "{ \"key1\": 1, \"key2\": 3.5, \"key3\": {\"nested_key\": {\"key1\": [4, 5]}}, \"key4\":" + + " [1, 2, 3], \"key5\": 0, \"key6\": \"hello\", \"key7\": null, \"key8\":" + + " {\"nested_key\": {\"key1\": 3.5953862697246314e307}}, \"key9\":" + + " 3.5953862697246314e307, \"key10\": true }"; + assertEquals("OK", Json.set(client, key, "$", doc).get()); + + assertArrayEquals(new Object[] {1L}, (Object[]) Json.debugFields(client, key, "$.key1").get()); + + assertEquals(2L, Json.debugFields(client, gs(key), gs(".key3.nested_key.key1")).get()); + + assertArrayEquals( + new Object[] {16L}, (Object[]) Json.debugMemory(client, key, "$.key4[2]").get()); + + assertEquals(16L, Json.debugMemory(client, gs(key), gs(".key6")).get()); + + assertEquals(504L, Json.debugMemory(client, key).get()); + assertEquals(19L, Json.debugFields(client, gs(key)).get()); + } + + @Test + @SneakyThrows + public void arrlen() { + String key = UUID.randomUUID().toString(); + + String doc = "{\"a\": [1, 2, 3], \"b\": {\"a\": [1, 2], \"c\": {\"a\": 42}}}"; + assertEquals("OK", Json.set(client, key, "$", doc).get()); + + var res = Json.arrlen(client, key, "$.a").get(); + assertArrayEquals(new Object[] {3L}, (Object[]) res); + + res = Json.arrlen(client, key, "$..a").get(); + assertArrayEquals(new Object[] {3L, 2L, null}, (Object[]) res); + + // Legacy path retrieves the first array match at ..a + res = Json.arrlen(client, gs(key), gs("..a")).get(); + assertEquals(3L, res); + + doc = "[1, 2, true, null, \"tree\"]"; + assertEquals("OK", Json.set(client, key, "$", doc).get()); + + // no path + res = Json.arrlen(client, key).get(); + assertEquals(5L, res); + res = Json.arrlen(client, gs(key)).get(); + assertEquals(5L, res); + } + + @Test + @SneakyThrows + public void arrpop() { + String key = UUID.randomUUID().toString(); + String doc = + "{\"a\": [1, 2, true], \"b\": {\"a\": [3, 4, [\"value\", 3, false], 5], \"c\": {\"a\":" + + " 42}}}"; + assertEquals(OK, Json.set(client, key, "$", doc).get()); + + var res = Json.arrpop(client, key, "$.a", 1).get(); + assertArrayEquals(new Object[] {"2"}, (Object[]) res); + + res = Json.arrpop(client, gs(key), gs("$..a")).get(); + assertArrayEquals(new Object[] {gs("true"), gs("5"), null}, (Object[]) res); + + res = Json.arrpop(client, key, "..a").get(); + assertEquals("1", res); + + // Even if only one array element was returned, ensure second array at `..a` was popped + doc = Json.get(client, key, new String[] {"$..a"}).get(); + assertEquals("[[],[3,4],42]", doc); + + // Out of index + res = Json.arrpop(client, key, "$..a", 10).get(); + assertArrayEquals(new Object[] {null, "4", null}, (Object[]) res); + + // pop without options + assertEquals(OK, Json.set(client, key, "$", doc).get()); + res = Json.arrpop(client, key).get(); + assertEquals("42", res); + res = Json.arrpop(client, gs(key)).get(); + assertEquals(gs("[3,4]"), res); + } + + @Test + @SneakyThrows + public void clear() { + String key = UUID.randomUUID().toString(); + String json = + "{\"obj\": {\"a\":1, \"b\":2}, \"arr\":[1, 2, 3], \"str\": \"foo\", \"bool\": true," + + " \"int\": 42, \"float\": 3.14, \"nullVal\": null}"; + + assertEquals("OK", Json.set(client, key, "$", json).get()); + + assertEquals(6L, Json.clear(client, key, "$.*").get()); + var doc = Json.get(client, key, new String[] {"$"}).get(); + assertEquals( + "[{\"obj\":{},\"arr\":[],\"str\":\"\",\"bool\":false,\"int\":0,\"float\":0.0,\"nullVal\":null}]", + doc); + assertEquals(0L, Json.clear(client, gs(key), gs(".*")).get()); + + assertEquals(1L, Json.clear(client, gs(key)).get()); + doc = Json.get(client, key, new String[] {"$"}).get(); + assertEquals("[{}]", doc); + + assertThrows( + ExecutionException.class, () -> Json.clear(client, UUID.randomUUID().toString()).get()); + } + + @Test + @SneakyThrows + void numincrby() { + String key = UUID.randomUUID().toString(); + + var jsonValue = + "{" + + " \"key1\": 1," + + " \"key2\": 3.5," + + " \"key3\": {\"nested_key\": {\"key1\": [4, 5]}}," + + " \"key4\": [1, 2, 3]," + + " \"key5\": 0," + + " \"key6\": \"hello\"," + + " \"key7\": null," + + " \"key8\": {\"nested_key\": {\"key1\": 69}}," + + " \"key9\": 1.7976931348623157e308" + + "}"; + + // Set the initial JSON document at the key + assertEquals("OK", Json.set(client, key, "$", jsonValue).get()); + + // Test JSONPath + // Increment integer value (key1) by 5 + String result = Json.numincrby(client, key, "$.key1", 5).get(); + assertEquals("[6]", result); // Expect 1 + 5 = 6 + + // Increment float value (key2) by 2.5 + result = Json.numincrby(client, key, "$.key2", 2.5).get(); + assertEquals("[6]", result); // Expect 3.5 + 2.5 = 6 + + // Increment nested object (key3.nested_key.key1[0]) by 7 + result = Json.numincrby(client, key, "$.key3.nested_key.key1[1]", 7).get(); + assertEquals("[12]", result); // Expect 4 + 7 = 12 + + // Increment array element (key4[1]) by 1 + result = Json.numincrby(client, key, "$.key4[1]", 1).get(); + assertEquals("[3]", result); // Expect 2 + 1 = 3 + + // Increment zero value (key5) by 10.23 (float number) + result = Json.numincrby(client, key, "$.key5", 10.23).get(); + assertEquals("[10.23]", result); // Expect 0 + 10.23 = 10.23 + + // Increment a string value (key6) by a number + result = Json.numincrby(client, key, "$.key6", 99).get(); + assertEquals("[null]", result); // Expect null + + // Increment a None value (key7) by a number + result = Json.numincrby(client, key, "$.key7", 51).get(); + assertEquals("[null]", result); // Expect null + + // Check increment for all numbers in the document using JSON Path (First Null: key3 as an + // entire object. Second Null: The path checks under key3, which is an object, for numeric + // values). + result = Json.numincrby(client, key, "$..*", 5).get(); + assertEquals( + "[11,11,null,null,15.23,null,null,null,1.7976931348623157e+308,null,null,9,17,6,8,8,null,74]", + result); + + // Check for multiple path match in enhanced + result = Json.numincrby(client, key, "$..key1", 1).get(); + assertEquals("[12,null,75]", result); // Expect null + + // Check for non existent path in JSONPath + result = Json.numincrby(client, key, "$.key10", 51).get(); + assertEquals("[]", result); // Expect Empty Array + + // Check for non existent key in JSONPath + assertThrows( + ExecutionException.class, + () -> Json.numincrby(client, "non_existent_key", "$.key10", 51).get()); + + // Check for Overflow in JSONPath + assertThrows( + ExecutionException.class, + () -> Json.numincrby(client, key, "$.key9", 1.7976931348623157e308).get()); + + // Decrement integer value (key1) by 12 + result = Json.numincrby(client, key, "$.key1", -12).get(); + assertEquals("[0]", result); // Expect 12 - 12 = 0 + + // Decrement integer value (key1) by 0.5 + result = Json.numincrby(client, key, "$.key1", -0.5).get(); + assertEquals("[-0.5]", result); // Expect 0 - 0.5 = -0.5 + + // Test Legacy Path + // Increment float value (key1) by 5 (integer) + result = Json.numincrby(client, key, "key1", 5).get(); + assertEquals("4.5", result); // Expect -0.5 + 5 = 4.5 + + // Decrement float value (key1) by 5.5 (integer) + result = Json.numincrby(client, key, "key1", -5.5).get(); + assertEquals("-1", result); // Expect 4.5 - 5.5 = -1 + + // Increment int value (key2) by 2.5 (a float number) + result = Json.numincrby(client, key, "key2", 2.5).get(); + assertEquals("13.5", result); // Expect 11 + 2.5 = 13.5 + + // Increment nested value (key3.nested_key.key1[0]) by 7 + result = Json.numincrby(client, key, "key3.nested_key.key1[0]", 7).get(); + assertEquals("16", result); // Expect 9 + 7 = 16 + + // Increment array element (key4[1]) by 1 + result = Json.numincrby(client, key, "key4[1]", 1).get(); + assertEquals("9", result); // Expect 8 + 1 = 9 + + // Increment a float value (key5) by 10.2 (a float number) + result = Json.numincrby(client, key, "key5", 10.2).get(); + assertEquals("25.43", result); // Expect 15.23 + 10.2 = 25.43 + + // Check for multiple path match in legacy and assure that the result of the last updated value + // is returned + result = Json.numincrby(client, key, "..key1", 1).get(); + assertEquals("76", result); + + // Check if the rest of the key1 path matches were updated and not only the last value + result = Json.get(client, key, new String[] {"$..key1"}).get(); + assertEquals( + "[0,[16,17],76]", + result); // First is 0 as 0 + 0 = 0, Second doesn't change as its an array type + // (non-numeric), third is 76 as 0 + 76 = 0 + + // Check for non existent path in legacy + assertThrows(ExecutionException.class, () -> Json.numincrby(client, key, ".key10", 51).get()); + + // Check for non existent key in legacy + assertThrows( + ExecutionException.class, + () -> Json.numincrby(client, "non_existent_key", ".key10", 51).get()); + + // Check for Overflow in legacy + assertThrows( + ExecutionException.class, + () -> Json.numincrby(client, key, ".key9", 1.7976931348623157e308).get()); + + // Binary tests + // Binary integer test + GlideString binaryResult = Json.numincrby(client, gs(key), gs("key4[1]"), 1).get(); + assertEquals(gs("10"), binaryResult); // Expect 9 + 1 = 10 + + // Binary float test + binaryResult = Json.numincrby(client, gs(key), gs("key5"), 1.0).get(); + assertEquals(gs("26.43"), binaryResult); // Expect 25.43 + 1.0 = 26.43 + } + + @Test + @SneakyThrows + void nummultby() { + String key = UUID.randomUUID().toString(); + var jsonValue = + "{" + + " \"key1\": 1," + + " \"key2\": 3.5," + + " \"key3\": {\"nested_key\": {\"key1\": [4, 5]}}," + + " \"key4\": [1, 2, 3]," + + " \"key5\": 0," + + " \"key6\": \"hello\"," + + " \"key7\": null," + + " \"key8\": {\"nested_key\": {\"key1\": 69}}," + + " \"key9\": 3.5953862697246314e307" + + "}"; + + // Set the initial JSON document at the key + assertEquals("OK", Json.set(client, key, "$", jsonValue).get()); + + // Test JSONPath + // Multiply integer value (key1) by 5 + String result = Json.nummultby(client, key, "$.key1", 5).get(); + assertEquals("[5]", result); // Expect 1 * 5 = 5 + + // Multiply float value (key2) by 2.5 + result = Json.nummultby(client, key, "$.key2", 2.5).get(); + assertEquals("[8.75]", result); // Expect 3.5 * 2.5 = 8.75 + + // Multiply nested object (key3.nested_key.key1[1]) by 7 + result = Json.nummultby(client, key, "$.key3.nested_key.key1[1]", 7).get(); + assertEquals("[35]", result); // Expect 5 * 7 = 35 + + // Multiply array element (key4[1]) by 1 + result = Json.nummultby(client, key, "$.key4[1]", 1).get(); + assertEquals("[2]", result); // Expect 2 * 1 = 2 + + // Multiply zero value (key5) by 10.23 (float number) + result = Json.nummultby(client, key, "$.key5", 10.23).get(); + assertEquals("[0]", result); // Expect 0 * 10.23 = 0 + + // Multiply a string value (key6) by a number + result = Json.nummultby(client, key, "$.key6", 99).get(); + assertEquals("[null]", result); // Expect null + + // Multiply a None value (key7) by a number + result = Json.nummultby(client, key, "$.key7", 51).get(); + assertEquals("[null]", result); // Expect null + + // Check multiplication for all numbers in the document using JSON Path + // key1: 5 * 5 = 25 + // key2: 8.75 * 5 = 43.75 + // key3.nested_key.key1[0]: 4 * 5 = 20 + // key3.nested_key.key1[1]: 35 * 5 = 175 + // key4[0]: 1 * 5 = 5 + // key4[1]: 2 * 5 = 10 + // key4[2]: 3 * 5 = 15 + // key5: 0 * 5 = 0 + // key8.nested_key.key1: 69 * 5 = 345 + // key9: 3.5953862697246314e307 * 5 = 1.7976931348623157e308 + result = Json.nummultby(client, key, "$..*", 5).get(); + assertEquals( + "[25,43.75,null,null,0,null,null,null,1.7976931348623157e+308,null,null,20,175,5,10,15,null,345]", + result); + + // Check for multiple path matches in JSONPath + // key1: 25 * 2 = 50 + // key8.nested_key.key1: 345 * 2 = 690 + result = Json.nummultby(client, key, "$..key1", 2).get(); + assertEquals("[50,null,690]", result); // After previous multiplications + + // Check for non-existent path in JSONPath + result = Json.nummultby(client, key, "$.key10", 51).get(); + assertEquals("[]", result); // Expect Empty Array + + // Check for non-existent key in JSONPath + assertThrows( + ExecutionException.class, + () -> Json.nummultby(client, "non_existent_key", "$.key10", 51).get()); + + // Check for Overflow in JSONPath + assertThrows( + ExecutionException.class, + () -> Json.nummultby(client, key, "$.key9", 1.7976931348623157e308).get()); + + // Multiply integer value (key1) by -12 + result = Json.nummultby(client, key, "$.key1", -12).get(); + assertEquals("[-600]", result); // Expect 50 * -12 = -600 + + // Multiply integer value (key1) by -0.5 + result = Json.nummultby(client, key, "$.key1", -0.5).get(); + assertEquals("[300]", result); // Expect -600 * -0.5 = 300 + + // Test Legacy Path + // Multiply int value (key1) by 5 (integer) + result = Json.nummultby(client, key, "key1", 5).get(); + assertEquals("1500", result); // Expect 300 * 5 = -1500 + + // Multiply int value (key1) by -5.5 (float number) + result = Json.nummultby(client, key, "key1", -5.5).get(); + assertEquals("-8250", result); // Expect -150 * -5.5 = -8250 + + // Multiply int float (key2) by 2.5 (a float number) + result = Json.nummultby(client, key, "key2", 2.5).get(); + assertEquals("109.375", result); // Expect 43.75 * 2.5 = 109.375 + + // Multiply nested value (key3.nested_key.key1[0]) by 7 + result = Json.nummultby(client, key, "key3.nested_key.key1[0]", 7).get(); + assertEquals("140", result); // Expect 20 * 7 = 140 + + // Multiply array element (key4[1]) by 1 + result = Json.nummultby(client, key, "key4[1]", 1).get(); + assertEquals("10", result); // Expect 10 * 1 = 10 + + // Multiply a float value (key5) by 10.2 (a float number) + result = Json.nummultby(client, key, "key5", 10.2).get(); + assertEquals("0", result); // Expect 0 * 10.2 = 0 + + // Check for multiple path matches in legacy and assure that the result of the last updated + // value is returned + // last updated value is key8.nested_key.key1: 690 * 2 = 1380 + result = Json.nummultby(client, key, "..key1", 2).get(); + assertEquals("1380", result); // Expect the last updated key1 value multiplied by 2 + + // Check if the rest of the key1 path matches were updated and not only the last value + result = Json.get(client, key, new String[] {"$..key1"}).get(); + assertEquals(result, "[-16500,[140,175],1380]"); + + // Check for non-existent path in legacy + assertThrows(ExecutionException.class, () -> Json.nummultby(client, key, ".key10", 51).get()); + + // Check for non-existent key in legacy + assertThrows( + ExecutionException.class, + () -> Json.nummultby(client, "non_existent_key", ".key10", 51).get()); + + // Check for Overflow in legacy + assertThrows( + ExecutionException.class, + () -> Json.nummultby(client, key, ".key9", 1.7976931348623157e308).get()); + + // Binary tests + // Binary integer test + GlideString binaryResult = Json.nummultby(client, gs(key), gs("key4[1]"), 1).get(); + assertEquals(gs("10"), binaryResult); // Expect 10 * 1 = 10 + + // Binary float test + binaryResult = Json.nummultby(client, gs(key), gs("key5"), 10.2).get(); + assertEquals(gs("0"), binaryResult); // Expect 0 * 10.2 = 0 + } + + @Test + @SneakyThrows + public void arrtrim() { + String key = UUID.randomUUID().toString(); + + String doc = + "{\"a\": [0, 1, 2, 3, 4, 5, 6, 7, 8], \"b\": {\"a\": [0, 9, 10, 11, 12, 13], \"c\": {\"a\":" + + " 42}}}"; + assertEquals("OK", Json.set(client, key, "$", doc).get()); + + // Basic trim + var res = Json.arrtrim(client, key, "$..a", 1, 7).get(); + assertArrayEquals(new Object[] {7L, 5L, null}, (Object[]) res); + + String getResult = Json.get(client, key, new String[] {"$..a"}).get(); + String expectedGetResult = "[[1, 2, 3, 4, 5, 6, 7], [9, 10, 11, 12, 13], 42]"; + assertEquals(JsonParser.parseString(expectedGetResult), JsonParser.parseString(getResult)); + + // Test end >= size (should be treated as size-1) + res = Json.arrtrim(client, key, "$.a", 0, 10).get(); + assertArrayEquals(new Object[] {7L}, (Object[]) res); + res = Json.arrtrim(client, key, ".a", 0, 10).get(); + assertEquals(7L, res); + + // Test negative start (should be treated as 0) + res = Json.arrtrim(client, key, "$.a", -1, 5).get(); + assertArrayEquals(new Object[] {6L}, (Object[]) res); + res = Json.arrtrim(client, key, ".a", -1, 5).get(); + assertEquals(6L, res); + + // Test start >= size (should empty the array) + res = Json.arrtrim(client, key, "$.a", 7, 10).get(); + assertArrayEquals(new Object[] {0L}, (Object[]) res); + + assertEquals("OK", Json.set(client, key, ".a", "[\"a\", \"b\", \"c\"]").get()); + res = Json.arrtrim(client, key, ".a", 7, 10).get(); + assertEquals(0L, res); + + // Test start > end (should empty the array) + res = Json.arrtrim(client, key, "$..a", 2, 1).get(); + assertArrayEquals(new Object[] {0L, 0L, null}, (Object[]) res); + + assertEquals("OK", Json.set(client, key, ".a", "[\"a\", \"b\", \"c\", \"d\"]").get()); + res = Json.arrtrim(client, key, "..a", 2, 1).get(); + assertEquals(0L, res); + + // Multiple path match + assertEquals("OK", Json.set(client, key, "$", doc).get()); + res = Json.arrtrim(client, key, "..a", 1, 10).get(); + assertEquals(8L, res); + + getResult = Json.get(client, key, new String[] {"$..a"}).get(); + expectedGetResult = "[[1,2,3,4,5,6,7,8], [9,10,11,12,13], 42]"; + assertEquals(JsonParser.parseString(expectedGetResult), JsonParser.parseString(getResult)); + + // Test with non-existing path + var exception = + assertThrows( + ExecutionException.class, () -> Json.arrtrim(client, key, ".non_existing", 0, 1).get()); + + res = Json.arrtrim(client, key, "$.non_existing", 0, 1).get(); + assertArrayEquals(new Object[] {}, (Object[]) res); + + // Test with non-array path + res = Json.arrtrim(client, key, "$", 0, 1).get(); + assertArrayEquals(new Object[] {null}, (Object[]) res); + + exception = + assertThrows(ExecutionException.class, () -> Json.arrtrim(client, key, ".", 0, 1).get()); + + // Test with non-existing key + exception = + assertThrows( + ExecutionException.class, + () -> Json.arrtrim(client, "non_existing_key", "$", 0, 1).get()); + + exception = + assertThrows( + ExecutionException.class, + () -> Json.arrtrim(client, "non_existing_key", ".", 0, 1).get()); + + // Test with empty array + assertEquals("OK", Json.set(client, key, "$.empty", "[]").get()); + res = Json.arrtrim(client, key, "$.empty", 0, 1).get(); + assertArrayEquals(new Object[] {0L}, (Object[]) res); + res = Json.arrtrim(client, key, ".empty", 0, 1).get(); + assertEquals(0L, res); + } + + @Test + @SneakyThrows + public void objlen() { + String key = UUID.randomUUID().toString(); + + String doc = "{\"a\": 1.0, \"b\": {\"a\": {\"x\": 1, \"y\": 2}, \"b\": 2.5, \"c\": true}}"; + assertEquals("OK", Json.set(client, key, "$", doc).get()); + + var res = Json.objlen(client, key, "$..").get(); + assertArrayEquals(new Object[] {2L, 3L, 2L}, (Object[]) res); + + res = Json.objlen(client, gs(key), gs("..b")).get(); + assertEquals(3L, res); + + // without path + res = Json.objlen(client, key).get(); + assertEquals(2L, res); + res = Json.objlen(client, gs(key)).get(); + assertEquals(2L, res); + } + + @Test + @SneakyThrows + public void json_del() { + String key = UUID.randomUUID().toString(); + assertEquals( + OK, + Json.set(client, key, "$", "{\"a\": 1.0, \"b\": {\"a\": 1, \"b\": 2.5, \"c\": true}}") + .get()); + assertEquals(2L, Json.del(client, key, "$..a").get()); + assertEquals("[]", Json.get(client, key, new String[] {"$..a"}).get()); + String expectedGetResult = "{\"b\": {\"b\": 2.5, \"c\": true}}"; + String actualGetResult = Json.get(client, key).get(); + assertEquals( + JsonParser.parseString(expectedGetResult), JsonParser.parseString(actualGetResult)); + + assertEquals(1L, Json.del(client, gs(key), gs("$")).get()); + assertEquals(0L, Json.del(client, key).get()); + assertNull(Json.get(client, key, new String[] {"$"}).get()); + } + + @Test + @SneakyThrows + public void objkeys() { + String key = UUID.randomUUID().toString(); + + String doc = "{\"a\": 1.0, \"b\": {\"a\": {\"x\": 1, \"y\": 2}, \"b\": 2.5, \"c\": true}}"; + assertEquals("OK", Json.set(client, key, "$", doc).get()); + + var res = Json.objkeys(client, key, "..").get(); + assertArrayEquals(new Object[] {"a", "b"}, res); + + res = Json.objkeys(client, gs(key), gs("$..b")).get(); + assertArrayEquals(new Object[][] {{gs("a"), gs("b"), gs("c")}, {}}, res); + + // without path + res = Json.objkeys(client, key).get(); + assertArrayEquals(new Object[] {"a", "b"}, res); + res = Json.objkeys(client, gs(key)).get(); + assertArrayEquals(new Object[] {gs("a"), gs("b")}, res); + } + + @Test + @SneakyThrows + public void mget() { + String key1 = UUID.randomUUID().toString(); + String key2 = UUID.randomUUID().toString(); + var data = + Map.of( + key1, "{\"a\": 1, \"b\": [\"one\", \"two\"]}", + key2, "{\"a\": 1, \"c\": false}"); + + for (var entry : data.entrySet()) { + assertEquals("OK", Json.set(client, entry.getKey(), "$", entry.getValue()).get()); + } + + var res1 = + Json.mget(client, new String[] {key1, key2, UUID.randomUUID().toString()}, "$.c").get(); + assertArrayEquals(new String[] {"[]", "[false]", null}, res1); + + var res2 = Json.mget(client, new GlideString[] {gs(key1), gs(key2)}, gs(".b[*]")).get(); + assertArrayEquals(new GlideString[] {gs("\"one\""), null}, res2); + } + + @Test + @SneakyThrows + public void json_forget() { + String key = UUID.randomUUID().toString(); + assertEquals( + OK, + Json.set(client, key, "$", "{\"a\": 1.0, \"b\": {\"a\": 1, \"b\": 2.5, \"c\": true}}") + .get()); + assertEquals(2L, Json.forget(client, key, "$..a").get()); + assertEquals("[]", Json.get(client, key, new String[] {"$..a"}).get()); + String expectedGetResult = "{\"b\": {\"b\": 2.5, \"c\": true}}"; + String actualGetResult = Json.get(client, key).get(); + assertEquals( + JsonParser.parseString(expectedGetResult), JsonParser.parseString(actualGetResult)); + + assertEquals(1L, Json.forget(client, gs(key), gs("$")).get()); + assertEquals(0L, Json.forget(client, key).get()); + assertNull(Json.get(client, key, new String[] {"$"}).get()); + } + + @Test + @SneakyThrows + public void toggle() { + String key = UUID.randomUUID().toString(); + String key2 = UUID.randomUUID().toString(); + String doc = "{\"bool\": true, \"nested\": {\"bool\": false, \"nested\": {\"bool\": 10}}}"; + + assertEquals("OK", Json.set(client, key, "$", doc).get()); + + assertArrayEquals( + new Object[] {false, true, null}, (Object[]) Json.toggle(client, key, "$..bool").get()); + + assertEquals(true, Json.toggle(client, gs(key), gs("bool")).get()); + + assertArrayEquals(new Object[] {}, (Object[]) Json.toggle(client, key, "$.non_existing").get()); + assertArrayEquals(new Object[] {null}, (Object[]) Json.toggle(client, key, "$.nested").get()); + + // testing behaviour with default path + assertEquals("OK", Json.set(client, key2, ".", "true").get()); + assertEquals(false, Json.toggle(client, key2).get()); + assertEquals(true, Json.toggle(client, gs(key2)).get()); + + // expect request errors + assertThrows(ExecutionException.class, () -> Json.toggle(client, key, "nested").get()); + assertThrows(ExecutionException.class, () -> Json.toggle(client, key, ".non_existing").get()); + assertThrows( + ExecutionException.class, () -> Json.toggle(client, "non_existing_key", "$").get()); + } + + @Test + @SneakyThrows + public void strappend() { + String key = UUID.randomUUID().toString(); + String jsonValue = "{\"a\": \"foo\", \"nested\": {\"a\": \"hello\"}, \"nested2\": {\"a\": 31}}"; + assertEquals("OK", Json.set(client, key, "$", jsonValue).get()); + + assertArrayEquals( + new Object[] {6L, 8L, null}, + (Object[]) Json.strappend(client, key, "\"bar\"", "$..a").get()); + assertEquals(9L, (Long) Json.strappend(client, key, "\"foo\"", "a").get()); + + String jsonStr = Json.get(client, key, new String[] {"."}).get(); + assertEquals( + "{\"a\":\"foobarfoo\",\"nested\":{\"a\":\"hellobar\"},\"nested2\":{\"a\":31}}", jsonStr); + + assertArrayEquals( + new Object[] {null}, (Object[]) Json.strappend(client, key, "\"bar\"", "$.nested").get()); + + assertThrows( + ExecutionException.class, () -> Json.strappend(client, key, "\"bar\"", ".nested").get()); + + assertThrows(ExecutionException.class, () -> Json.strappend(client, key, "\"bar\"").get()); + + assertArrayEquals( + new Object[] {}, + (Object[]) Json.strappend(client, key, "\"try\"", "$.non_existing_path").get()); + + assertThrows( + ExecutionException.class, + () -> Json.strappend(client, key, "\"try\"", "non_existing_path").get()); + + assertThrows( + ExecutionException.class, + () -> Json.strappend(client, "non_existing_key", "\"try\"").get()); + + // Binary test + // Binary with path + assertEquals(12L, (Long) Json.strappend(client, gs(key), gs("\"foo\""), gs("a")).get()); + jsonStr = Json.get(client, key, new String[] {"."}).get(); + assertEquals( + "{\"a\":\"foobarfoofoo\",\"nested\":{\"a\":\"hellobar\"},\"nested2\":{\"a\":31}}", jsonStr); + + // Binary no path + assertEquals("OK", Json.set(client, key, "$", "\"hi\"").get()); + assertEquals(5L, Json.strappend(client, gs(key), gs("\"foo\"")).get()); + jsonStr = Json.get(client, key, new String[] {"."}).get(); + assertEquals("\"hifoo\"", jsonStr); + } + + @Test + @SneakyThrows + public void strlen() { + String key = UUID.randomUUID().toString(); + String jsonValue = "{\"a\": \"foo\", \"nested\": {\"a\": \"hello\"}, \"nested2\": {\"a\": 31}}"; + assertEquals("OK", Json.set(client, key, "$", jsonValue).get()); + + assertArrayEquals( + new Object[] {3L, 5L, null}, (Object[]) Json.strlen(client, key, "$..a").get()); + assertEquals(3L, (Long) Json.strlen(client, key, "a").get()); + + assertArrayEquals(new Object[] {null}, (Object[]) Json.strlen(client, key, "$.nested").get()); + + assertThrows(ExecutionException.class, () -> Json.strlen(client, key, "nested").get()); + + assertThrows(ExecutionException.class, () -> Json.strlen(client, key).get()); + + assertArrayEquals( + new Object[] {}, (Object[]) Json.strlen(client, key, "$.non_existing_path").get()); + assertThrows( + ExecutionException.class, () -> Json.strlen(client, key, ".non_existing_path").get()); + + assertNull(Json.strlen(client, "non_existing_key", ".").get()); + assertNull(Json.strlen(client, "non_existing_key", "$").get()); + + // Binary test + // Binary with path + assertEquals(3L, (Long) Json.strlen(client, gs(key), gs("a")).get()); + + // Binary no path + assertEquals("OK", Json.set(client, key, "$", "\"hi\"").get()); + assertEquals(2L, Json.strlen(client, gs(key)).get()); + } + + @Test + @SneakyThrows + public void json_resp() { + String key = UUID.randomUUID().toString(); + String jsonValue = + "{\"obj\":{\"a\":1, \"b\":2}, \"arr\":[1,2,3], \"str\": \"foo\", \"bool\": true, \"int\":" + + " 42, \"float\": 3.14, \"nullVal\": null}"; + assertEquals(OK, Json.set(client, key, "$", jsonValue).get()); + + Object actualResult1 = Json.resp(client, key, "$.*").get(); + Object[] expectedResult1 = + new Object[] { + new Object[] { + "{", + new Object[] {"a", 1L}, + new Object[] {"b", 2L} // leading "{" indicates JSON objects + }, + new Object[] {"[", 1L, 2L, 3L}, // leading "[" indicates JSON arrays + "foo", + "true", + 42L, + "3.14", + null + }; + assertInstanceOf(Object[].class, actualResult1); + assertArrayEquals(expectedResult1, (Object[]) actualResult1); + + // multiple path match, the first will be returned + Object actualResult2 = Json.resp(client, key, "*").get(); + Object[] expectedResult2 = new Object[] {"{", new Object[] {"a", 1L}, new Object[] {"b", 2L}}; + assertInstanceOf(Object[].class, actualResult2); + assertArrayEquals(expectedResult2, (Object[]) actualResult2); + + Object actualResult3 = Json.resp(client, key, "$").get(); + Object[] expectedResult3 = + new Object[] { + new Object[] { + "{", + new Object[] { + "obj", new Object[] {"{", new Object[] {"a", 1L}, new Object[] {"b", 2L}} + }, + new Object[] {"arr", new Object[] {"[", 1L, 2L, 3L}}, + new Object[] {"str", "foo"}, + new Object[] {"bool", "true"}, + new Object[] {"int", 42L}, + new Object[] {"float", "3.14"}, + new Object[] {"nullVal", null} + } + }; + assertInstanceOf(Object[].class, actualResult3); + assertArrayEquals(expectedResult3, (Object[]) actualResult3); + + Object actualResult4 = Json.resp(client, key, ".").get(); + Object[] expectedResult4 = + new Object[] { + "{", + new Object[] {"obj", new Object[] {"{", new Object[] {"a", 1L}, new Object[] {"b", 2L}}}, + new Object[] {"arr", new Object[] {"[", 1L, 2L, 3L}}, + new Object[] {"str", "foo"}, + new Object[] {"bool", "true"}, + new Object[] {"int", 42L}, + new Object[] {"float", "3.14"}, + new Object[] {"nullVal", null} + }; + assertInstanceOf(Object[].class, actualResult4); + assertArrayEquals(expectedResult4, (Object[]) actualResult4); + // resp without path defaults to the same behavior of passing "." as path + Object actualResult4WithoutPath = Json.resp(client, key).get(); + assertArrayEquals(expectedResult4, (Object[]) actualResult4WithoutPath); + assertArrayEquals(expectedResult4, (Object[]) actualResult4WithoutPath); + + Object actualResult5 = Json.resp(client, gs(key), gs("$.str")).get(); + Object[] expectedResult5 = new Object[] {gs("foo")}; + assertInstanceOf(Object[].class, actualResult5); + assertArrayEquals(expectedResult5, (Object[]) actualResult5); + + Object actualResult6 = Json.resp(client, key, ".str").get(); + String expectedResult6 = "foo"; + assertEquals(expectedResult6, actualResult6); + + assertArrayEquals(new Object[] {}, (Object[]) Json.resp(client, key, "$.nonexistent").get()); + + assertThrows(ExecutionException.class, () -> Json.resp(client, key, "nonexistent").get()); + + assertNull(Json.resp(client, "nonexistent_key", "$").get()); + assertNull(Json.resp(client, "nonexistent_key", ".").get()); + assertNull(Json.resp(client, "nonexistent_key").get()); + } + + @Test + @SneakyThrows + public void json_type() { + String key = UUID.randomUUID().toString(); + String jsonValue = + "{\"key1\": \"value1\", \"key2\": 2, \"key3\": [1, 2, 3], \"key4\": {\"nested_key\":" + + " {\"key1\": [4, 5]}}, \"key5\": null, \"key6\": true, \"dec_key\": 2.3}"; + assertEquals(OK, Json.set(client, key, "$", jsonValue).get()); + + assertArrayEquals(new Object[] {"object"}, (Object[]) Json.type(client, key, "$").get()); + assertArrayEquals( + new Object[] {gs("string"), gs("array")}, + (Object[]) Json.type(client, gs(key), gs("$..key1")).get()); + assertArrayEquals(new Object[] {"integer"}, (Object[]) Json.type(client, key, "$.key2").get()); + assertArrayEquals(new Object[] {"array"}, (Object[]) Json.type(client, key, "$.key3").get()); + assertArrayEquals(new Object[] {"object"}, (Object[]) Json.type(client, key, "$.key4").get()); + assertArrayEquals( + new Object[] {"object"}, (Object[]) Json.type(client, key, "$.key4.nested_key").get()); + assertArrayEquals(new Object[] {"null"}, (Object[]) Json.type(client, key, "$.key5").get()); + assertArrayEquals(new Object[] {"boolean"}, (Object[]) Json.type(client, key, "$.key6").get()); + // Check for non-existent path in enhanced mode $.key7 + assertArrayEquals(new Object[] {}, (Object[]) Json.type(client, key, "$.key7").get()); + // Check for non-existent path within an existing key (array bound) + assertArrayEquals(new Object[] {}, (Object[]) Json.type(client, key, "$.key3[3]").get()); + // Legacy path (without $) - will return None for non-existing path + assertNull(Json.type(client, key, "key7").get()); + // Check for multiple path match in legacy + assertEquals("string", Json.type(client, key, "..key1").get()); + // Check for non-existent key with enhanced path + assertNull(Json.type(client, "non_existing_key", "$.key1").get()); + // Check for non-existent key with legacy path + assertNull(Json.type(client, "non_existing_key", "key1").get()); + // Check for all types in the JSON document using JSON Path + Object[] actualResult = (Object[]) Json.type(client, key, "$[*]").get(); + Object[] expectedResult = + new Object[] {"string", "integer", "array", "object", "null", "boolean", "number"}; + assertArrayEquals(expectedResult, actualResult); + // Check for all types in the JSON document using legacy path + assertEquals("string", Json.type(client, key, "[*]").get()); + } } diff --git a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java index 07b0946b3d..f53f7ced30 100644 --- a/java/integTest/src/test/java/glide/modules/VectorSearchTests.java +++ b/java/integTest/src/test/java/glide/modules/VectorSearchTests.java @@ -1,24 +1,825 @@ /** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ package glide.modules; +import static glide.TestUtilities.assertDeepEquals; import static glide.TestUtilities.commonClusterClientConfig; +import static glide.api.BaseClient.OK; +import static glide.api.models.GlideString.gs; +import static glide.api.models.configuration.RequestRoutingConfiguration.SimpleMultiNodeRoute.ALL_PRIMARIES; import static glide.api.models.configuration.RequestRoutingConfiguration.SimpleSingleNodeRoute.RANDOM; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import glide.api.GlideClusterClient; +import glide.api.commands.servermodules.FT; +import glide.api.commands.servermodules.Json; +import glide.api.models.GlideString; +import glide.api.models.commands.FT.FTAggregateOptions; +import glide.api.models.commands.FT.FTAggregateOptions.Apply; +import glide.api.models.commands.FT.FTAggregateOptions.GroupBy; +import glide.api.models.commands.FT.FTAggregateOptions.GroupBy.Reducer; +import glide.api.models.commands.FT.FTAggregateOptions.SortBy; +import glide.api.models.commands.FT.FTAggregateOptions.SortBy.SortOrder; +import glide.api.models.commands.FT.FTAggregateOptions.SortBy.SortProperty; +import glide.api.models.commands.FT.FTCreateOptions; +import glide.api.models.commands.FT.FTCreateOptions.DataType; +import glide.api.models.commands.FT.FTCreateOptions.DistanceMetric; +import glide.api.models.commands.FT.FTCreateOptions.FieldInfo; +import glide.api.models.commands.FT.FTCreateOptions.NumericField; +import glide.api.models.commands.FT.FTCreateOptions.TagField; +import glide.api.models.commands.FT.FTCreateOptions.TextField; +import glide.api.models.commands.FT.FTCreateOptions.VectorFieldFlat; +import glide.api.models.commands.FT.FTCreateOptions.VectorFieldHnsw; +import glide.api.models.commands.FT.FTProfileOptions; +import glide.api.models.commands.FT.FTSearchOptions; +import glide.api.models.commands.FlushMode; import glide.api.models.commands.InfoOptions.Section; +import glide.api.models.exceptions.RequestException; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.stream.Collectors; import lombok.SneakyThrows; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; public class VectorSearchTests { - @Test + private static GlideClusterClient client; + + /** Waiting interval to let server process the data before querying */ + private static final int DATA_PROCESSING_TIMEOUT = 1000; // ms + + @BeforeAll @SneakyThrows - public void check_module_loaded() { - var client = + public static void init() { + client = GlideClusterClient.createClient(commonClusterClientConfig().requestTimeout(5000).build()) .get(); + client.flushall(FlushMode.SYNC, ALL_PRIMARIES).get(); + } + + @AfterAll + @SneakyThrows + public static void teardown() { + client.flushall(FlushMode.SYNC, ALL_PRIMARIES).get(); + client.close(); + } + + @Test + @SneakyThrows + public void check_module_loaded() { var info = client.info(new Section[] {Section.MODULES}, RANDOM).get().getSingleValue(); assertTrue(info.contains("# search_index_stats")); } + + @SneakyThrows + @Test + public void ft_create() { + // create few simple indices + assertEquals( + OK, + FT.create( + client, + UUID.randomUUID().toString(), + new FieldInfo[] { + new FieldInfo("vec", "VEC", VectorFieldHnsw.builder(DistanceMetric.L2, 2).build()) + }) + .get()); + assertEquals( + OK, + FT.create( + client, + UUID.randomUUID().toString(), + new FieldInfo[] { + new FieldInfo( + "$.vec", "VEC", VectorFieldFlat.builder(DistanceMetric.L2, 6).build()) + }, + FTCreateOptions.builder() + .dataType(DataType.JSON) + .prefixes(new String[] {"json:"}) + .build()) + .get()); + + // create an index with HNSW vector with additional parameters + assertEquals( + OK, + FT.create( + client, + UUID.randomUUID().toString(), + new FieldInfo[] { + new FieldInfo( + "doc_embedding", + VectorFieldHnsw.builder(DistanceMetric.COSINE, 1536) + .numberOfEdges(40) + .vectorsExaminedOnConstruction(250) + .vectorsExaminedOnRuntime(40) + .build()) + }, + FTCreateOptions.builder() + .dataType(DataType.HASH) + .prefixes(new String[] {"docs:"}) + .build()) + .get()); + + // create an index with multiple fields + assertEquals( + OK, + FT.create( + client, + UUID.randomUUID().toString(), + new FieldInfo[] { + new FieldInfo("title", new TextField()), + new FieldInfo("published_at", new NumericField()), + new FieldInfo("category", new TagField()) + }, + FTCreateOptions.builder() + .dataType(DataType.HASH) + .prefixes(new String[] {"blog:post:"}) + .build()) + .get()); + + // create an index with multiple prefixes + var index = UUID.randomUUID().toString(); + assertEquals( + OK, + FT.create( + client, + index, + new FieldInfo[] { + new FieldInfo("author_id", new TagField()), + new FieldInfo("author_ids", new TagField()), + new FieldInfo("title", new TextField()), + new FieldInfo("name", new TextField()) + }, + FTCreateOptions.builder() + .dataType(DataType.HASH) + .prefixes(new String[] {"author:details:", "book:details:"}) + .build()) + .get()); + + // create a duplicating index + var exception = + assertThrows( + ExecutionException.class, + () -> + FT.create( + client, + index, + new FieldInfo[] { + new FieldInfo("title", new TextField()), + new FieldInfo("name", new TextField()) + }) + .get()); + assertInstanceOf(RequestException.class, exception.getCause()); + assertTrue(exception.getMessage().contains("already exists")); + + // create an index without fields + exception = + assertThrows( + ExecutionException.class, + () -> FT.create(client, UUID.randomUUID().toString(), new FieldInfo[0]).get()); + assertInstanceOf(RequestException.class, exception.getCause()); + assertTrue(exception.getMessage().contains("wrong number of arguments")); + + // duplicated field name + exception = + assertThrows( + ExecutionException.class, + () -> + FT.create( + client, + UUID.randomUUID().toString(), + new FieldInfo[] { + new FieldInfo("name", new TextField()), + new FieldInfo("name", new TextField()) + }) + .get()); + assertInstanceOf(RequestException.class, exception.getCause()); + assertTrue(exception.getMessage().contains("already exists")); + } + + @SneakyThrows + @Test + public void ft_search() { + String prefix = "{" + UUID.randomUUID() + "}:"; + String index = prefix + "index"; + + assertEquals( + OK, + FT.create( + client, + index, + new FieldInfo[] { + new FieldInfo("vec", "VEC", VectorFieldHnsw.builder(DistanceMetric.L2, 2).build()) + }, + FTCreateOptions.builder() + .dataType(DataType.HASH) + .prefixes(new String[] {prefix}) + .build()) + .get()); + + assertEquals( + 1L, + client + .hset( + gs(prefix + 0), + Map.of( + gs("vec"), + gs( + new byte[] { + (byte) 0, (byte) 0, (byte) 0, (byte) 0, (byte) 0, (byte) 0, (byte) 0, + (byte) 0 + }))) + .get()); + assertEquals( + 1L, + client + .hset( + gs(prefix + 1), + Map.of( + gs("vec"), + gs( + new byte[] { + (byte) 0, + (byte) 0, + (byte) 0, + (byte) 0, + (byte) 0, + (byte) 0, + (byte) 0x80, + (byte) 0xBF + }))) + .get()); + Thread.sleep(DATA_PROCESSING_TIMEOUT); // let server digest the data and update index + + // FT.SEARCH hash_idx1 "*=>[KNN 2 @VEC $query_vec]" PARAMS 2 query_vec + // "\x00\x00\x00\x00\x00\x00\x00\x00" DIALECT 2 + var options = + FTSearchOptions.builder() + .params( + Map.of( + gs("query_vec"), + gs( + new byte[] { + (byte) 0, (byte) 0, (byte) 0, (byte) 0, (byte) 0, (byte) 0, (byte) 0, + (byte) 0 + }))) + .build(); + var query = "*=>[KNN 2 @VEC $query_vec]"; + var ftsearch = FT.search(client, index, query, options).get(); + + assertArrayEquals( + new Object[] { + 2L, + Map.of( + gs(prefix + 0), + Map.of(gs("__VEC_score"), gs("0"), gs("vec"), gs("\0\0\0\0\0\0\0\0")), + gs(prefix + 1), + Map.of( + gs("__VEC_score"), + gs("1"), + gs("vec"), + gs( + new byte[] { + (byte) 0, + (byte) 0, + (byte) 0, + (byte) 0, + (byte) 0, + (byte) 0, + (byte) 0x80, + (byte) 0xBF + }))) + }, + ftsearch); + + // TODO more tests with json index + + var ftprofile = FT.profile(client, index, new FTProfileOptions(query, options)).get(); + assertArrayEquals(ftsearch, (Object[]) ftprofile[0]); + + // querying non-existing index + var exception = + assertThrows( + ExecutionException.class, + () -> FT.search(client, UUID.randomUUID().toString(), "*").get()); + assertInstanceOf(RequestException.class, exception.getCause()); + assertTrue(exception.getMessage().contains("Index not found")); + } + + @SneakyThrows + @Test + public void ft_drop_and_ft_list() { + var index = gs(UUID.randomUUID().toString()); + assertEquals( + OK, + FT.create( + client, + index, + new FieldInfo[] { + new FieldInfo("vec", VectorFieldHnsw.builder(DistanceMetric.L2, 2).build()) + }) + .get()); + + var before = Set.of(FT.list(client).get()); + + assertEquals(OK, FT.dropindex(client, index).get()); + + var after = new HashSet<>(Set.of(FT.list(client).get())); + + assertFalse(after.contains(index)); + after.add(index); + assertEquals(after, before); + + var exception = assertThrows(ExecutionException.class, () -> FT.dropindex(client, index).get()); + assertInstanceOf(RequestException.class, exception.getCause()); + assertTrue(exception.getMessage().contains("Index does not exist")); + } + + @SneakyThrows + @Test + public void ft_aggregate() { + var prefixBicycles = "{bicycles}:"; + var indexBicycles = prefixBicycles + UUID.randomUUID(); + var prefixMovies = "{movies}:"; + var indexMovies = prefixMovies + UUID.randomUUID(); + + // FT.CREATE idx:bicycle ON JSON PREFIX 1 bicycle: SCHEMA $.model AS model TEXT $.description AS + // description TEXT $.price AS price NUMERIC $.condition AS condition TAG SEPARATOR , + assertEquals( + OK, + FT.create( + client, + indexBicycles, + new FieldInfo[] { + new FieldInfo("$.model", "model", new TextField()), + new FieldInfo("$.price", "price", new NumericField()), + new FieldInfo("$.condition", "condition", new TagField(',')), + }, + FTCreateOptions.builder() + .dataType(DataType.JSON) + .prefixes(new String[] {prefixBicycles}) + .build()) + .get()); + + // TODO check JSON module loaded + + Json.set( + client, + prefixBicycles + 0, + ".", + "{\"brand\": \"Velorim\", \"model\": \"Jigger\", \"price\": 270, \"condition\":" + + " \"new\"}") + .get(); + Json.set( + client, + prefixBicycles + 1, + ".", + "{\"brand\": \"Bicyk\", \"model\": \"Hillcraft\", \"price\": 1200, \"condition\":" + + " \"used\"}") + .get(); + Json.set( + client, + prefixBicycles + 2, + ".", + "{\"brand\": \"Nord\", \"model\": \"Chook air 5\", \"price\": 815, \"condition\":" + + " \"used\"}") + .get(); + Json.set( + client, + prefixBicycles + 3, + ".", + "{\"brand\": \"Eva\", \"model\": \"Eva 291\", \"price\": 3400, \"condition\":" + + " \"used\"}") + .get(); + Json.set( + client, + prefixBicycles + 4, + ".", + "{\"brand\": \"Noka Bikes\", \"model\": \"Kahuna\", \"price\": 3200, \"condition\":" + + " \"used\"}") + .get(); + Json.set( + client, + prefixBicycles + 5, + ".", + "{\"brand\": \"Breakout\", \"model\": \"XBN 2.1 Alloy\", \"price\": 810, \"condition\":" + + " \"new\"}") + .get(); + Json.set( + client, + prefixBicycles + 6, + ".", + "{\"brand\": \"ScramBikes\", \"model\": \"WattBike\", \"price\": 2300, \"condition\":" + + " \"new\"}") + .get(); + Json.set( + client, + prefixBicycles + 7, + ".", + "{\"brand\": \"Peaknetic\", \"model\": \"Secto\", \"price\": 430, \"condition\":" + + " \"new\"}") + .get(); + Json.set( + client, + prefixBicycles + 8, + ".", + "{\"brand\": \"nHill\", \"model\": \"Summit\", \"price\": 1200, \"condition\":" + + " \"new\"}") + .get(); + Json.set( + client, + prefixBicycles + 9, + ".", + "{\"model\": \"ThrillCycle\", \"brand\": \"BikeShind\", \"price\": 815, \"condition\":" + + " \"refurbished\"}") + .get(); + + Thread.sleep(DATA_PROCESSING_TIMEOUT); // let server digest the data and update index + + // FT.AGGREGATE idx:bicycle "*" LOAD 1 "__key" GROUPBY 1 "@condition" REDUCE COUNT 0 AS bicylces + var options = + FTAggregateOptions.builder() + .loadFields(new String[] {"__key"}) + .addClause( + new GroupBy( + new String[] {"@condition"}, + new Reducer[] {new Reducer("COUNT", new String[0], "bicycles")})) + .build(); + var aggreg = FT.aggregate(client, indexBicycles, "*", options).get(); + // elements (maps in array) could be reordered, comparing as sets + assertDeepEquals( + Set.of( + Map.of(gs("condition"), gs("new"), gs("bicycles"), 5.), + Map.of(gs("condition"), gs("used"), gs("bicycles"), 4.), + Map.of(gs("condition"), gs("refurbished"), gs("bicycles"), 1.)), + Set.of(aggreg)); + + // FT.CREATE idx:movie ON hash PREFIX 1 "movie:" SCHEMA title TEXT release_year NUMERIC rating + // NUMERIC genre TAG votes NUMERIC + assertEquals( + OK, + FT.create( + client, + indexMovies, + new FieldInfo[] { + new FieldInfo("title", new TextField()), + new FieldInfo("release_year", new NumericField()), + new FieldInfo("rating", new NumericField()), + new FieldInfo("genre", new TagField()), + new FieldInfo("votes", new NumericField()), + }, + FTCreateOptions.builder() + .dataType(DataType.HASH) + .prefixes(new String[] {prefixMovies}) + .build()) + .get()); + + client + .hset( + prefixMovies + 11002, + Map.of( + "title", + "Star Wars: Episode V - The Empire Strikes Back", + "release_year", + "1980", + "genre", + "Action", + "rating", + "8.7", + "votes", + "1127635", + "imdb_id", + "tt0080684")) + .get(); + client + .hset( + prefixMovies + 11003, + Map.of( + "title", + "The Godfather", + "release_year", + "1972", + "genre", + "Drama", + "rating", + "9.2", + "votes", + "1563839", + "imdb_id", + "tt0068646")) + .get(); + client + .hset( + prefixMovies + 11004, + Map.of( + "title", + "Heat", + "release_year", + "1995", + "genre", + "Thriller", + "rating", + "8.2", + "votes", + "559490", + "imdb_id", + "tt0113277")) + .get(); + client + .hset( + prefixMovies + 11005, + Map.of( + "title", + "Star Wars: Episode VI - Return of the Jedi", + "genre", + "Action", + "votes", + "906260", + "rating", + "8.3", + "release_year", + "1983", + "ibmdb_id", + "tt0086190")) + .get(); + Thread.sleep(DATA_PROCESSING_TIMEOUT); // let server digest the data and update index + + // FT.AGGREGATE idx:movie * LOAD * APPLY ceil(@rating) as r_rating GROUPBY 1 @genre REDUCE + // COUNT 0 AS nb_of_movies REDUCE SUM 1 votes AS nb_of_votes REDUCE AVG 1 r_rating AS avg_rating + // SORTBY 4 @avg_rating DESC @nb_of_votes DESC + options = + FTAggregateOptions.builder() + .loadAll() + .addClause(new Apply("ceil(@rating)", "r_rating")) + .addClause( + new GroupBy( + new String[] {"@genre"}, + new Reducer[] { + new Reducer("COUNT", new String[0], "nb_of_movies"), + new Reducer("SUM", new String[] {"votes"}, "nb_of_votes"), + new Reducer("AVG", new String[] {"r_rating"}, "avg_rating") + })) + .addClause( + new SortBy( + new SortProperty[] { + new SortProperty("@avg_rating", SortOrder.DESC), + new SortProperty("@nb_of_votes", SortOrder.DESC) + })) + .build(); + aggreg = FT.aggregate(client, indexMovies, "*", options).get(); + // elements (maps in array) could be reordered, comparing as sets + assertDeepEquals( + Set.of( + Map.of( + gs("genre"), + gs("Drama"), + gs("nb_of_movies"), + 1., + gs("nb_of_votes"), + 1563839., + gs("avg_rating"), + 10.), + Map.of( + gs("genre"), + gs("Action"), + gs("nb_of_movies"), + 2., + gs("nb_of_votes"), + 2033895., + gs("avg_rating"), + 9.), + Map.of( + gs("genre"), + gs("Thriller"), + gs("nb_of_movies"), + 1., + gs("nb_of_votes"), + 559490., + gs("avg_rating"), + 9.)), + Set.of(aggreg)); + + var ftprofile = FT.profile(client, indexMovies, new FTProfileOptions("*", options)).get(); + assertDeepEquals(aggreg, ftprofile[0]); + } + + @SuppressWarnings("unchecked") + @Test + @SneakyThrows + public void ft_info() { + var index = UUID.randomUUID().toString(); + assertEquals( + OK, + FT.create( + client, + index, + new FieldInfo[] { + new FieldInfo( + "$.vec", "VEC", VectorFieldHnsw.builder(DistanceMetric.COSINE, 42).build()), + new FieldInfo("$.name", new TextField()), + }, + FTCreateOptions.builder() + .dataType(DataType.JSON) + .prefixes(new String[] {"123"}) + .build()) + .get()); + + var response = FT.info(client, index).get(); + assertEquals(gs(index), response.get("index_name")); + assertEquals(gs("JSON"), response.get("key_type")); + assertArrayEquals(new GlideString[] {gs("123")}, (Object[]) response.get("key_prefixes")); + var fields = (Object[]) response.get("fields"); + assertEquals(2, fields.length); + var f1 = (Map) fields[1]; + assertEquals(gs("$.vec"), f1.get(gs("identifier"))); + assertEquals(gs("VECTOR"), f1.get(gs("type"))); + assertEquals(gs("VEC"), f1.get(gs("field_name"))); + var f1params = (Map) f1.get(gs("vector_params")); + assertEquals(gs("COSINE"), f1params.get(gs("distance_metric"))); + assertEquals(42L, f1params.get(gs("dimension"))); + + assertEquals( + Map.of( + gs("identifier"), + gs("$.name"), + gs("type"), + gs("TEXT"), + gs("field_name"), + gs("$.name"), + gs("option"), + gs("")), + fields[0]); + + // querying a missing index + assertEquals(OK, FT.dropindex(client, index).get()); + var exception = assertThrows(ExecutionException.class, () -> FT.info(client, index).get()); + assertInstanceOf(RequestException.class, exception.getCause()); + assertTrue(exception.getMessage().contains("Index not found")); + } + + @SneakyThrows + @Test + public void ft_aliasadd_aliasdel_aliasupdate_aliaslist() { + + var alias1 = "alias1"; + var alias2 = "a2"; + var indexName = "{" + UUID.randomUUID() + "-index}"; + + // create some indices + assertEquals( + OK, + FT.create( + client, + indexName, + new FieldInfo[] { + new FieldInfo("vec", VectorFieldFlat.builder(DistanceMetric.L2, 2).build()) + }) + .get()); + + assertEquals(0, FT.aliaslist(client).get().size()); + assertEquals(OK, FT.aliasadd(client, alias1, indexName).get()); + assertEquals(Map.of(gs(alias1), gs(indexName)), FT.aliaslist(client).get()); + + // error with adding the same alias to the same index + var exception = + assertThrows(ExecutionException.class, () -> FT.aliasadd(client, alias1, indexName).get()); + assertInstanceOf(RequestException.class, exception.getCause()); + assertTrue(exception.getMessage().contains("Alias already exists")); + + assertEquals(OK, FT.aliasupdate(client, alias2, indexName).get()); + assertEquals( + Map.of(gs(alias1), gs(indexName), gs(alias2), gs(indexName)), FT.aliaslist(client).get()); + assertEquals(OK, FT.aliasdel(client, alias2).get()); + + // with GlideString: + assertEquals(OK, FT.aliasupdate(client, gs(alias1), gs(indexName)).get()); + assertEquals(OK, FT.aliasdel(client, gs(alias1)).get()); + assertEquals(OK, FT.aliasadd(client, gs(alias2), gs(indexName)).get()); + assertEquals(OK, FT.aliasdel(client, gs(alias2)).get()); + + // exception with calling `aliasdel` on an alias that doesn't exist + exception = assertThrows(ExecutionException.class, () -> FT.aliasdel(client, alias2).get()); + assertInstanceOf(RequestException.class, exception.getCause()); + assertTrue(exception.getMessage().contains("Alias does not exist")); + + // exception with calling `aliasadd` with a nonexisting index + exception = + assertThrows( + ExecutionException.class, () -> FT.aliasadd(client, alias1, "nonexistent_index").get()); + assertInstanceOf(RequestException.class, exception.getCause()); + assertTrue(exception.getMessage().contains("Index does not exist")); + } + + @SneakyThrows + @Test + public void ft_explain() { + + String indexName = UUID.randomUUID().toString(); + createIndexHelper(indexName); + + // search query containing numeric field. + String query = "@price:[0 10]"; + String result = FT.explain(client, indexName, query).get(); + assertTrue(result.contains("price")); + assertTrue(result.contains("0")); + assertTrue(result.contains("10")); + + GlideString resultGS = FT.explain(client, gs(indexName), gs(query)).get(); + assertTrue((resultGS).toString().contains("price")); + assertTrue((resultGS).toString().contains("0")); + assertTrue((resultGS).toString().contains("10")); + + // search query that returns all data. + GlideString resultGSAllData = FT.explain(client, gs(indexName), gs("*")).get(); + assertTrue(resultGSAllData.toString().contains("*")); + + assertEquals(OK, FT.dropindex(client, indexName).get()); + + // missing index throws an error. + var exception = + assertThrows( + ExecutionException.class, + () -> FT.explain(client, UUID.randomUUID().toString(), "*").get()); + assertInstanceOf(RequestException.class, exception.getCause()); + assertTrue(exception.getMessage().contains("Index not found")); + } + + @SneakyThrows + @Test + public void ft_explaincli() { + + String indexName = UUID.randomUUID().toString(); + createIndexHelper(indexName); + + // search query containing numeric field. + String query = "@price:[0 10]"; + String[] result = FT.explaincli(client, indexName, query).get(); + List resultList = Arrays.stream(result).map(String::trim).collect(Collectors.toList()); + + assertTrue(resultList.contains("price")); + assertTrue(resultList.contains("0")); + assertTrue(resultList.contains("10")); + + GlideString[] resultGS = FT.explaincli(client, gs(indexName), gs(query)).get(); + List resultListGS = + Arrays.stream(resultGS) + .map(GlideString::toString) + .map(String::trim) + .collect(Collectors.toList()); + + assertTrue((resultListGS).contains("price")); + assertTrue((resultListGS).contains("0")); + assertTrue((resultListGS).contains("10")); + + // search query that returns all data. + GlideString[] resultGSAllData = FT.explaincli(client, gs(indexName), gs("*")).get(); + List resultListGSAllData = + Arrays.stream(resultGSAllData) + .map(GlideString::toString) + .map(String::trim) + .collect(Collectors.toList()); + assertTrue((resultListGSAllData).contains("*")); + + assertEquals(OK, FT.dropindex(client, indexName).get()); + + // missing index throws an error. + var exception = + assertThrows( + ExecutionException.class, + () -> FT.explaincli(client, UUID.randomUUID().toString(), "*").get()); + assertInstanceOf(RequestException.class, exception.getCause()); + assertTrue(exception.getMessage().contains("Index not found")); + } + + private void createIndexHelper(String indexName) throws ExecutionException, InterruptedException { + FieldInfo numericField = new FieldInfo("price", new NumericField()); + FieldInfo textField = new FieldInfo("title", new TextField()); + + FieldInfo[] fields = new FieldInfo[] {numericField, textField}; + + String prefix = "{hash-search-" + UUID.randomUUID().toString() + "}:"; + + assertEquals( + OK, + FT.create( + client, + indexName, + fields, + FTCreateOptions.builder() + .dataType(DataType.HASH) + .prefixes(new String[] {prefix}) + .build()) + .get()); + } } diff --git a/java/integTest/src/test/java/glide/standalone/StandaloneClientTests.java b/java/integTest/src/test/java/glide/standalone/StandaloneClientTests.java index e61b97ed4e..02551a6b9c 100644 --- a/java/integTest/src/test/java/glide/standalone/StandaloneClientTests.java +++ b/java/integTest/src/test/java/glide/standalone/StandaloneClientTests.java @@ -6,6 +6,8 @@ import static glide.TestUtilities.getRandomString; import static glide.api.BaseClient.OK; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assumptions.assumeTrue; @@ -14,10 +16,15 @@ import glide.api.models.configuration.ServerCredentials; import glide.api.models.exceptions.ClosingException; import glide.api.models.exceptions.RequestException; +import java.util.Map; +import java.util.UUID; import java.util.concurrent.ExecutionException; import lombok.SneakyThrows; +import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; @Timeout(10) // seconds public class StandaloneClientTests { @@ -45,7 +52,7 @@ public void can_connect_with_auth_require_pass() { GlideClient client = GlideClient.createClient(commonClientConfig().build()).get(); String password = "TEST_AUTH"; - client.customCommand(new String[] {"CONFIG", "SET", "requirepass", password}).get(); + client.configSet(Map.of("requirepass", password)).get(); // Creation of a new client without a password should fail ExecutionException exception = @@ -69,7 +76,7 @@ public void can_connect_with_auth_require_pass() { assertEquals(value, auth_client.get(key).get()); // Reset password - client.customCommand(new String[] {"CONFIG", "SET", "requirepass", ""}).get(); + client.configSet(Map.of("requirepass", "")).get(); auth_client.close(); client.close(); @@ -159,4 +166,120 @@ public void closed_client_throws_ExecutionException_with_ClosingException_as_cau assertThrows(ExecutionException.class, () -> client.set("key", "value").get()); assertTrue(executionException.getCause() instanceof ClosingException); } + + @SneakyThrows + @Test + public void update_connection_password_auth_non_valid_pass() { + // Test Client fails on call to updateConnectionPassword with invalid parameters + try (var testClient = GlideClient.createClient(commonClientConfig().build()).get()) { + var emptyPasswordException = + assertThrows( + ExecutionException.class, () -> testClient.updateConnectionPassword("", true).get()); + assertInstanceOf(RequestException.class, emptyPasswordException.getCause()); + + var noPasswordException = + assertThrows( + ExecutionException.class, () -> testClient.updateConnectionPassword(true).get()); + assertInstanceOf(RequestException.class, noPasswordException.getCause()); + } + } + + @SneakyThrows + @Test + public void update_connection_password_no_server_auth() { + var pwd = UUID.randomUUID().toString(); + + try (var testClient = GlideClient.createClient(commonClientConfig().build()).get()) { + // validate that we can use the client + assertNotNull(testClient.info().get()); + + // Test that immediate re-authentication fails when no server password is set. + var exception = + assertThrows( + ExecutionException.class, () -> testClient.updateConnectionPassword(pwd, true).get()); + assertInstanceOf(RequestException.class, exception.getCause()); + } + } + + @SneakyThrows + @Test + public void update_connection_password_long() { + var pwd = RandomStringUtils.randomAlphabetic(1000); + + try (var testClient = GlideClient.createClient(commonClientConfig().build()).get()) { + // validate that we can use the client + assertNotNull(testClient.info().get()); + + // Test replacing connection password with a long password string. + assertEquals(OK, testClient.updateConnectionPassword(pwd, false).get()); + } + } + + @Timeout(50) + @SneakyThrows + @Test + public void replace_password_immediateAuth_wrong_password() { + var pwd = UUID.randomUUID().toString(); + var notThePwd = UUID.randomUUID().toString(); + + GlideClient adminClient = GlideClient.createClient(commonClientConfig().build()).get(); + try (var testClient = GlideClient.createClient(commonClientConfig().build()).get()) { + // validate that we can use the client + assertNotNull(testClient.info().get()); + + // set the password to something else + adminClient.configSet(Map.of("requirepass", notThePwd)).get(); + + // Test that re-authentication fails when using wrong password. + var exception = + assertThrows( + ExecutionException.class, () -> testClient.updateConnectionPassword(pwd, true).get()); + assertInstanceOf(RequestException.class, exception.getCause()); + + // But using something else password returns OK + assertEquals(OK, testClient.updateConnectionPassword(notThePwd, true).get()); + } finally { + adminClient.configSet(Map.of("requirepass", "")).get(); + adminClient.close(); + } + } + + @SneakyThrows + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void update_connection_password_connection_lost_before_password_update( + boolean immediateAuth) { + GlideClient adminClient = GlideClient.createClient(commonClientConfig().build()).get(); + var pwd = UUID.randomUUID().toString(); + + try (var testClient = GlideClient.createClient(commonClientConfig().build()).get()) { + // validate that we can use the client + assertNotNull(testClient.info().get()); + + // set the password and forcefully drop connection for the testClient + assertEquals("OK", adminClient.configSet(Map.of("requirepass", pwd)).get()); + adminClient.customCommand(new String[] {"CLIENT", "KILL", "TYPE", "NORMAL"}).get(); + + /* + * Some explanation for the curious mind: + * Our library is abstracting a connection or connections, with a lot of mechanism around it, making it behave like what we call a "client". + * When using standalone mode, the client is a single connection, so on disconnection the first thing it planned to do is to reconnect. + * + * There's no reason to get other commands and to take care of them since to serve commands we need to be connected. + * + * Hence, the client will try to reconnect and will not listen try to take care of new tasks, but will let them wait in line, + * so the update connection password will not be able to reach the connection and will return an error. + * For future versions, standalone will be considered as a different animal then it is now, since standalone is not necessarily one node. + * It can be replicated and have a lot of nodes, and to be what we like to call "one shard cluster". + * So, in the future, we will have many existing connection and request can be managed also when one connection is locked, + */ + var exception = + assertThrows( + ExecutionException.class, + () -> testClient.updateConnectionPassword(pwd, immediateAuth).get()); + } finally { + adminClient.configSet(Map.of("requirepass", "")).get(); + adminClient.close(); + } + } } diff --git a/java/src/lib.rs b/java/src/lib.rs index bc40503971..311d9a13dc 100644 --- a/java/src/lib.rs +++ b/java/src/lib.rs @@ -13,6 +13,9 @@ use glide_core::STREAM as TYPE_STREAM; use glide_core::STRING as TYPE_STRING; use glide_core::ZSET as TYPE_ZSET; +// Telemetry required for getStatistics +use glide_core::Telemetry; + use bytes::Bytes; use jni::errors::Error as JniError; use jni::objects::{JByteArray, JClass, JObject, JObjectArray, JString}; @@ -22,6 +25,7 @@ use redis::Value; use std::sync::mpsc; mod errors; +mod linked_hashmap; use errors::{handle_errors, handle_panics, FFIError}; @@ -580,6 +584,38 @@ pub extern "system" fn Java_glide_ffi_resolvers_ObjectTypeResolver_getTypeStream safe_create_jstring(env, TYPE_STREAM, "getTypeStreamConstant") } +/// Returns a Java's `HashMap` representing the statistics collected for this process. +/// +/// This function is meant to be invoked by Java using JNI. +/// +/// * `env` - The JNI environment. +/// * `_class` - The class object. Not used. +#[no_mangle] +pub extern "system" fn Java_glide_ffi_resolvers_StatisticsResolver_getStatistics<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, +) -> JObject<'local> { + let Some(mut map) = linked_hashmap::new_linked_hashmap(&mut env) else { + return JObject::null(); + }; + + linked_hashmap::put_strings( + &mut env, + &mut map, + "total_connections", + &format!("{}", Telemetry::total_connections()), + ); + + linked_hashmap::put_strings( + &mut env, + &mut map, + "total_clients", + &format!("{}", Telemetry::total_clients()), + ); + + map +} + /// Convert a Rust string to a Java String and handle errors. /// /// * `env` - The JNI environment. diff --git a/java/src/linked_hashmap.rs b/java/src/linked_hashmap.rs new file mode 100644 index 0000000000..2e1ac7fd9b --- /dev/null +++ b/java/src/linked_hashmap.rs @@ -0,0 +1,70 @@ +use crate::errors; +use jni::{objects::JObject, JNIEnv}; + +const LINKED_HASHMAP: &str = "java/util/LinkedHashMap"; +const LINKED_HASHMAP_PUT_SIG: &str = "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;"; + +/// Create new Java `LinkedHashMap` +pub fn new_linked_hashmap<'a>(env: &mut JNIEnv<'a>) -> Option> { + let hash_map = env.new_object(LINKED_HASHMAP, "()V", &[]); + let Ok(hash_map) = hash_map else { + errors::throw_java_exception( + env, + errors::ExceptionType::RuntimeException, + "Failed to allocated LinkedHashMap", + ); + return None; + }; + Some(hash_map) +} + +/// Put `key` / `value` pair into the `map`, where both `key` and `value` are of type `&str` +/// This method is provided for convenience +pub fn put_strings<'a>(env: &mut JNIEnv<'a>, map: &mut JObject<'a>, key: &str, value: &str) { + let Some(key) = string_to_jobject(env, key) else { + return; + }; + let Some(value) = string_to_jobject(env, value) else { + return; + }; + put_objects(env, map, key, value) +} + +/// Put `key` / `value` pair into the `map`, where both `key` and `value` are of type `JObject` +pub fn put_objects<'a>( + env: &mut JNIEnv<'a>, + map: &mut JObject<'a>, + key: JObject<'a>, + value: JObject<'a>, +) { + if env + .call_method( + &map, + "put", + LINKED_HASHMAP_PUT_SIG, + &[(&key).into(), (&value).into()], + ) + .is_err() + { + errors::throw_java_exception( + env, + errors::ExceptionType::RuntimeException, + "Failed to call LinkedHashMap::put method", + ); + } +} + +/// Construct new Java string from Rust's `str` +fn string_to_jobject<'a>(env: &mut JNIEnv<'a>, string: &str) -> Option> { + match env.new_string(string) { + Ok(obj) => Some(JObject::from(obj)), + Err(_) => { + errors::throw_java_exception( + env, + errors::ExceptionType::RuntimeException, + "Failed to create Java string", + ); + None + } + } +} diff --git a/logger_core/src/lib.rs b/logger_core/src/lib.rs index d1054339a8..a31e4fdbc2 100644 --- a/logger_core/src/lib.rs +++ b/logger_core/src/lib.rs @@ -24,6 +24,9 @@ use tracing_subscriber::{ prelude::*, reload::{self, Handle}, }; + +use std::str::FromStr; + // Layer-Filter pair determines whether a log will be collected type InnerFiltered = Filtered, LevelFilter, Registry>; // A Reloadable pair of layer-filter @@ -52,6 +55,7 @@ pub static INITIATE_ONCE: InitiateOnce = InitiateOnce { }; const FILE_DIRECTORY: &str = "glide-logs"; +const ENV_GLIDE_LOG_DIR: &str = "GLIDE_LOG_DIR"; /// Wraps [RollingFileAppender] to defer initialization until logging is required, /// allowing [init] to disable file logging on read-only filesystems. @@ -114,6 +118,21 @@ impl Level { } } +/// Attempt to read a directory path from an environment variable. If the environment variable `envname` exists +/// and contains a valid path - this function will create and return that path. In any case of failure, +/// this method returns `None` (e.g. the environment variable exists but contains an empty path etc) +pub fn create_directory_from_env(envname: &str) -> Option { + let Ok(dirpath) = std::env::var(envname) else { + return None; + }; + + if dirpath.trim().is_empty() || std::fs::create_dir_all(&dirpath).is_err() { + return None; + } + + Some(dirpath) +} + // Initialize the global logger to error level on the first call only // In any of the calls to the function, including the first - resetting the existence loggers to the new setting // provided by using the global reloadable handle @@ -128,22 +147,34 @@ pub fn init(minimal_level: Option, file_name: Option<&str>) -> Level { let (stdout_layer, stdout_reload) = reload::Layer::new(stdout_fmt); + // Check if the environment variable GLIDE_LOG is set + let logs_dir = + create_directory_from_env(ENV_GLIDE_LOG_DIR).unwrap_or(FILE_DIRECTORY.to_string()); let file_appender = LazyRollingFileAppender::new( Rotation::HOURLY, - FILE_DIRECTORY, + logs_dir, file_name.unwrap_or("output.log"), ); + let file_fmt = tracing_subscriber::fmt::layer() .with_writer(file_appender) .with_filter(LevelFilter::OFF); let (file_layer, file_reload) = reload::Layer::new(file_fmt); + // If user has set the environment variable "RUST_LOG" with a valid log verbosity, use it + let log_level = if let Ok(level) = std::env::var("RUST_LOG") { + let trace_level = tracing::Level::from_str(&level).unwrap_or(tracing::Level::TRACE); + LevelFilter::from(trace_level) + } else { + LevelFilter::TRACE + }; + // Enable logging only from allowed crates let targets_filter = filter::Targets::new() - .with_target("glide", LevelFilter::TRACE) - .with_target("redis", LevelFilter::TRACE) - .with_target("logger_core", LevelFilter::TRACE) - .with_target(std::env!("CARGO_PKG_NAME"), LevelFilter::TRACE); + .with_target("glide", log_level) + .with_target("redis", log_level) + .with_target("logger_core", log_level) + .with_target(std::env!("CARGO_PKG_NAME"), log_level); tracing_subscriber::registry() .with(stdout_layer) @@ -174,8 +205,10 @@ pub fn init(minimal_level: Option, file_name: Option<&str>) -> Level { }); } Some(file) => { - let file_appender = - LazyRollingFileAppender::new(Rotation::HOURLY, FILE_DIRECTORY, file); + // Check if the environment variable GLIDE_LOG is set + let logs_dir = + create_directory_from_env(ENV_GLIDE_LOG_DIR).unwrap_or(FILE_DIRECTORY.to_string()); + let file_appender = LazyRollingFileAppender::new(Rotation::HOURLY, logs_dir, file); let _ = reloads .file_reload .write() @@ -237,3 +270,39 @@ pub fn log, Identifier: AsRef>( Level::Off => (), } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_directory_from_env() { + let dir_path = format!("{}/glide-logs", std::env::temp_dir().display()); + // Case 1: try to create an already existing folder + // make sure we are starting fresh + let _ = std::fs::remove_dir_all(&dir_path); + // Create the directory + assert!(std::fs::create_dir_all(&dir_path).is_ok()); + + std::env::set_var(ENV_GLIDE_LOG_DIR, &dir_path); + assert!(create_directory_from_env(ENV_GLIDE_LOG_DIR).is_some()); + assert!(std::fs::metadata(&dir_path).is_ok()); + + // Case 2: try to create a new folder (i.e. the folder does not already exist) + let _ = std::fs::remove_dir_all(&dir_path); + + // Create the directory + assert!(std::fs::create_dir_all(&dir_path).is_ok()); + assert!(std::fs::metadata(&dir_path).is_ok()); + + std::env::set_var(ENV_GLIDE_LOG_DIR, &dir_path); + assert!(create_directory_from_env(ENV_GLIDE_LOG_DIR).is_some()); + + // make sure we are starting fresh + let _ = std::fs::remove_dir_all(&dir_path); + + // Case 3: empty variable is not acceptable + std::env::set_var(ENV_GLIDE_LOG_DIR, ""); + assert!(create_directory_from_env(ENV_GLIDE_LOG_DIR).is_none()); + } +} diff --git a/node/DEVELOPER.md b/node/DEVELOPER.md index f71966862e..8878fdd91d 100644 --- a/node/DEVELOPER.md +++ b/node/DEVELOPER.md @@ -65,11 +65,8 @@ Before starting this step, make sure you've installed all software requirments. git clone https://github.com/valkey-io/valkey-glide.git cd valkey-glide ``` -2. Initialize git submodule: - ```bash - git submodule update --init --recursive - ``` -3. Install all node dependencies: +2. Install all node dependencies: + ```bash cd node npm i @@ -77,7 +74,8 @@ Before starting this step, make sure you've installed all software requirments. npm i cd .. ``` -4. Build the Node wrapper (Choose a build option from the following and run it from the `node` folder): + +3. Build the Node wrapper (Choose a build option from the following and run it from the `node` folder): 1. Build in release mode, stripped from all debug symbols (optimized and minimized binary size): @@ -99,14 +97,14 @@ Before starting this step, make sure you've installed all software requirments. Once building completed, you'll find the compiled JavaScript code in the`./build-ts` folder. -5. Run tests: +4. Run tests: 1. Ensure that you have installed server and valkey-cli on your host. You can download Valkey at the following link: [Valkey Download page](https://valkey.io/download/). 2. Execute the following command from the node folder: ```bash npm run build # make sure we have a debug build compiled first npm test ``` -6. Integrating the built GLIDE package into your project: +5. Integrating the built GLIDE package into your project: Add the package to your project using the folder path with the command `npm install /node`. - For a fast build, execute `npm run build`. This will perform a full, unoptimized build, which is suitable for developing tests. Keep in mind that performance is significantly affected in an unoptimized build, so it's required to build with the `build:release` or `build:benchmark` option when measuring performance. @@ -128,19 +126,65 @@ To run tests, use the following command: npm test ``` -To execute a specific test, include the [`testNamePattern`](https://jestjs.io/docs/cli#--testnamepatternregex) option. For example: +Simplified test suite skips few time consuming tests and runs faster: + +```bash +npm test-minimum +``` + +To execute a specific test, use the [`testNamePattern`](https://jestjs.io/docs/cli#--testnamepatternregex) option with `test-dbg` script. For example: ```bash -npm run test -- --testNamePattern="transaction" +npm run test-dbg -- --testNamePattern="transaction" ``` IT suite starts the server for testing - standalone and cluster installation using `cluster_manager` script. To run the integration tests with existing servers, run the following command: ```bash -npm run test -- --cluster-endpoints=localhost:7000 --standalone-endpoints=localhost:6379 +npm run test-dbg -- --cluster-endpoints=localhost:7000 --standalone-endpoints=localhost:6379 + +# If those endpoints use TLS, add `--tls=true` (applies to both endpoints) +npm run test-dbg -- --cluster-endpoints=localhost:7000 --standalone-endpoints=localhost:6379 --tls=true +``` + +Parameters `cluster-endpoints`, `standalone-endpoints` and `tls` could be used with all test suites. + +By default, the server modules tests do not run using `npm run test`. This test suite also does not start the server. +In order to run these tests, use: + +```bash +npm run test-modules -- --cluster-endpoints=
      : +``` +Note: these tests don't run with standalone server as of now. + +### REPL (interactive shell) + +It is possible to run an interactive shell synced with the currect client code to test and debug it: + +```bash +npx ts-node --project tsconfig.json +``` + +This shell allows executing typescript and javascript code line by line: + +```typescript +import { GlideClient, GlideClusterClient } from "."; +let client = await GlideClient.createClient({ + addresses: [{ host: "localhost", port: 6379 }], +}); +let clusterClient = await GlideClusterClient.createClient({ + addresses: [{ host: "localhost", port: 7000 }], +}); +await client.ping(); ``` +After applying changes in client code you need to restart the shell. + +It has command history and bash-like search (`Ctrl+R`). + +Shell hangs on exit (`Ctrl+D`) if you don't close the clients. Use `Ctrl+C` to kill it and/or close clients before exit. + ### Submodules After pulling new changes, ensure that you update the submodules by running the following command: @@ -173,7 +217,7 @@ Development on the Node wrapper may involve changes in either the TypeScript or # Run from the node folder npm run lint # To automatically apply ESLint and/or prettier recommendations - npx run lint:fix + npm run lint:fix ``` 2. Rust diff --git a/node/README.md b/node/README.md index 661e742b96..eeaf4e4a60 100644 --- a/node/README.md +++ b/node/README.md @@ -19,7 +19,7 @@ Linux: macOS: -- macOS (12.7 and latest) (Apple silicon/aarch_64 and Intel/x86_64) +- macOS 14.7 (Apple silicon/aarch_64) Alpine: @@ -96,9 +96,9 @@ client.close(); ### Supported platforms -Currentlly the package is supported on: +Currently, the package is tested on: | Operation systems | C lib | Architecture | | ----------------- | -------------------- | ----------------- | | `Linux` | `glibc`, `musl libc` | `x86_64`, `arm64` | -| `macOS` | `Darwin` | `x86_64`, `arm64` | +| `macOS` | `Darwin` | `arm64` | diff --git a/node/index.ts b/node/index.ts index e1f04fd9c2..ee035c2a49 100644 --- a/node/index.ts +++ b/node/index.ts @@ -9,4 +9,7 @@ export * from "./src/Errors"; export * from "./src/GlideClient"; export * from "./src/GlideClusterClient"; export * from "./src/Logger"; +export * from "./src/server-modules/GlideJson"; +export * from "./src/server-modules/GlideFt"; +export * from "./src/server-modules/GlideFtOptions"; export * from "./src/Transaction"; diff --git a/node/jest.config.js b/node/jest.config.js index 6952aecfca..607c4c0830 100644 --- a/node/jest.config.js +++ b/node/jest.config.js @@ -29,5 +29,5 @@ module.exports = { }, ], ], - setupFilesAfterEnv: ["./tests/setup.js"], + setupFilesAfterEnv: ["./tests/setup.ts"], }; diff --git a/node/npm/glide/index.ts b/node/npm/glide/index.ts index 98171bfef4..da719c51c7 100644 --- a/node/npm/glide/index.ts +++ b/node/npm/glide/index.ts @@ -84,6 +84,7 @@ function initialize() { const { AggregationType, BaseScanOptions, + ScanOptions, ZScanOptions, HScanOptions, BitEncoding, @@ -117,8 +118,32 @@ function initialize() { GlideClient, GlideClusterClient, GlideClientConfiguration, + GlideJson, + GlideFt, + Field, + TextField, + TagField, + NumericField, + VectorField, + VectorFieldAttributesFlat, + VectorFieldAttributesHnsw, + FtCreateOptions, + FtSearchOptions, + FtInfoReturnType, + FtAggregateOptions, + FtAggregateLimit, + FtAggregateFilter, + FtAggregateGroupBy, + FtAggregateReducer, + FtAggregateSortBy, + FtAggregateSortProperty, + FtAggregateApply, + FtAggregateReturnType, + FtSearchReturnType, GlideRecord, GlideString, + JsonGetOptions, + JsonArrPopOptions, SortedSetDataType, StreamEntryDataType, HashDataType, @@ -151,16 +176,13 @@ function initialize() { InfBoundary, KeyWeight, Boundary, - UpdateOptions, ProtocolVersion, RangeByIndex, RangeByScore, RangeByLex, ReadFrom, ServerCredentials, - SortClusterOptions, SortOptions, - SortedSetRange, StreamGroupOptions, StreamTrimOptions, StreamAddOptions, @@ -191,7 +213,6 @@ function initialize() { createLeakedDouble, createLeakedMap, createLeakedString, - parseInfoResponse, Script, ObjectType, ClusterScanCursor, @@ -202,11 +223,14 @@ function initialize() { ReturnTypeMap, ClusterResponse, ReturnTypeAttribute, + ReturnTypeJson, + UniversalReturnTypeJson, } = nativeBinding; module.exports = { AggregationType, BaseScanOptions, + ScanOptions, HScanOptions, ZScanOptions, BitEncoding, @@ -226,8 +250,32 @@ function initialize() { Decoder, DecoderOption, GeoAddOptions, + GlideFt, + Field, + TextField, + TagField, + NumericField, + VectorField, + VectorFieldAttributesFlat, + VectorFieldAttributesHnsw, + FtCreateOptions, + FtSearchOptions, + FtInfoReturnType, + FtAggregateOptions, + FtAggregateLimit, + FtAggregateFilter, + FtAggregateGroupBy, + FtAggregateReducer, + FtAggregateSortBy, + FtAggregateSortProperty, + FtAggregateApply, + FtAggregateReturnType, + FtSearchReturnType, GlideRecord, + GlideJson, GlideString, + JsonGetOptions, + JsonArrPopOptions, SortedSetDataType, StreamEntryDataType, HashDataType, @@ -276,16 +324,13 @@ function initialize() { InfBoundary, KeyWeight, Boundary, - UpdateOptions, ProtocolVersion, RangeByIndex, RangeByScore, RangeByLex, ReadFrom, ServerCredentials, - SortClusterOptions, SortOptions, - SortedSetRange, StreamGroupOptions, StreamTrimOptions, StreamAddOptions, @@ -314,7 +359,6 @@ function initialize() { createLeakedDouble, createLeakedMap, createLeakedString, - parseInfoResponse, Script, ObjectType, ClusterScanCursor, @@ -325,6 +369,8 @@ function initialize() { ReturnTypeMap, ClusterResponse, ReturnTypeAttribute, + ReturnTypeJson, + UniversalReturnTypeJson, }; globalObject = Object.assign(global, nativeBinding); diff --git a/node/npm/glide/package.json b/node/npm/glide/package.json index 9514160893..78ec8d0821 100644 --- a/node/npm/glide/package.json +++ b/node/npm/glide/package.json @@ -10,8 +10,8 @@ "lint": "eslint .", "lint:fix": "eslint . --fix", "clean": "rm -rf build-ts/", - "copy-declaration-files": "cp ../../build-ts/*.d.ts build-ts/ && cp ../../build-ts/src/*.d.ts build-ts/src/", - "build": "tsc && mkdir -p build-ts/src && npm run copy-declaration-files" + "copy-declaration-files": "cp ../../build-ts/*.d.ts build-ts/ && cp ../../build-ts/src/*.d.ts build-ts/src/ && cp ../../build-ts/src/server-modules/*.d.ts build-ts/src/server-modules/", + "build": "tsc && mkdir -p build-ts/src && mkdir -p build-ts/src/server-modules && npm run copy-declaration-files" }, "files": [ "/build-ts" diff --git a/node/package.json b/node/package.json index 9b990f84e7..52a3b68260 100644 --- a/node/package.json +++ b/node/package.json @@ -33,7 +33,8 @@ "clean": "rm -rf build-ts rust-client/target docs glide-logs rust-client/glide-rs.*.node rust-client/index.* src/ProtobufMessage.*", "fix-protobuf-file": "replace 'this\\.encode\\(message, writer\\)\\.ldelim' 'this.encode(message, writer && writer.len ? writer.fork() : writer).ldelim' src/ProtobufMessage.js", "test": "npm run build-test-utils && jest --verbose --testPathIgnorePatterns='ServerModules'", - "test-minimum": "npm run build-test-utils && jest --verbose --runInBand --testNamePattern='^(.(?!(GlideJson|GlideFt|pubsub|kill)))*$'", + "test-dbg": "npm run build-test-utils && jest --runInBand", + "test-minimum": "npm run build-test-utils && jest --verbose --testNamePattern='^(.(?!(GlideJson|GlideFt|pubsub|kill)))*$'", "test-modules": "npm run build-test-utils && jest --verbose --testNamePattern='(GlideJson|GlideFt)'", "build-test-utils": "cd ../utils && npm i && npm run build", "lint:fix": "npm run install-linting && npx eslint -c ../eslint.config.mjs --fix && npm run prettier:format", @@ -59,8 +60,9 @@ "replace": "^1.2.2", "semver": "^7.6.3", "ts-jest": "^29.2.5", - "typescript": "^5.5.4", - "uuid": "^10.0.0" + "ts-node": "^10.9.2", + "typescript": "^5.6.3", + "uuid": "^11.0.3" }, "author": "Valkey GLIDE Maintainers", "license": "Apache-2.0", diff --git a/node/rust-client/Cargo.toml b/node/rust-client/Cargo.toml index e9e2af8851..f9baaf6cc2 100644 --- a/node/rust-client/Cargo.toml +++ b/node/rust-client/Cargo.toml @@ -11,7 +11,7 @@ authors = ["Valkey GLIDE Maintainers"] crate-type = ["cdylib"] [dependencies] -redis = { path = "../../submodules/redis-rs/redis", features = ["aio", "tokio-comp", "tokio-rustls-comp"] } +redis = { path = "../../glide-core/redis-rs/redis", features = ["aio", "tokio-comp", "tokio-rustls-comp"] } glide-core = { path = "../../glide-core", features = ["socket-layer"] } tokio = { version = "1", features = ["rt", "macros", "rt-multi-thread", "time"] } napi = {version = "2.14", features = ["napi4", "napi6"] } diff --git a/node/rust-client/src/lib.rs b/node/rust-client/src/lib.rs index a6e611c0f6..b15b18f521 100644 --- a/node/rust-client/src/lib.rs +++ b/node/rust-client/src/lib.rs @@ -1,3 +1,4 @@ +use glide_core::Telemetry; use redis::GlideConnectionOptions; /** * Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 @@ -43,6 +44,9 @@ pub const MAX_REQUEST_ARGS_LEN: u32 = MAX_REQUEST_ARGS_LENGTH as u32; pub const DEFAULT_TIMEOUT_IN_MILLISECONDS: u32 = glide_core::client::DEFAULT_RESPONSE_TIMEOUT.as_millis() as u32; +#[napi] +pub const DEFAULT_INFLIGHT_REQUESTS_LIMIT: u32 = glide_core::client::DEFAULT_MAX_INFLIGHT_REQUESTS; + #[napi] struct AsyncClient { #[allow(dead_code)] @@ -481,3 +485,14 @@ impl Drop for ClusterScanCursor { glide_core::cluster_scan_container::remove_scan_state_cursor(self.cursor.clone()); } } + +#[napi] +pub fn get_statistics(env: Env) -> Result { + let total_connections = Telemetry::total_connections().to_string(); + let total_clients = Telemetry::total_clients().to_string(); + let mut stats: JsObject = env.create_object()?; + stats.set_named_property("total_connections", total_connections)?; + stats.set_named_property("total_clients", total_clients)?; + + Ok(stats) +} diff --git a/node/src/BaseClient.ts b/node/src/BaseClient.ts index 2bfb9a3bbd..71b1ffd89c 100644 --- a/node/src/BaseClient.ts +++ b/node/src/BaseClient.ts @@ -3,9 +3,11 @@ */ import { ClusterScanCursor, + DEFAULT_INFLIGHT_REQUESTS_LIMIT, DEFAULT_TIMEOUT_IN_MILLISECONDS, Script, StartSocketConnection, + getStatistics, valueFromSplitPointer, } from "glide-rs"; import * as net from "net"; @@ -497,10 +499,74 @@ export type ReadFrom = | "primary" /** Spread the requests between all replicas in a round robin manner. If no replica is available, route the requests to the primary.*/ - | "preferReplica"; + | "preferReplica" + /** Spread the requests between replicas in the same client's Aviliablity zone in a round robin manner. + If no replica is available, route the requests to the primary.*/ + | "AZAffinity"; /** * Configuration settings for creating a client. Shared settings for standalone and cluster clients. + * + * @remarks + * The `BaseClientConfiguration` interface defines the foundational configuration options used when creating a client to connect to a Valkey server or cluster. It includes connection details, authentication, communication protocols, and various settings that influence the client's behavior and interaction with the server. + * + * ### Connection Details + * + * - **Addresses**: Use the `addresses` property to specify the hostnames and ports of the server(s) to connect to. + * - **Cluster Mode**: In cluster mode, the client will discover other nodes based on the provided addresses. + * - **Standalone Mode**: In standalone mode, only the provided nodes will be used. + * + * ### Security Settings + * + * - **TLS**: Enable secure communication using `useTLS`. + * - **Authentication**: Provide `credentials` to authenticate with the server. + * + * ### Communication Settings + * + * - **Request Timeout**: Set `requestTimeout` to specify how long the client should wait for a request to complete. + * - **Protocol Version**: Choose the serialization protocol using `protocol`. + * + * ### Client Identification + * + * - **Client Name**: Set `clientName` to identify the client connection. + * + * ### Read Strategy + * + * - Use `readFrom` to specify the client's read strategy (e.g., primary, preferReplica, AZAffinity). + * + * ### Availability Zone + * + * - Use `clientAz` to specify the client's availability zone, which can influence read operations when using `readFrom: 'AZAffinity'`. + * + * ### Decoder Settings + * + * - **Default Decoder**: Set `defaultDecoder` to specify how responses are decoded by default. + * + * ### Concurrency Control + * + * - **Inflight Requests Limit**: Control the number of concurrent requests using `inflightRequestsLimit`. + * + * @example + * ```typescript + * const config: BaseClientConfiguration = { + * addresses: [ + * { host: 'redis-node-1.example.com', port: 6379 }, + * { host: 'redis-node-2.example.com' }, // Defaults to port 6379 + * ], + * useTLS: true, + * credentials: { + * username: 'myUser', + * password: 'myPassword', + * }, + * requestTimeout: 5000, // 5 seconds + * protocol: ProtocolVersion.RESP3, + * clientName: 'myValkeyClient', + * readFrom: ReadFrom.AZAffinity, + * clientAz: 'us-east-1a', + * defaultDecoder: Decoder.String, + * inflightRequestsLimit: 1000, + * }; + * ``` */ export interface BaseClientConfiguration { /** @@ -563,6 +629,25 @@ export interface BaseClientConfiguration { * If not set, 'Decoder.String' will be used. */ defaultDecoder?: Decoder; + /** + * The maximum number of concurrent requests allowed to be in-flight (sent but not yet completed). + * This limit is used to control the memory usage and prevent the client from overwhelming the + * server or getting stuck in case of a queue backlog. If not set, a default value of 1000 will be + * used. + */ + inflightRequestsLimit?: number; + /** + * Availability Zone of the client. + * If ReadFrom strategy is AZAffinity, this setting ensures that readonly commands are directed to replicas within the specified AZ if exits. + * + * @example + * ```typescript + * // Example configuration for setting client availability zone and read strategy + * configuration.clientAz = 'us-east-1a'; // Sets the client's availability zone + * configuration.readFrom = 'AZAffinity'; // Directs read operations to nodes within the same AZ + * ``` + */ + clientAz?: string; } /** @@ -707,6 +792,8 @@ export class BaseClient { protected defaultDecoder = Decoder.String; private readonly pubsubFutures: [PromiseFunction, ErrorFunction][] = []; private pendingPushNotification: response.Response[] = []; + private readonly inflightRequestsLimit: number; + private readonly clientAz: string | undefined; private config: BaseClientConfiguration | undefined; protected configurePubsub( @@ -873,6 +960,8 @@ export class BaseClient { this.close(); }); this.defaultDecoder = options?.defaultDecoder ?? Decoder.String; + this.inflightRequestsLimit = + options?.inflightRequestsLimit ?? DEFAULT_INFLIGHT_REQUESTS_LIMIT; } protected getCallbackIndex(): number { @@ -904,7 +993,8 @@ export class BaseClient { | command_request.Command | command_request.Command[] | command_request.ScriptInvocation - | command_request.ClusterScan, + | command_request.ClusterScan + | command_request.UpdateConnectionPassword, options: WritePromiseOptions = {}, ): Promise { const route = toProtobufRoute(options?.route); @@ -974,7 +1064,8 @@ export class BaseClient { | command_request.Command | command_request.Command[] | command_request.ScriptInvocation - | command_request.ClusterScan, + | command_request.ClusterScan + | command_request.UpdateConnectionPassword, route?: command_request.Routes, ) { const message = Array.isArray(command) @@ -994,10 +1085,15 @@ export class BaseClient { callbackIdx, clusterScan: command, }) - : command_request.CommandRequest.create({ - callbackIdx, - scriptInvocation: command, - }); + : command instanceof command_request.UpdateConnectionPassword + ? command_request.CommandRequest.create({ + callbackIdx, + updateConnectionPassword: command, + }) + : command_request.CommandRequest.create({ + callbackIdx, + scriptInvocation: command, + }); message.route = route; this.writeOrBufferRequest( @@ -1385,6 +1481,14 @@ export class BaseClient { * * @see {@link https://valkey.io/commands/del/|valkey.io} for details. * + * @remarks In cluster mode, if keys in `keys` map to different hash slots, + * the command will be split across these slots and executed separately for each. + * This means the command is atomic only at the slot level. If one or more slot-specific + * requests fail, the entire call will return the first encountered error, even + * though some requests may have succeeded while others did not. + * If this behavior impacts your application logic, consider splitting the + * request into sub-requests per slot to ensure atomicity. + * * @param keys - The keys we wanted to remove. * @returns The number of keys that were removed. * @@ -1485,7 +1589,14 @@ export class BaseClient { /** Retrieve the values of multiple keys. * * @see {@link https://valkey.io/commands/mget/|valkey.io} for details. - * @remarks When in cluster mode, the command may route to multiple nodes when `keys` map to different hash slots. + * + * @remarks In cluster mode, if keys in `keys` map to different hash slots, + * the command will be split across these slots and executed separately for each. + * This means the command is atomic only at the slot level. If one or more slot-specific + * requests fail, the entire call will return the first encountered error, even + * though some requests may have succeeded while others did not. + * If this behavior impacts your application logic, consider splitting the + * request into sub-requests per slot to ensure atomicity. * * @param keys - A list of keys to retrieve values for. * @param options - (Optional) See {@link DecoderOption}. @@ -1511,10 +1622,18 @@ export class BaseClient { /** Set multiple keys to multiple values in a single operation. * * @see {@link https://valkey.io/commands/mset/|valkey.io} for details. - * @remarks When in cluster mode, the command may route to multiple nodes when keys in `keyValueMap` map to different hash slots. + * + * @remarks In cluster mode, if keys in `keyValueMap` map to different hash slots, + * the command will be split across these slots and executed separately for each. + * This means the command is atomic only at the slot level. If one or more slot-specific + * requests fail, the entire call will return the first encountered error, even + * though some requests may have succeeded while others did not. + * If this behavior impacts your application logic, consider splitting the + * request into sub-requests per slot to ensure atomicity. * * @param keysAndValues - A list of key-value pairs to set. - * @returns always "OK". + * + * @returns A simple "OK" response. * * @example * ```typescript @@ -2275,7 +2394,7 @@ export class BaseClient { * * @param key - The key of the set. * @param cursor - The cursor that points to the next iteration of results. A value of `"0"` indicates the start of the search. - * @param options - (Optional) The {@link HScanOptions}. + * @param options - (Optional) See {@link HScanOptions} and {@link DecoderOption}. * @returns An array of the `cursor` and the subset of the hash held by `key`. * The first element is always the `cursor` for the next iteration of results. `"0"` will be the `cursor` * returned on the last iteration of the hash. The second element is always an array of the subset of the @@ -3423,6 +3542,14 @@ export class BaseClient { /** * Returns the number of keys in `keys` that exist in the database. * + * @remarks In cluster mode, if keys in `keys` map to different hash slots, + * the command will be split across these slots and executed separately for each. + * This means the command is atomic only at the slot level. If one or more slot-specific + * requests fail, the entire call will return the first encountered error, even + * though some requests may have succeeded while others did not. + * If this behavior impacts your application logic, consider splitting the + * request into sub-requests per slot to ensure atomicity. + * * @see {@link https://valkey.io/commands/exists/|valkey.io} for details. * * @param keys - The keys list to check. @@ -3445,6 +3572,14 @@ export class BaseClient { * This command, similar to {@link del}, removes specified keys and ignores non-existent ones. * However, this command does not block the server, while {@link https://valkey.io/commands/del|`DEL`} does. * + * @remarks In cluster mode, if keys in `keys` map to different hash slots, + * the command will be split across these slots and executed separately for each. + * This means the command is atomic only at the slot level. If one or more slot-specific + * requests fail, the entire call will return the first encountered error, even + * though some requests may have succeeded while others did not. + * If this behavior impacts your application logic, consider splitting the + * request into sub-requests per slot to ensure atomicity. + * * @see {@link https://valkey.io/commands/unlink/|valkey.io} for details. * * @param keys - The keys we wanted to unlink. @@ -5499,7 +5634,6 @@ export class BaseClient { * attributes of a consumer group for the stream at `key`. * @example * ```typescript - *
      {@code
            * const result = await client.xinfoGroups("my_stream");
            * console.log(result); // Output:
            * // [
      @@ -5937,6 +6071,7 @@ export class BaseClient {
           > = {
               primary: connection_request.ReadFrom.Primary,
               preferReplica: connection_request.ReadFrom.PreferReplica,
      +        AZAffinity: connection_request.ReadFrom.AZAffinity,
           };
       
           /**
      @@ -5952,13 +6087,11 @@ export class BaseClient {
            *
            * @example
            * ```typescript
      -     *  
      {@code
            * const entryId = await client.xadd("mystream", ["myfield", "mydata"]);
            * // read messages from streamId
            * const readResult = await client.xreadgroup(["myfield", "mydata"], "mygroup", "my0consumer");
            * // acknowledge messages on stream
            * console.log(await client.xack("mystream", "mygroup", [entryId])); // Output: 1
      -     * 
      * ``` */ public async xack( @@ -7073,7 +7206,14 @@ export class BaseClient { * Updates the last access time of the specified keys. * * @see {@link https://valkey.io/commands/touch/|valkey.io} for more details. - * @remarks When in cluster mode, the command may route to multiple nodes when `keys` map to different hash slots. + * + * @remarks In cluster mode, if keys in `keys` map to different hash slots, + * the command will be split across these slots and executed separately for each. + * This means the command is atomic only at the slot level. If one or more slot-specific + * requests fail, the entire call will return the first encountered error, even + * though some requests may have succeeded while others did not. + * If this behavior impacts your application logic, consider splitting the + * request into sub-requests per slot to ensure atomicity. * * @param keys - The keys to update the last access time of. * @returns The number of keys that were updated. A key is ignored if it doesn't exist. @@ -7096,7 +7236,14 @@ export class BaseClient { * transaction. Executing a transaction will automatically flush all previously watched keys. * * @see {@link https://valkey.io/commands/watch/|valkey.io} and {@link https://valkey.io/topics/transactions/#cas|Valkey Glide Wiki} for more details. - * @remarks When in cluster mode, the command may route to multiple nodes when `keys` map to different hash slots. + * + * @remarks In cluster mode, if keys in `keys` map to different hash slots, + * the command will be split across these slots and executed separately for each. + * This means the command is atomic only at the slot level. If one or more slot-specific + * requests fail, the entire call will return the first encountered error, even + * though some requests may have succeeded while others did not. + * If this behavior impacts your application logic, consider splitting the + * request into sub-requests per slot to ensure atomicity. * * @param keys - The keys to watch. * @returns A simple `"OK"` response. @@ -7505,6 +7652,8 @@ export class BaseClient { clusterModeEnabled: false, readFrom, authenticationInfo, + inflightRequestsLimit: options.inflightRequestsLimit, + clientAz: options.clientAz ?? null, }; } @@ -7610,4 +7759,58 @@ export class BaseClient { throw err; } } + + /** + * Update the current connection with a new password. + * + * This method is useful in scenarios where the server password has changed or when utilizing short-lived passwords for enhanced security. + * It allows the client to update its password to reconnect upon disconnection without the need to recreate the client instance. + * This ensures that the internal reconnection mechanism can handle reconnection seamlessly, preventing the loss of in-flight commands. + * + * This method updates the client's internal password configuration and does not perform password rotation on the server side. + * + * @param password - `String | null`. The new password to update the current password, or `null` to remove the current password. + * @param immidiateAuth - A `boolean` flag. If `true`, the client will authenticate immediately with the new password against all connections, Using `AUTH` command. + * If password supplied is an empty string, the client will not perform auth and instead a warning will be returned. + * The default is `false`. + * + * @example + * ```typescript + * await client.updateConnectionPassword("newPassword", true) // "OK" + * ``` + */ + async updateConnectionPassword( + password: string | null, + immediateAuth = false, + ) { + const updateConnectionPassword = + command_request.UpdateConnectionPassword.create({ + password, + immediateAuth, + }); + + const response = await this.createWritePromise( + updateConnectionPassword, + ); + + if (response === "OK" && !this.config?.credentials) { + this.config = { + ...this.config!, + credentials: { + ...this.config!.credentials, + password: password ? password : "", + }, + }; + } + + return response; + } + /** + * Return a statistics + * + * @return Return an object that contains the statistics collected internally by GLIDE core + */ + public getStatistics(): object { + return getStatistics(); + } } diff --git a/node/src/GlideClient.ts b/node/src/GlideClient.ts index 3feaabdc56..acec0c377f 100644 --- a/node/src/GlideClient.ts +++ b/node/src/GlideClient.ts @@ -99,6 +99,43 @@ export namespace GlideClientConfiguration { } } +/** + * Configuration options for creating a {@link GlideClient | GlideClient}. + * + * Extends `BaseClientConfiguration` with properties specific to `GlideClient`, such as database selection, + * reconnection strategies, and Pub/Sub subscription settings. + * + * @remarks + * This configuration allows you to tailor the client's behavior when connecting to a standalone Valkey Glide server. + * + * - **Database Selection**: Use `databaseId` to specify which logical database to connect to. + * - **Reconnection Strategy**: Customize how the client should attempt reconnections using `connectionBackoff`. + * - `numberOfRetries`: The maximum number of retry attempts with increasing delays. + * - After this limit is reached, the retry interval becomes constant. + * - `factor`: A multiplier applied to the base delay between retries (e.g., `500` means a 500ms base delay). + * - `exponentBase`: The exponential growth factor for delays (e.g., `2` means the delay doubles with each retry). + * - **Pub/Sub Subscriptions**: Predefine Pub/Sub channels and patterns to subscribe to upon connection establishment. + * + * @example + * ```typescript + * const config: GlideClientConfiguration = { + * databaseId: 1, + * connectionBackoff: { + * numberOfRetries: 10, // Maximum retries before delay becomes constant + * factor: 500, // Base delay in milliseconds + * exponentBase: 2, // Delay doubles with each retry (2^N) + * }, + * pubsubSubscriptions: { + * channelsAndPatterns: { + * [GlideClientConfiguration.PubSubChannelModes.Pattern]: new Set(['news.*']), + * }, + * callback: (msg) => { + * console.log(`Received message on ${msg.channel}:`, msg.payload); + * }, + * }, + * }; + * ``` + */ export type GlideClientConfiguration = BaseClientConfiguration & { /** * index of the logical database to connect to. @@ -154,7 +191,53 @@ export class GlideClient extends BaseClient { this.configurePubsub(options, configuration); return configuration; } - + /** + * Creates a new `GlideClient` instance and establishes a connection to a standalone Valkey Glide server. + * + * @param options - The configuration options for the client, including server addresses, authentication credentials, TLS settings, database selection, reconnection strategy, and Pub/Sub subscriptions. + * @returns A promise that resolves to a connected `GlideClient` instance. + * + * @remarks + * Use this static method to create and connect a `GlideClient` to a standalone Valkey Glide server. The client will automatically handle connection establishment, including any authentication and TLS configurations. + * + * @example + * ```typescript + * // Connecting to a Standalone Server + * import { GlideClient, GlideClientConfiguration } from '@valkey/valkey-glide'; + * + * const client = await GlideClient.createClient({ + * addresses: [ + * { host: 'primary.example.com', port: 6379 }, + * { host: 'replica1.example.com', port: 6379 }, + * ], + * databaseId: 1, + * credentials: { + * username: 'user1', + * password: 'passwordA', + * }, + * useTLS: true, + * connectionBackoff: { + * numberOfRetries: 5, + * factor: 1000, + * exponentBase: 2, + * }, + * pubsubSubscriptions: { + * channelsAndPatterns: { + * [GlideClientConfiguration.PubSubChannelModes.Exact]: new Set(['updates']), + * }, + * callback: (msg) => { + * console.log(`Received message: ${msg.payload}`); + * }, + * }, + * }); + * ``` + * + * @remarks + * - **Authentication**: If `credentials` are provided, the client will attempt to authenticate using the specified username and password. + * - **TLS**: If `useTLS` is set to `true`, the client will establish a secure connection using TLS. + * - **Reconnection Strategy**: The `connectionBackoff` settings define how the client will attempt to reconnect in case of disconnections. + * - **Pub/Sub Subscriptions**: Any channels or patterns specified in `pubsubSubscriptions` will be subscribed to upon connection. + */ public static async createClient( options: GlideClientConfiguration, ): Promise { @@ -164,7 +247,9 @@ export class GlideClient extends BaseClient { new GlideClient(socket, options), ); } - + /** + * @internal + */ static async __createClient( options: BaseClientConfiguration, connectedSocket: net.Socket, diff --git a/node/src/GlideClusterClient.ts b/node/src/GlideClusterClient.ts index 92f8c439ee..0524128dd5 100644 --- a/node/src/GlideClusterClient.ts +++ b/node/src/GlideClusterClient.ts @@ -121,7 +121,9 @@ export namespace GlideClusterClientConfiguration { */ Sharded = 2, } - + /** + * Configuration for Pub/Sub subscriptions that the client will establish upon connection. + */ export interface PubSubSubscriptions { /** * Channels and patterns by modes. @@ -141,6 +143,39 @@ export namespace GlideClusterClientConfiguration { context?: any; } } +/** + * Configuration options for creating a {@link GlideClusterClient | GlideClusterClient}. + * + * Extends `BaseClientConfiguration` with properties specific to `GlideClusterClient`, such as periodic topology checks + * and Pub/Sub subscription settings. + * + * @remarks + * This configuration allows you to tailor the client's behavior when connecting to a Valkey GLIDE Cluster. + * + * - **Periodic Topology Checks**: Use `periodicChecks` to configure how the client performs periodic checks to detect changes in the cluster's topology. + * - `"enabledDefaultConfigs"`: Enables periodic checks with default configurations. + * - `"disabled"`: Disables periodic topology checks. + * - `{ duration_in_sec: number }`: Manually configure the interval for periodic checks. + * - **Pub/Sub Subscriptions**: Predefine Pub/Sub channels and patterns to subscribe to upon connection establishment. + * - Supports exact channels, patterns, and sharded channels (available since Valkey version 7.0). + * + * @example + * ```typescript + * const config: GlideClusterClientConfiguration = { + * periodicChecks: { + * duration_in_sec: 30, // Perform periodic checks every 30 seconds + * }, + * pubsubSubscriptions: { + * channelsAndPatterns: { + * [GlideClusterClientConfiguration.PubSubChannelModes.Pattern]: new Set(['cluster.*']), + * }, + * callback: (msg) => { + * console.log(`Received message on ${msg.channel}:`, msg.payload); + * }, + * }, + * }; + * ``` + */ export type GlideClusterClientConfiguration = BaseClientConfiguration & { /** * Configure the periodic topology checks. @@ -194,6 +229,37 @@ function convertClusterGlideRecord( ? (res as T) : convertGlideRecordToRecord(res as GlideRecord); } +/** + * Routing configuration for commands based on a specific slot ID in a Valkey cluster. + * + * @remarks + * This interface allows you to specify routing of a command to a node responsible for a particular slot ID in the cluster. + * Valkey clusters use hash slots to distribute data across multiple shards. There are 16,384 slots in total, and each shard + * manages a range of slots. + * + * - **Slot ID**: A number between 0 and 16383 representing a hash slot. + * - **Routing Type**: + * - `"primarySlotId"`: Routes the command to the primary node responsible for the specified slot ID. + * - `"replicaSlotId"`: Routes the command to a replica node responsible for the specified slot ID, overriding the `readFrom` configuration. + * + * @example + * ```typescript + * // Route command to the primary node responsible for slot ID 12345 + * const routeBySlotId: SlotIdTypes = { + * type: "primarySlotId", + * id: 12345, + * }; + * + * // Route command to a replica node responsible for slot ID 12345 + * const routeToReplicaBySlotId: SlotIdTypes = { + * type: "replicaSlotId", + * id: 12345, + * }; + * + * // Use the routing configuration when executing a command + * const result = await client.get("mykey", { route: routeBySlotId }); + * ``` + */ export interface SlotIdTypes { /** @@ -207,7 +273,36 @@ export interface SlotIdTypes { */ id: number; } - +/** + * Routing configuration for commands based on a key in a Valkey cluster. + * + * @remarks + * This interface allows you to specify routing of a command to a node responsible for the slot that a specific key hashes to. + * Valkey clusters use consistent hashing to map keys to hash slots, which are then managed by different shards in the cluster. + * + * - **Key**: The key whose hash slot will determine the routing of the command. + * - **Routing Type**: + * - `"primarySlotKey"`: Routes the command to the primary node responsible for the key's slot. + * - `"replicaSlotKey"`: Routes the command to a replica node responsible for the key's slot, overriding the `readFrom` configuration. + * + * @example + * ```typescript + * // Route command to the primary node responsible for the key's slot + * const routeByKey: SlotKeyTypes = { + * type: "primarySlotKey", + * key: "user:1001", + * }; + * + * // Route command to a replica node responsible for the key's slot + * const routeToReplicaByKey: SlotKeyTypes = { + * type: "replicaSlotKey", + * key: "user:1001", + * }; + * + * // Use the routing configuration when executing a command + * const result = await client.get("user:1001", { route: routeByKey }); + * ``` + */ export interface SlotKeyTypes { /** * `replicaSlotKey` overrides the `readFrom` configuration. If it's used the request @@ -220,7 +315,39 @@ export interface SlotKeyTypes { key: string; } -/// Route command to specific node. +/** + * Routing configuration to send a command to a specific node by its address and port. + * + * @remarks + * This interface allows you to specify routing of a command to a node in the Valkey cluster by providing its network address and port. + * It's useful when you need to direct a command to a particular node. + * + * - **Type**: Must be set to `"routeByAddress"` to indicate that the routing should be based on the provided address. + * - **Host**: The endpoint of the node. + * - If `port` is not provided, `host` should be in the format `${address}:${port}`, where `address` is the preferred endpoint as shown in the output of the `CLUSTER SLOTS` command. + * - If `port` is provided, `host` should be the address or hostname of the node without the port. + * - **Port**: (Optional) The port to access on the node. + * - If `port` is not provided, `host` is assumed to include the port number. + * + * @example + * ```typescript + * // Route command to a node at '192.168.1.10:6379' + * const routeByAddress: RouteByAddress = { + * type: "routeByAddress", + * host: "192.168.1.10", + * port: 6379, + * }; + * + * // Alternatively, include the port in the host string + * const routeByAddressWithPortInHost: RouteByAddress = { + * type: "routeByAddress", + * host: "192.168.1.10:6379", + * }; + * + * // Use the routing configuration when executing a command + * const result = await client.ping({ route: routeByAddress }); + * ``` + */ export interface RouteByAddress { type: "routeByAddress"; /** @@ -232,7 +359,52 @@ export interface RouteByAddress { */ port?: number; } - +/** + * Defines the routing configuration for a command in a Valkey cluster. + * + * @remarks + * The `Routes` type allows you to specify how a command should be routed in a Valkey cluster. + * Commands can be routed to a single node or broadcast to multiple nodes depending on the routing strategy. + * + * **Routing Options**: + * + * - **Single Node Routing** (`SingleNodeRoute`): + * - **"randomNode"**: Route the command to a random node in the cluster. + * - **`SlotIdTypes`**: Route based on a specific slot ID. + * - **`SlotKeyTypes`**: Route based on the slot of a specific key. + * - **`RouteByAddress`**: Route to a specific node by its address and port. + * - **Broadcast Routing**: + * - **"allPrimaries"**: Route the command to all primary nodes in the cluster. + * - **"allNodes"**: Route the command to all nodes (both primaries and replicas) in the cluster. + * + * @example + * ```typescript + * // Route command to a random node + * const routeRandom: Routes = "randomNode"; + * + * // Route command to all primary nodes + * const routeAllPrimaries: Routes = "allPrimaries"; + * + * // Route command to all nodes + * const routeAllNodes: Routes = "allNodes"; + * + * // Route command to a node by slot key + * const routeByKey: Routes = { + * type: "primarySlotKey", + * key: "myKey", + * }; + * + * // Route command to a specific node by address + * const routeByAddress: Routes = { + * type: "routeByAddress", + * host: "192.168.1.10", + * port: 6379, + * }; + * + * // Use the routing configuration when executing a command + * const result = await client.ping({ route: routeByAddress }); + * ``` + */ export type Routes = | SingleNodeRoute /** @@ -243,7 +415,48 @@ export type Routes = * Route request to all nodes. */ | "allNodes"; - +/** + * Defines the routing configuration to a single node in the Valkey cluster. + * + * @remarks + * The `SingleNodeRoute` type allows you to specify routing of a command to a single node in the cluster. + * This can be based on various criteria such as a random node, a node responsible for a specific slot, or a node identified by its address. + * + * **Options**: + * + * - **"randomNode"**: Route the command to a random node in the cluster. + * - **`SlotIdTypes`**: Route to the node responsible for a specific slot ID. + * - **`SlotKeyTypes`**: Route to the node responsible for the slot of a specific key. + * - **`RouteByAddress`**: Route to a specific node by its address and port. + * + * @example + * ```typescript + * // Route to a random node + * const routeRandomNode: SingleNodeRoute = "randomNode"; + * + * // Route based on slot ID + * const routeBySlotId: SingleNodeRoute = { + * type: "primarySlotId", + * id: 12345, + * }; + * + * // Route based on key + * const routeByKey: SingleNodeRoute = { + * type: "primarySlotKey", + * key: "myKey", + * }; + * + * // Route to a specific node by address + * const routeByAddress: SingleNodeRoute = { + * type: "routeByAddress", + * host: "192.168.1.10", + * port: 6379, + * }; + * + * // Use the routing configuration when executing a command + * const result = await client.get("myKey", { route: routeByKey }); + * ``` + */ export type SingleNodeRoute = /** * Route request to a random node. @@ -293,7 +506,51 @@ export class GlideClusterClient extends BaseClient { this.configurePubsub(options, configuration); return configuration; } - + /** + * Creates a new `GlideClusterClient` instance and establishes connections to a Valkey GLIDE Cluster. + * + * @param options - The configuration options for the client, including cluster addresses, authentication credentials, TLS settings, periodic checks, and Pub/Sub subscriptions. + * @returns A promise that resolves to a connected `GlideClusterClient` instance. + * + * @remarks + * Use this static method to create and connect a `GlideClusterClient` to a Valkey GLIDE Cluster. The client will automatically handle connection establishment, including cluster topology discovery and handling of authentication and TLS configurations. + * + * ### Example - Connecting to a Cluster + * ```typescript + * import { GlideClusterClient, GlideClusterClientConfiguration } from '@valkey/valkey-glide'; + * + * const client = await GlideClusterClient.createClient({ + * addresses: [ + * { host: 'address1.example.com', port: 6379 }, + * { host: 'address2.example.com', port: 6379 }, + * ], + * credentials: { + * username: 'user1', + * password: 'passwordA', + * }, + * useTLS: true, + * periodicChecks: { + * duration_in_sec: 30, // Perform periodic checks every 30 seconds + * }, + * pubsubSubscriptions: { + * channelsAndPatterns: { + * [GlideClusterClientConfiguration.PubSubChannelModes.Exact]: new Set(['updates']), + * [GlideClusterClientConfiguration.PubSubChannelModes.Sharded]: new Set(['sharded_channel']), + * }, + * callback: (msg) => { + * console.log(`Received message: ${msg.payload}`); + * }, + * }, + * }); + * ``` + * + * @remarks + * - **Cluster Topology Discovery**: The client will automatically discover the cluster topology based on the seed addresses provided. + * - **Authentication**: If `credentials` are provided, the client will attempt to authenticate using the specified username and password. + * - **TLS**: If `useTLS` is set to `true`, the client will establish secure connections using TLS. + * - **Periodic Checks**: The `periodicChecks` setting allows you to configure how often the client checks for cluster topology changes. + * - **Pub/Sub Subscriptions**: Any channels or patterns specified in `pubsubSubscriptions` will be subscribed to upon connection. + */ public static async createClient( options: GlideClusterClientConfiguration, ): Promise { @@ -303,7 +560,9 @@ export class GlideClusterClient extends BaseClient { new GlideClusterClient(socket, options), ); } - + /** + * @internal + */ static async __createClient( options: BaseClientConfiguration, connectedSocket: net.Socket, @@ -706,7 +965,7 @@ export class GlideClusterClient extends BaseClient { * @example * ```typescript * // Example usage of configSet method to set multiple configuration parameters - * const result = await client.configSet({ timeout: "1000", maxmemory, "1GB" }); + * const result = await client.configSet({ timeout: "1000", maxmemory: "1GB" }); * console.log(result); // Output: 'OK' * ``` */ @@ -1462,7 +1721,7 @@ export class GlideClusterClient extends BaseClient { * @example * ```typescript * const luaScript = new Script("return { ARGV[1] }"); - * const result = await invokeScript(luaScript, { args: ["bar"] }); + * const result = await client.invokeScript(luaScript, { args: ["bar"] }); * console.log(result); // Output: ['bar'] * ``` */ diff --git a/node/src/server-modules/GlideFt.ts b/node/src/server-modules/GlideFt.ts new file mode 100644 index 0000000000..61e7c36a9f --- /dev/null +++ b/node/src/server-modules/GlideFt.ts @@ -0,0 +1,892 @@ +/** + * Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + */ + +import { + convertGlideRecordToRecord, + Decoder, + DecoderOption, + GlideRecord, + GlideReturnType, + GlideString, +} from "../BaseClient"; +import { GlideClient } from "../GlideClient"; +import { GlideClusterClient } from "../GlideClusterClient"; +import { + Field, + FtAggregateOptions, + FtCreateOptions, + FtSearchOptions, +} from "./GlideFtOptions"; + +/** Response type of {@link GlideFt.info | ft.info} command. */ +export type FtInfoReturnType = Record< + string, + | GlideString + | number + | GlideString[] + | Record[]> +>; + +/** + * Response type for the {@link GlideFt.search | ft.search} command. + */ +export type FtSearchReturnType = [ + number, + GlideRecord>, +]; + +/** + * Response type for the {@link GlideFt.aggregate | ft.aggregate} command. + */ +export type FtAggregateReturnType = GlideRecord[]; + +/** Module for Vector Search commands. */ +export class GlideFt { + /** + * Creates an index and initiates a backfill of that index. + * + * @param client - The client to execute the command. + * @param indexName - The index name for the index to be created. + * @param schema - The fields of the index schema, specifying the fields and their types. + * @param options - (Optional) Options for the `FT.CREATE` command. See {@link FtCreateOptions}. + * @returns If the index is successfully created, returns "OK". + * + * @example + * ```typescript + * // Example usage of FT.CREATE to create a 6-dimensional JSON index using the HNSW algorithm + * await GlideFt.create(client, "json_idx1", [{ + * type: "VECTOR", + * name: "$.vec", + * alias: "VEC", + * attributes: { + * algorithm: "HNSW", + * type: "FLOAT32", + * dimension: 6, + * distanceMetric: "L2", + * numberOfEdges: 32, + * }, + * }], { + * dataType: "JSON", + * prefixes: ["json:"] + * }); + * ``` + */ + static async create( + client: GlideClient | GlideClusterClient, + indexName: GlideString, + schema: Field[], + options?: FtCreateOptions, + ): Promise<"OK"> { + const args: GlideString[] = ["FT.CREATE", indexName]; + + if (options) { + if ("dataType" in options) { + args.push("ON", options.dataType); + } + + if ("prefixes" in options && options.prefixes) { + args.push( + "PREFIX", + options.prefixes.length.toString(), + ...options.prefixes, + ); + } + } + + args.push("SCHEMA"); + + schema.forEach((f) => { + args.push(f.name); + + if (f.alias) { + args.push("AS", f.alias); + } + + args.push(f.type); + + switch (f.type) { + case "TAG": { + if (f.separator) { + args.push("SEPARATOR", f.separator); + } + + if (f.caseSensitive) { + args.push("CASESENSITIVE"); + } + + break; + } + + case "VECTOR": { + if (f.attributes) { + args.push(f.attributes.algorithm); + + const attributes: GlideString[] = []; + + // all VectorFieldAttributes attributes + if (f.attributes.dimensions) { + attributes.push( + "DIM", + f.attributes.dimensions.toString(), + ); + } + + if (f.attributes.distanceMetric) { + attributes.push( + "DISTANCE_METRIC", + f.attributes.distanceMetric.toString(), + ); + } + + if (f.attributes.type) { + attributes.push( + "TYPE", + f.attributes.type.toString(), + ); + } else { + attributes.push("TYPE", "FLOAT32"); + } + + if (f.attributes.initialCap) { + attributes.push( + "INITIAL_CAP", + f.attributes.initialCap.toString(), + ); + } + + // VectorFieldAttributesHnsw attributes + if ("m" in f.attributes && f.attributes.m) { + attributes.push("M", f.attributes.m.toString()); + } + + if ( + "efContruction" in f.attributes && + f.attributes.efContruction + ) { + attributes.push( + "EF_CONSTRUCTION", + f.attributes.efContruction.toString(), + ); + } + + if ( + "efRuntime" in f.attributes && + f.attributes.efRuntime + ) { + attributes.push( + "EF_RUNTIME", + f.attributes.efRuntime.toString(), + ); + } + + args.push(attributes.length.toString(), ...attributes); + } + + break; + } + + default: + // no-op + } + }); + + return _handleCustomCommand(client, args, { + decoder: Decoder.String, + }) as Promise<"OK">; + } + + /** + * Deletes an index and associated content. Indexed document keys are unaffected. + * + * @param client - The client to execute the command. + * @param indexName - The index name. + * @returns "OK" + * + * @example + * ```typescript + * // Example usage of FT.DROPINDEX to drop an index + * await GlideFt.dropindex(client, "json_idx1"); // "OK" + * ``` + */ + static async dropindex( + client: GlideClient | GlideClusterClient, + indexName: GlideString, + ): Promise<"OK"> { + const args: GlideString[] = ["FT.DROPINDEX", indexName]; + + return _handleCustomCommand(client, args, { + decoder: Decoder.String, + }) as Promise<"OK">; + } + + /** + * Lists all indexes. + * + * @param client - The client to execute the command. + * @param options - (Optional) See {@link DecoderOption}. + * @returns An array of index names. + * + * @example + * ```typescript + * console.log(await GlideFt.list(client)); // Output: ["index1", "index2"] + * ``` + */ + static async list( + client: GlideClient | GlideClusterClient, + options?: DecoderOption, + ): Promise { + return _handleCustomCommand(client, ["FT._LIST"], options) as Promise< + GlideString[] + >; + } + + /** + * Runs a search query on an index, and perform aggregate transformations on the results. + * + * @param client - The client to execute the command. + * @param indexName - The index name. + * @param query - The text query to search. + * @param options - Additional parameters for the command - see {@link FtAggregateOptions} and {@link DecoderOption}. + * @returns Results of the last stage of the pipeline. + * + * @example + * ```typescript + * const options: FtAggregateOptions = { + * loadFields: ["__key"], + * clauses: [ + * { + * type: "GROUPBY", + * properties: ["@condition"], + * reducers: [ + * { + * function: "TOLIST", + * args: ["__key"], + * name: "bicycles", + * }, + * ], + * }, + * ], + * }; + * const result = await GlideFt.aggregate(client, "myIndex", "*", options); + * console.log(result); // Output: + * // [ + * // [ + * // { + * // key: "condition", + * // value: "refurbished" + * // }, + * // { + * // key: "bicycles", + * // value: [ "bicycle:9" ] + * // } + * // ], + * // [ + * // { + * // key: "condition", + * // value: "used" + * // }, + * // { + * // key: "bicycles", + * // value: [ "bicycle:1", "bicycle:2", "bicycle:3" ] + * // } + * // ], + * // [ + * // { + * // key: "condition", + * // value: "new" + * // }, + * // { + * // key: "bicycles", + * // value: [ "bicycle:0", "bicycle:5" ] + * // } + * // ] + * // ] + * ``` + */ + static async aggregate( + client: GlideClient | GlideClusterClient, + indexName: GlideString, + query: GlideString, + options?: DecoderOption & FtAggregateOptions, + ): Promise { + const args: GlideString[] = [ + "FT.AGGREGATE", + indexName, + query, + ..._addFtAggregateOptions(options), + ]; + + return _handleCustomCommand( + client, + args, + options, + ) as Promise; + } + + /** + * Returns information about a given index. + * + * @param client - The client to execute the command. + * @param indexName - The index name. + * @param options - (Optional) See {@link DecoderOption}. + * @returns Nested maps with info about the index. See example for more details. + * + * @example + * ```typescript + * const info = await GlideFt.info(client, "myIndex"); + * console.log(info); // Output: + * // { + * // index_name: 'myIndex', + * // index_status: 'AVAILABLE', + * // key_type: 'JSON', + * // creation_timestamp: 1728348101728771, + * // key_prefixes: [ 'json:' ], + * // num_indexed_vectors: 0, + * // space_usage: 653471, + * // num_docs: 0, + * // vector_space_usage: 653471, + * // index_degradation_percentage: 0, + * // fulltext_space_usage: 0, + * // current_lag: 0, + * // fields: [ + * // { + * // identifier: '$.vec', + * // type: 'VECTOR', + * // field_name: 'VEC', + * // option: '', + * // vector_params: { + * // data_type: 'FLOAT32', + * // initial_capacity: 1000, + * // current_capacity: 1000, + * // distance_metric: 'L2', + * // dimension: 6, + * // block_size: 1024, + * // algorithm: 'FLAT' + * // } + * // }, + * // { + * // identifier: 'name', + * // type: 'TEXT', + * // field_name: 'name', + * // option: '' + * // }, + * // ] + * // } + * ``` + */ + static async info( + client: GlideClient | GlideClusterClient, + indexName: GlideString, + options?: DecoderOption, + ): Promise { + const args: GlideString[] = ["FT.INFO", indexName]; + + return ( + _handleCustomCommand(client, args, options) as Promise< + GlideRecord + > + ).then(convertGlideRecordToRecord); + } + + /** + * Parse a query and return information about how that query was parsed. + * + * @param client - The client to execute the command. + * @param indexName - The index name. + * @param query - The text query to search. It is the same as the query passed as + * an argument to {@link search | FT.SEARCH} or {@link aggregate | FT.AGGREGATE}. + * @param options - (Optional) See {@link DecoderOption}. + * @returns A query execution plan. + * + * @example + * ```typescript + * const result = GlideFt.explain(client, "myIndex", "@price:[0 10]"); + * console.log(result); // Output: "Field {\n\tprice\n\t0\n\t10\n}" + * ``` + */ + static explain( + client: GlideClient | GlideClusterClient, + indexName: GlideString, + query: GlideString, + options?: DecoderOption, + ): Promise { + const args = ["FT.EXPLAIN", indexName, query]; + + return _handleCustomCommand(client, args, options); + } + + /** + * Parse a query and return information about how that query was parsed. + * Same as {@link explain | FT.EXPLAIN}, except that the results are + * displayed in a different format. + * + * @param client - The client to execute the command. + * @param indexName - The index name. + * @param query - The text query to search. It is the same as the query passed as + * an argument to {@link search | FT.SEARCH} or {@link aggregate | FT.AGGREGATE}. + * @param options - (Optional) See {@link DecoderOption}. + * @returns A query execution plan. + * + * @example + * ```typescript + * const result = GlideFt.explaincli(client, "myIndex", "@price:[0 10]"); + * console.log(result); // Output: ["Field {", "price", "0", "10", "}"] + * ``` + */ + static explaincli( + client: GlideClient | GlideClusterClient, + indexName: GlideString, + query: GlideString, + options?: DecoderOption, + ): Promise { + const args = ["FT.EXPLAINCLI", indexName, query]; + + return _handleCustomCommand(client, args, options); + } + + /** + * Uses the provided query expression to locate keys within an index. Once located, the count + * and/or content of indexed fields within those keys can be returned. + * + * @param client - The client to execute the command. + * @param indexName - The index name to search into. + * @param query - The text query to search. + * @param options - (Optional) See {@link FtSearchOptions} and {@link DecoderOption}. + * @returns A two-element array, where the first element is the number of documents in the result set, and the + * second element has the format: `GlideRecord>`: + * a mapping between document names and a map of their attributes. + * + * If `count` or `limit` with values `{offset: 0, count: 0}` is + * set, the command returns array with only one element: the number of documents. + * + * @example + * ```typescript + * // + * const vector = Buffer.alloc(24); + * const result = await GlideFt.search(client, "json_idx1", "*=>[KNN 2 @VEC $query_vec]", {params: [{key: "query_vec", value: vector}]}); + * console.log(result); // Output: + * // [ + * // 2, + * // [ + * // { + * // key: "json:2", + * // value: [ + * // { + * // key: "$", + * // value: '{"vec":[1.1,1.2,1.3,1.4,1.5,1.6]}', + * // }, + * // { + * // key: "__VEC_score", + * // value: "11.1100006104", + * // }, + * // ], + * // }, + * // { + * // key: "json:0", + * // value: [ + * // { + * // key: "$", + * // value: '{"vec":[1,2,3,4,5,6]}', + * // }, + * // { + * // key: "__VEC_score", + * // value: "91", + * // }, + * // ], + * // }, + * // ], + * // ] + * ``` + */ + static async search( + client: GlideClient | GlideClusterClient, + indexName: GlideString, + query: GlideString, + options?: FtSearchOptions & DecoderOption, + ): Promise { + const args: GlideString[] = [ + "FT.SEARCH", + indexName, + query, + ..._addFtSearchOptions(options), + ]; + + return _handleCustomCommand(client, args, options) as Promise< + [number, GlideRecord>] + >; + } + + /** + * Runs a search query and collects performance profiling information. + * + * @param client - The client to execute the command. + * @param indexName - The index name. + * @param query - The text query to search. + * @param options - (Optional) See {@link FtSearchOptions} and {@link DecoderOption}. Additionally: + * - `limited` (Optional) - Either provide a full verbose output or some brief version. + * + * @returns A two-element array. The first element contains results of the search query being profiled, the + * second element stores profiling information. + * + * @example + * ```typescript + * // Example of running profile on a search query + * const vector = Buffer.alloc(24); + * const result = await GlideFt.profileSearch(client, "json_idx1", "*=>[KNN 2 @VEC $query_vec]", {params: [{key: "query_vec", value: vector}]}); + * console.log(result); // Output: + * // result[0] contains `FT.SEARCH` response with the given query + * // result[1] contains profiling data as a `Record` + * ``` + */ + static async profileSearch( + client: GlideClient | GlideClusterClient, + indexName: GlideString, + query: GlideString, + options?: DecoderOption & + FtSearchOptions & { + limited?: boolean; + }, + ): Promise<[FtSearchReturnType, Record]> { + const args: GlideString[] = ["FT.PROFILE", indexName, "SEARCH"]; + + if (options?.limited) { + args.push("LIMITED"); + } + + args.push("QUERY", query); + + if (options) { + args.push(..._addFtSearchOptions(options)); + } + + return ( + _handleCustomCommand( + client, + args, + options as DecoderOption, + ) as Promise<[FtSearchReturnType, GlideRecord]> + ).then((v) => [v[0], convertGlideRecordToRecord(v[1])]); + } + + /** + * Runs an aggregate query and collects performance profiling information. + * + * @param client - The client to execute the command. + * @param indexName - The index name. + * @param query - The text query to search. + * @param options - (Optional) See {@link FtAggregateOptions} and {@link DecoderOption}. Additionally: + * - `limited` (Optional) - Either provide a full verbose output or some brief version. + * + * @returns A two-element array. The first element contains results of the aggregate query being profiled, the + * second element stores profiling information. + * + * @example + * ```typescript + * // Example of running profile on an aggregate query + * const options: FtAggregateOptions = { + * loadFields: ["__key"], + * clauses: [ + * { + * type: "GROUPBY", + * properties: ["@condition"], + * reducers: [ + * { + * function: "TOLIST", + * args: ["__key"], + * name: "bicycles", + * }, + * ], + * }, + * ], + * }; + * const result = await GlideFt.profileAggregate(client, "myIndex", "*", options); + * console.log(result); // Output: + * // result[0] contains `FT.AGGREGATE` response with the given query + * // result[1] contains profiling data as a `Record` + * ``` + */ + static async profileAggregate( + client: GlideClient | GlideClusterClient, + indexName: GlideString, + query: GlideString, + options?: DecoderOption & + FtAggregateOptions & { + limited?: boolean; + }, + ): Promise<[FtAggregateReturnType, Record]> { + const args: GlideString[] = ["FT.PROFILE", indexName, "AGGREGATE"]; + + if (options?.limited) { + args.push("LIMITED"); + } + + args.push("QUERY", query); + + if (options) { + args.push(..._addFtAggregateOptions(options)); + } + + return ( + _handleCustomCommand( + client, + args, + options as DecoderOption, + ) as Promise<[FtAggregateReturnType, GlideRecord]> + ).then((v) => [v[0], convertGlideRecordToRecord(v[1])]); + } + + /** + * Adds an alias for an index. The new alias name can be used anywhere that an index name is required. + * + * @param client - The client to execute the command. + * @param indexName - The alias to be added to the index. + * @param alias - The index name for which the alias has to be added. + * @returns `"OK"` + * + * @example + * ```typescript + * // Example usage of FT.ALIASADD to add an alias for an index. + * await GlideFt.aliasadd(client, "index", "alias"); // "OK" + * ``` + */ + static async aliasadd( + client: GlideClient | GlideClusterClient, + indexName: GlideString, + alias: GlideString, + ): Promise<"OK"> { + const args: GlideString[] = ["FT.ALIASADD", alias, indexName]; + return _handleCustomCommand(client, args, { + decoder: Decoder.String, + }) as Promise<"OK">; + } + + /** + * Deletes an existing alias for an index. + * + * @param client - The client to execute the command. + * @param alias - The existing alias to be deleted for an index. + * @returns `"OK"` + * + * @example + * ```typescript + * // Example usage of FT.ALIASDEL to delete an existing alias. + * await GlideFt.aliasdel(client, "alias"); // "OK" + * ``` + */ + static async aliasdel( + client: GlideClient | GlideClusterClient, + alias: GlideString, + ): Promise<"OK"> { + const args: GlideString[] = ["FT.ALIASDEL", alias]; + return _handleCustomCommand(client, args, { + decoder: Decoder.String, + }) as Promise<"OK">; + } + + /** + * Updates an existing alias to point to a different physical index. This command only affects future references to the alias. + * + * @param client - The client to execute the command. + * @param alias - The alias name. This alias will now be pointed to a different index. + * @param indexName - The index name for which an existing alias has to updated. + * @returns `"OK"` + * + * @example + * ```typescript + * // Example usage of FT.ALIASUPDATE to update an alias to point to a different index. + * await GlideFt.aliasupdate(client, "newAlias", "index"); // "OK" + * ``` + */ + static async aliasupdate( + client: GlideClient | GlideClusterClient, + alias: GlideString, + indexName: GlideString, + ): Promise<"OK"> { + const args: GlideString[] = ["FT.ALIASUPDATE", alias, indexName]; + return _handleCustomCommand(client, args, { + decoder: Decoder.String, + }) as Promise<"OK">; + } + + /** + * List the index aliases. + * + * @param client - The client to execute the command. + * @param options - (Optional) See {@link DecoderOption}. + * @returns A map of index aliases for indices being aliased. + * + * @example + * ```typescript + * // Example usage of FT._ALIASLIST to query index aliases + * const result = await GlideFt.aliaslist(client); + * console.log(result); // Output: + * //[{"key": "alias1", "value": "index1"}, {"key": "alias2", "value": "index2"}] + * ``` + */ + static async aliaslist( + client: GlideClient | GlideClusterClient, + options?: DecoderOption, + ): Promise> { + const args: GlideString[] = ["FT._ALIASLIST"]; + return _handleCustomCommand(client, args, options); + } +} + +/** + * @internal + */ +function _addFtAggregateOptions(options?: FtAggregateOptions): GlideString[] { + if (!options) return []; + + const args: GlideString[] = []; + + if (options.loadAll) args.push("LOAD", "*"); + else if (options.loadFields) + args.push( + "LOAD", + options.loadFields.length.toString(), + ...options.loadFields, + ); + + if (options.timeout) args.push("TIMEOUT", options.timeout.toString()); + + if (options.params) { + args.push( + "PARAMS", + (options.params.length * 2).toString(), + ...options.params.flatMap((param) => [param.key, param.value]), + ); + } + + if (options.clauses) { + for (const clause of options.clauses) { + switch (clause.type) { + case "LIMIT": + args.push( + clause.type, + clause.offset.toString(), + clause.count.toString(), + ); + break; + case "FILTER": + args.push(clause.type, clause.expression); + break; + case "GROUPBY": + args.push( + clause.type, + clause.properties.length.toString(), + ...clause.properties, + ); + + for (const reducer of clause.reducers) { + args.push( + "REDUCE", + reducer.function, + reducer.args.length.toString(), + ...reducer.args, + ); + if (reducer.name) args.push("AS", reducer.name); + } + + break; + case "SORTBY": + args.push( + clause.type, + (clause.properties.length * 2).toString(), + ); + for (const property of clause.properties) + args.push(property.property, property.order); + if (clause.max) args.push("MAX", clause.max.toString()); + break; + case "APPLY": + args.push( + clause.type, + clause.expression, + "AS", + clause.name, + ); + break; + default: + throw new Error( + "Unknown clause type in FtAggregateOptions", + ); + } + } + } + + return args; +} + +/** + * @internal + */ +function _addFtSearchOptions(options?: FtSearchOptions): GlideString[] { + if (!options) return []; + + const args: GlideString[] = []; + + // RETURN + if (options.returnFields) { + const returnFields: GlideString[] = []; + options.returnFields.forEach((returnField) => + returnField.alias + ? returnFields.push( + returnField.fieldIdentifier, + "AS", + returnField.alias, + ) + : returnFields.push(returnField.fieldIdentifier), + ); + args.push("RETURN", returnFields.length.toString(), ...returnFields); + } + + // TIMEOUT + if (options.timeout) { + args.push("TIMEOUT", options.timeout.toString()); + } + + // PARAMS + if (options.params) { + args.push( + "PARAMS", + (options.params.length * 2).toString(), + ...options.params.flatMap((param) => [param.key, param.value]), + ); + } + + // LIMIT + if (options.limit) { + args.push( + "LIMIT", + options.limit.offset.toString(), + options.limit.count.toString(), + ); + } + + // COUNT + if (options.count) { + args.push("COUNT"); + } + + return args; +} + +/** + * @internal + */ +async function _handleCustomCommand( + client: GlideClient | GlideClusterClient, + args: GlideString[], + decoderOption: DecoderOption = {}, +): Promise { + return client instanceof GlideClient + ? ((client as GlideClient).customCommand( + args, + decoderOption, + ) as Promise) + : ((client as GlideClusterClient).customCommand( + args, + decoderOption, + ) as Promise); +} diff --git a/node/src/server-modules/GlideFtOptions.ts b/node/src/server-modules/GlideFtOptions.ts new file mode 100644 index 0000000000..6d9e8f4528 --- /dev/null +++ b/node/src/server-modules/GlideFtOptions.ts @@ -0,0 +1,272 @@ +/** + * Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + */ + +import { SortOrder } from "src/Commands"; +import { GlideRecord, GlideString } from "../BaseClient"; +import { GlideFt } from "./GlideFt"; // eslint-disable-line @typescript-eslint/no-unused-vars + +interface BaseField { + /** The name of the field. */ + name: GlideString; + /** An alias for field. */ + alias?: GlideString; +} + +/** + * Field contains any blob of data. + */ +export type TextField = BaseField & { + /** Field identifier */ + type: "TEXT"; +}; + +/** + * Tag fields are similar to full-text fields, but they interpret the text as a simple list of + * tags delimited by a separator character. + * + * For HASH fields, separator default is a comma (`,`). For JSON fields, there is no default + * separator; you must declare one explicitly if needed. + */ +export type TagField = BaseField & { + /** Field identifier */ + type: "TAG"; + /** Specify how text in the attribute is split into individual tags. Must be a single character. */ + separator?: GlideString; + /** Preserve the original letter cases of tags. If set to `false`, characters are converted to lowercase by default. */ + caseSensitive?: boolean; +}; + +/** + * Field contains a number. + */ +export type NumericField = BaseField & { + /** Field identifier */ + type: "NUMERIC"; +}; + +/** + * Superclass for vector field implementations, contains common logic. + */ +export type VectorField = BaseField & { + /** Field identifier */ + type: "VECTOR"; + /** Additional attributes to be passed with the vector field after the algorithm name. */ + attributes: VectorFieldAttributesFlat | VectorFieldAttributesHnsw; +}; + +/** + * Base class for defining vector field attributes to be used after the vector algorithm name. + */ +export interface VectorFieldAttributes { + /** Number of dimensions in the vector. Equivalent to `DIM` in the module API. */ + dimensions: number; + /** + * The distance metric used in vector type field. Can be one of `[L2 | IP | COSINE]`. Equivalent to `DISTANCE_METRIC` in the module API. + */ + distanceMetric: "L2" | "IP" | "COSINE"; + /** Vector type. The only supported type is FLOAT32. */ + type?: "FLOAT32"; + /** + * Initial vector capacity in the index affecting memory allocation size of the index. Defaults to `1024`. Equivalent to `INITIAL_CAP` in the module API. + */ + initialCap?: number; +} + +/** + * Vector field that supports vector search by FLAT (brute force) algorithm. + * + * The algorithm is a brute force linear processing of each vector in the index, yielding exact + * answers within the bounds of the precision of the distance computations. + */ +export type VectorFieldAttributesFlat = VectorFieldAttributes & { + algorithm: "FLAT"; +}; + +/** + * Vector field that supports vector search by HNSM (Hierarchical Navigable Small World) algorithm. + * + * The algorithm provides an approximation of the correct answer in exchange for substantially + * lower execution times. + */ +export type VectorFieldAttributesHnsw = VectorFieldAttributes & { + algorithm: "HNSW"; + /** + * Number of maximum allowed outgoing edges for each node in the graph in each layer. Default is `16`, maximum is `512`. + * Equivalent to `M` in the module API. + */ + numberOfEdges?: number; + /** + * Controls the number of vectors examined during index construction. Default value is `200`, Maximum value is `4096`. + * Equivalent to `EF_CONSTRUCTION` in the module API. + */ + vectorsExaminedOnConstruction?: number; + /** + * Controls the number of vectors examined during query operations. Default value is `10`, Maximum value is `4096`. + * Equivalent to `EF_RUNTIME` in the module API. + */ + vectorsExaminedOnRuntime?: number; +}; + +export type Field = TextField | TagField | NumericField | VectorField; + +/** + * Represents the input options to be used in the {@link GlideFt.create | FT.CREATE} command. + * All fields in this class are optional inputs for FT.CREATE. + */ +export interface FtCreateOptions { + /** The type of data to be indexed using FT.CREATE. */ + dataType: "JSON" | "HASH"; + /** The prefix of the key to be indexed. */ + prefixes?: GlideString[]; +} + +/** Additional parameters for {@link GlideFt.aggregate | FT.AGGREGATE} command. */ +export type FtAggregateOptions = { + /** Query timeout in milliseconds. */ + timeout?: number; + /** + * {@link FtAggregateFilter | FILTER}, {@link FtAggregateLimit | LIMIT}, {@link FtAggregateGroupBy | GROUPBY}, + * {@link FtAggregateSortBy | SORTBY} and {@link FtAggregateApply | APPLY} clauses, that can be repeated + * multiple times in any order and be freely intermixed. They are applied in the order specified, + * with the output of one clause feeding the input of the next clause. + */ + clauses?: ( + | FtAggregateLimit + | FtAggregateFilter + | FtAggregateGroupBy + | FtAggregateSortBy + | FtAggregateApply + )[]; + /** + * Query parameters, which could be referenced in the query by `$` sign, followed by + * the parameter name. + */ + params?: GlideRecord; +} & ( + | { + /** List of fields to load from the index. */ + loadFields?: GlideString[]; + /** `loadAll` and `loadFields` are mutually exclusive. */ + loadAll?: never; + } + | { + /** Option to load all fields declared in the index */ + loadAll?: boolean; + /** `loadAll` and `loadFields` are mutually exclusive. */ + loadFields?: never; + } +); + +/** A clause for limiting the number of retained records. */ +export interface FtAggregateLimit { + type: "LIMIT"; + /** Starting point from which the records have to be retained. */ + offset: number; + /** The total number of records to be retained. */ + count: number; +} + +/** + * A clause for filtering the results using predicate expression relating to values in each result. + * It is applied post query and relate to the current state of the pipeline. + */ +export interface FtAggregateFilter { + type: "FILTER"; + /** The expression to filter the results. */ + expression: GlideString; +} + +/** A clause for grouping the results in the pipeline based on one or more properties. */ +export interface FtAggregateGroupBy { + type: "GROUPBY"; + /** The list of properties to be used for grouping the results in the pipeline. */ + properties: GlideString[]; + /** The list of functions that handles the group entries by performing multiple aggregate operations. */ + reducers: FtAggregateReducer[]; +} + +/** + * A clause for reducing the matching results in each group using a reduction function. + * The matching results are reduced into a single record. + */ +export interface FtAggregateReducer { + /** The reduction function name for the respective group. */ + function: string; + /** The list of arguments for the reducer. */ + args: GlideString[]; + /** User defined property name for the reducer. */ + name?: GlideString; +} + +/** A clause for sorting the pipeline up until the point of SORTBY, using a list of properties. */ +export interface FtAggregateSortBy { + type: "SORTBY"; + /** A list of sorting parameters for the sort operation. */ + properties: FtAggregateSortProperty[]; + /** The MAX value for optimizing the sorting, by sorting only for the n-largest elements. */ + max?: number; +} + +/** A single property for the {@link FtAggregateSortBy | SORTBY} clause. */ +export interface FtAggregateSortProperty { + /** The sorting parameter. */ + property: GlideString; + /** The order for the sorting. */ + order: SortOrder; +} + +/** + * A clause for applying a 1-to-1 transformation on one or more properties and stores the result + * as a new property down the pipeline or replaces any property using this transformation. + */ +export interface FtAggregateApply { + type: "APPLY"; + /** The transformation expression. */ + expression: GlideString; + /** The new property name to store the result of apply. This name can be referenced by further operations down the pipeline. */ + name: GlideString; +} + +/** + * Represents the input options to be used in the FT.SEARCH command. + * All fields in this class are optional inputs for FT.SEARCH. + */ +export type FtSearchOptions = { + /** Query timeout in milliseconds. */ + timeout?: number; + + /** + * Add a field to be returned. + * @param fieldIdentifier field name to return. + * @param alias optional alias for the field name to return. + */ + returnFields?: { fieldIdentifier: GlideString; alias?: GlideString }[]; + + /** + * Query parameters, which could be referenced in the query by `$` sign, followed by + * the parameter name. + */ + params?: GlideRecord; +} & ( + | { + /** + * Configure query pagination. By default only first 10 documents are returned. + * + * @param offset Zero-based offset. + * @param count Number of elements to return. + */ + limit?: { offset: number; count: number }; + /** `limit` and `count` are mutually exclusive. */ + count?: never; + } + | { + /** + * Once set, the query will return only the number of documents in the result set without actually + * returning them. + */ + count?: boolean; + /** `limit` and `count` are mutually exclusive. */ + limit?: never; + } +); diff --git a/node/src/server-modules/GlideJson.ts b/node/src/server-modules/GlideJson.ts new file mode 100644 index 0000000000..c5e69b9993 --- /dev/null +++ b/node/src/server-modules/GlideJson.ts @@ -0,0 +1,1159 @@ +/** + * Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + */ + +import { BaseClient, DecoderOption, GlideString } from "../BaseClient"; +import { ConditionalChange } from "../Commands"; +import { GlideClient } from "../GlideClient"; +import { GlideClusterClient, RouteOption } from "../GlideClusterClient"; + +export type ReturnTypeJson = T | (T | null)[]; +export type UniversalReturnTypeJson = T | T[]; + +/** + * Represents options for formatting JSON data, to be used in the {@link GlideJson.get | JSON.GET} command. + */ +export interface JsonGetOptions { + /** The path or list of paths within the JSON document. Default is root `$`. */ + path?: GlideString | GlideString[]; + /** Sets an indentation string for nested levels. */ + indent?: GlideString; + /** Sets a string that's printed at the end of each line. */ + newline?: GlideString; + /** Sets a string that's put between a key and a value. */ + space?: GlideString; + /** Optional, allowed to be present for legacy compatibility and has no other effect */ + noescape?: boolean; +} + +/** Additional options for {@link GlideJson.arrpop | JSON.ARRPOP} command. */ +export interface JsonArrPopOptions { + /** The path within the JSON document. */ + path: GlideString; + /** The index of the element to pop. Out of boundary indexes are rounded to their respective array boundaries. */ + index?: number; +} + +/** + * @internal + */ +function _jsonGetOptionsToArgs(options: JsonGetOptions): GlideString[] { + const result: GlideString[] = []; + + if (options.path) { + if (Array.isArray(options.path)) { + result.push(...options.path); + } else { + result.push(options.path); + } + } + + if (options.indent) { + result.push("INDENT", options.indent); + } + + if (options.newline) { + result.push("NEWLINE", options.newline); + } + + if (options.space) { + result.push("SPACE", options.space); + } + + if (options.noescape) { + result.push("NOESCAPE"); + } + + return result; +} + +/** + * @internal + */ +function _executeCommand( + client: BaseClient, + args: GlideString[], + options?: RouteOption & DecoderOption, +): Promise { + if (client instanceof GlideClient) { + return (client as GlideClient).customCommand( + args, + options, + ) as Promise; + } else { + return (client as GlideClusterClient).customCommand( + args, + options, + ) as Promise; + } +} + +/** Module for JSON commands. */ +export class GlideJson { + /** + * Sets the JSON value at the specified `path` stored at `key`. + * + * @param client The client to execute the command. + * @param key - The key of the JSON document. + * @param path - Represents the path within the JSON document where the value will be set. + * The key will be modified only if `value` is added as the last child in the specified `path`, or if the specified `path` acts as the parent of a new child being added. + * @param value - The value to set at the specific path, in JSON formatted bytes or str. + * @param options - (Optional) Additional parameters: + * - (Optional) `conditionalChange` - Set the value only if the given condition is met (within the key or path). + * Equivalent to [`XX` | `NX`] in the module API. + * - (Optional) `decoder`: see {@link DecoderOption}. + * + * @returns If the value is successfully set, returns `"OK"`. + * If `value` isn't set because of `conditionalChange`, returns `null`. + * + * @example + * ```typescript + * const value = {a: 1.0, b:2}; + * const jsonStr = JSON.stringify(value); + * const result = await GlideJson.set("doc", "$", jsonStr); + * console.log(result); // 'OK' - Indicates successful setting of the value at path '$' in the key stored at `doc`. + * + * const jsonGetStr = await GlideJson.get(client, "doc", {path: "$"}); // Returns the value at path '$' in the JSON document stored at `doc` as JSON string. + * console.log(jsonGetStr); // '[{"a":1.0,"b":2}]' + * console.log(JSON.stringify(jsonGetStr)); // [{"a": 1.0, "b": 2}] # JSON object retrieved from the key `doc` + * ``` + */ + static async set( + client: BaseClient, + key: GlideString, + path: GlideString, + value: GlideString, + options?: { conditionalChange: ConditionalChange } & DecoderOption, + ): Promise<"OK" | null> { + const args: GlideString[] = ["JSON.SET", key, path, value]; + + if (options?.conditionalChange !== undefined) { + args.push(options.conditionalChange); + } + + return _executeCommand<"OK" | null>(client, args, options); + } + + /** + * Retrieves the JSON value at the specified `paths` stored at `key`. + * + * @param client The client to execute the command. + * @param key - The key of the JSON document. + * @param options - (Optional) Additional parameters: + * - (Optional) Options for formatting the byte representation of the JSON data. See {@link JsonGetOptions}. + * - (Optional) `decoder`: see {@link DecoderOption}. + * @returns + * - If one path is given: + * - For JSONPath (path starts with `$`): + * - Returns a stringified JSON list of bytes replies for every possible path, + * or a byte string representation of an empty array, if path doesn't exist. + * If `key` doesn't exist, returns `null`. + * - For legacy path (path doesn't start with `$`): + * Returns a byte string representation of the value in `path`. + * If `path` doesn't exist, an error is raised. + * If `key` doesn't exist, returns `null`. + * - If multiple paths are given: + * Returns a stringified JSON object in bytes, in which each path is a key, and it's corresponding value, is the value as if the path was executed in the command as a single path. + * In case of multiple paths, and `paths` are a mix of both JSONPath and legacy path, the command behaves as if all are JSONPath paths. + * + * @example + * ```typescript + * const jsonStr = await GlideJson.get('doc', {path: '$'}); + * console.log(JSON.parse(jsonStr as string)); + * // Output: [{"a": 1.0, "b" :2}] - JSON object retrieved from the key `doc`. + * + * const jsonData = await GlideJson.get(('doc', {path: '$'}); + * console.log(jsonData); + * // Output: '[{"a":1.0,"b":2}]' - Returns the value at path '$' in the JSON document stored at `doc`. + * + * const formattedJson = await GlideJson.get(('doc', { + * ['$.a', '$.b'] + * indent: " ", + * newline: "\n", + * space: " " + * }); + * console.log(formattedJson); + * // Output: "{\n \"$.a\": [\n 1.0\n ],\n \"$.b\": [\n 2\n ]\n}" - Returns values at paths '$.a' and '$.b' with custom format. + * + * const nonExistingPath = await GlideJson.get(('doc', {path: '$.non_existing_path'}); + * console.log(nonExistingPath); + * // Output: "[]" - Empty array since the path does not exist in the JSON document. + * ``` + */ + static async get( + client: BaseClient, + key: GlideString, + options?: JsonGetOptions & DecoderOption, + ): Promise> { + const args = ["JSON.GET", key]; + + if (options) { + const optionArgs = _jsonGetOptionsToArgs(options); + args.push(...optionArgs); + } + + return _executeCommand(client, args, options); + } + + /** + * Retrieves the JSON values at the specified `path` stored at multiple `keys`. + * + * @remarks When in cluster mode, if keys in `keyValueMap` map to different hash slots, the command + * will be split across these slots and executed separately for each. This means the command + * is atomic only at the slot level. If one or more slot-specific requests fail, the entire + * call will return the first encountered error, even though some requests may have succeeded + * while others did not. If this behavior impacts your application logic, consider splitting + * the request into sub-requests per slot to ensure atomicity. + * + * @param client - The client to execute the command. + * @param keys - The keys of the JSON documents. + * @param path - The path within the JSON documents. + * @param options - (Optional) See {@link DecoderOption}. + * @returns + * - For JSONPath (path starts with `$`): + * Returns a stringified JSON list replies for every possible path, or a string representation + * of an empty array, if path doesn't exist. + * - For legacy path (path doesn't start with `$`): + * Returns a string representation of the value in `path`. If `path` doesn't exist, + * the corresponding array element will be `null`. + * - If a `key` doesn't exist, the corresponding array element will be `null`. + * + * @example + * ```typescript + * await GlideJson.set(client, "doc1", "$", '{"a": 1, "b": ["one", "two"]}'); + * await GlideJson.set(client, "doc2", "$", '{"a": 1, "c": false}'); + * const res = await GlideJson.mget(client, [ "doc1", "doc2", "doc3" ], "$.c"); + * console.log(res); // Output: ["[]", "[false]", null] + * ``` + */ + static async mget( + client: BaseClient, + keys: GlideString[], + path: GlideString, + options?: DecoderOption, + ): Promise { + const args = ["JSON.MGET", ...keys, path]; + return _executeCommand(client, args, options); + } + + /** + * Inserts one or more values into the array at the specified `path` within the JSON + * document stored at `key`, before the given `index`. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param path - The path within the JSON document. + * @param index - The array index before which values are inserted. + * @param values - The JSON values to be inserted into the array. + * JSON string values must be wrapped with quotes. For example, to insert `"foo"`, pass `"\"foo\""`. + * @returns + * - For JSONPath (path starts with `$`): + * Returns an array with a list of integers for every possible path, + * indicating the new length of the array, or `null` for JSON values matching + * the path that are not an array. If `path` does not exist, an empty array + * will be returned. + * - For legacy path (path doesn't start with `$`): + * Returns an integer representing the new length of the array. If multiple paths are + * matched, returns the length of the first modified array. If `path` doesn't + * exist or the value at `path` is not an array, an error is raised. + * - If the index is out of bounds or `key` doesn't exist, an error is raised. + * + * @example + * ```typescript + * await GlideJson.set(client, "doc", "$", '[[], ["a"], ["a", "b"]]'); + * const result = await GlideJson.arrinsert(client, "doc", "$[*]", 0, ['"c"', '{"key": "value"}', "true", "null", '["bar"]']); + * console.log(result); // Output: [5, 6, 7] + * const doc = await json.get(client, "doc"); + * console.log(doc); // Output: '[["c",{"key":"value"},true,null,["bar"]],["c",{"key":"value"},true,null,["bar"],"a"],["c",{"key":"value"},true,null,["bar"],"a","b"]]' + * ``` + * @example + * ```typescript + * await GlideJson.set(client, "doc", "$", '[[], ["a"], ["a", "b"]]'); + * const result = await GlideJson.arrinsert(client, "doc", ".", 0, ['"c"']) + * console.log(result); // Output: 4 + * const doc = await json.get(client, "doc"); + * console.log(doc); // Output: '[\"c\",[],[\"a\"],[\"a\",\"b\"]]' + * ``` + */ + static async arrinsert( + client: BaseClient, + key: GlideString, + path: GlideString, + index: number, + values: GlideString[], + ): Promise> { + const args = ["JSON.ARRINSERT", key, path, index.toString(), ...values]; + + return _executeCommand(client, args); + } + + /** + * Pops an element from the array located at `path` in the JSON document stored at `key`. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param options - (Optional) See {@link JsonArrPopOptions} and {@link DecoderOption}. + * @returns + * - For JSONPath (path starts with `$`): + * Returns an array with a strings for every possible path, representing the popped JSON + * values, or `null` for JSON values matching the path that are not an array + * or an empty array. + * - For legacy path (path doesn't start with `$`): + * Returns a string representing the popped JSON value, or `null` if the + * array at `path` is empty. If multiple paths are matched, the value from + * the first matching array that is not empty is returned. If `path` doesn't + * exist or the value at `path` is not an array, an error is raised. + * - If the index is out of bounds or `key` doesn't exist, an error is raised. + * + * @example + * ```typescript + * await GlideJson.set(client, "doc", "$", '{"a": [1, 2, true], "b": {"a": [3, 4, ["value", 3, false], 5], "c": {"a": 42}}}'); + * let result = await GlideJson.arrpop(client, "doc", { path: "$.a", index: 1 }); + * console.log(result); // Output: ['2'] - Popped second element from array at path `$.a` + * result = await GlideJson.arrpop(client, "doc", { path: "$..a" }); + * console.log(result); // Output: ['true', '5', null] - Popped last elements from all arrays matching path `$..a` + * + * result = await GlideJson.arrpop(client, "doc", { path: "..a" }); + * console.log(result); // Output: "1" - First match popped (from array at path ..a) + * // Even though only one value is returned from `..a`, subsequent arrays are also affected + * console.log(await GlideJson.get(client, "doc", "$..a")); // Output: "[[], [3, 4], 42]" + * ``` + * @example + * ```typescript + * await GlideJson.set(client, "doc", "$", '[[], ["a"], ["a", "b", "c"]]'); + * let result = await GlideJson.arrpop(client, "doc", { path: ".", index: -1 }); + * console.log(result); // Output: '["a","b","c"]' - Popped last elements at path `.` + * ``` + */ + static async arrpop( + client: BaseClient, + key: GlideString, + options?: JsonArrPopOptions & DecoderOption, + ): Promise> { + const args = ["JSON.ARRPOP", key]; + if (options?.path) args.push(options?.path); + if (options && "index" in options && options.index) + args.push(options?.index.toString()); + + return _executeCommand(client, args, options); + } + + /** + * Retrieves the length of the array at the specified `path` within the JSON document stored at `key`. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param options - (Optional) Additional parameters: + * - (Optional) `path`: The path within the JSON document. Defaults to the root (`"."`) if not specified. + * @returns + * - For JSONPath (path starts with `$`): + * Returns an array with a list of integers for every possible path, + * indicating the length of the array, or `null` for JSON values matching + * the path that are not an array. If `path` does not exist, an empty array + * will be returned. + * - For legacy path (path doesn't start with `$`): + * Returns an integer representing the length of the array. If multiple paths are + * matched, returns the length of the first matching array. If `path` doesn't + * exist or the value at `path` is not an array, an error is raised. + * - If the index is out of bounds or `key` doesn't exist, an error is raised. + * + * @example + * ```typescript + * await GlideJson.set(client, "doc", "$", '{"a": [1, 2, 3], "b": {"a": [1, 2], "c": {"a": 42}}}'); + * console.log(await GlideJson.arrlen(client, "doc", { path: "$" })); // Output: [null] - No array at the root path. + * console.log(await GlideJson.arrlen(client, "doc", { path: "$.a" })); // Output: [3] - Retrieves the length of the array at path $.a. + * console.log(await GlideJson.arrlen(client, "doc", { path: "$..a" })); // Output: [3, 2, null] - Retrieves lengths of arrays found at all levels of the path `$..a`. + * console.log(await GlideJson.arrlen(client, "doc", { path: "..a" })); // Output: 3 - Legacy path retrieves the first array match at path `..a`. + * ``` + * @example + * ```typescript + * await GlideJson.set(client, "doc", "$", '[1, 2, 3, 4]'); + * console.log(await GlideJson.arrlen(client, "doc")); // Output: 4 - the length of array at root. + * ``` + */ + static async arrlen( + client: BaseClient, + key: GlideString, + options?: { path: GlideString }, + ): Promise> { + const args = ["JSON.ARRLEN", key]; + if (options?.path) args.push(options?.path); + + return _executeCommand(client, args); + } + + /** + * Trims an array at the specified `path` within the JSON document stored at `key` so that it becomes a subarray [start, end], both inclusive. + * If `start` < 0, it is treated as 0. + * If `end` >= size (size of the array), it is treated as size-1. + * If `start` >= size or `start` > `end`, the array is emptied and 0 is returned. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param path - The path within the JSON document. + * @param start - The start index, inclusive. + * @param end - The end index, inclusive. + * @returns + * - For JSONPath (`path` starts with `$`): + * - Returns a list of integer replies for every possible path, indicating the new length of the array, + * or `null` for JSON values matching the path that are not an array. + * - If the array is empty, its corresponding return value is 0. + * - If `path` doesn't exist, an empty array will be returned. + * - If an index argument is out of bounds, an error is raised. + * - For legacy path (`path` doesn't start with `$`): + * - Returns an integer representing the new length of the array. + * - If the array is empty, its corresponding return value is 0. + * - If multiple paths match, the length of the first trimmed array match is returned. + * - If `path` doesn't exist, or the value at `path` is not an array, an error is raised. + * - If an index argument is out of bounds, an error is raised. + * + * @example + * ```typescript + * console.log(await GlideJson.set(client, "doc", "$", '[[], ["a"], ["a", "b"], ["a", "b", "c"]]'); + * // Output: 'OK' - Indicates successful setting of the value at path '$' in the key stored at `doc`. + * const result = await GlideJson.arrtrim(client, "doc", "$[*]", 0, 1); + * console.log(result); + * // Output: [0, 1, 2, 2] + * console.log(await GlideJson.get(client, "doc", "$")); + * // Output: '[[],["a"],["a","b"],["a","b"]]' - Returns the value at path '$' in the JSON document stored at `doc`. + * ``` + * @example + * ```typescript + * console.log(await GlideJson.set(client, "doc", "$", '{"children": ["John", "Jack", "Tom", "Bob", "Mike"]}'); + * // Output: 'OK' - Indicates successful setting of the value at path '$' in the key stored at `doc`. + * result = await GlideJson.arrtrim(client, "doc", ".children", 0, 1); + * console.log(result); + * // Output: 2 + * console.log(await GlideJson.get(client, "doc", ".children")); + * // Output: '["John", "Jack"]' - Returns the value at path '$' in the JSON document stored at `doc`. + * ``` + */ + static async arrtrim( + client: BaseClient, + key: GlideString, + path: GlideString, + start: number, + end: number, + ): Promise> { + const args: GlideString[] = [ + "JSON.ARRTRIM", + key, + path, + start.toString(), + end.toString(), + ]; + return _executeCommand>(client, args); + } + + /** + * Searches for the first occurrence of a `scalar` JSON value in the arrays at the `path`. + * Out of range errors are treated by rounding the index to the array's `start` and `end. + * If `start` > `end`, return `-1` (not found). + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param path - The path within the JSON document. + * @param scalar - The scalar value to search for. + * @param options - (Optional) Additional parameters: + * - (Optional) `start`: The start index, inclusive. Default to 0 if not provided. + * - (Optional) `end`: The end index, exclusive. Default to 0 if not provided. + * 0 or -1 means the last element is included. + * @returns + * - For JSONPath (path starts with `$`): + * Returns an array with a list of integers for every possible path, + * indicating the index of the matching element. The value is `-1` if not found. + * If a value is not an array, its corresponding return value is `null`. + * - For legacy path (path doesn't start with `$`): + * Returns an integer representing the index of matching element, or `-1` if + * not found. If the value at the `path` is not an array, an error is raised. + * + * @example + * ```typescript + * await GlideJson.set(client, "doc", "$", '{"a": ["value", 3], "b": {"a": [3, ["value", false], 5]}}'); + * console.log(await GlideJson.arrindex(client, "doc", "$..a", 3, { start: 3, end: 3 }); // Output: [2, -1] + * ``` + */ + static async arrindex( + client: BaseClient, + key: GlideString, + path: GlideString, + scalar: GlideString | number | boolean | null, + options?: { start: number; end?: number }, + ): Promise> { + const args = ["JSON.ARRINDEX", key, path]; + + if (typeof scalar === `number`) { + args.push(scalar.toString()); + } else if (typeof scalar === `boolean`) { + args.push(scalar ? `true` : `false`); + } else if (scalar !== null) { + args.push(scalar); + } else { + args.push(`null`); + } + + if (options?.start !== undefined) args.push(options?.start.toString()); + if (options?.end !== undefined) args.push(options?.end.toString()); + + return _executeCommand(client, args); + } + + /** + * Toggles a Boolean value stored at the specified `path` within the JSON document stored at `key`. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param options - (Optional) Additional parameters: + * - (Optional) `path`: The path within the JSON document. Defaults to the root (`"."`) if not specified. + * @returns - For JSONPath (`path` starts with `$`), returns a list of boolean replies for every possible path, with the toggled boolean value, + * or `null` for JSON values matching the path that are not boolean. + * - For legacy path (`path` doesn't starts with `$`), returns the value of the toggled boolean in `path`. + * - Note that when sending legacy path syntax, If `path` doesn't exist or the value at `path` isn't a boolean, an error is raised. + * + * @example + * ```typescript + * const value = {bool: true, nested: {bool: false, nested: {bool: 10}}}; + * const jsonStr = JSON.stringify(value); + * const resultSet = await GlideJson.set("doc", "$", jsonStr); + * // Output: 'OK' + * + * const resultToggle = await.GlideJson.toggle(client, "doc", {path: "$.bool"}); + * // Output: [false, true, null] - Indicates successful toggling of the Boolean values at path '$.bool' in the key stored at `doc`. + * + * const resultToggle = await.GlideJson.toggle(client, "doc", {path: "bool"}); + * // Output: true - Indicates successful toggling of the Boolean value at path 'bool' in the key stored at `doc`. + * + * const resultToggle = await.GlideJson.toggle(client, "doc", {path: "bool"}); + * // Output: true - Indicates successful toggling of the Boolean value at path 'bool' in the key stored at `doc`. + * + * const jsonGetStr = await GlideJson.get(client, "doc", {path: "$"}); + * console.log(JSON.stringify(jsonGetStr)); + * // Output: [{bool: true, nested: {bool: true, nested: {bool: 10}}}] - The updated JSON value in the key stored at `doc`. + * + * // Without specifying a path, the path defaults to root. + * console.log(await GlideJson.set(client, "doc2", ".", true)); // Output: "OK" + * console.log(await GlideJson.toggle(client, {path: "doc2"})); // Output: "false" + * console.log(await GlideJson.toggle(client, {path: "doc2"})); // Output: "true" + * ``` + */ + static async toggle( + client: BaseClient, + key: GlideString, + options?: { path: GlideString }, + ): Promise> { + const args = ["JSON.TOGGLE", key]; + + if (options) { + args.push(options.path); + } + + return _executeCommand(client, args); + } + + /** + * Deletes the JSON value at the specified `path` within the JSON document stored at `key`. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param options - (Optional) Additional parameters: + * - (Optional) `path`: If `null`, deletes the entire JSON document at `key`. + * @returns - The number of elements removed. If `key` or `path` doesn't exist, returns 0. + * + * @example + * ```typescript + * console.log(await GlideJson.set(client, "doc", "$", '{a: 1, nested: {a:2, b:3}}')); + * // Output: "OK" - Indicates successful setting of the value at path '$' in the key stored at `doc`. + * console.log(await GlideJson.del(client, "doc", {path: "$..a"})); + * // Output: 2 - Indicates successful deletion of the specific values in the key stored at `doc`. + * console.log(await GlideJson.get(client, "doc", {path: "$"})); + * // Output: "[{nested: {b: 3}}]" - Returns the value at path '$' in the JSON document stored at `doc`. + * console.log(await GlideJson.del(client, "doc")); + * // Output: 1 - Deletes the entire JSON document stored at `doc`. + * ``` + */ + static async del( + client: BaseClient, + key: GlideString, + options?: { path: GlideString }, + ): Promise { + const args = ["JSON.DEL", key]; + + if (options) { + args.push(options.path); + } + + return _executeCommand(client, args); + } + + /** + * Deletes the JSON value at the specified `path` within the JSON document stored at `key`. This command is + * an alias of {@link del}. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param options - (Optional) Additional parameters: + * - (Optional) `path`: If `null`, deletes the entire JSON document at `key`. + * @returns - The number of elements removed. If `key` or `path` doesn't exist, returns 0. + * + * @example + * ```typescript + * console.log(await GlideJson.set(client, "doc", "$", '{a: 1, nested: {a:2, b:3}}')); + * // Output: "OK" - Indicates successful setting of the value at path '$' in the key stored at `doc`. + * console.log(await GlideJson.forget(client, "doc", {path: "$..a"})); + * // Output: 2 - Indicates successful deletion of the specific values in the key stored at `doc`. + * console.log(await GlideJson.get(client, "doc", {path: "$"})); + * // Output: "[{nested: {b: 3}}]" - Returns the value at path '$' in the JSON document stored at `doc`. + * console.log(await GlideJson.forget(client, "doc")); + * // Output: 1 - Deletes the entire JSON document stored at `doc`. + * ``` + */ + static async forget( + client: BaseClient, + key: GlideString, + options?: { path: GlideString }, + ): Promise { + const args = ["JSON.FORGET", key]; + + if (options) { + args.push(options.path); + } + + return _executeCommand(client, args); + } + + /** + * Reports the type of values at the given path. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param options - (Optional) Additional parameters: + * - (Optional) `path`: Defaults to root (`"."`) if not provided. + * @returns + * - For JSONPath (path starts with `$`): + * - Returns an array of strings that represents the type of value at each path. + * The type is one of "null", "boolean", "string", "number", "integer", "object" and "array". + * - If a path does not exist, its corresponding return value is `null`. + * - Empty array if the document key does not exist. + * - For legacy path (path doesn't start with `$`): + * - String that represents the type of the value. + * - `null` if the document key does not exist. + * - `null` if the JSON path is invalid or does not exist. + * + * @example + * ```typescript + * console.log(await GlideJson.set(client, "doc", "$", '[1, 2.3, "foo", true, null, {}, []]')); + * // Output: 'OK' - Indicates successful setting of the value at path '$' in the key stored at `doc`. + * const result = await GlideJson.type(client, "doc", {path: "$[*]"}); + * console.log(result); + * // Output: ["integer", "number", "string", "boolean", null, "object", "array"]; + * console.log(await GlideJson.set(client, "doc2", ".", "{Name: 'John', Age: 27}")); + * console.log(await GlideJson.type(client, "doc2")); // Output: "object" + * console.log(await GlideJson.type(client, "doc2", {path: ".Age"})); // Output: "integer" + * ``` + */ + static async type( + client: BaseClient, + key: GlideString, + options?: { path: GlideString }, + ): Promise> { + const args = ["JSON.TYPE", key]; + + if (options) { + args.push(options.path); + } + + return _executeCommand>(client, args); + } + + /** + * Clears arrays or objects at the specified JSON path in the document stored at `key`. + * Numeric values are set to `0`, boolean values are set to `false`, and string values are converted to empty strings. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param options - (Optional) Additional parameters: + * - (Optional) `path`: The JSON path to the arrays or objects to be cleared. Defaults to root if not provided. + * @returns The number of containers cleared, numeric values zeroed, and booleans toggled to `false`, + * and string values converted to empty strings. + * If `path` doesn't exist, or the value at `path` is already empty (e.g., an empty array, object, or string), `0` is returned. + * If `key doesn't exist, an error is raised. + * + * @example + * ```typescript + * console.log(await GlideJson.set(client, "doc", "$", '{"obj":{"a":1, "b":2}, "arr":[1,2,3], "str": "foo", "bool": true, "int": 42, "float": 3.14, "nullVal": null}')); + * // Output: 'OK' - JSON document is successfully set. + * console.log(await GlideJson.clear(client, "doc", {path: "$.*"})); + * // Output: 6 - 6 values are cleared (arrays/objects/strings/numbers/booleans), but `null` remains as is. + * console.log(await GlideJson.get(client, "doc", "$")); + * // Output: '[{"obj":{},"arr":[],"str":"","bool":false,"int":0,"float":0.0,"nullVal":null}]' + * console.log(await GlideJson.clear(client, "doc", {path: "$.*"})); + * // Output: 0 - No further clearing needed since the containers are already empty and the values are defaults. + * + * console.log(await GlideJson.set(client, "doc", "$", '{"a": 1, "b": {"a": [5, 6, 7], "b": {"a": true}}, "c": {"a": "value", "b": {"a": 3.5}}, "d": {"a": {"foo": "foo"}}, "nullVal": null}')); + * // Output: 'OK' + * console.log(await GlideJson.clear(client, "doc", {path: "b.a[1:3]"})); + * // Output: 2 - 2 elements (`6` and `7`) are cleared. + * console.log(await GlideJson.clear(client, "doc", {path: "b.a[1:3]"})); + * // Output: 0 - No elements cleared since specified slice has already been cleared. + * console.log(await GlideJson.get(client, "doc", {path: "$..a"})); + * // Output: '[1,[5,0,0],true,"value",3.5,{"foo":"foo"}]' + * + * console.log(await GlideJson.clear(client, "doc", {path: "$..a"})); + * // Output: 6 - All numeric, boolean, and string values across paths are cleared. + * console.log(await GlideJson.get(client, "doc", {path: "$..a"})); + * // Output: '[0,[],false,"",0.0,{}]' + * ``` + */ + static async clear( + client: BaseClient, + key: GlideString, + options?: { path: GlideString }, + ): Promise> { + const args = ["JSON.CLEAR", key]; + + if (options) { + args.push(options.path); + } + + return _executeCommand>(client, args); + } + + /** + * Retrieve the JSON value at the specified `path` within the JSON document stored at `key`. + * The returning result is in the Valkey or Redis OSS Serialization Protocol (RESP). + * JSON null is mapped to the RESP Null Bulk String. + * JSON Booleans are mapped to RESP Simple string. + * JSON integers are mapped to RESP Integers. + * JSON doubles are mapped to RESP Bulk Strings. + * JSON strings are mapped to RESP Bulk Strings. + * JSON arrays are represented as RESP arrays, where the first element is the simple string [, followed by the array's elements. + * JSON objects are represented as RESP object, where the first element is the simple string {, followed by key-value pairs, each of which is a RESP bulk string. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param options - (Optional) Additional parameters: + * - (Optional) `path`: The path within the JSON document, defaults to root (`"."`) if not provided. + * - (Optional) `decoder`: see {@link DecoderOption}. + * @returns + * - For JSONPath (path starts with `$`): + * - Returns an array of replies for every possible path, indicating the RESP form of the JSON value. + * If `path` doesn't exist, returns an empty array. + * - For legacy path (path doesn't start with `$`): + * - Returns a single reply for the JSON value at the specified `path`, in its RESP form. + * If multiple paths match, the value of the first JSON value match is returned. If `path` doesn't exist, an error is raised. + * - If `key` doesn't exist, `null` is returned. + * + * @example + * ```typescript + * console.log(await GlideJson.set(client, "doc", ".", '{a: [1, 2, 3], b: {a: [1, 2], c: {a: 42}}}')); + * // Output: 'OK' - Indicates successful setting of the value at path '.' in the key stored at `doc`. + * const result = await GlideJson.resp(client, "doc", {path: "$..a"}); + * console.log(result); + * // Output: [ ["[", 1, 2, 3], ["[", 1, 2], [42]]; + * console.log(await GlideJson.type(client, "doc", {path: "..a"})); // Output: ["[", 1, 2, 3] + * ``` + */ + static async resp( + client: BaseClient, + key: GlideString, + options?: { path: GlideString } & DecoderOption, + ): Promise< + UniversalReturnTypeJson< + (number | GlideString) | (number | GlideString | null) | null + > + > { + const args = ["JSON.RESP", key]; + + if (options) { + args.push(options.path); + } + + return _executeCommand(client, args, options); + } + + /** + * Returns the length of the JSON string value stored at the specified `path` within + * the JSON document stored at `key`. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param options - (Optional) Additional parameters: + * - (Optional) `path`: The path within the JSON document, Defaults to root (`"."`) if not provided. + * @returns + * - For JSONPath (path starts with `$`): + * - Returns a list of integer replies for every possible path, indicating the length of + * the JSON string value, or `null` for JSON values matching the path that + * are not string. + * - For legacy path (path doesn't start with `$`): + * - Returns the length of the JSON value at `path` or `null` if `key` doesn't exist. + * - If multiple paths match, the length of the first matched string is returned. + * - If the JSON value at`path` is not a string or if `path` doesn't exist, an error is raised. + * - If `key` doesn't exist, `null` is returned. + * + * @example + * ```typescript + * console.log(await GlideJson.set(client, "doc", "$", '{a:"foo", nested: {a: "hello"}, nested2: {a: 31}}')); + * // Output: 'OK' - Indicates successful setting of the value at path '$' in the key stored at `doc`. + * console.log(await GlideJson.strlen(client, "doc", {path: "$..a"})); + * // Output: [3, 5, null] - The length of the string values at path '$..a' in the key stored at `doc`. + * + * console.log(await GlideJson.strlen(client, "doc", {path: "nested.a"})); + * // Output: 5 - The length of the JSON value at path 'nested.a' in the key stored at `doc`. + * + * console.log(await GlideJson.strlen(client, "doc", {path: "$"})); + * // Output: [null] - Returns an array with null since the value at root path does in the JSON document stored at `doc` is not a string. + * + * console.log(await GlideJson.strlen(client, "non_existent_key", {path: "."})); + * // Output: null - return null if key does not exist. + * ``` + */ + static async strlen( + client: BaseClient, + key: GlideString, + options?: { path: GlideString }, + ): Promise> { + const args = ["JSON.STRLEN", key]; + + if (options) { + args.push(options.path); + } + + return _executeCommand(client, args); + } + + /** + * Appends the specified `value` to the string stored at the specified `path` within the JSON document stored at `key`. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param value - The value to append to the string. Must be wrapped with single quotes. For example, to append "foo", pass '"foo"'. + * @param options - (Optional) Additional parameters: + * - (Optional) `path`: The path within the JSON document, defaults to root (`"."`) if not provided. + * @returns + * - For JSONPath (path starts with `$`): + * - Returns a list of integer replies for every possible path, indicating the length of the resulting string after appending `value`, + * or None for JSON values matching the path that are not string. + * - If `key` doesn't exist, an error is raised. + * - For legacy path (path doesn't start with `$`): + * - Returns the length of the resulting string after appending `value` to the string at `path`. + * - If multiple paths match, the length of the last updated string is returned. + * - If the JSON value at `path` is not a string of if `path` doesn't exist, an error is raised. + * - If `key` doesn't exist, an error is raised. + * + * @example + * ```typescript + * console.log(await GlideJson.set(client, "doc", "$", '{a:"foo", nested: {a: "hello"}, nested2: {a: 31}}')); + * // Output: 'OK' - Indicates successful setting of the value at path '$' in the key stored at `doc`. + * console.log(await GlideJson.strappend(client, "doc", jsonpy.dumps("baz"), {path: "$..a"})) + * // Output: [6, 8, null] - The new length of the string values at path '$..a' in the key stored at `doc` after the append operation. + * + * console.log(await GlideJson.strappend(client, "doc", '"foo"', {path: "nested.a"})); + * // Output: 11 - The length of the string value after appending "foo" to the string at path 'nested.array' in the key stored at `doc`. + * + * const result = JSON.parse(await GlideJson.get(client, "doc", {path: "$"})); + * console.log(result); + * // Output: [{"a":"foobaz", "nested": {"a": "hellobazfoo"}, "nested2": {"a": 31}}] - The updated JSON value in the key stored at `doc`. + * ``` + */ + static async strappend( + client: BaseClient, + key: GlideString, + value: GlideString, + options?: { path: GlideString }, + ): Promise> { + const args = ["JSON.STRAPPEND", key]; + + if (options) { + args.push(options.path); + } + + args.push(value); + + return _executeCommand>(client, args); + } + + /** + * Appends one or more `values` to the JSON array at the specified `path` within the JSON + * document stored at `key`. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param path - The path within the JSON document. + * @param values - The JSON values to be appended to the array. + * JSON string values must be wrapped with quotes. For example, to append `"foo"`, pass `"\"foo\""`. + * @returns + * - For JSONPath (path starts with `$`): + * Returns an array with a list of integers for every possible path, + * indicating the new length of the array, or `null` for JSON values matching + * the path that are not an array. If `path` does not exist, an empty array + * will be returned. + * - For legacy path (path doesn't start with `$`): + * Returns an integer representing the new length of the array. If multiple paths are + * matched, returns the length of the first modified array. If `path` doesn't + * exist or the value at `path` is not an array, an error is raised. + * - If the index is out of bounds or `key` doesn't exist, an error is raised. + * + * @example + * ```typescript + * await GlideJson.set(client, "doc", "$", '{"a": 1, "b": ["one", "two"]}'); + * const result = await GlideJson.arrappend(client, "doc", "$.b", ["three"]); + * console.log(result); // Output: [3] - the new length of the array at path '$.b' after appending the value. + * const result = await GlideJson.arrappend(client, "doc", ".b", ["four"]); + * console.log(result); // Output: 4 - the new length of the array at path '.b' after appending the value. + * const doc = await json.get(client, "doc"); + * console.log(doc); // Output: '{"a": 1, "b": ["one", "two", "three", "four"]}' + * ``` + */ + static async arrappend( + client: BaseClient, + key: GlideString, + path: GlideString, + values: GlideString[], + ): Promise> { + const args = ["JSON.ARRAPPEND", key, path, ...values]; + return _executeCommand(client, args); + } + + /** + * Reports memory usage in bytes of a JSON object at the specified `path` within the JSON document stored at `key`. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param value - The value to append to the string. Must be wrapped with single quotes. For example, to append "foo", pass '"foo"'. + * @param options - (Optional) Additional parameters: + * - (Optional) `path`: The path within the JSON document, returns total memory usage if no path is given. + * @returns + * - For JSONPath (path starts with `$`): + * - Returns an array of numbers for every possible path, indicating the memory usage. + * If `path` does not exist, an empty array will be returned. + * - For legacy path (path doesn't start with `$`): + * - Returns an integer representing the memory usage. If multiple paths are matched, + * returns the data of the first matching object. If `path` doesn't exist, an error is raised. + * - If `key` doesn't exist, returns `null`. + * + * @example + * ```typescript + * console.log(await GlideJson.set(client, "doc", "$", '[1, 2.3, "foo", true, null, {}, [], {a:1, b:2}, [1, 2, 3]]')); + * // Output: 'OK' - Indicates successful setting of the value at path '$' in the key stored at `doc`. + * console.log(await GlideJson.debugMemory(client, "doc", {path: ".."}); + * // Output: 258 + * ``` + */ + static async debugMemory( + client: BaseClient, + key: GlideString, + options?: { path: GlideString }, + ): Promise> { + const args = ["JSON.DEBUG", "MEMORY", key]; + + if (options) { + args.push(options.path); + } + + return _executeCommand(client, args); + } + + /** + * Reports the number of fields at the specified `path` within the JSON document stored at `key`. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param value - The value to append to the string. Must be wrapped with single quotes. For example, to append "foo", pass '"foo"'. + * @param options - (Optional) Additional parameters: + * - (Optional) `path`: The path within the JSON document, returns total number of fields if no path is given. + * @returns + * - For JSONPath (path starts with `$`): + * - Returns an array of numbers for every possible path, indicating the number of fields. + * If `path` does not exist, an empty array will be returned. + * - For legacy path (path doesn't start with `$`): + * - Returns an integer representing the memory usage. If multiple paths are matched, + * returns the data of the first matching object. If `path` doesn't exist, an error is raised. + * - If `key` doesn't exist, returns `null`. + * + * @example + * ```typescript + * console.log(await GlideJson.set(client, "doc", "$", '[1, 2.3, "foo", true, null, {}, [], {a:1, b:2}, [1, 2, 3]]')); + * // Output: 'OK' - Indicates successful setting of the value at path '$' in the key stored at `doc`. + * console.log(await GlideJson.debugMemory(client, "doc", {path: "$[*]"}); + * // Output: [1, 1, 1, 1, 1, 0, 0, 2, 3] + * ``` + */ + static async debugFields( + client: BaseClient, + key: GlideString, + options?: { path: GlideString }, + ): Promise> { + const args = ["JSON.DEBUG", "FIELDS", key]; + + if (options) { + args.push(options.path); + } + + return _executeCommand(client, args); + } + + /** + * Increments or decrements the JSON value(s) at the specified `path` by `number` within the JSON document stored at `key`. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param path - The path within the JSON document. + * @param num - The number to increment or decrement by. + * @returns + * - For JSONPath (path starts with `$`): + * - Returns a string representation of an array of strings, indicating the new values after incrementing for each matched `path`. + * If a value is not a number, its corresponding return value will be `null`. + * If `path` doesn't exist, a byte string representation of an empty array will be returned. + * - For legacy path (path doesn't start with `$`): + * - Returns a string representation of the resulting value after the increment or decrement. + * If multiple paths match, the result of the last updated value is returned. + * If the value at the `path` is not a number or `path` doesn't exist, an error is raised. + * - If `key` does not exist, an error is raised. + * - If the result is out of the range of 64-bit IEEE double, an error is raised. + * + * @example + * ```typescript + * console.log(await GlideJson.set(client, "doc", "$", '{"a": [], "b": [1], "c": [1, 2], "d": [1, 2, 3]}')); + * // Output: 'OK' - Indicates successful setting of the value at path '$' in the key stored at `doc`. + * console.log(await GlideJson.numincrby(client, "doc", "$.d[*]", 10)) + * // Output: '[11,12,13]' - Increment each element in `d` array by 10. + * + * console.log(await GlideJson.numincrby(client, "doc", ".c[1]", 10)); + * // Output: '12' - Increment the second element in the `c` array by 10. + * ``` + */ + static async numincrby( + client: BaseClient, + key: GlideString, + path: GlideString, + num: number, + ): Promise { + const args = ["JSON.NUMINCRBY", key, path, num.toString()]; + return _executeCommand(client, args); + } + + /** + * Multiplies the JSON value(s) at the specified `path` by `number` within the JSON document stored at `key`. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param path - The path within the JSON document. + * @param num - The number to multiply by. + * @returns + * - For JSONPath (path starts with `$`): + * - Returns a GlideString representation of an array of strings, indicating the new values after multiplication for each matched `path`. + * If a value is not a number, its corresponding return value will be `null`. + * If `path` doesn't exist, a byte string representation of an empty array will be returned. + * - For legacy path (path doesn't start with `$`): + * - Returns a GlideString representation of the resulting value after multiplication. + * If multiple paths match, the result of the last updated value is returned. + * If the value at the `path` is not a number or `path` doesn't exist, an error is raised. + * - If `key` does not exist, an error is raised. + * - If the result is out of the range of 64-bit IEEE double, an error is raised. + * + * @example + * ```typescript + * console.log(await GlideJson.set(client, "doc", "$", '{"a": [], "b": [1], "c": [1, 2], "d": [1, 2, 3]}')); + * // Output: 'OK' - Indicates successful setting of the value at path '$' in the key stored at `doc`. + * console.log(await GlideJson.nummultby(client, "doc", "$.d[*]", 2)) + * // Output: '[2,4,6]' - Multiplies each element in the `d` array by 2. + * + * console.log(await GlideJson.nummultby(client, "doc", ".c[1]", 2)); + * // Output: '4' - Multiplies the second element in the `c` array by 2. + * ``` + */ + static async nummultby( + client: BaseClient, + key: GlideString, + path: GlideString, + num: number, + ): Promise { + const args = ["JSON.NUMMULTBY", key, path, num.toString()]; + return _executeCommand(client, args); + } + + /** + * Retrieves the number of key-value pairs in the object stored at the specified `path` within the JSON document stored at `key`. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param options - (Optional) Additional parameters: + * - (Optional) `path`: The path within the JSON document, Defaults to root (`"."`) if not provided. + * @returns ReturnTypeJson: + * - For JSONPath (`path` starts with `$`): + * - Returns a list of integer replies for every possible path, indicating the length of the object, + * or `null` for JSON values matching the path that are not an object. + * - If `path` doesn't exist, an empty array will be returned. + * - For legacy path (`path` doesn't starts with `$`): + * - Returns the length of the object at `path`. + * - If multiple paths match, the length of the first object match is returned. + * - If the JSON value at `path` is not an object or if `path` doesn't exist, an error is raised. + * - If `key` doesn't exist, `null` is returned. + * + * @example + * ```typescript + * console.log(await GlideJson.set(client, "doc", "$", '{"a": 1.0, "b": {"a": {"x": 1, "y": 2}, "b": 2.5, "c": true}}')); + * // Output: 'OK' - Indicates successful setting of the value at the root path '$' in the key `doc`. + * console.log(await GlideJson.objlen(client, "doc", { path: "$" })); + * // Output: [2] - Returns the number of key-value pairs at the root object, which has 2 keys: 'a' and 'b'. + * console.log(await GlideJson.objlen(client, "doc", { path: "." })); + * // Output: 2 - Returns the number of key-value pairs for the object matching the path '.', which has 2 keys: 'a' and 'b'. + * ``` + */ + static async objlen( + client: BaseClient, + key: GlideString, + options?: { path: GlideString }, + ): Promise> { + const args = ["JSON.OBJLEN", key]; + + if (options) { + args.push(options.path); + } + + return _executeCommand>(client, args); + } + + /** + * Retrieves key names in the object values at the specified `path` within the JSON document stored at `key`. + * + * @param client - The client to execute the command. + * @param key - The key of the JSON document. + * @param options - (Optional) Additional parameters: + * - (Optional) `path`: The path within the JSON document where the key names will be retrieved. Defaults to root (`"."`) if not provided. + * @returns ReturnTypeJson: + * - For JSONPath (`path` starts with `$`): + * - Returns a list of arrays containing key names for each matching object. + * - If a value matching the path is not an object, an empty array is returned. + * - If `path` doesn't exist, an empty array is returned. + * - For legacy path (`path` starts with `.`): + * - Returns a list of key names for the object value matching the path. + * - If multiple objects match the path, the key names of the first object is returned. + * - If a value matching the path is not an object, an error is raised. + * - If `path` doesn't exist, `null` is returned. + * - If `key` doesn't exist, `null` is returned. + * + * @example + * ```typescript + * console.log(await GlideJson.set(client, "doc", "$", '{"a": 1.0, "b": {"a": {"x": 1, "y": 2}, "b": 2.5, "c": true}}')); + * // Output: 'OK' - Indicates successful setting of the value at the root path '$' in the key `doc`. + * console.log(await GlideJson.objkeys(client, "doc", { path: "$" })); + * // Output: [["a", "b"]] - Returns a list of arrays containing the key names for objects matching the path '$'. + * console.log(await GlideJson.objkeys(client, "doc", { path: "." })); + * // Output: ["a", "b"] - Returns key names for the object matching the path '.' as it is the only match. + * ``` + */ + static async objkeys( + client: BaseClient, + key: GlideString, + options?: { path: GlideString } & DecoderOption, + ): Promise> { + const args = ["JSON.OBJKEYS", key]; + + if (options) { + args.push(options.path); + } + + return _executeCommand(client, args, options); + } +} diff --git a/node/tests/AuthTest.test.ts b/node/tests/AuthTest.test.ts new file mode 100644 index 0000000000..3466199ba8 --- /dev/null +++ b/node/tests/AuthTest.test.ts @@ -0,0 +1,330 @@ +/** + * Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + */ + +import { + afterAll, + afterEach, + beforeAll, + describe, + expect, + it, +} from "@jest/globals"; +import { + BaseClientConfiguration, + GlideClient, + GlideClusterClient, + ProtocolVersion, + RequestError, +} from ".."; +import { ValkeyCluster } from "../../utils/TestUtils"; +import { + flushAndCloseClient, + getServerVersion, + parseEndpoints, +} from "./TestUtilities"; + +type BaseClient = GlideClient | GlideClusterClient; + +const TIMEOUT = 50000; + +type AddressEntry = [string, number]; + +describe("Auth tests", () => { + let cmeCluster: ValkeyCluster; + let cmdCluster: ValkeyCluster; + let managementClient: BaseClient; + let client: BaseClient; + beforeAll(async () => { + const standaloneAddresses = global.STAND_ALONE_ENDPOINT; + const clusterAddresses = global.CLUSTER_ENDPOINTS; + + // Connect to cluster or create a new one based on the parsed addresses + cmdCluster = standaloneAddresses + ? await ValkeyCluster.initFromExistingCluster( + false, + parseEndpoints(standaloneAddresses), + getServerVersion, + ) + : await ValkeyCluster.createCluster(false, 1, 1, getServerVersion); + + cmeCluster = clusterAddresses + ? await ValkeyCluster.initFromExistingCluster( + true, + parseEndpoints(clusterAddresses), + getServerVersion, + ) + : await ValkeyCluster.createCluster(true, 3, 1, getServerVersion); + + // Create appropriate client based on mode + const isStandaloneMode = !!standaloneAddresses; + const activeCluster = isStandaloneMode ? cmdCluster : cmeCluster; + const ClientClass = isStandaloneMode ? GlideClient : GlideClusterClient; + + managementClient = await ClientClass.createClient({ + addresses: formatAddresses(activeCluster.getAddresses()), + }); + }, 40000); + + const formatAddresses = ( + addresses: AddressEntry[], + ): { host: string; port: number }[] => + addresses.map(([host, port]) => ({ host, port })); + + afterEach(async () => { + if (managementClient) { + try { + await managementClient.customCommand(["AUTH", "new_password"]); + await managementClient.configSet({ requirepass: "" }); + } catch { + // Ignore errors + } + + await managementClient.flushall(); + + try { + await client.updateConnectionPassword(""); + } catch { + // Ignore errors + } + } + + if (cmdCluster) { + await flushAndCloseClient(false, cmdCluster.getAddresses()); + } + + if (cmeCluster) { + await flushAndCloseClient(true, cmeCluster.getAddresses()); + } + }); + + afterAll(async () => { + await cmdCluster?.close(); + await cmeCluster?.close(); + managementClient?.close(); + }); + + const runTest = async ( + test: (client: BaseClient) => Promise, + protocol: ProtocolVersion, + configOverrides?: Partial, + ) => { + const isStandaloneMode = configOverrides?.addresses?.length === 1; + const activeCluster = isStandaloneMode ? cmdCluster : cmeCluster; + + if (!activeCluster) { + throw new Error( + `${isStandaloneMode ? "Standalone" : "Cluster"} mode not configured`, + ); + } + + const ClientClass = isStandaloneMode ? GlideClient : GlideClusterClient; + const addresses = formatAddresses(activeCluster.getAddresses()); + + client = await ClientClass.createClient({ + addresses, + protocol, + ...configOverrides, + }); + + try { + await test(client); + } finally { + client.close(); + } + }; + + describe.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "update_connection_password_%p", + (protocol) => { + const NEW_PASSWORD = "new_password"; + const WRONG_PASSWORD = "wrong_password"; + + /** + * Test replacing connection password with immediate re-authentication using a non-valid password. + * Verifies that immediate re-authentication fails when the password is not valid. + */ + it("test_update_connection_password_auth_non_valid_pass", async () => { + await runTest(async (client: BaseClient) => { + await expect( + client.updateConnectionPassword(null, true), + ).rejects.toThrow(RequestError); + await expect( + client.updateConnectionPassword("", true), + ).rejects.toThrow(RequestError); + }, protocol); + }); + + /** + * Test replacing the connection password without immediate re-authentication. + * Verifies that: + * 1. The client can update its internal password + * 2. The client remains connected with current auth + * 3. The client can reconnect using the new password after server password change + * Currently, this test is only supported for cluster mode, + * since standalone mode dont have multiple connections to manage, + * and the client will try to reconnect and will not listen to new tasks. + */ + it( + "test_update_connection_password", + async () => { + await runTest(async (client: BaseClient) => { + if (client instanceof GlideClient) { + return; + } + + // Update password without re-authentication + const result = await client.updateConnectionPassword( + NEW_PASSWORD, + false, + ); + expect(result).toEqual("OK"); + + // Verify client still works with old auth + await client.set("test_key", "test_value"); + const value = await client.get("test_key"); + expect(value).toEqual("test_value"); + + // Update server password + await client.configSet({ requirepass: NEW_PASSWORD }); + + // Kill all other clients to force reconnection + await managementClient.customCommand([ + "CLIENT", + "KILL", + "TYPE", + "normal", + ]); + + // Verify client auto-reconnects with new password + await client.set("test_key2", "test_value2"); + const value2 = await client.get("test_key2"); + expect(value2).toEqual("test_value2"); + }, protocol); + }, + TIMEOUT, + ); + + /** + * Test that immediate re-authentication fails when no server password is set. + */ + it("test_update_connection_password_no_server_auth", async () => { + await runTest(async (client: BaseClient) => { + try { + await expect( + client.updateConnectionPassword(NEW_PASSWORD, true), + ).rejects.toThrow(RequestError); + } finally { + client?.close(); + } + }, protocol); + }); + + /** + * Test replacing connection password with a long password string. + */ + it("test_update_connection_password_long", async () => { + await runTest(async (client: BaseClient) => { + const longPassword = "p".repeat(1000); + expect( + await client.updateConnectionPassword( + longPassword, + false, + ), + ).toEqual("OK"); + await client.configSet({ + requirepass: "", + }); + }, protocol); + }); + + /** + * Test that re-authentication fails when using wrong password. + */ + it("test_replace_password_immediateAuth_wrong_password", async () => { + await runTest(async (client: BaseClient) => { + await client.configSet({ + requirepass: NEW_PASSWORD, + }); + await expect( + client.updateConnectionPassword(WRONG_PASSWORD, true), + ).rejects.toThrow(RequestError); + await expect( + client.updateConnectionPassword(NEW_PASSWORD, true), + ).resolves.toBe("OK"); + }, protocol); + }); + + /** + * Test replacing connection password with immediate re-authentication. + */ + it( + "test_update_connection_password_with_immediateAuth", + async () => { + await runTest(async (client: BaseClient) => { + // Set server password + await client.configSet({ requirepass: NEW_PASSWORD }); + + // Update client password with re-auth + expect( + await client.updateConnectionPassword( + NEW_PASSWORD, + true, + ), + ).toEqual("OK"); + + // Verify client works with new auth + await client.set("test_key", "test_value"); + const value = await client.get("test_key"); + expect(value).toEqual("test_value"); + }, protocol); + }, + TIMEOUT, + ); + + /** + * Test changing server password when connection is lost before password update. + * Verifies that the client will not be able to reach the connection under the abstraction and return an error. + * + * **Note: This test is only supported for standalone mode, bellow explanation why* + * + * Some explanation for the curious mind: + * Our library is abstracting a connection or connections, with a lot of mechanism around it, making it behave like what we call a "client". + * When using standalone mode, the client is a single connection, so on disconnection the first thing it planned to do is to reconnect. + * Theres no reason to get other commands and to take care of them since to serve commands we need to be connected. + * Hence, the client will try to reconnect and will not listen try to take care of new tasks, but will let them wait in line, + * so the update connection password will not be able to reach the connection and will return an error. + * For future versions, standalone will be considered as a different animal then it is now, since standalone is not necessarily one node. + * It can be replicated and have a lot of nodes, and to be what we like to call "one shard cluster". + * So, in the future, we will have many existing connection and request can be managed also when one connection is locked. + * + */ + it("test_update_connection_password_connection_lost_before_password_update", async () => { + await runTest(async (client: BaseClient) => { + if (client instanceof GlideClusterClient) { + return; + } + + // Set a key to ensure connection is established + await client.set("test_key", "test_value"); + // Update server password + await client.configSet({ requirepass: NEW_PASSWORD }); + // Kill client connections + await managementClient.customCommand([ + "CLIENT", + "KILL", + "TYPE", + "normal", + ]); + // Try updating client password without immediate re-auth and with, both should fail + await expect( + client.updateConnectionPassword(NEW_PASSWORD, false), + ).rejects.toThrow(RequestError); + await expect( + client.updateConnectionPassword(NEW_PASSWORD, true), + ).rejects.toThrow(RequestError); + }, protocol); + }); + }, + ); +}); diff --git a/node/tests/GlideClient.test.ts b/node/tests/GlideClient.test.ts index 858ad87943..0c77bde519 100644 --- a/node/tests/GlideClient.test.ts +++ b/node/tests/GlideClient.test.ts @@ -13,7 +13,6 @@ import { import { BufferReader, BufferWriter } from "protobufjs"; import { v4 as uuidv4 } from "uuid"; import { - convertGlideRecordToRecord, Decoder, FlushMode, FunctionRestorePolicy, @@ -25,6 +24,7 @@ import { RequestError, Script, Transaction, + convertGlideRecordToRecord, } from ".."; import { ValkeyCluster } from "../../utils/TestUtils.js"; import { command_request } from "../src/ProtobufMessage"; @@ -35,18 +35,15 @@ import { convertStringArrayToBuffer, createLongRunningLuaScript, createLuaLibWithLongRunningFunction, - DumpAndRestoreTest, encodableTransactionTest, flushAndCloseClient, generateLuaLibCode, getClientConfigurationOption, getServerVersion, - parseCommandLineArgs, parseEndpoints, transactionTest, validateTransactionResponse, waitForNotBusy, - waitForScriptNotBusy, } from "./TestUtilities"; const TIMEOUT = 50000; @@ -54,10 +51,11 @@ const TIMEOUT = 50000; describe("GlideClient", () => { let testsFailed = 0; let cluster: ValkeyCluster; + let azCluster: ValkeyCluster; let client: GlideClient; + let azClient: GlideClient; beforeAll(async () => { - const standaloneAddresses = - parseCommandLineArgs()["standalone-endpoints"]; + const standaloneAddresses = global.STAND_ALONE_ENDPOINT; cluster = standaloneAddresses ? await ValkeyCluster.initFromExistingCluster( false, @@ -65,17 +63,28 @@ describe("GlideClient", () => { getServerVersion, ) : await ValkeyCluster.createCluster(false, 1, 1, getServerVersion); + + azCluster = standaloneAddresses + ? await ValkeyCluster.initFromExistingCluster( + false, + parseEndpoints(standaloneAddresses), + getServerVersion, + ) + : await ValkeyCluster.createCluster(false, 1, 1, getServerVersion); }, 20000); afterEach(async () => { await flushAndCloseClient(false, cluster.getAddresses(), client); + await flushAndCloseClient(false, azCluster.getAddresses(), azClient); }); afterAll(async () => { if (testsFailed === 0) { await cluster.close(); + await azCluster.close(); } else { await cluster.close(true); + await azCluster.close(); } }, TIMEOUT); @@ -286,32 +295,27 @@ describe("GlideClient", () => { client = await GlideClient.createClient( getClientConfigurationOption(cluster.getAddresses(), protocol), ); - const bytesTransaction = new Transaction(); - await client.set("key", "value"); - const dumpValue: Buffer = (await client.dump("key")) as Buffer; - await client.del(["key"]); - const expectedBytesRes = await DumpAndRestoreTest( - bytesTransaction, - dumpValue, - ); - bytesTransaction.select(0); - const result = await client.exec(bytesTransaction, { - decoder: Decoder.Bytes, - }); - expectedBytesRes.push(["select(0)", "OK"]); - - validateTransactionResponse(result, expectedBytesRes); + const key1 = uuidv4(); + const key2 = uuidv4(); + const value = uuidv4(); - const stringTransaction = new Transaction(); - await DumpAndRestoreTest(stringTransaction, dumpValue); - stringTransaction.select(0); + const transaction1 = new Transaction().set(key1, value).dump(key1); // Since DUMP gets binary results, we cannot use the string decoder here, so we expected to get an error. await expect( - client.exec(stringTransaction, { decoder: Decoder.String }), - ).rejects.toThrowError( - /invalid utf-8 sequence of 1 bytes from index/, - ); + client.exec(transaction1, { decoder: Decoder.String }), + ).rejects.toThrow("invalid utf-8 sequence of"); + + const result = await client.exec(transaction1, { + decoder: Decoder.Bytes, + }); + expect(result?.[0]).toEqual("OK"); + const dump = result?.[1] as Buffer; + + const transaction2 = new Transaction().restore(key2, 0, dump); + expect(await client.exec(transaction2)).toEqual(["OK"]); + + expect(value).toEqual(await client.get(key2)); client.close(); }, @@ -1462,70 +1466,53 @@ describe("GlideClient", () => { TIMEOUT, ); - it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( - "script kill killable test_%p", - async (protocol) => { + it.each([ + [ProtocolVersion.RESP2, 5], + [ProtocolVersion.RESP2, 100], + [ProtocolVersion.RESP2, 1500], + [ProtocolVersion.RESP3, 5], + [ProtocolVersion.RESP3, 100], + [ProtocolVersion.RESP3, 1500], + ])( + "test inflight requests limit of %p with protocol %p", + async (protocol, inflightRequestsLimit) => { const config = getClientConfigurationOption( cluster.getAddresses(), protocol, - { requestTimeout: 10000 }, + { inflightRequestsLimit }, ); - const client1 = await GlideClient.createClient(config); - const client2 = await GlideClient.createClient(config); + const client = await GlideClient.createClient(config); try { - // Verify that script kill raises an error when no script is running - await expect(client1.scriptKill()).rejects.toThrow( - "No scripts in execution right now", - ); + const key1 = `{nonexistinglist}:1-${uuidv4()}`; + const tasks: Promise<[GlideString, GlideString] | null>[] = []; - // Create a long-running script - const longScript = new Script( - createLongRunningLuaScript(5, false), - ); - - try { - // call the script without await - const promise = client2 - .invokeScript(longScript) - .catch((e) => - expect((e as Error).message).toContain( - "Script killed", - ), - ); - - let killed = false; - let timeout = 4000; - await new Promise((resolve) => setTimeout(resolve, 1000)); - - while (timeout >= 0) { - try { - expect(await client1.scriptKill()).toEqual("OK"); - killed = true; - break; - } catch { - // do nothing - } + // Start inflightRequestsLimit blocking tasks + for (let i = 0; i < inflightRequestsLimit; i++) { + tasks.push(client.blpop([key1], 0)); + } - await new Promise((resolve) => - setTimeout(resolve, 500), - ); - timeout -= 500; - } + // This task should immediately fail due to reaching the limit + await expect(client.blpop([key1], 0)).rejects.toThrow( + RequestError, + ); - expect(killed).toBeTruthy(); - await promise; - } finally { - await waitForScriptNotBusy(client1); - } + // Verify that all previous tasks are still pending + const timeoutPromise = new Promise((resolve) => + setTimeout(resolve, 100), + ); + const allTasksStatus = await Promise.race([ + Promise.any( + tasks.map((task) => task.then(() => "resolved")), + ), + timeoutPromise.then(() => "pending"), + ]); + expect(allTasksStatus).toBe("pending"); } finally { - expect(await client1.scriptFlush()).toEqual("OK"); - client1.close(); - client2.close(); + await client.close(); } }, ); - runBaseTests({ init: async (protocol, configOverrides) => { const config = getClientConfigurationOption( @@ -1533,10 +1520,18 @@ describe("GlideClient", () => { protocol, configOverrides, ); + client = await GlideClient.createClient(config); + + const configNew = getClientConfigurationOption( + azCluster.getAddresses(), + protocol, + configOverrides, + ); testsFailed += 1; + azClient = await GlideClient.createClient(configNew); client = await GlideClient.createClient(config); - return { client, cluster }; + return { client, cluster, azClient, azCluster }; }, close: (testSucceeded: boolean) => { if (testSucceeded) { diff --git a/node/tests/GlideClusterClient.test.ts b/node/tests/GlideClusterClient.test.ts index 8cde021e00..163989ccdd 100644 --- a/node/tests/GlideClusterClient.test.ts +++ b/node/tests/GlideClusterClient.test.ts @@ -15,7 +15,6 @@ import { v4 as uuidv4 } from "uuid"; import { BitwiseOperation, ClusterTransaction, - convertRecordToGlideRecord, Decoder, FlushMode, FunctionListResponse, @@ -24,15 +23,18 @@ import { GeoUnit, GlideClusterClient, GlideReturnType, + GlideString, InfoOptions, ListDirection, ProtocolVersion, + ReadFrom, RequestError, Routes, ScoreFilter, Script, SlotKeyTypes, SortOrder, + convertRecordToGlideRecord, } from ".."; import { ValkeyCluster } from "../../utils/TestUtils.js"; import { runBaseTests } from "./SharedTests"; @@ -49,12 +51,10 @@ import { getServerVersion, intoArray, intoString, - parseCommandLineArgs, parseEndpoints, transactionTest, validateTransactionResponse, waitForNotBusy, - waitForScriptNotBusy, } from "./TestUtilities"; const TIMEOUT = 50000; @@ -62,45 +62,80 @@ const TIMEOUT = 50000; describe("GlideClusterClient", () => { let testsFailed = 0; let cluster: ValkeyCluster; + let azCluster: ValkeyCluster; let client: GlideClusterClient; + let azClient: GlideClusterClient; beforeAll(async () => { - const clusterAddresses = parseCommandLineArgs()["cluster-endpoints"]; - // Connect to cluster or create a new one based on the parsed addresses - cluster = clusterAddresses - ? await ValkeyCluster.initFromExistingCluster( - true, - parseEndpoints(clusterAddresses), - getServerVersion, - ) - : // setting replicaCount to 1 to facilitate tests routed to replicas - await ValkeyCluster.createCluster(true, 3, 1, getServerVersion); - }, 20000); + const clusterAddresses = global.CLUSTER_ENDPOINTS; + + if (clusterAddresses) { + // Initialize current cluster from existing addresses + cluster = await ValkeyCluster.initFromExistingCluster( + true, + parseEndpoints(clusterAddresses), + getServerVersion, + ); + + // Initialize cluster from existing addresses for AzAffinity test + azCluster = await ValkeyCluster.initFromExistingCluster( + true, + parseEndpoints(clusterAddresses), + getServerVersion, + ); + } else { + cluster = await ValkeyCluster.createCluster( + true, + 3, + 1, + getServerVersion, + ); + + azCluster = await ValkeyCluster.createCluster( + true, + 3, + 4, + getServerVersion, + ); + } + }, 120000); afterEach(async () => { await flushAndCloseClient(true, cluster.getAddresses(), client); + await flushAndCloseClient(true, azCluster.getAddresses(), azClient); }); afterAll(async () => { if (testsFailed === 0) { - await cluster.close(); + if (cluster) await cluster.close(); + if (azCluster) await azCluster.close(); } else { - await cluster.close(true); + if (cluster) await cluster.close(true); + if (azCluster) await azCluster.close(true); } }); runBaseTests({ init: async (protocol, configOverrides) => { - const config = getClientConfigurationOption( + const configCurrent = getClientConfigurationOption( cluster.getAddresses(), protocol, configOverrides, ); + client = await GlideClusterClient.createClient(configCurrent); + + const configNew = getClientConfigurationOption( + azCluster.getAddresses(), + protocol, + configOverrides, + ); + azClient = await GlideClusterClient.createClient(configNew); testsFailed += 1; - client = await GlideClusterClient.createClient(config); return { client, + azClient, cluster, + azCluster, }; }, close: (testSucceeded: boolean) => { @@ -247,7 +282,7 @@ describe("GlideClusterClient", () => { expect(await client.set(key, value)).toEqual("OK"); // Since DUMP gets binary results, we cannot use the default decoder (string) here, so we expected to get an error. await expect(client.customCommand(["DUMP", key])).rejects.toThrow( - "invalid utf-8 sequence of 1 bytes from index", + "invalid utf-8 sequence", ); const dumpResult = await client.customCommand(["DUMP", key], { @@ -406,7 +441,6 @@ describe("GlideClusterClient", () => { client.sdiffstore("abc", ["zxy", "lkn"]), client.sortStore("abc", "zyx"), client.sortStore("abc", "zyx", { isAlpha: true }), - ...lmpopArr, client.bzpopmax(["abc", "def"], 0.5), client.bzpopmin(["abc", "def"], 0.5), client.xread({ abc: "0-0", zxy: "0-0", lkn: "0-0" }), @@ -449,6 +483,12 @@ describe("GlideClusterClient", () => { client.lcs("abc", "xyz"), client.lcsLen("abc", "xyz"), client.lcsIdx("abc", "xyz"), + client.lmpop(["abc", "def"], ListDirection.LEFT, { + count: 1, + }), + client.blmpop(["abc", "def"], ListDirection.RIGHT, 0.1, { + count: 1, + }), ); } @@ -1929,68 +1969,481 @@ describe("GlideClusterClient", () => { TIMEOUT, ); - it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( - "script kill killable test_%p", - async (protocol) => { + it.each([ + [ProtocolVersion.RESP2, 5], + [ProtocolVersion.RESP2, 100], + [ProtocolVersion.RESP2, 1500], + [ProtocolVersion.RESP3, 5], + [ProtocolVersion.RESP3, 100], + [ProtocolVersion.RESP3, 1500], + ])( + "test inflight requests limit of %p with protocol %p", + async (protocol, inflightRequestsLimit) => { const config = getClientConfigurationOption( cluster.getAddresses(), protocol, - { requestTimeout: 10000 }, + { inflightRequestsLimit }, ); - const client1 = await GlideClusterClient.createClient(config); - const client2 = await GlideClusterClient.createClient(config); + const client = await GlideClusterClient.createClient(config); try { - // Verify that script kill raises an error when no script is running - await expect(client1.scriptKill()).rejects.toThrow( - "No scripts in execution right now", + const key1 = `{nonexistinglist}:1-${uuidv4()}`; + const tasks: Promise<[GlideString, GlideString] | null>[] = []; + + // Start inflightRequestsLimit blocking tasks + for (let i = 0; i < inflightRequestsLimit; i++) { + tasks.push(client.blpop([key1], 0)); + } + + // This task should immediately fail due to reaching the limit + await expect(client.blpop([key1], 0)).rejects.toThrow( + RequestError, ); - // Create a long-running script - const longScript = new Script( - createLongRunningLuaScript(5, false), + // Verify that all previous tasks are still pending + const timeoutPromise = new Promise((resolve) => + setTimeout(resolve, 100), + ); + const allTasksStatus = await Promise.race([ + Promise.any( + tasks.map((task) => task.then(() => "resolved")), + ), + timeoutPromise.then(() => "pending"), + ]); + expect(allTasksStatus).toBe("pending"); + } finally { + await client.close(); + } + }, + ); + describe("GlideClusterClient - AZAffinity Read Strategy Test", () => { + async function getNumberOfReplicas( + azClient: GlideClusterClient, + ): Promise { + const replicationInfo = await azClient.customCommand([ + "INFO", + "REPLICATION", + ]); + + if (Array.isArray(replicationInfo)) { + // Handle array response from cluster (CME Mode) + let totalReplicas = 0; + + for (const node of replicationInfo) { + const nodeInfo = node as { + key: string; + value: string | string[] | null; + }; + + if (typeof nodeInfo.value === "string") { + const lines = nodeInfo.value.split(/\r?\n/); + const connectedReplicasLine = lines.find( + (line) => + line.startsWith("connected_slaves:") || + line.startsWith("connected_replicas:"), + ); + + if (connectedReplicasLine) { + const parts = connectedReplicasLine.split(":"); + const numReplicas = parseInt(parts[1], 10); + + if (!isNaN(numReplicas)) { + // Sum up replicas from each primary node + totalReplicas += numReplicas; + } + } + } + } + + if (totalReplicas > 0) { + return totalReplicas; + } + + throw new Error( + "Could not find replica information in any node's response", ); + } + + throw new Error( + "Unexpected response format from INFO REPLICATION command", + ); + } + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "should route GET commands to all replicas with the same AZ using protocol %p", + async (protocol) => { + const az = "us-east-1a"; + const GET_CALLS_PER_REPLICA = 3; + + let client_for_config_set; + let client_for_testing_az; try { - // call the script without await - const promise = client2 - .invokeScript(longScript) - .catch((e) => - expect((e as Error).message).toContain( - "Script killed", + // Stage 1: Configure nodes + client_for_config_set = + await GlideClusterClient.createClient( + getClientConfigurationOption( + azCluster.getAddresses(), + protocol, ), ); - let killed = false; - let timeout = 4000; - await new Promise((resolve) => setTimeout(resolve, 1000)); + // Skip test if version is below 8.0.0 + if (cluster.checkIfServerVersionLessThan("8.0.0")) { + console.log( + "Skipping test: requires Valkey 8.0.0 or higher", + ); + return; + } - while (timeout >= 0) { - try { - expect(await client1.scriptKill()).toEqual("OK"); - killed = true; - break; - } catch { - // do nothing - } + await client_for_config_set.customCommand([ + "CONFIG", + "RESETSTAT", + ]); + await client_for_config_set.customCommand( + ["CONFIG", "SET", "availability-zone", az], + { route: "allNodes" }, + ); + + // Retrieve the number of replicas dynamically + const n_replicas = await getNumberOfReplicas( + client_for_config_set, + ); + + if (n_replicas === 0) { + throw new Error( + "No replicas found in the cluster. Test requires at least one replica.", + ); + } + + const GET_CALLS = GET_CALLS_PER_REPLICA * n_replicas; + const get_cmdstat = `calls=${GET_CALLS_PER_REPLICA}`; + + // Stage 2: Create AZ affinity client and verify configuration + client_for_testing_az = + await GlideClusterClient.createClient( + getClientConfigurationOption( + azCluster.getAddresses(), + protocol, + { + readFrom: "AZAffinity" as ReadFrom, + clientAz: az, + }, + ), + ); + + const azs = await client_for_testing_az.customCommand( + ["CONFIG", "GET", "availability-zone"], + { route: "allNodes" }, + ); + + if (Array.isArray(azs)) { + const allAZsMatch = azs.every((node) => { + const nodeResponse = node as { + key: string; + value: string | number; + }; + + if (protocol === ProtocolVersion.RESP2) { + // RESP2: Direct array format ["availability-zone", "us-east-1a"] + return ( + Array.isArray(nodeResponse.value) && + nodeResponse.value[1] === az + ); + } else { + // RESP3: Nested object format [{ key: "availability-zone", value: "us-east-1a" }] + return ( + Array.isArray(nodeResponse.value) && + nodeResponse.value[0]?.key === + "availability-zone" && + nodeResponse.value[0]?.value === az + ); + } + }); + expect(allAZsMatch).toBe(true); + } else { + throw new Error( + "Unexpected response format from CONFIG GET command", + ); + } + + // Stage 3: Set test data and perform GET operations + await client_for_testing_az.set("foo", "testvalue"); + + for (let i = 0; i < GET_CALLS; i++) { + await client_for_testing_az.get("foo"); + } + + // Stage 4: Verify GET commands were routed correctly + const info_result = + await client_for_testing_az.customCommand( + ["INFO", "ALL"], // Get both replication and commandstats info + { route: "allNodes" }, + ); + + if (Array.isArray(info_result)) { + const matching_entries_count = info_result.filter( + (node) => { + const nodeInfo = node as { + key: string; + value: string | string[] | null; + }; + const infoStr = + nodeInfo.value?.toString() || ""; + + // Check if this is a replica node AND it has the expected number of GET calls + const isReplicaNode = + infoStr.includes("role:slave") || + infoStr.includes("role:replica"); + + return ( + isReplicaNode && + infoStr.includes(get_cmdstat) + ); + }, + ).length; + + expect(matching_entries_count).toBe(n_replicas); // Should expect 12 as the cluster was created with 3 primary and 4 replicas, totalling 12 replica nodes + } else { + throw new Error( + "Unexpected response format from INFO command", + ); + } + } finally { + // Cleanup + await client_for_config_set?.close(); + await client_for_testing_az?.close(); + } + }, + ); + }); + describe("GlideClusterClient - AZAffinity Routing to 1 replica", () => { + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "should route commands to single replica with AZ using protocol %p", + async (protocol) => { + const az = "us-east-1a"; + const GET_CALLS = 3; + const get_cmdstat = `calls=${GET_CALLS}`; + let client_for_config_set; + let client_for_testing_az; + + try { + // Stage 1: Configure nodes + client_for_config_set = + await GlideClusterClient.createClient( + getClientConfigurationOption( + azCluster.getAddresses(), + protocol, + ), + ); - await new Promise((resolve) => - setTimeout(resolve, 500), + // Skip test if version is below 8.0.0 + if (cluster.checkIfServerVersionLessThan("8.0.0")) { + console.log( + "Skipping test: requires Valkey 8.0.0 or higher", ); - timeout -= 500; + return; } - expect(killed).toBeTruthy(); - await promise; + await client_for_config_set.customCommand( + ["CONFIG", "SET", "availability-zone", ""], + { route: "allNodes" }, + ); + + await client_for_config_set.customCommand([ + "CONFIG", + "RESETSTAT", + ]); + + await client_for_config_set.customCommand( + ["CONFIG", "SET", "availability-zone", az], + { route: { type: "replicaSlotId", id: 12182 } }, + ); + + // Stage 2: Create AZ affinity client and verify configuration + client_for_testing_az = + await GlideClusterClient.createClient( + getClientConfigurationOption( + azCluster.getAddresses(), + protocol, + { + readFrom: "AZAffinity", + clientAz: az, + }, + ), + ); + await client_for_testing_az.set("foo", "testvalue"); + + for (let i = 0; i < GET_CALLS; i++) { + await client_for_testing_az.get("foo"); + } + + // Stage 4: Verify GET commands were routed correctly + const info_result = + await client_for_testing_az.customCommand( + ["INFO", "ALL"], + { route: "allNodes" }, + ); + + // Process the info_result to check that only one replica has the GET calls + if (Array.isArray(info_result)) { + // Count the number of nodes where both get_cmdstat and az are present + const matching_entries_count = info_result.filter( + (node) => { + const nodeInfo = node as { + key: string; + value: string | string[] | null; + }; + const infoStr = + nodeInfo.value?.toString() || ""; + return ( + infoStr.includes(get_cmdstat) && + infoStr.includes(`availability_zone:${az}`) + ); + }, + ).length; + + expect(matching_entries_count).toBe(1); + + // Check that only one node has the availability zone set to az + const changed_az_count = info_result.filter((node) => { + const nodeInfo = node as { + key: string; + value: string | string[] | null; + }; + const infoStr = nodeInfo.value?.toString() || ""; + return infoStr.includes(`availability_zone:${az}`); + }).length; + + expect(changed_az_count).toBe(1); + } else { + throw new Error( + "Unexpected response format from INFO command", + ); + } } finally { - await waitForScriptNotBusy(client1); + await client_for_config_set?.close(); + await client_for_testing_az?.close(); } - } finally { - expect(await client1.scriptFlush()).toEqual("OK"); - client1.close(); - client2.close(); - } - }, - TIMEOUT, - ); + }, + ); + }); + describe("GlideClusterClient - AZAffinity with Non-existing AZ", () => { + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "should route commands to a replica when AZ does not exist using protocol %p", + async (protocol) => { + const GET_CALLS = 4; + const replica_calls = 1; + const get_cmdstat = `cmdstat_get:calls=${replica_calls}`; + let client_for_testing_az; + + try { + // Skip test if server version is below 8.0.0 + if (azCluster.checkIfServerVersionLessThan("8.0.0")) { + console.log( + "Skipping test: requires Valkey 8.0.0 or higher", + ); + return; + } + + // Create a client configured for AZAffinity with a non-existing AZ + client_for_testing_az = + await GlideClusterClient.createClient( + getClientConfigurationOption( + azCluster.getAddresses(), + protocol, + { + readFrom: "AZAffinity", + clientAz: "non-existing-az", + requestTimeout: 2000, + }, + ), + ); + + // Reset command stats on all nodes + await client_for_testing_az.customCommand( + ["CONFIG", "RESETSTAT"], + { route: "allNodes" }, + ); + + // Issue GET commands + for (let i = 0; i < GET_CALLS; i++) { + await client_for_testing_az.get("foo"); + } + + // Fetch command stats from all nodes + const info_result = + await client_for_testing_az.customCommand( + ["INFO", "COMMANDSTATS"], + { route: "allNodes" }, + ); + + // Inline matching logic + let matchingEntriesCount = 0; + + if ( + typeof info_result === "object" && + info_result !== null + ) { + const nodeResponses = Object.values(info_result); + + for (const response of nodeResponses) { + if ( + response && + typeof response === "object" && + "value" in response && + response.value.includes(get_cmdstat) + ) { + matchingEntriesCount++; + } + } + } else { + throw new Error( + "Unexpected response format from INFO command", + ); + } + + // Validate that only one replica handled the GET calls + expect(matchingEntriesCount).toBe(4); + } finally { + // Cleanup: Close the client after test execution + await client_for_testing_az?.close(); + } + }, + ); + }); + describe("GlideClusterClient - Get Statistics", () => { + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "should return valid statistics using protocol %p", + async (protocol) => { + let glideClientForTesting; + + try { + // Create a GlideClusterClient instance for testing + glideClientForTesting = + await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + { + requestTimeout: 2000, + }, + ), + ); + + // Fetch statistics using get_statistics method + const stats = await glideClientForTesting.getStatistics(); + + // Assertions to check if stats object has correct structure + expect(typeof stats).toBe("object"); + expect(stats).toHaveProperty("total_connections"); + expect(stats).toHaveProperty("total_clients"); + expect(Object.keys(stats)).toHaveLength(2); + } finally { + // Ensure the client is properly closed + await glideClientForTesting?.close(); + } + }, + ); + }); }); diff --git a/node/tests/PubSub.test.ts b/node/tests/PubSub.test.ts index 53fbba0a60..12bb3d800a 100644 --- a/node/tests/PubSub.test.ts +++ b/node/tests/PubSub.test.ts @@ -28,7 +28,6 @@ import ValkeyCluster from "../../utils/TestUtils"; import { flushAndCloseClient, getServerVersion, - parseCommandLineArgs, parseEndpoints, } from "./TestUtilities"; @@ -60,9 +59,8 @@ describe("PubSub", () => { let cmeCluster: ValkeyCluster; let cmdCluster: ValkeyCluster; beforeAll(async () => { - const standaloneAddresses = - parseCommandLineArgs()["standalone-endpoints"]; - const clusterAddresses = parseCommandLineArgs()["cluster-endpoints"]; + const standaloneAddresses = global.STAND_ALONE_ENDPOINT; + const clusterAddresses = global.CLUSTER_ENDPOINTS; // Connect to cluster or create a new one based on the parsed addresses const [_cmdCluster, _cmeCluster] = await Promise.all([ standaloneAddresses @@ -84,20 +82,22 @@ describe("PubSub", () => { cmeCluster = _cmeCluster; }, 40000); afterEach(async () => { - await Promise.all([ - cmdCluster - ? flushAndCloseClient(false, cmdCluster.getAddresses()) - : Promise.resolve(), - cmeCluster - ? flushAndCloseClient(true, cmeCluster.getAddresses()) - : Promise.resolve(), - ]); + if (cmdCluster) { + await flushAndCloseClient(false, cmdCluster.getAddresses()); + } + + if (cmeCluster) { + await flushAndCloseClient(true, cmeCluster.getAddresses()); + } }); afterAll(async () => { - await Promise.all([ - cmdCluster ? cmdCluster.close() : Promise.resolve(), - cmeCluster ? cmeCluster.close() : Promise.resolve(), - ]); + if (cmdCluster) { + await cmdCluster.close(); + } + + if (cmeCluster) { + await cmeCluster.close(); + } }); async function createClients( diff --git a/node/tests/ScanTest.test.ts b/node/tests/ScanTest.test.ts index 5c975cacdc..bff90bab36 100644 --- a/node/tests/ScanTest.test.ts +++ b/node/tests/ScanTest.test.ts @@ -18,7 +18,6 @@ import { flushAndCloseClient, getClientConfigurationOption, getServerVersion, - parseCommandLineArgs, parseEndpoints, } from "./TestUtilities"; @@ -30,7 +29,7 @@ describe("Scan GlideClusterClient", () => { let cluster: ValkeyCluster; let client: GlideClusterClient; beforeAll(async () => { - const clusterAddresses = parseCommandLineArgs()["cluster-endpoints"]; + const clusterAddresses = global.CLUSTER_ENDPOINTS; // Connect to cluster or create a new one based on the parsed addresses cluster = clusterAddresses ? await ValkeyCluster.initFromExistingCluster( @@ -385,8 +384,7 @@ describe("Scan GlideClient", () => { let cluster: ValkeyCluster; let client: GlideClient; beforeAll(async () => { - const standaloneAddresses = - parseCommandLineArgs()["standalone-endpoints"]; + const standaloneAddresses = global.STAND_ALONE_ENDPOINT; cluster = standaloneAddresses ? await ValkeyCluster.initFromExistingCluster( false, diff --git a/node/tests/ServerModules.test.ts b/node/tests/ServerModules.test.ts new file mode 100644 index 0000000000..df16ce89e7 --- /dev/null +++ b/node/tests/ServerModules.test.ts @@ -0,0 +1,3457 @@ +/** + * Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + */ +import { + afterAll, + afterEach, + beforeAll, + describe, + expect, + it, +} from "@jest/globals"; +import { v4 as uuidv4 } from "uuid"; +import { + ConditionalChange, + convertGlideRecordToRecord, + Decoder, + FtAggregateOptions, + FtAggregateReturnType, + FtSearchOptions, + FtSearchReturnType, + GlideClusterClient, + GlideFt, + GlideJson, + GlideRecord, + GlideString, + InfoOptions, + JsonGetOptions, + ProtocolVersion, + RequestError, + SortOrder, + VectorField, +} from ".."; +import { ValkeyCluster } from "../../utils/TestUtils"; +import { + flushAndCloseClient, + getClientConfigurationOption, + getServerVersion, + parseEndpoints, +} from "./TestUtilities"; + +const TIMEOUT = 50000; +/** Waiting interval to let server process the data before querying */ +const DATA_PROCESSING_TIMEOUT = 1000; + +describe("Server Module Tests", () => { + let cluster: ValkeyCluster; + + beforeAll(async () => { + const clusterAddresses = global.CLUSTER_ENDPOINTS; + cluster = await ValkeyCluster.initFromExistingCluster( + true, + parseEndpoints(clusterAddresses), + getServerVersion, + ); + }, 40000); + + afterAll(async () => { + await cluster.close(); + }, TIMEOUT); + + describe.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "GlideJson", + (protocol) => { + let client: GlideClusterClient; + + afterEach(async () => { + await flushAndCloseClient(true, cluster.getAddresses(), client); + }); + + it("check modules loaded", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const info = await client.info({ + sections: [InfoOptions.Modules], + route: "randomNode", + }); + expect(info).toContain("# json_core_metrics"); + expect(info).toContain("# search_index_stats"); + }); + + it("json.set and json.get tests", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const jsonValue = { a: 1.0, b: 2 }; + + // JSON.set + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + + // JSON.get + let result = await GlideJson.get(client, key, { path: "." }); + expect(JSON.parse(result.toString())).toEqual(jsonValue); + + // binary buffer test + result = await GlideJson.get(client, Buffer.from(key), { + path: Buffer.from("."), + decoder: Decoder.Bytes, + }); + expect(result).toEqual(Buffer.from(JSON.stringify(jsonValue))); + + expect( + await GlideJson.set( + client, + Buffer.from(key), + Buffer.from("$"), + Buffer.from(JSON.stringify({ a: 1.0, b: 3 })), + ), + ).toBe("OK"); + + // JSON.get with array of paths + result = await GlideJson.get(client, key, { + path: ["$.a", "$.b"], + }); + expect(JSON.parse(result.toString())).toEqual({ + "$.a": [1.0], + "$.b": [3], + }); + + // JSON.get with non-existing key + expect( + await GlideJson.get(client, "non_existing_key", { + path: ["$"], + }), + ); + + // JSON.get with non-existing path + result = await GlideJson.get(client, key, { path: "$.d" }); + expect(result).toEqual("[]"); + }); + + it("json.set and json.get tests with multiple value", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + + // JSON.set with complex object + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify({ + a: { c: 1, d: 4 }, + b: { c: 2 }, + c: true, + }), + ), + ).toBe("OK"); + + // JSON.get with deep path + let result = await GlideJson.get(client, key, { + path: "$..c", + }); + expect(JSON.parse(result.toString())).toEqual([true, 1, 2]); + + // JSON.set with deep path + expect( + await GlideJson.set(client, key, "$..c", '"new_value"'), + ).toBe("OK"); + + // verify JSON.set result + result = await GlideJson.get(client, key, { path: "$..c" }); + expect(JSON.parse(result.toString())).toEqual([ + "new_value", + "new_value", + "new_value", + ]); + }); + + it("json.set conditional set", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const value = JSON.stringify({ a: 1.0, b: 2 }); + + expect( + await GlideJson.set(client, key, "$", value, { + conditionalChange: ConditionalChange.ONLY_IF_EXISTS, + }), + ).toBeNull(); + + expect( + await GlideJson.set(client, key, "$", value, { + conditionalChange: + ConditionalChange.ONLY_IF_DOES_NOT_EXIST, + }), + ).toBe("OK"); + + expect( + await GlideJson.set(client, key, "$.a", "4.5", { + conditionalChange: + ConditionalChange.ONLY_IF_DOES_NOT_EXIST, + }), + ).toBeNull(); + let result = await GlideJson.get(client, key, { + path: ".a", + }); + expect(result).toEqual("1"); + + expect( + await GlideJson.set(client, key, "$.a", "4.5", { + conditionalChange: ConditionalChange.ONLY_IF_EXISTS, + }), + ).toBe("OK"); + result = await GlideJson.get(client, key, { path: ".a" }); + expect(result).toEqual("4.5"); + }); + + it("json.get formatting", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + // Set initial JSON value + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify({ a: 1.0, b: 2, c: { d: 3, e: 4 } }), + ), + ).toBe("OK"); + // JSON.get with formatting options + let result = await GlideJson.get(client, key, { + path: "$", + indent: " ", + newline: "\n", + space: " ", + } as JsonGetOptions); + + const expectedResult1 = + '[\n {\n "a": 1,\n "b": 2,\n "c": {\n "d": 3,\n "e": 4\n }\n }\n]'; + expect(result).toEqual(expectedResult1); + // JSON.get with different formatting options + result = await GlideJson.get(client, key, { + path: "$", + indent: "~", + newline: "\n", + space: "*", + } as JsonGetOptions); + + const expectedResult2 = + '[\n~{\n~~"a":*1,\n~~"b":*2,\n~~"c":*{\n~~~"d":*3,\n~~~"e":*4\n~~}\n~}\n]'; + expect(result).toEqual(expectedResult2); + + // binary buffer test + const result3 = await GlideJson.get(client, Buffer.from(key), { + path: Buffer.from("$"), + indent: Buffer.from("~"), + newline: Buffer.from("\n"), + space: Buffer.from("*"), + } as JsonGetOptions); + expect(result3).toEqual(expectedResult2); + }); + + it("json.mget", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key1 = uuidv4(); + const key2 = uuidv4(); + const data = { + [key1]: '{"a": 1, "b": ["one", "two"]}', + [key2]: '{"a": 1, "c": false}', + }; + + for (const key of Object.keys(data)) { + await GlideJson.set(client, key, ".", data[key]); + } + + expect( + await GlideJson.mget( + client, + [key1, key2, uuidv4()], + Buffer.from("$.c"), + ), + ).toEqual(["[]", "[false]", null]); + expect( + await GlideJson.mget( + client, + [Buffer.from(key1), key2], + ".b[*]", + { decoder: Decoder.Bytes }, + ), + ).toEqual([Buffer.from('"one"'), null]); + }); + + it("json.arrinsert", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + + const key = uuidv4(); + const doc = { + a: [], + b: { a: [1, 2, 3, 4] }, + c: { a: "not an array" }, + d: [{ a: ["x", "y"] }, { a: [["foo"]] }], + e: [{ a: 42 }, { a: {} }], + f: { a: [true, false, null] }, + }; + expect( + await GlideJson.set(client, key, "$", JSON.stringify(doc)), + ).toBe("OK"); + + const result = await GlideJson.arrinsert( + client, + key, + "$..a", + 0, + [ + '"string_value"', + "123", + '{"key": "value"}', + "true", + "null", + '["bar"]', + ], + ); + expect(result).toEqual([6, 10, null, 8, 7, null, null, 9]); + + const expected = { + a: [ + "string_value", + 123, + { key: "value" }, + true, + null, + ["bar"], + ], + b: { + a: [ + "string_value", + 123, + { key: "value" }, + true, + null, + ["bar"], + 1, + 2, + 3, + 4, + ], + }, + c: { a: "not an array" }, + d: [ + { + a: [ + "string_value", + 123, + { key: "value" }, + true, + null, + ["bar"], + "x", + "y", + ], + }, + { + a: [ + "string_value", + 123, + { key: "value" }, + true, + null, + ["bar"], + ["foo"], + ], + }, + ], + e: [{ a: 42 }, { a: {} }], + f: { + a: [ + "string_value", + 123, + { key: "value" }, + true, + null, + ["bar"], + true, + false, + null, + ], + }, + }; + expect( + JSON.parse((await GlideJson.get(client, key)) as string), + ).toEqual(expected); + + // Binary buffer test + expect( + JSON.parse( + (await GlideJson.get( + client, + Buffer.from(key), + )) as string, + ), + ).toEqual(expected); + }); + + it("json.arrpop", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + + const key = uuidv4(); + let doc = + '{"a": [1, 2, true], "b": {"a": [3, 4, ["value", 3, false], 5], "c": {"a": 42}}}'; + expect(await GlideJson.set(client, key, "$", doc)).toBe("OK"); + + let res = await GlideJson.arrpop(client, key, { + path: "$.a", + index: 1, + }); + expect(res).toEqual(["2"]); + + res = await GlideJson.arrpop(client, Buffer.from(key), { + path: "$..a", + }); + expect(res).toEqual(["true", "5", null]); + + res = await GlideJson.arrpop(client, key, { + path: "..a", + decoder: Decoder.Bytes, + }); + expect(res).toEqual(Buffer.from("1")); + + // Even if only one array element was returned, ensure second array at `..a` was popped + doc = (await GlideJson.get(client, key, { + path: ["$..a"], + })) as string; + expect(doc).toEqual("[[],[3,4],42]"); + + // Out of index + res = await GlideJson.arrpop(client, key, { + path: Buffer.from("$..a"), + index: 10, + }); + expect(res).toEqual([null, "4", null]); + + // pop without options + expect(await GlideJson.set(client, key, "$", doc)).toEqual( + "OK", + ); + expect(await GlideJson.arrpop(client, key)).toEqual("42"); + + // Binary buffer test + expect( + await GlideJson.arrpop(client, Buffer.from(key)), + ).toEqual("[3,4]"); + }); + + it("json.arrlen", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + + const key = uuidv4(); + const doc = + '{"a": [1, 2, 3], "b": {"a": [1, 2], "c": {"a": 42}}}'; + expect(await GlideJson.set(client, key, "$", doc)).toBe("OK"); + + expect( + await GlideJson.arrlen(client, key, { path: "$.a" }), + ).toEqual([3]); + expect( + await GlideJson.arrlen(client, key, { path: "$..a" }), + ).toEqual([3, 2, null]); + // Legacy path retrieves the first array match at ..a + expect( + await GlideJson.arrlen(client, key, { path: "..a" }), + ).toEqual(3); + // Value at path is not an array + expect( + await GlideJson.arrlen(client, key, { path: "$" }), + ).toEqual([null]); + + await expect( + GlideJson.arrlen(client, key, { path: "." }), + ).rejects.toThrow(); + + expect( + await GlideJson.set(client, key, "$", "[1, 2, 3, 4]"), + ).toBe("OK"); + expect(await GlideJson.arrlen(client, key)).toEqual(4); + + // Binary buffer test + expect( + await GlideJson.arrlen(client, Buffer.from(key)), + ).toEqual(4); + }); + + it("json.arrindex", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + + const key1 = uuidv4(); + const key2 = uuidv4(); + const doc1 = + '{"a": [1, 3, true, "hello"], "b": {"a": [3, 4, [3, false], 5], "c": {"a": 42}}}'; + + expect(await GlideJson.set(client, key1, "$", doc1)).toBe("OK"); + + // Verify scalar type + expect( + await GlideJson.arrindex(client, key1, "$..a", true), + ).toEqual([2, -1, null]); + expect( + await GlideJson.arrindex(client, key1, "..a", true), + ).toEqual(2); + + expect( + await GlideJson.arrindex(client, key1, "$..a", 3), + ).toEqual([1, 0, null]); + expect( + await GlideJson.arrindex(client, key1, "..a", 3), + ).toEqual(1); + + expect( + await GlideJson.arrindex(client, key1, "$..a", '"hello"'), + ).toEqual([3, -1, null]); + expect( + await GlideJson.arrindex(client, key1, "..a", '"hello"'), + ).toEqual(3); + + expect( + await GlideJson.arrindex(client, key1, "$..a", null), + ).toEqual([-1, -1, null]); + expect( + await GlideJson.arrindex(client, key1, "..a", null), + ).toEqual(-1); + + // Value at the path is not an array + expect( + await GlideJson.arrindex(client, key1, "$..c", 42), + ).toEqual([null]); + await expect( + GlideJson.arrindex(client, key1, "..c", 42), + ).rejects.toThrow(RequestError); + + const doc2 = + '{"a": [1, 3, true, "foo", "meow", "m", "foo", "lol", false],' + + ' "b": {"a": [3, 4, ["value", 3, false], 5], "c": {"a": 42}}}'; + + expect(await GlideJson.set(client, key2, "$", doc2)).toBe("OK"); + + // Verify optional `start` and `end` + expect( + await GlideJson.arrindex(client, key2, "$..a", '"foo"', { + start: 6, + end: 8, + }), + ).toEqual([6, -1, null]); + expect( + await GlideJson.arrindex(client, key2, "$..a", '"foo"', { + start: 2, + end: 8, + }), + ).toEqual([3, -1, null]); + expect( + await GlideJson.arrindex(client, key2, "..a", '"meow"', { + start: 2, + end: 8, + }), + ).toEqual(4); + + // Verify without optional `end` + expect( + await GlideJson.arrindex(client, key2, "$..a", '"foo"', { + start: 6, + }), + ).toEqual([6, -1, null]); + expect( + await GlideJson.arrindex(client, key2, "..a", '"foo"', { + start: 6, + }), + ).toEqual(6); + + // Verify optional `end` with 0 or -1 (means the last element is included) + expect( + await GlideJson.arrindex(client, key2, "$..a", '"foo"', { + start: 6, + end: 0, + }), + ).toEqual([6, -1, null]); + expect( + await GlideJson.arrindex(client, key2, "..a", '"foo"', { + start: 6, + end: 0, + }), + ).toEqual(6); + expect( + await GlideJson.arrindex(client, key2, "$..a", '"foo"', { + start: 6, + end: -1, + }), + ).toEqual([6, -1, null]); + expect( + await GlideJson.arrindex(client, key2, "..a", '"foo"', { + start: 6, + end: -1, + }), + ).toEqual(6); + + // Test with binary input + expect( + await GlideJson.arrindex( + client, + Buffer.from(key2), + Buffer.from("$..a"), + Buffer.from('"foo"'), + { + start: 6, + end: -1, + }, + ), + ).toEqual([6, -1, null]); + expect( + await GlideJson.arrindex( + client, + Buffer.from(key2), + Buffer.from("..a"), + Buffer.from('"foo"'), + { + start: 6, + end: -1, + }, + ), + ).toEqual(6); + + // Test with non-existent path + expect( + await GlideJson.arrindex( + client, + key2, + "$.nonexistent", + true, + ), + ).toEqual([]); + await expect( + GlideJson.arrindex(client, key2, "nonexistent", true), + ).rejects.toThrow(RequestError); + + // Test with non-existent key + await expect( + GlideJson.arrindex(client, "non_existing_key", "$", true), + ).rejects.toThrow(RequestError); + await expect( + GlideJson.arrindex(client, "non_existing_key", ".", true), + ).rejects.toThrow(RequestError); + }); + + it("json.toggle tests", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const key2 = uuidv4(); + const jsonValue = { + bool: true, + nested: { bool: false, nested: { bool: 10 } }, + }; + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + expect( + await GlideJson.toggle(client, key, { path: "$..bool" }), + ).toEqual([false, true, null]); + expect( + await GlideJson.toggle(client, key, { path: "bool" }), + ).toBe(true); + expect( + await GlideJson.toggle(client, key, { + path: "$.non_existing", + }), + ).toEqual([]); + expect( + await GlideJson.toggle(client, key, { path: "$.nested" }), + ).toEqual([null]); + + // testing behavior with default pathing + expect(await GlideJson.set(client, key2, ".", "true")).toBe( + "OK", + ); + expect(await GlideJson.toggle(client, key2)).toBe(false); + expect(await GlideJson.toggle(client, key2)).toBe(true); + + // expect request errors + await expect( + GlideJson.toggle(client, key, { path: "nested" }), + ).rejects.toThrow(RequestError); + await expect( + GlideJson.toggle(client, key, { path: ".non_existing" }), + ).rejects.toThrow(RequestError); + await expect( + GlideJson.toggle(client, "non_existing_key", { path: "$" }), + ).rejects.toThrow(RequestError); + + // Binary buffer test + expect(await GlideJson.toggle(client, Buffer.from(key2))).toBe( + false, + ); + }); + + it("json.del tests", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const jsonValue = { a: 1.0, b: { a: 1, b: 2.5, c: true } }; + // setup + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + + // non-existing paths + expect( + await GlideJson.del(client, key, { path: "$..path" }), + ).toBe(0); + expect( + await GlideJson.del(client, key, { path: "..path" }), + ).toBe(0); + + // deleting existing path + expect(await GlideJson.del(client, key, { path: "$..a" })).toBe( + 2, + ); + expect(await GlideJson.get(client, key, { path: "$..a" })).toBe( + "[]", + ); + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + expect(await GlideJson.del(client, key, { path: "..a" })).toBe( + 2, + ); + await expect( + GlideJson.get(client, key, { path: "..a" }), + ).rejects.toThrow(RequestError); + + // verify result + const result = await GlideJson.get(client, key, { + path: "$", + }); + expect(JSON.parse(result as string)).toEqual([ + { b: { b: 2.5, c: true } }, + ]); + + // test root deletion operations + expect(await GlideJson.del(client, key, { path: "$" })).toBe(1); + + // reset and test dot deletion + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + expect(await GlideJson.del(client, key, { path: "." })).toBe(1); + + // reset and test key deletion + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + expect(await GlideJson.del(client, key)).toBe(1); + expect(await GlideJson.del(client, key)).toBe(0); + expect( + await GlideJson.get(client, key, { path: "$" }), + ).toBeNull(); + + // Binary buffer test + expect(await GlideJson.del(client, Buffer.from(key))).toBe(0); + + // non-existing keys + expect( + await GlideJson.del(client, "non_existing_key", { + path: "$", + }), + ).toBe(0); + expect( + await GlideJson.del(client, "non_existing_key", { + path: ".", + }), + ).toBe(0); + }); + + it("json.forget tests", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const jsonValue = { a: 1.0, b: { a: 1, b: 2.5, c: true } }; + // setup + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + + // non-existing paths + expect( + await GlideJson.forget(client, key, { path: "$..path" }), + ).toBe(0); + expect( + await GlideJson.forget(client, key, { path: "..path" }), + ).toBe(0); + + // deleting existing paths + expect( + await GlideJson.forget(client, key, { path: "$..a" }), + ).toBe(2); + expect(await GlideJson.get(client, key, { path: "$..a" })).toBe( + "[]", + ); + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + expect( + await GlideJson.forget(client, key, { path: "..a" }), + ).toBe(2); + await expect( + GlideJson.get(client, key, { path: "..a" }), + ).rejects.toThrow(RequestError); + + // verify result + const result = await GlideJson.get(client, key, { + path: "$", + }); + expect(JSON.parse(result as string)).toEqual([ + { b: { b: 2.5, c: true } }, + ]); + + // test root deletion operations + expect(await GlideJson.forget(client, key, { path: "$" })).toBe( + 1, + ); + + // reset and test dot deletion + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + expect(await GlideJson.forget(client, key, { path: "." })).toBe( + 1, + ); + + // reset and test key deletion + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + expect(await GlideJson.forget(client, key)).toBe(1); + expect(await GlideJson.forget(client, key)).toBe(0); + expect( + await GlideJson.get(client, key, { path: "$" }), + ).toBeNull(); + + // Binary buffer test + expect(await GlideJson.forget(client, Buffer.from(key))).toBe( + 0, + ); + + // non-existing keys + expect( + await GlideJson.forget(client, "non_existing_key", { + path: "$", + }), + ).toBe(0); + expect( + await GlideJson.forget(client, "non_existing_key", { + path: ".", + }), + ).toBe(0); + }); + + it("json.type tests", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const jsonValue = [1, 2.3, "foo", true, null, {}, []]; + // setup + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + expect( + await GlideJson.type(client, key, { path: "$[*]" }), + ).toEqual([ + "integer", + "number", + "string", + "boolean", + "null", + "object", + "array", + ]); + expect( + await GlideJson.type(client, "non_existing", { + path: "$[*]", + }), + ).toBeNull(); + expect( + await GlideJson.type(client, key, { + path: "$non_existing", + }), + ).toEqual([]); + + const key2 = uuidv4(); + const jsonValue2 = { Name: "John", Age: 27 }; + // setup + expect( + await GlideJson.set( + client, + key2, + "$", + JSON.stringify(jsonValue2), + ), + ).toBe("OK"); + expect( + await GlideJson.type(client, key2, { path: "." }), + ).toEqual("object"); + expect( + await GlideJson.type(client, key2, { path: ".Age" }), + ).toEqual("integer"); + expect( + await GlideJson.type(client, key2, { path: ".Job" }), + ).toBeNull(); + expect( + await GlideJson.type(client, "non_existing", { path: "." }), + ).toBeNull(); + + // Binary buffer test + expect( + await GlideJson.type(client, Buffer.from(key2), { + path: Buffer.from(".Age"), + }), + ).toEqual("integer"); + }); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "json.clear tests", + async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const jsonValue = { + obj: { a: 1, b: 2 }, + arr: [1, 2, 3], + str: "foo", + bool: true, + int: 42, + float: 3.14, + nullVal: null, + }; + + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + + expect( + await GlideJson.clear(client, key, { path: "$.*" }), + ).toBe(6); + + const result = await GlideJson.get(client, key, { + path: ["$"], + }); + + expect(JSON.parse(result as string)).toEqual([ + { + obj: {}, + arr: [], + str: "", + bool: false, + int: 0, + float: 0.0, + nullVal: null, + }, + ]); + + expect( + await GlideJson.clear(client, key, { path: "$.*" }), + ).toBe(0); + + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + + expect( + await GlideJson.clear(client, key, { path: "*" }), + ).toBe(6); + + const jsonValue2 = { + a: 1, + b: { a: [5, 6, 7], b: { a: true } }, + c: { a: "value", b: { a: 3.5 } }, + d: { a: { foo: "foo" } }, + nullVal: null, + }; + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue2), + ), + ).toBe("OK"); + + expect( + await GlideJson.clear(client, key, { + path: "b.a[1:3]", + }), + ).toBe(2); + + expect( + await GlideJson.clear(client, key, { + path: "b.a[1:3]", + }), + ).toBe(0); + + expect( + JSON.parse( + (await GlideJson.get(client, key, { + path: ["$..a"], + })) as string, + ), + ).toEqual([ + 1, + [5, 0, 0], + true, + "value", + 3.5, + { foo: "foo" }, + ]); + + expect( + await GlideJson.clear(client, key, { path: "..a" }), + ).toBe(6); + + expect( + JSON.parse( + (await GlideJson.get(client, key, { + path: ["$..a"], + })) as string, + ), + ).toEqual([0, [], false, "", 0.0, {}]); + + expect( + await GlideJson.clear(client, key, { path: "$..a" }), + ).toBe(0); + + // Path doesn't exist + expect( + await GlideJson.clear(client, key, { path: "$.path" }), + ).toBe(0); + + expect( + await GlideJson.clear(client, key, { path: "path" }), + ).toBe(0); + + // Key doesn't exist + await expect( + GlideJson.clear(client, "non_existing_key"), + ).rejects.toThrow(RequestError); + + await expect( + GlideJson.clear(client, "non_existing_key", { + path: "$", + }), + ).rejects.toThrow(RequestError); + + await expect( + GlideJson.clear(client, "non_existing_key", { + path: ".", + }), + ).rejects.toThrow(RequestError); + }, + ); + + it("json.resp tests", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const jsonValue = { + obj: { a: 1, b: 2 }, + arr: [1, 2, 3], + str: "foo", + bool: true, + int: 42, + float: 3.14, + nullVal: null, + }; + // setup + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + expect( + await GlideJson.resp(client, key, { path: "$.*" }), + ).toEqual([ + ["{", ["a", 1], ["b", 2]], + ["[", 1, 2, 3], + "foo", + "true", + 42, + "3.14", + null, + ]); // leading "{" - JSON objects, leading "[" - JSON arrays + + // multiple path match, the first will be returned + expect( + await GlideJson.resp(client, key, { path: "*" }), + ).toEqual(["{", ["a", 1], ["b", 2]]); + + // testing $ path + expect( + await GlideJson.resp(client, key, { path: "$" }), + ).toEqual([ + [ + "{", + ["obj", ["{", ["a", 1], ["b", 2]]], + ["arr", ["[", 1, 2, 3]], + ["str", "foo"], + ["bool", "true"], + ["int", 42], + ["float", "3.14"], + ["nullVal", null], + ], + ]); + + // testing . path + expect( + await GlideJson.resp(client, key, { path: "." }), + ).toEqual([ + "{", + ["obj", ["{", ["a", 1], ["b", 2]]], + ["arr", ["[", 1, 2, 3]], + ["str", "foo"], + ["bool", "true"], + ["int", 42], + ["float", "3.14"], + ["nullVal", null], + ]); + + // $.str and .str + expect( + await GlideJson.resp(client, key, { path: "$.str" }), + ).toEqual(["foo"]); + expect( + await GlideJson.resp(client, key, { path: ".str" }), + ).toEqual("foo"); + + // setup new json value + const jsonValue2 = { + a: [1, 2, 3], + b: { a: [1, 2], c: { a: 42 } }, + }; + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue2), + ), + ).toBe("OK"); + + expect( + await GlideJson.resp(client, key, { path: "..a" }), + ).toEqual(["[", 1, 2, 3]); + + expect( + await GlideJson.resp(client, key, { + path: "$.nonexistent", + }), + ).toEqual([]); + + // error case + await expect( + GlideJson.resp(client, key, { path: "nonexistent" }), + ).rejects.toThrow(RequestError); + + // non-existent key + expect( + await GlideJson.resp(client, "nonexistent_key", { + path: "$", + }), + ).toBeNull(); + expect( + await GlideJson.resp(client, "nonexistent_key", { + path: ".", + }), + ).toBeNull(); + expect( + await GlideJson.resp(client, "nonexistent_key"), + ).toBeNull(); + + // binary buffer test + expect( + await GlideJson.resp(client, Buffer.from(key), { + path: Buffer.from("..a"), + decoder: Decoder.Bytes, + }), + ).toEqual([Buffer.from("["), 1, 2, 3]); + }); + + it("json.arrtrim tests", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + + const key = uuidv4(); + const jsonValue = { + a: [0, 1, 2, 3, 4, 5, 6, 7, 8], + b: { a: [0, 9, 10, 11, 12, 13], c: { a: 42 } }, + }; + + // setup + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + + // Basic trim + expect( + await GlideJson.arrtrim(client, key, "$..a", 1, 7), + ).toEqual([7, 5, null]); + + // Test end >= size (should be treated as size-1) + expect( + await GlideJson.arrtrim(client, key, "$.a", 0, 10), + ).toEqual([7]); + expect( + await GlideJson.arrtrim(client, key, ".a", 0, 10), + ).toEqual(7); + + // Test negative start (should be treated as 0) + expect( + await GlideJson.arrtrim(client, key, "$.a", -1, 5), + ).toEqual([6]); + expect( + await GlideJson.arrtrim(client, key, ".a", -1, 5), + ).toEqual(6); + + // Test start >= size (should empty the array) + expect( + await GlideJson.arrtrim(client, key, "$.a", 7, 10), + ).toEqual([0]); + const jsonValue2 = ["a", "b", "c"]; + expect( + await GlideJson.set( + client, + key, + ".a", + JSON.stringify(jsonValue2), + ), + ).toBe("OK"); + expect( + await GlideJson.arrtrim(client, key, ".a", 7, 10), + ).toEqual(0); + + // Test start > end (should empty the array) + expect( + await GlideJson.arrtrim(client, key, "$..a", 2, 1), + ).toEqual([0, 0, null]); + const jsonValue3 = ["a", "b", "c", "d"]; + expect( + await GlideJson.set( + client, + key, + "..a", + JSON.stringify(jsonValue3), + ), + ).toBe("OK"); + expect( + await GlideJson.arrtrim(client, key, "..a", 2, 1), + ).toEqual(0); + + // Multiple path match + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + expect( + await GlideJson.arrtrim(client, key, "..a", 1, 10), + ).toEqual(8); + + // Test with non-existent path + await expect( + GlideJson.arrtrim(client, key, "nonexistent", 0, 1), + ).rejects.toThrow(RequestError); + expect( + await GlideJson.arrtrim(client, key, "$.nonexistent", 0, 1), + ).toEqual([]); + + // Test with non-array path + expect(await GlideJson.arrtrim(client, key, "$", 0, 1)).toEqual( + [null], + ); + await expect( + GlideJson.arrtrim(client, key, ".", 0, 1), + ).rejects.toThrow(RequestError); + + // Test with non-existent key + await expect( + GlideJson.arrtrim(client, "non_existing_key", "$", 0, 1), + ).rejects.toThrow(RequestError); + await expect( + GlideJson.arrtrim(client, "non_existing_key", ".", 0, 1), + ).rejects.toThrow(RequestError); + + // Test empty array + expect( + await GlideJson.set( + client, + key, + "$.empty", + JSON.stringify([]), + ), + ).toBe("OK"); + expect( + await GlideJson.arrtrim(client, key, "$.empty", 0, 1), + ).toEqual([0]); + expect( + await GlideJson.arrtrim(client, key, ".empty", 0, 1), + ).toEqual(0); + }); + + it("json.strlen tests", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const jsonValue = { + a: "foo", + nested: { a: "hello" }, + nested2: { a: 31 }, + }; + // setup + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + + expect( + await GlideJson.strlen(client, key, { path: "$..a" }), + ).toEqual([3, 5, null]); + expect(await GlideJson.strlen(client, key, { path: "a" })).toBe( + 3, + ); + + expect( + await GlideJson.strlen(client, key, { + path: "$.nested", + }), + ).toEqual([null]); + expect( + await GlideJson.strlen(client, key, { path: "$..a" }), + ).toEqual([3, 5, null]); + + expect( + await GlideJson.strlen(client, "non_existing_key", { + path: ".", + }), + ).toBeNull(); + expect( + await GlideJson.strlen(client, "non_existing_key", { + path: "$", + }), + ).toBeNull(); + expect( + await GlideJson.strlen(client, key, { + path: "$.non_existing_path", + }), + ).toEqual([]); + + // error case + await expect( + GlideJson.strlen(client, key, { path: "nested" }), + ).rejects.toThrow(RequestError); + await expect(GlideJson.strlen(client, key)).rejects.toThrow( + RequestError, + ); + // Binary buffer test + expect( + await GlideJson.strlen(client, Buffer.from(key), { + path: Buffer.from("$..a"), + }), + ).toEqual([3, 5, null]); + }); + + it("json.arrappend", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + let doc = { a: 1, b: ["one", "two"] }; + expect( + await GlideJson.set(client, key, "$", JSON.stringify(doc)), + ).toBe("OK"); + + expect( + await GlideJson.arrappend(client, key, Buffer.from("$.b"), [ + '"three"', + ]), + ).toEqual([3]); + expect( + await GlideJson.arrappend(client, key, ".b", [ + '"four"', + '"five"', + ]), + ).toEqual(5); + doc = JSON.parse( + (await GlideJson.get(client, key, { path: "." })) as string, + ); + expect(doc).toEqual({ + a: 1, + b: ["one", "two", "three", "four", "five"], + }); + + expect( + await GlideJson.arrappend(client, key, "$.a", ['"value"']), + ).toEqual([null]); + }); + + it("json.strappend tests", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const jsonValue = { + a: "foo", + nested: { a: "hello" }, + nested2: { a: 31 }, + }; + // setup + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + + expect( + await GlideJson.strappend(client, key, '"bar"', { + path: "$..a", + }), + ).toEqual([6, 8, null]); + expect( + await GlideJson.strappend( + client, + key, + JSON.stringify("foo"), + { + path: "a", + }, + ), + ).toBe(9); + + expect(await GlideJson.get(client, key, { path: "." })).toEqual( + JSON.stringify({ + a: "foobarfoo", + nested: { a: "hellobar" }, + nested2: { a: 31 }, + }), + ); + + // Binary buffer test + expect( + await GlideJson.strappend( + client, + Buffer.from(key), + Buffer.from(JSON.stringify("foo")), + { + path: Buffer.from("a"), + }, + ), + ).toBe(12); + + expect( + await GlideJson.strappend( + client, + key, + JSON.stringify("bar"), + { + path: "$.nested", + }, + ), + ).toEqual([null]); + + await expect( + GlideJson.strappend(client, key, JSON.stringify("bar"), { + path: ".nested", + }), + ).rejects.toThrow(RequestError); + await expect( + GlideJson.strappend(client, key, JSON.stringify("bar")), + ).rejects.toThrow(RequestError); + + expect( + await GlideJson.strappend( + client, + key, + JSON.stringify("try"), + { + path: "$.non_existing_path", + }, + ), + ).toEqual([]); + + await expect( + GlideJson.strappend(client, key, JSON.stringify("try"), { + path: ".non_existing_path", + }), + ).rejects.toThrow(RequestError); + await expect( + GlideJson.strappend( + client, + "non_existing_key", + JSON.stringify("try"), + ), + ).rejects.toThrow(RequestError); + }); + + it("json.numincrby tests", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const jsonValue = { + key1: 1, + key2: 3.5, + key3: { nested_key: { key1: [4, 5] } }, + key4: [1, 2, 3], + key5: 0, + key6: "hello", + key7: null, + key8: { nested_key: { key1: 69 } }, + key9: 1.7976931348623157e308, + }; + // setup + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + + // Increment integer value (key1) by 5 + expect( + await GlideJson.numincrby(client, key, "$.key1", 5), + ).toBe("[6]"); // 1 + 5 = 6 + + // Increment float value (key2) by 2.5 + expect( + await GlideJson.numincrby(client, key, "$.key2", 2.5), + ).toBe("[6]"); // 3.5 + 2.5 = 6 + + // Increment nested object (key3.nested_key.key1[0]) by 7 + expect( + await GlideJson.numincrby( + client, + key, + "$.key3.nested_key.key1[1]", + 7, + ), + ).toBe("[12]"); // 4 + 7 = 12 + + // Increment array element (key4[1]) by 1 + expect( + await GlideJson.numincrby(client, key, "$.key4[1]", 1), + ).toBe("[3]"); // 2 + 1 = 3 + + // Increment zero value (key5) by 10.23 (float number) + expect( + await GlideJson.numincrby(client, key, "$.key5", 10.23), + ).toBe("[10.23]"); // 0 + 10.23 = 10.23 + + // Increment a string value (key6) by a number + expect( + await GlideJson.numincrby(client, key, "$.key6", 99), + ).toBe("[null]"); // null + + // Increment a None value (key7) by a number + expect( + await GlideJson.numincrby(client, key, "$.key7", 51), + ).toBe("[null]"); // null + + // Check increment for all numbers in the document using JSON Path (First Null: key3 as an entire object. Second Null: The path checks under key3, which is an object, for numeric values). + expect(await GlideJson.numincrby(client, key, "$..*", 5)).toBe( + "[11,11,null,null,15.23,null,null,null,1.7976931348623157e+308,null,null,9,17,6,8,8,null,74]", + ); + + // Check for multiple path match in enhanced + expect( + await GlideJson.numincrby(client, key, "$..key1", 1), + ).toBe("[12,null,75]"); + + // Check for non existent path in JSONPath + expect( + await GlideJson.numincrby(client, key, "$.key10", 51), + ).toBe("[]"); // empty array + + // Check for non existent key in JSONPath + await expect( + GlideJson.numincrby( + client, + "non_existing_key", + "$.key10", + 51, + ), + ).rejects.toThrow(RequestError); + + // Check for Overflow in JSONPath + await expect( + GlideJson.numincrby( + client, + key, + "$.key9", + 1.7976931348623157e308, + ), + ).rejects.toThrow(RequestError); + + // Decrement integer value (key1) by 12 + expect( + await GlideJson.numincrby(client, key, "$.key1", -12), + ).toBe("[0]"); // 12 - 12 = 0 + + // Decrement integer value (key1) by 0.5 + expect( + await GlideJson.numincrby(client, key, "$.key1", -0.5), + ).toBe("[-0.5]"); // 0 - 0.5 = -0.5 + + // Test Legacy Path + // Increment float value (key1) by 5 (integer) + expect(await GlideJson.numincrby(client, key, "key1", 5)).toBe( + "4.5", + ); // -0.5 + 5 = 4.5 + + // Decrement float value (key1) by 5.5 (integer) + expect( + await GlideJson.numincrby(client, key, "key1", -5.5), + ).toBe("-1"); // 4.5 - 5.5 = -1 + + // Increment int value (key2) by 2.5 (a float number) + expect( + await GlideJson.numincrby(client, key, "key2", 2.5), + ).toBe("13.5"); // 11 + 2.5 = 13.5 + + // Increment nested value (key3.nested_key.key1[0]) by 7 + expect( + await GlideJson.numincrby( + client, + key, + "key3.nested_key.key1[0]", + 7, + ), + ).toBe("16"); // 9 + 7 = 16 + + // Increment array element (key4[1]) by 1 + expect( + await GlideJson.numincrby(client, key, "key4[1]", 1), + ).toBe("9"); // 8 + 1 = 9 + + // Increment a float value (key5) by 10.2 (a float number) + expect( + await GlideJson.numincrby(client, key, "key5", 10.2), + ).toBe("25.43"); // 15.23 + 10.2 = 25.43 + + // Check for multiple path match in legacy and assure that the result of the last updated value is returned + expect( + await GlideJson.numincrby(client, key, "..key1", 1), + ).toBe("76"); + + // Check if the rest of the key1 path matches were updated and not only the last value + expect( + await GlideJson.get(client, key, { path: "$..key1" }), + ).toBe("[0,[16,17],76]"); + // First is 0 as 0 + 0 = 0, Second doesn't change as its an array type (non-numeric), third is 76 as 0 + 76 = 0 + + // Check for non existent path in legacy + await expect( + GlideJson.numincrby(client, key, ".key10", 51), + ).rejects.toThrow(RequestError); + + // Check for non existent key in legacy + await expect( + GlideJson.numincrby( + client, + "non_existent_key", + ".key10", + 51, + ), + ).rejects.toThrow(RequestError); + + // Check for Overflow in legacy + await expect( + GlideJson.numincrby( + client, + key, + ".key9", + 1.7976931348623157e308, + ), + ).rejects.toThrow(RequestError); + + // binary buffer test + expect( + await GlideJson.numincrby( + client, + Buffer.from(key), + Buffer.from("key5"), + 1, + ), + ).toBe("26.43"); + }); + + it("json.nummultiby tests", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const jsonValue = + "{" + + ' "key1": 1,' + + ' "key2": 3.5,' + + ' "key3": {"nested_key": {"key1": [4, 5]}},' + + ' "key4": [1, 2, 3],' + + ' "key5": 0,' + + ' "key6": "hello",' + + ' "key7": null,' + + ' "key8": {"nested_key": {"key1": 69}},' + + ' "key9": 3.5953862697246314e307' + + "}"; + // setup + expect(await GlideJson.set(client, key, "$", jsonValue)).toBe( + "OK", + ); + + // Test JSONPath + // Multiply integer value (key1) by 5 + expect( + await GlideJson.nummultby(client, key, "$.key1", 5), + ).toBe("[5]"); // 1 * 5 = 5 + + // Multiply float value (key2) by 2.5 + expect( + await GlideJson.nummultby(client, key, "$.key2", 2.5), + ).toBe("[8.75]"); // 3.5 * 2.5 = 8.75 + + // Multiply nested object (key3.nested_key.key1[1]) by 7 + expect( + await GlideJson.nummultby( + client, + key, + "$.key3.nested_key.key1[1]", + 7, + ), + ).toBe("[35]"); // 5 * 7 = 5 + + // Multiply array element (key4[1]) by 1 + expect( + await GlideJson.nummultby(client, key, "$.key4[1]", 1), + ).toBe("[2]"); // 2 * 1 = 2 + + // Multiply zero value (key5) by 10.23 (float number) + expect( + await GlideJson.nummultby(client, key, "$.key5", 10.23), + ).toBe("[0]"); // 0 * 10.23 = 0 + + // Multiply a string value (key6) by a number + expect( + await GlideJson.nummultby(client, key, "$.key6", 99), + ).toBe("[null]"); + + // Multiply a None value (key7) by a number + expect( + await GlideJson.nummultby(client, key, "$.key7", 51), + ).toBe("[null]"); + + // Check multiplication for all numbers in the document using JSON Path + // key1: 5 * 5 = 25 + // key2: 8.75 * 5 = 43.75 + // key3.nested_key.key1[0]: 4 * 5 = 20 + // key3.nested_key.key1[1]: 35 * 5 = 175 + // key4[0]: 1 * 5 = 5 + // key4[1]: 2 * 5 = 10 + // key4[2]: 3 * 5 = 15 + // key5: 0 * 5 = 0 + // key8.nested_key.key1: 69 * 5 = 345 + // key9: 3.5953862697246314e307 * 5 = 1.7976931348623157e308 + expect(await GlideJson.nummultby(client, key, "$..*", 5)).toBe( + "[25,43.75,null,null,0,null,null,null,1.7976931348623157e+308,null,null,20,175,5,10,15,null,345]", + ); + + // Check for multiple path matches in JSONPath + // key1: 25 * 2 = 50 + // key8.nested_key.key1: 345 * 2 = 690 + expect( + await GlideJson.nummultby(client, key, "$..key1", 2), + ).toBe("[50,null,690]"); // After previous multiplications + + // Check for non-existent path in JSONPath + expect( + await GlideJson.nummultby(client, key, "$.key10", 51), + ).toBe("[]"); // Empty Array + + // Check for non-existent key in JSONPath + await expect( + GlideJson.numincrby( + client, + "non_existent_key", + "$.key10", + 51, + ), + ).rejects.toThrow(RequestError); + + // Check for Overflow in JSONPath + await expect( + GlideJson.numincrby( + client, + key, + "$.key9", + 1.7976931348623157e308, + ), + ).rejects.toThrow(RequestError); + + // Multiply integer value (key1) by -12 + expect( + await GlideJson.nummultby(client, key, "$.key1", -12), + ).toBe("[-600]"); // 50 * -12 = -600 + + // Multiply integer value (key1) by -0.5 + expect( + await GlideJson.nummultby(client, key, "$.key1", -0.5), + ).toBe("[300]"); // -600 * -0.5 = 300 + + // Test Legacy Path + // Multiply int value (key1) by 5 (integer) + expect(await GlideJson.nummultby(client, key, "key1", 5)).toBe( + "1500", + ); // 300 * 5 = -1500 + + // Multiply int value (key1) by -5.5 (float number) + expect( + await GlideJson.nummultby(client, key, "key1", -5.5), + ).toBe("-8250"); // -150 * -5.5 = -8250 + + // Multiply int float (key2) by 2.5 (a float number) + expect( + await GlideJson.nummultby(client, key, "key2", 2.5), + ).toBe("109.375"); // 109.375 + + // Multiply nested value (key3.nested_key.key1[0]) by 7 + expect( + await GlideJson.nummultby( + client, + key, + "key3.nested_key.key1[0]", + 7, + ), + ).toBe("140"); // 20 * 7 = 140 + + // Multiply array element (key4[1]) by 1 + expect( + await GlideJson.nummultby(client, key, "key4[1]", 1), + ).toBe("10"); // 10 * 1 = 10 + + // Multiply a float value (key5) by 10.2 (a float number) + expect( + await GlideJson.nummultby(client, key, "key5", 10.2), + ).toBe("0"); // 0 * 10.2 = 0 + + // Check for multiple path matches in legacy and assure that the result of the last updated value is returned + // last updated value is key8.nested_key.key1: 690 * 2 = 1380 + expect( + await GlideJson.nummultby(client, key, "..key1", 2), + ).toBe("1380"); // the last updated key1 value multiplied by 2 + + // Check if the rest of the key1 path matches were updated and not only the last value + expect( + await GlideJson.get(client, key, { path: "$..key1" }), + ).toBe("[-16500,[140,175],1380]"); + + // Check for non-existent path in legacy + await expect( + GlideJson.numincrby(client, key, ".key10", 51), + ).rejects.toThrow(RequestError); + + // Check for non-existent key in legacy + await expect( + GlideJson.numincrby( + client, + "non_existent_key", + ".key10", + 51, + ), + ).rejects.toThrow(RequestError); + + // Check for Overflow in legacy + await expect( + GlideJson.numincrby( + client, + key, + ".key9", + 1.7976931348623157e308, + ), + ).rejects.toThrow(RequestError); + + // binary buffer tests + expect( + await GlideJson.nummultby( + client, + Buffer.from(key), + Buffer.from("key5"), + 10.2, + ), + ).toBe("0"); // 0 * 10.2 = 0 + }); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "json.debug tests", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const jsonValue = + '{ "key1": 1, "key2": 3.5, "key3": {"nested_key": {"key1": [4, 5]}}, "key4":' + + ' [1, 2, 3], "key5": 0, "key6": "hello", "key7": null, "key8":' + + ' {"nested_key": {"key1": 3.5953862697246314e307}}, "key9":' + + ' 3.5953862697246314e307, "key10": true }'; + // setup + expect( + await GlideJson.set(client, key, "$", jsonValue), + ).toBe("OK"); + + expect( + await GlideJson.debugFields(client, key, { + path: "$.key1", + }), + ).toEqual([1]); + + expect( + await GlideJson.debugFields(client, key, { + path: "$.key3.nested_key.key1", + }), + ).toEqual([2]); + + expect( + await GlideJson.debugMemory(client, key, { + path: "$.key4[2]", + }), + ).toEqual([16]); + + expect( + await GlideJson.debugMemory(client, key, { + path: ".key6", + }), + ).toEqual(16); + + expect(await GlideJson.debugMemory(client, key)).toEqual( + 504, + ); + + expect(await GlideJson.debugFields(client, key)).toEqual( + 19, + ); + + // testing binary input + expect( + await GlideJson.debugMemory(client, Buffer.from(key)), + ).toEqual(504); + + expect( + await GlideJson.debugFields(client, Buffer.from(key)), + ).toEqual(19); + }, + ); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "json.objlen tests", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const jsonValue = { + a: 1.0, + b: { a: { x: 1, y: 2 }, b: 2.5, c: true }, + }; + + // setup + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + + expect( + await GlideJson.objlen(client, key, { path: "$" }), + ).toEqual([2]); + + expect( + await GlideJson.objlen(client, key, { path: "." }), + ).toEqual(2); + + expect( + await GlideJson.objlen(client, key, { path: "$.." }), + ).toEqual([2, 3, 2]); + + expect( + await GlideJson.objlen(client, key, { path: ".." }), + ).toEqual(2); + + expect( + await GlideJson.objlen(client, key, { path: "$..b" }), + ).toEqual([3, null]); + + expect( + await GlideJson.objlen(client, key, { path: "..b" }), + ).toEqual(3); + + expect( + await GlideJson.objlen(client, Buffer.from(key), { + path: Buffer.from("..a"), + }), + ).toEqual(2); + + expect(await GlideJson.objlen(client, key)).toEqual(2); + + // path doesn't exist + expect( + await GlideJson.objlen(client, key, { + path: "$.non_existing_path", + }), + ).toEqual([]); + + await expect( + GlideJson.objlen(client, key, { + path: "non_existing_path", + }), + ).rejects.toThrow(RequestError); + + // Value at path isnt an object + expect( + await GlideJson.objlen(client, key, { + path: "$.non_existing_path", + }), + ).toEqual([]); + + await expect( + GlideJson.objlen(client, key, { path: ".a" }), + ).rejects.toThrow(RequestError); + + // Non-existing key + expect( + await GlideJson.objlen(client, "non_existing_key", { + path: "$", + }), + ).toBeNull(); + + expect( + await GlideJson.objlen(client, "non_existing_key", { + path: ".", + }), + ).toBeNull(); + + expect( + await GlideJson.set( + client, + key, + "$", + '{"a": 1, "b": 2, "c":3, "d":4}', + ), + ).toBe("OK"); + expect(await GlideJson.objlen(client, key)).toEqual(4); + }, + ); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "json.objkeys tests", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const key = uuidv4(); + const jsonValue = { + a: 1.0, + b: { a: { x: 1, y: 2 }, b: 2.5, c: true }, + }; + + // setup + expect( + await GlideJson.set( + client, + key, + "$", + JSON.stringify(jsonValue), + ), + ).toBe("OK"); + + expect( + await GlideJson.objkeys(client, key, { path: "$" }), + ).toEqual([["a", "b"]]); + + expect( + await GlideJson.objkeys(client, key, { + path: ".", + decoder: Decoder.Bytes, + }), + ).toEqual([Buffer.from("a"), Buffer.from("b")]); + + expect( + await GlideJson.objkeys(client, Buffer.from(key), { + path: Buffer.from("$.."), + }), + ).toEqual([ + ["a", "b"], + ["a", "b", "c"], + ["x", "y"], + ]); + + expect( + await GlideJson.objkeys(client, key, { path: ".." }), + ).toEqual(["a", "b"]); + + expect( + await GlideJson.objkeys(client, key, { path: "$..b" }), + ).toEqual([["a", "b", "c"], []]); + + expect( + await GlideJson.objkeys(client, key, { path: "..b" }), + ).toEqual(["a", "b", "c"]); + + // path doesn't exist + expect( + await GlideJson.objkeys(client, key, { + path: "$.non_existing_path", + }), + ).toEqual([]); + + expect( + await GlideJson.objkeys(client, key, { + path: "non_existing_path", + }), + ).toBeNull(); + + // Value at path isnt an object + expect( + await GlideJson.objkeys(client, key, { path: "$.a" }), + ).toEqual([[]]); + + await expect( + GlideJson.objkeys(client, key, { path: ".a" }), + ).rejects.toThrow(RequestError); + + // Non-existing key + expect( + await GlideJson.objkeys(client, "non_existing_key", { + path: "$", + }), + ).toBeNull(); + + expect( + await GlideJson.objkeys(client, "non_existing_key", { + path: ".", + }), + ).toBeNull(); + }, + ); + }, + ); + + describe("GlideFt", () => { + let client: GlideClusterClient; + + afterEach(async () => { + await flushAndCloseClient(true, cluster.getAddresses(), client); + }); + + it("ServerModules check Vector Search module is loaded", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + ProtocolVersion.RESP3, + ), + ); + const info = await client.info({ + sections: [InfoOptions.Modules], + route: "randomNode", + }); + expect(info).toContain("# search_index_stats"); + }); + + it("FT.CREATE test", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + ProtocolVersion.RESP3, + ), + ); + + // Create a few simple indices: + const vectorField_1: VectorField = { + type: "VECTOR", + name: "vec", + alias: "VEC", + attributes: { + algorithm: "HNSW", + type: "FLOAT32", + dimensions: 2, + distanceMetric: "L2", + }, + }; + expect( + await GlideFt.create(client, uuidv4(), [vectorField_1]), + ).toEqual("OK"); + + expect( + await GlideFt.create( + client, + "json_idx1", + [ + { + type: "VECTOR", + name: "$.vec", + alias: "VEC", + attributes: { + algorithm: "HNSW", + type: "FLOAT32", + dimensions: 6, + distanceMetric: "L2", + numberOfEdges: 32, + }, + }, + ], + { + dataType: "JSON", + prefixes: ["json:"], + }, + ), + ).toEqual("OK"); + + const vectorField_2: VectorField = { + type: "VECTOR", + name: "$.vec", + alias: "VEC", + attributes: { + algorithm: "FLAT", + type: "FLOAT32", + dimensions: 6, + distanceMetric: "L2", + }, + }; + expect( + await GlideFt.create(client, uuidv4(), [vectorField_2]), + ).toEqual("OK"); + + // create an index with HNSW vector with additional parameters + const vectorField_3: VectorField = { + type: "VECTOR", + name: "doc_embedding", + attributes: { + algorithm: "HNSW", + type: "FLOAT32", + dimensions: 1536, + distanceMetric: "COSINE", + numberOfEdges: 40, + vectorsExaminedOnConstruction: 250, + vectorsExaminedOnRuntime: 40, + }, + }; + expect( + await GlideFt.create(client, uuidv4(), [vectorField_3], { + dataType: "HASH", + prefixes: ["docs:"], + }), + ).toEqual("OK"); + + // create an index with multiple fields + expect( + await GlideFt.create( + client, + uuidv4(), + [ + { type: "TEXT", name: "title" }, + { type: "NUMERIC", name: "published_at" }, + { type: "TAG", name: "category" }, + ], + { dataType: "HASH", prefixes: ["blog:post:"] }, + ), + ).toEqual("OK"); + + // create an index with multiple prefixes + const name = uuidv4(); + expect( + await GlideFt.create( + client, + name, + [ + { type: "TAG", name: "author_id" }, + { type: "TAG", name: "author_ids" }, + { type: "TEXT", name: "title" }, + { type: "TEXT", name: "name" }, + ], + { + dataType: "HASH", + prefixes: ["author:details:", "book:details:"], + }, + ), + ).toEqual("OK"); + + // create a duplicating index - expect a RequestError + try { + expect( + await GlideFt.create(client, name, [ + { type: "TEXT", name: "title" }, + { type: "TEXT", name: "name" }, + ]), + ).rejects.toThrow(); + } catch (e) { + expect((e as Error).message).toContain("already exists"); + } + + // create an index without fields - expect a RequestError + try { + expect( + await GlideFt.create(client, uuidv4(), []), + ).rejects.toThrow(); + } catch (e) { + expect((e as Error).message).toContain( + "wrong number of arguments", + ); + } + + // duplicated field name - expect a RequestError + try { + expect( + await GlideFt.create(client, uuidv4(), [ + { type: "TEXT", name: "name" }, + { type: "TEXT", name: "name" }, + ]), + ).rejects.toThrow(); + } catch (e) { + expect((e as Error).message).toContain("already exists"); + } + }); + + it("FT.DROPINDEX FT._LIST FT.LIST", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + ProtocolVersion.RESP3, + ), + ); + + // create an index + const index = uuidv4(); + expect( + await GlideFt.create(client, index, [ + { + type: "VECTOR", + name: "vec", + attributes: { + algorithm: "HNSW", + distanceMetric: "L2", + dimensions: 2, + }, + }, + { type: "NUMERIC", name: "published_at" }, + { type: "TAG", name: "category" }, + ]), + ).toEqual("OK"); + + const before = await GlideFt.list(client); + expect(before).toContain(index); + + // DROP it + expect(await GlideFt.dropindex(client, index)).toEqual("OK"); + + const after = await GlideFt.list(client); + expect(after).not.toContain(index); + + // dropping the index again results in an error + try { + expect( + await GlideFt.dropindex(client, index), + ).rejects.toThrow(); + } catch (e) { + expect((e as Error).message).toContain("Index does not exist"); + } + }); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "FT.INFO ft.info", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + + const index = uuidv4(); + expect( + await GlideFt.create( + client, + Buffer.from(index), + [ + { + type: "VECTOR", + name: "$.vec", + alias: "VEC", + attributes: { + algorithm: "HNSW", + distanceMetric: "COSINE", + dimensions: 42, + }, + }, + { type: "TEXT", name: "$.name" }, + ], + { dataType: "JSON", prefixes: ["123"] }, + ), + ).toEqual("OK"); + + let response = await GlideFt.info(client, Buffer.from(index)); + + expect(response).toMatchObject({ + index_name: index, + key_type: "JSON", + key_prefixes: ["123"], + fields: [ + { + identifier: "$.name", + type: "TEXT", + field_name: "$.name", + option: "", + }, + { + identifier: "$.vec", + type: "VECTOR", + field_name: "VEC", + option: "", + vector_params: { + distance_metric: "COSINE", + dimension: 42, + }, + }, + ], + }); + + response = await GlideFt.info(client, index, { + decoder: Decoder.Bytes, + }); + expect(response).toMatchObject({ + index_name: Buffer.from(index), + }); + + expect(await GlideFt.dropindex(client, index)).toEqual("OK"); + // querying a missing index + await expect(GlideFt.info(client, index)).rejects.toThrow( + "Index not found", + ); + }, + ); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "FT.AGGREGATE on JSON", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + + const isResp3 = protocol == ProtocolVersion.RESP3; + const prefixBicycles = "{bicycles}:"; + const indexBicycles = prefixBicycles + uuidv4(); + const query = "*"; + + // FT.CREATE idx:bicycle ON JSON PREFIX 1 bicycle: SCHEMA $.model AS model TEXT $.description AS + // description TEXT $.price AS price NUMERIC $.condition AS condition TAG SEPARATOR , + expect( + await GlideFt.create( + client, + indexBicycles, + [ + { type: "TEXT", name: "$.model", alias: "model" }, + { + type: "TEXT", + name: "$.description", + alias: "description", + }, + { + type: "NUMERIC", + name: "$.price", + alias: "price", + }, + { + type: "TAG", + name: "$.condition", + alias: "condition", + separator: ",", + }, + ], + { prefixes: [prefixBicycles], dataType: "JSON" }, + ), + ).toEqual("OK"); + + // TODO check JSON module loaded + expect( + await GlideJson.set( + client, + prefixBicycles + 0, + ".", + '{"brand": "Velorim", "model": "Jigger", "price": 270, "condition": "new"}', + ), + ).toEqual("OK"); + + expect( + await GlideJson.set( + client, + prefixBicycles + 1, + ".", + '{"brand": "Bicyk", "model": "Hillcraft", "price": 1200, "condition": "used"}', + ), + ).toEqual("OK"); + + expect( + await GlideJson.set( + client, + prefixBicycles + 2, + ".", + '{"brand": "Nord", "model": "Chook air 5", "price": 815, "condition": "used"}', + ), + ).toEqual("OK"); + + expect( + await GlideJson.set( + client, + prefixBicycles + 3, + ".", + '{"brand": "Eva", "model": "Eva 291", "price": 3400, "condition": "used"}', + ), + ).toEqual("OK"); + + expect( + await GlideJson.set( + client, + prefixBicycles + 4, + ".", + '{"brand": "Noka Bikes", "model": "Kahuna", "price": 3200, "condition": "used"}', + ), + ).toEqual("OK"); + + expect( + await GlideJson.set( + client, + prefixBicycles + 5, + ".", + '{"brand": "Breakout", "model": "XBN 2.1 Alloy", "price": 810, "condition": "new"}', + ), + ).toEqual("OK"); + + expect( + await GlideJson.set( + client, + prefixBicycles + 6, + ".", + '{"brand": "ScramBikes", "model": "WattBike", "price": 2300, "condition": "new"}', + ), + ).toEqual("OK"); + + expect( + await GlideJson.set( + client, + prefixBicycles + 7, + ".", + '{"brand": "Peaknetic", "model": "Secto", "price": 430, "condition": "new"}', + ), + ).toEqual("OK"); + + expect( + await GlideJson.set( + client, + prefixBicycles + 8, + ".", + '{"brand": "nHill", "model": "Summit", "price": 1200, "condition": "new"}', + ), + ).toEqual("OK"); + + expect( + await GlideJson.set( + client, + prefixBicycles + 9, + ".", + '{"model": "ThrillCycle", "brand": "BikeShind", "price": 815, "condition": "refurbished"}', + ), + ).toEqual("OK"); + + // let server digest the data and update index + await new Promise((resolve) => + setTimeout(resolve, DATA_PROCESSING_TIMEOUT), + ); + + // FT.AGGREGATE idx:bicycle * LOAD 1 __key GROUPBY 1 @condition REDUCE COUNT 0 AS bicycles + const options: FtAggregateOptions = { + loadFields: ["__key"], + clauses: [ + { + type: "GROUPBY", + properties: ["@condition"], + reducers: [ + { + function: "COUNT", + args: [], + name: "bicycles", + }, + ], + }, + ], + }; + const aggreg = await GlideFt.aggregate( + client, + indexBicycles, + query, + options, + ); + const expectedAggreg = [ + { + condition: "new", + bicycles: isResp3 ? 5 : "5", + }, + { + condition: "refurbished", + bicycles: isResp3 ? 1 : "1", + }, + { + condition: "used", + bicycles: isResp3 ? 4 : "4", + }, + ]; + expect( + aggreg + .map(convertGlideRecordToRecord) + // elements (records in array) could be reordered + .sort((a, b) => + a["condition"]! > b["condition"]! ? 1 : -1, + ), + ).toEqual(expectedAggreg); + + const aggregProfile: [ + FtAggregateReturnType, + Record, + ] = await GlideFt.profileAggregate( + client, + indexBicycles, + "*", + options, + ); + // profile metrics and categories are subject to change + expect(aggregProfile[1]).toBeTruthy(); + expect( + aggregProfile[0] + .map(convertGlideRecordToRecord) + // elements (records in array) could be reordered + .sort((a, b) => + a["condition"]! > b["condition"]! ? 1 : -1, + ), + ).toEqual(expectedAggreg); + + await GlideFt.dropindex(client, indexBicycles); + }, + ); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "FT.AGGREGATE on HASH", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + + const isResp3 = protocol == ProtocolVersion.RESP3; + const prefixMovies = "{movies}:"; + const indexMovies = prefixMovies + uuidv4(); + const query = "*"; + + // FT.CREATE idx:movie ON hash PREFIX 1 "movie:" SCHEMA title TEXT release_year NUMERIC + // rating NUMERIC genre TAG votes NUMERIC + expect( + await GlideFt.create( + client, + indexMovies, + [ + { type: "TEXT", name: "title" }, + { type: "NUMERIC", name: "release_year" }, + { type: "NUMERIC", name: "rating" }, + { type: "TAG", name: "genre" }, + { type: "NUMERIC", name: "votes" }, + ], + { prefixes: [prefixMovies], dataType: "HASH" }, + ), + ).toEqual("OK"); + + await client.hset(prefixMovies + 11002, { + title: "Star Wars: Episode V - The Empire Strikes Back", + release_year: "1980", + genre: "Action", + rating: "8.7", + votes: "1127635", + imdb_id: "tt0080684", + }); + + await client.hset(prefixMovies + 11003, { + title: "The Godfather", + release_year: "1972", + genre: "Drama", + rating: "9.2", + votes: "1563839", + imdb_id: "tt0068646", + }); + + await client.hset(prefixMovies + 11004, { + title: "Heat", + release_year: "1995", + genre: "Thriller", + rating: "8.2", + votes: "559490", + imdb_id: "tt0113277", + }); + + await client.hset(prefixMovies + 11005, { + title: "Star Wars: Episode VI - Return of the Jedi", + release_year: "1983", + genre: "Action", + rating: "8.3", + votes: "906260", + imdb_id: "tt0086190", + }); + + // let server digest the data and update index + await new Promise((resolve) => + setTimeout(resolve, DATA_PROCESSING_TIMEOUT), + ); + + // FT.AGGREGATE idx:movie * LOAD * APPLY ceil(@rating) as r_rating GROUPBY 1 @genre REDUCE + // COUNT 0 AS nb_of_movies REDUCE SUM 1 votes AS nb_of_votes REDUCE AVG 1 r_rating AS avg_rating + // SORTBY 4 @avg_rating DESC @nb_of_votes DESC + const options: FtAggregateOptions = { + loadAll: true, + clauses: [ + { + type: "APPLY", + expression: "ceil(@rating)", + name: "r_rating", + }, + { + type: "GROUPBY", + properties: ["@genre"], + reducers: [ + { + function: "COUNT", + args: [], + name: "nb_of_movies", + }, + { + function: "SUM", + args: ["votes"], + name: "nb_of_votes", + }, + { + function: "AVG", + args: ["r_rating"], + name: "avg_rating", + }, + ], + }, + { + type: "SORTBY", + properties: [ + { + property: "@avg_rating", + order: SortOrder.DESC, + }, + { + property: "@nb_of_votes", + order: SortOrder.DESC, + }, + ], + }, + ], + }; + const aggreg = await GlideFt.aggregate( + client, + indexMovies, + query, + options, + ); + const expectedAggreg = [ + { + genre: "Action", + nb_of_movies: isResp3 ? 2.0 : "2", + nb_of_votes: isResp3 ? 2033895.0 : "2033895", + avg_rating: isResp3 ? 9.0 : "9", + }, + { + genre: "Drama", + nb_of_movies: isResp3 ? 1.0 : "1", + nb_of_votes: isResp3 ? 1563839.0 : "1563839", + avg_rating: isResp3 ? 10.0 : "10", + }, + { + genre: "Thriller", + nb_of_movies: isResp3 ? 1.0 : "1", + nb_of_votes: isResp3 ? 559490.0 : "559490", + avg_rating: isResp3 ? 9.0 : "9", + }, + ]; + expect( + aggreg + .map(convertGlideRecordToRecord) + // elements (records in array) could be reordered + .sort((a, b) => (a["genre"]! > b["genre"]! ? 1 : -1)), + ).toEqual(expectedAggreg); + + const aggregProfile: [ + FtAggregateReturnType, + Record, + ] = await GlideFt.profileAggregate( + client, + indexMovies, + query, + options, + ); + // profile metrics and categories are subject to change + expect(aggregProfile[1]).toBeTruthy(); + expect( + aggregProfile[0] + .map(convertGlideRecordToRecord) + // elements (records in array) could be reordered + .sort((a, b) => (a["genre"]! > b["genre"]! ? 1 : -1)), + ).toEqual(expectedAggreg); + + await GlideFt.dropindex(client, indexMovies); + }, + ); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "FT.SEARCH binary on HASH", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + const prefix = "{" + uuidv4() + "}:"; + const index = prefix + "index"; + const query = "*=>[KNN 2 @VEC $query_vec]"; + + // setup a hash index: + expect( + await GlideFt.create( + client, + index, + [ + { + type: "VECTOR", + name: "vec", + alias: "VEC", + attributes: { + algorithm: "HNSW", + distanceMetric: "L2", + dimensions: 2, + }, + }, + ], + { + dataType: "HASH", + prefixes: [prefix], + }, + ), + ).toEqual("OK"); + + const binaryValue1 = Buffer.alloc(8); + expect( + await client.hset(Buffer.from(prefix + "0"), [ + // value of + { field: "vec", value: binaryValue1 }, + ]), + ).toEqual(1); + + const binaryValue2: Buffer = Buffer.alloc(8); + binaryValue2[6] = 0x80; + binaryValue2[7] = 0xbf; + expect( + await client.hset(Buffer.from(prefix + "1"), [ + // value of + { field: "vec", value: binaryValue2 }, + ]), + ).toEqual(1); + + // let server digest the data and update index + const sleep = new Promise((resolve) => + setTimeout(resolve, DATA_PROCESSING_TIMEOUT), + ); + await sleep; + + // With the `COUNT` parameters - returns only the count + const optionsWithCount: FtSearchOptions = { + params: [{ key: "query_vec", value: binaryValue1 }], + timeout: 10000, + count: true, + }; + const binaryResultCount: FtSearchReturnType = + await GlideFt.search(client, index, query, { + decoder: Decoder.Bytes, + ...optionsWithCount, + }); + expect(binaryResultCount).toEqual([2]); + + const options: FtSearchOptions = { + params: [{ key: "query_vec", value: binaryValue1 }], + timeout: 10000, + }; + const binaryResult: FtSearchReturnType = await GlideFt.search( + client, + index, + query, + { + decoder: Decoder.Bytes, + ...options, + }, + ); + + const expectedBinaryResult: FtSearchReturnType = [ + 2, + [ + { + key: Buffer.from(prefix + "1"), + value: [ + { + key: Buffer.from("vec"), + value: binaryValue2, + }, + { + key: Buffer.from("__VEC_score"), + value: Buffer.from("1"), + }, + ], + }, + { + key: Buffer.from(prefix + "0"), + value: [ + { + key: Buffer.from("vec"), + value: binaryValue1, + }, + { + key: Buffer.from("__VEC_score"), + value: Buffer.from("0"), + }, + ], + }, + ], + ]; + expect(binaryResult).toEqual(expectedBinaryResult); + + const binaryProfileResult: [ + FtSearchReturnType, + Record, + ] = await GlideFt.profileSearch(client, index, query, { + decoder: Decoder.Bytes, + ...options, + }); + // profile metrics and categories are subject to change + expect(binaryProfileResult[1]).toBeTruthy(); + expect(binaryProfileResult[0]).toEqual(expectedBinaryResult); + }, + ); + + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "FT.SEARCH binary on JSON", + async (protocol) => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + protocol, + ), + ); + + const prefix = "{" + uuidv4() + "}:"; + const index = prefix + "index"; + const query = "*"; + + // set string values + expect( + await GlideJson.set( + client, + prefix + "1", + "$", + '[{"arr": 42}, {"val": "hello"}, {"val": "world"}]', + ), + ).toEqual("OK"); + + // setup a json index: + expect( + await GlideFt.create( + client, + index, + [ + { + type: "NUMERIC", + name: "$..arr", + alias: "arr", + }, + { + type: "TEXT", + name: "$..val", + alias: "val", + }, + ], + { + dataType: "JSON", + prefixes: [prefix], + }, + ), + ).toEqual("OK"); + + // let server digest the data and update index + const sleep = new Promise((resolve) => + setTimeout(resolve, DATA_PROCESSING_TIMEOUT), + ); + await sleep; + + const optionsWithLimit: FtSearchOptions = { + returnFields: [ + { fieldIdentifier: "$..arr", alias: "myarr" }, + { fieldIdentifier: "$..val", alias: "myval" }, + ], + timeout: 10000, + limit: { offset: 0, count: 2 }, + }; + const stringResult: FtSearchReturnType = await GlideFt.search( + client, + index, + query, + optionsWithLimit, + ); + const expectedStringResult: FtSearchReturnType = [ + 1, + [ + { + key: prefix + "1", + value: [ + { + key: "myarr", + value: "42", + }, + { + key: "myval", + value: "hello", + }, + ], + }, + ], + ]; + expect(stringResult).toEqual(expectedStringResult); + + const stringProfileResult: [ + FtSearchReturnType, + Record, + ] = await GlideFt.profileSearch( + client, + index, + query, + optionsWithLimit, + ); + // profile metrics and categories are subject to change + expect(stringProfileResult[1]).toBeTruthy(); + expect(stringProfileResult[0]).toEqual(expectedStringResult); + }, + ); + + it("FT.EXPLAIN ft.explain FT.EXPLAINCLI ft.explaincli", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + ProtocolVersion.RESP3, + ), + ); + + const index = uuidv4(); + expect( + await GlideFt.create(client, index, [ + { type: "NUMERIC", name: "price" }, + { type: "TEXT", name: "title" }, + ]), + ).toEqual("OK"); + + let explain = await GlideFt.explain( + client, + Buffer.from(index), + "@price:[0 10]", + ); + expect(explain).toContain("price"); + expect(explain).toContain("10"); + + explain = ( + (await GlideFt.explain(client, index, "@price:[0 10]", { + decoder: Decoder.Bytes, + })) as Buffer + ).toString(); + expect(explain).toContain("price"); + expect(explain).toContain("10"); + + explain = await GlideFt.explain(client, index, "*"); + expect(explain).toContain("*"); + + let explaincli = ( + await GlideFt.explaincli( + client, + Buffer.from(index), + "@price:[0 10]", + ) + ).map((s) => (s as string).trim()); + expect(explaincli).toContain("price"); + expect(explaincli).toContain("0"); + expect(explaincli).toContain("10"); + + explaincli = ( + await GlideFt.explaincli(client, index, "@price:[0 10]", { + decoder: Decoder.Bytes, + }) + ).map((s) => (s as Buffer).toString().trim()); + expect(explaincli).toContain("price"); + expect(explaincli).toContain("0"); + expect(explaincli).toContain("10"); + + expect(await GlideFt.dropindex(client, index)).toEqual("OK"); + // querying a missing index + await expect(GlideFt.explain(client, index, "*")).rejects.toThrow( + "Index not found", + ); + await expect( + GlideFt.explaincli(client, index, "*"), + ).rejects.toThrow("Index not found"); + }); + + it("FT.ALIASADD, FT.ALIASUPDATE and FT.ALIASDEL test", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + ProtocolVersion.RESP3, + ), + ); + const index = uuidv4(); + const alias = uuidv4() + "-alias"; + + // Create an index. + expect( + await GlideFt.create(client, index, [ + { type: "NUMERIC", name: "published_at" }, + { type: "TAG", name: "category" }, + ]), + ).toEqual("OK"); + // Check if the index created successfully. + expect(await client.customCommand(["FT._LIST"])).toContain(index); + + // Add an alias to the index. + expect(await GlideFt.aliasadd(client, index, alias)).toEqual("OK"); + + const newIndex = uuidv4(); + const newAlias = uuidv4(); + + // Create a second index. + expect( + await GlideFt.create(client, newIndex, [ + { type: "NUMERIC", name: "published_at" }, + { type: "TAG", name: "category" }, + ]), + ).toEqual("OK"); + // Check if the second index created successfully. + expect(await client.customCommand(["FT._LIST"])).toContain( + newIndex, + ); + + // Add an alias to second index and also test addalias for bytes type input. + expect( + await GlideFt.aliasadd( + client, + Buffer.from(newIndex), + Buffer.from(newAlias), + ), + ).toEqual("OK"); + + // Test if updating an already existing alias to point to an existing index returns "OK". + expect(await GlideFt.aliasupdate(client, newAlias, index)).toEqual( + "OK", + ); + // Test alias update for byte type input. + expect( + await GlideFt.aliasupdate( + client, + Buffer.from(alias), + Buffer.from(newIndex), + ), + ).toEqual("OK"); + + // Test if an existing alias is deleted successfully. + expect(await GlideFt.aliasdel(client, alias)).toEqual("OK"); + + // Test if an existing alias is deleted successfully for bytes type input. + expect( + await GlideFt.aliasdel(client, Buffer.from(newAlias)), + ).toEqual("OK"); + + // Drop both indexes. + expect(await GlideFt.dropindex(client, index)).toEqual("OK"); + expect(await client.customCommand(["FT._LIST"])).not.toContain( + index, + ); + expect(await GlideFt.dropindex(client, newIndex)).toEqual("OK"); + expect(await client.customCommand(["FT._LIST"])).not.toContain( + newIndex, + ); + }); + + it("FT._ALIASLIST test", async () => { + client = await GlideClusterClient.createClient( + getClientConfigurationOption( + cluster.getAddresses(), + ProtocolVersion.RESP3, + ), + ); + const index1 = uuidv4(); + const alias1 = uuidv4() + "-alias"; + const index2 = uuidv4(); + const alias2 = uuidv4() + "-alias"; + + //Create the 2 test indexes. + expect( + await GlideFt.create(client, index1, [ + { type: "NUMERIC", name: "published_at" }, + { type: "TAG", name: "category" }, + ]), + ).toEqual("OK"); + expect( + await GlideFt.create(client, index2, [ + { type: "NUMERIC", name: "published_at" }, + { type: "TAG", name: "category" }, + ]), + ).toEqual("OK"); + + //Check if the two indexes created successfully. + expect(await client.customCommand(["FT._LIST"])).toContain(index1); + expect(await client.customCommand(["FT._LIST"])).toContain(index2); + + //Add aliases to the 2 indexes. + expect(await GlideFt.aliasadd(client, index1, alias1)).toBe("OK"); + expect(await GlideFt.aliasadd(client, index2, alias2)).toBe("OK"); + + //Test if the aliaslist command return the added alias. + const result = await GlideFt.aliaslist(client); + const expected: GlideRecord = [ + { + key: alias2, + value: index2, + }, + { + key: alias1, + value: index1, + }, + ]; + + const compareFunction = function ( + a: { key: GlideString; value: GlideString }, + b: { key: GlideString; value: GlideString }, + ) { + return a.key.toString().localeCompare(b.key.toString()) > 0 + ? 1 + : -1; + }; + + expect(result.sort(compareFunction)).toEqual( + expected.sort(compareFunction), + ); + }); + }); +}); diff --git a/node/tests/SharedTests.ts b/node/tests/SharedTests.ts index 51aebc03e1..6cada7b66f 100644 --- a/node/tests/SharedTests.ts +++ b/node/tests/SharedTests.ts @@ -20,7 +20,6 @@ import { BitOverflowControl, BitmapIndexType, BitwiseOperation, - ClosingError, ClusterTransaction, ConditionalChange, Decoder, @@ -127,13 +126,9 @@ export function runBaseTests(config: { await runTest(async (client: BaseClient) => { client.close(); - try { - expect(await client.set("foo", "bar")).toThrow(); - } catch (e) { - expect((e as ClosingError).message).toMatch( - "Unable to execute requests; the client is closed. Please create a new client.", - ); - } + await expect(client.set("foo", "bar")).rejects.toThrow( + "Unable to execute requests; the client is closed. Please create a new client.", + ); }, protocol); }, config.timeout, @@ -300,14 +295,10 @@ export function runBaseTests(config: { if (conf_file.length > 0) { expect(await client.configRewrite()).toEqual("OK"); } else { - try { - /// We expect Valkey to return an error since the test cluster doesn't use redis.conf file - expect(await client.configRewrite()).toThrow(); - } catch (e) { - expect((e as Error).message).toMatch( - "The server is running without a config file", - ); - } + /// We expect Valkey to return an error since the test cluster doesn't use redis.conf file + await expect(client.configRewrite()).rejects.toThrow( + "The server is running without a config file", + ); } }, protocol); }, @@ -505,29 +496,16 @@ export function runBaseTests(config: { const key = uuidv4(); expect(await client.set(key, "foo")).toEqual("OK"); - try { - expect(await client.incr(key)).toThrow(); - } catch (e) { - expect((e as Error).message).toMatch( - "value is not an integer", - ); - } - - try { - expect(await client.incrBy(key, 1)).toThrow(); - } catch (e) { - expect((e as Error).message).toMatch( - "value is not an integer", - ); - } + await expect(client.incr(key)).rejects.toThrow( + "value is not an integer", + ); - try { - expect(await client.incrByFloat(key, 1.5)).toThrow(); - } catch (e) { - expect((e as Error).message).toMatch( - "value is not a valid float", - ); - } + await expect(client.incrBy(key, 1)).rejects.toThrow( + "value is not an integer", + ); + await expect(client.incrByFloat(key, 1.5)).rejects.toThrow( + "value is not a valid float", + ); }, protocol); }, config.timeout, @@ -617,21 +595,13 @@ export function runBaseTests(config: { const key = uuidv4(); expect(await client.set(key, "foo")).toEqual("OK"); - try { - expect(await client.decr(key)).toThrow(); - } catch (e) { - expect((e as Error).message).toMatch( - "value is not an integer", - ); - } + await expect(client.decr(key)).rejects.toThrow( + "value is not an integer", + ); - try { - expect(await client.decrBy(key, 3)).toThrow(); - } catch (e) { - expect((e as Error).message).toMatch( - "value is not an integer", - ); - } + await expect(client.decrBy(key, 3)).rejects.toThrow( + "value is not an integer", + ); }, protocol); }, config.timeout, @@ -1594,7 +1564,7 @@ export function runBaseTests(config: { // Test count with match returns a non-empty list result = await client.hscan(key1, initialCursor, { match: "1*", - count: 30, + count: 1000, }); expect(result[resultCursorIndex]).not.toEqual(initialCursor); expect(result[resultCollectionIndex].length).toBeGreaterThan(0); @@ -1929,23 +1899,12 @@ export function runBaseTests(config: { }; expect(await client.hset(key, fieldValueMap)).toEqual(1); - try { - expect(await client.hincrBy(key, field, 2)).toThrow(); - } catch (e) { - expect((e as Error).message).toMatch( - "hash value is not an integer", - ); - } - - try { - expect( - await client.hincrByFloat(key, field, 1.5), - ).toThrow(); - } catch (e) { - expect((e as Error).message).toMatch( - "hash value is not a float", - ); - } + await expect(client.hincrBy(key, field, 2)).rejects.toThrow( + "hash value is not an integer", + ); + await expect( + client.hincrByFloat(key, field, 1.5), + ).rejects.toThrow("hash value is not a float"); }, protocol); }, config.timeout, @@ -2198,29 +2157,15 @@ export function runBaseTests(config: { const key = uuidv4(); expect(await client.set(key, "foo")).toEqual("OK"); - try { - expect(await client.lpush(key, ["bar"])).toThrow(); - } catch (e) { - expect((e as Error).message).toMatch( - "Operation against a key holding the wrong kind of value", - ); - } - - try { - expect(await client.lpop(key)).toThrow(); - } catch (e) { - expect((e as Error).message).toMatch( - "Operation against a key holding the wrong kind of value", - ); - } - - try { - expect(await client.lrange(key, 0, -1)).toThrow(); - } catch (e) { - expect((e as Error).message).toMatch( - "Operation against a key holding the wrong kind of value", - ); - } + await expect(client.lpush(key, ["bar"])).rejects.toThrow( + "Operation against a key holding the wrong kind of value", + ); + await expect(client.lpop(key)).rejects.toThrow( + "Operation against a key holding the wrong kind of value", + ); + await expect(client.lrange(key, 0, -1)).rejects.toThrow( + "Operation against a key holding the wrong kind of value", + ); }, protocol); }, config.timeout, @@ -2292,13 +2237,9 @@ export function runBaseTests(config: { expect(await client.set(key2, "foo")).toEqual("OK"); - try { - expect(await client.llen(key2)).toThrow(); - } catch (e) { - expect((e as Error).message).toMatch( - "Operation against a key holding the wrong kind of value", - ); - } + await expect(client.llen(key2)).rejects.toThrow( + "Operation against a key holding the wrong kind of value", + ); }, protocol); }, config.timeout, @@ -2660,13 +2601,9 @@ export function runBaseTests(config: { expect(await client.set(key, "foo")).toEqual("OK"); - try { - expect(await client.ltrim(key, 0, 1)).toThrow(); - } catch (e) { - expect((e as Error).message).toMatch( - "Operation against a key holding the wrong kind of value", - ); - } + await expect(client.ltrim(key, 0, 1)).rejects.toThrow( + "Operation against a key holding the wrong kind of value", + ); //test for binary key as input to the command const key2 = uuidv4(); @@ -2774,21 +2711,12 @@ export function runBaseTests(config: { const key = uuidv4(); expect(await client.set(key, "foo")).toEqual("OK"); - try { - expect(await client.rpush(key, ["bar"])).toThrow(); - } catch (e) { - expect((e as Error).message).toMatch( - "Operation against a key holding the wrong kind of value", - ); - } - - try { - expect(await client.rpop(key)).toThrow(); - } catch (e) { - expect((e as Error).message).toMatch( - "Operation against a key holding the wrong kind of value", - ); - } + await expect(client.rpush(key, ["bar"])).rejects.toThrow( + "Operation against a key holding the wrong kind of value", + ); + await expect(client.rpop(key)).rejects.toThrow( + "Operation against a key holding the wrong kind of value", + ); }, protocol); }, config.timeout, @@ -2974,37 +2902,18 @@ export function runBaseTests(config: { const key = uuidv4(); expect(await client.set(key, "foo")).toEqual("OK"); - try { - expect(await client.sadd(key, ["bar"])).toThrow(); - } catch (e) { - expect((e as Error).message).toMatch( - "Operation against a key holding the wrong kind of value", - ); - } - - try { - expect(await client.srem(key, ["bar"])).toThrow(); - } catch (e) { - expect((e as Error).message).toMatch( - "Operation against a key holding the wrong kind of value", - ); - } - - try { - expect(await client.scard(key)).toThrow(); - } catch (e) { - expect((e as Error).message).toMatch( - "Operation against a key holding the wrong kind of value", - ); - } - - try { - expect(await client.smembers(key)).toThrow(); - } catch (e) { - expect((e as Error).message).toMatch( - "Operation against a key holding the wrong kind of value", - ); - } + await expect(client.sadd(key, ["bar"])).rejects.toThrow( + "Operation against a key holding the wrong kind of value", + ); + await expect(client.srem(key, ["bar"])).rejects.toThrow( + "Operation against a key holding the wrong kind of value", + ); + await expect(client.scard(key)).rejects.toThrow( + "Operation against a key holding the wrong kind of value", + ); + await expect(client.smembers(key)).rejects.toThrow( + "Operation against a key holding the wrong kind of value", + ); }, protocol); }, config.timeout, @@ -3036,13 +2945,9 @@ export function runBaseTests(config: { ).toEqual(new Set([Buffer.from("c"), Buffer.from("d")])); // invalid argument - key list must not be empty - try { - expect(await client.sinter([])).toThrow(); - } catch (e) { - expect((e as Error).message).toMatch( - "ResponseError: wrong number of arguments", - ); - } + await expect(client.sinter([])).rejects.toThrow( + "wrong number of arguments", + ); // non-existing key returns empty set expect(await client.sinter([key1, non_existing_key])).toEqual( @@ -3052,13 +2957,9 @@ export function runBaseTests(config: { // non-set key expect(await client.set(key2, "value")).toEqual("OK"); - try { - expect(await client.sinter([key2])).toThrow(); - } catch (e) { - expect((e as Error).message).toMatch( - "Operation against a key holding the wrong kind of value", - ); - } + await expect(client.sinter([key2])).rejects.toThrow( + "Operation against a key holding the wrong kind of value", + ); }, protocol); }, config.timeout, @@ -4232,6 +4133,18 @@ export function runBaseTests(config: { config.timeout, ); + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( + "script kill killable test_%p", + async (protocol) => { + await runTest(async (client: BaseClient) => { + await expect(client.scriptKill()).rejects.toThrow( + "No scripts in execution right now", + ); + }, protocol); + }, + config.timeout, + ); + it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( `zadd and zaddIncr with NX XX test_%p`, async (protocol) => { @@ -6432,7 +6345,7 @@ export function runBaseTests(config: { expect( await client.bzpopmax( [key3], - cluster.checkIfServerVersionLessThan("6.0.0") + cluster.checkIfServerVersionLessThan("7.0.0") ? 1.0 : 0.01, ), @@ -6482,7 +6395,7 @@ export function runBaseTests(config: { expect( await client.bzpopmin( [key3], - cluster.checkIfServerVersionLessThan("6.0.0") + cluster.checkIfServerVersionLessThan("7.0.0") ? 1.0 : 0.01, ), @@ -6692,19 +6605,6 @@ export function runBaseTests(config: { // key exists, but it is not a list await client.set("foo", "bar"); await expect(client.brpop(["foo"], 0.1)).rejects.toThrow(); - - // Same-slot requirement - if (client instanceof GlideClusterClient) { - try { - expect( - await client.brpop(["abc", "zxy", "lkn"], 0.1), - ).toThrow(); - } catch (e) { - expect((e as Error).message.toLowerCase()).toMatch( - "crossslot", - ); - } - } }, protocol); }, config.timeout, @@ -6735,19 +6635,6 @@ export function runBaseTests(config: { // key exists, but it is not a list await client.set("foo", "bar"); await expect(client.blpop(["foo"], 0.1)).rejects.toThrow(); - - // Same-slot requirement - if (client instanceof GlideClusterClient) { - try { - expect( - await client.blpop(["abc", "zxy", "lkn"], 0.1), - ).toThrow(); - } catch (e) { - expect((e as Error).message.toLowerCase()).toMatch( - "crossslot", - ); - } - } }, protocol); }, config.timeout, @@ -7525,126 +7412,152 @@ export function runBaseTests(config: { it.each([ProtocolVersion.RESP2, ProtocolVersion.RESP3])( `xinfo stream xinfosream test_%p`, async (protocol) => { - await runTest( - async (client: BaseClient, cluster: ValkeyCluster) => { - const key = uuidv4(); - const groupName = `group-${uuidv4()}`; - const consumerName = `consumer-${uuidv4()}`; - const streamId0_0 = "0-0"; - const streamId1_0 = "1-0"; - const streamId1_1 = "1-1"; + await runTest(async (client: BaseClient, cluster) => { + const key = uuidv4(); + const groupName = `group-${uuidv4()}`; + const consumerName = `consumer-${uuidv4()}`; + const streamId0_0 = "0-0"; + const streamId1_0 = "1-0"; + const streamId1_1 = "1-1"; - expect( - await client.xadd( - key, - [ - ["a", "b"], - ["c", "d"], - ], - { id: streamId1_0 }, - ), - ).toEqual(streamId1_0); + expect( + await client.xadd( + key, + [ + ["a", "b"], + ["c", "d"], + ], + { id: streamId1_0 }, + ), + ).toEqual(streamId1_0); - expect( - await client.xgroupCreate(key, groupName, streamId0_0), - ).toEqual("OK"); + expect( + await client.xgroupCreate(key, groupName, streamId0_0), + ).toEqual("OK"); - await client.xreadgroup(groupName, consumerName, { - [key]: ">", - }); + await client.xreadgroup(groupName, consumerName, { + [key]: ">", + }); - const result = (await client.xinfoStream(key)) as { - length: number; - "radix-tree-keys": number; - "radix-tree-nodes": number; - "last-generated-id": string; - "max-deleted-entry-id": string; - "entries-added": number; - "recorded-first-entry-id": string; - "first-entry": (string | number | string[])[]; - "last-entry": (string | number | string[])[]; - groups: number; - }; + const result = (await client.xinfoStream(key)) as { + length: number; + "radix-tree-keys": number; + "radix-tree-nodes": number; + "last-generated-id": string; + "max-deleted-entry-id": string; + "entries-added": number; + "recorded-first-entry-id": string; + "first-entry": (string | number | string[])[]; + "last-entry": (string | number | string[])[]; + groups: number; + }; - expect(result.length).toEqual(1); - const expectedFirstEntry = ["1-0", ["a", "b", "c", "d"]]; - expect(result["first-entry"]).toEqual(expectedFirstEntry); - expect(result["last-entry"]).toEqual(expectedFirstEntry); - expect(result.groups).toEqual(1); + expect(result.length).toEqual(1); + const expectedFirstEntry = ["1-0", ["a", "b", "c", "d"]]; + expect(result["first-entry"]).toEqual(expectedFirstEntry); + expect(result["last-entry"]).toEqual(expectedFirstEntry); + expect(result.groups).toEqual(1); - expect( - await client.xadd(key, [["foo", "bar"]], { - id: streamId1_1, - }), - ).toEqual(streamId1_1); - const fullResult = (await client.xinfoStream( - Buffer.from(key), + expect( + await client.xadd(key, [["foo", "bar"]], { + id: streamId1_1, + }), + ).toEqual(streamId1_1); + const fullResult = (await client.xinfoStream(Buffer.from(key), { + fullOptions: 1, + })) as { + length: number; + "radix-tree-keys": number; + "radix-tree-nodes": number; + "last-generated-id": string; + "max-deleted-entry-id": string; + "entries-added": number; + "recorded-first-entry-id": string; + entries: (string | number | string[])[][]; + groups: [ { - fullOptions: 1, + name: string; + "last-delivered-id": string; + "entries-read": number; + lag: number; + "pel-count": number; + pending: (string | number)[][]; + consumers: [ + { + name: string; + "seen-time": number; + "active-time": number; + "pel-count": number; + pending: (string | number)[][]; + }, + ]; }, - )) as { - length: number; - "radix-tree-keys": number; - "radix-tree-nodes": number; - "last-generated-id": string; - "max-deleted-entry-id": string; - "entries-added": number; - "recorded-first-entry-id": string; - entries: (string | number | string[])[][]; - groups: [ - { - name: string; - "last-delivered-id": string; - "entries-read": number; - lag: number; - "pel-count": number; - pending: (string | number)[][]; - consumers: [ - { - name: string; - "seen-time": number; - "active-time": number; - "pel-count": number; - pending: (string | number)[][]; - }, - ]; - }, - ]; - }; + ]; + }; - expect(fullResult.length).toEqual(2); + // verify full result like: + // { + // length: 2, + // 'radix-tree-keys': 1, + // 'radix-tree-nodes': 2, + // 'last-generated-id': '1-1', + // 'max-deleted-entry-id': '0-0', + // 'entries-added': 2, + // 'recorded-first-entry-id': '1-0', + // entries: [ [ '1-0', ['a', 'b', ...] ] ], + // groups: [ { + // name: 'group', + // 'last-delivered-id': '1-0', + // 'entries-read': 1, + // lag: 1, + // 'pel-count': 1, + // pending: [ [ '1-0', 'consumer', 1722624726802, 1 ] ], + // consumers: [ { + // name: 'consumer', + // 'seen-time': 1722624726802, + // 'active-time': 1722624726802, + // 'pel-count': 1, + // pending: [ [ '1-0', 'consumer', 1722624726802, 1 ] ], + // } + // ] + // } + // ] + // } + expect(fullResult.length).toEqual(2); - if (cluster.checkIfServerVersionLessThan("7.0.0")) { - expect( - fullResult["max-deleted-entry-id"], - ).toBeUndefined(); - expect(fullResult["entries-added"]).toBeUndefined(); - expect( - fullResult.groups[0]["entries-read"], - ).toBeUndefined(); - expect(fullResult.groups[0]["lag"]).toBeUndefined(); - } else if (cluster.checkIfServerVersionLessThan("7.2.0")) { - expect(fullResult["recorded-first-entry-id"]).toEqual( - streamId1_0, - ); + if (!cluster.checkIfServerVersionLessThan("7.0.0")) { + expect(fullResult["recorded-first-entry-id"]).toEqual( + streamId1_0, + ); + } - expect( - fullResult.groups[0].consumers[0]["active-time"], - ).toBeUndefined(); - expect( - fullResult.groups[0].consumers[0]["seen-time"], - ).toBeDefined(); - } else { - expect( - fullResult.groups[0].consumers[0]["active-time"], - ).toBeDefined(); - expect( - fullResult.groups[0].consumers[0]["seen-time"], - ).toBeDefined(); - } - }, - protocol, - ); + if (cluster.checkIfServerVersionLessThan("7.0.0")) { + expect(fullResult["max-deleted-entry-id"]).toBeUndefined(); + expect(fullResult["entries-added"]).toBeUndefined(); + expect( + fullResult.groups[0]["entries-read"], + ).toBeUndefined(); + expect(fullResult.groups[0]["lag"]).toBeUndefined(); + } else if (cluster.checkIfServerVersionLessThan("7.2.0")) { + expect(fullResult["recorded-first-entry-id"]).toEqual( + streamId1_0, + ); + + expect( + fullResult.groups[0].consumers[0]["active-time"], + ).toBeUndefined(); + expect( + fullResult.groups[0].consumers[0]["seen-time"], + ).toBeDefined(); + } else { + expect( + fullResult.groups[0].consumers[0]["active-time"], + ).toBeDefined(); + expect( + fullResult.groups[0].consumers[0]["seen-time"], + ).toBeDefined(); + } + }, protocol); }, config.timeout, ); @@ -7736,11 +7649,9 @@ export function runBaseTests(config: { const key3 = `{key}-3-${uuidv4()}`; // renamenx missing key - try { - expect(await client.renamenx(key1, key2)).toThrow(); - } catch (e) { - expect((e as Error).message).toMatch("no such key"); - } + await expect(client.renamenx(key1, key2)).rejects.toThrow( + "no such key", + ); // renamenx a string await client.set(key1, "key1"); @@ -7969,13 +7880,9 @@ export function runBaseTests(config: { expect(await client.pfcount([key3])).toEqual(0); // invalid argument - key list must not be empty - try { - expect(await client.pfcount([])).toThrow(); - } catch (e) { - expect((e as Error).message).toMatch( - "ResponseError: wrong number of arguments", - ); - } + await expect(client.pfcount([])).rejects.toThrow( + "ResponseError: wrong number of arguments", + ); // key exists, but it is not a HyperLogLog expect(await client.set(stringKey, "value")).toEqual("OK"); @@ -9246,7 +9153,6 @@ export function runBaseTests(config: { { sortOrder: SortOrder.DESC, storeDist: true }, ), ).toEqual(3); - // TODO deep close to https://github.com/maasencioh/jest-matcher-deep-close-to expect( await client.zrangeWithScores( key2, @@ -9843,7 +9749,7 @@ export function runBaseTests(config: { // Test count with match returns a non-empty list result = await client.zscan(key1, initialCursor, { match: "member1*", - count: 20, + count: 1000, }); expect(result[resultCursorIndex]).not.toEqual("0"); expect( diff --git a/node/tests/TestUtilities.ts b/node/tests/TestUtilities.ts index d83fb56d98..a2e28a0a6d 100644 --- a/node/tests/TestUtilities.ts +++ b/node/tests/TestUtilities.ts @@ -4,7 +4,6 @@ import { expect } from "@jest/globals"; import { exec } from "child_process"; -import parseArgs from "minimist"; import { gte } from "semver"; import { v4 as uuidv4 } from "uuid"; import { @@ -177,7 +176,6 @@ export function flushallOnPort(port: number): Promise { */ export const parseEndpoints = (endpointsStr: string): [string, number][] => { try { - console.log(endpointsStr); const endpoints: string[][] = endpointsStr .split(",") .map((endpoint) => endpoint.split(":")); @@ -330,40 +328,6 @@ export function createLongRunningLuaScript( return script.replaceAll("$timeout", timeout.toString()); } -export async function waitForScriptNotBusy( - client: GlideClusterClient | GlideClient, -) { - // If function wasn't killed, and it didn't time out - it blocks the server and cause rest test to fail. - let isBusy = true; - - do { - try { - await client.scriptKill(); - } catch (err) { - // should throw `notbusy` error, because the function should be killed before - if ((err as Error).message.toLowerCase().includes("notbusy")) { - isBusy = false; - } - } - } while (isBusy); -} - -/** - * Parses the command-line arguments passed to the Node.js process. - * - * @returns Parsed command-line arguments. - * - * @example - * ```typescript - * // Command: node script.js --name="John Doe" --age=30 - * const args = parseCommandLineArgs(); - * // args = { name: 'John Doe', age: 30 } - * ``` - */ -export function parseCommandLineArgs() { - return parseArgs(process.argv.slice(2)); -} - export async function testTeardown( cluster_mode: boolean, option: BaseClientConfiguration, @@ -387,6 +351,8 @@ export const getClientConfigurationOption = ( port, })), protocol, + useTLS: global.TLS ?? false, + requestTimeout: 1000, ...configOverrides, }; }; @@ -598,46 +564,6 @@ export async function encodableTransactionTest( return responseData; } -/** Populates a transaction with dump and restore commands - * - * @param baseTransaction - A transaction - * @param valueResponse - Represents the encoded response of "value" to compare - * @returns Array of tuples, where first element is a test name/description, second - expected return value. - */ -export async function DumpAndRestoreTest( - baseTransaction: Transaction, - dumpValue: Buffer | null, -): Promise<[string, GlideReturnType][]> { - if (dumpValue == null) { - throw new Error("dumpValue is null"); - } - - const key = "{key}-" + uuidv4(); // string - const buffValue = Buffer.from("value"); - // array of tuples - first element is test name/description, second - expected return value - const responseData: [string, GlideReturnType][] = []; - - baseTransaction.set(key, "value"); - responseData.push(["set(key, stringValue)", "OK"]); - baseTransaction.customCommand(["DUMP", key]); - responseData.push(['customCommand(["DUMP", key])', dumpValue]); - baseTransaction.get(key); - responseData.push(["get(key)", buffValue]); - baseTransaction.del([key]); - responseData.push(["del(key)", 1]); - baseTransaction.get(key); - responseData.push(["get(key)", null]); - baseTransaction.customCommand(["RESTORE", key, "0", dumpValue]); - responseData.push([ - 'customCommand(["RESTORE", buffKey, "0", stringValue])', - "OK", - ]); - baseTransaction.get(key); - responseData.push(["get(key)", buffValue]); - - return responseData; -} - /** * Populates a transaction with commands to test. * @param baseTransaction - A transaction. @@ -834,14 +760,14 @@ export async function transactionTest( baseTransaction.lmpop([key24], ListDirection.LEFT); responseData.push([ "lmpop([key22], ListDirection.LEFT)", - [{ key: key24, value: [field + "2"] }], + convertRecordToGlideRecord({ [key24]: [field + "2"] }), ]); baseTransaction.lpush(key24, [field + "2"]); responseData.push(["lpush(key22, [2])", 2]); baseTransaction.blmpop([key24], ListDirection.LEFT, 0.1, 1); responseData.push([ "blmpop([key22], ListDirection.LEFT, 0.1, 1)", - [{ key: key24, value: [field + "2"] }], + convertRecordToGlideRecord({ [key24]: [field + "2"] }), ]); } @@ -1270,15 +1196,6 @@ export async function transactionTest( "xpending(key9, groupName1)", [1, "0-2", "0-2", [[consumer, "1"]]], ]); - baseTransaction.xpendingWithOptions(key9, groupName1, { - start: InfBoundary.NegativeInfinity, - end: InfBoundary.PositiveInfinity, - count: 10, - }); - responseData.push([ - "xpendingWithOptions(key9, groupName1, -, +, 10)", - [["0-2", consumer, 0, 1]], - ]); baseTransaction.xclaim(key9, groupName1, consumer, 0, ["0-2"]); responseData.push([ 'xclaim(key9, groupName1, consumer, 0, ["0-2"])', @@ -1346,7 +1263,6 @@ export async function transactionTest( responseData.push(["xgroupDestroy(key9, groupName1)", true]); baseTransaction.xgroupDestroy(key9, groupName2); responseData.push(["xgroupDestroy(key9, groupName2)", true]); - baseTransaction.rename(key9, key10); responseData.push(["rename(key9, key10)", "OK"]); baseTransaction.exists([key10]); @@ -1711,9 +1627,6 @@ export async function transactionTest( responseData.push(["sortReadOnly(key21)", ["1", "2", "3"]]); } - baseTransaction.wait(1, 200); - if (gte(version, "7.0.0")) responseData.push(["wait(1, 200)", 1]); - else responseData.push(["wait(1, 200)", 0]); return responseData; } diff --git a/node/tests/setup.ts b/node/tests/setup.ts new file mode 100644 index 0000000000..337d5430d2 --- /dev/null +++ b/node/tests/setup.ts @@ -0,0 +1,23 @@ +/* eslint-disable no-var */ +import { beforeAll } from "@jest/globals"; +import minimist from "minimist"; +import { Logger } from "../build-ts"; + +beforeAll(() => { + Logger.init("error", "log.log"); + // Logger.setLoggerConfig("off"); +}); + +declare global { + var CLI_ARGS: Record; + var CLUSTER_ENDPOINTS: string; + var STAND_ALONE_ENDPOINT: string; + var TLS: boolean; +} + +const args = minimist(process.argv.slice(2)); +// Make the arguments available globally +global.CLI_ARGS = args; +global.CLUSTER_ENDPOINTS = args["cluster-endpoints"] as string; +global.STAND_ALONE_ENDPOINT = args["standalone-endpoints"] as string; +global.TLS = !!args.tls; diff --git a/node/tests/tsconfig.json b/node/tests/tsconfig.json index f7b57f9a00..ab48e3f902 100644 --- a/node/tests/tsconfig.json +++ b/node/tests/tsconfig.json @@ -3,5 +3,5 @@ "compilerOptions": { "rootDir": "../../" }, - "include": ["*.ts", "./*.test.ts"] + "include": ["*.ts", "./*.test.ts", "setup.ts"] } diff --git a/node/tsconfig.json b/node/tsconfig.json index 4cd744701c..a1824416be 100644 --- a/node/tsconfig.json +++ b/node/tsconfig.json @@ -25,6 +25,15 @@ ] /* Specify a set of bundled library declaration files that describe the target runtime environment. */, "outDir": "./build-ts" /* Specify an output folder for all emitted files.*/ }, + "ts-node": { + "transpileOnly": true, + "compilerOptions": { + "module": "CommonJS", + "target": "ES2018", + "esModuleInterop": true + }, + "esm": true + }, "compileOnSave": false, "include": ["./*.ts", "src/*.ts", "src/*.js"], "exclude": ["node_modules", "build-ts"] diff --git a/package.json b/package.json index fa682d7107..3f61298feb 100644 --- a/package.json +++ b/package.json @@ -3,10 +3,10 @@ "@eslint/js": "^9.10.0", "@types/eslint__js": "^8.42.3", "@types/eslint-config-prettier": "^6.11.3", - "eslint": "^9.10.0", + "eslint": "9.14.0", "eslint-config-prettier": "^9.1.0", "prettier": "^3.3.3", "typescript": "^5.6.2", - "typescript-eslint": "^8.5.0" + "typescript-eslint": "^8.13" } } diff --git a/python/Cargo.toml b/python/Cargo.toml index cb82b7e5e4..aaf49762df 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -16,8 +16,8 @@ pyo3 = { version = "^0.22", features = [ "num-bigint", "gil-refs", ] } -bytes = { version = "1.6.0" } -redis = { path = "../submodules/redis-rs/redis", features = [ +bytes = { version = "^1.8" } +redis = { path = "../glide-core/redis-rs/redis", features = [ "aio", "tokio-comp", "connection-manager", diff --git a/python/DEVELOPER.md b/python/DEVELOPER.md index abf12dc9a3..66127913c3 100644 --- a/python/DEVELOPER.md +++ b/python/DEVELOPER.md @@ -2,15 +2,13 @@ This document describes how to set up your development environment to build and test the Valkey GLIDE Python wrapper. -### Development Overview - The Valkey GLIDE Python wrapper consists of both Python and Rust code. Rust bindings for Python are implemented using [PyO3](https://github.com/PyO3/pyo3), and the Python package is built using [maturin](https://github.com/PyO3/maturin). The Python and Rust components communicate using the [protobuf](https://github.com/protocolbuffers/protobuf) protocol. -### Build from source +# Prerequisites +--- -#### Prerequisites +Before building the package from source, make sure that you have installed the listed dependencies below: -Software Dependencies - python3 virtualenv - git @@ -21,7 +19,10 @@ Software Dependencies - openssl-dev - rustup -**Dependencies installation for Ubuntu** +For your convenience, we wrapped the steps in a "copy-paste" code blocks for common operating systems: + +
      +Ubuntu / Debian ```bash sudo apt update -y @@ -42,7 +43,10 @@ export PATH="$PATH:$HOME/.local/bin" protoc --version ``` -**Dependencies installation for CentOS** +
      + +
      +CentOS ```bash sudo yum update -y @@ -62,7 +66,10 @@ export PATH="$PATH:$HOME/.local/bin" protoc --version ``` -**Dependencies installation for MacOS** +
      + +
      +MacOS ```bash brew update @@ -80,112 +87,108 @@ source /Users/$USER/.bash_profile protoc --version ``` -#### Building and installation steps +
      -Before starting this step, make sure you've installed all software requirments. +# Building +--- -1. Clone the repository: - ```bash - git clone https://github.com/valkey-io/valkey-glide.git - cd valkey-glide - ``` -2. Initialize git submodule: - ```bash - git submodule update --init --recursive - ``` -3. Generate protobuf files: - ```bash - GLIDE_ROOT_FOLDER_PATH=. - protoc -Iprotobuf=${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/ --python_out=${GLIDE_ROOT_FOLDER_PATH}/python/python/glide ${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/*.proto - ``` -4. Create a virtual environment: - ```bash - cd python - python3 -m venv .env - ``` -5. Activate the virtual environment: - ```bash - source .env/bin/activate - ``` -6. Install requirements: - ```bash - pip install -r requirements.txt - ``` -7. Build the Python wrapper in release mode: - ``` - maturin develop --release --strip - ``` - > **Note:** To build the wrapper binary with debug symbols remove the --strip flag. -8. Run tests: - 1. Ensure that you have installed redis-server or valkey-server and redis-cli or valkey-cli on your host. You can find the Redis installation guide at the following link: [Redis Installation Guide](https://redis.io/docs/install/install-redis/install-redis-on-linux/). You can get Valkey from the following link: [Valkey Download](https://valkey.io/download/). - 2. Validate the activation of the virtual environment from step 4 by ensuring its name (`.env`) is displayed next to your command prompt. - 3. Execute the following command from the python folder: - ```bash - pytest --asyncio-mode=auto - ``` - > **Note:** To run Valkey modules tests, add -k "test_server_modules.py". - -- Install Python development requirements with: +Before starting this step, make sure you've installed all software requirements. - ```bash - pip install -r python/dev_requirements.txt - ``` +## Prepare your environment -- For a fast build, execute `maturin develop` without the release flag. This will perform an unoptimized build, which is suitable for developing tests. Keep in mind that performance is significantly affected in an unoptimized build, so it's required to include the "--release" flag when measuring performance. - -### Test +```bash +mkdir -p $HOME/src +cd $_ +git clone https://github.com/valkey-io/valkey-glide.git +cd valkey-glide +GLIDE_ROOT=$(pwd) +protoc -Iprotobuf=${GLIDE_ROOT}/glide-core/src/protobuf/ \ + --python_out=${GLIDE_ROOT}/python/python/glide \ + ${GLIDE_ROOT}/glide-core/src/protobuf/*.proto +cd python +python3 -m venv .env +source .env/bin/activate +pip install -r requirements.txt +pip install -r dev_requirements.txt +``` -To run tests, use the following command: +## Build the package (in release mode): ```bash -pytest --asyncio-mode=auto +maturin develop --release --strip ``` -To execute a specific test, include the `-k ` option. For example: +> **Note:** to build the wrapper binary with debug symbols remove the `--strip` flag. + +> **Note 2:** for a faster build time, execute `maturin develop` without the release flag. This will perform an unoptimized build, which is suitable for developing tests. Keep in mind that performance is significantly affected in an unoptimized build, so it's required to include the `--release` flag when measuring performance. + +# Running tests +--- + +Ensure that you have installed `redis-server` or `valkey-server` along with `redis-cli` or `valkey-cli` on your host. You can find the Redis installation guide at the following link: [Redis Installation Guide](https://redis.io/docs/install/install-redis/install-redis-on-linux/). You can get Valkey from the following link: [Valkey Download](https://valkey.io/download/). + +From a terminal, change directory to the GLIDE source folder and type: ```bash -pytest --asyncio-mode=auto -k test_socket_set_and_get +cd $HOME/src/valkey-glide +cd python +source .env/bin/activate +pytest --asyncio-mode=auto ``` -IT suite starts the server for testing - standalone and cluster installation using `cluster_manager` script. -If you want IT to use already started servers, use the following command line from `python/python` dir: +To run modules tests: ```bash -pytest --asyncio-mode=auto --cluster-endpoints=localhost:7000 --standalone-endpoints=localhost:6379 +cd $HOME/src/valkey-glide +cd python +source .env/bin/activate +pytest --asyncio-mode=auto -k "test_server_modules.py" ``` -### Submodules +**TIP:** to run a specific test, append `-k ` to the `pytest` execution line -After pulling new changes, ensure that you update the submodules by running the following command: +To run tests against an already running servers, change the `pytest` line above to this: ```bash -git submodule update +pytest --asyncio-mode=auto --cluster-endpoints=localhost:7000 --standalone-endpoints=localhost:6379 ``` -### Generate protobuf files +# Generate protobuf files +--- -During the initial build, Python protobuf files were created in `python/python/glide/protobuf`. If modifications are made to the protobuf definition files (.proto files located in `glide-core/src/protofuf`), it becomes necessary to regenerate the Python protobuf files. To do so, run: +During the initial build, Python protobuf files were created in `python/python/glide/protobuf`. If modifications are made +to the protobuf definition files (`.proto` files located in `glide-core/src/protofuf`), it becomes necessary to +regenerate the Python protobuf files. To do so, run: ```bash -GLIDE_ROOT_FOLDER_PATH=. # e.g. /home/ubuntu/valkey-glide -protoc -Iprotobuf=${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/ --python_out=${GLIDE_ROOT_FOLDER_PATH}/python/python/glide ${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/*.proto +cd $HOME/src/valkey-glide +GLIDE_ROOT_FOLDER_PATH=. +protoc -Iprotobuf=${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/ \ + --python_out=${GLIDE_ROOT_FOLDER_PATH}/python/python/glide \ + ${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/*.proto ``` -#### Protobuf interface files +## Protobuf interface files To generate the protobuf files with Python Interface files (pyi) for type-checking purposes, ensure you have installed `mypy-protobuf` with pip, and then execute the following command: ```bash -GLIDE_ROOT_FOLDER_PATH=. # e.g. /home/ubuntu/valkey-glide +cd $HOME/src/valkey-glide +GLIDE_ROOT_FOLDER_PATH=. MYPY_PROTOC_PATH=`which protoc-gen-mypy` -protoc --plugin=protoc-gen-mypy=${MYPY_PROTOC_PATH} -Iprotobuf=${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/ --python_out=${GLIDE_ROOT_FOLDER_PATH}/python/python/glide --mypy_out=${GLIDE_ROOT_FOLDER_PATH}/python/python/glide ${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/*.proto +protoc --plugin=protoc-gen-mypy=${MYPY_PROTOC_PATH} \ + -Iprotobuf=${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/ \ + --python_out=${GLIDE_ROOT_FOLDER_PATH}/python/python/glide \ + --mypy_out=${GLIDE_ROOT_FOLDER_PATH}/python/python/glide \ + ${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/*.proto ``` -### Linters +# Linters +--- Development on the Python wrapper may involve changes in either the Python or Rust code. Each language has distinct linter tests that must be passed before committing changes. -#### Language-specific Linters +## Language-specific Linters **Python:** @@ -199,31 +202,37 @@ Development on the Python wrapper may involve changes in either the Python or Ru - clippy - fmt -#### Running the linters +## Running the linters Run from the main `/python` folder 1. Python - > Note: make sure to [generate protobuf with interface files]("#protobuf-interface-files") before running mypy linter + > Note: make sure to [generate protobuf with interface files]("#protobuf-interface-files") before running `mypy` linter ```bash + cd $HOME/src/valkey-glide/python + source .env/bin/activate pip install -r dev_requirements.txt isort . --profile black --skip-glob python/glide/protobuf --skip-glob .env black . --exclude python/glide/protobuf --exclude .env - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=python/glide/protobuf,.env/* --extend-ignore=E230 - flake8 . --count --exit-zero --max-complexity=12 --max-line-length=127 --statistics --exclude=python/glide/protobuf,.env/* --extend-ignore=E230 + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics \ + --exclude=python/glide/protobuf,.env/* --extend-ignore=E230 + flake8 . --count --exit-zero --max-complexity=12 --max-line-length=127 \ + --statistics --exclude=python/glide/protobuf,.env/* \ + --extend-ignore=E230 # run type check mypy . ``` + 2. Rust ```bash rustup component add clippy rustfmt cargo clippy --all-features --all-targets -- -D warnings cargo fmt --manifest-path ./Cargo.toml --all - ``` -### Recommended extensions for VS Code +# Recommended extensions for VS Code +--- - [Python](https://marketplace.visualstudio.com/items?itemName=ms-python.python) - [isort](https://marketplace.visualstudio.com/items?itemName=ms-python.isort) diff --git a/python/Pipfile b/python/Pipfile index 5d44a4887c..d582066226 100644 --- a/python/Pipfile +++ b/python/Pipfile @@ -8,4 +8,4 @@ name = "pypi" [dev-packages] [requires] -python_version = "3.8" +python_version = "3.9" diff --git a/python/README.md b/python/README.md index 89c4bec560..aa8ee70a63 100644 --- a/python/README.md +++ b/python/README.md @@ -10,17 +10,26 @@ Refer to the [Supported Engine Versions table](https://github.com/valkey-io/valk ## System Requirements -The beta release of Valkey GLIDE was tested on Intel x86_64 using Ubuntu 22.04.1, Amazon Linux 2023 (AL2023), and macOS 12.7. +The release of Valkey GLIDE was tested on the following platforms: + +Linux: + +- Ubuntu 22.04.1 (x86_64 and aarch64) +- Amazon Linux 2023 (AL2023) (x86_64) + +macOS: + +- macOS 14.7 (Apple silicon/aarch_64) ## Python Supported Versions | Python Version | |----------------| -| 3.8 | | 3.9 | | 3.10 | | 3.11 | | 3.12 | +| 3.13 | ## Installation and Setup diff --git a/python/dev_requirements.txt b/python/dev_requirements.txt index 36f3438740..02e9c4fd53 100644 --- a/python/dev_requirements.txt +++ b/python/dev_requirements.txt @@ -4,3 +4,4 @@ isort == 5.10 mypy == 1.2 mypy-protobuf == 3.5 packaging >= 22.0 +pyrsistent diff --git a/python/pyproject.toml b/python/pyproject.toml index 4f4e79b91d..013a4b0e57 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,14 +1,14 @@ [build-system] -requires = ["maturin>=0.13,<0.14"] +requires = ["maturin==0.14.17"] build-backend = "maturin" [project] name = "valkey-glide" -requires-python = ">=3.8" +requires-python = ">=3.9" dependencies = [ "async-timeout>=4.0.2; python_version < '3.11'", "typing-extensions>=4.8.0; python_version < '3.11'", - "protobuf>=3.20" + "protobuf>=3.20", ] classifiers = [ "Topic :: Database", @@ -30,7 +30,7 @@ max-line-length = 127 extend-ignore = ['E203'] [tool.black] -target-version = ['py38', 'py39', 'py310', 'py311', 'py312'] +target-version = ['py39', 'py310', 'py311', 'py312', 'py313'] [tool.mypy] -exclude = [ 'submodules', 'utils/release-candidate-testing' ] +exclude = ['submodules', 'utils/release-candidate-testing'] diff --git a/python/pytest.ini b/python/pytest.ini index 0624078cea..81fc7f8fce 100644 --- a/python/pytest.ini +++ b/python/pytest.ini @@ -1,4 +1,5 @@ [pytest] markers = smoke_test: mark a test as a build verification testing. + skip_if_version_below: parametrized mark for fast skipping tests by server version. addopts = -k "not server_modules and not pubsub" diff --git a/python/python/glide/__init__.py b/python/python/glide/__init__.py index cf817c128a..8f6ceac47b 100644 --- a/python/python/glide/__init__.py +++ b/python/python/glide/__init__.py @@ -32,7 +32,48 @@ InsertPosition, UpdateOptions, ) -from glide.async_commands.server_modules import json +from glide.async_commands.server_modules import ft, glide_json +from glide.async_commands.server_modules.ft_options.ft_aggregate_options import ( + FtAggregateApply, + FtAggregateClause, + FtAggregateFilter, + FtAggregateGroupBy, + FtAggregateLimit, + FtAggregateOptions, + FtAggregateReducer, + FtAggregateSortBy, + FtAggregateSortProperty, +) +from glide.async_commands.server_modules.ft_options.ft_create_options import ( + DataType, + DistanceMetricType, + Field, + FieldType, + FtCreateOptions, + NumericField, + TagField, + TextField, + VectorAlgorithm, + VectorField, + VectorFieldAttributes, + VectorFieldAttributesFlat, + VectorFieldAttributesHnsw, + VectorType, +) +from glide.async_commands.server_modules.ft_options.ft_profile_options import ( + FtProfileOptions, + QueryType, +) +from glide.async_commands.server_modules.ft_options.ft_search_options import ( + FtSearchLimit, + FtSearchOptions, + ReturnField, +) +from glide.async_commands.server_modules.glide_json import ( + JsonArrIndexOptions, + JsonArrPopOptions, + JsonGetOptions, +) from glide.async_commands.sorted_set import ( AggregationType, GeoSearchByBox, @@ -83,11 +124,17 @@ from glide.constants import ( OK, TOK, + FtAggregateResponse, + FtInfoResponse, + FtProfileResponse, + FtSearchResponse, TClusterResponse, TEncodable, TFunctionListResponse, TFunctionStatsFullResponse, TFunctionStatsSingleNodeResponse, + TJsonResponse, + TJsonUniversalResponse, TResult, TSingleNodeRoute, TXInfoStreamFullResponse, @@ -145,10 +192,16 @@ "TFunctionListResponse", "TFunctionStatsFullResponse", "TFunctionStatsSingleNodeResponse", + "TJsonResponse", + "TJsonUniversalResponse", "TOK", "TResult", "TXInfoStreamFullResponse", "TXInfoStreamResponse", + "FtAggregateResponse", + "FtInfoResponse", + "FtProfileResponse", + "FtSearchResponse", # Commands "BitEncoding", "BitFieldGet", @@ -184,7 +237,7 @@ "InfBound", "InfoSection", "InsertPosition", - "json", + "ft", "LexBoundary", "Limit", "ListDirection", @@ -209,9 +262,14 @@ "TrimByMaxLen", "TrimByMinId", "UpdateOptions", - "ClusterScanCursor" + "ClusterScanCursor", # PubSub "PubSubMsg", + # Json + "glide_json", + "JsonGetOptions", + "JsonArrIndexOptions", + "JsonArrPopOptions", # Logger "Logger", "LogLevel", @@ -233,4 +291,33 @@ "GlideError", "RequestError", "TimeoutError", + # Ft + "DataType", + "DistanceMetricType", + "Field", + "FieldType", + "FtCreateOptions", + "NumericField", + "TagField", + "TextField", + "VectorAlgorithm", + "VectorField", + "VectorFieldAttributes", + "VectorFieldAttributesFlat", + "VectorFieldAttributesHnsw", + "VectorType", + "FtSearchLimit", + "ReturnField", + "FtSearchOptions", + "FtAggregateApply", + "FtAggregateFilter", + "FtAggregateClause", + "FtAggregateLimit", + "FtAggregateOptions", + "FtAggregateGroupBy", + "FtAggregateReducer", + "FtAggregateSortBy", + "FtAggregateSortProperty", + "FtProfileOptions", + "QueryType", ] diff --git a/python/python/glide/async_commands/command_args.py b/python/python/glide/async_commands/command_args.py index 92e0100665..05efe925ab 100644 --- a/python/python/glide/async_commands/command_args.py +++ b/python/python/glide/async_commands/command_args.py @@ -36,6 +36,7 @@ class OrderBy(Enum): This enum is used for the following commands: - `SORT`: General sorting in ascending or descending order. - `GEOSEARCH`: Sorting items based on their proximity to a center point. + - `FT.AGGREGATE`: Used in the SortBy clause of the FT.AGGREGATE command. """ ASC = "ASC" diff --git a/python/python/glide/async_commands/core.py b/python/python/glide/async_commands/core.py index 7acb44ca60..6ebc8d2ab6 100644 --- a/python/python/glide/async_commands/core.py +++ b/python/python/glide/async_commands/core.py @@ -392,6 +392,44 @@ async def _cluster_scan( type: Optional[ObjectType] = ..., ) -> TResult: ... + async def _update_connection_password( + self, password: Optional[str], immediate_auth: bool + ) -> TResult: ... + + async def update_connection_password( + self, password: Optional[str], immediate_auth=False + ) -> TOK: + """ + Update the current connection password with a new password. + + **Note:** This method updates the client's internal password configuration and does + not perform password rotation on the server side. + + This method is useful in scenarios where the server password has changed or when + utilizing short-lived passwords for enhanced security. It allows the client to + update its password to reconnect upon disconnection without the need to recreate + the client instance. This ensures that the internal reconnection mechanism can + handle reconnection seamlessly, preventing the loss of in-flight commands. + + Args: + password (`Optional[str]`): The new password to use for the connection, + if `None` the password will be removed. + immediate_auth (`bool`): + - `True`: The client will authenticate immediately with the new password against all connections, Using `AUTH` command. + If password supplied is an empty string, auth will not be performed and warning will be returned. + The default is `False`. + + Returns: + TOK: A simple OK response. If `immediate_auth=True` returns OK if the reauthenticate succeed. + + Example: + >>> await client.update_connection_password("new_password", immediate_auth=True) + 'OK' + """ + return cast( + TOK, await self._update_connection_password(password, immediate_auth) + ) + async def set( self, key: TEncodable, @@ -613,7 +651,13 @@ async def delete(self, keys: List[TEncodable]) -> int: See https://valkey.io/commands/del/ for details. Note: - When in cluster mode, the command may route to multiple nodes when `keys` map to different hash slots. + In cluster mode, if keys in `keys` map to different hash slots, + the command will be split across these slots and executed separately for each. + This means the command is atomic only at the slot level. If one or more slot-specific + requests fail, the entire call will return the first encountered error, even + though some requests may have succeeded while others did not. + If this behavior impacts your application logic, consider splitting the + request into sub-requests per slot to ensure atomicity. Args: keys (List[TEncodable]): A list of keys to be deleted from the database. @@ -730,7 +774,13 @@ async def mset(self, key_value_map: Mapping[TEncodable, TEncodable]) -> TOK: See https://valkey.io/commands/mset/ for more details. Note: - When in cluster mode, the command may route to multiple nodes when keys in `key_value_map` map to different hash slots. + In cluster mode, if keys in `key_value_map` map to different hash slots, + the command will be split across these slots and executed separately for each. + This means the command is atomic only at the slot level. If one or more slot-specific + requests fail, the entire call will return the first encountered error, even + though some requests may have succeeded while others did not. + If this behavior impacts your application logic, consider splitting the + request into sub-requests per slot to ensure atomicity. Args: key_value_map (Mapping[TEncodable, TEncodable]): A map of key value pairs. @@ -783,8 +833,13 @@ async def mget(self, keys: List[TEncodable]) -> List[Optional[bytes]]: See https://valkey.io/commands/mget/ for more details. Note: - When in cluster mode, the command may route to multiple nodes when `keys` map to different hash slots. - + In cluster mode, if keys in `keys` map to different hash slots, + the command will be split across these slots and executed separately for each. + This means the command is atomic only at the slot level. If one or more slot-specific + requests fail, the entire call will return the first encountered error, even + though some requests may have succeeded while others did not. + If this behavior impacts your application logic, consider splitting the + request into sub-requests per slot to ensure atomicity. Args: keys (List[TEncodable]): A list of keys to retrieve values for. @@ -850,7 +905,14 @@ async def touch(self, keys: List[TEncodable]) -> int: See https://valkey.io/commands/touch/ for details. Note: - When in cluster mode, the command may route to multiple nodes when `keys` map to different hash slots. + In cluster mode, if keys in `key_value_map` map to different hash slots, + the command will be split across these slots and executed separately for each. + This means the command is atomic only at the slot level. If one or more slot-specific + requests fail, the entire call will return the first encountered error, even + though some requests may have succeeded while others did not. + If this behavior impacts your application logic, consider splitting the + request into sub-requests per slot to ensure atomicity. Args: + keys (List[TEncodable]): The list of keys to unlink. Args: keys (List[TEncodable]): The keys to update last access time. @@ -2303,7 +2365,13 @@ async def exists(self, keys: List[TEncodable]) -> int: See https://valkey.io/commands/exists/ for more details. Note: - When in cluster mode, the command may route to multiple nodes when `keys` map to different hash slots. + In cluster mode, if keys in `keys` map to different hash slots, + the command will be split across these slots and executed separately for each. + This means the command is atomic only at the slot level. If one or more slot-specific + requests fail, the entire call will return the first encountered error, even + though some requests may have succeeded while others did not. + If this behavior impacts your application logic, consider splitting the + request into sub-requests per slot to ensure atomicity. Args: keys (List[TEncodable]): The list of keys to check. @@ -2327,7 +2395,13 @@ async def unlink(self, keys: List[TEncodable]) -> int: See https://valkey.io/commands/unlink/ for more details. Note: - When in cluster mode, the command may route to multiple nodes when `keys` map to different hash slots. + In cluster mode, if keys in `key_value_map` map to different hash slots, + the command will be split across these slots and executed separately for each. + This means the command is atomic only at the slot level. If one or more slot-specific + requests fail, the entire call will return the first encountered error, even + though some requests may have succeeded while others did not. + If this behavior impacts your application logic, consider splitting the + request into sub-requests per slot to ensure atomicity. Args: keys (List[TEncodable]): The list of keys to unlink. @@ -6360,7 +6434,13 @@ async def watch(self, keys: List[TEncodable]) -> TOK: See https://valkey.io/commands/watch for more details. Note: - When in cluster mode, the command may route to multiple nodes when `keys` map to different hash slots. + In cluster mode, if keys in `key_value_map` map to different hash slots, + the command will be split across these slots and executed separately for each. + This means the command is atomic only at the slot level. If one or more slot-specific + requests fail, the entire call will return the first encountered error, even + though some requests may have succeeded while others did not. + If this behavior impacts your application logic, consider splitting the + request into sub-requests per slot to ensure atomicity. Args: keys (List[TEncodable]): The keys to watch. diff --git a/python/python/glide/async_commands/server_modules/ft.py b/python/python/glide/async_commands/server_modules/ft.py new file mode 100644 index 0000000000..9dbd79592c --- /dev/null +++ b/python/python/glide/async_commands/server_modules/ft.py @@ -0,0 +1,395 @@ +# Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 +""" +module for `vector search` commands. +""" + +from typing import List, Mapping, Optional, cast + +from glide.async_commands.server_modules.ft_options.ft_aggregate_options import ( + FtAggregateOptions, +) +from glide.async_commands.server_modules.ft_options.ft_constants import ( + CommandNames, + FtCreateKeywords, +) +from glide.async_commands.server_modules.ft_options.ft_create_options import ( + Field, + FtCreateOptions, +) +from glide.async_commands.server_modules.ft_options.ft_profile_options import ( + FtProfileOptions, +) +from glide.async_commands.server_modules.ft_options.ft_search_options import ( + FtSearchOptions, +) +from glide.constants import ( + TOK, + FtAggregateResponse, + FtInfoResponse, + FtProfileResponse, + FtSearchResponse, + TEncodable, +) +from glide.glide_client import TGlideClient + + +async def create( + client: TGlideClient, + index_name: TEncodable, + schema: List[Field], + options: Optional[FtCreateOptions] = None, +) -> TOK: + """ + Creates an index and initiates a backfill of that index. + + Args: + client (TGlideClient): The client to execute the command. + index_name (TEncodable): The index name. + schema (List[Field]): Fields to populate into the index. Equivalent to `SCHEMA` block in the module API. + options (Optional[FtCreateOptions]): Optional arguments for the FT.CREATE command. + + Returns: + TOK: A simple "OK" response. + + Examples: + >>> from glide import ft + >>> schema: List[Field] = [TextField("title")] + >>> prefixes: List[str] = ["blog:post:"] + >>> await ft.create(glide_client, "my_idx1", schema, FtCreateOptions(DataType.HASH, prefixes)) + 'OK' # Indicates successful creation of index named 'idx' + """ + args: List[TEncodable] = [CommandNames.FT_CREATE, index_name] + if options: + args.extend(options.to_args()) + if schema: + args.append(FtCreateKeywords.SCHEMA) + for field in schema: + args.extend(field.to_args()) + return cast(TOK, await client.custom_command(args)) + + +async def dropindex(client: TGlideClient, index_name: TEncodable) -> TOK: + """ + Drops an index. The index definition and associated content are deleted. Keys are unaffected. + + Args: + client (TGlideClient): The client to execute the command. + index_name (TEncodable): The index name for the index to be dropped. + + Returns: + TOK: A simple "OK" response. + + Examples: + For the following example to work, an index named 'idx' must be already created. If not created, you will get an error. + >>> from glide import ft + >>> index_name = "idx" + >>> await ft.dropindex(glide_client, index_name) + 'OK' # Indicates successful deletion/dropping of index named 'idx' + """ + args: List[TEncodable] = [CommandNames.FT_DROPINDEX, index_name] + return cast(TOK, await client.custom_command(args)) + + +async def list(client: TGlideClient) -> List[TEncodable]: + """ + Lists all indexes. + + Args: + client (TGlideClient): The client to execute the command. + + Returns: + List[TEncodable]: An array of index names. + + Examples: + >>> from glide import ft + >>> await ft.list(glide_client) + [b"index1", b"index2"] + """ + return cast(List[TEncodable], await client.custom_command([CommandNames.FT_LIST])) + + +async def search( + client: TGlideClient, + index_name: TEncodable, + query: TEncodable, + options: Optional[FtSearchOptions], +) -> FtSearchResponse: + """ + Uses the provided query expression to locate keys within an index. Once located, the count and/or the content of indexed fields within those keys can be returned. + + Args: + client (TGlideClient): The client to execute the command. + index_name (TEncodable): The index name to search into. + query (TEncodable): The text query to search. + options (Optional[FtSearchOptions]): The search options. + + Returns: + FtSearchResponse: A two element array, where first element is count of documents in result set, and the second element, which has the format Mapping[TEncodable, Mapping[TEncodable, TEncodable]] is a mapping between document names and map of their attributes. + If count(option in `FtSearchOptions`) is set to true or limit(option in `FtSearchOptions`) is set to FtSearchLimit(0, 0), the command returns array with only one element - the count of the documents. + + Examples: + For the following example to work the following must already exist: + - An index named "idx", with fields having identifiers as "a" and "b" and prefix as "{json:}" + - A key named {json:}1 with value {"a":1, "b":2} + + >>> from glide import ft + >>> await ft.search(glide_client, "idx", "*", options=FtSeachOptions(return_fields=[ReturnField(field_identifier="first"), ReturnField(field_identifier="second")])) + [1, { b'json:1': { b'first': b'42', b'second': b'33' } }] # The first element, 1 is the number of keys returned in the search result. The second element is a map of data queried per key. + """ + args: List[TEncodable] = [CommandNames.FT_SEARCH, index_name, query] + if options: + args.extend(options.to_args()) + return cast(FtSearchResponse, await client.custom_command(args)) + + +async def aliasadd( + client: TGlideClient, alias: TEncodable, index_name: TEncodable +) -> TOK: + """ + Adds an alias for an index. The new alias name can be used anywhere that an index name is required. + + Args: + client (TGlideClient): The client to execute the command. + alias (TEncodable): The alias to be added to an index. + index_name (TEncodable): The index name for which the alias has to be added. + + Returns: + TOK: A simple "OK" response. + + Examples: + >>> from glide import ft + >>> await ft.aliasadd(glide_client, "myalias", "myindex") + 'OK' # Indicates the successful addition of the alias named "myalias" for the index. + """ + args: List[TEncodable] = [CommandNames.FT_ALIASADD, alias, index_name] + return cast(TOK, await client.custom_command(args)) + + +async def aliasdel(client: TGlideClient, alias: TEncodable) -> TOK: + """ + Deletes an existing alias for an index. + + Args: + client (TGlideClient): The client to execute the command. + alias (TEncodable): The existing alias to be deleted for an index. + + Returns: + TOK: A simple "OK" response. + + Examples: + >>> from glide import ft + >>> await ft.aliasdel(glide_client, "myalias") + 'OK' # Indicates the successful deletion of the alias named "myalias" + """ + args: List[TEncodable] = [CommandNames.FT_ALIASDEL, alias] + return cast(TOK, await client.custom_command(args)) + + +async def aliasupdate( + client: TGlideClient, alias: TEncodable, index_name: TEncodable +) -> TOK: + """ + Updates an existing alias to point to a different physical index. This command only affects future references to the alias. + + Args: + client (TGlideClient): The client to execute the command. + alias (TEncodable): The alias name. This alias will now be pointed to a different index. + index_name (TEncodable): The index name for which an existing alias has to updated. + + Returns: + TOK: A simple "OK" response. + + Examples: + >>> from glide import ft + >>> await ft.aliasupdate(glide_client, "myalias", "myindex") + 'OK' # Indicates the successful update of the alias to point to the index named "myindex" + """ + args: List[TEncodable] = [CommandNames.FT_ALIASUPDATE, alias, index_name] + return cast(TOK, await client.custom_command(args)) + + +async def info(client: TGlideClient, index_name: TEncodable) -> FtInfoResponse: + """ + Returns information about a given index. + + Args: + client (TGlideClient): The client to execute the command. + index_name (TEncodable): The index name for which the information has to be returned. + + Returns: + FtInfoResponse: Nested maps with info about the index. See example for more details. + + Examples: + An index with name 'myIndex', 1 text field and 1 vector field is already created for gettting the output of this example. + >>> from glide import ft + >>> await ft.info(glide_client, "myIndex") + [ + b'index_name', + b'myIndex', + b'creation_timestamp', 1729531116945240, + b'key_type', b'JSON', + b'key_prefixes', [b'key-prefix'], + b'fields', [ + [ + b'identifier', b'$.vec', + b'field_name', b'VEC', + b'type', b'VECTOR', + b'option', b'', + b'vector_params', [ + b'algorithm', b'HNSW', b'data_type', b'FLOAT32', b'dimension', 2, b'distance_metric', b'L2', b'initial_capacity', 1000, b'current_capacity', 1000, b'maximum_edges', 16, b'ef_construction', 200, b'ef_runtime', 10, b'epsilon', b'0.01' + ] + ], + [ + b'identifier', b'$.text-field', + b'field_name', b'text-field', + b'type', b'TEXT', + b'option', b'' + ] + ], + b'space_usage', 653351, + b'fulltext_space_usage', 0, + b'vector_space_usage', 653351, + b'num_docs', 0, + b'num_indexed_vectors', 0, + b'current_lag', 0, + b'index_status', b'AVAILABLE', + b'index_degradation_percentage', 0 + ] + """ + args: List[TEncodable] = [CommandNames.FT_INFO, index_name] + return cast(FtInfoResponse, await client.custom_command(args)) + + +async def explain( + client: TGlideClient, index_name: TEncodable, query: TEncodable +) -> TEncodable: + """ + Parse a query and return information about how that query was parsed. + + Args: + client (TGlideClient): The client to execute the command. + index_name (TEncodable): The index name for which the query is written. + query (TEncodable): The search query, same as the query passed as an argument to FT.SEARCH. + + Returns: + TEncodable: A string containing the parsed results representing the execution plan. + + Examples: + >>> from glide import ft + >>> await ft.explain(glide_client, indexName="myIndex", query="@price:[0 10]") + b'Field {\n price\n 0\n 10\n}\n' # Parsed results. + """ + args: List[TEncodable] = [CommandNames.FT_EXPLAIN, index_name, query] + return cast(TEncodable, await client.custom_command(args)) + + +async def explaincli( + client: TGlideClient, index_name: TEncodable, query: TEncodable +) -> List[TEncodable]: + """ + Same as the FT.EXPLAIN command except that the results are displayed in a different format. More useful with cli. + + Args: + client (TGlideClient): The client to execute the command. + index_name (TEncodable): The index name for which the query is written. + query (TEncodable): The search query, same as the query passed as an argument to FT.SEARCH. + + Returns: + List[TEncodable]: An array containing the execution plan. + + Examples: + >>> from glide import ft + >>> await ft.explaincli(glide_client, indexName="myIndex", query="@price:[0 10]") + [b'Field {', b' price', b' 0', b' 10', b'}', b''] # Parsed results. + """ + args: List[TEncodable] = [CommandNames.FT_EXPLAINCLI, index_name, query] + return cast(List[TEncodable], await client.custom_command(args)) + + +async def aggregate( + client: TGlideClient, + index_name: TEncodable, + query: TEncodable, + options: Optional[FtAggregateOptions], +) -> FtAggregateResponse: + """ + A superset of the FT.SEARCH command, it allows substantial additional processing of the keys selected by the query expression. + + Args: + client (TGlideClient): The client to execute the command. + index_name (TEncodable): The index name for which the query is written. + query (TEncodable): The search query, same as the query passed as an argument to FT.SEARCH. + options (Optional[FtAggregateOptions]): The optional arguments for the command. + + Returns: + FtAggregateResponse: An array containing a mapping of field name and associated value as returned after the last stage of the command. + + Examples: + >>> from glide import ft + >>> await ft.aggregate(glide_client, "myIndex", "*", FtAggregateOptions(loadFields=["__key"], clauses=[GroupBy(["@condition"], [Reducer("COUNT", [], "bicycles")])])) + [{b'condition': b'refurbished', b'bicycles': b'1'}, {b'condition': b'new', b'bicycles': b'5'}, {b'condition': b'used', b'bicycles': b'4'}] + """ + args: List[TEncodable] = [CommandNames.FT_AGGREGATE, index_name, query] + if options: + args.extend(options.to_args()) + return cast(FtAggregateResponse, await client.custom_command(args)) + + +async def profile( + client: TGlideClient, index_name: TEncodable, options: FtProfileOptions +) -> FtProfileResponse: + """ + Runs a search or aggregation query and collects performance profiling information. + + Args: + client (TGlideClient): The client to execute the command. + index_name (TEncodable): The index name + options (FtProfileOptions): Options for the command. + + Returns: + FtProfileResponse: A two-element array. The first element contains results of query being profiled, the second element stores profiling information. + + Examples: + >>> ftSearchOptions = FtSeachOptions(return_fields=[ReturnField(field_identifier="a", alias="a_new"), ReturnField(field_identifier="b", alias="b_new")]) + >>> await ft.profile(glide_client, "myIndex", FtProfileOptions.from_query_options(query="*", queryOptions=ftSearchOptions)) + [ + [ + 2, + { + b'key1': { + b'a': b'11111', + b'b': b'2' + }, + b'key2': { + b'a': b'22222', + b'b': b'2' + } + } + ], + { + b'all.count': 2, + b'sync.time': 1, + b'query.time': 7, + b'result.count': 2, + b'result.time': 0 + } + ] + """ + args: List[TEncodable] = [CommandNames.FT_PROFILE, index_name] + options.to_args() + return cast(FtProfileResponse, await client.custom_command(args)) + + +async def aliaslist(client: TGlideClient) -> Mapping[TEncodable, TEncodable]: + """ + List the index aliases. + Args: + client (TGlideClient): The client to execute the command. + Returns: + Mapping[TEncodable, TEncodable]: A map of index aliases for indices being aliased. + Examples: + >>> from glide import ft + >>> await ft._aliaslist(glide_client) + {b'alias': b'index1', b'alias-bytes': b'index2'} + """ + args: List[TEncodable] = [CommandNames.FT_ALIASLIST] + return cast(Mapping, await client.custom_command(args)) diff --git a/python/python/glide/async_commands/server_modules/ft_options/ft_aggregate_options.py b/python/python/glide/async_commands/server_modules/ft_options/ft_aggregate_options.py new file mode 100644 index 0000000000..c121c10985 --- /dev/null +++ b/python/python/glide/async_commands/server_modules/ft_options/ft_aggregate_options.py @@ -0,0 +1,293 @@ +# Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from enum import Enum +from typing import List, Mapping, Optional + +from glide.async_commands.command_args import OrderBy +from glide.async_commands.server_modules.ft_options.ft_constants import ( + FtAggregateKeywords, +) +from glide.constants import TEncodable + + +class FtAggregateClause(ABC): + """ + Abstract base class for the FT.AGGREGATE command clauses. + """ + + @abstractmethod + def to_args(self) -> List[TEncodable]: + """ + Get the arguments for the clause of the FT.AGGREGATE command. + + Returns: + List[TEncodable]: A list of arguments for the clause of the FT.AGGREGATE command. + """ + args: List[TEncodable] = [] + return args + + +class FtAggregateLimit(FtAggregateClause): + """ + A clause for limiting the number of retained records. + """ + + def __init__(self, offset: int, count: int): + """ + Initialize a new FtAggregateLimit instance. + + Args: + offset (int): Starting point from which the records have to be retained. + count (int): The total number of records to be retained. + """ + self.offset = offset + self.count = count + + def to_args(self) -> List[TEncodable]: + """ + Get the arguments for the Limit clause. + + Returns: + List[TEncodable]: A list of Limit clause arguments. + """ + return [FtAggregateKeywords.LIMIT, str(self.offset), str(self.count)] + + +class FtAggregateFilter(FtAggregateClause): + """ + A clause for filtering the results using predicate expression relating to values in each result. It is applied post query and relate to the current state of the pipeline. + """ + + def __init__(self, expression: TEncodable): + """ + Initialize a new FtAggregateFilter instance. + + Args: + expression (TEncodable): The expression to filter the results. + """ + self.expression = expression + + def to_args(self) -> List[TEncodable]: + """ + Get the arguments for the Filter clause. + + Returns: + List[TEncodable]: A list arguments for the filter clause. + """ + return [FtAggregateKeywords.FILTER, self.expression] + + +class FtAggregateReducer: + """ + A clause for reducing the matching results in each group using a reduction function. The matching results are reduced into a single record. + """ + + def __init__( + self, + function: TEncodable, + args: List[TEncodable], + name: Optional[TEncodable] = None, + ): + """ + Initialize a new FtAggregateReducer instance. + + Args: + function (TEncodable): The reduction function names for the respective group. + args (List[TEncodable]): The list of arguments for the reducer. + name (Optional[TEncodable]): User defined property name for the reducer. + """ + self.function = function + self.args = args + self.name = name + + def to_args(self) -> List[TEncodable]: + """ + Get the arguments for the Reducer. + + Returns: + List[TEncodable]: A list of arguments for the reducer. + """ + args: List[TEncodable] = [ + FtAggregateKeywords.REDUCE, + self.function, + str(len(self.args)), + ] + self.args + if self.name: + args.extend([FtAggregateKeywords.AS, self.name]) + return args + + +class FtAggregateGroupBy(FtAggregateClause): + """ + A clause for grouping the results in the pipeline based on one or more properties. + """ + + def __init__( + self, properties: List[TEncodable], reducers: List[FtAggregateReducer] + ): + """ + Initialize a new FtAggregateGroupBy instance. + + Args: + properties (List[TEncodable]): The list of properties to be used for grouping the results in the pipeline. + reducers (List[Reducer]): The list of functions that handles the group entries by performing multiple aggregate operations. + """ + self.properties = properties + self.reducers = reducers + + def to_args(self) -> List[TEncodable]: + args = [ + FtAggregateKeywords.GROUPBY, + str(len(self.properties)), + ] + self.properties + if self.reducers: + for reducer in self.reducers: + args.extend(reducer.to_args()) + return args + + +class FtAggregateSortProperty: + """ + This class represents the a single property for the SortBy clause. + """ + + def __init__(self, property: TEncodable, order: OrderBy): + """ + Initialize a new FtAggregateSortProperty instance. + + Args: + property (TEncodable): The sorting parameter. + order (OrderBy): The order for the sorting. This option can be added for each property. + """ + self.property = property + self.order = order + + def to_args(self) -> List[TEncodable]: + """ + Get the arguments for the SortBy clause property. + + Returns: + List[TEncodable]: A list of arguments for the SortBy clause property. + """ + return [self.property, self.order.value] + + +class FtAggregateSortBy(FtAggregateClause): + """ + A clause for sorting the pipeline up until the point of SORTBY, using a list of properties. + """ + + def __init__( + self, properties: List[FtAggregateSortProperty], max: Optional[int] = None + ): + """ + Initialize a new FtAggregateSortBy instance. + + Args: + properties (List[FtAggregateSortProperty]): A list of sorting parameters for the sort operation. + max: (Optional[int]): The MAX value for optimizing the sorting, by sorting only for the n-largest elements. + """ + self.properties = properties + self.max = max + + def to_args(self) -> List[TEncodable]: + """ + Get the arguments for the SortBy clause. + + Returns: + List[TEncodable]: A list of arguments for the SortBy clause. + """ + args: List[TEncodable] = [ + FtAggregateKeywords.SORTBY, + str(len(self.properties) * 2), + ] + for property in self.properties: + args.extend(property.to_args()) + if self.max: + args.extend([FtAggregateKeywords.MAX, str(self.max)]) + return args + + +class FtAggregateApply(FtAggregateClause): + """ + A clause for applying a 1-to-1 transformation on one or more properties and stores the result as a new property down the pipeline or replaces any property using this transformation. + """ + + def __init__(self, expression: TEncodable, name: TEncodable): + """ + Initialize a new FtAggregateApply instance. + + Args: + expression (TEncodable): The expression to be transformed. + name (TEncodable): The new property name to store the result of apply. This name can be referenced by further APPLY/SORTBY/GROUPBY/REDUCE operations down the pipeline. + """ + self.expression = expression + self.name = name + + def to_args(self) -> List[TEncodable]: + """ + Get the arguments for the Apply clause. + + Returns: + List[TEncodable]: A list of arguments for the Apply clause. + """ + return [ + FtAggregateKeywords.APPLY, + self.expression, + FtAggregateKeywords.AS, + self.name, + ] + + +class FtAggregateOptions: + """ + This class represents the optional arguments for the FT.AGGREGATE command. + """ + + def __init__( + self, + loadAll: Optional[bool] = False, + loadFields: Optional[List[TEncodable]] = [], + timeout: Optional[int] = None, + params: Optional[Mapping[TEncodable, TEncodable]] = {}, + clauses: Optional[List[FtAggregateClause]] = [], + ): + """ + Initialize a new FtAggregateOptions instance. + + Args: + loadAll (Optional[bool]): An option to load all fields declared in the index. + loadFields (Optional[List[TEncodable]]): An option to load only the fields passed in this list. + timeout (Optional[int]): Overrides the timeout parameter of the module. + params (Optional[Mapping[TEncodable, TEncodable]]): The key/value pairs can be referenced from within the query expression. + clauses (Optional[List[FtAggregateClause]]): FILTER, LIMIT, GROUPBY, SORTBY and APPLY clauses, that can be repeated multiple times in any order and be freely intermixed. They are applied in the order specified, with the output of one clause feeding the input of the next clause. + """ + self.loadAll = loadAll + self.loadFields = loadFields + self.timeout = timeout + self.params = params + self.clauses = clauses + + def to_args(self) -> List[TEncodable]: + """ + Get the optional arguments for the FT.AGGREGATE command. + + Returns: + List[TEncodable]: A list of optional arguments for the FT.AGGREGATE command. + """ + args: List[TEncodable] = [] + if self.loadAll: + args.extend([FtAggregateKeywords.LOAD, "*"]) + elif self.loadFields: + args.extend([FtAggregateKeywords.LOAD, str(len(self.loadFields))]) + args.extend(self.loadFields) + if self.timeout: + args.extend([FtAggregateKeywords.TIMEOUT, str(self.timeout)]) + if self.params: + args.extend([FtAggregateKeywords.PARAMS, str(len(self.params) * 2)]) + for [name, value] in self.params.items(): + args.extend([name, value]) + if self.clauses: + for clause in self.clauses: + args.extend(clause.to_args()) + return args diff --git a/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py b/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py new file mode 100644 index 0000000000..b3ef48e687 --- /dev/null +++ b/python/python/glide/async_commands/server_modules/ft_options/ft_constants.py @@ -0,0 +1,84 @@ +# Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + + +class CommandNames: + """ + Command name constants for vector search. + """ + + FT_CREATE = "FT.CREATE" + FT_DROPINDEX = "FT.DROPINDEX" + FT_LIST = "FT._LIST" + FT_SEARCH = "FT.SEARCH" + FT_INFO = "FT.INFO" + FT_ALIASADD = "FT.ALIASADD" + FT_ALIASDEL = "FT.ALIASDEL" + FT_ALIASUPDATE = "FT.ALIASUPDATE" + FT_EXPLAIN = "FT.EXPLAIN" + FT_EXPLAINCLI = "FT.EXPLAINCLI" + FT_AGGREGATE = "FT.AGGREGATE" + FT_PROFILE = "FT.PROFILE" + FT_ALIASLIST = "FT._ALIASLIST" + + +class FtCreateKeywords: + """ + Keywords used in the FT.CREATE command. + """ + + SCHEMA = "SCHEMA" + AS = "AS" + SORTABLE = "SORTABLE" + UNF = "UNF" + NO_INDEX = "NOINDEX" + ON = "ON" + PREFIX = "PREFIX" + SEPARATOR = "SEPARATOR" + CASESENSITIVE = "CASESENSITIVE" + DIM = "DIM" + DISTANCE_METRIC = "DISTANCE_METRIC" + TYPE = "TYPE" + INITIAL_CAP = "INITIAL_CAP" + M = "M" + EF_CONSTRUCTION = "EF_CONSTRUCTION" + EF_RUNTIME = "EF_RUNTIME" + + +class FtSearchKeywords: + """ + Keywords used in the FT.SEARCH command. + """ + + RETURN = "RETURN" + TIMEOUT = "TIMEOUT" + PARAMS = "PARAMS" + LIMIT = "LIMIT" + COUNT = "COUNT" + AS = "AS" + + +class FtAggregateKeywords: + """ + Keywords used in the FT.AGGREGATE command. + """ + + LIMIT = "LIMIT" + FILTER = "FILTER" + GROUPBY = "GROUPBY" + REDUCE = "REDUCE" + AS = "AS" + SORTBY = "SORTBY" + MAX = "MAX" + APPLY = "APPLY" + LOAD = "LOAD" + TIMEOUT = "TIMEOUT" + PARAMS = "PARAMS" + + +class FtProfileKeywords: + """ + Keywords used in the FT.PROFILE command. + """ + + QUERY = "QUERY" + LIMITED = "LIMITED" diff --git a/python/python/glide/async_commands/server_modules/ft_options/ft_create_options.py b/python/python/glide/async_commands/server_modules/ft_options/ft_create_options.py new file mode 100644 index 0000000000..551c160641 --- /dev/null +++ b/python/python/glide/async_commands/server_modules/ft_options/ft_create_options.py @@ -0,0 +1,409 @@ +# Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from enum import Enum +from typing import List, Optional + +from glide.async_commands.server_modules.ft_options.ft_constants import FtCreateKeywords +from glide.constants import TEncodable + + +class FieldType(Enum): + """ + All possible values for the data type of field identifier for the SCHEMA option. + """ + + TEXT = "TEXT" + """ + If the field contains any blob of data. + """ + TAG = "TAG" + """ + If the field contains a tag field. + """ + NUMERIC = "NUMERIC" + """ + If the field contains a number. + """ + VECTOR = "VECTOR" + """ + If the field is a vector field that supports vector search. + """ + + +class VectorAlgorithm(Enum): + """ + Algorithm for vector type fields used for vector similarity search. + """ + + HNSW = "HNSW" + """ + Hierarchical Navigable Small World algorithm. + """ + FLAT = "FLAT" + """ + Flat algorithm or the brute force algorithm. + """ + + +class DistanceMetricType(Enum): + """ + Distance metrics to measure the degree of similarity between two vectors. + + The above metrics calculate distance between two vectors, where the smaller the value is, the + closer the two vectors are in the vector space. + """ + + L2 = "L2" + """ + Euclidean distance + """ + IP = "IP" + """ + Inner product + """ + COSINE = "COSINE" + """ + Cosine distance + """ + + +class VectorType(Enum): + """ + Type type for the vector field type. + """ + + FLOAT32 = "FLOAT32" + """ + FLOAT32 type of vector. The only supported type. + """ + + +class Field(ABC): + """ + Abstract base class for a vector search field. + """ + + @abstractmethod + def __init__( + self, + name: TEncodable, + type: FieldType, + alias: Optional[TEncodable] = None, + ): + """ + Initialize a new field instance. + + Args: + name (TEncodable): The name of the field. + type (FieldType): The type of the field. + alias (Optional[TEncodable]): An alias for the field. + """ + self.name = name + self.type = type + self.alias = alias + + @abstractmethod + def to_args(self) -> List[TEncodable]: + """ + Get the arguments representing the field. + + Returns: + List[TEncodable]: A list of field arguments. + """ + args = [self.name] + if self.alias: + args.extend([FtCreateKeywords.AS, self.alias]) + args.append(self.type.value) + return args + + +class TextField(Field): + """ + Field contains any blob of data. + """ + + def __init__(self, name: TEncodable, alias: Optional[TEncodable] = None): + """ + Initialize a new TextField instance. + + Args: + name (TEncodable): The name of the text field. + alias (Optional[TEncodable]): An alias for the field. + """ + super().__init__(name, FieldType.TEXT, alias) + + def to_args(self) -> List[TEncodable]: + args = super().to_args() + return args + + +class TagField(Field): + """ + Tag fields are similar to full-text fields, but they interpret the text as a simple list of + tags delimited by a separator character. + + For `HASH fields, separator default is a comma `,`. For `JSON` fields, there is no + default separator; you must declare one explicitly if needed. + """ + + def __init__( + self, + name: TEncodable, + alias: Optional[TEncodable] = None, + separator: Optional[TEncodable] = None, + case_sensitive: bool = False, + ): + """ + Initialize a new TagField instance. + + Args: + name (TEncodable): The name of the tag field. + alias (Optional[TEncodable]): An alias for the field. + separator (Optional[TEncodable]): Specify how text in the attribute is split into individual tags. Must be a single character. + case_sensitive (bool): Preserve the original letter cases of tags. If set to False, characters are converted to lowercase by default. + """ + super().__init__(name, FieldType.TAG, alias) + self.separator = separator + self.case_sensitive = case_sensitive + + def to_args(self) -> List[TEncodable]: + args = super().to_args() + if self.separator: + args.extend([FtCreateKeywords.SEPARATOR, self.separator]) + if self.case_sensitive: + args.append(FtCreateKeywords.CASESENSITIVE) + return args + + +class NumericField(Field): + """ + Field contains a number. + """ + + def __init__(self, name: TEncodable, alias: Optional[TEncodable] = None): + """ + Initialize a new NumericField instance. + + Args: + name (TEncodable): The name of the numeric field. + alias (Optional[TEncodable]): An alias for the field. + """ + super().__init__(name, FieldType.NUMERIC, alias) + + def to_args(self) -> List[TEncodable]: + args = super().to_args() + return args + + +class VectorFieldAttributes(ABC): + """ + Abstract base class for defining vector field attributes to be used after the vector algorithm name. + """ + + @abstractmethod + def __init__( + self, dimensions: int, distance_metric: DistanceMetricType, type: VectorType + ): + """ + Initialize a new vector field attributes instance. + + Args: + dimensions (int): Number of dimensions in the vector. Equivalent to `DIM` on the module API. + distance_metric (DistanceMetricType): The distance metric used in vector type field. Can be one of `[L2 | IP | COSINE]`. Equivalent to `DISTANCE_METRIC` on the module API. + type (VectorType): Vector type. The only supported type is `FLOAT32`. Equivalent to `TYPE` on the module API. + """ + self.dimensions = dimensions + self.distance_metric = distance_metric + self.type = type + + @abstractmethod + def to_args(self) -> List[TEncodable]: + """ + Get the arguments to be used for the algorithm of the vector field. + + Returns: + List[TEncodable]: A list of arguments. + """ + args: List[TEncodable] = [] + if self.dimensions: + args.extend([FtCreateKeywords.DIM, str(self.dimensions)]) + if self.distance_metric: + args.extend([FtCreateKeywords.DISTANCE_METRIC, self.distance_metric.name]) + if self.type: + args.extend([FtCreateKeywords.TYPE, self.type.name]) + return args + + +class VectorFieldAttributesFlat(VectorFieldAttributes): + """ + Get the arguments to be used for the FLAT algorithm of the vector field. + """ + + def __init__( + self, + dimensions: int, + distance_metric: DistanceMetricType, + type: VectorType, + initial_cap: Optional[int] = None, + ): + """ + Initialize a new flat vector field attributes instance. + + Args: + dimensions (int): Number of dimensions in the vector. Equivalent to `DIM` on the module API. + distance_metric (DistanceMetricType): The distance metric used in vector type field. Can be one of `[L2 | IP | COSINE]`. Equivalent to `DISTANCE_METRIC` on the module API. + type (VectorType): Vector type. The only supported type is `FLOAT32`. Equivalent to `TYPE` on the module API. + initial_cap (Optional[int]): Initial vector capacity in the index affecting memory allocation size of the index. Defaults to `1024`. Equivalent to `INITIAL_CAP` on the module API. + """ + super().__init__(dimensions, distance_metric, type) + self.initial_cap = initial_cap + + def to_args(self) -> List[TEncodable]: + args = super().to_args() + if self.initial_cap: + args.extend([FtCreateKeywords.INITIAL_CAP, str(self.initial_cap)]) + return args + + +class VectorFieldAttributesHnsw(VectorFieldAttributes): + """ + Get the arguments to be used for the HNSW algorithm of the vector field. + """ + + def __init__( + self, + dimensions: int, + distance_metric: DistanceMetricType, + type: VectorType, + initial_cap: Optional[int] = None, + number_of_edges: Optional[int] = None, + vectors_examined_on_construction: Optional[int] = None, + vectors_examined_on_runtime: Optional[int] = None, + ): + """ + Initialize a new HNSW vector field attributes instance. + + Args: + dimensions (int): Number of dimensions in the vector. Equivalent to `DIM` on the module API. + distance_metric (DistanceMetricType): The distance metric used in vector type field. Can be one of `[L2 | IP | COSINE]`. Equivalent to `DISTANCE_METRIC` on the module API. + type (VectorType): Vector type. The only supported type is `FLOAT32`. Equivalent to `TYPE` on the module API. + initial_cap (Optional[int]): Initial vector capacity in the index affecting memory allocation size of the index. Defaults to `1024`. Equivalent to `INITIAL_CAP` on the module API. + number_of_edges (Optional[int]): Number of maximum allowed outgoing edges for each node in the graph in each layer. Default is `16`, maximum is `512`. Equivalent to `M` on the module API. + vectors_examined_on_construction (Optional[int]): Controls the number of vectors examined during index construction. Default value is `200`, Maximum value is `4096`. Equivalent to `EF_CONSTRUCTION` on the module API. + vectors_examined_on_runtime (Optional[int]): Controls the number of vectors examined during query operations. Default value is `10`, Maximum value is `4096`. Equivalent to `EF_RUNTIME` on the module API. + """ + super().__init__(dimensions, distance_metric, type) + self.initial_cap = initial_cap + self.number_of_edges = number_of_edges + self.vectors_examined_on_construction = vectors_examined_on_construction + self.vectors_examined_on_runtime = vectors_examined_on_runtime + + def to_args(self) -> List[TEncodable]: + args = super().to_args() + if self.initial_cap: + args.extend([FtCreateKeywords.INITIAL_CAP, str(self.initial_cap)]) + if self.number_of_edges: + args.extend([FtCreateKeywords.M, str(self.number_of_edges)]) + if self.vectors_examined_on_construction: + args.extend( + [ + FtCreateKeywords.EF_CONSTRUCTION, + str(self.vectors_examined_on_construction), + ] + ) + if self.vectors_examined_on_runtime: + args.extend( + [FtCreateKeywords.EF_RUNTIME, str(self.vectors_examined_on_runtime)] + ) + return args + + +class VectorField(Field): + """ + Class for defining vector field in a schema. + """ + + def __init__( + self, + name: TEncodable, + algorithm: VectorAlgorithm, + attributes: VectorFieldAttributes, + alias: Optional[TEncodable] = None, + ): + """ + Initialize a new VectorField instance. + + Args: + name (TEncodable): The name of the vector field. + algorithm (VectorAlgorithm): The vector indexing algorithm. + alias (Optional[TEncodable]): An alias for the field. + attributes (VectorFieldAttributes): Additional attributes to be passed with the vector field after the algorithm name. + """ + super().__init__(name, FieldType.VECTOR, alias) + self.algorithm = algorithm + self.attributes = attributes + + def to_args(self) -> List[TEncodable]: + args = super().to_args() + args.append(self.algorithm.value) + if self.attributes: + attribute_list = self.attributes.to_args() + args.append(str(len(attribute_list))) + args.extend(attribute_list) + return args + + +class DataType(Enum): + """ + Type of the index dataset. + """ + + HASH = "HASH" + """ + Data stored in hashes, so field identifiers are field names within the hashes. + """ + JSON = "JSON" + """ + Data stored as a JSON document, so field identifiers are JSON Path expressions. + """ + + +class FtCreateOptions: + """ + This class represents the input options to be used in the FT.CREATE command. + All fields in this class are optional inputs for FT.CREATE. + """ + + def __init__( + self, + data_type: Optional[DataType] = None, + prefixes: Optional[List[TEncodable]] = None, + ): + """ + Initialize the FT.CREATE optional fields. + + Args: + data_type (Optional[DataType]): The index data type. If not defined a `HASH` index is created. + prefixes (Optional[List[TEncodable]]): A list of prefixes of index definitions. + """ + self.data_type = data_type + self.prefixes = prefixes + + def to_args(self) -> List[TEncodable]: + """ + Get the optional arguments for the FT.CREATE command. + + Returns: + List[TEncodable]: + List of FT.CREATE optional agruments. + """ + args: List[TEncodable] = [] + if self.data_type: + args.append(FtCreateKeywords.ON) + args.append(self.data_type.value) + if self.prefixes: + args.append(FtCreateKeywords.PREFIX) + args.append(str(len(self.prefixes))) + for prefix in self.prefixes: + args.append(prefix) + return args diff --git a/python/python/glide/async_commands/server_modules/ft_options/ft_profile_options.py b/python/python/glide/async_commands/server_modules/ft_options/ft_profile_options.py new file mode 100644 index 0000000000..46bbab7b9f --- /dev/null +++ b/python/python/glide/async_commands/server_modules/ft_options/ft_profile_options.py @@ -0,0 +1,108 @@ +# Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 +from enum import Enum +from typing import List, Optional, Union, cast + +from glide.async_commands.server_modules.ft_options.ft_aggregate_options import ( + FtAggregateOptions, +) +from glide.async_commands.server_modules.ft_options.ft_constants import ( + FtProfileKeywords, +) +from glide.async_commands.server_modules.ft_options.ft_search_options import ( + FtSearchOptions, +) +from glide.constants import TEncodable + + +class QueryType(Enum): + """ + This class represents the query type being profiled. + """ + + AGGREGATE = "AGGREGATE" + """ + If the query being profiled is for the FT.AGGREGATE command. + """ + SEARCH = "SEARCH" + """ + If the query being profiled is for the FT.SEARCH command. + """ + + +class FtProfileOptions: + """ + This class represents the arguments/options for the FT.PROFILE command. + """ + + def __init__( + self, + query: TEncodable, + query_type: QueryType, + query_options: Optional[Union[FtSearchOptions, FtAggregateOptions]] = None, + limited: Optional[bool] = False, + ): + """ + Initialize a new FtProfileOptions instance. + + Args: + query (TEncodable): The query that is being profiled. This is the query argument from the FT.AGGREGATE/FT.SEARCH command. + query_type (Optional[QueryType]): The type of query to be profiled. + query_options (Optional[Union[FtSearchOptions, FtAggregateOptions]]): The arguments/options for the FT.AGGREGATE/FT.SEARCH command being profiled. + limited (Optional[bool]): To provide some brief version of the output, otherwise a full verbose output is provided. + """ + self.query = query + self.query_type = query_type + self.query_options = query_options + self.limited = limited + + @classmethod + def from_query_options( + cls, + query: TEncodable, + query_options: Union[FtSearchOptions, FtAggregateOptions], + limited: Optional[bool] = False, + ): + """ + A class method to create FtProfileOptions with FT.SEARCH/FT.AGGREGATE options. + + Args: + query (TEncodable): The query that is being profiled. This is the query argument from the FT.AGGREGATE/FT.SEARCH command. + query_options (Optional[Union[FtSearchOptions, FtAggregateOptions]]): The arguments/options for the FT.AGGREGATE/FT.SEARCH command being profiled. + limited (Optional[bool]): To provide some brief version of the output, otherwise a full verbose output is provided. + """ + query_type: QueryType = QueryType.SEARCH + if type(query_options) == FtAggregateOptions: + query_type = QueryType.AGGREGATE + return cls(query, query_type, query_options, limited) + + @classmethod + def from_query_type( + cls, query: TEncodable, query_type: QueryType, limited: Optional[bool] = False + ): + """ + A class method to create FtProfileOptions with QueryType. + + Args: + query (TEncodable): The query that is being profiled. This is the query argument from the FT.AGGREGATE/FT.SEARCH command. + query_type (QueryType): The type of query to be profiled. + limited (Optional[bool]): To provide some brief version of the output, otherwise a full verbose output is provided. + """ + return cls(query, query_type, None, limited) + + def to_args(self) -> List[TEncodable]: + """ + Get the remaining arguments for the FT.PROFILE command. + + Returns: + List[TEncodable]: A list of remaining arguments for the FT.PROFILE command. + """ + args: List[TEncodable] = [self.query_type.value] + if self.limited: + args.append(FtProfileKeywords.LIMITED) + args.extend([FtProfileKeywords.QUERY, self.query]) + if self.query_options: + if type(self.query_options) == FtAggregateOptions: + args.extend(cast(FtAggregateOptions, self.query_options).to_args()) + else: + args.extend(cast(FtSearchOptions, self.query_options).to_args()) + return args diff --git a/python/python/glide/async_commands/server_modules/ft_options/ft_search_options.py b/python/python/glide/async_commands/server_modules/ft_options/ft_search_options.py new file mode 100644 index 0000000000..f76b309b0f --- /dev/null +++ b/python/python/glide/async_commands/server_modules/ft_options/ft_search_options.py @@ -0,0 +1,131 @@ +# Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +from typing import List, Mapping, Optional + +from glide.async_commands.server_modules.ft_options.ft_constants import FtSearchKeywords +from glide.constants import TEncodable + + +class FtSearchLimit: + """ + This class represents the arguments for the LIMIT option of the FT.SEARCH command. + """ + + def __init__(self, offset: int, count: int): + """ + Initialize a new FtSearchLimit instance. + + Args: + offset (int): The number of keys to skip before returning the result for the FT.SEARCH command. + count (int): The total number of keys to be returned by FT.SEARCH command. + """ + self.offset = offset + self.count = count + + def to_args(self) -> List[TEncodable]: + """ + Get the arguments for the LIMIT option of FT.SEARCH. + + Returns: + List[TEncodable]: A list of LIMIT option arguments. + """ + args: List[TEncodable] = [ + FtSearchKeywords.LIMIT, + str(self.offset), + str(self.count), + ] + return args + + +class ReturnField: + """ + This class represents the arguments for the RETURN option of the FT.SEARCH command. + """ + + def __init__( + self, field_identifier: TEncodable, alias: Optional[TEncodable] = None + ): + """ + Initialize a new ReturnField instance. + + Args: + field_identifier (TEncodable): The identifier for the field of the key that has to returned as a result of FT.SEARCH command. + alias (Optional[TEncodable]): The alias to override the name of the field in the FT.SEARCH result. + """ + self.field_identifier = field_identifier + self.alias = alias + + def to_args(self) -> List[TEncodable]: + """ + Get the arguments for the RETURN option of FT.SEARCH. + + Returns: + List[TEncodable]: A list of RETURN option arguments. + """ + args: List[TEncodable] = [self.field_identifier] + if self.alias: + args.append(FtSearchKeywords.AS) + args.append(self.alias) + return args + + +class FtSearchOptions: + """ + This class represents the input options to be used in the FT.SEARCH command. + All fields in this class are optional inputs for FT.SEARCH. + """ + + def __init__( + self, + return_fields: Optional[List[ReturnField]] = None, + timeout: Optional[int] = None, + params: Optional[Mapping[TEncodable, TEncodable]] = None, + limit: Optional[FtSearchLimit] = None, + count: Optional[bool] = False, + ): + """ + Initialize the FT.SEARCH optional fields. + + Args: + return_fields (Optional[List[ReturnField]]): The fields of a key that are returned by FT.SEARCH command. See `ReturnField`. + timeout (Optional[int]): This value overrides the timeout parameter of the module. The unit for the timout is in milliseconds. + params (Optional[Mapping[TEncodable, TEncodable]]): Param key/value pairs that can be referenced from within the query expression. + limit (Optional[FtSearchLimit]): This option provides pagination capability. Only the keys that satisfy the offset and count values are returned. See `FtSearchLimit`. + count (Optional[bool]): This flag option suppresses returning the contents of keys. Only the number of keys is returned. + """ + self.return_fields = return_fields + self.timeout = timeout + self.params = params + self.limit = limit + self.count = count + + def to_args(self) -> List[TEncodable]: + """ + Get the optional arguments for the FT.SEARCH command. + + Returns: + List[TEncodable]: + List of FT.SEARCH optional agruments. + """ + args: List[TEncodable] = [] + if self.return_fields: + args.append(FtSearchKeywords.RETURN) + return_field_args: List[TEncodable] = [] + for return_field in self.return_fields: + return_field_args.extend(return_field.to_args()) + args.append(str(len(return_field_args))) + args.extend(return_field_args) + if self.timeout: + args.append(FtSearchKeywords.TIMEOUT) + args.append(str(self.timeout)) + if self.params: + args.append(FtSearchKeywords.PARAMS) + args.append(str(len(self.params))) + for name, value in self.params.items(): + args.append(name) + args.append(value) + if self.limit: + args.extend(self.limit.to_args()) + if self.count: + args.append(FtSearchKeywords.COUNT) + return args diff --git a/python/python/glide/async_commands/server_modules/glide_json.py b/python/python/glide/async_commands/server_modules/glide_json.py new file mode 100644 index 0000000000..ba9fe26e57 --- /dev/null +++ b/python/python/glide/async_commands/server_modules/glide_json.py @@ -0,0 +1,1254 @@ +# Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 +"""Glide module for `JSON` commands. + + Examples: + + >>> from glide import glide_json + >>> import json + >>> value = {'a': 1.0, 'b': 2} + >>> json_str = json.dumps(value) # Convert Python dictionary to JSON string using json.dumps() + >>> await json.set(client, "doc", "$", json_str) + 'OK' # Indicates successful setting of the value at path '$' in the key stored at `doc`. + >>> json_get = await glide_json.get(client, "doc", "$") # Returns the value at path '$' in the JSON document stored at `doc` as JSON string. + >>> print(json_get) + b"[{\"a\":1.0,\"b\":2}]" + >>> json.loads(str(json_get)) + [{"a": 1.0, "b" :2}] # JSON object retrieved from the key `doc` using json.loads() + """ +from typing import List, Optional, Union, cast + +from glide.async_commands.core import ConditionalChange +from glide.constants import TOK, TEncodable, TJsonResponse, TJsonUniversalResponse +from glide.glide_client import TGlideClient +from glide.protobuf.command_request_pb2 import RequestType + + +class JsonGetOptions: + """ + Represents options for formatting JSON data, to be used in the [JSON.GET](https://valkey.io/commands/json.get/) command. + + Args: + indent (Optional[str]): Sets an indentation string for nested levels. Defaults to None. + newline (Optional[str]): Sets a string that's printed at the end of each line. Defaults to None. + space (Optional[str]): Sets a string that's put between a key and a value. Defaults to None. + """ + + def __init__( + self, + indent: Optional[str] = None, + newline: Optional[str] = None, + space: Optional[str] = None, + ): + self.indent = indent + self.new_line = newline + self.space = space + + def get_options(self) -> List[str]: + args = [] + if self.indent: + args.extend(["INDENT", self.indent]) + if self.new_line: + args.extend(["NEWLINE", self.new_line]) + if self.space: + args.extend(["SPACE", self.space]) + return args + + +class JsonArrIndexOptions: + """ + Options for the `JSON.ARRINDEX` command. + + Args: + start (int): The inclusive start index from which the search begins. Defaults to None. + end (Optional[int]): The exclusive end index where the search stops. Defaults to None. + + Note: + - If `start` is greater than `end`, the command returns `-1` to indicate that the value was not found. + - Indices that exceed the array bounds are automatically adjusted to the nearest valid position. + """ + + def __init__(self, start: int, end: Optional[int] = None): + self.start = start + self.end = end + + def to_args(self) -> List[str]: + """ + Get the options as a list of arguments for the JSON.ARRINDEX command. + + Returns: + List[str]: A list containing the start and end indices if specified. + """ + args = [str(self.start)] + if self.end is not None: + args.append(str(self.end)) + return args + + +class JsonArrPopOptions: + """ + Options for the JSON.ARRPOP command. + + Args: + path (TEncodable): The path within the JSON document. + index (Optional[int]): The index of the element to pop. If not specified, will pop the last element. + Out of boundary indexes are rounded to their respective array boundaries. Defaults to None. + """ + + def __init__(self, path: TEncodable, index: Optional[int] = None): + self.path = path + self.index = index + + def to_args(self) -> List[TEncodable]: + """ + Get the options as a list of arguments for the `JSON.ARRPOP` command. + + Returns: + List[TEncodable]: A list containing the path and, if specified, the index. + """ + args = [self.path] + if self.index is not None: + args.append(str(self.index)) + return args + + +async def set( + client: TGlideClient, + key: TEncodable, + path: TEncodable, + value: TEncodable, + set_condition: Optional[ConditionalChange] = None, +) -> Optional[TOK]: + """ + Sets the JSON value at the specified `path` stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (TEncodable): Represents the path within the JSON document where the value will be set. + The key will be modified only if `value` is added as the last child in the specified `path`, or if the specified `path` acts as the parent of a new child being added. + value (TEncodable): The value to set at the specific path, in JSON formatted bytes or str. + set_condition (Optional[ConditionalChange]): Set the value only if the given condition is met (within the key or path). + Equivalent to [`XX` | `NX`] in the RESP API. Defaults to None. + + Returns: + Optional[TOK]: If the value is successfully set, returns OK. + If `value` isn't set because of `set_condition`, returns None. + + Examples: + >>> from glide import glide_json + >>> import json + >>> value = {'a': 1.0, 'b': 2} + >>> json_str = json.dumps(value) + >>> await glide_json.set(client, "doc", "$", json_str) + 'OK' # Indicates successful setting of the value at path '$' in the key stored at `doc`. + """ + args = ["JSON.SET", key, path, value] + if set_condition: + args.append(set_condition.value) + + return cast(Optional[TOK], await client.custom_command(args)) + + +async def get( + client: TGlideClient, + key: TEncodable, + paths: Optional[Union[TEncodable, List[TEncodable]]] = None, + options: Optional[JsonGetOptions] = None, +) -> TJsonResponse[Optional[bytes]]: + """ + Retrieves the JSON value at the specified `paths` stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + paths (Optional[Union[TEncodable, List[TEncodable]]]): The path or list of paths within the JSON document. Default to None. + options (Optional[JsonGetOptions]): Options for formatting the byte representation of the JSON data. See `JsonGetOptions`. + + Returns: + TJsonResponse[Optional[bytes]]: + If one path is given: + For JSONPath (path starts with `$`): + Returns a stringified JSON list of bytes replies for every possible path, + or a byte string representation of an empty array, if path doesn't exists. + If `key` doesn't exist, returns None. + For legacy path (path doesn't start with `$`): + Returns a byte string representation of the value in `path`. + If `path` doesn't exist, an error is raised. + If `key` doesn't exist, returns None. + If multiple paths are given: + Returns a stringified JSON object in bytes, in which each path is a key, and it's corresponding value, is the value as if the path was executed in the command as a single path. + In case of multiple paths, and `paths` are a mix of both JSONPath and legacy path, the command behaves as if all are JSONPath paths. + For more information about the returned type, see `TJsonResponse`. + + Examples: + >>> from glide import glide_json, JsonGetOptions + >>> import json + >>> json_str = await glide_json.get(client, "doc", "$") + >>> json.loads(str(json_str)) # Parse JSON string to Python data + [{"a": 1.0, "b" :2}] # JSON object retrieved from the key `doc` using json.loads() + >>> await glide_json.get(client, "doc", "$") + b"[{\"a\":1.0,\"b\":2}]" # Returns the value at path '$' in the JSON document stored at `doc`. + >>> await glide_json.get(client, "doc", ["$.a", "$.b"], JsonGetOptions(indent=" ", newline="\n", space=" ")) + b"{\n \"$.a\": [\n 1.0\n ],\n \"$.b\": [\n 2\n ]\n}" # Returns the values at paths '$.a' and '$.b' in the JSON document stored at `doc`, with specified formatting options. + >>> await glide_json.get(client, "doc", "$.non_existing_path") + b"[]" # Returns an empty array since the path '$.non_existing_path' does not exist in the JSON document stored at `doc`. + """ + args = ["JSON.GET", key] + if options: + args.extend(options.get_options()) + if paths: + if isinstance(paths, (str, bytes)): + paths = [paths] + args.extend(paths) + + return cast(TJsonResponse[Optional[bytes]], await client.custom_command(args)) + + +async def arrappend( + client: TGlideClient, + key: TEncodable, + path: TEncodable, + values: List[TEncodable], +) -> TJsonResponse[int]: + """ + Appends one or more `values` to the JSON array at the specified `path` within the JSON document stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (TEncodable): Represents the path within the JSON document where the `values` will be appended. + values (TEncodable): The values to append to the JSON array at the specified path. + JSON string values must be wrapped with quotes. For example, to append `"foo"`, pass `"\"foo\""`. + + Returns: + TJsonResponse[int]: + For JSONPath (`path` starts with `$`): + Returns a list of integer replies for every possible path, indicating the new length of the array after appending `values`, + or None for JSON values matching the path that are not an array. + If `path` doesn't exist, an empty array will be returned. + For legacy path (`path` doesn't start with `$`): + Returns the length of the array after appending `values` to the array at `path`. + If multiple paths match, the length of the first updated array is returned. + If the JSON value at `path` is not a array or if `path` doesn't exist, an error is raised. + If `key` doesn't exist, an error is raised. + For more information about the returned type, see `TJsonResponse`. + + Examples: + >>> from glide import glide_json + >>> import json + >>> await glide_json.set(client, "doc", "$", '{"a": 1, "b": ["one", "two"]}') + 'OK' # Indicates successful setting of the value at path '$' in the key stored at `doc`. + >>> await glide_json.arrappend(client, "doc", ["three"], "$.b") + [3] # Returns the new length of the array at path '$.b' after appending the value. + >>> await glide_json.arrappend(client, "doc", ["four"], ".b") + 4 # Returns the new length of the array at path '.b' after appending the value. + >>> json.loads(await glide_json.get(client, "doc", ".")) + {"a": 1, "b": ["one", "two", "three", "four"]} # Returns the updated JSON document + """ + args = ["JSON.ARRAPPEND", key, path] + values + return cast(TJsonResponse[int], await client.custom_command(args)) + + +async def arrindex( + client: TGlideClient, + key: TEncodable, + path: TEncodable, + value: TEncodable, + options: Optional[JsonArrIndexOptions] = None, +) -> TJsonResponse[int]: + """ + Searches for the first occurrence of a scalar JSON value (i.e., a value that is neither an object nor an array) within arrays at the specified `path` in the JSON document stored at `key`. + + If specified, `options.start` and `options.end` define an inclusive-to-exclusive search range within the array. + (Where `options.start` is inclusive and `options.end` is exclusive). + + Out-of-range indices adjust to the nearest valid position, and negative values count from the end (e.g., `-1` is the last element, `-2` the second last). + + Setting `options.end` to `0` behaves like `-1`, extending the range to the array's end (inclusive). + + If `options.start` exceeds `options.end`, `-1` is returned, indicating that the value was not found. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (TEncodable): The path within the JSON document. + value (TEncodable): The value to search for within the arrays. + options (Optional[JsonArrIndexOptions]): Options specifying an inclusive `start` index and an optional exclusive `end` index for a range-limited search. + Defaults to the full array if not provided. See `JsonArrIndexOptions`. + + Returns: + Optional[TJsonResponse[int]]: + For JSONPath (`path` starts with `$`): + Returns an array of integers for every possible path, indicating of the first occurrence of `value` within the array, + or None for JSON values matching the path that are not an array. + A returned value of `-1` indicates that the value was not found in that particular array. + If `path` does not exist, an empty array will be returned. + For legacy path (`path` doesn't start with `$`): + Returns an integer representing the index of the first occurrence of `value` within the array at the specified path. + A returned value of `-1` indicates that the value was not found in that particular array. + If multiple paths match, the index of the value from the first matching array is returned. + If the JSON value at the `path` is not an array or if `path` does not exist, an error is raised. + If `key` does not exist, an error is raised. + For more information about the returned type, see `TJsonResponse`. + + Examples: + >>> from glide import glide_json + >>> await glide_json.set(client, "doc", "$", '[[], ["a"], ["a", "b"], ["a", "b", "c"]]') + 'OK' + >>> await glide_json.arrindex(client, "doc", "$[*]", '"b"') + [-1, -1, 1, 1] + >>> await glide_json.set(client, "doc", ".", '{"children": ["John", "Jack", "Tom", "Bob", "Mike"]}') + 'OK' + >>> await glide_json.arrindex(client, "doc", ".children", '"Tom"') + 2 + >>> await glide_json.set(client, "doc", "$", '{"fruits": ["apple", "banana", "cherry", "banana", "grape"]}') + 'OK' + >>> await glide_json.arrindex(client, "doc", "$.fruits", '"banana"', JsonArrIndexOptions(start=2, end=4)) + 3 + >>> await glide_json.set(client, "k", ".", '[1, 2, "a", 4, "a", 6, 7, "b"]') + 'OK' + >>> await glide_json.arrindex(client, "k", ".", '"b"', JsonArrIndexOptions(start=4, end=0)) + 7 # "b" found at index 7 within the specified range, treating end=0 as the entire array's end. + >>> await glide_json.arrindex(client, "k", ".", '"b"', JsonArrIndexOptions(start=4, end=-1)) + 7 # "b" found at index 7, with end=-1 covering the full array to its last element. + >>> await glide_json.arrindex(client, "k", ".", '"b"', JsonArrIndexOptions(start=4, end=7)) + -1 # "b" not found within the range from index 4 to exclusive end at index 7. + """ + args = ["JSON.ARRINDEX", key, path, value] + + if options: + args.extend(options.to_args()) + + return cast(TJsonResponse[int], await client.custom_command(args)) + + +async def arrinsert( + client: TGlideClient, + key: TEncodable, + path: TEncodable, + index: int, + values: List[TEncodable], +) -> TJsonResponse[int]: + """ + Inserts one or more values into the array at the specified `path` within the JSON document stored at `key`, before the given `index`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (TEncodable): The path within the JSON document. + index (int): The array index before which values are inserted. + values (List[TEncodable]): The JSON values to be inserted into the array, in JSON formatted bytes or str. + Json string values must be wrapped with single quotes. For example, to append "foo", pass '"foo"'. + + Returns: + TJsonResponse[int]: + For JSONPath (`path` starts with '$'): + Returns a list of integer replies for every possible path, indicating the new length of the array, + or None for JSON values matching the path that are not an array. + If `path` does not exist, an empty array will be returned. + For legacy path (`path` doesn't start with '$'): + Returns an integer representing the new length of the array. + If multiple paths are matched, returns the length of the first modified array. + If `path` doesn't exist or the value at `path` is not an array, an error is raised. + If the index is out of bounds, an error is raised. + If `key` doesn't exist, an error is raised. + For more information about the returned type, see `TJsonResponse`. + + Examples: + >>> from glide import glide_json + >>> await glide_json.set(client, "doc", "$", '[[], ["a"], ["a", "b"]]') + 'OK' + >>> await glide_json.arrinsert(client, "doc", "$[*]", 0, ['"c"', '{"key": "value"}', "true", "null", '["bar"]']) + [5, 6, 7] # New lengths of arrays after insertion + >>> await glide_json.get(client, "doc") + b'[["c",{"key":"value"},true,null,["bar"]],["c",{"key":"value"},true,null,["bar"],"a"],["c",{"key":"value"},true,null,["bar"],"a","b"]]' + + >>> await glide_json.set(client, "doc", "$", '[[], ["a"], ["a", "b"]]') + 'OK' + >>> await glide_json.arrinsert(client, "doc", ".", 0, ['"c"']) + 4 # New length of the root array after insertion + >>> await glide_json.get(client, "doc") + b'[\"c\",[],[\"a\"],[\"a\",\"b\"]]' + """ + args = ["JSON.ARRINSERT", key, path, str(index)] + values + return cast(TJsonResponse[int], await client.custom_command(args)) + + +async def arrlen( + client: TGlideClient, + key: TEncodable, + path: Optional[TEncodable] = None, +) -> Optional[TJsonResponse[int]]: + """ + Retrieves the length of the array at the specified `path` within the JSON document stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (Optional[TEncodable]): The path within the JSON document. Defaults to None. + + Returns: + Optional[TJsonResponse[int]]: + For JSONPath (`path` starts with `$`): + Returns a list of integer replies for every possible path, indicating the length of the array, + or None for JSON values matching the path that are not an array. + If `path` doesn't exist, an empty array will be returned. + For legacy path (`path` doesn't starts with `$`): + Returns the length of the array at `path`. + If multiple paths match, the length of the first array match is returned. + If the JSON value at `path` is not a array or if `path` doesn't exist, an error is raised. + If `key` doesn't exist, None is returned. + For more information about the returned type, see `TJsonResponse`. + + Examples: + >>> from glide import glide_json + >>> await glide_json.set(client, "doc", "$", '{"a": [1, 2, 3], "b": {"a": [1, 2], "c": {"a": 42}}}') + 'OK' # JSON is successfully set for doc + >>> await glide_json.arrlen(client, "doc", "$") + [None] # No array at the root path. + >>> await glide_json.arrlen(client, "doc", "$.a") + [3] # Retrieves the length of the array at path $.a. + >>> await glide_json.arrlen(client, "doc", "$..a") + [3, 2, None] # Retrieves lengths of arrays found at all levels of the path `$..a`. + >>> await glide_json.arrlen(client, "doc", "..a") + 3 # Legacy path retrieves the first array match at path `..a`. + >>> await glide_json.arrlen(client, "non_existing_key", "$.a") + None # Returns None because the key does not exist. + + >>> await glide_json.set(client, "doc", "$", '[1, 2, 3, 4]') + 'OK' # JSON is successfully set for doc + >>> await glide_json.arrlen(client, "doc") + 4 # Retrieves lengths of array in root. + """ + args = ["JSON.ARRLEN", key] + if path: + args.append(path) + return cast( + Optional[TJsonResponse[int]], + await client.custom_command(args), + ) + + +async def arrpop( + client: TGlideClient, + key: TEncodable, + options: Optional[JsonArrPopOptions] = None, +) -> Optional[TJsonResponse[bytes]]: + """ + Pops an element from the array located at the specified path within the JSON document stored at `key`. + If `options.index` is provided, it pops the element at that index instead of the last element. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + options (Optional[JsonArrPopOptions]): Options including the path and optional index. See `JsonArrPopOptions`. Default to None. + If not specified, attempts to pop the last element from the root value if it's an array. + If the root value is not an array, an error will be raised. + + Returns: + Optional[TJsonResponse[bytes]]: + For JSONPath (`options.path` starts with `$`): + Returns a list of bytes string replies for every possible path, representing the popped JSON values, + or None for JSON values matching the path that are not an array or are an empty array. + If `options.path` doesn't exist, an empty list will be returned. + For legacy path (`options.path` doesn't starts with `$`): + Returns a bytes string representing the popped JSON value, or None if the array at `options.path` is empty. + If multiple paths match, the value from the first matching array that is not empty is returned. + If the JSON value at `options.path` is not a array or if `options.path` doesn't exist, an error is raised. + If `key` doesn't exist, an error is raised. + For more information about the returned type, see `TJsonResponse`. + + Examples: + >>> from glide import glide_json + >>> await glide_json.set(client, "doc", "$", '{"a": [1, 2, true], "b": {"a": [3, 4, ["value", 3, false], 5], "c": {"a": 42}}}') + b'OK' + >>> await glide_json.arrpop(client, "doc", JsonArrPopOptions(path="$.a", index=1)) + [b'2'] # Pop second element from array at path $.a + >>> await glide_json.arrpop(client, "doc", JsonArrPopOptions(path="$..a")) + [b'true', b'5', None] # Pop last elements from all arrays matching path `$..a` + + #### Using a legacy path (..) to pop the first matching array + >>> await glide_json.arrpop(client, "doc", JsonArrPopOptions(path="..a")) + b"1" # First match popped (from array at path ..a) + + #### Even though only one value is returned from `..a`, subsequent arrays are also affected + >>> await glide_json.get(client, "doc", "$..a") + b"[[], [3, 4], 42]" # Remaining elements after pop show the changes + + >>> await glide_json.set(client, "doc", "$", '[[], ["a"], ["a", "b", "c"]]') + b'OK' # JSON is successfully set + >>> await glide_json.arrpop(client, "doc", JsonArrPopOptions(path=".", index=-1)) + b'["a","b","c"]' # Pop last elements at path `.` + >>> await glide_json.arrpop(client, "doc") + b'["a"]' # Pop last elements at path `.` + """ + args = ["JSON.ARRPOP", key] + if options: + args.extend(options.to_args()) + + return cast( + Optional[TJsonResponse[bytes]], + await client.custom_command(args), + ) + + +async def arrtrim( + client: TGlideClient, + key: TEncodable, + path: TEncodable, + start: int, + end: int, +) -> TJsonResponse[int]: + """ + Trims an array at the specified `path` within the JSON document stored at `key` so that it becomes a subarray [start, end], both inclusive. + If `start` < 0, it is treated as 0. + If `end` >= size (size of the array), it is treated as size-1. + If `start` >= size or `start` > `end`, the array is emptied and 0 is returned. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (TEncodable): The path within the JSON document. + start (int): The start index, inclusive. + end (int): The end index, inclusive. + + Returns: + TJsonResponse[int]: + For JSONPath (`path` starts with '$'): + Returns a list of integer replies for every possible path, indicating the new length of the array, or None for JSON values matching the path that are not an array. + If a value is an empty array, its corresponding return value is 0. + If `path` doesn't exist, an empty array will be returned. + For legacy path (`path` doesn't starts with `$`): + Returns an integer representing the new length of the array. + If the array is empty, returns 0. + If multiple paths match, the length of the first trimmed array match is returned. + If `path` doesn't exist, or the value at `path` is not an array, an error is raised. + If `key` doesn't exist, an error is raised. + For more information about the returned type, see `TJsonResponse`. + + Examples: + >>> from glide import glide_json + >>> await glide_json.set(client, "doc", "$", '[[], ["a"], ["a", "b"], ["a", "b", "c"]]') + 'OK' + >>> await glide_json.arrtrim(client, "doc", "$[*]", 0, 1) + [0, 1, 2, 2] + >>> await glide_json.get(client, "doc") + b'[[],[\"a\"],[\"a\",\"b\"],[\"a\",\"b\"]]' + + >>> await glide_json.set(client, "doc", "$", '{"children": ["John", "Jack", "Tom", "Bob", "Mike"]}') + 'OK' + >>> await glide_json.arrtrim(client, "doc", ".children", 0, 1) + 2 + >>> await glide_json.get(client, "doc", ".children") + b'["John","Jack"]' + """ + return cast( + TJsonResponse[int], + await client.custom_command(["JSON.ARRTRIM", key, path, str(start), str(end)]), + ) + + +async def clear( + client: TGlideClient, + key: TEncodable, + path: Optional[str] = None, +) -> int: + """ + Clears arrays or objects at the specified JSON path in the document stored at `key`. + Numeric values are set to `0`, and boolean values are set to `False`, and string values are converted to empty strings. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (Optional[str]): The path within the JSON document. Default to None. + + Returns: + int: The number of containers cleared, numeric values zeroed, and booleans toggled to `false`, + and string values converted to empty strings. + If `path` doesn't exist, or the value at `path` is already empty (e.g., an empty array, object, or string), 0 is returned. + If `key doesn't exist, an error is raised. + + Examples: + >>> from glide import glide_json + >>> await glide_json.set(client, "doc", "$", '{"obj":{"a":1, "b":2}, "arr":[1,2,3], "str": "foo", "bool": true, "int": 42, "float": 3.14, "nullVal": null}') + 'OK' # JSON document is successfully set. + >>> await glide_json.clear(client, "doc", "$.*") + 6 # 6 values are cleared (arrays/objects/strings/numbers/booleans), but `null` remains as is. + >>> await glide_json.get(client, "doc", "$") + b'[{"obj":{},"arr":[],"str":"","bool":false,"int":0,"float":0.0,"nullVal":null}]' + >>> await glide_json.clear(client, "doc", "$.*") + 0 # No further clearing needed since the containers are already empty and the values are defaults. + + >>> await glide_json.set(client, "doc", "$", '{"a": 1, "b": {"a": [5, 6, 7], "b": {"a": true}}, "c": {"a": "value", "b": {"a": 3.5}}, "d": {"a": {"foo": "foo"}}, "nullVal": null}') + 'OK' + >>> await glide_json.clear(client, "doc", "b.a[1:3]") + 2 # 2 elements (`6` and `7`) are cleared. + >>> await glide_json.clear(client, "doc", "b.a[1:3]") + 0 # No elements cleared since specified slice has already been cleared. + >>> await glide_json.get(client, "doc", "$..a") + b'[1,[5,0,0],true,"value",3.5,{"foo":"foo"}]' + + >>> await glide_json.clear(client, "doc", "$..a") + 6 # All numeric, boolean, and string values across paths are cleared. + >>> await glide_json.get(client, "doc", "$..a") + b'[0,[],false,"",0.0,{}]' + """ + args = ["JSON.CLEAR", key] + if path: + args.append(path) + + return cast(int, await client.custom_command(args)) + + +async def debug_fields( + client: TGlideClient, + key: TEncodable, + path: Optional[TEncodable] = None, +) -> Optional[TJsonUniversalResponse[int]]: + """ + Returns the number of fields of the JSON value at the specified `path` within the JSON document stored at `key`. + - **Primitive Values**: Each non-container JSON value (e.g., strings, numbers, booleans, and null) counts as one field. + - **Arrays and Objects:**: Each item in an array and each key-value pair in an object is counted as one field. (Each top-level value counts as one field, regardless of it's type.) + - Their nested values are counted recursively and added to the total. + - **Example**: For the JSON `{"a": 1, "b": [2, 3, {"c": 4}]}`, the count would be: + - Top-level: 2 fields (`"a"` and `"b"`) + - Nested: 3 fields in the array (`2`, `3`, and `{"c": 4}`) plus 1 for the object (`"c"`) + - Total: 2 (top-level) + 3 (from array) + 1 (from nested object) = 6 fields. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (Optional[TEncodable]): The path within the JSON document. Defaults to root if not provided. + + Returns: + Optional[TJsonUniversalResponse[int]]: + For JSONPath (`path` starts with `$`): + Returns an array of integers, each indicating the number of fields for each matched `path`. + If `path` doesn't exist, an empty array will be returned. + For legacy path (`path` doesn't start with `$`): + Returns an integer indicating the number of fields for each matched `path`. + If multiple paths match, number of fields of the first JSON value match is returned. + If `path` doesn't exist, an error is raised. + If `path` is not provided, it reports the total number of fields in the entire JSON document. + If `key` doesn't exist, None is returned. + For more information about the returned type, see `TJsonUniversalResponse`. + + Examples: + >>> from glide import glide_json + >>> await glide_json.set(client, "k1", "$", '[1, 2.3, "foo", true, null, {}, [], {"a":1, "b":2}, [1,2,3]]') + 'OK' + >>> await glide_json.debug_fields(client, "k1", "$[*]") + [1, 1, 1, 1, 1, 0, 0, 2, 3] + >>> await glide_json.debug_fields(client, "k1", ".") + 14 # 9 top-level fields + 5 nested address fields + + >>> await glide_json.set(client, "k1", "$", '{"firstName":"John","lastName":"Smith","age":27,"weight":135.25,"isAlive":true,"address":{"street":"21 2nd Street","city":"New York","state":"NY","zipcode":"10021-3100"},"phoneNumbers":[{"type":"home","number":"212 555-1234"},{"type":"office","number":"646 555-4567"}],"children":[],"spouse":null}') + 'OK' + >>> await glide_json.debug_fields(client, "k1") + 19 + >>> await glide_json.debug_fields(client, "k1", ".address") + 4 + """ + args = ["JSON.DEBUG", "FIELDS", key] + if path: + args.append(path) + + return cast( + Optional[TJsonUniversalResponse[int]], await client.custom_command(args) + ) + + +async def debug_memory( + client: TGlideClient, + key: TEncodable, + path: Optional[TEncodable] = None, +) -> Optional[TJsonUniversalResponse[int]]: + """ + Reports memory usage in bytes of a JSON value at the specified `path` within the JSON document stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (Optional[TEncodable]): The path within the JSON document. Defaults to None. + + Returns: + Optional[TJsonUniversalResponse[int]]: + For JSONPath (`path` starts with `$`): + Returns an array of integers, indicating the memory usage in bytes of a JSON value for each matched `path`. + If `path` doesn't exist, an empty array will be returned. + For legacy path (`path` doesn't start with `$`): + Returns an integer, indicating the memory usage in bytes for the JSON value in `path`. + If multiple paths match, the memory usage of the first JSON value match is returned. + If `path` doesn't exist, an error is raised. + If `path` is not provided, it reports the total memory usage in bytes in the entire JSON document. + If `key` doesn't exist, None is returned. + For more information about the returned type, see `TJsonUniversalResponse`. + + Examples: + >>> from glide import glide_json + >>> await glide_json.set(client, "k1", "$", '[1, 2.3, "foo", true, null, {}, [], {"a":1, "b":2}, [1,2,3]]') + 'OK' + >>> await glide_json.debug_memory(client, "k1", "$[*]") + [16, 16, 19, 16, 16, 16, 16, 66, 64] + + >>> await glide_json.set(client, "k1", "$", '{"firstName":"John","lastName":"Smith","age":27,"weight":135.25,"isAlive":true,"address":{"street":"21 2nd Street","city":"New York","state":"NY","zipcode":"10021-3100"},"phoneNumbers":[{"type":"home","number":"212 555-1234"},{"type":"office","number":"646 555-4567"}],"children":[],"spouse":null}') + 'OK' + >>> await glide_json.debug_memory(client, "k1") + 472 + >>> await glide_json.debug_memory(client, "k1", ".phoneNumbers") + 164 + """ + args = ["JSON.DEBUG", "MEMORY", key] + if path: + args.append(path) + + return cast( + Optional[TJsonUniversalResponse[int]], await client.custom_command(args) + ) + + +async def delete( + client: TGlideClient, + key: TEncodable, + path: Optional[TEncodable] = None, +) -> int: + """ + Deletes the JSON value at the specified `path` within the JSON document stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (Optional[TEncodable]): The path within the JSON document. + If None, deletes the entire JSON document at `key`. Defaults to None. + + Returns: + int: The number of elements removed. + If `key` or `path` doesn't exist, returns 0. + + Examples: + >>> from glide import glide_json + >>> await glide_json.set(client, "doc", "$", '{"a": 1, "nested": {"a": 2, "b": 3}}') + 'OK' # Indicates successful setting of the value at path '$' in the key stored at `doc`. + >>> await glide_json.delete(client, "doc", "$..a") + 2 # Indicates successful deletion of the specific values in the key stored at `doc`. + >>> await glide_json.get(client, "doc", "$") + "[{\"nested\":{\"b\":3}}]" # Returns the value at path '$' in the JSON document stored at `doc`. + >>> await glide_json.delete(client, "doc") + 1 # Deletes the entire JSON document stored at `doc`. + """ + + return cast( + int, await client.custom_command(["JSON.DEL", key] + ([path] if path else [])) + ) + + +async def forget( + client: TGlideClient, + key: TEncodable, + path: Optional[TEncodable] = None, +) -> Optional[int]: + """ + Deletes the JSON value at the specified `path` within the JSON document stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (Optional[TEncodable]): The path within the JSON document. + If None, deletes the entire JSON document at `key`. Defaults to None. + + Returns: + int: The number of elements removed. + If `key` or `path` doesn't exist, returns 0. + + Examples: + >>> from glide import glide_json + >>> await glide_json.set(client, "doc", "$", '{"a": 1, "nested": {"a": 2, "b": 3}}') + 'OK' # Indicates successful setting of the value at path '$' in the key stored at `doc`. + >>> await glide_json.forget(client, "doc", "$..a") + 2 # Indicates successful deletion of the specific values in the key stored at `doc`. + >>> await glide_json.get(client, "doc", "$") + "[{\"nested\":{\"b\":3}}]" # Returns the value at path '$' in the JSON document stored at `doc`. + >>> await glide_json.forget(client, "doc") + 1 # Deletes the entire JSON document stored at `doc`. + """ + + return cast( + Optional[int], + await client.custom_command(["JSON.FORGET", key] + ([path] if path else [])), + ) + + +async def mget( + client: TGlideClient, + keys: List[TEncodable], + path: TEncodable, +) -> List[Optional[bytes]]: + """ + Retrieves the JSON values at the specified `path` stored at multiple `keys`. + + Note: + In cluster mode, if keys in `keys` map to different hash slots, the command + will be split across these slots and executed separately for each. This means the command + is atomic only at the slot level. If one or more slot-specific requests fail, the entire + call will return the first encountered error, even though some requests may have succeeded + while others did not. If this behavior impacts your application logic, consider splitting + the request into sub-requests per slot to ensure atomicity. + + Args: + client (TGlideClient): The client to execute the command. + keys (List[TEncodable]): A list of keys for the JSON documents. + path (TEncodable): The path within the JSON documents. + + Returns: + List[Optional[bytes]]: + For JSONPath (`path` starts with `$`): + Returns a list of byte representations of the values found at the given path for each key. + If `path` does not exist within the key, the entry will be an empty array. + For legacy path (`path` doesn't starts with `$`): + Returns a list of byte representations of the values found at the given path for each key. + If `path` does not exist within the key, the entry will be None. + If a key doesn't exist, the corresponding list element will be None. + + + Examples: + >>> from glide import glide_json + >>> import json + >>> json_strs = await glide_json.mget(client, ["doc1", "doc2"], "$") + >>> [json.loads(js) for js in json_strs] # Parse JSON strings to Python data + [[{"a": 1.0, "b": 2}], [{"a": 2.0, "b": {"a": 3.0, "b" : 4.0}}]] # JSON objects retrieved from keys `doc1` and `doc2` + >>> await glide_json.mget(client, ["doc1", "doc2"], "$.a") + [b"[1.0]", b"[2.0]"] # Returns values at path '$.a' for the JSON documents stored at `doc1` and `doc2`. + >>> await glide_json.mget(client, ["doc1"], "$.non_existing_path") + [None] # Returns an empty array since the path '$.non_existing_path' does not exist in the JSON document stored at `doc1`. + """ + args = ["JSON.MGET"] + keys + [path] + return cast(TJsonResponse[Optional[bytes]], await client.custom_command(args)) + + +async def numincrby( + client: TGlideClient, + key: TEncodable, + path: TEncodable, + number: Union[int, float], +) -> bytes: + """ + Increments or decrements the JSON value(s) at the specified `path` by `number` within the JSON document stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (TEncodable): The path within the JSON document. + number (Union[int, float]): The number to increment or decrement by. + + Returns: + bytes: + For JSONPath (`path` starts with `$`): + Returns a bytes string representation of an array of bulk strings, indicating the new values after incrementing for each matched `path`. + If a value is not a number, its corresponding return value will be `null`. + If `path` doesn't exist, a byte string representation of an empty array will be returned. + For legacy path (`path` doesn't start with `$`): + Returns a bytes string representation of the resulting value after the increment or decrement. + If multiple paths match, the result of the last updated value is returned. + If the value at the `path` is not a number or `path` doesn't exist, an error is raised. + If `key` does not exist, an error is raised. + If the result is out of the range of 64-bit IEEE double, an error is raised. + + Examples: + >>> from glide import glide_json + >>> await glide_json.set(client, "doc", "$", '{"a": [], "b": [1], "c": [1, 2], "d": [1, 2, 3]}') + 'OK' + >>> await glide_json.numincrby(client, "doc", "$.d[*]", 10) + b'[11,12,13]' # Increment each element in `d` array by 10. + >>> await glide_json.numincrby(client, "doc", ".c[1]", 10) + b'12' # Increment the second element in the `c` array by 10. + """ + args = ["JSON.NUMINCRBY", key, path, str(number)] + + return cast(bytes, await client.custom_command(args)) + + +async def nummultby( + client: TGlideClient, + key: TEncodable, + path: TEncodable, + number: Union[int, float], +) -> bytes: + """ + Multiplies the JSON value(s) at the specified `path` by `number` within the JSON document stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (TEncodable): The path within the JSON document. + number (Union[int, float]): The number to multiply by. + + Returns: + bytes: + For JSONPath (`path` starts with `$`): + Returns a bytes string representation of an array of bulk strings, indicating the new values after multiplication for each matched `path`. + If a value is not a number, its corresponding return value will be `null`. + If `path` doesn't exist, a byte string representation of an empty array will be returned. + For legacy path (`path` doesn't start with `$`): + Returns a bytes string representation of the resulting value after multiplication. + If multiple paths match, the result of the last updated value is returned. + If the value at the `path` is not a number or `path` doesn't exist, an error is raised. + If `key` does not exist, an error is raised. + If the result is out of the range of 64-bit IEEE double, an error is raised. + + Examples: + >>> from glide import glide_json + >>> await glide_json.set(client, "doc", "$", '{"a": [], "b": [1], "c": [1, 2], "d": [1, 2, 3]}') + 'OK' + >>> await glide_json.nummultby(client, "doc", "$.d[*]", 2) + b'[2,4,6]' # Multiplies each element in the `d` array by 2. + >>> await glide_json.nummultby(client, "doc", ".c[1]", 2) + b'4' # Multiplies the second element in the `c` array by 2. + """ + args = ["JSON.NUMMULTBY", key, path, str(number)] + + return cast(bytes, await client.custom_command(args)) + + +async def objlen( + client: TGlideClient, + key: TEncodable, + path: Optional[TEncodable] = None, +) -> Optional[TJsonResponse[int]]: + """ + Retrieves the number of key-value pairs in the object stored at the specified `path` within the JSON document stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (Optional[TEncodable]): The path within the JSON document. Defaults to None. + + Returns: + Optional[TJsonResponse[int]]: + For JSONPath (`path` starts with `$`): + Returns a list of integer replies for every possible path, indicating the length of the object, + or None for JSON values matching the path that are not an object. + If `path` doesn't exist, an empty array will be returned. + For legacy path (`path` doesn't starts with `$`): + Returns the length of the object at `path`. + If multiple paths match, the length of the first object match is returned. + If the JSON value at `path` is not an object or if `path` doesn't exist, an error is raised. + If `key` doesn't exist, None is returned. + For more information about the returned type, see `TJsonResponse`. + + + Examples: + >>> from glide import glide_json + >>> await glide_json.set(client, "doc", "$", '{"a": 1.0, "b": {"a": {"x": 1, "y": 2}, "b": 2.5, "c": true}}') + b'OK' # Indicates successful setting of the value at the root path '$' in the key `doc`. + >>> await glide_json.objlen(client, "doc", "$") + [2] # Returns the number of key-value pairs at the root object, which has 2 keys: 'a' and 'b'. + >>> await glide_json.objlen(client, "doc", ".") + 2 # Returns the number of key-value pairs for the object matching the path '.', which has 2 keys: 'a' and 'b'. + >>> await glide_json.objlen(client, "doc", "$.b") + [3] # Returns the length of the object at path '$.b', which has 3 keys: 'a', 'b', and 'c'. + >>> await glide_json.objlen(client, "doc", ".b") + 3 # Returns the length of the nested object at path '.b', which has 3 keys. + >>> await glide_json.objlen(client, "doc", "$..a") + [None, 2] + >>> await glide_json.objlen(client, "doc") + 2 # Returns the number of key-value pairs for the object matching the path '.', which has 2 keys: 'a' and 'b'. + """ + args = ["JSON.OBJLEN", key] + if path: + args.append(path) + return cast( + Optional[TJsonResponse[int]], + await client.custom_command(args), + ) + + +async def objkeys( + client: TGlideClient, + key: TEncodable, + path: Optional[TEncodable] = None, +) -> Optional[TJsonUniversalResponse[List[bytes]]]: + """ + Retrieves key names in the object values at the specified `path` within the JSON document stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (Optional[TEncodable]): Represents the path within the JSON document where the key names will be retrieved. + Defaults to None. + + Returns: + Optional[TJsonUniversalResponse[List[bytes]]]: + For JSONPath (`path` starts with `$`): + Returns a list of arrays containing key names for each matching object. + If a value matching the path is not an object, an empty array is returned. + If `path` doesn't exist, an empty array is returned. + For legacy path (`path` starts with `.`): + Returns a list of key names for the object value matching the path. + If multiple objects match the path, the key names of the first object are returned. + If a value matching the path is not an object, an error is raised. + If `path` doesn't exist, None is returned. + If `key` doesn't exist, None is returned. + For more information about the returned type, see `TJsonUniversalResponse`. + + Examples: + >>> from glide import glide_json + >>> await glide_json.set(client, "doc", "$", '{"a": 1.0, "b": {"a": {"x": 1, "y": 2}, "b": 2.5, "c": true}}') + b'OK' # Indicates successful setting of the value at the root path '$' in the key `doc`. + >>> await glide_json.objkeys(client, "doc", "$") + [[b"a", b"b"]] # Returns a list of arrays containing the key names for objects matching the path '$'. + >>> await glide_json.objkeys(client, "doc", ".") + [b"a", b"b"] # Returns key names for the object matching the path '.' as it is the only match. + >>> await glide_json.objkeys(client, "doc", "$.b") + [[b"a", b"b", b"c"]] # Returns key names as a nested list for objects matching the JSONPath '$.b'. + >>> await glide_json.objkeys(client, "doc", ".b") + [b"a", b"b", b"c"] # Returns key names for the nested object at path '.b'. + """ + args = ["JSON.OBJKEYS", key] + if path: + args.append(path) + return cast( + Optional[Union[List[bytes], List[List[bytes]]]], + await client.custom_command(args), + ) + + +async def resp( + client: TGlideClient, key: TEncodable, path: Optional[TEncodable] = None +) -> TJsonUniversalResponse[ + Optional[Union[bytes, int, List[Optional[Union[bytes, int]]]]] +]: + """ + Retrieve the JSON value at the specified `path` within the JSON document stored at `key`. + The returning result is in the Valkey or Redis OSS Serialization Protocol (RESP).\n + JSON null is mapped to the RESP Null Bulk String.\n + JSON Booleans are mapped to RESP Simple string.\n + JSON integers are mapped to RESP Integers.\n + JSON doubles are mapped to RESP Bulk Strings.\n + JSON strings are mapped to RESP Bulk Strings.\n + JSON arrays are represented as RESP arrays, where the first element is the simple string [, followed by the array's elements.\n + JSON objects are represented as RESP object, where the first element is the simple string {, followed by key-value pairs, each of which is a RESP bulk string.\n + + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (Optional[TEncodable]): The path within the JSON document. Default to None. + + Returns: + TJsonUniversalResponse[Optional[Union[bytes, int, List[Optional[Union[bytes, int]]]]]] + For JSONPath ('path' starts with '$'): + Returns a list of replies for every possible path, indicating the RESP form of the JSON value. + If `path` doesn't exist, returns an empty list. + For legacy path (`path` doesn't starts with `$`): + Returns a single reply for the JSON value at the specified path, in its RESP form. + This can be a bytes object, an integer, None, or a list representing complex structures. + If multiple paths match, the value of the first JSON value match is returned. + If `path` doesn't exist, an error is raised. + If `key` doesn't exist, an None is returned. + For more information about the returned type, see `TJsonUniversalResponse`. + + Examples: + >>> from glide import glide_json + >>> await glide_json.set(client, "doc", "$", '{"a": [1, 2, 3], "b": {"a": [1, 2], "c": {"a": 42}}}') + 'OK' + >>> await glide_json.resp(client, "doc", "$..a") + [[b"[", 1, 2, 3],[b"[", 1, 2],42] + >>> await glide_json.resp(client, "doc", "..a") + [b"[", 1, 2, 3] + """ + args = ["JSON.RESP", key] + if path: + args.append(path) + + return cast( + TJsonUniversalResponse[ + Optional[Union[bytes, int, List[Optional[Union[bytes, int]]]]] + ], + await client.custom_command(args), + ) + + +async def strappend( + client: TGlideClient, + key: TEncodable, + value: TEncodable, + path: Optional[TEncodable] = None, +) -> TJsonResponse[int]: + """ + Appends the specified `value` to the string stored at the specified `path` within the JSON document stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + value (TEncodable): The value to append to the string. Must be wrapped with single quotes. For example, to append "foo", pass '"foo"'. + path (Optional[TEncodable]): The path within the JSON document. Default to None. + + Returns: + TJsonResponse[int]: + For JSONPath (`path` starts with `$`): + Returns a list of integer replies for every possible path, indicating the length of the resulting string after appending `value`, + or None for JSON values matching the path that are not string. + If `key` doesn't exist, an error is raised. + For legacy path (`path` doesn't start with `$`): + Returns the length of the resulting string after appending `value` to the string at `path`. + If multiple paths match, the length of the last updated string is returned. + If the JSON value at `path` is not a string of if `path` doesn't exist, an error is raised. + If `key` doesn't exist, an error is raised. + For more information about the returned type, see `TJsonResponse`. + + Examples: + >>> from glide import glide_json + >>> import json + >>> await glide_json.set(client, "doc", "$", json.dumps({"a":"foo", "nested": {"a": "hello"}, "nested2": {"a": 31}})) + 'OK' + >>> await glide_json.strappend(client, "doc", json.dumps("baz"), "$..a") + [6, 8, None] # The new length of the string values at path '$..a' in the key stored at `doc` after the append operation. + >>> await glide_json.strappend(client, "doc", '"foo"', "nested.a") + 11 # The length of the string value after appending "foo" to the string at path 'nested.array' in the key stored at `doc`. + >>> json.loads(await glide_json.get(client, json.dumps("doc"), "$")) + [{"a":"foobaz", "nested": {"a": "hellobazfoo"}, "nested2": {"a": 31}}] # The updated JSON value in the key stored at `doc`. + """ + + return cast( + TJsonResponse[int], + await client.custom_command( + ["JSON.STRAPPEND", key] + ([path, value] if path else [value]) + ), + ) + + +async def strlen( + client: TGlideClient, + key: TEncodable, + path: Optional[TEncodable] = None, +) -> TJsonResponse[Optional[int]]: + """ + Returns the length of the JSON string value stored at the specified `path` within the JSON document stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (Optional[TEncodable]): The path within the JSON document. Default to None. + + Returns: + TJsonResponse[Optional[int]]: + For JSONPath (`path` starts with `$`): + Returns a list of integer replies for every possible path, indicating the length of the JSON string value, + or None for JSON values matching the path that are not string. + For legacy path (`path` doesn't start with `$`): + Returns the length of the JSON value at `path` or None if `key` doesn't exist. + If multiple paths match, the length of the first mached string is returned. + If the JSON value at `path` is not a string of if `path` doesn't exist, an error is raised. + If `key` doesn't exist, None is returned. + For more information about the returned type, see `TJsonResponse`. + + Examples: + >>> from glide import glide_json + >>> import json + >>> await glide_json.set(client, "doc", "$", json.dumps({"a":"foo", "nested": {"a": "hello"}, "nested2": {"a": 31}})) + 'OK' + >>> await glide_json.strlen(client, "doc", "$..a") + [3, 5, None] # The length of the string values at path '$..a' in the key stored at `doc`. + >>> await glide_json.strlen(client, "doc", "nested.a") + 5 # The length of the JSON value at path 'nested.a' in the key stored at `doc`. + >>> await glide_json.strlen(client, "doc", "$") + [None] # Returns an array with None since the value at root path does in the JSON document stored at `doc` is not a string. + >>> await glide_json.strlen(client, "non_existing_key", ".") + None # `key` doesn't exist. + """ + + return cast( + TJsonResponse[Optional[int]], + await client.custom_command( + ["JSON.STRLEN", key, path] if path else ["JSON.STRLEN", key] + ), + ) + + +async def toggle( + client: TGlideClient, + key: TEncodable, + path: TEncodable, +) -> TJsonResponse[bool]: + """ + Toggles a Boolean value stored at the specified `path` within the JSON document stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (TEncodable): The path within the JSON document. Default to None. + + Returns: + TJsonResponse[bool]: + For JSONPath (`path` starts with `$`): + Returns a list of boolean replies for every possible path, with the toggled boolean value, + or None for JSON values matching the path that are not boolean. + If `key` doesn't exist, an error is raised. + For legacy path (`path` doesn't start with `$`): + Returns the value of the toggled boolean in `path`. + If the JSON value at `path` is not a boolean of if `path` doesn't exist, an error is raised. + If `key` doesn't exist, an error is raised. + For more information about the returned type, see `TJsonResponse`. + + Examples: + >>> from glide import glide_json + >>> import json + >>> await glide_json.set(client, "doc", "$", json.dumps({"bool": True, "nested": {"bool": False, "nested": {"bool": 10}}})) + 'OK' + >>> await glide_json.toggle(client, "doc", "$.bool") + [False, True, None] # Indicates successful toggling of the Boolean values at path '$.bool' in the key stored at `doc`. + >>> await glide_json.toggle(client, "doc", "bool") + True # Indicates successful toggling of the Boolean value at path 'bool' in the key stored at `doc`. + >>> json.loads(await glide_json.get(client, "doc", "$")) + [{"bool": True, "nested": {"bool": True, "nested": {"bool": 10}}}] # The updated JSON value in the key stored at `doc`. + """ + + return cast( + TJsonResponse[bool], + await client.custom_command(["JSON.TOGGLE", key, path]), + ) + + +async def type( + client: TGlideClient, + key: TEncodable, + path: Optional[TEncodable] = None, +) -> Optional[TJsonUniversalResponse[bytes]]: + """ + Retrieves the type of the JSON value at the specified `path` within the JSON document stored at `key`. + + Args: + client (TGlideClient): The client to execute the command. + key (TEncodable): The key of the JSON document. + path (Optional[TEncodable]): The path within the JSON document. Default to None. + + Returns: + Optional[TJsonUniversalResponse[bytes]]: + For JSONPath ('path' starts with '$'): + Returns a list of byte string replies for every possible path, indicating the type of the JSON value. + If `path` doesn't exist, an empty array will be returned. + For legacy path (`path` doesn't starts with `$`): + Returns the type of the JSON value at `path`. + If multiple paths match, the type of the first JSON value match is returned. + If `path` doesn't exist, None will be returned. + If `key` doesn't exist, None is returned. + For more information about the returned type, see `TJsonUniversalResponse`. + + Examples: + >>> from glide import glide_json + >>> await glide_json.set(client, "doc", "$", '{"a": 1, "nested": {"a": 2, "b": 3}}') + 'OK' + >>> await glide_json.type(client, "doc", "$.nested") + [b'object'] # Indicates the type of the value at path '$.nested' in the key stored at `doc`. + >>> await glide_json.type(client, "doc", "$.nested.a") + [b'integer'] # Indicates the type of the value at path '$.nested.a' in the key stored at `doc`. + >>> await glide_json.type(client, "doc", "$[*]") + [b'integer', b'object'] # Array of types in all top level elements. + """ + args = ["JSON.TYPE", key] + if path: + args.append(path) + + return cast( + Optional[TJsonUniversalResponse[bytes]], await client.custom_command(args) + ) diff --git a/python/python/glide/async_commands/server_modules/json.py b/python/python/glide/async_commands/server_modules/json.py deleted file mode 100644 index d1709806bc..0000000000 --- a/python/python/glide/async_commands/server_modules/json.py +++ /dev/null @@ -1,351 +0,0 @@ -# Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 -"""module for `RedisJSON` commands. - - Examples: - - >>> from glide import json as redisJson - >>> import json - >>> value = {'a': 1.0, 'b': 2} - >>> json_str = json.dumps(value) # Convert Python dictionary to JSON string using json.dumps() - >>> await redisJson.set(client, "doc", "$", json_str) - 'OK' # Indicates successful setting of the value at path '$' in the key stored at `doc`. - >>> json_get = await redisJson.get(client, "doc", "$") # Returns the value at path '$' in the JSON document stored at `doc` as JSON string. - >>> print(json_get) - b"[{\"a\":1.0,\"b\":2}]" - >>> json.loads(str(json_get)) - [{"a": 1.0, "b" :2}] # JSON object retrieved from the key `doc` using json.loads() - """ -from typing import List, Optional, Union, cast - -from glide.async_commands.core import ConditionalChange -from glide.constants import TOK, TEncodable, TJsonResponse -from glide.glide_client import TGlideClient -from glide.protobuf.command_request_pb2 import RequestType - - -class JsonGetOptions: - """ - Represents options for formatting JSON data, to be used in the [JSON.GET](https://valkey.io/commands/json.get/) command. - - Args: - indent (Optional[str]): Sets an indentation string for nested levels. Defaults to None. - newline (Optional[str]): Sets a string that's printed at the end of each line. Defaults to None. - space (Optional[str]): Sets a string that's put between a key and a value. Defaults to None. - """ - - def __init__( - self, - indent: Optional[str] = None, - newline: Optional[str] = None, - space: Optional[str] = None, - ): - self.indent = indent - self.new_line = newline - self.space = space - - def get_options(self) -> List[str]: - args = [] - if self.indent: - args.extend(["INDENT", self.indent]) - if self.new_line: - args.extend(["NEWLINE", self.new_line]) - if self.space: - args.extend(["SPACE", self.space]) - return args - - -async def set( - client: TGlideClient, - key: TEncodable, - path: TEncodable, - value: TEncodable, - set_condition: Optional[ConditionalChange] = None, -) -> Optional[TOK]: - """ - Sets the JSON value at the specified `path` stored at `key`. - - See https://valkey.io/commands/json.set/ for more details. - - Args: - client (TGlideClient): The Redis client to execute the command. - key (TEncodable): The key of the JSON document. - path (TEncodable): Represents the path within the JSON document where the value will be set. - The key will be modified only if `value` is added as the last child in the specified `path`, or if the specified `path` acts as the parent of a new child being added. - value (TEncodable): The value to set at the specific path, in JSON formatted bytes or str. - set_condition (Optional[ConditionalChange]): Set the value only if the given condition is met (within the key or path). - Equivalent to [`XX` | `NX`] in the Redis API. Defaults to None. - - Returns: - Optional[TOK]: If the value is successfully set, returns OK. - If value isn't set because of `set_condition`, returns None. - - Examples: - >>> from glide import json as redisJson - >>> import json - >>> value = {'a': 1.0, 'b': 2} - >>> json_str = json.dumps(value) - >>> await redisJson.set(client, "doc", "$", json_str) - 'OK' # Indicates successful setting of the value at path '$' in the key stored at `doc`. - """ - args = ["JSON.SET", key, path, value] - if set_condition: - args.append(set_condition.value) - - return cast(Optional[TOK], await client.custom_command(args)) - - -async def get( - client: TGlideClient, - key: TEncodable, - paths: Optional[Union[TEncodable, List[TEncodable]]] = None, - options: Optional[JsonGetOptions] = None, -) -> Optional[bytes]: - """ - Retrieves the JSON value at the specified `paths` stored at `key`. - - See https://valkey.io/commands/json.get/ for more details. - - Args: - client (TGlideClient): The Redis client to execute the command. - key (TEncodable): The key of the JSON document. - paths (Optional[Union[TEncodable, List[TEncodable]]]): The path or list of paths within the JSON document. Default is root `$`. - options (Optional[JsonGetOptions]): Options for formatting the byte representation of the JSON data. See `JsonGetOptions`. - - Returns: - bytes: A bytes representation of the returned value. - If `key` doesn't exists, returns None. - - Examples: - >>> from glide import json as redisJson - >>> import json - >>> json_str = await redisJson.get(client, "doc", "$") - >>> json.loads(str(json_str)) # Parse JSON string to Python data - [{"a": 1.0, "b" :2}] # JSON object retrieved from the key `doc` using json.loads() - >>> await redisJson.get(client, "doc", "$") - b"[{\"a\":1.0,\"b\":2}]" # Returns the value at path '$' in the JSON document stored at `doc`. - >>> await redisJson.get(client, "doc", ["$.a", "$.b"], json.JsonGetOptions(indent=" ", newline="\n", space=" ")) - b"{\n \"$.a\": [\n 1.0\n ],\n \"$.b\": [\n 2\n ]\n}" # Returns the values at paths '$.a' and '$.b' in the JSON document stored at `doc`, with specified formatting options. - >>> await redisJson.get(client, "doc", "$.non_existing_path") - b"[]" # Returns an empty array since the path '$.non_existing_path' does not exist in the JSON document stored at `doc`. - """ - args = ["JSON.GET", key] - if options: - args.extend(options.get_options()) - if paths: - if isinstance(paths, (str, bytes)): - paths = [paths] - args.extend(paths) - - return cast(bytes, await client.custom_command(args)) - - -async def arrlen( - client: TGlideClient, - key: TEncodable, - path: Optional[TEncodable] = None, -) -> Optional[TJsonResponse[int]]: - """ - Retrieves the length of the array at the specified `path` within the JSON document stored at `key`. - - Args: - client (TGlideClient): The client to execute the command. - key (TEncodable): The key of the JSON document. - path (Optional[TEncodable]): The path within the JSON document. Defaults to None. - - Returns: - Optional[TJsonResponse[int]]: - For JSONPath (`path` starts with `$`): - Returns a list of integer replies for every possible path, indicating the length of the array, - or None for JSON values matching the path that are not an array. - If `path` doesn't exist, an empty array will be returned. - For legacy path (`path` doesn't starts with `$`): - Returns the length of the array at `path`. - If multiple paths match, the length of the first array match is returned. - If the JSON value at `path` is not a array or if `path` doesn't exist, an error is raised. - If `key` doesn't exist, None is returned. - - Examples: - >>> from glide import json - >>> await json.set(client, "doc", "$", '{"a": [1, 2, 3], "b": {"a": [1, 2], "c": {"a": 42}}}') - b'OK' # JSON is successfully set for doc - >>> await json.arrlen(client, "doc", "$") - [None] # No array at the root path. - >>> await json.arrlen(client, "doc", "$.a") - [3] # Retrieves the length of the array at path $.a. - >>> await json.arrlen(client, "doc", "$..a") - [3, 2, None] # Retrieves lengths of arrays found at all levels of the path `..a`. - >>> await json.arrlen(client, "doc", "..a") - 3 # Legacy path retrieves the first array match at path `..a`. - >>> await json.arrlen(client, "non_existing_key", "$.a") - None # Returns None because the key does not exist. - - >>> await json.set(client, "doc", "$", '[1, 2, 3, 4]') - b'OK' # JSON is successfully set for doc - >>> await json.arrlen(client, "doc") - 4 # Retrieves lengths of arrays in root. - """ - args = ["JSON.ARRLEN", key] - if path: - args.append(path) - return cast( - Optional[TJsonResponse[int]], - await client.custom_command(args), - ) - - -async def delete( - client: TGlideClient, - key: TEncodable, - path: Optional[TEncodable] = None, -) -> int: - """ - Deletes the JSON value at the specified `path` within the JSON document stored at `key`. - - See https://valkey.io/commands/json.del/ for more details. - - Args: - client (TGlideClient): The Redis client to execute the command. - key (TEncodable): The key of the JSON document. - path (Optional[TEncodable]): Represents the path within the JSON document where the value will be deleted. - If None, deletes the entire JSON document at `key`. Defaults to None. - - Returns: - int: The number of elements removed. - If `key` or path doesn't exist, returns 0. - - Examples: - >>> from glide import json as redisJson - >>> await redisJson.set(client, "doc", "$", '{"a": 1, "nested": {"a": 2, "b": 3}}') - 'OK' # Indicates successful setting of the value at path '$' in the key stored at `doc`. - >>> await redisJson.delete(client, "doc", "$..a") - 2 # Indicates successful deletion of the specific values in the key stored at `doc`. - >>> await redisJson.get(client, "doc", "$") - "[{\"nested\":{\"b\":3}}]" # Returns the value at path '$' in the JSON document stored at `doc`. - >>> await redisJson.delete(client, "doc") - 1 # Deletes the entire JSON document stored at `doc`. - """ - - return cast( - int, await client.custom_command(["JSON.DEL", key] + ([path] if path else [])) - ) - - -async def forget( - client: TGlideClient, - key: TEncodable, - path: Optional[TEncodable] = None, -) -> Optional[int]: - """ - Deletes the JSON value at the specified `path` within the JSON document stored at `key`. - - See https://valkey.io/commands/json.forget/ for more details. - - Args: - client (TGlideClient): The Redis client to execute the command. - key (TEncodable): The key of the JSON document. - path (Optional[TEncodable]): Represents the path within the JSON document where the value will be deleted. - If None, deletes the entire JSON document at `key`. Defaults to None. - - Returns: - int: The number of elements removed. - If `key` or path doesn't exist, returns 0. - - Examples: - >>> from glide import json as redisJson - >>> await redisJson.set(client, "doc", "$", '{"a": 1, "nested": {"a": 2, "b": 3}}') - 'OK' # Indicates successful setting of the value at path '$' in the key stored at `doc`. - >>> await redisJson.forget(client, "doc", "$..a") - 2 # Indicates successful deletion of the specific values in the key stored at `doc`. - >>> await redisJson.get(client, "doc", "$") - "[{\"nested\":{\"b\":3}}]" # Returns the value at path '$' in the JSON document stored at `doc`. - >>> await redisJson.forget(client, "doc") - 1 # Deletes the entire JSON document stored at `doc`. - """ - - return cast( - Optional[int], - await client.custom_command(["JSON.FORGET", key] + ([path] if path else [])), - ) - - -async def toggle( - client: TGlideClient, - key: TEncodable, - path: TEncodable, -) -> TJsonResponse[bool]: - """ - Toggles a Boolean value stored at the specified `path` within the JSON document stored at `key`. - - See https://valkey.io/commands/json.toggle/ for more details. - - Args: - client (TGlideClient): The Redis client to execute the command. - key (TEncodable): The key of the JSON document. - path (TEncodable): The JSONPath to specify. - - Returns: - TJsonResponse[bool]: For JSONPath (`path` starts with `$`), returns a list of boolean replies for every possible path, with the toggled boolean value, - or None for JSON values matching the path that are not boolean. - For legacy path (`path` doesn't starts with `$`), returns the value of the toggled boolean in `path`. - Note that when sending legacy path syntax, If `path` doesn't exist or the value at `path` isn't a boolean, an error is raised. - For more information about the returned type, see `TJsonResponse`. - - Examples: - >>> from glide import json as redisJson - >>> import json - >>> await redisJson.set(client, "doc", "$", json.dumps({"bool": True, "nested": {"bool": False, "nested": {"bool": 10}}})) - 'OK' - >>> await redisJson.toggle(client, "doc", "$.bool") - [False, True, None] # Indicates successful toggling of the Boolean values at path '$.bool' in the key stored at `doc`. - >>> await redisJson.toggle(client, "doc", "bool") - True # Indicates successful toggling of the Boolean value at path 'bool' in the key stored at `doc`. - >>> json.loads(await redisJson.get(client, "doc", "$")) - [{"bool": True, "nested": {"bool": True, "nested": {"bool": 10}}}] # The updated JSON value in the key stored at `doc`. - """ - - return cast( - TJsonResponse[bool], - await client.custom_command(["JSON.TOGGLE", key, path]), - ) - - -async def type( - client: TGlideClient, - key: TEncodable, - path: Optional[TEncodable] = None, -) -> Optional[Union[bytes, List[bytes]]]: - """ - Retrieves the type of the JSON value at the specified `path` within the JSON document stored at `key`. - - Args: - client (TGlideClient): The client to execute the command. - key (TEncodable): The key of the JSON document. - path (Optional[TEncodable]): Represents the path within the JSON document where the type will be retrieved. - Defaults to None. - - Returns: - Optional[Union[bytes, List[bytes]]]: - For JSONPath ('path' starts with '$'): - Returns a list of byte string replies for every possible path, indicating the type of the JSON value. - If `path` doesn't exist, an empty array will be returned. - For legacy path (`path` doesn't starts with `$`): - Returns the type of the JSON value at `path`. - If multiple paths match, the type of the first JSON value match is returned. - If `path` doesn't exist, None will be returned. - If `key` doesn't exist, None is returned. - - Examples: - >>> from glide import json - >>> await json.set(client, "doc", "$", '{"a": 1, "nested": {"a": 2, "b": 3}}') - >>> await json.type(client, "doc", "$.nested") - [b'object'] # Indicates the type of the value at path '$.nested' in the key stored at `doc`. - >>> await json.type(client, "doc", "$.nested.a") - [b'integer'] # Indicates the type of the value at path '$.nested.a' in the key stored at `doc`. - >>> await json.type(client, "doc", "$[*]") - [b'integer', b'object'] # Array of types in all top level elements. - """ - args = ["JSON.TYPE", key] - if path: - args.append(path) - - return cast(Optional[Union[bytes, List[bytes]]], await client.custom_command(args)) diff --git a/python/python/glide/async_commands/standalone_commands.py b/python/python/glide/async_commands/standalone_commands.py index 6bf5e81140..1cb2230a87 100644 --- a/python/python/glide/async_commands/standalone_commands.py +++ b/python/python/glide/async_commands/standalone_commands.py @@ -2,19 +2,17 @@ from __future__ import annotations -from typing import Any, Dict, List, Mapping, Optional, Set, Union, cast +from typing import Dict, List, Mapping, Optional, Union, cast -from glide.async_commands.command_args import Limit, ObjectType, OrderBy +from glide.async_commands.command_args import ObjectType from glide.async_commands.core import ( CoreCommands, FlushMode, FunctionRestorePolicy, InfoSection, - _build_sort_args, ) from glide.async_commands.transaction import Transaction from glide.constants import ( - OK, TOK, TEncodable, TFunctionListResponse, @@ -23,7 +21,7 @@ ) from glide.protobuf.command_request_pb2 import RequestType -from ..glide import ClusterScanCursor, Script +from ..glide import Script class StandaloneCommands(CoreCommands): diff --git a/python/python/glide/config.py b/python/python/glide/config.py index db85202876..b33c037cbf 100644 --- a/python/python/glide/config.py +++ b/python/python/glide/config.py @@ -41,6 +41,11 @@ class ReadFrom(Enum): Spread the requests between all replicas in a round robin manner. If no replica is available, route the requests to the primary. """ + AZ_AFFINITY = ProtobufReadFrom.AZAffinity + """ + Spread the read requests between replicas in the same client's AZ (Aviliablity zone) in a round robin manner, + falling back to other replicas or the primary if needed + """ class ProtocolVersion(Enum): @@ -135,6 +140,7 @@ def __init__( client_name: Optional[str] = None, protocol: ProtocolVersion = ProtocolVersion.RESP3, inflight_requests_limit: Optional[int] = None, + client_az: Optional[str] = None, ): """ Represents the configuration settings for a Glide client. @@ -172,6 +178,12 @@ def __init__( self.client_name = client_name self.protocol = protocol self.inflight_requests_limit = inflight_requests_limit + self.client_az = client_az + + if read_from == ReadFrom.AZ_AFFINITY and not client_az: + raise ValueError( + "client_az mus t be set when read_from is set to AZ_AFFINITY" + ) def _create_a_protobuf_conn_request( self, cluster_mode: bool = False @@ -204,6 +216,8 @@ def _create_a_protobuf_conn_request( request.protocol = self.protocol.value if self.inflight_requests_limit: request.inflight_requests_limit = self.inflight_requests_limit + if self.client_az: + request.client_az = self.client_az return request @@ -293,6 +307,7 @@ def __init__( protocol: ProtocolVersion = ProtocolVersion.RESP3, pubsub_subscriptions: Optional[PubSubSubscriptions] = None, inflight_requests_limit: Optional[int] = None, + client_az: Optional[str] = None, ): super().__init__( addresses=addresses, @@ -303,6 +318,7 @@ def __init__( client_name=client_name, protocol=protocol, inflight_requests_limit=inflight_requests_limit, + client_az=client_az, ) self.reconnect_strategy = reconnect_strategy self.database_id = database_id @@ -442,6 +458,7 @@ def __init__( ] = PeriodicChecksStatus.ENABLED_DEFAULT_CONFIGS, pubsub_subscriptions: Optional[PubSubSubscriptions] = None, inflight_requests_limit: Optional[int] = None, + client_az: Optional[str] = None, ): super().__init__( addresses=addresses, @@ -452,6 +469,7 @@ def __init__( client_name=client_name, protocol=protocol, inflight_requests_limit=inflight_requests_limit, + client_az=client_az, ) self.periodic_checks = periodic_checks self.pubsub_subscriptions = pubsub_subscriptions diff --git a/python/python/glide/constants.py b/python/python/glide/constants.py index 754aacf6fa..9740ac8cf6 100644 --- a/python/python/glide/constants.py +++ b/python/python/glide/constants.py @@ -1,6 +1,6 @@ # Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 -from typing import Dict, List, Literal, Mapping, Optional, Set, TypeVar, Union +from typing import Any, Dict, List, Literal, Mapping, Optional, Set, TypeVar, Union from glide.protobuf.command_request_pb2 import CommandRequest from glide.protobuf.connection_request_pb2 import ConnectionRequest @@ -33,8 +33,27 @@ TSingleNodeRoute = Union[RandomNode, SlotKeyRoute, SlotIdRoute, ByAddressRoute] # When specifying legacy path (path doesn't start with `$`), response will be T # Otherwise, (when specifying JSONPath), response will be List[Optional[T]]. +# +# TJsonResponse is designed to handle scenarios where some paths may not contain valid values, especially with JSONPath targeting multiple paths. +# In such cases, the response may include None values, represented as `Optional[T]` in the list. +# This type provides flexibility for commands where a subset of the paths may return None. +# # For more information, see: https://redis.io/docs/data-types/json/path/ . TJsonResponse = Union[T, List[Optional[T]]] + +# When specifying legacy path (path doesn't start with `$`), response will be T +# Otherwise, (when specifying JSONPath), response will be List[T]. +# This type represents the response format for commands that apply to every path and every type in a JSON document. +# It covers both singular and multiple paths, ensuring that the command returns valid results for each matched path without None values. +# +# TJsonUniversalResponse is considered "universal" because it applies to every matched path and +# guarantees valid, non-null results across all paths, covering both singular and multiple paths. +# This type is used for commands that return results from all matched paths, ensuring that each +# path contains meaningful values without None entries (unless it's part of the commands response). +# It is typically used in scenarios where each target is expected to yield a valid response. For commands that are valid for all target types. +# +# For more information, see: https://redis.io/docs/data-types/json/path/ . +TJsonUniversalResponse = Union[T, List[T]] TEncodable = Union[str, bytes] TFunctionListResponse = List[ Mapping[ @@ -74,3 +93,28 @@ List[Mapping[bytes, Union[bytes, int, List[List[Union[bytes, int]]]]]], ], ] + +FtInfoResponse = Mapping[ + TEncodable, + Union[ + TEncodable, + int, + List[TEncodable], + List[ + Mapping[ + TEncodable, + Union[TEncodable, Mapping[TEncodable, Union[TEncodable, int]]], + ] + ], + ], +] + +FtSearchResponse = List[ + Union[int, Mapping[TEncodable, Mapping[TEncodable, TEncodable]]] +] + +FtAggregateResponse = List[Mapping[TEncodable, Any]] + +FtProfileResponse = List[ + Union[FtSearchResponse, FtAggregateResponse, Mapping[str, int]] +] diff --git a/python/python/glide/glide.pyi b/python/python/glide/glide.pyi index b544a3948e..bbd5274770 100644 --- a/python/python/glide/glide.pyi +++ b/python/python/glide/glide.pyi @@ -31,5 +31,6 @@ def start_socket_listener_external(init_callback: Callable) -> None: ... def value_from_pointer(pointer: int) -> TResult: ... def create_leaked_value(message: str) -> int: ... def create_leaked_bytes_vec(args_vec: List[bytes]) -> int: ... +def get_statistics() -> dict: ... def py_init(level: Optional[Level], file_name: Optional[str]) -> Level: ... def py_log(log_level: Level, log_identifier: str, message: str) -> None: ... diff --git a/python/python/glide/glide_client.py b/python/python/glide/glide_client.py index f53644fa3d..6178b997a7 100644 --- a/python/python/glide/glide_client.py +++ b/python/python/glide/glide_client.py @@ -9,7 +9,7 @@ from glide.async_commands.command_args import ObjectType from glide.async_commands.core import CoreCommands from glide.async_commands.standalone_commands import StandaloneCommands -from glide.config import BaseClientConfiguration +from glide.config import BaseClientConfiguration, ServerCredentials from glide.constants import DEFAULT_READ_BYTES_SIZE, OK, TEncodable, TRequest, TResult from glide.exceptions import ( ClosingError, @@ -32,6 +32,7 @@ MAX_REQUEST_ARGS_LEN, ClusterScanCursor, create_leaked_bytes_vec, + get_statistics, start_socket_listener_external, value_from_pointer, ) @@ -523,7 +524,7 @@ async def _reader_loop(self) -> None: read_bytes, read_bytes_view, offset, Response ) except PartialMessageException: - # Recieved only partial response, break the inner loop + # Received only partial response, break the inner loop remaining_read_bytes = read_bytes[offset:] break response = cast(Response, response) @@ -532,6 +533,25 @@ async def _reader_loop(self) -> None: else: await self._process_response(response=response) + async def get_statistics(self) -> dict: + return get_statistics() + + async def _update_connection_password( + self, password: Optional[str], immediate_auth: bool + ) -> TResult: + request = CommandRequest() + request.callback_idx = self._get_callback_index() + if password is not None: + request.update_connection_password.password = password + request.update_connection_password.immediate_auth = immediate_auth + response = await self._write_request_await_response(request) + # Update the client binding side password if managed to change core configuration password + if response is OK: + if self.config.credentials is None: + self.config.credentials = ServerCredentials(password=password or "") + self.config.credentials.password = password or "" + return response + class GlideClusterClient(BaseClient, ClusterCommands): """ diff --git a/python/python/tests/conftest.py b/python/python/tests/conftest.py index 9b1db487da..0937ca2067 100644 --- a/python/python/tests/conftest.py +++ b/python/python/tests/conftest.py @@ -9,12 +9,16 @@ GlideClusterClientConfiguration, NodeAddress, ProtocolVersion, + ReadFrom, ServerCredentials, ) +from glide.exceptions import ClosingError from glide.glide_client import GlideClient, GlideClusterClient, TGlideClient from glide.logger import Level as logLevel from glide.logger import Logger +from glide.routes import AllNodes from tests.utils.cluster import ValkeyCluster +from tests.utils.utils import check_if_server_version_lt DEFAULT_HOST = "localhost" DEFAULT_PORT = 6379 @@ -129,6 +133,7 @@ def create_clusters(tls, load_module, cluster_endpoints, standalone_endpoints): cluster_mode=True, load_module=load_module, addresses=cluster_endpoints, + replica_count=2, ) pytest.standalone_cluster = ValkeyCluster( tls=tls, @@ -203,9 +208,11 @@ def pytest_collection_modifyitems(config, items): ) -@pytest.fixture() +@pytest.fixture(scope="function") async def glide_client( - request, cluster_mode: bool, protocol: ProtocolVersion + request, + cluster_mode: bool, + protocol: ProtocolVersion, ) -> AsyncGenerator[TGlideClient, None]: "Get async socket client for tests" client = await create_client(request, cluster_mode, protocol=protocol) @@ -214,6 +221,19 @@ async def glide_client( await client.close() +@pytest.fixture(scope="function") +async def management_client( + request, + cluster_mode: bool, + protocol: ProtocolVersion, +) -> AsyncGenerator[TGlideClient, None]: + "Get async socket client for tests, used to manage the state when tests are on the client ability to connect" + client = await create_client(request, cluster_mode, protocol=protocol) + yield client + await test_teardown(request, cluster_mode, protocol) + await client.close() + + async def create_client( request, cluster_mode: bool, @@ -222,7 +242,7 @@ async def create_client( addresses: Optional[List[NodeAddress]] = None, client_name: Optional[str] = None, protocol: ProtocolVersion = ProtocolVersion.RESP3, - timeout: Optional[int] = None, + timeout: Optional[int] = 1000, cluster_mode_pubsub: Optional[ GlideClusterClientConfiguration.PubSubSubscriptions ] = None, @@ -230,6 +250,8 @@ async def create_client( GlideClientConfiguration.PubSubSubscriptions ] = None, inflight_requests_limit: Optional[int] = None, + read_from: ReadFrom = ReadFrom.PRIMARY, + client_az: Optional[str] = None, ) -> Union[GlideClient, GlideClusterClient]: # Create async socket client use_tls = request.config.getoption("--tls") @@ -247,6 +269,8 @@ async def create_client( request_timeout=timeout, pubsub_subscriptions=cluster_mode_pubsub, inflight_requests_limit=inflight_requests_limit, + read_from=read_from, + client_az=client_az, ) return await GlideClusterClient.create(cluster_config) else: @@ -263,18 +287,105 @@ async def create_client( request_timeout=timeout, pubsub_subscriptions=standalone_mode_pubsub, inflight_requests_limit=inflight_requests_limit, + read_from=read_from, + client_az=client_az, ) return await GlideClient.create(config) +NEW_PASSWORD = "new_secure_password" +WRONG_PASSWORD = "wrong_password" + + +async def auth_client(client: TGlideClient, password): + """ + Authenticates the given TGlideClient server connected. + """ + if isinstance(client, GlideClient): + await client.custom_command(["AUTH", password]) + elif isinstance(client, GlideClusterClient): + await client.custom_command(["AUTH", password], route=AllNodes()) + + +async def config_set_new_password(client: TGlideClient, password): + """ + Sets a new password for the given TGlideClient server connected. + This function updates the server to require a new password. + """ + if isinstance(client, GlideClient): + await client.config_set({"requirepass": password}) + elif isinstance(client, GlideClusterClient): + await client.config_set({"requirepass": password}, route=AllNodes()) + + +async def kill_connections(client: TGlideClient): + """ + Kills all connections to the given TGlideClient server connected. + """ + if isinstance(client, GlideClient): + await client.custom_command(["CLIENT", "KILL", "TYPE", "normal"]) + elif isinstance(client, GlideClusterClient): + await client.custom_command( + ["CLIENT", "KILL", "TYPE", "normal"], route=AllNodes() + ) + + async def test_teardown(request, cluster_mode: bool, protocol: ProtocolVersion): """ Perform teardown tasks such as flushing all data from the cluster. - We create a new client here because some tests load lots of data to the cluster, - which might cause the client to time out during flushing. Therefore, we create - a client with a custom timeout to ensure the operation completes successfully. + If authentication is required, attempt to connect with the known password, + reset it back to empty, and proceed with teardown. """ - client = await create_client(request, cluster_mode, protocol=protocol, timeout=2000) - await client.custom_command(["FLUSHALL"]) - await client.close() + credentials = None + try: + # Try connecting without credentials + client = await create_client( + request, cluster_mode, protocol=protocol, timeout=2000 + ) + await client.custom_command(["FLUSHALL"]) + await client.close() + except ClosingError as e: + # Check if the error is due to authentication + if "NOAUTH" in str(e): + # Use the known password to authenticate + credentials = ServerCredentials(password=NEW_PASSWORD) + client = await create_client( + request, + cluster_mode, + protocol=protocol, + timeout=2000, + credentials=credentials, + ) + try: + await auth_client(client, NEW_PASSWORD) + # Reset the server password back to empty + await config_set_new_password(client, "") + await client.update_connection_password(None) + # Perform the teardown + await client.custom_command(["FLUSHALL"]) + finally: + await client.close() + else: + raise e + + +@pytest.fixture(autouse=True) +async def skip_if_version_below(request): + """ + Skip test(s) if server version is below than given parameter. Can skip a complete test suite. + + Example: + + @pytest.mark.skip_if_version_below('7.0.0') + async def test_meow_meow(...): + ... + """ + if request.node.get_closest_marker("skip_if_version_below"): + min_version = request.node.get_closest_marker("skip_if_version_below").args[0] + client = await create_client(request, False) + if await check_if_server_version_lt(client, min_version): + pytest.skip( + reason=f"This feature added in version {min_version}", + allow_module_level=True, + ) diff --git a/python/python/tests/test_async_client.py b/python/python/tests/test_async_client.py index 7566194dcc..b32aa6936d 100644 --- a/python/python/tests/test_async_client.py +++ b/python/python/tests/test_async_client.py @@ -1,4 +1,5 @@ # Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 +# mypy: disable_error_code="arg-type" from __future__ import annotations @@ -104,6 +105,7 @@ get_random_string, is_single_response, parse_info_response, + round_values, ) @@ -135,6 +137,7 @@ async def test_send_and_receive_large_values(self, request, cluster_mode, protoc assert len(value) == length await glide_client.set(key, value) assert await glide_client.get(key) == value.encode() + await glide_client.close() @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) @@ -200,6 +203,8 @@ async def test_can_connect_with_auth_requirepass( key = get_random_string(10) assert await auth_client.set(key, key) == OK assert await auth_client.get(key) == key.encode() + await auth_client.close() + finally: # Reset the password auth_client = await create_client( @@ -209,6 +214,7 @@ async def test_can_connect_with_auth_requirepass( addresses=glide_client.config.addresses, ) await auth_client.custom_command(["CONFIG", "SET", "requirepass", ""]) + await auth_client.close() @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) @@ -252,6 +258,7 @@ async def test_can_connect_with_auth_acl( # This client isn't authorized to perform SET await testuser_client.set("foo", "bar") assert "NOPERM" in str(e) + await testuser_client.close() finally: # Delete this user await glide_client.custom_command(["ACL", "DELUSER", username]) @@ -263,6 +270,7 @@ async def test_select_standalone_database_id(self, request, cluster_mode): ) client_info = await glide_client.custom_command(["CLIENT", "INFO"]) assert b"db=4" in client_info + await glide_client.close() @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) @@ -275,6 +283,7 @@ async def test_client_name(self, request, cluster_mode, protocol): ) client_info = await glide_client.custom_command(["CLIENT", "INFO"]) assert b"name=TEST_CLIENT_NAME" in client_info + await glide_client.close() @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) @@ -284,6 +293,15 @@ async def test_closed_client_raises_error(self, glide_client: TGlideClient): await glide_client.set("foo", "bar") assert "the client is closed" in str(e) + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_statistics(self, glide_client: TGlideClient): + stats = await glide_client.get_statistics() + assert isinstance(stats, dict) + assert "total_connections" in stats + assert "total_clients" in stats + assert len(stats) == 2 + @pytest.mark.asyncio class TestCommands: @@ -2686,6 +2704,7 @@ async def test_geosearchstore_by_box(self, glide_client: TGlideClient): ) expected_map = {member: value[1] for member, value in result.items()} sorted_expected_map = dict(sorted(expected_map.items(), key=lambda x: x[1])) + zrange_map = round_values(zrange_map, 10) assert compare_maps(zrange_map, sorted_expected_map) is True # Test storing results of a box search, unit: kilometes, from a geospatial data, with distance @@ -2705,6 +2724,8 @@ async def test_geosearchstore_by_box(self, glide_client: TGlideClient): ) expected_map = {member: value[0] for member, value in result.items()} sorted_expected_map = dict(sorted(expected_map.items(), key=lambda x: x[1])) + zrange_map = round_values(zrange_map, 10) + sorted_expected_map = round_values(sorted_expected_map, 10) assert compare_maps(zrange_map, sorted_expected_map) is True # Test storing results of a box search, unit: kilometes, from a geospatial data, with count @@ -2745,6 +2766,8 @@ async def test_geosearchstore_by_box(self, glide_client: TGlideClient): b"Palermo": 166274.15156960033, b"edge2": 236529.17986494553, } + zrange_map = round_values(zrange_map, 9) + expected_distances = round_values(expected_distances, 9) assert compare_maps(zrange_map, expected_distances) is True # Test search by box, unit: feet, from a member, with limited ANY count to 2, with hash @@ -2826,6 +2849,8 @@ async def test_geosearchstore_by_radius(self, glide_client: TGlideClient): b"Catania": 0.0, b"Palermo": 166274.15156960033, } + zrange_map = round_values(zrange_map, 9) + expected_distances = round_values(expected_distances, 9) assert compare_maps(zrange_map, expected_distances) is True # Test search by radius, unit: miles, from a geospatial data @@ -2859,6 +2884,8 @@ async def test_geosearchstore_by_radius(self, glide_client: TGlideClient): ) expected_map = {member: value[0] for member, value in result.items()} sorted_expected_map = dict(sorted(expected_map.items(), key=lambda x: x[1])) + zrange_map = round_values(zrange_map, 10) + sorted_expected_map = round_values(sorted_expected_map, 10) assert compare_maps(zrange_map, sorted_expected_map) is True # Test storing results of a radius search, unit: kilometers, from a geospatial data, with limited ANY count to 1 @@ -8462,6 +8489,7 @@ async def wait_and_function_kill(): with pytest.raises(RequestError) as e: assert await glide_client.function_kill() assert "NotBusy" in str(e) + await test_client.close() @pytest.mark.parametrize("cluster_mode", [False, True]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) @@ -8512,6 +8540,7 @@ async def wait_and_function_kill(): endless_fcall_route_call(), wait_and_function_kill(), ) + await test_client.close() @pytest.mark.parametrize("cluster_mode", [True]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) @@ -10426,7 +10455,7 @@ async def test_script_flush(self, glide_client: TGlideClient): assert await glide_client.script_exists([script.get_hash()]) == [False] @pytest.mark.parametrize("cluster_mode", [True]) - @pytest.mark.parametrize("single_route", [True, False]) + @pytest.mark.parametrize("single_route", [True]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) async def test_script_kill_route( self, diff --git a/python/python/tests/test_auth.py b/python/python/tests/test_auth.py new file mode 100644 index 0000000000..7e3fc67851 --- /dev/null +++ b/python/python/tests/test_auth.py @@ -0,0 +1,174 @@ +# Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +import asyncio + +import pytest +from glide.config import ProtocolVersion +from glide.constants import OK +from glide.exceptions import RequestError +from glide.glide_client import GlideClient, GlideClusterClient, TGlideClient +from tests.conftest import ( + NEW_PASSWORD, + WRONG_PASSWORD, + auth_client, + config_set_new_password, + kill_connections, +) + + +@pytest.mark.asyncio +class TestAuthCommands: + """Test cases for password authentication and management""" + + @pytest.fixture(autouse=True, scope="function") + async def cleanup(self, request, management_client: TGlideClient): + """ + Ensure password is reset after each test, regardless of test outcome. + This fixture runs after each test. + """ + yield + try: + await auth_client(management_client, NEW_PASSWORD) + await config_set_new_password(management_client, "") + await management_client.update_connection_password(None) + except RequestError: + pass + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_update_connection_password( + self, glide_client: TGlideClient, management_client: TGlideClient + ): + """ + Test replacing the connection password without immediate re-authentication. + Verifies that: + 1. The client can update its internal password + 2. The client remains connected with current auth + 3. The client can reconnect using the new password after server password change + This test is only for cluster mode, as standalone mode does not have a connection available handler + """ + result = await glide_client.update_connection_password( + NEW_PASSWORD, immediate_auth=False + ) + assert result == OK + # Verify that the client is still authenticated + assert await glide_client.set("test_key", "test_value") == OK + value = await glide_client.get("test_key") + assert value == b"test_value" + await config_set_new_password(glide_client, NEW_PASSWORD) + await kill_connections(management_client) + # Add a short delay to allow the server to apply the new password + # without this delay, command may or may not time out while the client reconnect + # ending up with a flaky test + await asyncio.sleep(1) + # Verify that the client is able to reconnect with the new password, + value = await glide_client.get("test_key") + assert value == b"test_value" + await glide_client.update_connection_password(None) + await kill_connections(management_client) + # Verify that the client is able to immediateAuth with the new password after client is killed + result = await glide_client.update_connection_password( + NEW_PASSWORD, immediate_auth=True + ) + assert result == OK + # Verify that the client is still authenticated + assert await glide_client.set("test_key", "test_value") == OK + + @pytest.mark.parametrize("cluster_mode", [False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_update_connection_password_connection_lost_before_password_update( + self, glide_client: TGlideClient, management_client: TGlideClient + ): + """ + Test changing server password when connection is lost before password update. + Verifies that the client will not be able to reach the inner core and return an error. + """ + await glide_client.set("test_key", "test_value") + await config_set_new_password(glide_client, NEW_PASSWORD) + await kill_connections(management_client) + await asyncio.sleep(1) + with pytest.raises(RequestError): + await glide_client.update_connection_password( + NEW_PASSWORD, immediate_auth=False + ) + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_update_connection_password_no_server_auth( + self, glide_client: TGlideClient, management_client: TGlideClient + ): + """ + Test that immediate re-authentication fails when no server password is set. + This verifies proper error handling when trying to re-authenticate with a + password when the server has no password set. + """ + with pytest.raises(RequestError): + await glide_client.update_connection_password( + WRONG_PASSWORD, immediate_auth=True + ) + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_update_connection_password_long( + self, glide_client: TGlideClient, management_client: TGlideClient + ): + """ + Test replacing connection password with a long password string. + Verifies that the client can handle long passwords (1000 characters). + """ + long_password = "p" * 1000 + result = await glide_client.update_connection_password( + long_password, immediate_auth=False + ) + assert result == OK + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_replace_password_immediate_auth_wrong_password( + self, glide_client: TGlideClient, management_client: TGlideClient + ): + """ + Test that re-authentication fails when using wrong password. + Verifies proper error handling when immediate re-authentication is attempted + with a password that doesn't match the server's password. + """ + await config_set_new_password(glide_client, NEW_PASSWORD) + with pytest.raises(RequestError): + await glide_client.update_connection_password( + WRONG_PASSWORD, immediate_auth=True + ) + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_update_connection_password_with_immediate_auth( + self, glide_client: TGlideClient, management_client: TGlideClient + ): + """ + Test replacing connection password with immediate re-authentication. + Verifies that: + 1. The client can update its password and re-authenticate immediately + 2. The client remains operational after re-authentication + """ + await config_set_new_password(glide_client, NEW_PASSWORD) + result = await glide_client.update_connection_password( + NEW_PASSWORD, immediate_auth=True + ) + assert result == OK + # Verify that the client is still authenticated + assert await glide_client.set("test_key", "test_value") == OK + value = await glide_client.get("test_key") + assert value == b"test_value" + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_update_connection_password_auth_non_valid_pass( + self, glide_client: TGlideClient, management_client: TGlideClient + ): + """ + Test replacing connection password with immediate re-authentication using a non-valid password. + Verifies that immediate re-authentication fails when the password is not valid. + """ + with pytest.raises(RequestError): + await glide_client.update_connection_password(None, immediate_auth=True) + with pytest.raises(RequestError): + await glide_client.update_connection_password("", immediate_auth=True) diff --git a/python/python/tests/test_config.py b/python/python/tests/test_config.py index 93c280245f..3b22adb09c 100644 --- a/python/python/tests/test_config.py +++ b/python/python/tests/test_config.py @@ -52,3 +52,18 @@ def test_periodic_checks_interval_to_protobuf(): config.periodic_checks = PeriodicChecksManualInterval(30) request = config._create_a_protobuf_conn_request(cluster_mode=True) assert request.periodic_checks_manual_interval.duration_in_sec == 30 + + +def test_convert_config_with_azaffinity_to_protobuf(): + az = "us-east-1a" + config = BaseClientConfiguration( + [NodeAddress("127.0.0.1")], + use_tls=True, + read_from=ReadFrom.AZ_AFFINITY, + client_az=az, + ) + request = config._create_a_protobuf_conn_request() + assert isinstance(request, ConnectionRequest) + assert request.tls_mode is TlsMode.SecureTls + assert request.read_from == ProtobufReadFrom.AZAffinity + assert request.client_az == az diff --git a/python/python/tests/test_pubsub.py b/python/python/tests/test_pubsub.py index 4c2e5757e8..6069104ed7 100644 --- a/python/python/tests/test_pubsub.py +++ b/python/python/tests/test_pubsub.py @@ -13,9 +13,9 @@ GlideClusterClientConfiguration, ProtocolVersion, ) -from glide.constants import OK, TEncodable +from glide.constants import OK from glide.exceptions import ConfigurationError -from glide.glide_client import BaseClient, GlideClient, GlideClusterClient, TGlideClient +from glide.glide_client import GlideClient, GlideClusterClient, TGlideClient from tests.conftest import create_client from tests.utils.utils import check_if_server_version_lt, get_random_string @@ -469,6 +469,7 @@ async def test_pubsub_exact_happy_path_many_channels_co_existence( await client_cleanup(listening_client, pub_sub if cluster_mode else None) await client_cleanup(publishing_client, None) + @pytest.mark.skip_if_version_below("7.0.0") @pytest.mark.parametrize("cluster_mode", [True]) @pytest.mark.parametrize( "method", [MethodTesting.Async, MethodTesting.Sync, MethodTesting.Callback] @@ -506,9 +507,6 @@ async def test_sharded_pubsub( listening_client, publishing_client = await create_two_clients_with_pubsub( request, cluster_mode, pub_sub ) - min_version = "7.0.0" - if await check_if_server_version_lt(publishing_client, min_version): - pytest.skip(reason=f"Valkey version required >= {min_version}") assert ( await cast(GlideClusterClient, publishing_client).publish( @@ -534,6 +532,7 @@ async def test_sharded_pubsub( await client_cleanup(listening_client, pub_sub if cluster_mode else None) await client_cleanup(publishing_client, None) + @pytest.mark.skip_if_version_below("7.0.0") @pytest.mark.parametrize("cluster_mode", [True]) async def test_sharded_pubsub_co_existence(self, request, cluster_mode: bool): """ @@ -563,10 +562,6 @@ async def test_sharded_pubsub_co_existence(self, request, cluster_mode: bool): request, cluster_mode, pub_sub ) - min_version = "7.0.0" - if await check_if_server_version_lt(publishing_client, min_version): - pytest.skip(reason=f"Valkey version required >= {min_version}") - assert ( await cast(GlideClusterClient, publishing_client).publish( message, channel, sharded=True @@ -608,6 +603,7 @@ async def test_sharded_pubsub_co_existence(self, request, cluster_mode: bool): await client_cleanup(listening_client, pub_sub if cluster_mode else None) await client_cleanup(publishing_client, None) + @pytest.mark.skip_if_version_below("7.0.0") @pytest.mark.parametrize("cluster_mode", [True]) @pytest.mark.parametrize( "method", [MethodTesting.Async, MethodTesting.Sync, MethodTesting.Callback] @@ -656,10 +652,6 @@ async def test_sharded_pubsub_many_channels( request, cluster_mode, pub_sub ) - min_version = "7.0.0" - if await check_if_server_version_lt(publishing_client, min_version): - pytest.skip(reason=f"Valkey version required >= {min_version}") - # Publish messages to each channel for channel, message in channels_and_messages.items(): assert ( @@ -1172,6 +1164,7 @@ async def test_pubsub_combined_exact_and_pattern_multiple_clients( ) await client_cleanup(client_dont_care, None) + @pytest.mark.skip_if_version_below("7.0.0") @pytest.mark.parametrize("cluster_mode", [True]) @pytest.mark.parametrize( "method", [MethodTesting.Async, MethodTesting.Sync, MethodTesting.Callback] @@ -1247,10 +1240,6 @@ async def test_pubsub_combined_exact_pattern_and_sharded_one_client( pub_sub_exact, ) - # Setup PUBSUB for sharded channels (Valkey version > 7) - if await check_if_server_version_lt(publishing_client, "7.0.0"): - pytest.skip("Valkey version required >= 7.0.0") - # Publish messages to all channels for channel, message in { **exact_channels_and_messages, @@ -1308,6 +1297,7 @@ async def test_pubsub_combined_exact_pattern_and_sharded_one_client( ) await client_cleanup(publishing_client, None) + @pytest.mark.skip_if_version_below("7.0.0") @pytest.mark.parametrize("cluster_mode", [True]) @pytest.mark.parametrize( "method", [MethodTesting.Async, MethodTesting.Sync, MethodTesting.Callback] @@ -1399,10 +1389,6 @@ async def test_pubsub_combined_exact_pattern_and_sharded_multi_client( ) ) - # Setup PUBSUB for sharded channels (Valkey version > 7) - if await check_if_server_version_lt(publishing_client, "7.0.0"): - pytest.skip("Valkey version required >= 7.0.0") - if method == MethodTesting.Callback: context = callback_messages_pattern @@ -1534,6 +1520,7 @@ async def test_pubsub_combined_exact_pattern_and_sharded_multi_client( listening_client_sharded, pub_sub_sharded if cluster_mode else None ) + @pytest.mark.skip_if_version_below("7.0.0") @pytest.mark.parametrize("cluster_mode", [True]) @pytest.mark.parametrize( "method", [MethodTesting.Async, MethodTesting.Sync, MethodTesting.Callback] @@ -1603,10 +1590,6 @@ async def test_pubsub_combined_different_channels_with_same_name( ) ) - # (Valkey version > 7) - if await check_if_server_version_lt(publishing_client, "7.0.0"): - pytest.skip("Valkey version required >= 7.0.0") - # Setup PUBSUB for pattern channel if method == MethodTesting.Callback: context = callback_messages_pattern @@ -1801,6 +1784,7 @@ async def test_pubsub_two_publishing_clients_same_name( client_pattern, pub_sub_pattern if cluster_mode else None ) + @pytest.mark.skip_if_version_below("7.0.0") @pytest.mark.parametrize("cluster_mode", [True]) @pytest.mark.parametrize( "method", [MethodTesting.Async, MethodTesting.Sync, MethodTesting.Callback] @@ -1892,9 +1876,6 @@ async def test_pubsub_three_publishing_clients_same_name_with_sharded( client_sharded, client_dont_care = await create_two_clients_with_pubsub( request, cluster_mode, pub_sub_sharded ) - # (Valkey version > 7) - if await check_if_server_version_lt(client_pattern, "7.0.0"): - pytest.skip("Valkey version required >= 7.0.0") # Publish messages to each channel - both clients publishing assert ( @@ -2024,6 +2005,7 @@ async def test_pubsub_exact_max_size_message(self, request, cluster_mode: bool): await client_cleanup(listening_client, pub_sub if cluster_mode else None) await client_cleanup(publishing_client, None) + @pytest.mark.skip_if_version_below("7.0.0") @pytest.mark.skip( reason="This test requires special configuration for client-output-buffer-limit for valkey-server and timeouts seems to vary across platforms and server versions" ) @@ -2062,10 +2044,6 @@ async def test_pubsub_sharded_max_size_message(self, request, cluster_mode: bool timeout=10000, ) - # (Valkey version > 7) - if await check_if_server_version_lt(publishing_client, "7.0.0"): - pytest.skip("Valkey version required >= 7.0.0") - assert ( await cast(GlideClusterClient, publishing_client).publish( message, channel, sharded=True @@ -2161,6 +2139,7 @@ async def test_pubsub_exact_max_size_message_callback( await client_cleanup(listening_client, pub_sub if cluster_mode else None) await client_cleanup(publishing_client, None) + @pytest.mark.skip_if_version_below("7.0.0") @pytest.mark.skip( reason="This test requires special configuration for client-output-buffer-limit for valkey-server and timeouts seems to vary across platforms and server versions" ) @@ -2201,10 +2180,6 @@ async def test_pubsub_sharded_max_size_message_callback( request, cluster_mode, pub_sub, timeout=10000 ) - # (Valkey version > 7) - if await check_if_server_version_lt(publishing_client, "7.0.0"): - pytest.skip("Valkey version required >= 7.0.0") - assert ( await cast(GlideClusterClient, publishing_client).publish( message, channel, sharded=True @@ -2463,6 +2438,7 @@ async def test_pubsub_numsub(self, request, cluster_mode: bool): await client_cleanup(client4, None) await client_cleanup(client, None) + @pytest.mark.skip_if_version_below("7.0.0") @pytest.mark.parametrize("cluster_mode", [True]) async def test_pubsub_shardchannels(self, request, cluster_mode: bool): """ @@ -2479,9 +2455,6 @@ async def test_pubsub_shardchannels(self, request, cluster_mode: bool): pattern = "test_*" client = await create_client(request, cluster_mode) - min_version = "7.0.0" - if await check_if_server_version_lt(client, min_version): - pytest.skip(reason=f"Valkey version required >= {min_version}") assert type(client) == GlideClusterClient # Assert no sharded channels exist yet assert await client.pubsub_shardchannels() == [] @@ -2524,6 +2497,7 @@ async def test_pubsub_shardchannels(self, request, cluster_mode: bool): await client_cleanup(client2, None) await client_cleanup(client, None) + @pytest.mark.skip_if_version_below("7.0.0") @pytest.mark.parametrize("cluster_mode", [True]) async def test_pubsub_shardnumsub(self, request, cluster_mode: bool): """ @@ -2578,9 +2552,7 @@ async def test_pubsub_shardnumsub(self, request, cluster_mode: bool): # Create a client and check initial subscribers client = await create_client(request, cluster_mode) - min_version = "7.0.0" - if await check_if_server_version_lt(client, min_version): - pytest.skip(reason=f"Valkey version required >= {min_version}") + assert type(client) == GlideClusterClient assert await client.pubsub_shardnumsub([channel1, channel2, channel3]) == { channel1_bytes: 0, @@ -2620,6 +2592,7 @@ async def test_pubsub_shardnumsub(self, request, cluster_mode: bool): await client_cleanup(client4, None) await client_cleanup(client, None) + @pytest.mark.skip_if_version_below("7.0.0") @pytest.mark.parametrize("cluster_mode", [True]) async def test_pubsub_channels_and_shardchannels_separation( self, request, cluster_mode: bool @@ -2655,10 +2628,6 @@ async def test_pubsub_channels_and_shardchannels_separation( request, cluster_mode, pub_sub ) - min_version = "7.0.0" - if await check_if_server_version_lt(client1, min_version): - pytest.skip(reason=f"Valkey version required >= {min_version}") - assert type(client2) == GlideClusterClient # Test pubsub_channels assert await client2.pubsub_channels() == [regular_channel_bytes] @@ -2670,6 +2639,7 @@ async def test_pubsub_channels_and_shardchannels_separation( await client_cleanup(client1, pub_sub if cluster_mode else None) await client_cleanup(client2, None) + @pytest.mark.skip_if_version_below("7.0.0") @pytest.mark.parametrize("cluster_mode", [True]) async def test_pubsub_numsub_and_shardnumsub_separation( self, request, cluster_mode: bool @@ -2715,10 +2685,6 @@ async def test_pubsub_numsub_and_shardnumsub_separation( request, cluster_mode, pub_sub1, pub_sub2 ) - min_version = "7.0.0" - if await check_if_server_version_lt(client1, min_version): - pytest.skip(reason=f"Valkey version required >= {min_version}") - assert type(client2) == GlideClusterClient # Test pubsub_numsub diff --git a/python/python/tests/test_read_from_strategy.py b/python/python/tests/test_read_from_strategy.py new file mode 100644 index 0000000000..03f3f8e9ae --- /dev/null +++ b/python/python/tests/test_read_from_strategy.py @@ -0,0 +1,221 @@ +# Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +import re + +import pytest +from glide.async_commands.core import InfoSection +from glide.config import ProtocolVersion, ReadFrom +from glide.constants import OK +from glide.glide_client import GlideClusterClient +from glide.routes import AllNodes, SlotIdRoute, SlotType +from tests.conftest import create_client +from tests.utils.utils import get_first_result + + +@pytest.mark.asyncio +# @pytest.mark.usefixtures("multiple_replicas_cluster") +class TestAZAffinity: + async def _get_num_replicas(self, client: GlideClusterClient) -> int: + info_replicas = get_first_result( + await client.info([InfoSection.REPLICATION]) + ).decode() + match = re.search(r"connected_slaves:(\d+)", info_replicas) + if match: + return int(match.group(1)) + else: + raise ValueError( + "Could not find the number of replicas in the INFO REPLICATION response" + ) + + @pytest.mark.skip_if_version_below("8.0.0") + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_routing_with_az_affinity_strategy_to_1_replica( + self, + request, + cluster_mode: bool, + protocol: ProtocolVersion, + ): + """Test that the client with az affinity strategy will only route to the 1 replica with the same az""" + az = "us-east-1a" + GET_CALLS = 3 + get_cmdstat = f"cmdstat_get:calls={GET_CALLS}" + + client_for_config_set = await create_client( + request, + cluster_mode, + # addresses=multiple_replicas_cluster.nodes_addr, + protocol=protocol, + timeout=2000, + ) + + # Reset the availability zone for all nodes + await client_for_config_set.custom_command( + ["CONFIG", "SET", "availability-zone", ""], + route=AllNodes(), + ) + assert await client_for_config_set.config_resetstat() == OK + + # 12182 is the slot of "foo" + await client_for_config_set.custom_command( + ["CONFIG", "SET", "availability-zone", az], + route=SlotIdRoute(SlotType.REPLICA, 12182), + ) + + client_for_testing_az = await create_client( + request, + cluster_mode, + protocol=protocol, + read_from=ReadFrom.AZ_AFFINITY, + timeout=2000, + client_az=az, + ) + + for _ in range(GET_CALLS): + await client_for_testing_az.get("foo") + + info_result = await client_for_testing_az.info( + [InfoSection.SERVER, InfoSection.COMMAND_STATS], AllNodes() + ) + + # Check that only the replica with az has all the GET calls + matching_entries_count = sum( + 1 + for value in info_result.values() + if get_cmdstat in value.decode() and az in value.decode() + ) + assert matching_entries_count == 1 + + # Check that the other replicas have no availability zone set + changed_az_count = sum( + 1 + for node in info_result.values() + if f"availability_zone:{az}" in node.decode() + ) + assert changed_az_count == 1 + await client_for_testing_az.close() + await client_for_config_set.close() + + @pytest.mark.skip_if_version_below("8.0.0") + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_routing_by_slot_to_replica_with_az_affinity_strategy_to_all_replicas( + self, + request, + cluster_mode: bool, + protocol: ProtocolVersion, + ): + """Test that the client with AZ affinity strategy routes in a round-robin manner to all replicas within the specified AZ""" + + az = "us-east-1a" + client_for_config_set = await create_client( + request, + cluster_mode, + # addresses=multiple_replicas_cluster.nodes_addr, + protocol=protocol, + timeout=2000, + ) + assert await client_for_config_set.config_resetstat() == OK + await client_for_config_set.custom_command( + ["CONFIG", "SET", "availability-zone", az], AllNodes() + ) + + client_for_testing_az = await create_client( + request, + cluster_mode, + protocol=protocol, + read_from=ReadFrom.AZ_AFFINITY, + timeout=2000, + client_az=az, + ) + azs = await client_for_testing_az.custom_command( + ["CONFIG", "GET", "availability-zone"], AllNodes() + ) + + # Check that all replicas have the availability zone set to the az + assert all( + ( + node[1].decode() == az + if isinstance(node, list) + else node[b"availability-zone"].decode() == az + ) + for node in azs.values() + ) + + n_replicas = await self._get_num_replicas(client_for_testing_az) + GET_CALLS = 4 * n_replicas + get_cmdstat = f"cmdstat_get:calls={GET_CALLS // n_replicas}" + + for _ in range(GET_CALLS): + await client_for_testing_az.get("foo") + + info_result = await client_for_testing_az.info( + [InfoSection.COMMAND_STATS, InfoSection.SERVER], AllNodes() + ) + + # Check that all replicas have the same number of GET calls + matching_entries_count = sum( + 1 + for value in info_result.values() + if get_cmdstat in value.decode() and az in value.decode() + ) + assert matching_entries_count == n_replicas + + await client_for_config_set.close() + await client_for_testing_az.close() + + @pytest.mark.skip_if_version_below("8.0.0") + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_az_affinity_non_existing_az( + self, + request, + cluster_mode: bool, + protocol: ProtocolVersion, + ): + GET_CALLS = 4 + + client_for_testing_az = await create_client( + request, + cluster_mode, + # addresses=multiple_replicas_cluster.nodes_addr, + protocol=protocol, + read_from=ReadFrom.AZ_AFFINITY, + timeout=2000, + client_az="non-existing-az", + ) + assert await client_for_testing_az.config_resetstat() == OK + + for _ in range(GET_CALLS): + await client_for_testing_az.get("foo") + + n_replicas = await self._get_num_replicas(client_for_testing_az) + # We expect the calls to be distributed evenly among the replicas + get_cmdstat = f"cmdstat_get:calls={GET_CALLS // n_replicas}" + + info_result = await client_for_testing_az.info( + [InfoSection.COMMAND_STATS, InfoSection.SERVER], AllNodes() + ) + + matching_entries_count = sum( + 1 for value in info_result.values() if get_cmdstat in value.decode() + ) + assert matching_entries_count == n_replicas + + await client_for_testing_az.close() + + @pytest.mark.skip_if_version_below("8.0.0") + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_az_affinity_requires_client_az( + self, request, cluster_mode: bool, protocol: ProtocolVersion + ): + """Test that setting read_from to AZ_AFFINITY without client_az raises an error.""" + with pytest.raises(ValueError): + await create_client( + request, + cluster_mode=cluster_mode, + protocol=protocol, + read_from=ReadFrom.AZ_AFFINITY, + timeout=2000, + ) diff --git a/python/python/tests/test_transaction.py b/python/python/tests/test_transaction.py index dadef84200..ccdb309f58 100644 --- a/python/python/tests/test_transaction.py +++ b/python/python/tests/test_transaction.py @@ -1033,6 +1033,9 @@ async def test_standalone_transaction(self, glide_client: GlideClient): assert result[5:13] == [2, 2, 2, [b"Bob", b"Alice"], 2, OK, None, 0] assert result[13:] == expected + @pytest.mark.filterwarnings( + action="ignore", message="The test " + ) def test_transaction_clear(self): transaction = Transaction() transaction.info() diff --git a/python/python/tests/tests_server_modules/test_ft.py b/python/python/tests/tests_server_modules/test_ft.py new file mode 100644 index 0000000000..ee2b9416ee --- /dev/null +++ b/python/python/tests/tests_server_modules/test_ft.py @@ -0,0 +1,1092 @@ +# Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +import json +import time +import uuid +from typing import List, Mapping, Union, cast + +import pytest +from glide.async_commands.command_args import OrderBy +from glide.async_commands.server_modules import ft +from glide.async_commands.server_modules import glide_json as GlideJson +from glide.async_commands.server_modules.ft_options.ft_aggregate_options import ( + FtAggregateApply, + FtAggregateGroupBy, + FtAggregateOptions, + FtAggregateReducer, + FtAggregateSortBy, + FtAggregateSortProperty, +) +from glide.async_commands.server_modules.ft_options.ft_create_options import ( + DataType, + DistanceMetricType, + Field, + FtCreateOptions, + NumericField, + TagField, + TextField, + VectorAlgorithm, + VectorField, + VectorFieldAttributesHnsw, + VectorType, +) +from glide.async_commands.server_modules.ft_options.ft_profile_options import ( + FtProfileOptions, +) +from glide.async_commands.server_modules.ft_options.ft_search_options import ( + FtSearchOptions, + ReturnField, +) +from glide.config import ProtocolVersion +from glide.constants import OK, FtSearchResponse, TEncodable +from glide.exceptions import RequestError +from glide.glide_client import GlideClusterClient + + +@pytest.mark.asyncio +class TestFt: + SearchResultField = Mapping[ + TEncodable, Union[TEncodable, Mapping[TEncodable, Union[TEncodable, int]]] + ] + + SerchResultFieldsList = List[ + Mapping[ + TEncodable, + Union[TEncodable, Mapping[TEncodable, Union[TEncodable, int]]], + ] + ] + + sleep_wait_time = 1 # This value is in seconds + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_create(self, glide_client: GlideClusterClient): + fields: List[Field] = [ + TextField("$title"), + NumericField("$published_at"), + TextField("$category"), + ] + prefixes: List[TEncodable] = ["blog:post:"] + + # Create an index with multiple fields with Hash data type. + index = str(uuid.uuid4()) + assert ( + await ft.create( + glide_client, index, fields, FtCreateOptions(DataType.HASH, prefixes) + ) + == OK + ) + assert await ft.dropindex(glide_client, index) == OK + + # Create an index with multiple fields with JSON data type. + index2 = str(uuid.uuid4()) + assert ( + await ft.create( + glide_client, index2, fields, FtCreateOptions(DataType.JSON, prefixes) + ) + == OK + ) + assert await ft.dropindex(glide_client, index2) == OK + + # Create an index for vectors of size 2 + # FT.CREATE hash_idx1 ON HASH PREFIX 1 hash: SCHEMA vec AS VEC VECTOR HNSW 6 DIM 2 TYPE FLOAT32 DISTANCE_METRIC L2 + index3 = str(uuid.uuid4()) + prefixes = ["hash:"] + fields = [ + VectorField( + name="vec", + algorithm=VectorAlgorithm.HNSW, + attributes=VectorFieldAttributesHnsw( + dimensions=2, + distance_metric=DistanceMetricType.L2, + type=VectorType.FLOAT32, + ), + alias="VEC", + ) + ] + + assert ( + await ft.create( + glide_client, index3, fields, FtCreateOptions(DataType.HASH, prefixes) + ) + == OK + ) + assert await ft.dropindex(glide_client, index3) == OK + + # Create a 6-dimensional JSON index using the HNSW algorithm + # FT.CREATE json_idx1 ON JSON PREFIX 1 json: SCHEMA $.vec AS VEC VECTOR HNSW 6 DIM 6 TYPE FLOAT32 DISTANCE_METRIC L2 + index4 = str(uuid.uuid4()) + prefixes = ["json:"] + fields = [ + VectorField( + name="$.vec", + algorithm=VectorAlgorithm.HNSW, + attributes=VectorFieldAttributesHnsw( + dimensions=6, + distance_metric=DistanceMetricType.L2, + type=VectorType.FLOAT32, + ), + alias="VEC", + ) + ] + + assert ( + await ft.create( + glide_client, index4, fields, FtCreateOptions(DataType.JSON, prefixes) + ) + == OK + ) + assert await ft.dropindex(glide_client, index4) == OK + + # Create an index without FtCreateOptions + + index5 = str(uuid.uuid4()) + assert await ft.create(glide_client, index5, fields, FtCreateOptions()) == OK + assert await ft.dropindex(glide_client, index5) == OK + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_create_byte_type_input(self, glide_client: GlideClusterClient): + fields: List[Field] = [ + TextField(b"$title"), + NumericField(b"$published_at"), + TextField(b"$category"), + ] + prefixes: List[TEncodable] = [b"blog:post:"] + + # Create an index with multiple fields with Hash data type with byte type input. + index = str(uuid.uuid4()) + assert ( + await ft.create( + glide_client, + index.encode("utf-8"), + fields, + FtCreateOptions(DataType.HASH, prefixes), + ) + == OK + ) + assert await ft.dropindex(glide_client, index) == OK + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_dropindex(self, glide_client: GlideClusterClient): + # Index name for the index to be dropped. + index_name = str(uuid.uuid4()) + fields: List[Field] = [TextField("$title")] + prefixes: List[TEncodable] = ["blog:post:"] + + # Create an index with multiple fields with Hash data type. + assert ( + await ft.create( + glide_client, + index_name, + fields, + FtCreateOptions(DataType.HASH, prefixes), + ) + == OK + ) + + # Drop the index. Expects "OK" as a response. + assert await ft.dropindex(glide_client, index_name) == OK + + # Create an index with multiple fields with Hash data type for byte type testing + index_name_for_bytes_type_input = str(uuid.uuid4()) + assert ( + await ft.create( + glide_client, + index_name_for_bytes_type_input, + fields, + FtCreateOptions(DataType.HASH, prefixes), + ) + == OK + ) + + # Drop the index. Expects "OK" as a response. + assert ( + await ft.dropindex( + glide_client, index_name_for_bytes_type_input.encode("utf-8") + ) + == OK + ) + + # Drop a non existent index. Expects a RequestError. + with pytest.raises(RequestError): + await ft.dropindex(glide_client, index_name) + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_search(self, glide_client: GlideClusterClient): + prefix = "{json-search-" + str(uuid.uuid4()) + "}:" + json_key1 = prefix + str(uuid.uuid4()) + json_key2 = prefix + str(uuid.uuid4()) + json_value1 = {"a": 11111, "b": 2, "c": 3} + json_value2 = {"a": 22222, "b": 2, "c": 3} + index = prefix + str(uuid.uuid4()) + + # Create an index. + assert ( + await ft.create( + glide_client, + index, + schema=[ + NumericField("$.a", "a"), + NumericField("$.b", "b"), + ], + options=FtCreateOptions(DataType.JSON), + ) + == OK + ) + + # Create a json key. + assert ( + await GlideJson.set(glide_client, json_key1, "$", json.dumps(json_value1)) + == OK + ) + assert ( + await GlideJson.set(glide_client, json_key2, "$", json.dumps(json_value2)) + == OK + ) + + # Wait for index to be updated to avoid this error - ResponseError: The index is under construction. + time.sleep(self.sleep_wait_time) + + ft_search_options = FtSearchOptions( + return_fields=[ + ReturnField(field_identifier="a", alias="a_new"), + ReturnField(field_identifier="b", alias="b_new"), + ] + ) + + # Search the index for string inputs. + result1 = await ft.search(glide_client, index, "*", options=ft_search_options) + # Check if we get the expected result from ft.search for string inputs. + TestFt._ft_search_deep_compare_result( + self, + result=result1, + json_key1=json_key1, + json_key2=json_key2, + json_value1=json_value1, + json_value2=json_value2, + fieldName1="a", + fieldName2="b", + ) + + # Test FT.PROFILE for the above mentioned FT.SEARCH query and search options. + + ft_profile_result = await ft.profile( + glide_client, + index, + FtProfileOptions.from_query_options( + query="*", query_options=ft_search_options + ), + ) + assert len(ft_profile_result) > 0 + + # Check if we get the expected result from FT.PROFILE for string inputs. + TestFt._ft_search_deep_compare_result( + self, + result=cast(FtSearchResponse, ft_profile_result[0]), + json_key1=json_key1, + json_key2=json_key2, + json_value1=json_value1, + json_value2=json_value2, + fieldName1="a", + fieldName2="b", + ) + ft_search_options_bytes_input = FtSearchOptions( + return_fields=[ + ReturnField(field_identifier=b"a", alias=b"a_new"), + ReturnField(field_identifier=b"b", alias=b"b_new"), + ] + ) + + # Search the index for byte type inputs. + result2 = await ft.search( + glide_client, + index.encode("utf-8"), + b"*", + options=ft_search_options_bytes_input, + ) + + # Check if we get the expected result from ft.search for byte type inputs. + TestFt._ft_search_deep_compare_result( + self, + result=result2, + json_key1=json_key1, + json_key2=json_key2, + json_value1=json_value1, + json_value2=json_value2, + fieldName1="a", + fieldName2="b", + ) + + # Test FT.PROFILE for the above mentioned FT.SEARCH query and search options for byte type inputs. + ft_profile_result = await ft.profile( + glide_client, + index.encode("utf-8"), + FtProfileOptions.from_query_options( + query=b"*", query_options=ft_search_options_bytes_input + ), + ) + assert len(ft_profile_result) > 0 + + # Check if we get the expected result from FT.PROFILE for byte type inputs. + TestFt._ft_search_deep_compare_result( + self, + result=cast(FtSearchResponse, ft_profile_result[0]), + json_key1=json_key1, + json_key2=json_key2, + json_value1=json_value1, + json_value2=json_value2, + fieldName1="a", + fieldName2="b", + ) + + assert await ft.dropindex(glide_client, index) == OK + + def _ft_search_deep_compare_result( + self, + result: List[Union[int, Mapping[TEncodable, Mapping[TEncodable, TEncodable]]]], + json_key1: str, + json_key2: str, + json_value1: dict, + json_value2: dict, + fieldName1: str, + fieldName2: str, + ): + """ + Deep compare the keys and values in FT.SEARCH result array. + + Args: + result (List[Union[int, Mapping[TEncodable, Mapping[TEncodable, TEncodable]]]]): + json_key1 (str): The first key in search result. + json_key2 (str): The second key in the search result. + json_value1 (dict): The fields map for first key in the search result. + json_value2 (dict): The fields map for second key in the search result. + """ + assert len(result) == 2 + assert result[0] == 2 + search_result_map: Mapping[TEncodable, Mapping[TEncodable, TEncodable]] = cast( + Mapping[TEncodable, Mapping[TEncodable, TEncodable]], result[1] + ) + expected_result_map: Mapping[TEncodable, Mapping[TEncodable, TEncodable]] = { + json_key1.encode(): { + fieldName1.encode(): str(json_value1.get(fieldName1)).encode(), + fieldName2.encode(): str(json_value1.get(fieldName2)).encode(), + }, + json_key2.encode(): { + fieldName1.encode(): str(json_value2.get(fieldName1)).encode(), + fieldName2.encode(): str(json_value2.get(fieldName2)).encode(), + }, + } + assert search_result_map == expected_result_map + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_aliasadd(self, glide_client: GlideClusterClient): + index_name: str = str(uuid.uuid4()) + alias: str = "alias" + # Test ft.aliasadd throws an error if index does not exist. + with pytest.raises(RequestError): + await ft.aliasadd(glide_client, alias, index_name) + + # Test ft.aliasadd successfully adds an alias to an existing index. + await TestFt._create_test_index_hash_type(self, glide_client, index_name) + assert await ft.aliasadd(glide_client, alias, index_name) == OK + assert await ft.dropindex(glide_client, index_name) == OK + + # Test ft.aliasadd for input of bytes type. + index_name_string = str(uuid.uuid4()) + index_names_bytes = index_name_string.encode("utf-8") + alias_name_bytes = b"alias-bytes" + await TestFt._create_test_index_hash_type(self, glide_client, index_name_string) + assert ( + await ft.aliasadd(glide_client, alias_name_bytes, index_names_bytes) == OK + ) + assert await ft.dropindex(glide_client, index_name_string) == OK + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_aliasdel(self, glide_client: GlideClusterClient): + index_name: TEncodable = str(uuid.uuid4()) + alias: str = "alias" + await TestFt._create_test_index_hash_type(self, glide_client, index_name) + + # Test if deleting a non existent alias throws an error. + with pytest.raises(RequestError): + await ft.aliasdel(glide_client, alias) + + # Test if an existing alias is deleted successfully. + assert await ft.aliasadd(glide_client, alias, index_name) == OK + assert await ft.aliasdel(glide_client, alias) == OK + + # Test if an existing alias is deleted successfully for bytes type input. + assert await ft.aliasadd(glide_client, alias, index_name) == OK + assert await ft.aliasdel(glide_client, alias.encode("utf-8")) == OK + + assert await ft.dropindex(glide_client, index_name) == OK + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_aliasupdate(self, glide_client: GlideClusterClient): + index_name: str = str(uuid.uuid4()) + alias: str = "alias" + await TestFt._create_test_index_hash_type(self, glide_client, index_name) + assert await ft.aliasadd(glide_client, alias, index_name) == OK + new_alias_name: str = "newAlias" + new_index_name: str = str(uuid.uuid4()) + + await TestFt._create_test_index_hash_type(self, glide_client, new_index_name) + assert await ft.aliasadd(glide_client, new_alias_name, new_index_name) == OK + + # Test if updating an already existing alias to point to an existing index returns "OK". + assert await ft.aliasupdate(glide_client, new_alias_name, index_name) == OK + assert ( + await ft.aliasupdate( + glide_client, alias.encode("utf-8"), new_index_name.encode("utf-8") + ) + == OK + ) + + assert await ft.dropindex(glide_client, index_name) == OK + assert await ft.dropindex(glide_client, new_index_name) == OK + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_dropindex_ft_list(self, glide_client: GlideClusterClient): + indexName = str(uuid.uuid4()).encode() + await TestFt._create_test_index_hash_type(self, glide_client, indexName) + + before = await ft.list(glide_client) + assert indexName in before + + assert await ft.dropindex(glide_client, indexName) == OK + after = await ft.list(glide_client) + assert indexName not in after + + assert {_ for _ in after + [indexName]} == {_ for _ in before} + + # Drop a non existent index. Expects a RequestError. + with pytest.raises(RequestError): + await ft.dropindex(glide_client, indexName) + + async def _create_test_index_hash_type( + self, glide_client: GlideClusterClient, index_name: TEncodable + ): + # Helper function used for creating a basic index with hash data type with one text field. + fields: List[Field] = [TextField("title")] + prefix = "{hash-search-" + str(uuid.uuid4()) + "}:" + prefixes: List[TEncodable] = [prefix] + result = await ft.create( + glide_client, index_name, fields, FtCreateOptions(DataType.HASH, prefixes) + ) + assert result == OK + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_info(self, glide_client: GlideClusterClient): + index_name = str(uuid.uuid4()) + await TestFt._create_test_index_with_vector_field( + self, glide_client, index_name + ) + result = await ft.info(glide_client, index_name) + assert await ft.dropindex(glide_client, index_name) == OK + TestFt._ft_info_deep_compare_result(self, index_name, result) + + # Test for bytes type input. + index_name_for_bytes_input = str(uuid.uuid4()) + await TestFt._create_test_index_with_vector_field( + self, glide_client, index_name_for_bytes_input + ) + result = await ft.info(glide_client, index_name_for_bytes_input.encode("utf-8")) + assert await ft.dropindex(glide_client, index_name_for_bytes_input) == OK + TestFt._ft_info_deep_compare_result(self, index_name_for_bytes_input, result) + + # Querying a missing index throws an error. + with pytest.raises(RequestError): + await ft.info(glide_client, str(uuid.uuid4())) + + def _ft_info_deep_compare_result(self, index_name: str, result): + assert index_name.encode() == result.get(b"index_name") + assert b"JSON" == result.get(b"key_type") + assert [b"key-prefix"] == result.get(b"key_prefixes") + + # Get vector and text fields from the fields array. + fields: TestFt.SerchResultFieldsList = cast( + TestFt.SerchResultFieldsList, result.get(b"fields") + ) + assert len(fields) == 2 + text_field: TestFt.SearchResultField = {} + vector_field: TestFt.SearchResultField = {} + if fields[0].get(b"type") == b"VECTOR": + vector_field = cast(TestFt.SearchResultField, fields[0]) + text_field = cast(TestFt.SearchResultField, fields[1]) + else: + vector_field = cast(TestFt.SearchResultField, fields[1]) + text_field = cast(TestFt.SearchResultField, fields[0]) + + # Compare vector field arguments + assert b"$.vec" == vector_field.get(b"identifier") + assert b"VECTOR" == vector_field.get(b"type") + assert b"VEC" == vector_field.get(b"field_name") + vector_field_params: Mapping[TEncodable, Union[TEncodable, int]] = cast( + Mapping[TEncodable, Union[TEncodable, int]], + vector_field.get(b"vector_params"), + ) + assert DistanceMetricType.L2.value.encode() == vector_field_params.get( + b"distance_metric" + ) + assert 2 == vector_field_params.get(b"dimension") + assert b"HNSW" == vector_field_params.get(b"algorithm") + assert b"FLOAT32" == vector_field_params.get(b"data_type") + + # Compare text field arguments. + assert b"$.text-field" == text_field.get(b"identifier") + assert b"TEXT" == text_field.get(b"type") + assert b"text-field" == text_field.get(b"field_name") + + async def _create_test_index_with_vector_field( + self, glide_client: GlideClusterClient, index_name: TEncodable + ): + # Helper function used for creating an index with JSON data type with a text and vector field. + fields: List[Field] = [ + VectorField( + name="$.vec", + algorithm=VectorAlgorithm.HNSW, + attributes=VectorFieldAttributesHnsw( + dimensions=2, + distance_metric=DistanceMetricType.L2, + type=VectorType.FLOAT32, + ), + alias="VEC", + ), + TextField("$.text-field", "text-field"), + ] + + prefixes: List[TEncodable] = ["key-prefix"] + + await ft.create( + glide_client, + index_name, + schema=fields, + options=FtCreateOptions(DataType.JSON, prefixes=prefixes), + ) + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_explain(self, glide_client: GlideClusterClient): + index_name = str(uuid.uuid4()) + await TestFt._create_test_index_for_ft_explain_commands( + self, glide_client, index_name + ) + + # FT.EXPLAIN on a search query containing numeric field. + query = "@price:[0 10]" + result = await ft.explain(glide_client, index_name, query) + result_string = cast(bytes, result).decode(encoding="utf-8") + assert ( + "price" in result_string and "0" in result_string and "10" in result_string + ) + + # FT.EXPLAIN on a search query containing numeric field and having bytes type input to the command. + result = await ft.explain(glide_client, index_name.encode(), query.encode()) + result_string = cast(bytes, result).decode(encoding="utf-8") + assert ( + "price" in result_string and "0" in result_string and "10" in result_string + ) + + # FT.EXPLAIN on a search query that returns all data. + result = await ft.explain(glide_client, index_name, query="*") + result_string = cast(bytes, result).decode(encoding="utf-8") + assert "*" in result_string + + assert await ft.dropindex(glide_client, index_name) + + # FT.EXPLAIN on a missing index throws an error. + with pytest.raises(RequestError): + await ft.explain(glide_client, str(uuid.uuid4()), query="*") + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_explaincli(self, glide_client: GlideClusterClient): + index_name = str(uuid.uuid4()) + await TestFt._create_test_index_for_ft_explain_commands( + self, glide_client, index_name + ) + + # FT.EXPLAINCLI on a search query containing numeric field. + query = "@price:[0 10]" + result = await ft.explaincli(glide_client, index_name, query) + result_string_arr = [] + for i in result: + result_string_arr.append(cast(bytes, i).decode(encoding="utf-8").strip()) + assert ( + "price" in result_string_arr + and "0" in result_string_arr + and "10" in result_string_arr + ) + + # FT.EXPLAINCLI on a search query containing numeric field and having bytes type input to the command. + result = await ft.explaincli(glide_client, index_name.encode(), query.encode()) + result_string_arr = [] + for i in result: + result_string_arr.append(cast(bytes, i).decode(encoding="utf-8").strip()) + assert ( + "price" in result_string_arr + and "0" in result_string_arr + and "10" in result_string_arr + ) + + # FT.EXPLAINCLI on a search query that returns all data. + result = await ft.explaincli(glide_client, index_name, query="*") + result_string_arr = [] + for i in result: + result_string_arr.append(cast(bytes, i).decode(encoding="utf-8").strip()) + assert "*" in result_string_arr + + assert await ft.dropindex(glide_client, index_name) + + # FT.EXPLAINCLI on a missing index throws an error. + with pytest.raises(RequestError): + await ft.explaincli(glide_client, str(uuid.uuid4()), "*") + + async def _create_test_index_for_ft_explain_commands( + self, glide_client: GlideClusterClient, index_name: TEncodable + ): + # Helper function used for creating an index having hash data type, one text field and one numeric field. + fields: List[Field] = [TextField("title"), NumericField("price")] + prefix = "{hash-search-" + str(uuid.uuid4()) + "}:" + prefixes: List[TEncodable] = [prefix] + + assert ( + await ft.create( + glide_client, + index_name, + fields, + FtCreateOptions(DataType.HASH, prefixes), + ) + == OK + ) + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_aggregate_with_bicycles_data( + self, glide_client: GlideClusterClient, protocol + ): + prefix_bicycles = "{bicycles}:" + index_bicycles = prefix_bicycles + str(uuid.uuid4()) + await TestFt._create_index_for_ft_aggregate_with_bicycles_data( + self, + glide_client, + index_bicycles, + prefix_bicycles, + ) + await TestFt._create_json_keys_for_ft_aggregate_with_bicycles_data( + self, glide_client, prefix_bicycles + ) + time.sleep(self.sleep_wait_time) + + ft_aggregate_options: FtAggregateOptions = FtAggregateOptions( + loadFields=["__key"], + clauses=[ + FtAggregateGroupBy( + ["@condition"], [FtAggregateReducer("COUNT", [], "bicycles")] + ) + ], + ) + + # Run FT.AGGREGATE command with the following arguments: ['FT.AGGREGATE', '{bicycles}:1e15faab-a870-488e-b6cd-f2b76c6916a3', '*', 'LOAD', '1', '__key', 'GROUPBY', '1', '@condition', 'REDUCE', 'COUNT', '0', 'AS', 'bicycles'] + result = await ft.aggregate( + glide_client, + index_bicycles, + query="*", + options=ft_aggregate_options, + ) + sorted_result = sorted(result, key=lambda x: (x[b"condition"], x[b"bicycles"])) + expected_result = sorted( + [ + { + b"condition": b"refurbished", + b"bicycles": b"1" if (protocol == ProtocolVersion.RESP2) else 1.0, + }, + { + b"condition": b"new", + b"bicycles": b"5" if (protocol == ProtocolVersion.RESP2) else 5.0, + }, + { + b"condition": b"used", + b"bicycles": b"4" if (protocol == ProtocolVersion.RESP2) else 4.0, + }, + ], + key=lambda x: (x[b"condition"], x[b"bicycles"]), + ) + assert sorted_result == expected_result + + # Test FT.PROFILE for the above mentioned FT.AGGREGATE query + ft_profile_result = await ft.profile( + glide_client, + index_bicycles, + FtProfileOptions.from_query_options( + query="*", query_options=ft_aggregate_options + ), + ) + assert len(ft_profile_result) > 0 + assert ( + sorted( + ft_profile_result[0], key=lambda x: (x[b"condition"], x[b"bicycles"]) + ) + == expected_result + ) + + assert await ft.dropindex(glide_client, index_bicycles) == OK + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_aggregate_with_movies_data( + self, glide_client: GlideClusterClient, protocol + ): + prefix_movies = "{movies}:" + index_movies = prefix_movies + str(uuid.uuid4()) + # Create index for movies data. + await TestFt._create_index_for_ft_aggregate_with_movies_data( + self, + glide_client, + index_movies, + prefix_movies, + ) + # Set JSON keys with movies data. + await TestFt._create_hash_keys_for_ft_aggregate_with_movies_data( + self, glide_client, prefix_movies + ) + # Wait for index to be updated. + time.sleep(self.sleep_wait_time) + + # Run FT.AGGREGATE command with the following arguments: + # ['FT.AGGREGATE', '{movies}:5a0e6257-3488-4514-96f2-f4c80f6cb0a9', '*', 'LOAD', '*', 'APPLY', 'ceil(@rating)', 'AS', 'r_rating', 'GROUPBY', '1', '@genre', 'REDUCE', 'COUNT', '0', 'AS', 'nb_of_movies', 'REDUCE', 'SUM', '1', 'votes', 'AS', 'nb_of_votes', 'REDUCE', 'AVG', '1', 'r_rating', 'AS', 'avg_rating', 'SORTBY', '4', '@avg_rating', 'DESC', '@nb_of_votes', 'DESC'] + # Testing for bytes type input. + ft_aggregate_options: FtAggregateOptions = FtAggregateOptions( + loadAll=True, + clauses=[ + FtAggregateApply(expression=b"ceil(@rating)", name=b"r_rating"), + FtAggregateGroupBy( + [b"@genre"], + [ + FtAggregateReducer(b"COUNT", [], b"nb_of_movies"), + FtAggregateReducer(b"SUM", [b"votes"], b"nb_of_votes"), + FtAggregateReducer(b"AVG", [b"r_rating"], b"avg_rating"), + ], + ), + FtAggregateSortBy( + properties=[ + FtAggregateSortProperty(b"@avg_rating", OrderBy.DESC), + FtAggregateSortProperty(b"@nb_of_votes", OrderBy.DESC), + ] + ), + ], + ) + result = await ft.aggregate( + glide_client, + index_name=index_movies.encode("utf-8"), + query=b"*", + options=ft_aggregate_options, + ) + sorted_result = sorted( + result, + key=lambda x: ( + x[b"genre"], + x[b"nb_of_movies"], + x[b"nb_of_votes"], + x[b"avg_rating"], + ), + ) + expected_result = sorted( + [ + { + b"genre": b"Drama", + b"nb_of_movies": ( + b"1" if (protocol == ProtocolVersion.RESP2) else 1.0 + ), + b"nb_of_votes": ( + b"1563839" if (protocol == ProtocolVersion.RESP2) else 1563839.0 + ), + b"avg_rating": ( + b"10" if (protocol == ProtocolVersion.RESP2) else 10.0 + ), + }, + { + b"genre": b"Action", + b"nb_of_movies": ( + b"2" if (protocol == ProtocolVersion.RESP2) else 2.0 + ), + b"nb_of_votes": ( + b"2033895" if (protocol == ProtocolVersion.RESP2) else 2033895.0 + ), + b"avg_rating": b"9" if (protocol == ProtocolVersion.RESP2) else 9.0, + }, + { + b"genre": b"Thriller", + b"nb_of_movies": ( + b"1" if (protocol == ProtocolVersion.RESP2) else 1.0 + ), + b"nb_of_votes": ( + b"559490" if (protocol == ProtocolVersion.RESP2) else 559490.0 + ), + b"avg_rating": b"9" if (protocol == ProtocolVersion.RESP2) else 9.0, + }, + ], + key=lambda x: ( + x[b"genre"], + x[b"nb_of_movies"], + x[b"nb_of_votes"], + x[b"avg_rating"], + ), + ) + assert expected_result == sorted_result + + # Test FT.PROFILE for the above mentioned FT.AGGREGATE query + ft_profile_result = await ft.profile( + glide_client, + index_movies, + FtProfileOptions.from_query_options( + query="*", query_options=ft_aggregate_options + ), + ) + assert len(ft_profile_result) > 0 + assert ( + sorted( + ft_profile_result[0], + key=lambda x: ( + x[b"genre"], + x[b"nb_of_movies"], + x[b"nb_of_votes"], + x[b"avg_rating"], + ), + ) + == expected_result + ) + + assert await ft.dropindex(glide_client, index_movies) == OK + + async def _create_index_for_ft_aggregate_with_bicycles_data( + self, glide_client: GlideClusterClient, index_name: TEncodable, prefix + ): + fields: List[Field] = [ + TextField("$.model", "model"), + TextField("$.description", "description"), + NumericField("$.price", "price"), + TagField("$.condition", "condition", ","), + ] + assert ( + await ft.create( + glide_client, + index_name, + fields, + FtCreateOptions(DataType.JSON, prefixes=[prefix]), + ) + == OK + ) + + async def _create_json_keys_for_ft_aggregate_with_bicycles_data( + self, glide_client: GlideClusterClient, prefix + ): + assert ( + await GlideJson.set( + glide_client, + prefix + "0", + ".", + '{"brand": "Velorim", "model": "Jigger", "price": 270, "description":' + + ' "Small and powerful, the Jigger is the best ride for the smallest of tikes!' + + " This is the tiniest kids\\u2019 pedal bike on the market available without a" + + " coaster brake, the Jigger is the vehicle of choice for the rare tenacious" + + ' little rider raring to go.", "condition": "new"}', + ) + == OK + ) + + assert ( + await GlideJson.set( + glide_client, + prefix + "1", + ".", + '{"brand": "Bicyk", "model": "Hillcraft", "price": 1200, "condition": "used"}', + ) + == OK + ) + + assert ( + await GlideJson.set( + glide_client, + prefix + "2", + ".", + '{"brand": "Nord", "model": "Chook air 5", "price": 815, "condition": "used"}', + ) + == OK + ) + + assert ( + await GlideJson.set( + glide_client, + prefix + "3", + ".", + '{"brand": "Eva", "model": "Eva 291", "price": 3400, "condition": "used"}', + ) + == OK + ) + + assert ( + await GlideJson.set( + glide_client, + prefix + "4", + ".", + '{"brand": "Noka Bikes", "model": "Kahuna", "price": 3200, "condition": "used"}', + ) + == OK + ) + + assert ( + await GlideJson.set( + glide_client, + prefix + "5", + ".", + '{"brand": "Breakout", "model": "XBN 2.1 Alloy", "price": 810, "condition": "new"}', + ) + == OK + ) + + assert ( + await GlideJson.set( + glide_client, + prefix + "6", + ".", + '{"brand": "ScramBikes", "model": "WattBike", "price": 2300, "condition": "new"}', + ) + == OK + ) + + assert ( + await GlideJson.set( + glide_client, + prefix + "7", + ".", + '{"brand": "Peaknetic", "model": "Secto", "price": 430, "condition": "new"}', + ) + == OK + ) + + assert ( + await GlideJson.set( + glide_client, + prefix + "8", + ".", + '{"brand": "nHill", "model": "Summit", "price": 1200, "condition": "new"}', + ) + == OK + ) + + assert ( + await GlideJson.set( + glide_client, + prefix + "9", + ".", + '{"model": "ThrillCycle", "brand": "BikeShind", "price": 815, "condition": "refurbished"}', + ) + == OK + ) + + async def _create_index_for_ft_aggregate_with_movies_data( + self, glide_client: GlideClusterClient, index_name: TEncodable, prefix + ): + fields: List[Field] = [ + TextField("title"), + NumericField("release_year"), + NumericField("rating"), + TagField("genre"), + NumericField("votes"), + ] + assert ( + await ft.create( + glide_client, + index_name, + fields, + FtCreateOptions(DataType.HASH, prefixes=[prefix]), + ) + == OK + ) + + async def _create_hash_keys_for_ft_aggregate_with_movies_data( + self, glide_client: GlideClusterClient, prefix + ): + await glide_client.hset( + prefix + "11002", + { + "title": "Star Wars: Episode V - The Empire Strikes Back", + "release_year": "1980", + "genre": "Action", + "rating": "8.7", + "votes": "1127635", + "imdb_id": "tt0080684", + }, + ) + + await glide_client.hset( + prefix + "11003", + { + "title": "The Godfather", + "release_year": "1972", + "genre": "Drama", + "rating": "9.2", + "votes": "1563839", + "imdb_id": "tt0068646", + }, + ) + + await glide_client.hset( + prefix + "11004", + { + "title": "Heat", + "release_year": "1995", + "genre": "Thriller", + "rating": "8.2", + "votes": "559490", + "imdb_id": "tt0113277", + }, + ) + + await glide_client.hset( + prefix + "11005", + { + "title": "Star Wars: Episode VI - Return of the Jedi", + "release_year": "1983", + "genre": "Action", + "rating": "8.3", + "votes": "906260", + "imdb_id": "tt0086190", + }, + ) + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_ft_aliaslist(self, glide_client: GlideClusterClient): + index_name: str = str(uuid.uuid4()) + alias: str = "alias" + # Create an index and add an alias. + await TestFt._create_test_index_hash_type(self, glide_client, index_name) + assert await ft.aliasadd(glide_client, alias, index_name) == OK + + # Create a second index and add an alias. + index_name_string = str(uuid.uuid4()) + index_name_bytes = bytes(index_name_string, "utf-8") + alias_name_bytes = b"alias-bytes" + await TestFt._create_test_index_hash_type(self, glide_client, index_name_string) + assert await ft.aliasadd(glide_client, alias_name_bytes, index_name_bytes) == OK + + # List all aliases. + result = await ft.aliaslist(glide_client) + assert result == { + b"alias": index_name.encode("utf-8"), + b"alias-bytes": index_name_bytes, + } + + # Drop all indexes. + assert await ft.dropindex(glide_client, index_name) == OK + assert await ft.dropindex(glide_client, index_name_string) == OK diff --git a/python/python/tests/tests_server_modules/test_json.py b/python/python/tests/tests_server_modules/test_json.py index a69c3010e2..85657914de 100644 --- a/python/python/tests/tests_server_modules/test_json.py +++ b/python/python/tests/tests_server_modules/test_json.py @@ -1,11 +1,18 @@ # Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 +import copy import json as OuterJson +import random +import typing import pytest from glide.async_commands.core import ConditionalChange, InfoSection -from glide.async_commands.server_modules import json -from glide.async_commands.server_modules.json import JsonGetOptions +from glide.async_commands.server_modules import glide_json as json +from glide.async_commands.server_modules.glide_json import ( + JsonArrIndexOptions, + JsonArrPopOptions, + JsonGetOptions, +) from glide.config import ProtocolVersion from glide.constants import OK from glide.exceptions import RequestError @@ -13,6 +20,19 @@ from tests.test_async_client import get_random_string, parse_info_response +def get_random_value(value_type="str"): + if value_type == "int": + return random.randint(1, 100) + elif value_type == "float": + return round(random.uniform(1, 100), 2) + elif value_type == "str": + return "".join(random.choices("abcdefghijklmnopqrstuvwxyz", k=5)) + elif value_type == "bool": + return random.choice([True, False]) + elif value_type == "null": + return None + + @pytest.mark.asyncio class TestJson: @pytest.mark.parametrize("cluster_mode", [True, False]) @@ -144,42 +164,260 @@ async def test_json_get_formatting(self, glide_client: TGlideClient): @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) - async def test_del(self, glide_client: TGlideClient): + async def test_json_mget(self, glide_client: TGlideClient): + key1 = get_random_string(5) + key2 = get_random_string(5) + + json1_value = {"a": 1.0, "b": {"a": 1, "b": 2.5, "c": True}} + json2_value = {"a": 3.0, "b": {"a": 1, "b": 4}} + + assert ( + await json.set(glide_client, key1, "$", OuterJson.dumps(json1_value)) == OK + ) + assert ( + await json.set(glide_client, key2, "$", OuterJson.dumps(json2_value)) == OK + ) + + # Test with root JSONPath + result = await json.mget( + glide_client, + [key1, key2], + "$", + ) + expected_result = [ + b'[{"a":1.0,"b":{"a":1,"b":2.5,"c":true}}]', + b'[{"a":3.0,"b":{"a":1,"b":4}}]', + ] + assert result == expected_result + + # Retrieves the full JSON objects from multiple keys. + result = await json.mget( + glide_client, + [key1, key2], + ".", + ) + expected_result = [ + b'{"a":1.0,"b":{"a":1,"b":2.5,"c":true}}', + b'{"a":3.0,"b":{"a":1,"b":4}}', + ] + assert result == expected_result + + result = await json.mget( + glide_client, + [key1, key2], + "$.a", + ) + expected_result = [b"[1.0]", b"[3.0]"] + assert result == expected_result + + # Retrieves the value of the 'b' field for multiple keys. + result = await json.mget( + glide_client, + [key1, key2], + "$.b", + ) + expected_result = [b'[{"a":1,"b":2.5,"c":true}]', b'[{"a":1,"b":4}]'] + assert result == expected_result + + # Retrieves all values of 'b' fields using recursive path for multiple keys + result = await json.mget( + glide_client, + [key1, key2], + "$..b", + ) + expected_result = [b'[{"a":1,"b":2.5,"c":true},2.5]', b'[{"a":1,"b":4},4]'] + assert result == expected_result + + # retrieves the value of the nested 'b.b' field for multiple keys + result = await json.mget( + glide_client, + [key1, key2], + ".b.b", + ) + expected_result = [b"2.5", b"4"] + assert result == expected_result + + # JSONPath that exists in only one of the keys + result = await json.mget( + glide_client, + [key1, key2], + "$.b.c", + ) + expected_result = [b"[true]", b"[]"] + assert result == expected_result + + # Legacy path that exists in only one of the keys + result = await json.mget( + glide_client, + [key1, key2], + ".b.c", + ) + expected_result = [b"true", None] + assert result == expected_result + + # JSONPath doesn't exist + result = await json.mget( + glide_client, + [key1, key2], + "$non_existing_path", + ) + expected_result = [b"[]", b"[]"] + assert result == expected_result + + # Legacy path doesn't exist + result = await json.mget( + glide_client, + [key1, key2], + ".non_existing_path", + ) + assert result == [None, None] + + # JSONPath one key doesn't exist + result = await json.mget( + glide_client, + [key1, "{non_existing_key}"], + "$.a", + ) + assert result == [b"[1.0]", None] + + # Legacy path one key doesn't exist + result = await json.mget( + glide_client, + [key1, "{non_existing_key}"], + ".a", + ) + assert result == [b"1.0", None] + + # Both keys don't exist + result = await json.mget( + glide_client, + ["{non_existing_key}1", "{non_existing_key}2"], + "$a", + ) + assert result == [None, None] + + # Test with only one key + result = await json.mget( + glide_client, + [key1], + "$.a", + ) + expected_result = [b"[1.0]"] + assert result == expected_result + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_del(self, glide_client: TGlideClient): key = get_random_string(5) json_value = {"a": 1.0, "b": {"a": 1, "b": 2.5, "c": True}} assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + # Non-exiseting paths + assert await json.delete(glide_client, key, "$..path") == 0 + assert await json.delete(glide_client, key, "..path") == 0 + assert await json.delete(glide_client, key, "$..a") == 2 assert await json.get(glide_client, key, "$..a") == b"[]" + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + + assert await json.delete(glide_client, key, "..a") == 2 + with pytest.raises(RequestError): + assert await json.get(glide_client, key, "..a") + result = await json.get(glide_client, key, "$") assert isinstance(result, bytes) assert OuterJson.loads(result) == [{"b": {"b": 2.5, "c": True}}] assert await json.delete(glide_client, key, "$") == 1 + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + assert await json.delete(glide_client, key, ".") == 1 + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + assert await json.delete(glide_client, key) == 1 assert await json.delete(glide_client, key) == 0 assert await json.get(glide_client, key, "$") == None + # Non-existing keys + assert await json.delete(glide_client, "non_existing_key", "$") == 0 + assert await json.delete(glide_client, "non_existing_key", ".") == 0 + @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) - async def test_forget(self, glide_client: TGlideClient): + async def test_json_forget(self, glide_client: TGlideClient): key = get_random_string(5) json_value = {"a": 1.0, "b": {"a": 1, "b": 2.5, "c": True}} assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + # Non-existing paths + assert await json.forget(glide_client, key, "$..path") == 0 + assert await json.forget(glide_client, key, "..path") == 0 + assert await json.forget(glide_client, key, "$..a") == 2 assert await json.get(glide_client, key, "$..a") == b"[]" + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + + assert await json.forget(glide_client, key, "..a") == 2 + with pytest.raises(RequestError): + assert await json.get(glide_client, key, "..a") + result = await json.get(glide_client, key, "$") assert isinstance(result, bytes) assert OuterJson.loads(result) == [{"b": {"b": 2.5, "c": True}}] assert await json.forget(glide_client, key, "$") == 1 + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + assert await json.forget(glide_client, key, ".") == 1 + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + assert await json.forget(glide_client, key) == 1 assert await json.forget(glide_client, key) == 0 assert await json.get(glide_client, key, "$") == None + # Non-existing keys + assert await json.forget(glide_client, "non_existing_key", "$") == 0 + assert await json.forget(glide_client, "non_existing_key", ".") == 0 + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_objkeys(self, glide_client: TGlideClient): + key = get_random_string(5) + + json_value = {"a": 1.0, "b": {"a": {"x": 1, "y": 2}, "b": 2.5, "c": True}} + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + + keys = await json.objkeys(glide_client, key, "$") + assert keys == [[b"a", b"b"]] + + keys = await json.objkeys(glide_client, key, ".") + assert keys == [b"a", b"b"] + + keys = await json.objkeys(glide_client, key, "$..") + assert keys == [[b"a", b"b"], [b"a", b"b", b"c"], [b"x", b"y"]] + + keys = await json.objkeys(glide_client, key, "..") + assert keys == [b"a", b"b"] + + keys = await json.objkeys(glide_client, key, "$..b") + assert keys == [[b"a", b"b", b"c"], []] + + keys = await json.objkeys(glide_client, key, "..b") + assert keys == [b"a", b"b", b"c"] + + # path doesn't exist + assert await json.objkeys(glide_client, key, "$.non_existing_path") == [] + assert await json.objkeys(glide_client, key, "non_existing_path") == None + + # Value at path isnt an object + assert await json.objkeys(glide_client, key, "$.a") == [[]] + with pytest.raises(RequestError): + assert await json.objkeys(glide_client, key, ".a") + + # Non-existing key + assert await json.objkeys(glide_client, "non_exiting_key", "$") == None + assert await json.objkeys(glide_client, "non_exiting_key", ".") == None + @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) async def test_json_toggle(self, glide_client: TGlideClient): @@ -189,11 +427,15 @@ async def test_json_toggle(self, glide_client: TGlideClient): assert await json.toggle(glide_client, key, "$..bool") == [False, True, None] assert await json.toggle(glide_client, key, "bool") is True + assert await json.toggle(glide_client, key, "$.not_existing") == [] assert await json.toggle(glide_client, key, "$.nested") == [None] with pytest.raises(RequestError): assert await json.toggle(glide_client, key, "nested") + with pytest.raises(RequestError): + assert await json.toggle(glide_client, key, ".not_existing") + with pytest.raises(RequestError): assert await json.toggle(glide_client, "non_exiting_key", "$") @@ -277,6 +519,56 @@ async def test_json_type(self, glide_client: TGlideClient): result = await json.type(glide_client, key, "[*]") assert result == b"string" # Expecting only the first type (string for key1) + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_objlen(self, glide_client: TGlideClient): + key = get_random_string(5) + + json_value = {"a": 1.0, "b": {"a": {"x": 1, "y": 2}, "b": 2.5, "c": True}} + + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + + len = await json.objlen(glide_client, key, "$") + assert len == [2] + + len = await json.objlen(glide_client, key, ".") + assert len == 2 + + len = await json.objlen(glide_client, key, "$..") + assert len == [2, 3, 2] + + len = await json.objlen(glide_client, key, "..") + assert len == 2 + + len = await json.objlen(glide_client, key, "$..b") + assert len == [3, None] + + len = await json.objlen(glide_client, key, "..b") + assert len == 3 + + len = await json.objlen(glide_client, key, "..a") + assert len == 2 + + len = await json.objlen(glide_client, key) + assert len == 2 + + # path doesn't exist + assert await json.objlen(glide_client, key, "$.non_existing_path") == [] + with pytest.raises(RequestError): + await json.objlen(glide_client, key, "non_existing_path") + + # Value at path isnt an object + assert await json.objlen(glide_client, key, "$.a") == [None] + with pytest.raises(RequestError): + await json.objlen(glide_client, key, ".a") + + # Non-existing key + assert await json.objlen(glide_client, "non_exiting_key", "$") == None + assert await json.objlen(glide_client, "non_exiting_key", ".") == None + + assert await json.set(glide_client, key, "$", '{"a": 1, "b": 2, "c":3, "d":4}') + assert await json.objlen(glide_client, key) == 4 + @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) async def test_json_arrlen(self, glide_client: TGlideClient): @@ -312,3 +604,1496 @@ async def test_json_arrlen(self, glide_client: TGlideClient): assert await json.set(glide_client, key, "$", "[1, 2, 3, 4]") == OK assert await json.arrlen(glide_client, key) == 4 + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_clear(self, glide_client: TGlideClient): + key = get_random_string(5) + + json_value = '{"obj":{"a":1, "b":2}, "arr":[1,2,3], "str": "foo", "bool": true, "int": 42, "float": 3.14, "nullVal": null}' + assert await json.set(glide_client, key, "$", json_value) == OK + + assert await json.clear(glide_client, key, "$.*") == 6 + result = await json.get(glide_client, key, "$") + assert ( + result + == b'[{"obj":{},"arr":[],"str":"","bool":false,"int":0,"float":0.0,"nullVal":null}]' + ) + assert await json.clear(glide_client, key, "$.*") == 0 + + assert await json.set(glide_client, key, "$", json_value) == OK + assert await json.clear(glide_client, key, "*") == 6 + + json_value = '{"a": 1, "b": {"a": [5, 6, 7], "b": {"a": true}}, "c": {"a": "value", "b": {"a": 3.5}}, "d": {"a": {"foo": "foo"}}, "nullVal": null}' + assert await json.set(glide_client, key, "$", json_value) == OK + + assert await json.clear(glide_client, key, "b.a[1:3]") == 2 + assert await json.clear(glide_client, key, "b.a[1:3]") == 0 + assert ( + await json.get(glide_client, key, "$..a") + == b'[1,[5,0,0],true,"value",3.5,{"foo":"foo"}]' + ) + assert await json.clear(glide_client, key, "..a") == 6 + assert await json.get(glide_client, key, "$..a") == b'[0,[],false,"",0.0,{}]' + + assert await json.clear(glide_client, key, "$..a") == 0 + + # Path doesn't exists + assert await json.clear(glide_client, key, "$.path") == 0 + assert await json.clear(glide_client, key, "path") == 0 + + # Key doesn't exists + with pytest.raises(RequestError): + await json.clear(glide_client, "non_existing_key") + + with pytest.raises(RequestError): + await json.clear(glide_client, "non_existing_key", "$") + + with pytest.raises(RequestError): + await json.clear(glide_client, "non_existing_key", ".") + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_numincrby(self, glide_client: TGlideClient): + key = get_random_string(10) + + json_value = { + "key1": 1, + "key2": 3.5, + "key3": {"nested_key": {"key1": [4, 5]}}, + "key4": [1, 2, 3], + "key5": 0, + "key6": "hello", + "key7": None, + "key8": {"nested_key": {"key1": 69}}, + "key9": 1.7976931348623157e308, + } + + # Set the initial JSON document at the key + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + + # Test JSONPath + # Increment integer value (key1) by 5 + result = await json.numincrby(glide_client, key, "$.key1", 5) + assert result == b"[6]" # Expect 1 + 5 = 6 + + # Increment float value (key2) by 2.5 + result = await json.numincrby(glide_client, key, "$.key2", 2.5) + assert result == b"[6]" # Expect 3.5 + 2.5 = 6 + + # Increment nested object (key3.nested_key.key1[0]) by 7 + result = await json.numincrby(glide_client, key, "$.key3.nested_key.key1[1]", 7) + assert result == b"[12]" # Expect 4 + 7 = 12 + + # Increment array element (key4[1]) by 1 + result = await json.numincrby(glide_client, key, "$.key4[1]", 1) + assert result == b"[3]" # Expect 2 + 1 = 3 + + # Increment zero value (key5) by 10.23 (float number) + result = await json.numincrby(glide_client, key, "$.key5", 10.23) + assert result == b"[10.23]" # Expect 0 + 10.23 = 10.23 + + # Increment a string value (key6) by a number + result = await json.numincrby(glide_client, key, "$.key6", 99) + assert result == b"[null]" # Expect null + + # Increment a None value (key7) by a number + result = await json.numincrby(glide_client, key, "$.key7", 51) + assert result == b"[null]" # Expect null + + # Check increment for all numbers in the document using JSON Path (First Null: key3 as an entire object. Second Null: The path checks under key3, which is an object, for numeric values). + result = await json.numincrby(glide_client, key, "$..*", 5) + assert ( + result + == b"[11,11,null,null,15.23,null,null,null,1.7976931348623157e+308,null,null,9,17,6,8,8,null,74]" + ) + + # Check for multiple path match in enhanced + result = await json.numincrby(glide_client, key, "$..key1", 1) + assert result == b"[12,null,75]" + + # Check for non existent path in JSONPath + result = await json.numincrby(glide_client, key, "$.key10", 51) + assert result == b"[]" # Expect Empty Array + + # Check for non existent key in JSONPath + with pytest.raises(RequestError): + await json.numincrby(glide_client, "non_existent_key", "$.key10", 51) + + # Check for Overflow in JSONPath + with pytest.raises(RequestError): + await json.numincrby(glide_client, key, "$.key9", 1.7976931348623157e308) + + # Decrement integer value (key1) by 12 + result = await json.numincrby(glide_client, key, "$.key1", -12) + assert result == b"[0]" # Expect 12 - 12 = 0 + + # Decrement integer value (key1) by 0.5 + result = await json.numincrby(glide_client, key, "$.key1", -0.5) + assert result == b"[-0.5]" # Expect 0 - 0.5 = -0.5 + + # Check 'null' value + result = await json.numincrby(glide_client, key, "$.key7", 5) + assert result == b"[null]" # Expect 'null' + + # Test Legacy Path + # Increment float value (key1) by 5 (integer) + result = await json.numincrby(glide_client, key, "key1", 5) + assert result == b"4.5" # Expect -0.5 + 5 = 4.5 + + # Decrement float value (key1) by 5.5 (integer) + result = await json.numincrby(glide_client, key, "key1", -5.5) + assert result == b"-1" # Expect 4.5 - 5.5 = -1 + + # Increment int value (key2) by 2.5 (a float number) + result = await json.numincrby(glide_client, key, "key2", 2.5) + assert result == b"13.5" # Expect 11 + 2.5 = 13.5 + + # Increment nested value (key3.nested_key.key1[0]) by 7 + result = await json.numincrby(glide_client, key, "key3.nested_key.key1[0]", 7) + assert result == b"16" # Expect 9 + 7 = 16 + + # Increment array element (key4[1]) by 1 + result = await json.numincrby(glide_client, key, "key4[1]", 1) + assert result == b"9" # Expect 8 + 1 = 9 + + # Increment a float value (key5) by 10.2 (a float number) + result = await json.numincrby(glide_client, key, "key5", 10.2) + assert result == b"25.43" # Expect 15.23 + 10.2 = 25.43 + + # Check for multiple path match in legacy and assure that the result of the last updated value is returned + result = await json.numincrby(glide_client, key, "..key1", 1) + assert result == b"76" + + # Check if the rest of the key1 path matches were updated and not only the last value + result = await json.get(glide_client, key, "$..key1") # type: ignore + assert ( + result == b"[0,[16,17],76]" + ) # First is 0 as 0 + 0 = 0, Second doesn't change as its an array type (non-numeric), third is 76 as 0 + 76 = 0 + + # Check for non existent path in legacy + with pytest.raises(RequestError): + await json.numincrby(glide_client, key, ".key10", 51) + + # Check for non existent key in legacy + with pytest.raises(RequestError): + await json.numincrby(glide_client, "non_existent_key", ".key10", 51) + + # Check for Overflow in legacy + with pytest.raises(RequestError): + await json.numincrby(glide_client, key, ".key9", 1.7976931348623157e308) + + # Check 'null' value + with pytest.raises(RequestError): + await json.numincrby(glide_client, key, ".key7", 5) + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_nummultby(self, glide_client: TGlideClient): + key = get_random_string(10) + + json_value = { + "key1": 1, + "key2": 3.5, + "key3": {"nested_key": {"key1": [4, 5]}}, + "key4": [1, 2, 3], + "key5": 0, + "key6": "hello", + "key7": None, + "key8": {"nested_key": {"key1": 69}}, + "key9": 3.5953862697246314e307, + } + + # Set the initial JSON document at the key + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + + # Test JSONPath + # Multiply integer value (key1) by 5 + result = await json.nummultby(glide_client, key, "$.key1", 5) + assert result == b"[5]" # Expect 1 * 5 = 5 + + # Multiply float value (key2) by 2.5 + result = await json.nummultby(glide_client, key, "$.key2", 2.5) + assert result == b"[8.75]" # Expect 3.5 * 2.5 = 8.75 + + # Multiply nested object (key3.nested_key.key1[1]) by 7 + result = await json.nummultby(glide_client, key, "$.key3.nested_key.key1[1]", 7) + assert result == b"[35]" # Expect 5 * 7 = 35 + + # Multiply array element (key4[1]) by 1 + result = await json.nummultby(glide_client, key, "$.key4[1]", 1) + assert result == b"[2]" # Expect 2 * 1 = 2 + + # Multiply zero value (key5) by 10.23 (float number) + result = await json.nummultby(glide_client, key, "$.key5", 10.23) + assert result == b"[0]" # Expect 0 * 10.23 = 0 + + # Multiply a string value (key6) by a number + result = await json.nummultby(glide_client, key, "$.key6", 99) + assert result == b"[null]" # Expect null + + # Multiply a None value (key7) by a number + result = await json.nummultby(glide_client, key, "$.key7", 51) + assert result == b"[null]" # Expect null + + # Check multiplication for all numbers in the document using JSON Path + # key1: 5 * 5 = 25 + # key2: 8.75 * 5 = 43.75 + # key3.nested_key.key1[0]: 4 * 5 = 20 + # key3.nested_key.key1[1]: 35 * 5 = 175 + # key4[0]: 1 * 5 = 5 + # key4[1]: 2 * 5 = 10 + # key4[2]: 3 * 5 = 15 + # key5: 0 * 5 = 0 + # key8.nested_key.key1: 69 * 5 = 345 + # key9: 3.5953862697246314e307 * 5 = 1.7976931348623157e308 + result = await json.nummultby(glide_client, key, "$..*", 5) + assert ( + result + == b"[25,43.75,null,null,0,null,null,null,1.7976931348623157e+308,null,null,20,175,5,10,15,null,345]" + ) + + # Check for multiple path matches in JSONPath + # key1: 25 * 2 = 50 + # key8.nested_key.key1: 345 * 2 = 690 + result = await json.nummultby(glide_client, key, "$..key1", 2) + assert result == b"[50,null,690]" # After previous multiplications + + # Check for non-existent path in JSONPath + result = await json.nummultby(glide_client, key, "$.key10", 51) + assert result == b"[]" # Expect Empty Array + + # Check for non-existent key in JSONPath + with pytest.raises(RequestError): + await json.nummultby(glide_client, "non_existent_key", "$.key10", 51) + + # Check for Overflow in JSONPath + with pytest.raises(RequestError): + await json.nummultby(glide_client, key, "$.key9", 1.7976931348623157e308) + + # Multiply integer value (key1) by -12 + result = await json.nummultby(glide_client, key, "$.key1", -12) + assert result == b"[-600]" # Expect 50 * -12 = -600 + + # Multiply integer value (key1) by -0.5 + result = await json.nummultby(glide_client, key, "$.key1", -0.5) + assert result == b"[300]" # Expect -600 * -0.5 = 300 + + # Test Legacy Path + # Multiply int value (key1) by 5 (integer) + result = await json.nummultby(glide_client, key, "key1", 5) + assert result == b"1500" # Expect 300 * 5 = -1500 + + # Multiply int value (key1) by -5.5 (float number) + result = await json.nummultby(glide_client, key, "key1", -5.5) + assert result == b"-8250" # Expect -150 * -5.5 = -8250 + + # Multiply int float (key2) by 2.5 (a float number) + result = await json.nummultby(glide_client, key, "key2", 2.5) + assert result == b"109.375" # Expect 43.75 * 2.5 = 109.375 + + # Multiply nested value (key3.nested_key.key1[0]) by 7 + result = await json.nummultby(glide_client, key, "key3.nested_key.key1[0]", 7) + assert result == b"140" # Expect 20 * 7 = 140 + + # Multiply array element (key4[1]) by 1 + result = await json.nummultby(glide_client, key, "key4[1]", 1) + assert result == b"10" # Expect 10 * 1 = 10 + + # Multiply a float value (key5) by 10.2 (a float number) + result = await json.nummultby(glide_client, key, "key5", 10.2) + assert result == b"0" # Expect 0 * 10.2 = 0 + + # Check for multiple path matches in legacy and assure that the result of the last updated value is returned + # last updated value is key8.nested_key.key1: 690 * 2 = 1380 + result = await json.nummultby(glide_client, key, "..key1", 2) + assert result == b"1380" # Expect the last updated key1 value multiplied by 2 + + # Check if the rest of the key1 path matches were updated and not only the last value + result = await json.get(glide_client, key, "$..key1") # type: ignore + assert result == b"[-16500,[140,175],1380]" + + # Check 'null' in legacy + with pytest.raises(RequestError): + await json.nummultby(glide_client, key, ".key7", 5) + + # Check for non-existent path in legacy + with pytest.raises(RequestError): + await json.nummultby(glide_client, key, ".key10", 51) + + # Check for non-existent key in legacy + with pytest.raises(RequestError): + await json.nummultby(glide_client, "non_existent_key", ".key10", 51) + + # Check for Overflow in legacy + with pytest.raises(RequestError): + await json.nummultby(glide_client, key, ".key9", 1.7976931348623157e308) + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_strlen(self, glide_client: TGlideClient): + key = get_random_string(10) + json_value = {"a": "foo", "nested": {"a": "hello"}, "nested2": {"a": 31}} + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + + assert await json.strlen(glide_client, key, "$..a") == [3, 5, None] + assert await json.strlen(glide_client, key, "a") == 3 + + assert await json.strlen(glide_client, key, "$.nested") == [None] + with pytest.raises(RequestError): + assert await json.strlen(glide_client, key, "nested") + + with pytest.raises(RequestError): + assert await json.strlen(glide_client, key) + + assert await json.strlen(glide_client, key, "$.non_existing_path") == [] + with pytest.raises(RequestError): + await json.strlen(glide_client, key, ".non_existing_path") + + assert await json.strlen(glide_client, "non_exiting_key", ".") is None + assert await json.strlen(glide_client, "non_exiting_key", "$") is None + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_strappend(self, glide_client: TGlideClient): + key = get_random_string(10) + json_value = {"a": "foo", "nested": {"a": "hello"}, "nested2": {"a": 31}} + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + + assert await json.strappend(glide_client, key, '"bar"', "$..a") == [6, 8, None] + assert await json.strappend(glide_client, key, OuterJson.dumps("foo"), "a") == 9 + + json_str = await json.get(glide_client, key, ".") + assert isinstance(json_str, bytes) + assert OuterJson.loads(json_str) == { + "a": "foobarfoo", + "nested": {"a": "hellobar"}, + "nested2": {"a": 31}, + } + + assert await json.strappend( + glide_client, key, OuterJson.dumps("bar"), "$.nested" + ) == [None] + + with pytest.raises(RequestError): + await json.strappend(glide_client, key, OuterJson.dumps("bar"), ".nested") + + with pytest.raises(RequestError): + await json.strappend(glide_client, key, OuterJson.dumps("bar")) + + assert ( + await json.strappend( + glide_client, key, OuterJson.dumps("try"), "$.non_existing_path" + ) + == [] + ) + with pytest.raises(RequestError): + await json.strappend( + glide_client, key, OuterJson.dumps("try"), "non_existing_path" + ) + + with pytest.raises(RequestError): + await json.strappend( + glide_client, "non_exiting_key", OuterJson.dumps("try") + ) + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + @typing.no_type_check # since this is a complex test, skip typing to be more effective + async def test_json_arrinsert(self, glide_client: TGlideClient): + key = get_random_string(10) + + assert ( + await json.set( + glide_client, + key, + "$", + """ + { + "a": [], + "b": { "a": [1, 2, 3, 4] }, + "c": { "a": "not an array" }, + "d": [{ "a": ["x", "y"] }, { "a": [["foo"]] }], + "e": [{ "a": 42 }, { "a": {} }], + "f": { "a": [true, false, null] } + } + """, + ) + == OK + ) + + # Insert different types of values into the matching paths + result = await json.arrinsert( + glide_client, + key, + "$..a", + 0, + ['"string_value"', "123", '{"key": "value"}', "true", "null", '["bar"]'], + ) + assert result == [6, 10, None, 8, 7, None, None, 9] + + updated_doc = await json.get(glide_client, key) + + expected_doc = { + "a": ["string_value", 123, {"key": "value"}, True, None, ["bar"]], + "b": { + "a": [ + "string_value", + 123, + {"key": "value"}, + True, + None, + ["bar"], + 1, + 2, + 3, + 4, + ], + }, + "c": {"a": "not an array"}, + "d": [ + { + "a": [ + "string_value", + 123, + {"key": "value"}, + True, + None, + ["bar"], + "x", + "y", + ] + }, + { + "a": [ + "string_value", + 123, + {"key": "value"}, + True, + None, + ["bar"], + ["foo"], + ] + }, + ], + "e": [{"a": 42}, {"a": {}}], + "f": { + "a": [ + "string_value", + 123, + {"key": "value"}, + True, + None, + ["bar"], + True, + False, + None, + ] + }, + } + + assert OuterJson.loads(updated_doc) == expected_doc + + # Insert into a specific index (non-zero) + result = await json.arrinsert( + glide_client, + key, + "$..a", + 2, + ['"insert_at_2"'], + ) + assert result == [7, 11, None, 9, 8, None, None, 10] + + # Check document after insertion at index 2 + updated_doc_at_2 = await json.get(glide_client, key) + expected_doc["a"].insert(2, "insert_at_2") + expected_doc["b"]["a"].insert(2, "insert_at_2") + expected_doc["d"][0]["a"].insert(2, "insert_at_2") + expected_doc["d"][1]["a"].insert(2, "insert_at_2") + expected_doc["f"]["a"].insert(2, "insert_at_2") + assert OuterJson.loads(updated_doc_at_2) == expected_doc + + # Insert with a legacy path + result = await json.arrinsert( + glide_client, + key, + "..a", # legacy path + 0, + ['"legacy_value"'], + ) + assert ( + result == 8 + ) # Returns length of the first modified array (in this case, 'a') + + # Check document after insertion at root legacy path (all matching arrays should be updated) + updated_doc_legacy = await json.get(glide_client, key) + + # Update `expected_doc` with the new value inserted at index 0 of all matching arrays + expected_doc["a"].insert(0, "legacy_value") + expected_doc["b"]["a"].insert(0, "legacy_value") + expected_doc["d"][0]["a"].insert(0, "legacy_value") + expected_doc["d"][1]["a"].insert(0, "legacy_value") + expected_doc["f"]["a"].insert(0, "legacy_value") + + assert OuterJson.loads(updated_doc_legacy) == expected_doc + + # Insert with an index out of range for some arrays + with pytest.raises(RequestError): + await json.arrinsert( + glide_client, + key, + "$..a", + 10, # Index out of range for some paths but valid for others + ['"out_of_range_value"'], + ) + + with pytest.raises(RequestError): + await json.arrinsert( + glide_client, + key, + "..a", + 10, # Index out of range for some paths but valid for others + ['"out_of_range_value"'], + ) + + # Negative index insertion (should insert from the end of the array) + result = await json.arrinsert( + glide_client, + key, + "$..a", + -1, + ['"negative_index_value"'], + ) + assert result == [9, 13, None, 11, 10, None, None, 12] # Update valid paths + + # Check document after negative index insertion + updated_doc_negative = await json.get(glide_client, key) + expected_doc["a"].insert(-1, "negative_index_value") + expected_doc["b"]["a"].insert(-1, "negative_index_value") + expected_doc["d"][0]["a"].insert(-1, "negative_index_value") + expected_doc["d"][1]["a"].insert(-1, "negative_index_value") + expected_doc["f"]["a"].insert(-1, "negative_index_value") + assert OuterJson.loads(updated_doc_negative) == expected_doc + + # Non-existing path + with pytest.raises(RequestError): + await json.arrinsert(glide_client, key, ".path", 5, ['"value"']) + + await json.arrinsert(glide_client, key, "$.path", 5, ['"value"']) == [] + + # Key doesnt exist + with pytest.raises(RequestError): + await json.arrinsert(glide_client, "non_existent_key", "$", 5, ['"value"']) + + with pytest.raises(RequestError): + await json.arrinsert(glide_client, "non_existent_key", ".", 5, ['"value"']) + + # value at path is not an array + with pytest.raises(RequestError): + await json.arrinsert(glide_client, key, ".e", 5, ['"value"']) + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_debug_fields(self, glide_client: TGlideClient): + key = get_random_string(10) + + json_value = { + "key1": 1, + "key2": 3.5, + "key3": {"nested_key": {"key1": [4, 5]}}, + "key4": [1, 2, 3], + "key5": 0, + "key6": "hello", + "key7": None, + "key8": {"nested_key": {"key1": 3.5953862697246314e307}}, + "key9": 3.5953862697246314e307, + "key10": True, + } + + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + + # Test JSONPath - Fields Subcommand + # Test integer + result = await json.debug_fields(glide_client, key, "$.key1") + assert result == [1] + + # Test float + result = await json.debug_fields(glide_client, key, "$.key2") + assert result == [1] + + # Test Nested Value + result = await json.debug_fields(glide_client, key, "$.key3") + assert result == [4] + + result = await json.debug_fields(glide_client, key, "$.key3.nested_key.key1") + assert result == [2] + + # Test Array + result = await json.debug_fields(glide_client, key, "$.key4[2]") + assert result == [1] + + # Test String + result = await json.debug_fields(glide_client, key, "$.key6") + assert result == [1] + + # Test Null + result = await json.debug_fields(glide_client, key, "$.key7") + assert result == [1] + + # Test Bool + result = await json.debug_fields(glide_client, key, "$.key10") + assert result == [1] + + # Test all keys + result = await json.debug_fields(glide_client, key, "$[*]") + assert result == [1, 1, 4, 3, 1, 1, 1, 2, 1, 1] + + # Test multiple paths + result = await json.debug_fields(glide_client, key, "$..key1") + assert result == [1, 2, 1] + + # Test for non-existent path + result = await json.debug_fields(glide_client, key, "$.key11") + assert result == [] + + # Test for non-existent key + result = await json.debug_fields(glide_client, "non_existent_key", "$.key10") + assert result == None + + # Test no provided path + # Total Fields (19) - breakdown: + # Top-Level Fields: 10 + # Fields within key3: 4 ($.key3, $.key3.nested_key, $.key3.nested_key.key1, $.key3.nested_key.key1) + # Fields within key4: 3 ($.key4[0], $.key4[1], $.key4[2]) + # Fields within key8: 2 ($.key8, $.key8.nested_key) + result = await json.debug_fields(glide_client, key) + assert result == 19 + + # Test legacy path - Fields Subcommand + # Test integer + result = await json.debug_fields(glide_client, key, ".key1") + assert result == 1 + + # Test float + result = await json.debug_fields(glide_client, key, ".key2") + assert result == 1 + + # Test Nested Value + result = await json.debug_fields(glide_client, key, ".key3") + assert result == 4 + + result = await json.debug_fields(glide_client, key, ".key3.nested_key.key1") + assert result == 2 + + # Test Array + result = await json.debug_fields(glide_client, key, ".key4[2]") + assert result == 1 + + # Test String + result = await json.debug_fields(glide_client, key, ".key6") + assert result == 1 + + # Test Null + result = await json.debug_fields(glide_client, key, ".key7") + assert result == 1 + + # Test Bool + result = await json.debug_fields(glide_client, key, ".key10") + assert result == 1 + + # Test multiple paths + result = await json.debug_fields(glide_client, key, "..key1") + assert result == 1 # Returns number of fields of the first JSON value + + # Test for non-existent path + with pytest.raises(RequestError): + await json.debug_fields(glide_client, key, ".key11") + + # Test for non-existent key + result = await json.debug_fields(glide_client, "non_existent_key", ".key10") + assert result == None + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_debug_memory(self, glide_client: TGlideClient): + key = get_random_string(10) + + json_value = { + "key1": 1, + "key2": 3.5, + "key3": {"nested_key": {"key1": [4, 5]}}, + "key4": [1, 2, 3], + "key5": 0, + "key6": "hello", + "key7": None, + "key8": {"nested_key": {"key1": 3.5953862697246314e307}}, + "key9": 3.5953862697246314e307, + "key10": True, + } + + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + # Test JSONPath - Memory Subcommand + # Test integer + result = await json.debug_memory(glide_client, key, "$.key1") + assert result == [16] + # Test float + result = await json.debug_memory(glide_client, key, "$.key2") + assert result == [16] + # Test Nested Value + result = await json.debug_memory(glide_client, key, "$.key3.nested_key.key1[0]") + assert result == [16] + # Test Array + result = await json.debug_memory(glide_client, key, "$.key4") + assert result == [16 * 4] + + result = await json.debug_memory(glide_client, key, "$.key4[2]") + assert result == [16] + # Test String + result = await json.debug_memory(glide_client, key, "$.key6") + assert result == [16] + # Test Null + result = await json.debug_memory(glide_client, key, "$.key7") + assert result == [16] + # Test Bool + result = await json.debug_memory(glide_client, key, "$.key10") + assert result == [16] + # Test all keys + result = await json.debug_memory(glide_client, key, "$[*]") + assert result == [16, 16, 110, 64, 16, 16, 16, 101, 39, 16] + # Test multiple paths + result = await json.debug_memory(glide_client, key, "$..key1") + assert result == [16, 48, 39] + # Test for non-existent path + result = await json.debug_memory(glide_client, key, "$.key11") + assert result == [] + # Test for non-existent key + result = await json.debug_memory(glide_client, "non_existent_key", "$.key10") + assert result == None + # Test no provided path + # Total Memory (504 bytes) - visual breakdown: + # ├── Root Object Overhead (129 bytes) + # └── JSON Elements (374 bytes) + # ├── key1: 16 bytes + # ├── key2: 16 bytes + # ├── key3: 110 bytes + # ├── key4: 64 bytes + # ├── key5: 16 bytes + # ├── key6: 16 bytes + # ├── key7: 16 bytes + # ├── key8: 101 bytes + # └── key9: 39 bytes + result = await json.debug_memory(glide_client, key) + assert result == 504 + # Test Legacy Path - Memory Subcommand + # Test integer + result = await json.debug_memory(glide_client, key, ".key1") + assert result == 16 + # Test float + result = await json.debug_memory(glide_client, key, ".key2") + assert result == 16 + # Test Nested Value + result = await json.debug_memory(glide_client, key, ".key3.nested_key.key1[0]") + assert result == 16 + # Test Array + result = await json.debug_memory(glide_client, key, ".key4[2]") + assert result == 16 + # Test String + result = await json.debug_memory(glide_client, key, ".key6") + assert result == 16 + # Test Null + result = await json.debug_memory(glide_client, key, ".key7") + assert result == 16 + # Test Bool + result = await json.debug_memory(glide_client, key, ".key10") + assert result == 16 + # Test multiple paths + result = await json.debug_memory(glide_client, key, "..key1") + assert result == 16 # Returns the memory usage of the first JSON value + # Test for non-existent path + with pytest.raises(RequestError): + await json.debug_memory(glide_client, key, ".key11") + # Test for non-existent key + result = await json.debug_memory(glide_client, "non_existent_key", ".key10") + assert result == None + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @typing.no_type_check + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_arrtrim(self, glide_client: TGlideClient): + key = get_random_string(5) + + # Test with enhanced path syntax + json_value = '{"a": [0, 1, 2, 3, 4, 5, 6, 7, 8], "b": {"a": [0, 9, 10, 11, 12, 13], "c": {"a": 42}}}' + assert await json.set(glide_client, key, "$", json_value) == OK + + # Basic trim + assert await json.arrtrim(glide_client, key, "$..a", 1, 7) == [7, 5, None] + assert OuterJson.loads(await json.get(glide_client, key, "$..a")) == [ + [1, 2, 3, 4, 5, 6, 7], + [9, 10, 11, 12, 13], + 42, + ] + + # Test negative start (should be treated as 0) + assert await json.arrtrim(glide_client, key, "$.a", -1, 5) == [6] + assert OuterJson.loads(await json.get(glide_client, key, "$.a")) == [ + [1, 2, 3, 4, 5, 6] + ] + assert await json.arrtrim(glide_client, key, ".a", -1, 5) == 6 + assert OuterJson.loads(await json.get(glide_client, key, ".a")) == [ + 1, + 2, + 3, + 4, + 5, + 6, + ] + + # Test end >= size (should be treated as size-1) + assert await json.arrtrim(glide_client, key, "$.a", 0, 10) == [6] + assert OuterJson.loads(await json.get(glide_client, key, "$.a")) == [ + [1, 2, 3, 4, 5, 6] + ] + + assert await json.arrtrim(glide_client, key, ".a", 0, 10) == 6 + assert OuterJson.loads(await json.get(glide_client, key, ".a")) == [ + 1, + 2, + 3, + 4, + 5, + 6, + ] + + # Test start >= size (should empty the array) + assert await json.arrtrim(glide_client, key, "$.a", 7, 10) == [0] + assert OuterJson.loads(await json.get(glide_client, key, "$.a")) == [[]] + + assert await json.set(glide_client, key, ".a", '["a", "b", "c"]') == OK + assert await json.arrtrim(glide_client, key, ".a", 7, 10) == 0 + assert OuterJson.loads(await json.get(glide_client, key, ".a")) == [] + + # Test start > end (should empty the array) + assert await json.arrtrim(glide_client, key, "$..a", 2, 1) == [0, 0, None] + assert OuterJson.loads(await json.get(glide_client, key, "$..a")) == [ + [], + [], + 42, + ] + assert await json.set(glide_client, key, "..a", '["a", "b", "c", "d"]') == OK + assert await json.arrtrim(glide_client, key, "..a", 2, 1) == 0 + assert OuterJson.loads(await json.get(glide_client, key, ".a")) == [] + + # Multiple path match + assert await json.set(glide_client, key, "$", json_value) == OK + assert await json.arrtrim(glide_client, key, "..a", 1, 10) == 8 + assert OuterJson.loads(await json.get(glide_client, key, "$..a")) == [ + [1, 2, 3, 4, 5, 6, 7, 8], + [9, 10, 11, 12, 13], + 42, + ] + + # Test with non-existent path + with pytest.raises(RequestError): + await json.arrtrim(glide_client, key, ".non_existent", 0, 1) + + assert await json.arrtrim(glide_client, key, "$.non_existent", 0, 1) == [] + + # Test with non-array path + assert await json.arrtrim(glide_client, key, "$", 0, 1) == [None] + + with pytest.raises(RequestError): + await json.arrtrim(glide_client, key, ".", 0, 1) + + # Test with non-existent key + with pytest.raises(RequestError): + await json.arrtrim(glide_client, "non_existent_key", "$", 0, 1) + + # Test with non-existent key + with pytest.raises(RequestError): + await json.arrtrim(glide_client, "non_existent_key", ".", 0, 1) + + # Test empty array + assert await json.set(glide_client, key, "$.empty", "[]") == OK + assert await json.arrtrim(glide_client, key, "$.empty", 0, 1) == [0] + assert await json.arrtrim(glide_client, key, ".empty", 0, 1) == 0 + assert OuterJson.loads(await json.get(glide_client, key, "$.empty")) == [[]] + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_arrindex(self, glide_client: TGlideClient): + key = get_random_string(10) + + json_value = { + "empty_array": [], + "single_element": ["apple"], + "multiple_elements": ["banana", "cherry", "date"], + "nested_arrays": [ + ["alpha"], + ["beta", "gamma"], + ["delta", "epsilon", "zeta"], + ], + "mixed_types": [1, "two", True, None, 5.5], + "not_array": 5, + "nested_arrays2": [ + ["a"], + ["ab", "abc"], + ["abcd", "abcde", "abcdef", "abcdefg", "abcdefgh", 1, 2, None, "gamma"], + ], + "nested_structure": { + "level1": { + "level2": { + "level3": [ + ["gamma", "theta"], + ["iota", "kappa", "gamma"], + ] + } + } + }, + } + + assert await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == OK + + # JSONPath Syntax Tests + # Search for "beta" in all arrays at the root level, Non-array values return null + result = await json.arrindex(glide_client, key, "$[*]", '"beta"') + assert result == [-1, -1, -1, -1, -1, None, -1, None] + + # Search for a boolean + result = await json.arrindex(glide_client, key, "$.mixed_types", "true") + assert result == [2] # True found at index 2 in the "mixed_types" array + + # Search for a float + result = await json.arrindex(glide_client, key, "$.mixed_types", "5.5") + assert result == [4] # 5.5 found at index 4 in the "mixed_types" array + + # Search for "gamma" at nested level + result = await json.arrindex(glide_client, key, "$.nested_arrays[*]", '"gamma"') + assert result == [-1, 1, -1] # "gamma" found at index 1 in the second array + + # Search for "gamma" at nested level with a specified range + result = await json.arrindex( + glide_client, + key, + "$.nested_arrays2[*]", + '"gamma"', + JsonArrIndexOptions(start=0, end=5), + ) + assert result == [-1, -1, -1] + + # Search for "gamma" at nested level with start > end + result = await json.arrindex( + glide_client, + key, + "$.nested_arrays[*]", + '"gamma"', + JsonArrIndexOptions(start=2, end=1), + ) + assert result == [-1, -1, -1] # Invalid range, returns -1 for all + + # Search for "omega" which does not exist + result = await json.arrindex(glide_client, key, "$[*]", '"omega"') + assert result == [-1, -1, -1, -1, -1, None, -1, None] # "omega" not found + + # Search for null values, null found at at third index in the fifth array + result = await json.arrindex(glide_client, key, "$[*]", "null") + assert result == [-1, -1, -1, -1, 3, None, -1, None] + + # Search in mixed types, "two" found at first index in the fifth array + result = await json.arrindex(glide_client, key, "$[*]", '"two"') + assert result == [-1, -1, -1, -1, 1, None, -1, None] + + # Out of range check for "start" value + result = await json.arrindex( + glide_client, key, "$[*]", '"apple"', JsonArrIndexOptions(start=-200) + ) + assert result == [ + -1, + 0, + -1, + -1, + -1, + None, + -1, + None, + ] # Rounded to the array's start + + # Check for end = -1, tests if the function includes the last element, found "gamma" at index 8 at the third array + result = await json.arrindex( + glide_client, + key, + "$.nested_arrays2[*]", + '"gamma"', + JsonArrIndexOptions(start=0, end=-1), + ) + assert result == [-1, -1, 8] + + # Check for non-existent key + with pytest.raises(RequestError): + await json.arrindex( + glide_client, + "Non_existent", + "$.nested_arrays2[*]", + '"abcdefg"', + JsonArrIndexOptions(start=0, end=-1), + ) + + # Check for non-existent path + result = await json.arrindex( + glide_client, + key, + "$.nested_arrays3[*]", + '"abcdefg"', + JsonArrIndexOptions(start=0, end=-1), + ) + assert result == [] + + # Using JSONPath syntax to search for "gamma" in nested_structure.level1.level2.level3 + result = await json.arrindex( + glide_client, key, "$.nested_structure.level1.level2.level3[*]", '"gamma"' + ) + assert result == [ + 0, + 2, + ] # "gamma" at index 0 in first array, index 2 in second array + + # Check for inclusive behavior of start in JSONPath syntax + result = await json.arrindex( + glide_client, + key, + "$.nested_structure.level1.level2.level3[*]", + '"gamma"', + JsonArrIndexOptions(start=0), + ) + assert result == [ + 0, + 2, + ] # "gamma" at index 0 of level3[0] and index 2 of level3[1]. + + # Check for exclusive behavior of end in JSONPath syntax + result = await json.arrindex( + glide_client, + key, + "$.nested_structure.level1.level2.level3[*]", + '"gamma"', + JsonArrIndexOptions(start=0, end=2), + ) + assert result == [ + 0, + -1, + ] # Only "gamma" at index 0 of level3[0] is found; gamma at index 2 of level3[1] is excluded as its not within the search range. + + # Check for passing start = 0, end = 0 in JSONPath syntax + result = await json.arrindex( + glide_client, + key, + "$.nested_arrays[2]", + '"zeta"', + JsonArrIndexOptions(start=0, end=0), + ) + assert result == [2] # "zeta" found at index 2 as the whole range was searched + + # Check for passing start = 1, end = 0 (start>end) but end is a "special value" in JSONPath syntax + result = await json.arrindex( + glide_client, + key, + "$.nested_arrays[2]", + '"zeta"', + JsonArrIndexOptions(start=1, end=0), + ) + assert result == [2] # "zeta" found at index 2 as the whole range was searched + + # Check for passing start = 1, end = -1 (start>end) but end is a "special value" in JSONPath syntax + result = await json.arrindex( + glide_client, + key, + "$.nested_arrays[2]", + '"zeta"', + JsonArrIndexOptions(start=1, end=-1), + ) + assert result == [2] # "zeta" found at index 2 as the whole range was searched + + # Restricted Path Syntax Tests + # Search for "abcd" in the "nested_arrays2" array + result = await json.arrindex(glide_client, key, ".nested_arrays2[2]", '"abcd"') + assert result == 0 # "abcd" found at index 0 + + # Search for "abcd" in the "nested_arrays2" array with specified range + result = await json.arrindex( + glide_client, + key, + ".nested_arrays2[2]", + '"abcd"', + JsonArrIndexOptions(start=1, end=4), + ) + assert result == -1 # "abcd" not found at the specified range + + # Search for "abcdefg" in the "nested_arrays2" with start > end + result = await json.arrindex( + glide_client, + key, + ".nested_arrays2[2]", + '"abcdefg"', + JsonArrIndexOptions(start=4, end=3), + ) + assert result == -1 + + # Search for "theta" which does not exist + result = await json.arrindex(glide_client, key, ".multiple_elements", '"theta"') + assert result == -1 # "theta" not found + + # Check for non_existent path + with pytest.raises(RequestError): + await json.arrindex(glide_client, key, ".non_existent", '"value"') + + # Search in an empty array + result = await json.arrindex(glide_client, key, ".empty_array", '"anything"') + assert result == -1 # Nothing to find in empty array + + # Search for a boolean + result = await json.arrindex(glide_client, key, ".mixed_types", "true") + assert result == 2 # True found at index 2 + + # Search for a float + result = await json.arrindex(glide_client, key, ".mixed_types", "5.5") + assert result == 4 # 5.5 found at index 4 + + # Search for null value + result = await json.arrindex(glide_client, key, ".mixed_types", "null") + assert result == 3 # null found at index 3 + + # Out of range check for "start" value + result = await json.arrindex( + glide_client, + key, + ".single_element", + '"apple"', + JsonArrIndexOptions(start=-200), + ) + assert result == 0 # Rounded to the array's start + + # Check for end = -1, tests if the function includes the last element + result = await json.arrindex( + glide_client, + key, + ".nested_arrays2[2]", + '"gamma"', + JsonArrIndexOptions(start=0, end=-1), + ) + assert result == 8 + + # Check for non-existent key + with pytest.raises(RequestError): + await json.arrindex( + glide_client, "Non_existent", ".nested_arrays2[1]", '"abcdefg"' + ) + + # Check for value at path is not an array + with pytest.raises(RequestError): + await json.arrindex(glide_client, key, ".not_array", "val") + + # Using legacy syntax to search for "gamma" in nested_structure + result = await json.arrindex( + glide_client, key, ".nested_structure.level1.level2.level3[*]", '"gamma"' + ) + assert result == 0 # Legacy syntax returns index from first matching array + + # Check for inclusive behavior of start in legacy syntax + result = await json.arrindex( + glide_client, + key, + ".nested_arrays[2]", + '"epsilon"', + JsonArrIndexOptions(start=1), + ) + assert result == 1 # "epsilon" found at index 1 in nested_arrays[2]. + + # Check for exclusive behavior of end in legacy syntax + result = await json.arrindex( + glide_client, + key, + ".nested_arrays[2]", + '"zeta"', + JsonArrIndexOptions(start=1, end=2), + ) + assert result == -1 # "zeta" at index 2 is excluded due to exclusive end. + + # Check for passing start = 0, end = 0 + result = await json.arrindex( + glide_client, + key, + ".nested_arrays[2]", + '"zeta"', + JsonArrIndexOptions(start=0, end=0), + ) + assert result == 2 # "zeta" found at index 2 as the whole range was searched + + # Check for passing start = 1, end = 0 (start>end) but end is a "special value" + result = await json.arrindex( + glide_client, + key, + ".nested_arrays[2]", + '"zeta"', + JsonArrIndexOptions(start=1, end=0), + ) + assert result == 2 # "zeta" found at index 2 as the whole range was searched + + # Check for passing start = 1, end = -1 (start>end) but end is a "special value" + result = await json.arrindex( + glide_client, + key, + ".nested_arrays[2]", + '"zeta"', + JsonArrIndexOptions(start=1, end=-1), + ) + assert result == 2 # "zeta" found at index 2 as the whole range was searched + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_arrappend(self, glide_client: TGlideClient): + key = get_random_string(10) + initial_json_value = '{"a": 1, "b": ["one", "two"]}' + assert await json.set(glide_client, key, "$", initial_json_value) == OK + + assert await json.arrappend(glide_client, key, "$.b", ['"three"']) == [3] + assert await json.arrappend(glide_client, key, ".b", ['"four"', '"five"']) == 5 + + result = await json.get(glide_client, key, "$") + assert isinstance(result, bytes) + assert OuterJson.loads(result) == [ + {"a": 1, "b": ["one", "two", "three", "four", "five"]} + ] + + assert await json.arrappend(glide_client, key, "$.a", ['"value"']) == [None] + + # JSONPath, path doesnt exist + assert await json.arrappend(glide_client, key, "$.c", ['"value"']) == [] + # Legacy path, `path` doesnt exist + with pytest.raises(RequestError): + await json.arrappend(glide_client, key, ".c", ['"value"']) + + # Legacy path, the JSON value at `path` is not a array + with pytest.raises(RequestError): + await json.arrappend(glide_client, key, ".a", ['"value"']) + + with pytest.raises(RequestError): + await json.arrappend(glide_client, "non_existing_key", "$.b", ['"six"']) + with pytest.raises(RequestError): + await json.arrappend(glide_client, "non_existing_key", ".b", ['"six"']) + + # multiple path match + json_value = '[[], ["a"], ["a", "b"]]' + assert await json.set(glide_client, key, "$", json_value) == OK + assert await json.arrappend(glide_client, key, "[*]", ['"c"']) == 1 + result = await json.get(glide_client, key, "$") + assert isinstance(result, bytes) + assert OuterJson.loads(result) == [[["c"], ["a", "c"], ["a", "b", "c"]]] + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_resp(self, glide_client: TGlideClient): + key = get_random_string(5) + + # Generate random JSON content with specified types + json_value = { + "obj": {"a": get_random_value("int"), "b": get_random_value("float")}, + "arr": [get_random_value("int") for _ in range(3)], + "str": get_random_value("str"), + "bool": get_random_value("bool"), + "int": get_random_value("int"), + "float": get_random_value("float"), + "nullVal": get_random_value("null"), + } + + json_value_expected = copy.deepcopy(json_value) + json_value_expected["obj"]["b"] = str(json_value["obj"]["b"]).encode() + json_value_expected["float"] = str(json_value["float"]).encode() + json_value_expected["str"] = str(json_value["str"]).encode() + json_value_expected["bool"] = str(json_value["bool"]).lower().encode() + assert ( + await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == "OK" + ) + + assert await json.resp(glide_client, key, "$.*") == [ + [ + b"{", + [b"a", json_value_expected["obj"]["a"]], + [b"b", json_value_expected["obj"]["b"]], + ], + [b"[", *json_value_expected["arr"]], + json_value_expected["str"], + json_value_expected["bool"], + json_value_expected["int"], + json_value_expected["float"], + json_value_expected["nullVal"], + ] + + # multiple path match, the first will be returned + assert await json.resp(glide_client, key, "*") == [ + b"{", + [b"a", json_value_expected["obj"]["a"]], + [b"b", json_value_expected["obj"]["b"]], + ] + + assert await json.resp(glide_client, key, "$") == [ + [ + b"{", + [ + b"obj", + [ + b"{", + [b"a", json_value_expected["obj"]["a"]], + [b"b", json_value_expected["obj"]["b"]], + ], + ], + [b"arr", [b"[", *json_value_expected["arr"]]], + [ + b"str", + json_value_expected["str"], + ], + [ + b"bool", + json_value_expected["bool"], + ], + [b"int", json_value["int"]], + [ + b"float", + json_value_expected["float"], + ], + [b"nullVal", json_value["nullVal"]], + ], + ] + + assert await json.resp(glide_client, key, "$.str") == [ + json_value_expected["str"] + ] + assert await json.resp(glide_client, key, ".str") == json_value_expected["str"] + + # Further tests with a new random JSON structure + json_value = { + "a": [random.randint(1, 10) for _ in range(3)], + "b": { + "a": [random.randint(1, 10) for _ in range(2)], + "c": {"a": random.randint(1, 10)}, + }, + } + assert ( + await json.set(glide_client, key, "$", OuterJson.dumps(json_value)) == "OK" + ) + + # Multiple path match + assert await json.resp(glide_client, key, "$..a") == [ + [b"[", *json_value["a"]], + [b"[", *json_value["b"]["a"]], + json_value["b"]["c"]["a"], + ] + + assert await json.resp(glide_client, key, "..a") == [b"[", *json_value["a"]] + + # Test for non-existent paths + assert await json.resp(glide_client, key, "$.nonexistent") == [] + with pytest.raises(RequestError): + await json.resp(glide_client, key, "nonexistent") + + # Test for non-existent key + assert await json.resp(glide_client, "nonexistent_key", "$") is None + assert await json.resp(glide_client, "nonexistent_key", ".") is None + assert await json.resp(glide_client, "nonexistent_key") is None + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_json_arrpop(self, glide_client: TGlideClient): + key = get_random_string(5) + key2 = get_random_string(5) + + json_value = '{"a": [1, 2, true], "b": {"a": [3, 4, ["value", 3, false] ,5], "c": {"a": 42}}}' + assert await json.set(glide_client, key, "$", json_value) == OK + + assert await json.arrpop( + glide_client, key, JsonArrPopOptions(path="$.a", index=1) + ) == [b"2"] + assert ( + await json.arrpop(glide_client, key, JsonArrPopOptions(path="$..a")) + ) == [b"true", b"5", None] + + assert ( + await json.arrpop(glide_client, key, JsonArrPopOptions(path="..a")) == b"1" + ) + # Even if only one array element was returned, ensure second array at `..a` was popped + assert await json.get(glide_client, key, "$..a") == b"[[],[3,4],42]" + + # Out of index + assert await json.arrpop( + glide_client, key, JsonArrPopOptions(path="$..a", index=10) + ) == [None, b"4", None] + + assert ( + await json.arrpop( + glide_client, key, JsonArrPopOptions(path="..a", index=-10) + ) + == b"3" + ) + + # Path is not an array + assert await json.arrpop(glide_client, key, JsonArrPopOptions(path="$")) == [ + None + ] + with pytest.raises(RequestError): + assert await json.arrpop(glide_client, key, JsonArrPopOptions(path=".")) + with pytest.raises(RequestError): + assert await json.arrpop(glide_client, key) + + # Non existing path + assert ( + await json.arrpop( + glide_client, key, JsonArrPopOptions(path="$.non_existing_path") + ) + == [] + ) + with pytest.raises(RequestError): + assert await json.arrpop( + glide_client, key, JsonArrPopOptions(path="non_existing_path") + ) + + # Non existing key + with pytest.raises(RequestError): + await json.arrpop( + glide_client, "non_existing_key", JsonArrPopOptions(path="$.a") + ) + with pytest.raises(RequestError): + await json.arrpop( + glide_client, "non_existing_key", JsonArrPopOptions(path=".a") + ) + + assert ( + await json.set(glide_client, key2, "$", '[[], ["a"], ["a", "b", "c"]]') + == OK + ) + assert ( + await json.arrpop(glide_client, key2, JsonArrPopOptions(path=".", index=-1)) + == b'["a","b","c"]' + ) + assert await json.arrpop(glide_client, key2) == b'["a"]' + + # pop from an empty array + assert await json.arrpop(glide_client, key2, JsonArrPopOptions("$[0]")) == [ + None + ] + assert await json.arrpop(glide_client, key2, JsonArrPopOptions("$[0]", 10)) == [ + None + ] + assert await json.arrpop(glide_client, key2, JsonArrPopOptions("[0]")) == None + assert ( + await json.arrpop(glide_client, key2, JsonArrPopOptions("[0]", 10)) == None + ) + + # non jsonpath pops from all matching paths, even if one result is being returned + assert ( + await json.set( + glide_client, key2, "$", '[[], ["a"], ["a", "b"], ["a", "b", "c"]]' + ) + == OK + ) + + assert await json.arrpop(glide_client, key2, JsonArrPopOptions("[*]")) == b'"a"' + assert await json.get(glide_client, key2, ".") == b'[[],[],["a"],["a","b"]]' diff --git a/python/python/tests/utils/cluster.py b/python/python/tests/utils/cluster.py index e0bfb231ae..fa17742e7b 100644 --- a/python/python/tests/utils/cluster.py +++ b/python/python/tests/utils/cluster.py @@ -45,7 +45,7 @@ def __init__( stderr=subprocess.PIPE, text=True, ) - output, err = p.communicate(timeout=40) + output, err = p.communicate(timeout=80) if p.returncode != 0: raise Exception(f"Failed to create a cluster. Executed: {p}:\n{err}") self.parse_cluster_script_start_output(output) diff --git a/python/python/tests/utils/utils.py b/python/python/tests/utils/utils.py index 497342b5c7..f912d5f6bd 100644 --- a/python/python/tests/utils/utils.py +++ b/python/python/tests/utils/utils.py @@ -137,6 +137,11 @@ def compare_maps( ) +def round_values(map_data: dict, decimal_places: int) -> dict: + """Round the values in a map to the specified number of decimal places.""" + return {key: round(value, decimal_places) for key, value in map_data.items()} + + def convert_bytes_to_string_object( # TODO: remove the str options byte_string_dict: Optional[ diff --git a/python/requirements.txt b/python/requirements.txt index 63b2be3603..b39d1d96c8 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -1,7 +1,7 @@ async-timeout==4.0.2;python_version<"3.11" -maturin==0.13.0 +maturin==0.14.17 # higher version break the needs structure changes, the name of the project is not the same as the package name, and the naming both glide create a circular dependency - TODO: fix this protobuf==3.20.* -pytest==7.1.2 -pytest-asyncio==0.19.0 +pytest +pytest-asyncio typing_extensions==4.8.0;python_version<"3.11" pytest-html diff --git a/python/src/lib.rs b/python/src/lib.rs index d00034e8cf..6b41123dd3 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -4,12 +4,15 @@ use glide_core::client::FINISHED_SCAN_CURSOR; * Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ use glide_core::start_socket_listener; +use glide_core::Telemetry; use glide_core::MAX_REQUEST_ARGS_LENGTH; use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; -use pyo3::types::{PyAny, PyBool, PyBytes, PyDict, PyFloat, PyList, PySet}; +use pyo3::types::{PyAny, PyBool, PyBytes, PyDict, PyFloat, PyList, PySet, PyString}; use pyo3::Python; use redis::Value; +use std::collections::HashMap; +use std::sync::Arc; pub const DEFAULT_TIMEOUT_IN_MILLISECONDS: u32 = glide_core::client::DEFAULT_RESPONSE_TIMEOUT.as_millis() as u32; @@ -113,31 +116,68 @@ fn glide(_py: Python, m: &Bound) -> PyResult<()> { DEFAULT_TIMEOUT_IN_MILLISECONDS, )?; m.add("MAX_REQUEST_ARGS_LEN", MAX_REQUEST_ARGS_LEN)?; + m.add_function(wrap_pyfunction!(py_log, m)?)?; + m.add_function(wrap_pyfunction!(py_init, m)?)?; + m.add_function(wrap_pyfunction!(start_socket_listener_external, m)?)?; + m.add_function(wrap_pyfunction!(value_from_pointer, m)?)?; + m.add_function(wrap_pyfunction!(create_leaked_value, m)?)?; + m.add_function(wrap_pyfunction!(create_leaked_bytes_vec, m)?)?; + m.add_function(wrap_pyfunction!(get_statistics, m)?)?; - #[pyfn(m)] + #[pyfunction] fn py_log(log_level: Level, log_identifier: String, message: String) { log(log_level, log_identifier, message); } - #[pyfn(m)] + #[pyfunction] + fn get_statistics(_py: Python) -> PyResult { + let mut stats_map = HashMap::::new(); + stats_map.insert( + "total_connections".to_string(), + Telemetry::total_connections().to_string(), + ); + stats_map.insert( + "total_clients".to_string(), + Telemetry::total_clients().to_string(), + ); + + Python::with_gil(|py| { + let py_dict = PyDict::new_bound(py); + + for (key, value) in stats_map { + py_dict.set_item( + PyString::new_bound(py, &key), + PyString::new_bound(py, &value), + )?; + } + + Ok(py_dict.into_py(py)) + }) + } + + #[pyfunction] #[pyo3(signature = (level=None, file_name=None))] fn py_init(level: Option, file_name: Option<&str>) -> Level { init(level, file_name) } - - #[pyfn(m)] + #[pyfunction] fn start_socket_listener_external(init_callback: PyObject) -> PyResult { - start_socket_listener(move |socket_path| { - Python::with_gil(|py| { - match socket_path { - Ok(path) => { - let _ = init_callback.call_bound(py, (path, py.None()), None); - } - Err(error_message) => { - let _ = init_callback.call_bound(py, (py.None(), error_message), None); - } - }; - }); + let init_callback = Arc::new(init_callback); + start_socket_listener({ + let init_callback = Arc::clone(&init_callback); + move |socket_path| { + let init_callback = Arc::clone(&init_callback); + Python::with_gil(|py| { + match socket_path { + Ok(path) => { + let _ = init_callback.call_bound(py, (path, py.None()), None); + } + Err(error_message) => { + let _ = init_callback.call_bound(py, (py.None(), error_message), None); + } + }; + }); + } }); Ok(Python::with_gil(|py| "OK".into_py(py))) } @@ -213,13 +253,13 @@ fn glide(_py: Python, m: &Bound) -> PyResult<()> { } } - #[pyfn(m)] + #[pyfunction] pub fn value_from_pointer(py: Python, pointer: u64) -> PyResult { let value = unsafe { Box::from_raw(pointer as *mut Value) }; resp_value_to_py(py, *value) } - #[pyfn(m)] + #[pyfunction] /// This function is for tests that require a value allocated on the heap. /// Should NOT be used in production. pub fn create_leaked_value(message: String) -> usize { @@ -227,7 +267,7 @@ fn glide(_py: Python, m: &Bound) -> PyResult<()> { Box::leak(Box::new(value)) as *mut Value as usize } - #[pyfn(m)] + #[pyfunction] pub fn create_leaked_bytes_vec(args_vec: Vec<&PyBytes>) -> usize { // Convert the bytes vec -> Bytes vector let bytes_vec: Vec = args_vec @@ -241,7 +281,6 @@ fn glide(_py: Python, m: &Bound) -> PyResult<()> { } Ok(()) } - impl From for Level { fn from(level: logger_core::Level) -> Self { match level { diff --git a/submodules/redis-rs b/submodules/redis-rs deleted file mode 160000 index 396536db31..0000000000 --- a/submodules/redis-rs +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 396536db31fbf2de0f272d8179d68286329fa70e diff --git a/utils/TestUtils.ts b/utils/TestUtils.ts index f6c493771b..423bf8e9cb 100644 --- a/utils/TestUtils.ts +++ b/utils/TestUtils.ts @@ -22,7 +22,7 @@ function parseOutput(input: string): { .map((address) => address.split(":")) .map((address) => [address[0], Number(address[1])]) as [ string, - number + number, ][]; if (clusterFolder === undefined || ports === undefined) { @@ -43,7 +43,7 @@ export class ValkeyCluster { private constructor( version: string, addresses: [string, number][], - clusterFolder?: string + clusterFolder?: string, ) { this.addresses = addresses; this.clusterFolder = clusterFolder; @@ -56,9 +56,9 @@ export class ValkeyCluster { replicaCount: number, getVersionCallback: ( addresses: [string, number][], - clusterMode: boolean + clusterMode: boolean, ) => Promise, - loadModule?: string[] + loadModule?: string[], ): Promise { return new Promise((resolve, reject) => { let command = `start -r ${replicaCount} -n ${shardCount}`; @@ -70,7 +70,7 @@ export class ValkeyCluster { if (loadModule) { if (loadModule.length === 0) { throw new Error( - "Please provide the path(s) to the module(s) you want to load." + "Please provide the path(s) to the module(s) you want to load.", ); } @@ -94,12 +94,12 @@ export class ValkeyCluster { new ValkeyCluster( ver, addresses, - clusterFolder - ) - ) + clusterFolder, + ), + ), ); } - } + }, ); }); } @@ -109,11 +109,11 @@ export class ValkeyCluster { addresses: [string, number][], getVersionCallback: ( addresses: [string, number][], - clusterMode: boolean - ) => Promise + clusterMode: boolean, + ) => Promise, ): Promise { return getVersionCallback(addresses, cluster_mode).then( - (ver) => new ValkeyCluster(ver, addresses, "") + (ver) => new ValkeyCluster(ver, addresses, ""), ); } diff --git a/utils/cluster_manager.py b/utils/cluster_manager.py index 03adcaba00..dc196bcd4f 100644 --- a/utils/cluster_manager.py +++ b/utils/cluster_manager.py @@ -4,12 +4,15 @@ import argparse import logging -import os +import os, signal import random import socket import string import subprocess import time +import json +import re + from datetime import datetime, timezone from pathlib import Path from typing import List, Optional, Tuple @@ -66,7 +69,9 @@ def should_generate_new_tls_certs() -> bool: except FileExistsError: files_list = [CA_CRT, REDIS_KEY, REDIS_CRT] for file in files_list: - if check_if_tls_cert_exist(file) and check_if_tls_cert_is_valid(file): + if check_if_tls_cert_exist(file) and check_if_tls_cert_is_valid( + file + ): return False return True @@ -155,7 +160,9 @@ def make_key(name: str, size: int): ) _redis_key_output, err = p.communicate(timeout=10) if p.returncode != 0: - raise Exception(f"Failed to read Redis key. Executed: {str(p.args)}:\n{err}") + raise Exception( + f"Failed to read Redis key. Executed: {str(p.args)}:\n{err}" + ) # Build redis cert p = subprocess.Popen( @@ -185,7 +192,9 @@ def make_key(name: str, size: int): ) output, err = p.communicate(timeout=10) if p.returncode != 0: - raise Exception(f"Failed to create redis cert. Executed: {str(p.args)}:\n{err}") + raise Exception( + f"Failed to create redis cert. Executed: {str(p.args)}:\n{err}" + ) toc = time.perf_counter() logging.debug(f"generate_tls_certs() Elapsed time: {toc - tic:0.4f}") logging.debug(f"TLS files= {REDIS_CRT}, {REDIS_KEY}, {CA_CRT}") @@ -222,10 +231,40 @@ class RedisServer: def __init__(self, host: str, port: int) -> None: self.host = host self.port = port + self.pid = -1 + self.is_primary = True def __str__(self) -> str: return f"{self.host}:{self.port}" + def process_id(self) -> int: + return self.pid + + def set_process_id(self, pid: int): + self.pid = pid + + def to_dictionary(self) -> dict: + return { + "host": self.host, + "port": self.port, + "pid": self.pid, + "is_primary": self.is_primary, + } + + def set_primary(self, is_primary: bool): + self.is_primary = is_primary + + +def print_servers_json(servers: List[RedisServer]): + """ + Print the list of servers to the stdout as JSON array + """ + arr = [] + for server in servers: + arr.append(server.to_dictionary()) + + print("SERVERS_JSON={}".format(json.dumps(arr))) + def next_free_port( min_port: int = 6379, max_port: int = 55535, timeout: int = 60 @@ -240,7 +279,9 @@ def next_free_port( sock.bind(("127.0.0.1", port)) sock.close() toc = time.perf_counter() - logging.debug(f"next_free_port() is {port} Elapsed time: {toc - tic:0.4f}") + logging.debug( + f"next_free_port() is {port} Elapsed time: {toc - tic:0.4f}" + ) return port except OSError as e: logging.warning(f"next_free_port error for port {port}: {e}") @@ -299,9 +340,12 @@ def get_server_command() -> str: return server except Exception as e: logging.error(f"Error checking {server}: {e}") - raise Exception("Neither valkey-server nor redis-server found in the system.") + raise Exception( + "Neither valkey-server nor redis-server found in the system." + ) server_name = get_server_command() + logfile = f"{node_folder}/redis.log" # Define command arguments cmd_args = [ server_name, @@ -314,7 +358,7 @@ def get_server_command() -> str: "--daemonize", "yes", "--logfile", - f"{node_folder}/redis.log", + logfile, ] if load_module: if len(load_module) == 0: @@ -337,6 +381,15 @@ def get_server_command() -> str: ) server = RedisServer(host, port) + + # Read the process ID from the log file + # Note that `p.pid` is not good here since we daemonize the process + process_id = wait_for_regex_in_log( + logfile, "version=(.*?)pid=([\d]+), just started", 2 + ) + if process_id: + server.set_process_id(int(process_id)) + return server, node_folder @@ -349,6 +402,7 @@ def create_servers( tls: bool, cluster_mode: bool, load_module: Optional[List[str]] = None, + json_output: bool = False, ) -> List[RedisServer]: tic = time.perf_counter() logging.debug("## Creating servers") @@ -383,7 +437,13 @@ def create_servers( port = ports[i] if ports else None servers_to_check.add( start_redis_server( - host, port, cluster_folder, tls, tls_args, cluster_mode, load_module + host, + port, + cluster_folder, + tls, + tls_args, + cluster_mode, + load_module, ) ) # Check all servers @@ -451,11 +511,17 @@ def create_cluster( if err or "[OK] All 16384 slots covered." not in output: raise Exception(f"Failed to create cluster: {err if err else output}") - wait_for_a_message_in_redis_logs(cluster_folder, "Cluster state changed: ok") + wait_for_a_message_in_redis_logs( + cluster_folder, "Cluster state changed: ok" + ) wait_for_all_topology_views(servers, cluster_folder, use_tls) + print_servers_json(servers) + logging.debug("The cluster was successfully created!") toc = time.perf_counter() - logging.debug(f"create_cluster {cluster_folder} Elapsed time: {toc - tic:0.4f}") + logging.debug( + f"create_cluster {cluster_folder} Elapsed time: {toc - tic:0.4f}" + ) def create_standalone_replication( @@ -518,7 +584,10 @@ def wait_for_a_message_in_redis_logs( continue log_file = f"{dir}/redis.log" - if server_ports and os.path.basename(os.path.normpath(dir)) not in server_ports: + if ( + server_ports + and os.path.basename(os.path.normpath(dir)) not in server_ports + ): continue if not wait_for_message(log_file, message, 10): raise Exception( @@ -527,6 +596,58 @@ def wait_for_a_message_in_redis_logs( ) +def parse_cluster_nodes(command_output: Optional[str]) -> Optional[dict]: + """ + Parameters + ---------- + command_output: str : + The output returned from valkey for the command 'CLUSTER NODES' + + Returns + ------- + A dictionary for the current node's details + """ + if command_output is None: + return None + + lines = command_output.splitlines(keepends=False) + for line in lines: + tokens = line.split(" ") + if len(tokens) < 3: + continue + + node_id = tokens[0].strip() + network = tokens[1].strip() + flags = tokens[2].strip() + + if "myself" in flags: + # This is us + return { + "node_id": node_id, + "network": network, + "is_primary": "master" in flags, + } + return None + + +def redis_cli_run_command(cmd_args: List[str]) -> Optional[str]: + try: + p = subprocess.Popen( + cmd_args, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + output, err = p.communicate(timeout=5) + if err: + raise Exception( + f"Failed to execute command: {str(p.args)}\n Return code: {p.returncode}\n Error: {err}" + ) + return output + except subprocess.TimeoutExpired: + return None + + def wait_for_all_topology_views( servers: List[RedisServer], cluster_folder: str, use_tls: bool ): @@ -547,31 +668,33 @@ def wait_for_all_topology_views( ] logging.debug(f"Executing: {cmd_args}") retries = 60 - output = "" while retries >= 0: - try: - p = subprocess.Popen( - cmd_args, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - ) - output, err = p.communicate(timeout=5) - if err: - raise Exception( - f"Failed to execute command: {str(p.args)}\n Return code: {p.returncode}\n Error: {err}" - ) - - if output.count(f"{server.host}") == len(servers): - logging.debug(f"Server {server} is ready!") - break - else: - retries -= 1 - time.sleep(0.5) - continue - except subprocess.TimeoutExpired: - time.sleep(0.5) + output = redis_cli_run_command(cmd_args) + if output is not None and output.count(f"{server.host}") == len( + servers + ): + # Server is ready, get the node's role + cmd_args = [ + "redis-cli", + "-h", + server.host, + "-p", + str(server.port), + *get_redis_cli_option_args(cluster_folder, use_tls), + "cluster", + "nodes", + ] + cluster_slots_output = redis_cli_run_command(cmd_args) + node_info = parse_cluster_nodes(cluster_slots_output) + if node_info: + server.set_primary(node_info["is_primary"]) + logging.debug(f"Server {server} is ready!") + break + else: retries -= 1 + time.sleep(0.5) + continue + if retries < 0: raise Exception( f"Timeout exceeded trying to wait for server {server} to know all hosts.\n" @@ -633,10 +756,39 @@ def wait_for_message( else: time.sleep(0.1) continue - logging.warn(f"Timeout exceeded trying to check if {log_file} contains {message}") + logging.warn( + f"Timeout exceeded trying to check if {log_file} contains {message}" + ) return False +def wait_for_regex_in_log( + logfile: str, + pattern: str, + group: int, + timeout: int = 5, +) -> Optional[str]: + """Read the log file and search for a regular expression 'pattern'. If match is found + return the regex group identified by 'group'""" + + logging.debug(f"searching regex pattern: '{pattern}' in file: '{logfile}'") + timeout_start = time.time() + + while time.time() < timeout_start + timeout: + with open(logfile, "r") as f: + content = f.read() + lines = content.splitlines(keepends=False) + for line in lines: + result = re.search(pattern, line) + if result: + return result.group(group) + + else: + time.sleep(0.1) + continue + return None + + def is_address_already_in_use( server: RedisServer, log_file: str, @@ -674,7 +826,9 @@ def dir_path(path: str): raise NotADirectoryError(path) -def stop_server(server: RedisServer, cluster_folder: str, use_tls: bool, auth: str): +def stop_server( + server: RedisServer, cluster_folder: str, use_tls: bool, auth: str +): logging.debug(f"Stopping server {server}") cmd_args = [ "redis-cli", @@ -699,22 +853,26 @@ def stop_server(server: RedisServer, cluster_folder: str, use_tls: bool, auth: s ) output, err = p.communicate(timeout=5) if err and "Warning: Using a password with '-a'" not in err: - err_msg = ( - f"Failed to shutdown host {server.host}:{server.port}:\n {err}" - ) + err_msg = f"Failed to shutdown host {server.host}:{server.port}:\n {err}" logging.error(err_msg) raise Exception( f"Failed to execute command: {str(p.args)}\n Return code: {p.returncode}\n Error: {err}" ) - if not wait_for_server_shutdown(server, cluster_folder, use_tls, auth): - err_msg = "Timeout elapsed while waiting for the node to shutdown" + if not wait_for_server_shutdown( + server, cluster_folder, use_tls, auth + ): + err_msg = ( + "Timeout elapsed while waiting for the node to shutdown" + ) logging.error(err_msg) raise Exception(err_msg) return except subprocess.TimeoutExpired as e: raise_err = e retries -= 1 - err_msg = f"Failed to shutdown host {server.host}:{server.port}: {raise_err}" + err_msg = ( + f"Failed to shutdown host {server.host}:{server.port}: {raise_err}" + ) logging.error(err_msg) raise Exception(err_msg) @@ -789,7 +947,18 @@ def stop_clusters( auth: str, logfile: Optional[str], keep_folder: bool, + pids: Optional[str], ): + if pids: + pid_arr = pids.split(",") + for pid in pid_arr: + try: + # Kill the process group + os.killpg(int(pid), signal.SIGKILL) + except ProcessLookupError as e: + logging.debug(f"Could not kill server with PID: {pid}. {e}") + pass + if cluster_folder: cluster_folders = [cluster_folder] else: @@ -800,8 +969,13 @@ def stop_clusters( and prefix is not None and dirname.startswith(prefix) ] + + # request for graceful shutdown only if PID list was not provided + graceful_shutdown = pids is None for folder in cluster_folders: - stop_cluster(host, folder, use_tls, auth, logfile, keep_folder) + stop_cluster( + host, folder, use_tls, auth, logfile, keep_folder, graceful_shutdown + ) def stop_cluster( @@ -811,15 +985,24 @@ def stop_cluster( auth: str, logfile: Optional[str], keep_folder: bool, + graceful_shutdown: bool, ): - logfile = f"{cluster_folder}/cluster_manager.log" if not logfile else logfile - init_logger(logfile) - logging.debug(f"## Stopping cluster in path {cluster_folder}") - for it in os.scandir(cluster_folder): - if it.is_dir() and it.name.isdigit(): - port = it.name - stop_server(RedisServer(host, int(port)), cluster_folder, use_tls, auth) - logging.debug("All hosts were stopped") + if graceful_shutdown: + logfile = ( + f"{cluster_folder}/cluster_manager.log" if not logfile else logfile + ) + init_logger(logfile) + logging.debug(f"## Stopping cluster in path {cluster_folder}") + for it in os.scandir(cluster_folder): + if it.is_dir() and it.name.isdigit(): + port = it.name + stop_server( + RedisServer(host, int(port)), cluster_folder, use_tls, auth + ) + logging.debug("All hosts were stopped") + else: + logging.debug("Servers terminated using kill") + if not keep_folder: remove_folder(cluster_folder) @@ -875,6 +1058,7 @@ def main(): help="Create a Redis Cluster with cluster mode enabled. If not specified, a Standalone Redis cluster will be created.", required=False, ) + parser_start.add_argument( "--folder-path", type=dir_path, @@ -882,6 +1066,7 @@ def main(): required=False, default=CLUSTERS_FOLDER, ) + parser_start.add_argument( "-p", "--ports", @@ -928,7 +1113,9 @@ def main(): ) # Stop parser - parser_stop = subparsers.add_parser("stop", help="Shutdown a running cluster") + parser_stop = subparsers.add_parser( + "stop", help="Shutdown a running cluster" + ) parser_stop.add_argument( "--folder-path", type=dir_path, @@ -950,6 +1137,7 @@ def main(): help="Stop the cluster in the specified folder path. Expects a relative or a full path", required=False, ) + parser_stop.add_argument( "--keep-folder", action="store_true", @@ -958,6 +1146,13 @@ def main(): required=False, ) + parser_stop.add_argument( + "--pids", + type=str, + help="Optionally, provide comma separated list of process IDs to terminate", + default="", + ) + args = parser.parse_args() # Check logging level @@ -968,7 +1163,9 @@ def main(): f" -- must be one of: {' | '.join(LOG_LEVELS.keys())}" ) logging.root.setLevel(level=level) - logging.info(f"## Executing cluster_manager.py with the following args:\n {args}") + logging.info( + f"## Executing cluster_manager.py with the following args:\n {args}" + ) if args.action == "start": if not args.cluster_mode: @@ -1050,6 +1247,7 @@ def main(): args.auth, args.logfile, args.keep_folder, + args.pids, ) toc = time.perf_counter() logging.info(f"Cluster stopped in {toc - tic:0.4f} seconds") diff --git a/utils/package.json b/utils/package.json index 0bbd5c9d5b..6d3100505c 100644 --- a/utils/package.json +++ b/utils/package.json @@ -12,9 +12,9 @@ "author": "", "license": "Apache-2.0", "devDependencies": { - "@types/node": "^20.12.12", + "@types/node": "22.9", "@types/semver": "^7.5.8", - "prettier": "^2.8.8" + "prettier": "^3.3" }, "dependencies": { "child_process": "^1.0.2", diff --git a/utils/release-candidate-testing/node/index.js b/utils/release-candidate-testing/node/index.js index be55f97cc4..450ed0e308 100644 --- a/utils/release-candidate-testing/node/index.js +++ b/utils/release-candidate-testing/node/index.js @@ -4,7 +4,6 @@ import { GlideClient, GlideClusterClient } from "@valkey/valkey-glide"; import { ValkeyCluster } from "../../TestUtils.js"; - async function runCommands(client) { console.log("Executing commands"); // Set a bunch of keys @@ -41,7 +40,9 @@ async function runCommands(client) { // check that the correct number of keys were deleted if (deletedKeysNum !== 3) { console.log(deletedKeysNum); - throw new Error(`Unexpected number of keys deleted, expected 3, got ${deletedKeysNum}`); + throw new Error( + `Unexpected number of keys deleted, expected 3, got ${deletedKeysNum}`, + ); } // check that the keys were deleted for (let i = 1; i <= 3; i++) { @@ -74,7 +75,8 @@ async function clusterTests() { try { console.log("Testing cluster"); console.log("Creating cluster"); - let valkeyCluster = await ValkeyCluster.createCluster(true, + let valkeyCluster = await ValkeyCluster.createCluster( + true, 3, 1, getServerVersion, @@ -82,8 +84,12 @@ async function clusterTests() { console.log("Cluster created"); console.log("Connecting to cluster"); - let addresses = valkeyCluster.getAddresses().map((address) => { return { host: address[0], port: address[1] } }); - const client = await GlideClusterClient.createClient({ addresses: addresses }); + let addresses = valkeyCluster.getAddresses().map((address) => { + return { host: address[0], port: address[1] }; + }); + const client = await GlideClusterClient.createClient({ + addresses: addresses, + }); console.log("Connected to cluster"); await runCommands(client); @@ -103,9 +109,10 @@ async function clusterTests() { async function standaloneTests() { try { - console.log("Testing standalone Cluster") + console.log("Testing standalone Cluster"); console.log("Creating Cluster"); - let valkeyCluster = await ValkeyCluster.createCluster(false, + let valkeyCluster = await ValkeyCluster.createCluster( + false, 1, 1, getServerVersion, @@ -113,13 +120,14 @@ async function standaloneTests() { console.log("Cluster created"); console.log("Connecting to Cluster"); - let addresses = valkeyCluster.getAddresses().map((address) => { return { host: address[0], port: address[1] } }); + let addresses = valkeyCluster.getAddresses().map((address) => { + return { host: address[0], port: address[1] }; + }); const client = await GlideClient.createClient({ addresses: addresses }); console.log("Connected to Cluster"); await closeClientAndCluster(client, valkeyCluster); console.log("Done"); - } catch (error) { // Need this part just when running in our self-hosted runner, so if the test fails before closing Clusters we still kill them and clean up if (process.platform === "linux" && process.arch in ["arm", "arm64"]) { @@ -130,7 +138,6 @@ async function standaloneTests() { } } - async function main() { await clusterTests(); console.log("Cluster tests passed");