diff --git a/backend/db/utils.py b/backend/db/utils.py index af476905..418f40e0 100644 --- a/backend/db/utils.py +++ b/backend/db/utils.py @@ -206,6 +206,41 @@ def refresh_derived_tables( break +ENGINES = { + # 'n3c': get_engine(), + # 'public': get_engine(schema='public') +} + + +def get_engine(isolation_level='AUTOCOMMIT', schema: str = SCHEMA, local=False): + """Connect to db + :param local: If True, connection is on local instead of production database.""" + engine = ENGINES.get(schema, '') or create_engine(get_pg_connect_url(local), isolation_level=isolation_level) + if schema not in ENGINES: + ENGINES[schema] = engine + + # noinspection PyUnusedLocal + @event.listens_for(engine, "connect", insert=True) + def set_search_path(dbapi_connection, connection_record): + """This does "set search_path to n3c;" when you connect. + https://docs.sqlalchemy.org/en/14/dialects/postgresql.html#setting-alternate-search-paths-on-connect + :param connection_record: Part of the example but we're not using yet. + + Ideally, we'd want to be able to call this whenever we want. But cannot be called outside of context of + initializing a connection. + """ + if not schema: + return + existing_autocommit = dbapi_connection.autocommit + dbapi_connection.autocommit = True + cursor = dbapi_connection.cursor() + cursor.execute(f"SET SESSION search_path='{schema}'") + cursor.close() + dbapi_connection.autocommit = existing_autocommit + + return engine + + # todo: make 'isolation_level' the final param, since we never override it. this would it so we dont' have to pass the # other params as named params. def get_db_connection(isolation_level='AUTOCOMMIT', schema: str = SCHEMA, local=False) -> Connection: diff --git a/backend/routes/db.py b/backend/routes/db.py index 75de5c1c..8afd5607 100644 --- a/backend/routes/db.py +++ b/backend/routes/db.py @@ -17,7 +17,8 @@ from backend.api_logger import Api_logger, get_ip_from_request from backend.db.refresh import refresh_db from backend.db.queries import get_concepts -from backend.db.utils import get_db_connection, sql_query, SCHEMA, sql_query_single_col, sql_in, sql_in_safe, run_sql +from backend.db.utils import get_db_connection, get_engine, sql_query, SCHEMA, sql_query_single_col, sql_in, \ + sql_in_safe, run_sql from backend.utils import return_err_with_trace, commify from enclave_wrangler.config import RESEARCHER_COLS from enclave_wrangler.models import convert_rows @@ -597,8 +598,14 @@ def n3c_comparison_rpt(): display comparison data compiled in generate_n3c_comparison_rpt() and get_comparison_rpt() """ - with get_db_connection() as con: + engine = get_engine(schema='public') + with engine.connect() as con: + # with get_db_connection(schema='public') as con: rpt = sql_query_single_col(con, "SELECT rpt FROM public.codeset_comparison WHERE rpt IS NOT NULL") + print() + # con.close() + engine.dispose() + print() return rpt @@ -609,7 +616,7 @@ def single_n3c_comparison_rpt(pair: str): and get_comparison_rpt() """ orig_codeset_id, new_codeset_id = pair.split('-') - with get_db_connection() as con: + with get_db_connection(schema='public') as con: rpt = sql_query_single_col( con, "SELECT rpt FROM public.codeset_comparison WHERE orig_codeset_id || '-' || new_codeset_id = :pair",