Skip to content

Commit

Permalink
Support tests that access db from a new thread (#2116)
Browse files Browse the repository at this point in the history
  • Loading branch information
r4victor authored Dec 18, 2024
1 parent 84786d9 commit c135335
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 28 deletions.
21 changes: 17 additions & 4 deletions src/dstack/_internal/server/db.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from contextlib import asynccontextmanager
from typing import Optional

from alembic import command, config
from sqlalchemy import event
from sqlalchemy.engine.interfaces import DBAPIConnection
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import ConnectionPoolEntry

Expand All @@ -13,9 +14,12 @@


class Database:
def __init__(self, url: str):
def __init__(self, url: str, engine: Optional[AsyncEngine] = None):
self.url = url
self.engine = create_async_engine(self.url, echo=settings.SQL_ECHO_ENABLED)
if engine is not None:
self.engine = engine
else:
self.engine = create_async_engine(self.url, echo=settings.SQL_ECHO_ENABLED)
self.session_maker = sessionmaker(
bind=self.engine,
expire_on_commit=False,
Expand All @@ -40,7 +44,16 @@ def get_session(self) -> AsyncSession:
return self.session_maker()


_db = Database(url=DATABASE_URL)
def get_new_db() -> Database:
"""
Creates a new Database with a new Engine.
Use this when you need to access the DB in a new thread instead of calling Database directly
since it's easier to monkey-patch.
"""
return Database(url=DATABASE_URL)


_db = get_new_db()


def get_db() -> Database:
Expand Down
50 changes: 27 additions & 23 deletions src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,32 +459,36 @@ async def create_instance(
termination_idle_time: int = DEFAULT_POOL_TERMINATION_IDLE_TIME,
region: str = "eu-west",
remote_connection_info: Optional[RemoteConnectionInfo] = None,
job_provisioning_data: Optional[JobProvisioningData] = None,
) -> InstanceModel:
if instance_id is None:
instance_id = uuid.uuid4()
job_provisioning_data = {
"backend": backend.value,
"instance_type": {
"name": "instance",
"resources": {
"cpus": 1,
"memory_mib": 512,
"gpus": [],
"spot": spot,
"disk": {"size_mib": 102400},
"description": "",
if job_provisioning_data is None:
job_provisioning_data_dict = {
"backend": backend.value,
"instance_type": {
"name": "instance",
"resources": {
"cpus": 1,
"memory_mib": 512,
"gpus": [],
"spot": spot,
"disk": {"size_mib": 102400},
"description": "",
},
},
},
"instance_id": "running_instance.id",
"ssh_proxy": None,
"hostname": "running_instance.ip",
"region": region,
"price": 0.1,
"username": "root",
"ssh_port": 22,
"dockerized": True,
"backend_data": None,
}
"instance_id": "running_instance.id",
"ssh_proxy": None,
"hostname": "running_instance.ip",
"region": region,
"price": 0.1,
"username": "root",
"ssh_port": 22,
"dockerized": True,
"backend_data": None,
}
else:
job_provisioning_data_dict = job_provisioning_data.dict()
offer = {
"backend": backend.value,
"instance": {
Expand Down Expand Up @@ -530,7 +534,7 @@ async def create_instance(
created_at=created_at,
started_at=created_at,
finished_at=finished_at,
job_provisioning_data=json.dumps(job_provisioning_data),
job_provisioning_data=json.dumps(job_provisioning_data_dict),
offer=json.dumps(offer),
price=1,
region=region,
Expand Down
14 changes: 13 additions & 1 deletion src/dstack/_internal/server/testing/conf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import pytest
import pytest_asyncio
from sqlalchemy import StaticPool
from sqlalchemy.ext.asyncio import create_async_engine
from testcontainers.postgres import PostgresContainer

from dstack._internal.server import settings
from dstack._internal.server.db import Database, override_db
from dstack._internal.server.models import BaseModel

Expand All @@ -15,15 +18,24 @@ def postgres_container():
@pytest_asyncio.fixture
async def test_db(request):
db_type = getattr(request, "param", "sqlite")
engine = None
if db_type == "sqlite":
db_url = "sqlite+aiosqlite://"
# For SQLite, allow accessing the in-memory DB from multiple threads:
# https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#using-a-memory-database-in-multiple-threads
engine = create_async_engine(
db_url,
echo=settings.SQL_ECHO_ENABLED,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
elif db_type == "postgres":
if not request.config.getoption("--runpostgres"):
pytest.skip("Skipping Postgres tests as --runpostgres was not provided")
db_url = request.getfixturevalue("postgres_container")
else:
raise ValueError(f"Unknown db_type {db_type}")
db = Database(db_url)
db = Database(db_url, engine=engine)
override_db(db)
async with db.engine.begin() as conn:
await conn.run_sync(BaseModel.metadata.drop_all)
Expand Down

0 comments on commit c135335

Please sign in to comment.