From e84c217a468a71b9475ac34679d0b1e1de6e0e55 Mon Sep 17 00:00:00 2001 From: Sahas Subramanian Date: Tue, 22 Oct 2024 14:57:55 +0200 Subject: [PATCH] Support http2 keep-alive Signed-off-by: Sahas Subramanian --- src/frequenz/client/base/channel.py | 105 +++++++++++++++++++++++++++- tests/test_channel.py | 102 ++++++++++++++++++++++++++- 2 files changed, 202 insertions(+), 5 deletions(-) diff --git a/src/frequenz/client/base/channel.py b/src/frequenz/client/base/channel.py index cfbd9f8..fde5a5b 100644 --- a/src/frequenz/client/base/channel.py +++ b/src/frequenz/client/base/channel.py @@ -5,6 +5,7 @@ import dataclasses import pathlib +from datetime import timedelta from typing import assert_never from urllib.parse import parse_qs, urlparse @@ -41,6 +42,20 @@ class SslOptions: """ +@dataclasses.dataclass(frozen=True) +class KeepAliveOptions: + """Options for HTTP2 keep-alive pings.""" + + enabled: bool = True + """Whether keep-alive should be enabled.""" + + interval: timedelta = timedelta(seconds=60) + """The interval between pings.""" + + timeout: timedelta = timedelta(seconds=20) + """The time in milliseconds to wait for a keep-alive response.""" + + @dataclasses.dataclass(frozen=True) class ChannelOptions: """Options for a gRPC channel.""" @@ -51,6 +66,9 @@ class ChannelOptions: ssl: SslOptions = SslOptions() """SSL options for the channel.""" + keep_alive: KeepAliveOptions = KeepAliveOptions() + """HTTP2 keep-alive options for the channel.""" + def parse_grpc_uri( uri: str, @@ -120,6 +138,40 @@ def parse_grpc_uri( parsed_uri.netloc if parsed_uri.port else f"{parsed_uri.netloc}:{defaults.port}" ) + keep_alive = ( + defaults.keep_alive.enabled + if options.keep_alive is None + else options.keep_alive + ) + channel_options = ( + [ + ("grpc.http2.max_pings_without_data", 0), + ("grpc.keepalive_permit_without_calls", 1), + ( + "grpc.keepalive_time_ms", + ( + ( + options.keep_alive_interval + if options.keep_alive_interval is not None + else defaults.keep_alive.interval + ).total_seconds() + * 1000 + ), + ), + ( + "grpc.keepalive_timeout_ms", + ( + options.keep_alive_timeout + if options.keep_alive_timeout is not None + else defaults.keep_alive.timeout + ).total_seconds() + * 1000, + ), + ] + if keep_alive + else None + ) + ssl = defaults.ssl.enabled if options.ssl is None else options.ssl if ssl: return secure_channel( @@ -141,8 +193,9 @@ def parse_grpc_uri( defaults.ssl.certificate_chain, ), ), + channel_options, ) - return insecure_channel(target) + return insecure_channel(target, channel_options) def _to_bool(value: str) -> bool: @@ -160,6 +213,9 @@ class _QueryParams: ssl_root_certificates_path: pathlib.Path | None ssl_private_key_path: pathlib.Path | None ssl_certificate_chain_path: pathlib.Path | None + keep_alive: bool | None + keep_alive_interval: timedelta | None + keep_alive_timeout: timedelta | None def _parse_query_params(uri: str, query_string: str) -> _QueryParams: @@ -200,6 +256,26 @@ def _parse_query_params(uri: str, query_string: str) -> _QueryParams: f"Option(s) {', '.join(erros)} found in URI {uri!r}, but SSL is disabled", ) + keep_alive_option = options.pop("keep_alive", None) + keep_alive: bool | None = None + if keep_alive_option is not None: + keep_alive = _to_bool(keep_alive_option) + + keep_alive_opts = { + k: options.pop(k, None) + for k in ("keep_alive_interval_s", "keep_alive_timeout_s") + } + + if keep_alive is False: + erros = [] + for opt_name, opt in keep_alive_opts.items(): + if opt is not None: + erros.append(opt_name) + if erros: + raise ValueError( + f"Option(s) {', '.join(erros)} found in URI {uri!r}, but keep_alive is disabled", + ) + if options: names = ", ".join(options) raise ValueError( @@ -209,7 +285,32 @@ def _parse_query_params(uri: str, query_string: str) -> _QueryParams: return _QueryParams( ssl=ssl, - **{k: pathlib.Path(v) if v is not None else None for k, v in ssl_opts.items()}, + ssl_root_certificates_path=( + pathlib.Path(ssl_opts["ssl_root_certificates_path"]) + if ssl_opts["ssl_root_certificates_path"] is not None + else None + ), + ssl_private_key_path=( + pathlib.Path(ssl_opts["ssl_private_key_path"]) + if ssl_opts["ssl_private_key_path"] is not None + else None + ), + ssl_certificate_chain_path=( + pathlib.Path(ssl_opts["ssl_certificate_chain_path"]) + if ssl_opts["ssl_certificate_chain_path"] is not None + else None + ), + keep_alive=keep_alive, + keep_alive_interval=( + timedelta(seconds=int(keep_alive_opts["keep_alive_interval_s"])) + if keep_alive_opts["keep_alive_interval_s"] is not None + else None + ), + keep_alive_timeout=( + timedelta(seconds=int(keep_alive_opts["keep_alive_timeout_s"])) + if keep_alive_opts["keep_alive_timeout_s"] is not None + else None + ), ) diff --git a/tests/test_channel.py b/tests/test_channel.py index 43490aa..14e2527 100644 --- a/tests/test_channel.py +++ b/tests/test_channel.py @@ -5,6 +5,7 @@ import dataclasses import pathlib +from datetime import timedelta from unittest import mock import pytest @@ -13,6 +14,7 @@ from frequenz.client.base.channel import ( ChannelOptions, + KeepAliveOptions, SslOptions, _to_bool, parse_grpc_uri, @@ -136,6 +138,67 @@ class _ValidUrlTestCase: ), ), ), + _ValidUrlTestCase( + title="Keep-alive no defaults", + uri="grpc://localhost:1234?keep_alive=1&keep_alive_interval_s=300" + + "&keep_alive_timeout_s=60", + expected_host="localhost", + expected_port=1234, + expected_options=ChannelOptions( + keep_alive=KeepAliveOptions( + enabled=True, + interval=timedelta(minutes=5), + timeout=timedelta(minutes=1), + ), + ), + ), + _ValidUrlTestCase( + title="Keep-alive default timeout", + uri="grpc://localhost:1234?keep_alive=1&keep_alive_interval_s=300", + defaults=ChannelOptions( + keep_alive=KeepAliveOptions( + enabled=True, + interval=timedelta(seconds=10), + timeout=timedelta(seconds=2), + ), + ), + expected_host="localhost", + expected_port=1234, + expected_options=ChannelOptions( + keep_alive=KeepAliveOptions( + enabled=True, + interval=timedelta(seconds=300), + timeout=timedelta(seconds=2), + ), + ), + ), + _ValidUrlTestCase( + title="Keep-alive default interval", + uri="grpc://localhost:1234?keep_alive=1&keep_alive_timeout_s=60", + defaults=ChannelOptions( + keep_alive=KeepAliveOptions( + enabled=True, interval=timedelta(minutes=30) + ), + ), + expected_host="localhost", + expected_port=1234, + expected_options=ChannelOptions( + keep_alive=KeepAliveOptions( + enabled=True, + timeout=timedelta(minutes=1), + interval=timedelta(minutes=30), + ), + ), + ), + _ValidUrlTestCase( + title="keep-alive disabled", + uri="grpc://localhost:1234?keep_alive=0", + expected_host="localhost", + expected_port=1234, + expected_options=ChannelOptions( + keep_alive=KeepAliveOptions(enabled=False), + ), + ), ], ids=lambda case: case.title, ) @@ -154,7 +217,9 @@ def test_parse_uri_ok( # pylint: disable=too-many-locals ) expected_port = case.expected_port expected_ssl = ( - expected_options.ssl.enabled if "ssl=" in uri else defaults.ssl.enabled + expected_options.ssl.enabled + if "ssl=" in uri or defaults.ssl.enabled is None + else defaults.ssl.enabled ) expected_root_certificates = ( expected_options.ssl.root_certificates @@ -196,6 +261,35 @@ def test_parse_uri_ok( # pylint: disable=too-many-locals assert channel == expected_channel expected_target = f"{expected_host}:{expected_port}" + expected_keep_alive = ( + expected_options.keep_alive if "keep_alive=" in uri else defaults.keep_alive + ) + expected_keep_alive_interval = ( + expected_keep_alive.interval + if "keep_alive_interval_s=" in uri + else defaults.keep_alive.interval + ) + expected_keep_alive_timeout = ( + expected_keep_alive.timeout + if "keep_alive_timeout_s=" in uri + else defaults.keep_alive.timeout + ) + expected_channel_options = ( + [ + ("grpc.http2.max_pings_without_data", 0), + ("grpc.keepalive_permit_without_calls", 1), + ( + "grpc.keepalive_time_ms", + (expected_keep_alive_interval.total_seconds() * 1000), + ), + ( + "grpc.keepalive_timeout_ms", + expected_keep_alive_timeout.total_seconds() * 1000, + ), + ] + if expected_keep_alive.enabled + else None + ) if expected_ssl: if isinstance(expected_root_certificates, pathlib.Path): get_contents_mock.assert_any_call( @@ -221,10 +315,12 @@ def test_parse_uri_ok( # pylint: disable=too-many-locals certificate_chain=expected_certificate_chain, ) secure_channel_mock.assert_called_once_with( - expected_target, expected_credentials + expected_target, expected_credentials, expected_channel_options ) else: - insecure_channel_mock.assert_called_once_with(expected_target) + insecure_channel_mock.assert_called_once_with( + expected_target, expected_channel_options + ) @pytest.mark.parametrize("value", ["true", "on", "1", "TrUe", "On", "ON", "TRUE"])