Skip to content

Commit

Permalink
Add migration from state.json on gateways
Browse files Browse the repository at this point in the history
This will allow running `dstack-proxy` on existing
gateway instances without losing the state left
from the older `dstack-gateway` application.
  • Loading branch information
jvstme committed Dec 23, 2024
1 parent b7c622b commit 7e5b612
Show file tree
Hide file tree
Showing 18 changed files with 644 additions and 14 deletions.
14 changes: 12 additions & 2 deletions src/dstack/_internal/proxy/gateway/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""FastAPI app running on a gateway."""

from contextlib import asynccontextmanager
from pathlib import Path
from typing import Optional

from fastapi import FastAPI
Expand All @@ -16,7 +17,8 @@
get_gateway_injector_from_app,
get_gateway_proxy_repo,
)
from dstack._internal.proxy.gateway.repo import GatewayProxyRepo
from dstack._internal.proxy.gateway.repo.repo import GatewayProxyRepo
from dstack._internal.proxy.gateway.repo.state_v1 import migrate_from_state_v1
from dstack._internal.proxy.gateway.routers.auth import router as auth_router
from dstack._internal.proxy.gateway.routers.config import router as config_router
from dstack._internal.proxy.gateway.routers.registry import router as registry_router
Expand All @@ -31,6 +33,8 @@
from dstack.version import __version__

STATE_FILE = DSTACK_DIR_ON_GATEWAY / "state-v2.json"
LEGACY_STATE_FILE = DSTACK_DIR_ON_GATEWAY / "state.json"
LEGACY_KEYS_DIR = Path("~/.ssh/projects").expanduser().resolve()


@asynccontextmanager
Expand All @@ -47,9 +51,15 @@ async def lifespan(app: FastAPI):


def make_app(repo: Optional[GatewayProxyRepo] = None, nginx: Optional[Nginx] = None) -> FastAPI:
if repo is None:
migrate_from_state_v1(
v1_file=LEGACY_STATE_FILE, v2_file=STATE_FILE, keys_dir=LEGACY_KEYS_DIR
)
repo = GatewayProxyRepo.load(STATE_FILE)

app = FastAPI(lifespan=lifespan)
app.state.proxy_dependency_injector = GatewayDependencyInjector(
repo=repo or GatewayProxyRepo.load(STATE_FILE),
repo=repo,
auth=GatewayProxyAuthProvider(
server_client=HTTPMultiClient(SERVER_CONNECTIONS_DIR_ON_GATEWAY)
),
Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/proxy/gateway/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from fastapi import Depends, FastAPI, Request

from dstack._internal.proxy.gateway.repo import GatewayProxyRepo
from dstack._internal.proxy.gateway.repo.repo import GatewayProxyRepo
from dstack._internal.proxy.gateway.services.nginx import Nginx
from dstack._internal.proxy.gateway.services.stats import StatsCollector
from dstack._internal.proxy.lib.auth import BaseProxyAuthProvider
Expand Down
Empty file.
File renamed without changes.
156 changes: 156 additions & 0 deletions src/dstack/_internal/proxy/gateway/repo/state_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""
Migration from the legacy state.json file of dstack-gateway to the new
state-v2.json file of dstack-proxy.
"""

import json
from datetime import datetime
from pathlib import Path
from typing import Iterable

from dstack._internal.core.models.instances import SSHConnectionParams
from dstack._internal.proxy.gateway.models import ACMESettings, GlobalProxyConfig, ModelEntrypoint
from dstack._internal.proxy.gateway.repo.repo import GatewayProxyRepo, State
from dstack._internal.proxy.lib.models import (
AnyModelFormat,
ChatModel,
OpenAIChatModelFormat,
Project,
Replica,
Service,
TGIChatModelFormat,
)
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)


def migrate_from_state_v1(v1_file: Path, v2_file: Path, keys_dir: Path) -> None:
if v2_file.exists() or not v1_file.exists():
return
state_v1 = json.loads(v1_file.read_text())
state = parse_state_v1(state_v1, keys_dir)
repo = GatewayProxyRepo(state=state, file=v2_file)
repo.save()
logger.info("Successfully migrated state from %s to %s", v1_file, v2_file)


def parse_state_v1(state_v1: dict, keys_dir: Path) -> State:
services, models = get_services_models(state_v1)
return State(
services=services,
models=models,
entrypoints=get_entrypoints(state_v1.get("store", {})),
projects=get_projects(state_v1.get("store", {}).get("projects", {}), keys_dir),
config=get_config(state_v1.get("store", {}).get("nginx", {}).get("acme_settings", {})),
)


def get_services_models(
state_v1: dict,
) -> tuple[dict[str, dict[str, Service]], dict[str, dict[str, ChatModel]]]:
service_id_to_project_name = {}
for project_name, project_services in state_v1.get("store", {}).get("projects", {}).items():
for service_id in project_services:
service_id_to_project_name[service_id] = project_name

services: dict[str, dict[str, Service]] = {}
models: dict[str, dict[str, ChatModel]] = {}
for service in state_v1.get("store", {}).get("services", {}).values():
project_name = service_id_to_project_name[service["id"]]
replicas = []
for replica in service.get("replicas", []):
replicas.append(parse_replica(replica))
parsed_service = Service(
project_name=project_name,
run_name=service["domain"].split(".")[0],
domain=service["domain"],
https=service.get("https", True),
auth=service["auth"],
client_max_body_size=service.get("client_max_body_size", 1024 * 1024),
replicas=tuple(replicas),
)
services.setdefault(project_name, {})[parsed_service.run_name] = parsed_service
if model := service.get("options", {}).get("openai", {}).get("model", {}):
parsed_model = parse_model(
project_name, parsed_service.run_name, model, state_v1["openai"]["index"]
)
models.setdefault(project_name, {})[parsed_model.name] = parsed_model

return services, models


def parse_replica(replica: dict) -> Replica:
ssh_proxy = None
if (ssh_proxy_destination := replica.get("ssh_jump_host")) and (
ssh_proxy_port := replica.get("ssh_jump_port")
):
proxy_user, proxy_host = ssh_proxy_destination.split("@")
ssh_proxy = SSHConnectionParams(
hostname=proxy_host,
username=proxy_user,
port=ssh_proxy_port,
)
return Replica(
id=replica["id"],
app_port=replica["app_port"],
ssh_destination=replica["ssh_host"],
ssh_port=replica["ssh_port"],
ssh_proxy=ssh_proxy,
)


def parse_model(project_name: str, run_name: str, model: dict, openai_index: dict) -> ChatModel:
created_ts = openai_index[project_name]["chat"][model["name"]]["created"]
format_spec: AnyModelFormat
if model["format"] == "tgi":
format_spec = TGIChatModelFormat(
chat_template=model["chat_template"], eos_token=model["eos_token"]
)
else:
format_spec = OpenAIChatModelFormat(prefix=model["prefix"])
return ChatModel(
project_name=project_name,
name=model["name"],
created_at=datetime.fromtimestamp(created_ts),
run_name=run_name,
format_spec=format_spec,
)


def get_entrypoints(store: dict) -> dict[str, ModelEntrypoint]:
entrypoint_domain_to_project_name = {}
for entrypoint_domain, (project_name, _) in store.get("entrypoints", {}).items():
entrypoint_domain_to_project_name[entrypoint_domain] = project_name

entrypoints = {}
for site_config in store.get("nginx", {}).get("configs", {}).values():
if site_config["type"] == "entrypoint":
entrypoint = ModelEntrypoint(
project_name=entrypoint_domain_to_project_name[site_config["domain"]],
domain=site_config["domain"],
https=site_config["https"],
)
entrypoints[entrypoint.project_name] = entrypoint

return entrypoints


def get_projects(project_names: Iterable[str], keys_dir: Path) -> dict[str, Project]:
projects = {}
for project_name in project_names:
projects[project_name] = Project(
name=project_name,
ssh_private_key=(keys_dir / project_name).read_text(),
)
return projects


def get_config(acme_settings: dict) -> GlobalProxyConfig:
return GlobalProxyConfig(
acme_settings=ACMESettings(
server=acme_settings.get("server"),
eab_kid=acme_settings.get("eab_kid"),
eab_hmac_key=acme_settings.get("eab_hmac_key"),
)
)
2 changes: 1 addition & 1 deletion src/dstack/_internal/proxy/gateway/routers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from dstack._internal.proxy.gateway.deps import get_gateway_proxy_repo
from dstack._internal.proxy.gateway.models import ACMESettings, GlobalProxyConfig
from dstack._internal.proxy.gateway.repo import GatewayProxyRepo
from dstack._internal.proxy.gateway.repo.repo import GatewayProxyRepo
from dstack._internal.proxy.gateway.schemas.common import OkResponse
from dstack._internal.proxy.gateway.schemas.config import ConfigRequest

Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/proxy/gateway/routers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import dstack._internal.proxy.gateway.services.registry as registry_services
from dstack._internal.proxy.gateway.deps import get_gateway_proxy_repo, get_nginx
from dstack._internal.proxy.gateway.repo import GatewayProxyRepo
from dstack._internal.proxy.gateway.repo.repo import GatewayProxyRepo
from dstack._internal.proxy.gateway.schemas.common import OkResponse
from dstack._internal.proxy.gateway.schemas.registry import (
RegisterEntrypointRequest,
Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/proxy/gateway/routers/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fastapi import APIRouter, Depends

from dstack._internal.proxy.gateway.deps import get_gateway_proxy_repo, get_stats_collector
from dstack._internal.proxy.gateway.repo import GatewayProxyRepo
from dstack._internal.proxy.gateway.repo.repo import GatewayProxyRepo
from dstack._internal.proxy.gateway.schemas.stats import ServiceStats
from dstack._internal.proxy.gateway.services.stats import StatsCollector, get_service_stats

Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/proxy/gateway/services/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import dstack._internal.proxy.gateway.schemas.registry as schemas
from dstack._internal.core.models.instances import SSHConnectionParams
from dstack._internal.proxy.gateway import models as gateway_models
from dstack._internal.proxy.gateway.repo import GatewayProxyRepo
from dstack._internal.proxy.gateway.repo.repo import GatewayProxyRepo
from dstack._internal.proxy.gateway.services.nginx import (
ModelEntrypointConfig,
Nginx,
Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/proxy/gateway/services/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from pydantic import BaseModel

from dstack._internal.proxy.gateway.repo import GatewayProxyRepo
from dstack._internal.proxy.gateway.repo.repo import GatewayProxyRepo
from dstack._internal.proxy.gateway.schemas.stats import PerWindowStats, ServiceStats, Stat
from dstack._internal.utils.common import run_async

Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from dstack._internal.proxy.gateway.models import ACMESettings, GlobalProxyConfig, ModelEntrypoint
from dstack._internal.proxy.gateway.repo import GatewayProxyRepo
from dstack._internal.proxy.gateway.repo.repo import GatewayProxyRepo
from dstack._internal.proxy.lib.testing.common import make_project, make_service
from tests._internal.proxy.lib.routers.test_model_proxy import make_model

Expand Down
Loading

0 comments on commit 7e5b612

Please sign in to comment.