-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
30 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,42 @@ | ||
from functools import lru_cache | ||
from sqlalchemy import create_engine | ||
from sqlalchemy.orm import sessionmaker, Session | ||
from typing import Iterator | ||
|
||
from fastapi_utils.session import FastAPISessionMaker | ||
from sqlalchemy.orm import Session | ||
|
||
from application.settings import get_settings | ||
import logging | ||
from application.settings import get_settings | ||
|
||
# Set up logging | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
# this can be used in fast api path functions using Depends to inject a db session | ||
def get_session() -> Iterator[Session]: | ||
def _create_engine(): | ||
settings = get_settings() | ||
engine = create_engine( | ||
settings.READ_DATABASE_URL, | ||
pool_size=settings.DB_POOL_SIZE, | ||
max_overflow=settings.DB_POOL_MAX_OVERFLOW, | ||
) | ||
|
||
logger.info( | ||
f"Database engine created with pool_size={get_settings().DB_POOL_SIZE}, " | ||
f"max_overflow={get_settings().DB_POOL_MAX_OVERFLOW}" | ||
f"Engine created with pool_size={engine.pool.size()}, " | ||
f"max_overflow={engine.pool._max_overflow}, " | ||
) | ||
yield from _get_fastapi_sessionmaker().get_db() | ||
|
||
return engine | ||
|
||
|
||
# Create session factory | ||
engine = _create_engine() | ||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | ||
|
||
# this can be used in non path functions to create a context manager for a db session | ||
# see https://github.com/dmontagu/fastapi-utils/blob/master/fastapi_utils/session.py#L77:L91 | ||
def get_context_session() -> Iterator[Session]: | ||
return _get_fastapi_sessionmaker().context_session() | ||
|
||
def get_session() -> Iterator[Session]: | ||
db = SessionLocal() | ||
try: | ||
yield db | ||
finally: | ||
db.close() | ||
|
||
|
||
@lru_cache() | ||
def _get_fastapi_sessionmaker() -> FastAPISessionMaker: | ||
database_uri = get_settings().READ_DATABASE_URL | ||
return FastAPISessionMaker(database_uri) | ||
def get_context_session() -> Iterator[Session]: | ||
with SessionLocal() as session: | ||
yield session |