Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PYTHON-2560 Retry KMS requests on transient errors #2024

Merged
merged 9 commits into from
Dec 4, 2024
64 changes: 46 additions & 18 deletions pymongo/asynchronous/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import contextlib
import enum
import socket
import time as time # noqa: PLC0414 # needed in sync version
import uuid
import weakref
from copy import deepcopy
Expand Down Expand Up @@ -63,7 +64,11 @@
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.asynchronous.pool import _configured_socket, _raise_connection_failure
from pymongo.asynchronous.pool import (
_configured_socket,
_get_timeout_details,
_raise_connection_failure,
)
from pymongo.common import CONNECT_TIMEOUT
from pymongo.daemon import _spawn_daemon
from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts
Expand All @@ -72,7 +77,7 @@
EncryptedCollectionError,
EncryptionError,
InvalidOperation,
PyMongoError,
NetworkTimeout,
ServerSelectionTimeoutError,
)
from pymongo.network_layer import BLOCKING_IO_ERRORS, async_sendall
Expand All @@ -88,6 +93,9 @@
if TYPE_CHECKING:
from pymongocrypt.mongocrypt import MongoCryptKmsContext

from pymongo.pyopenssl_context import _sslConn
from pymongo.typings import _Address


_IS_SYNC = False

Expand All @@ -103,6 +111,13 @@
_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)


async def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]:
try:
return await _configured_socket(address, opts)
except Exception as exc:
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))


@contextlib.contextmanager
def _wrap_encryption_errors() -> Iterator[None]:
"""Context manager to wrap encryption related errors."""
Expand Down Expand Up @@ -166,18 +181,22 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
None, # crlfile
False, # allow_invalid_certificates
False, # allow_invalid_hostnames
False,
) # disable_ocsp_endpoint_check
False, # disable_ocsp_endpoint_check
)
# CSOT: set timeout for socket creation.
connect_timeout = max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0.001)
opts = PoolOptions(
connect_timeout=connect_timeout,
socket_timeout=connect_timeout,
ssl_context=ctx,
)
host, port = parse_host(endpoint, _HTTPS_PORT)
address = parse_host(endpoint, _HTTPS_PORT)
sleep_u = kms_context.usleep
if sleep_u:
sleep_sec = float(sleep_u) / 1e6
await asyncio.sleep(sleep_sec)
try:
conn = await _configured_socket((host, port), opts)
conn = await _connect_kms(address, opts)
try:
await async_sendall(conn, message)
while kms_context.bytes_needed > 0:
Expand All @@ -194,20 +213,29 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
if not data:
raise OSError("KMS connection closed")
kms_context.feed(data)
# Async raises an OSError instead of returning empty bytes
except OSError as err:
raise OSError("KMS connection closed") from err
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
except MongoCryptError:
raise # Propagate MongoCryptError errors directly.
except Exception as exc:
# Wrap I/O errors in PyMongo exceptions.
if isinstance(exc, BLOCKING_IO_ERRORS):
exc = socket.timeout("timed out")
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
finally:
conn.close()
except (PyMongoError, MongoCryptError):
raise # Propagate pymongo errors directly.
except asyncio.CancelledError:
raise
except Exception as error:
# Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error)
except MongoCryptError:
raise # Propagate MongoCryptError errors directly.
except Exception as exc:
remaining = _csot.remaining()
if isinstance(exc, NetworkTimeout) or (remaining is not None and remaining <= 0):
raise
# Mark this attempt as failed and defer to libmongocrypt to retry.
try:
kms_context.fail()
except MongoCryptError as final_err:
exc = MongoCryptError(
f"{final_err}, last attempt failed with: {exc}", final_err.code
)
raise exc from final_err

async def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
"""Get the collection info for a namespace.
Expand Down
64 changes: 46 additions & 18 deletions pymongo/synchronous/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import contextlib
import enum
import socket
import time as time # noqa: PLC0414 # needed in sync version
import uuid
import weakref
from copy import deepcopy
Expand Down Expand Up @@ -67,7 +68,7 @@
EncryptedCollectionError,
EncryptionError,
InvalidOperation,
PyMongoError,
NetworkTimeout,
ServerSelectionTimeoutError,
)
from pymongo.network_layer import BLOCKING_IO_ERRORS, sendall
Expand All @@ -80,14 +81,21 @@
from pymongo.synchronous.cursor import Cursor
from pymongo.synchronous.database import Database
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.synchronous.pool import _configured_socket, _raise_connection_failure
from pymongo.synchronous.pool import (
_configured_socket,
_get_timeout_details,
_raise_connection_failure,
)
from pymongo.typings import _DocumentType, _DocumentTypeArg
from pymongo.uri_parser import parse_host
from pymongo.write_concern import WriteConcern

if TYPE_CHECKING:
from pymongocrypt.mongocrypt import MongoCryptKmsContext

from pymongo.pyopenssl_context import _sslConn
from pymongo.typings import _Address


_IS_SYNC = True

Expand All @@ -103,6 +111,13 @@
_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)


def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]:
try:
return _configured_socket(address, opts)
except Exception as exc:
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))


@contextlib.contextmanager
def _wrap_encryption_errors() -> Iterator[None]:
"""Context manager to wrap encryption related errors."""
Expand Down Expand Up @@ -166,18 +181,22 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
None, # crlfile
False, # allow_invalid_certificates
False, # allow_invalid_hostnames
False,
) # disable_ocsp_endpoint_check
False, # disable_ocsp_endpoint_check
)
# CSOT: set timeout for socket creation.
connect_timeout = max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0.001)
opts = PoolOptions(
connect_timeout=connect_timeout,
socket_timeout=connect_timeout,
ssl_context=ctx,
)
host, port = parse_host(endpoint, _HTTPS_PORT)
address = parse_host(endpoint, _HTTPS_PORT)
sleep_u = kms_context.usleep
if sleep_u:
sleep_sec = float(sleep_u) / 1e6
time.sleep(sleep_sec)
try:
conn = _configured_socket((host, port), opts)
conn = _connect_kms(address, opts)
try:
sendall(conn, message)
while kms_context.bytes_needed > 0:
Expand All @@ -194,20 +213,29 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
if not data:
raise OSError("KMS connection closed")
kms_context.feed(data)
# Async raises an OSError instead of returning empty bytes
except OSError as err:
raise OSError("KMS connection closed") from err
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
except MongoCryptError:
raise # Propagate MongoCryptError errors directly.
except Exception as exc:
# Wrap I/O errors in PyMongo exceptions.
if isinstance(exc, BLOCKING_IO_ERRORS):
exc = socket.timeout("timed out")
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
finally:
conn.close()
except (PyMongoError, MongoCryptError):
raise # Propagate pymongo errors directly.
except asyncio.CancelledError:
raise
except Exception as error:
# Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error)
except MongoCryptError:
raise # Propagate MongoCryptError errors directly.
except Exception as exc:
remaining = _csot.remaining()
if isinstance(exc, NetworkTimeout) or (remaining is not None and remaining <= 0):
raise
# Mark this attempt as failed and defer to libmongocrypt to retry.
try:
kms_context.fail()
except MongoCryptError as final_err:
exc = MongoCryptError(
f"{final_err}, last attempt failed with: {exc}", final_err.code
)
raise exc from final_err

def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
"""Get the collection info for a namespace.
Expand Down
86 changes: 84 additions & 2 deletions test/asynchronous/test_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import base64
import copy
import http.client
import json
import os
import pathlib
import re
Expand Down Expand Up @@ -91,6 +93,7 @@
WriteError,
)
from pymongo.operations import InsertOne, ReplaceOne, UpdateOne
from pymongo.ssl_support import get_ssl_context
from pymongo.write_concern import WriteConcern

_IS_SYNC = False
Expand Down Expand Up @@ -1366,9 +1369,8 @@ async def test_04_aws_endpoint_invalid_port(self):
"key": ("arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0"),
"endpoint": "kms.us-east-1.amazonaws.com:12345",
}
with self.assertRaisesRegex(EncryptionError, "kms.us-east-1.amazonaws.com:12345") as ctx:
with self.assertRaisesRegex(EncryptionError, "kms.us-east-1.amazonaws.com:12345"):
await self.client_encryption.create_data_key("aws", master_key=master_key)
self.assertIsInstance(ctx.exception.cause, AutoReconnect)

@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
async def test_05_aws_endpoint_wrong_region(self):
Expand Down Expand Up @@ -2853,6 +2855,86 @@ async def test_accepts_trim_factor_0(self):
assert len(payload) > len(self.payload_defaults)


# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#24-kms-retry-tests
class TestKmsRetryProse(AsyncEncryptionIntegrationTest):
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
async def asyncSetUp(self):
await super().asyncSetUp()
# 1, create client with only tlsCAFile.
providers: dict = copy.deepcopy(ALL_KMS_PROVIDERS)
providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9003"
providers["gcp"]["endpoint"] = "127.0.0.1:9003"
kms_tls_opts = {
p: {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM} for p in providers
}
self.client_encryption = self.create_client_encryption(
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts
)

async def http_post(self, path, data=None):
# Note, the connection to the mock server needs to be closed after
# each request because the server is single threaded.
ctx: ssl.SSLContext = get_ssl_context(
CLIENT_PEM, # certfile
None, # passphrase
CA_PEM, # ca_certs
None, # crlfile
False, # allow_invalid_certificates
False, # allow_invalid_hostnames
False, # disable_ocsp_endpoint_check
)
conn = http.client.HTTPSConnection("127.0.0.1:9003", context=ctx)
try:
if data is not None:
headers = {"Content-type": "application/json"}
body = json.dumps(data)
else:
headers = {}
body = None
conn.request("POST", path, body, headers)
res = conn.getresponse()
res.read()
finally:
conn.close()

async def _test(self, provider, master_key):
await self.http_post("/reset")
# Case 1: createDataKey and encrypt with TCP retry
await self.http_post("/set_failpoint/network", {"count": 1})
key_id = await self.client_encryption.create_data_key(provider, master_key=master_key)
await self.http_post("/set_failpoint/network", {"count": 1})
await self.client_encryption.encrypt(
123, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id
)

# Case 2: createDataKey and encrypt with HTTP retry
await self.http_post("/set_failpoint/http", {"count": 1})
key_id = await self.client_encryption.create_data_key(provider, master_key=master_key)
await self.http_post("/set_failpoint/http", {"count": 1})
await self.client_encryption.encrypt(
123, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id
)

# Case 3: createDataKey fails after too many retries
await self.http_post("/set_failpoint/network", {"count": 4})
with self.assertRaisesRegex(EncryptionError, "KMS request failed after"):
await self.client_encryption.create_data_key(provider, master_key=master_key)

async def test_kms_retry(self):
await self._test("aws", {"region": "foo", "key": "bar", "endpoint": "127.0.0.1:9003"})
await self._test("azure", {"keyVaultEndpoint": "127.0.0.1:9003", "keyName": "foo"})
await self._test(
"gcp",
{
"projectId": "foo",
"location": "bar",
"keyRing": "baz",
"keyName": "qux",
"endpoint": "127.0.0.1:9003",
},
)


# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#automatic-data-encryption-keys
class TestAutomaticDecryptionKeys(AsyncEncryptionIntegrationTest):
@async_client_context.require_no_standalone
Expand Down
Loading
Loading