Skip to content

Commit

Permalink
Merge pull request #206 from jumpstarter-dev/fix-router-tls
Browse files Browse the repository at this point in the history
Handle all instances of tls credentials
  • Loading branch information
mangelajo authored Jan 8, 2025
2 parents 9426fd9 + a05bfc7 commit 9b31895
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 29 deletions.
4 changes: 3 additions & 1 deletion jumpstarter/client/lease.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from jumpstarter.common import MetadataFilter, TemporaryUnixListener
from jumpstarter.common.condition import condition_false, condition_present_and_equal, condition_true
from jumpstarter.common.streams import connect_router_stream
from jumpstarter.config.tls import TLSConfigV1Alpha1
from jumpstarter.v1 import jumpstarter_pb2, jumpstarter_pb2_grpc, kubernetes_pb2

logger = logging.getLogger(__name__)
Expand All @@ -27,6 +28,7 @@ class Lease(AbstractContextManager, AbstractAsyncContextManager):
unsafe: bool
release: bool = True # release on contexts exit
controller: jumpstarter_pb2_grpc.ControllerServiceStub = field(init=False)
tls_config: TLSConfigV1Alpha1 = field(default_factory=TLSConfigV1Alpha1)

def __post_init__(self):
self.controller = jumpstarter_pb2_grpc.ControllerServiceStub(self.channel)
Expand Down Expand Up @@ -110,7 +112,7 @@ def __exit__(self, exc_type, exc_value, traceback):
async def handle_async(self, stream):
logger.info("Connecting to Lease with name %s", self.name)
response = await self.controller.Dial(jumpstarter_pb2.DialRequest(lease_name=self.name))
async with connect_router_stream(response.router_endpoint, response.router_token, stream):
async with connect_router_stream(response.router_endpoint, response.router_token, stream, self.tls_config):
pass

@asynccontextmanager
Expand Down
8 changes: 4 additions & 4 deletions jumpstarter/common/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
import grpc


def ssl_channel_credentials(target: str, insecure: bool = False, ca_certificate: str = ""):
if insecure or os.getenv("JUMPSTARTER_GRPC_INSECURE") == "1":
def ssl_channel_credentials(target: str, tls_config):
if tls_config.insecure or os.getenv("JUMPSTARTER_GRPC_INSECURE") == "1":
parsed = urlparse(f"//{target}")
port = parsed.port if parsed.port else 443
root_certificates = ssl.get_server_certificate((parsed.hostname, port))
return grpc.ssl_channel_credentials(root_certificates=root_certificates.encode())
elif ca_certificate != "":
elif tls_config.ca != "":
# convert ca_certificate base64 encoded to pem encoded string
ca_certificate = base64.b64decode(ca_certificate)
ca_certificate = base64.b64decode(tls_config.ca)
return grpc.ssl_channel_credentials(ca_certificate)
else:
return grpc.ssl_channel_credentials()
Expand Down
4 changes: 2 additions & 2 deletions jumpstarter/common/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ class StreamRequestMetadata(BaseModel):


@asynccontextmanager
async def connect_router_stream(endpoint, token, stream):
async def connect_router_stream(endpoint, token, stream, tls_config):
credentials = grpc.composite_channel_credentials(
ssl_channel_credentials(endpoint),
ssl_channel_credentials(endpoint, tls_config),
grpc.access_token_call_credentials(token),
)

Expand Down
16 changes: 9 additions & 7 deletions jumpstarter/config/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from anyio.from_thread import BlockingPortal, start_blocking_portal
from pydantic import BaseModel, Field, ValidationError

from jumpstarter.client import Lease
from jumpstarter.common import MetadataFilter
from jumpstarter.common.grpc import aio_secure_channel, ssl_channel_credentials
from jumpstarter.v1 import jumpstarter_pb2, jumpstarter_pb2_grpc

from .common import CONFIG_PATH
from .env import JMP_DRIVERS_ALLOW, JMP_ENDPOINT, JMP_TOKEN
from .tls import TLSConfigV1Alpha1


def _allow_from_env():
Expand All @@ -27,10 +27,6 @@ def _allow_from_env():
case _:
return allow.split(","), False

class ClientConfigV1Alpha1TLS(BaseModel):
ca: str = Field(default="")
insecure: bool = Field(default=False)

class ClientConfigV1Alpha1Drivers(BaseModel):
allow: list[str] = Field(default_factory=[])
unsafe: bool = Field(default=False)
Expand All @@ -46,14 +42,14 @@ class ClientConfigV1Alpha1(BaseModel):
kind: Literal["ClientConfig"] = Field(default="ClientConfig")

endpoint: str
tls: ClientConfigV1Alpha1TLS = Field(default_factory=ClientConfigV1Alpha1TLS)
tls: TLSConfigV1Alpha1 = Field(default_factory=TLSConfigV1Alpha1)
token: str

drivers: ClientConfigV1Alpha1Drivers

async def channel(self):
credentials = grpc.composite_channel_credentials(
ssl_channel_credentials(self.endpoint, self.tls.insecure, self.tls.ca),
ssl_channel_credentials(self.endpoint, self.tls),
grpc.access_token_call_credentials(self.token),
)

Expand All @@ -79,13 +75,16 @@ def release_lease(self, name):
portal.call(self.release_lease_async, name)

async def request_lease_async(self, metadata_filter: MetadataFilter, portal:BlockingPortal):
# dynamically import to avoid circular imports
from jumpstarter.client import Lease
lease = Lease(
channel=await self.channel(),
name=None,
metadata_filter=metadata_filter,
portal=portal,
allow=self.drivers.allow,
unsafe=self.drivers.unsafe,
tls_config=self.tls,
)
return await lease.request_async()

Expand All @@ -100,6 +99,8 @@ async def release_lease_async(self, name):
@asynccontextmanager
async def lease_async(self, metadata_filter: MetadataFilter, lease_name: str | None, portal: BlockingPortal,
release=True):
# dynamically import to avoid circular imports
from jumpstarter.client import Lease
async with Lease(
channel=await self.channel(),
name=lease_name,
Expand All @@ -108,6 +109,7 @@ async def lease_async(self, metadata_filter: MetadataFilter, lease_name: str | N
allow=self.drivers.allow,
unsafe=self.drivers.unsafe,
release=release,
tls_config=self.tls,
) as lease:
yield lease

Expand Down
16 changes: 9 additions & 7 deletions jumpstarter/config/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from jumpstarter.common.grpc import aio_secure_channel, ssl_channel_credentials
from jumpstarter.common.importlib import import_class
from jumpstarter.driver import Driver
from jumpstarter.exporter import Exporter, Session

from .tls import TLSConfigV1Alpha1


class ExporterConfigV1Alpha1DriverInstance(BaseModel):
Expand All @@ -27,10 +28,6 @@ def instantiate(self) -> Driver:

return driver_class(children=children, **self.config)

class ExporterConfigV1Alpha1TLS(BaseModel):
ca: str = Field(default="")
insecure: bool = Field(default=False)

class ExporterConfigV1Alpha1(BaseModel):
BASE_PATH: ClassVar[Path] = Path("/etc/jumpstarter/exporters")

Expand All @@ -40,7 +37,7 @@ class ExporterConfigV1Alpha1(BaseModel):
kind: Literal["ExporterConfig"] = "ExporterConfig"

endpoint: str
tls: ExporterConfigV1Alpha1TLS = Field(default_factory=ExporterConfigV1Alpha1TLS)
tls: TLSConfigV1Alpha1 = Field(default_factory=TLSConfigV1Alpha1)
token: str

export: dict[str, ExporterConfigV1Alpha1DriverInstance] = Field(default_factory=dict)
Expand Down Expand Up @@ -83,6 +80,8 @@ def delete(self):

@asynccontextmanager
async def serve_unix_async(self):
# dynamic import to avoid circular imports
from jumpstarter.exporter import Session
with Session(
root_device=ExporterConfigV1Alpha1DriverInstance(children=self.export).instantiate(),
) as session:
Expand All @@ -96,15 +95,18 @@ def serve_unix(self):
yield path

async def serve(self):
# dynamic import to avoid circular imports
from jumpstarter.exporter import Exporter
def channel_factory():
credentials = grpc.composite_channel_credentials(
ssl_channel_credentials(self.endpoint, self.tls.insecure, self.tls.ca),
ssl_channel_credentials(self.endpoint, self.tls),
grpc.access_token_call_credentials(self.token),
)
return aio_secure_channel(self.endpoint, credentials)

async with Exporter(
channel_factory=channel_factory,
device_factory=ExporterConfigV1Alpha1DriverInstance(children=self.export).instantiate,
tls=self.tls,
) as exporter:
await exporter.serve()
6 changes: 4 additions & 2 deletions jumpstarter/config/exporter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from jumpstarter.common import MetadataFilter

from .client import ClientConfigV1Alpha1, ClientConfigV1Alpha1Drivers
from .exporter import ExporterConfigV1Alpha1, ExporterConfigV1Alpha1DriverInstance, ExporterConfigV1Alpha1TLS
from .exporter import ExporterConfigV1Alpha1, ExporterConfigV1Alpha1DriverInstance
from .tls import TLSConfigV1Alpha1

pytestmark = pytest.mark.anyio

Expand Down Expand Up @@ -39,6 +40,7 @@ async def test_exporter_serve(mock_controller):
endpoint=mock_controller,
token="dummy-client-token",
drivers=ClientConfigV1Alpha1Drivers(allow=[], unsafe=True),
tls=TLSConfigV1Alpha1(insecure=True),
)

async with create_task_group() as tg:
Expand Down Expand Up @@ -100,7 +102,7 @@ def test_exporter_config(monkeypatch, tmp_path):
kind="ExporterConfig",
endpoint="jumpstarter.my-lab.com:1443",
token="dGhpc2lzYXRva2VuLTEyMzQxMjM0MTIzNEyMzQtc2Rxd3Jxd2VycXdlcnF3ZXJxd2VyLTEyMzQxMjM0MTIz",
tls=ExporterConfigV1Alpha1TLS(ca="cacertificatedata", insecure=True),
tls=TLSConfigV1Alpha1(ca="cacertificatedata", insecure=True),
export={
"power": ExporterConfigV1Alpha1DriverInstance(
type="jumpstarter.drivers.power.PduPower",
Expand Down
6 changes: 6 additions & 0 deletions jumpstarter/config/tls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from pydantic import BaseModel, Field


class TLSConfigV1Alpha1(BaseModel):
ca: str = Field(default="")
insecure: bool = Field(default=False)
8 changes: 5 additions & 3 deletions jumpstarter/exporter/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from jumpstarter.common import Metadata
from jumpstarter.common.streams import connect_router_stream
from jumpstarter.config.tls import TLSConfigV1Alpha1
from jumpstarter.driver import Driver
from jumpstarter.exporter.session import Session
from jumpstarter.v1 import (
Expand All @@ -24,6 +25,7 @@ class Exporter(AbstractAsyncContextManager, Metadata):
channel_factory: Callable[[], grpc.aio.Channel]
device_factory: Callable[[], Driver]
lease_name: str = field(init=False, default="")
tls: TLSConfigV1Alpha1 = field(default_factory=TLSConfigV1Alpha1)

async def __aexit__(self, exc_type, exc_value, traceback):
controller = jumpstarter_pb2_grpc.ControllerServiceStub(self.channel_factory())
Expand All @@ -34,9 +36,9 @@ async def __aexit__(self, exc_type, exc_value, traceback):
)
)

async def __handle(self, path, endpoint, token):
async def __handle(self, path, endpoint, token, tls_config):
async with await connect_unix(path) as stream:
async with connect_router_stream(endpoint, token, stream):
async with connect_router_stream(endpoint, token, stream, tls_config):
pass

@asynccontextmanager
Expand Down Expand Up @@ -67,7 +69,7 @@ async def handle(self, lease_name, tg):
async with self.session() as path:
async for request in controller.Listen(jumpstarter_pb2.ListenRequest(lease_name=lease_name)):
logger.info("Handling new connection request on lease %s", lease_name)
tg.start_soon(self.__handle, path, request.router_endpoint, request.router_token)
tg.start_soon(self.__handle, path, request.router_endpoint, request.router_token, self.tls)

async def serve(self):
controller = jumpstarter_pb2_grpc.ControllerServiceStub(self.channel_factory())
Expand Down
6 changes: 3 additions & 3 deletions jumpstarter/testing/pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ class JumpstarterTest:
This class provides a client fixture that can be used to interact with
Jumpstarter services in test cases.
Looks for the JUMPSTARTER_HOST environment variable to connect to an
established Jumpstarter shell, or otherwise it will try to acquire a
lease for a single exporter using the filter_labels annotation.
Looks for the `JUMPSTARTER_HOST` environment variable to connect to an
established Jumpstarter shell, otherwise it will try to acquire a lease
for a single exporter using the filter_labels annotation.
i.e.:
.. code-block:: python
Expand Down

0 comments on commit 9b31895

Please sign in to comment.