diff --git a/src/dstack/_internal/core/models/instances.py b/src/dstack/_internal/core/models/instances.py index 4a1bdd10b..8f118d750 100644 --- a/src/dstack/_internal/core/models/instances.py +++ b/src/dstack/_internal/core/models/instances.py @@ -154,3 +154,10 @@ def is_available(self) -> bool: self.IDLE, self.BUSY, ) + + def is_active(self) -> bool: + return self not in self.finished_statuses() + + @classmethod + def finished_statuses(cls) -> List["InstanceStatus"]: + return [cls.TERMINATING, cls.TERMINATED] diff --git a/src/dstack/_internal/core/models/volumes.py b/src/dstack/_internal/core/models/volumes.py index 0dcca6d51..7422ef426 100644 --- a/src/dstack/_internal/core/models/volumes.py +++ b/src/dstack/_internal/core/models/volumes.py @@ -19,6 +19,13 @@ class VolumeStatus(str, Enum): ACTIVE = "active" FAILED = "failed" + def is_active(self) -> bool: + return self not in self.finished_statuses() + + @classmethod + def finished_statuses(cls) -> List["VolumeStatus"]: + return [cls.FAILED] + class VolumeConfiguration(CoreModel): type: Literal["volume"] = "volume" diff --git a/src/dstack/_internal/server/routers/backends.py b/src/dstack/_internal/server/routers/backends.py index 9de5f28a7..11df993bc 100644 --- a/src/dstack/_internal/server/routers/backends.py +++ b/src/dstack/_internal/server/routers/backends.py @@ -21,6 +21,7 @@ ) from dstack._internal.server.security.permissions import Authenticated, ProjectAdmin from dstack._internal.server.services import backends +from dstack._internal.server.services.backends import handlers as backends_handlers from dstack._internal.server.services.config import ( ServerConfigManager, create_backend_config_yaml, @@ -87,8 +88,8 @@ async def delete_backends( user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), ): _, project = user_project - await backends.delete_backends( - session=session, project=project, backends_types=body.backends_names + await backends_handlers.delete_backends_safe( + session=session, project=project, backends_types=body.backends_names, error=True ) if settings.SERVER_CONFIG_ENABLED: await ServerConfigManager().sync_config(session=session) diff --git a/src/dstack/_internal/server/services/backends/__init__.py b/src/dstack/_internal/server/services/backends/__init__.py index 75d8dbcd0..35696f2a6 100644 --- a/src/dstack/_internal/server/services/backends/__init__.py +++ b/src/dstack/_internal/server/services/backends/__init__.py @@ -1,6 +1,6 @@ import asyncio import heapq -from typing import Callable, Coroutine, Dict, Iterable, List, Optional, Tuple, Type, Union +from typing import Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union from uuid import UUID from sqlalchemy import delete, update @@ -229,16 +229,25 @@ async def get_config_info( async def delete_backends( session: AsyncSession, project: ProjectModel, - backends_types: Iterable[BackendType], + backends_types: List[BackendType], ): if BackendType.DSTACK in backends_types: raise ServerClientError("Cannot delete dstack backend") + current_backends_types = set(b.type for b in project.backends) + deleted_backends_types = current_backends_types.intersection(backends_types) + if len(deleted_backends_types) == 0: + return await session.execute( delete(BackendModel).where( - BackendModel.type.in_(backends_types), + BackendModel.type.in_(deleted_backends_types), BackendModel.project_id == project.id, ) ) + logger.info( + "Deleted backends %s in project %s", + [b.value for b in deleted_backends_types], + project.name, + ) BackendTuple = Tuple[BackendModel, Backend] diff --git a/src/dstack/_internal/server/services/backends/handlers.py b/src/dstack/_internal/server/services/backends/handlers.py new file mode 100644 index 000000000..bcd4b857b --- /dev/null +++ b/src/dstack/_internal/server/services/backends/handlers.py @@ -0,0 +1,98 @@ +from typing import List + +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import ServerClientError +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.server.models import ProjectModel +from dstack._internal.server.services.backends import delete_backends +from dstack._internal.server.services.fleets import list_project_fleet_models +from dstack._internal.server.services.volumes import list_project_volumes +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +async def delete_backends_safe( + session: AsyncSession, + project: ProjectModel, + backends_types: List[BackendType], + error: bool = True, +): + try: + await _check_active_instances( + session=session, + project=project, + backends_types=backends_types, + error=error, + ) + await _check_active_volumes( + session=session, + project=project, + backends_types=backends_types, + error=error, + ) + except ServerClientError as e: + if error: + raise + logger.warning("%s", e.msg) + await delete_backends( + session=session, + project=project, + backends_types=backends_types, + ) + + +async def _check_active_instances( + session: AsyncSession, + project: ProjectModel, + backends_types: List[BackendType], + error: bool, +): + fleet_models = await list_project_fleet_models( + session=session, + project=project, + ) + for fleet_model in fleet_models: + for instance in fleet_model.instances: + if instance.status.is_active() and instance.backend in backends_types: + if error: + msg = ( + f"Backend {instance.backend.value} has active instances." + " Delete instances before deleting the backend." + ) + else: + msg = ( + f"Backend {instance.backend.value} has active instances." + " The backend will be deleted but instances may be left hanging." + ) + raise ServerClientError(msg) + + +async def _check_active_volumes( + session: AsyncSession, + project: ProjectModel, + backends_types: List[BackendType], + error: bool, +): + volume_models = await list_project_volumes( + session=session, + project=project, + ) + for volume_model in volume_models: + if ( + volume_model.status.is_active() + and volume_model.provisioning_data is not None + and volume_model.provisioning_data.backend in backends_types + ): + if error: + msg = ( + f"Backend {volume_model.provisioning_data.backend.value} has active volumes." + " Delete volumes before deleting the backend." + ) + else: + msg = ( + f"Backend {volume_model.provisioning_data.backend.value} has active volumes." + " The backend will be deleted but volumes may be left hanging." + ) + raise ServerClientError(msg) diff --git a/src/dstack/_internal/server/services/config.py b/src/dstack/_internal/server/services/config.py index 965854cc8..acdf53b16 100644 --- a/src/dstack/_internal/server/services/config.py +++ b/src/dstack/_internal/server/services/config.py @@ -29,6 +29,7 @@ from dstack._internal.server.services import backends as backends_services from dstack._internal.server.services import encryption as encryption_services from dstack._internal.server.services import projects as projects_services +from dstack._internal.server.services.backends.handlers import delete_backends_safe from dstack._internal.server.services.encryption import AnyEncryptionKeyConfig from dstack._internal.server.services.permissions import ( DefaultPermissions, @@ -595,8 +596,11 @@ async def _apply_project_config( ) except Exception as e: logger.warning("Failed to configure backend %s: %s", config_info.type, e) - await backends_services.delete_backends( - session=session, project=project, backends_types=backends_to_delete + await delete_backends_safe( + session=session, + project=project, + backends_types=list(backends_to_delete), + error=False, ) async def _init_config( diff --git a/src/dstack/_internal/server/services/repos.py b/src/dstack/_internal/server/services/repos.py index 8bb859cc5..77f8ccb7f 100644 --- a/src/dstack/_internal/server/services/repos.py +++ b/src/dstack/_internal/server/services/repos.py @@ -29,6 +29,9 @@ ) from dstack._internal.server.services.storage import get_default_storage from dstack._internal.utils.common import run_async +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) async def list_repos( @@ -170,6 +173,7 @@ async def delete_repos( delete(RepoModel).where(RepoModel.project_id == project.id, RepoModel.name.in_(repos_ids)) ) await session.commit() + logger.info("Deleted repos %s in project %s", repos_ids, project.name) async def get_repo_creds( @@ -263,6 +267,7 @@ async def delete_repo_creds( ) ) await session.commit() + logger.info("Deleted repo creds for repo %s user %s", repo.name, user.name) async def upload_code( diff --git a/src/dstack/_internal/server/services/users.py b/src/dstack/_internal/server/services/users.py index 6e5e667b3..8fbcd4286 100644 --- a/src/dstack/_internal/server/services/users.py +++ b/src/dstack/_internal/server/services/users.py @@ -146,6 +146,7 @@ async def delete_users( ): await session.execute(delete(UserModel).where(UserModel.name.in_(usernames))) await session.commit() + logger.info("Deleted users %s by user %s", usernames, user.name) async def get_user_model_by_name( diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 88a704619..fbfb4afd1 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -578,6 +578,7 @@ async def create_volume( else None, instances=[], deleted_at=deleted_at, + deleted=True if deleted_at else False, ) session.add(vm) await session.commit() @@ -641,8 +642,10 @@ def get_volume_provisioning_data( size_gb: int = 100, availability_zone: Optional[str] = None, backend_data: Optional[str] = None, + backend: Optional[BackendType] = None, ) -> VolumeProvisioningData: return VolumeProvisioningData( + backend=backend, volume_id=volume_id, size_gb=size_gb, availability_zone=availability_zone, diff --git a/src/tests/_internal/server/routers/test_backends.py b/src/tests/_internal/server/routers/test_backends.py index 550e6d072..f1f816a40 100644 --- a/src/tests/_internal/server/routers/test_backends.py +++ b/src/tests/_internal/server/routers/test_backends.py @@ -1,4 +1,5 @@ import json +from datetime import datetime, timezone from operator import itemgetter from unittest.mock import Mock, patch @@ -11,14 +12,21 @@ from dstack._internal.core.backends.oci import region as oci_region from dstack._internal.core.errors import BackendAuthError from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.core.models.users import GlobalRole, ProjectRole +from dstack._internal.core.models.volumes import VolumeStatus from dstack._internal.server.models import BackendModel from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.testing.common import ( create_backend, + create_fleet, + create_instance, + create_pool, create_project, create_user, + create_volume, get_auth_headers, + get_volume_provisioning_data, ) FAKE_OCI_CLIENT_CREDS = { @@ -1189,6 +1197,103 @@ async def test_deletes_backends(self, test_db, session: AsyncSession, client: As res = await session.execute(select(BackendModel)) assert len(res.scalars().all()) == 0 + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_400_if_backend_has_active_instances( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + backend = await create_backend(session=session, project_id=project.id) + pool = await create_pool(session=session, project=project) + fleet = await create_fleet(session=session, project=project) + instance1 = await create_instance( + session=session, + project=project, + pool=pool, + status=InstanceStatus.TERMINATED, + backend=backend.type, + ) + instance2 = await create_instance( + session=session, + project=project, + pool=pool, + status=InstanceStatus.IDLE, + backend=backend.type, + ) + fleet.instances.append(instance1) + fleet.instances.append(instance2) + await session.commit() + response = await client.post( + f"/api/project/{project.name}/backends/delete", + headers=get_auth_headers(user.token), + json={"backends_names": [backend.type.value]}, + ) + assert response.status_code == 400 + res = await session.execute(select(BackendModel)) + assert len(res.scalars().all()) == 1 + fleet.instances.pop() + await session.commit() + response = await client.post( + f"/api/project/{project.name}/backends/delete", + headers=get_auth_headers(user.token), + json={"backends_names": [backend.type.value]}, + ) + assert response.status_code == 200 + res = await session.execute(select(BackendModel)) + assert len(res.scalars().all()) == 0 + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_400_if_backend_has_active_volumes( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + backend = await create_backend(session=session, project_id=project.id) + await create_volume( + session=session, + project=project, + user=user, + backend=backend.type, + volume_provisioning_data=get_volume_provisioning_data(backend=backend.type), + status=VolumeStatus.ACTIVE, + deleted_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + ) + volume2 = await create_volume( + session=session, + project=project, + user=user, + backend=backend.type, + volume_provisioning_data=get_volume_provisioning_data(backend=backend.type), + status=VolumeStatus.ACTIVE, + ) + await session.commit() + response = await client.post( + f"/api/project/{project.name}/backends/delete", + headers=get_auth_headers(user.token), + json={"backends_names": [backend.type.value]}, + ) + assert response.status_code == 400 + res = await session.execute(select(BackendModel)) + assert len(res.scalars().all()) == 1 + await session.delete(volume2) + await session.commit() + response = await client.post( + f"/api/project/{project.name}/backends/delete", + headers=get_auth_headers(user.token), + json={"backends_names": [backend.type.value]}, + ) + assert response.status_code == 200 + res = await session.execute(select(BackendModel)) + assert len(res.scalars().all()) == 0 + class TestGetConfigInfo: @pytest.mark.asyncio