Skip to content

Commit

Permalink
Support http2 keep-alive
Browse files Browse the repository at this point in the history
Signed-off-by: Sahas Subramanian <[email protected]>
  • Loading branch information
shsms committed Oct 29, 2024
1 parent 689feaa commit e84c217
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 5 deletions.
105 changes: 103 additions & 2 deletions src/frequenz/client/base/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import dataclasses
import pathlib
from datetime import timedelta
from typing import assert_never
from urllib.parse import parse_qs, urlparse

Expand Down Expand Up @@ -41,6 +42,20 @@ class SslOptions:
"""


@dataclasses.dataclass(frozen=True)
class KeepAliveOptions:
"""Options for HTTP2 keep-alive pings."""

enabled: bool = True
"""Whether keep-alive should be enabled."""

interval: timedelta = timedelta(seconds=60)
"""The interval between pings."""

timeout: timedelta = timedelta(seconds=20)
"""The time in milliseconds to wait for a keep-alive response."""


@dataclasses.dataclass(frozen=True)
class ChannelOptions:
"""Options for a gRPC channel."""
Expand All @@ -51,6 +66,9 @@ class ChannelOptions:
ssl: SslOptions = SslOptions()
"""SSL options for the channel."""

keep_alive: KeepAliveOptions = KeepAliveOptions()
"""HTTP2 keep-alive options for the channel."""


def parse_grpc_uri(
uri: str,
Expand Down Expand Up @@ -120,6 +138,40 @@ def parse_grpc_uri(
parsed_uri.netloc if parsed_uri.port else f"{parsed_uri.netloc}:{defaults.port}"
)

keep_alive = (
defaults.keep_alive.enabled
if options.keep_alive is None
else options.keep_alive
)
channel_options = (
[
("grpc.http2.max_pings_without_data", 0),
("grpc.keepalive_permit_without_calls", 1),
(
"grpc.keepalive_time_ms",
(
(
options.keep_alive_interval
if options.keep_alive_interval is not None
else defaults.keep_alive.interval
).total_seconds()
* 1000
),
),
(
"grpc.keepalive_timeout_ms",
(
options.keep_alive_timeout
if options.keep_alive_timeout is not None
else defaults.keep_alive.timeout
).total_seconds()
* 1000,
),
]
if keep_alive
else None
)

ssl = defaults.ssl.enabled if options.ssl is None else options.ssl
if ssl:
return secure_channel(
Expand All @@ -141,8 +193,9 @@ def parse_grpc_uri(
defaults.ssl.certificate_chain,
),
),
channel_options,
)
return insecure_channel(target)
return insecure_channel(target, channel_options)


def _to_bool(value: str) -> bool:
Expand All @@ -160,6 +213,9 @@ class _QueryParams:
ssl_root_certificates_path: pathlib.Path | None
ssl_private_key_path: pathlib.Path | None
ssl_certificate_chain_path: pathlib.Path | None
keep_alive: bool | None
keep_alive_interval: timedelta | None
keep_alive_timeout: timedelta | None


def _parse_query_params(uri: str, query_string: str) -> _QueryParams:
Expand Down Expand Up @@ -200,6 +256,26 @@ def _parse_query_params(uri: str, query_string: str) -> _QueryParams:
f"Option(s) {', '.join(erros)} found in URI {uri!r}, but SSL is disabled",
)

keep_alive_option = options.pop("keep_alive", None)
keep_alive: bool | None = None
if keep_alive_option is not None:
keep_alive = _to_bool(keep_alive_option)

keep_alive_opts = {
k: options.pop(k, None)
for k in ("keep_alive_interval_s", "keep_alive_timeout_s")
}

if keep_alive is False:
erros = []
for opt_name, opt in keep_alive_opts.items():
if opt is not None:
erros.append(opt_name)
if erros:
raise ValueError(
f"Option(s) {', '.join(erros)} found in URI {uri!r}, but keep_alive is disabled",
)

if options:
names = ", ".join(options)
raise ValueError(
Expand All @@ -209,7 +285,32 @@ def _parse_query_params(uri: str, query_string: str) -> _QueryParams:

return _QueryParams(
ssl=ssl,
**{k: pathlib.Path(v) if v is not None else None for k, v in ssl_opts.items()},
ssl_root_certificates_path=(
pathlib.Path(ssl_opts["ssl_root_certificates_path"])
if ssl_opts["ssl_root_certificates_path"] is not None
else None
),
ssl_private_key_path=(
pathlib.Path(ssl_opts["ssl_private_key_path"])
if ssl_opts["ssl_private_key_path"] is not None
else None
),
ssl_certificate_chain_path=(
pathlib.Path(ssl_opts["ssl_certificate_chain_path"])
if ssl_opts["ssl_certificate_chain_path"] is not None
else None
),
keep_alive=keep_alive,
keep_alive_interval=(
timedelta(seconds=int(keep_alive_opts["keep_alive_interval_s"]))
if keep_alive_opts["keep_alive_interval_s"] is not None
else None
),
keep_alive_timeout=(
timedelta(seconds=int(keep_alive_opts["keep_alive_timeout_s"]))
if keep_alive_opts["keep_alive_timeout_s"] is not None
else None
),
)


Expand Down
102 changes: 99 additions & 3 deletions tests/test_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import dataclasses
import pathlib
from datetime import timedelta
from unittest import mock

import pytest
Expand All @@ -13,6 +14,7 @@

from frequenz.client.base.channel import (
ChannelOptions,
KeepAliveOptions,
SslOptions,
_to_bool,
parse_grpc_uri,
Expand Down Expand Up @@ -136,6 +138,67 @@ class _ValidUrlTestCase:
),
),
),
_ValidUrlTestCase(
title="Keep-alive no defaults",
uri="grpc://localhost:1234?keep_alive=1&keep_alive_interval_s=300"
+ "&keep_alive_timeout_s=60",
expected_host="localhost",
expected_port=1234,
expected_options=ChannelOptions(
keep_alive=KeepAliveOptions(
enabled=True,
interval=timedelta(minutes=5),
timeout=timedelta(minutes=1),
),
),
),
_ValidUrlTestCase(
title="Keep-alive default timeout",
uri="grpc://localhost:1234?keep_alive=1&keep_alive_interval_s=300",
defaults=ChannelOptions(
keep_alive=KeepAliveOptions(
enabled=True,
interval=timedelta(seconds=10),
timeout=timedelta(seconds=2),
),
),
expected_host="localhost",
expected_port=1234,
expected_options=ChannelOptions(
keep_alive=KeepAliveOptions(
enabled=True,
interval=timedelta(seconds=300),
timeout=timedelta(seconds=2),
),
),
),
_ValidUrlTestCase(
title="Keep-alive default interval",
uri="grpc://localhost:1234?keep_alive=1&keep_alive_timeout_s=60",
defaults=ChannelOptions(
keep_alive=KeepAliveOptions(
enabled=True, interval=timedelta(minutes=30)
),
),
expected_host="localhost",
expected_port=1234,
expected_options=ChannelOptions(
keep_alive=KeepAliveOptions(
enabled=True,
timeout=timedelta(minutes=1),
interval=timedelta(minutes=30),
),
),
),
_ValidUrlTestCase(
title="keep-alive disabled",
uri="grpc://localhost:1234?keep_alive=0",
expected_host="localhost",
expected_port=1234,
expected_options=ChannelOptions(
keep_alive=KeepAliveOptions(enabled=False),
),
),
],
ids=lambda case: case.title,
)
Expand All @@ -154,7 +217,9 @@ def test_parse_uri_ok( # pylint: disable=too-many-locals
)
expected_port = case.expected_port
expected_ssl = (
expected_options.ssl.enabled if "ssl=" in uri else defaults.ssl.enabled
expected_options.ssl.enabled
if "ssl=" in uri or defaults.ssl.enabled is None
else defaults.ssl.enabled
)
expected_root_certificates = (
expected_options.ssl.root_certificates
Expand Down Expand Up @@ -196,6 +261,35 @@ def test_parse_uri_ok( # pylint: disable=too-many-locals

assert channel == expected_channel
expected_target = f"{expected_host}:{expected_port}"
expected_keep_alive = (
expected_options.keep_alive if "keep_alive=" in uri else defaults.keep_alive
)
expected_keep_alive_interval = (
expected_keep_alive.interval
if "keep_alive_interval_s=" in uri
else defaults.keep_alive.interval
)
expected_keep_alive_timeout = (
expected_keep_alive.timeout
if "keep_alive_timeout_s=" in uri
else defaults.keep_alive.timeout
)
expected_channel_options = (
[
("grpc.http2.max_pings_without_data", 0),
("grpc.keepalive_permit_without_calls", 1),
(
"grpc.keepalive_time_ms",
(expected_keep_alive_interval.total_seconds() * 1000),
),
(
"grpc.keepalive_timeout_ms",
expected_keep_alive_timeout.total_seconds() * 1000,
),
]
if expected_keep_alive.enabled
else None
)
if expected_ssl:
if isinstance(expected_root_certificates, pathlib.Path):
get_contents_mock.assert_any_call(
Expand All @@ -221,10 +315,12 @@ def test_parse_uri_ok( # pylint: disable=too-many-locals
certificate_chain=expected_certificate_chain,
)
secure_channel_mock.assert_called_once_with(
expected_target, expected_credentials
expected_target, expected_credentials, expected_channel_options
)
else:
insecure_channel_mock.assert_called_once_with(expected_target)
insecure_channel_mock.assert_called_once_with(
expected_target, expected_channel_options
)


@pytest.mark.parametrize("value", ["true", "on", "1", "TrUe", "On", "ON", "TRUE"])
Expand Down

0 comments on commit e84c217

Please sign in to comment.