diff --git a/lib/galaxy/datatypes/display_applications/util.py b/lib/galaxy/datatypes/display_applications/util.py index aefc65d9ff12..47a70b706e8e 100644 --- a/lib/galaxy/datatypes/display_applications/util.py +++ b/lib/galaxy/datatypes/display_applications/util.py @@ -17,13 +17,13 @@ def decode_dataset_user(trans, dataset_hash, user_hash): # decode dataset id as usual # decode user id using the dataset create time as the key dataset_id = trans.security.decode_id(dataset_hash) - dataset = trans.sa_session.query(trans.app.model.HistoryDatasetAssociation).get(dataset_id) + dataset = trans.sa_session.get(trans.app.model.HistoryDatasetAssociation, dataset_id) assert dataset, "Bad Dataset id provided to decode_dataset_user" if user_hash in [None, "None"]: user = None else: security = IdEncodingHelper(id_secret=dataset.create_time) user_id = security.decode_id(user_hash) - user = trans.sa_session.query(trans.app.model.User).get(user_id) + user = trans.sa_session.get(trans.app.model.User, user_id) assert user, "A Bad user id was passed to decode_dataset_user" return dataset, user diff --git a/lib/galaxy/managers/model_stores.py b/lib/galaxy/managers/model_stores.py index f348958aee59..1e817d130a60 100644 --- a/lib/galaxy/managers/model_stores.py +++ b/lib/galaxy/managers/model_stores.py @@ -95,11 +95,11 @@ def setup_history_export_job(self, request: SetupHistoryExportJob): include_deleted = request.include_deleted store_directory = request.store_directory - history = self._sa_session.query(model.History).get(history_id) + history = self._sa_session.get(model.History, history_id) # symlink files on export, on worker files will tarred up in a dereferenced manner. with DirectoryModelExportStore(store_directory, app=self._app, export_files="symlink") as export_store: export_store.export_history(history, include_hidden=include_hidden, include_deleted=include_deleted) - job = self._sa_session.query(model.Job).get(job_id) + job = self._sa_session.get(model.Job, job_id) job.state = model.Job.states.NEW with transaction(self._sa_session): self._sa_session.commit() @@ -137,10 +137,10 @@ def prepare_history_content_download(self, request: GenerateHistoryContentDownlo short_term_storage_target.path ) as export_store: if request.content_type == HistoryContentType.dataset: - hda = self._sa_session.query(model.HistoryDatasetAssociation).get(request.content_id) + hda = self._sa_session.get(model.HistoryDatasetAssociation, request.content_id) export_store.add_dataset(hda) else: - hdca = self._sa_session.query(model.HistoryDatasetCollectionAssociation).get(request.content_id) + hdca = self._sa_session.get(model.HistoryDatasetCollectionAssociation, request.content_id) export_store.export_collection( hdca, include_hidden=request.include_hidden, include_deleted=request.include_deleted ) @@ -157,7 +157,7 @@ def prepare_invocation_download(self, request: GenerateInvocationDownload): export_files=export_files, bco_export_options=self._bco_export_options(request), )(short_term_storage_target.path) as export_store: - invocation = self._sa_session.query(model.WorkflowInvocation).get(request.invocation_id) + invocation = self._sa_session.get(model.WorkflowInvocation, request.invocation_id) export_store.export_workflow_invocation( invocation, include_hidden=request.include_hidden, include_deleted=request.include_deleted ) @@ -174,7 +174,7 @@ def write_invocation_to(self, request: WriteInvocationTo): bco_export_options=self._bco_export_options(request), user_context=user_context, )(target_uri) as export_store: - invocation = self._sa_session.query(model.WorkflowInvocation).get(request.invocation_id) + invocation = self._sa_session.get(model.WorkflowInvocation, request.invocation_id) export_store.export_workflow_invocation( invocation, include_hidden=request.include_hidden, include_deleted=request.include_deleted ) @@ -199,10 +199,10 @@ def write_history_content_to(self, request: WriteHistoryContentTo): self._app, model_store_format, export_files=export_files, user_context=user_context )(target_uri) as export_store: if request.content_type == HistoryContentType.dataset: - hda = self._sa_session.query(model.HistoryDatasetAssociation).get(request.content_id) + hda = self._sa_session.get(model.HistoryDatasetAssociation, request.content_id) export_store.add_dataset(hda) else: - hdca = self._sa_session.query(model.HistoryDatasetCollectionAssociation).get(request.content_id) + hdca = self._sa_session.get(model.HistoryDatasetCollectionAssociation, request.content_id) export_store.export_collection( hdca, include_hidden=request.include_hidden, include_deleted=request.include_deleted ) @@ -267,7 +267,7 @@ def import_model_store(self, request: ImportModelStoreTaskRequest): ) history_id = request.history_id if history_id: - history = self._sa_session.query(model.History).get(history_id) + history = self._sa_session.get(model.History, history_id) else: history = None user_context = self._build_user_context(request.user.user_id) diff --git a/lib/galaxy/managers/pages.py b/lib/galaxy/managers/pages.py index af5a2b6d491b..df886ff96afd 100644 --- a/lib/galaxy/managers/pages.py +++ b/lib/galaxy/managers/pages.py @@ -16,6 +16,7 @@ ) from sqlalchemy import ( + desc, false, or_, select, @@ -42,7 +43,12 @@ ready_galaxy_markdown_for_export, ready_galaxy_markdown_for_import, ) -from galaxy.model import PageRevision +from galaxy.model import ( + Page, + PageRevision, + PageUserShareAssociation, + User, +) from galaxy.model.base import transaction from galaxy.model.index_filter_util import ( append_user_filter, @@ -631,3 +637,28 @@ def placeholderRenderForSave(trans: ProvidesHistoryContext, item_class, item_id, def get_page_revision(session: Session, page_id: int): stmt = select(PageRevision).filter_by(page_id=page_id) return session.scalars(stmt) + + +def get_shared_pages(session: Session, user: User): + stmt = ( + select(PageUserShareAssociation) + .where(PageUserShareAssociation.user == user) + .join(Page) + .where(Page.deleted == false()) + .order_by(desc(Page.update_time)) + ) + return session.scalars(stmt) + + +def get_page(session: Session, user: User, slug: str): + stmt = _build_page_query(select(Page), user, slug) + return session.scalar(stmt).first() + + +def page_exists(session: Session, user: User, slug: str) -> bool: + stmt = _build_page_query(select(Page.id), user, slug) + return session.scalar(stmt).first() is not None + + +def _build_page_query(select_clause, user: User, slug: str): + return select_clause.where(Page.user == user).where(Page.slug == slug).where(Page.deleted == false()).limit(1) diff --git a/lib/galaxy/managers/quotas.py b/lib/galaxy/managers/quotas.py index 0340698aa28f..72045d1cc513 100644 --- a/lib/galaxy/managers/quotas.py +++ b/lib/galaxy/managers/quotas.py @@ -11,12 +11,19 @@ Union, ) +from sqlalchemy import select + from galaxy import ( model, util, ) from galaxy.exceptions import ActionInputError from galaxy.managers import base +from galaxy.model import ( + Group, + Quota, + User, +) from galaxy.model.base import transaction from galaxy.quota import DatabaseQuotaAgent from galaxy.quota._schema import ( @@ -46,7 +53,8 @@ def quota_agent(self) -> DatabaseQuotaAgent: def create_quota(self, payload: dict, decode_id=None) -> Tuple[model.Quota, str]: params = CreateQuotaParams.parse_obj(payload) create_amount = self._parse_amount(params.amount) - if self.sa_session.query(model.Quota).filter(model.Quota.name == params.name).first(): + stmt = select(Quota).where(Quota.name == params.name).limit(1) + if self.sa_session.scalars(stmt).first(): raise ActionInputError( "Quota names must be unique and a quota with that name already exists, please choose another name." ) @@ -74,12 +82,10 @@ def create_quota(self, payload: dict, decode_id=None) -> Tuple[model.Quota, str] else: # Create the UserQuotaAssociations in_users = [ - self.sa_session.query(model.User).get(decode_id(x) if decode_id else x) - for x in util.listify(params.in_users) + self.sa_session.get(User, decode_id(x) if decode_id else x) for x in util.listify(params.in_users) ] in_groups = [ - self.sa_session.query(model.Group).get(decode_id(x) if decode_id else x) - for x in util.listify(params.in_groups) + self.sa_session.get(Group, decode_id(x) if decode_id else x) for x in util.listify(params.in_groups) ] if None in in_users: raise ActionInputError("One or more invalid user id has been provided.") @@ -108,12 +114,10 @@ def _parse_amount(self, amount: str) -> Optional[Union[int, bool]]: return False def rename_quota(self, quota, params) -> str: + stmt = select(Quota).where(Quota.name == params.name).limit(1) if not params.name: raise ActionInputError("Enter a valid name.") - elif ( - params.name != quota.name - and self.sa_session.query(model.Quota).filter(model.Quota.name == params.name).first() - ): + elif params.name != quota.name and self.sa_session.scalars(stmt).first(): raise ActionInputError("A quota with that name already exists.") else: old_name = quota.name @@ -131,13 +135,12 @@ def manage_users_and_groups_for_quota(self, quota, params, decode_id=None) -> st raise ActionInputError("Default quotas cannot be associated with specific users and groups.") else: in_users = [ - self.sa_session.query(model.User).get(decode_id(x) if decode_id else x) - for x in util.listify(params.in_users) + self.sa_session.get(model.User, decode_id(x) if decode_id else x) for x in util.listify(params.in_users) ] if None in in_users: raise ActionInputError("One or more invalid user id has been provided.") in_groups = [ - self.sa_session.query(model.Group).get(decode_id(x) if decode_id else x) + self.sa_session.get(model.Group, decode_id(x) if decode_id else x) for x in util.listify(params.in_groups) ] if None in in_groups: diff --git a/lib/galaxy/managers/roles.py b/lib/galaxy/managers/roles.py index bbecbdb7f424..bc840937a942 100644 --- a/lib/galaxy/managers/roles.py +++ b/lib/galaxy/managers/roles.py @@ -4,7 +4,10 @@ import logging from typing import List -from sqlalchemy import false +from sqlalchemy import ( + false, + select, +) from sqlalchemy.orm import exc as sqlalchemy_exceptions import galaxy.exceptions @@ -44,7 +47,8 @@ def get(self, trans: ProvidesUserContext, role_id: int) -> model.Role: :raises: InconsistentDatabase, RequestParameterInvalidException, InternalServerError """ try: - role = self.session().query(self.model_class).filter(self.model_class.id == role_id).one() + stmt = select(self.model_class).where(self.model_class.id == role_id) + role = self.session().execute(stmt).scalar_one() except sqlalchemy_exceptions.MultipleResultsFound: raise galaxy.exceptions.InconsistentDatabase("Multiple roles found with the same id.") except sqlalchemy_exceptions.NoResultFound: @@ -59,7 +63,8 @@ def get(self, trans: ProvidesUserContext, role_id: int) -> model.Role: def list_displayable_roles(self, trans: ProvidesUserContext) -> List[Role]: roles = [] - for role in trans.sa_session.query(Role).filter(Role.deleted == false()): + stmt = select(Role).where(Role.deleted == false()) + for role in trans.sa_session.scalars(stmt): if trans.user_is_admin or trans.app.security_agent.ok_to_display(trans.user, role): roles.append(role) return roles @@ -70,15 +75,16 @@ def create_role(self, trans: ProvidesUserContext, role_definition_model: RoleDef user_ids = role_definition_model.user_ids or [] group_ids = role_definition_model.group_ids or [] - if trans.sa_session.query(Role).filter(Role.name == name).first(): + stmt = select(Role).where(Role.name == name).limit(1) + if trans.sa_session.scalars(stmt).first(): raise RequestParameterInvalidException(f"A role with that name already exists [{name}]") role_type = Role.types.ADMIN # TODO: allow non-admins to create roles role = Role(name=name, description=description, type=role_type) trans.sa_session.add(role) - users = [trans.sa_session.query(model.User).get(i) for i in user_ids] - groups = [trans.sa_session.query(model.Group).get(i) for i in group_ids] + users = [trans.sa_session.get(model.User, i) for i in user_ids] + groups = [trans.sa_session.get(model.Group, i) for i in group_ids] # Create the UserRoleAssociations for user in users: diff --git a/lib/galaxy/managers/users.py b/lib/galaxy/managers/users.py index c1ed9f992f3a..b3c309383e84 100644 --- a/lib/galaxy/managers/users.py +++ b/lib/galaxy/managers/users.py @@ -39,6 +39,7 @@ ) from galaxy.model import ( User, + UserAddress, UserQuotaUsage, ) from galaxy.model.base import transaction @@ -233,13 +234,8 @@ def purge(self, user, flush=True): user.username = uname_hash # Redact user addresses as well if self.app.config.redact_user_address_during_deletion: - user_addresses = ( - self.session() - .query(self.app.model.UserAddress) - .filter(self.app.model.UserAddress.user_id == user.id) - .all() - ) - for addr in user_addresses: + stmt = select(UserAddress).where(UserAddress.user_id == user.id) + for addr in self.session().scalars(stmt): addr.desc = new_secure_hash_v2(addr.desc + pseudorandom_value) addr.name = new_secure_hash_v2(addr.name + pseudorandom_value) addr.institution = new_secure_hash_v2(addr.institution + pseudorandom_value) @@ -264,7 +260,7 @@ def _error_on_duplicate_email(self, email: str) -> None: raise exceptions.Conflict("Email must be unique", email=email) def by_id(self, user_id: int) -> model.User: - return self.app.model.session.query(self.model_class).get(user_id) + return self.app.model.session.get(self.model_class, user_id) # ---- filters def by_email(self, email: str, filters=None, **kwargs) -> Optional[model.User]: @@ -286,7 +282,8 @@ def by_api_key(self, api_key: str, sa_session=None): return schema.BootstrapAdminUser() sa_session = sa_session or self.app.model.session try: - provided_key = sa_session.query(self.app.model.APIKeys).filter_by(key=api_key, deleted=False).one() + stmt = select(self.app.model.APIKeys).filter_by(key=api_key, deleted=False) + provided_key = sa_session.execute(stmt).scalar_one() except NoResultFound: raise exceptions.AuthenticationFailed("Provided API key is not valid.") if provided_key.user.deleted: @@ -363,12 +360,7 @@ def get_user_by_identity(self, identity): user = get_user_by_email(self.session(), identity, self.model_class) if not user: # Try a case-insensitive match on the email - user = ( - self.session() - .query(self.model_class) - .filter(func.lower(self.model_class.table.c.email) == identity.lower()) - .first() - ) + user = self._get_user_by_email_case_insensitive(self.session(), identity) else: user = get_user_by_username(self.session(), identity, self.model_class) return user @@ -445,7 +437,7 @@ def change_password(self, trans, password=None, confirm=None, token=None, id=Non if not token and not id: return None, "Please provide a token or a user and password." if token: - token_result = trans.sa_session.query(self.app.model.PasswordResetToken).get(token) + token_result = trans.sa_session.get(self.app.model.PasswordResetToken, token) if not token_result or not token_result.expiration_time > datetime.utcnow(): return None, "Invalid or expired password reset token, please request a new one." user = token_result.user @@ -483,13 +475,14 @@ def __set_password(self, trans, user, password, confirm): user.set_password_cleartext(password) # Invalidate all other sessions if trans.galaxy_session: - for other_galaxy_session in trans.sa_session.query(self.app.model.GalaxySession).filter( + stmt = select(self.app.model.GalaxySession).where( and_( - self.app.model.GalaxySession.table.c.user_id == user.id, - self.app.model.GalaxySession.table.c.is_valid == true(), - self.app.model.GalaxySession.table.c.id != trans.galaxy_session.id, + self.app.model.GalaxySession.user_id == user.id, + self.app.model.GalaxySession.is_valid == true(), + self.app.model.GalaxySession.id != trans.galaxy_session.id, ) - ): + ) + for other_galaxy_session in trans.sa_session.scalars(stmt): other_galaxy_session.is_valid = False trans.sa_session.add(other_galaxy_session) trans.sa_session.add(user) @@ -581,11 +574,7 @@ def send_reset_email(self, trans, payload, **kwd): def get_reset_token(self, trans, email): reset_user = get_user_by_email(trans.sa_session, email, self.app.model.User) if not reset_user and email != email.lower(): - reset_user = ( - trans.sa_session.query(self.app.model.User) - .filter(func.lower(self.app.model.User.table.c.email) == email.lower()) - .first() - ) + reset_user = self._get_user_by_email_case_insensitive(trans.sa_session, email) if reset_user: prt = self.app.model.PasswordResetToken(reset_user) trans.sa_session.add(prt) @@ -644,9 +633,11 @@ def get_or_create_remote_user(self, remote_user_email): for char in [x for x in username if x not in f"{string.ascii_lowercase + string.digits}-."]: username = username.replace(char, "-") # Find a unique username - user can change it later - if self.session().query(self.app.model.User).filter_by(username=username).first(): + stmt = select(self.app.model.User).filter_by(username=username).limit(1) + if self.session().scalars(stmt).first(): i = 1 - while self.session().query(self.app.model.User).filter_by(username=f"{username}-{str(i)}").first(): + stmt = select(self.app.model.User).filter_by(username=f"{username}-{str(i)}").limit(1) + while self.session().scalars(stmt).first(): i += 1 username += f"-{str(i)}" user.username = username @@ -660,6 +651,10 @@ def get_or_create_remote_user(self, remote_user_email): # self.log_event( "Automatically created account '%s'", user.email ) return user + def _get_user_by_email_case_insensitive(self, session, email): + stmt = select(self.app.model.User).where(func.lower(self.app.model.User.email) == email.lower()).limit(1) + return session.scalars(stmt).first() + class UserSerializer(base.ModelSerializer, deletable.PurgableSerializerMixin): model_manager_class = UserManager diff --git a/lib/galaxy/model/__init__.py b/lib/galaxy/model/__init__.py index 1c3cd86d30a2..894c700e50b6 100644 --- a/lib/galaxy/model/__init__.py +++ b/lib/galaxy/model/__init__.py @@ -60,6 +60,7 @@ Column, column, DateTime, + delete, desc, event, false, @@ -2838,12 +2839,12 @@ class HistoryAudit(Base, RepresentById): @classmethod def prune(cls, sa_session): latest_subq = ( - sa_session.query(cls.history_id, func.max(cls.update_time).label("max_update_time")) + select(cls.history_id, func.max(cls.update_time).label("max_update_time")) .group_by(cls.history_id) .subquery() ) not_latest_query = ( - sa_session.query(cls.history_id, cls.update_time) + select(cls.history_id, cls.update_time) .select_from(latest_subq) .join( cls, @@ -3263,8 +3264,8 @@ def resume_paused_jobs(self): @property def paused_jobs(self): - db_session = object_session(self) - return db_session.query(Job).filter(Job.history_id == self.id, Job.state == Job.states.PAUSED).all() + stmt = select(Job).where(Job.history_id == self.id, Job.state == Job.states.PAUSED) + return object_session(self).scalars(stmt).all() @hybrid.hybrid_property def disk_size(self): @@ -3273,23 +3274,17 @@ def disk_size(self): all non-purged, unique datasets within it. """ # non-.expression part of hybrid.hybrid_property: called when an instance is the namespace (not the class) - db_session = object_session(self) - rval = db_session.query( - func.sum( - db_session.query(HistoryDatasetAssociation.dataset_id, Dataset.total_size) - .join(Dataset) - .filter(HistoryDatasetAssociation.table.c.history_id == self.id) - .filter(HistoryDatasetAssociation.purged != true()) - .filter(Dataset.purged != true()) - # unique datasets only - .distinct() - .subquery() - .c.total_size - ) - ).first()[0] - if rval is None: - rval = 0 - return rval + subq = ( + select(HistoryDatasetAssociation.dataset_id, Dataset.total_size) + .join(Dataset) + .where(HistoryDatasetAssociation.history_id == self.id) + .where(HistoryDatasetAssociation.purged != true()) + .where(Dataset.purged != true()) + .distinct() # unique datasets only + .subquery() + ) + stmt = select(func.sum(subq.c.total_size)) + return object_session(self).scalar(stmt) or 0 @disk_size.expression # type: ignore[no-redef] def disk_size(cls): @@ -3331,14 +3326,12 @@ def disk_nice_size(self): """Returns human readable size of history on disk.""" return galaxy.util.nice_size(self.disk_size) - @property - def active_dataset_and_roles_query(self): - db_session = object_session(self) + def _active_dataset_and_roles_query(self): return ( - db_session.query(HistoryDatasetAssociation) - .filter(HistoryDatasetAssociation.table.c.history_id == self.id) - .filter(not_(HistoryDatasetAssociation.deleted)) - .order_by(HistoryDatasetAssociation.table.c.hid.asc()) + select(HistoryDatasetAssociation) + .where(HistoryDatasetAssociation.history_id == self.id) + .where(not_(HistoryDatasetAssociation.deleted)) + .order_by(HistoryDatasetAssociation.hid.asc()) .options( joinedload(HistoryDatasetAssociation.dataset) .joinedload(Dataset.actions) @@ -3350,33 +3343,32 @@ def active_dataset_and_roles_query(self): @property def active_datasets_and_roles(self): if not hasattr(self, "_active_datasets_and_roles"): - self._active_datasets_and_roles = self.active_dataset_and_roles_query.all() + stmt = self._active_dataset_and_roles_query() + self._active_datasets_and_roles = object_session(self).scalars(stmt).unique().all() return self._active_datasets_and_roles @property def active_visible_datasets_and_roles(self): if not hasattr(self, "_active_visible_datasets_and_roles"): - self._active_visible_datasets_and_roles = self.active_dataset_and_roles_query.filter( - HistoryDatasetAssociation.visible - ).all() + stmt = self._active_dataset_and_roles_query().where(HistoryDatasetAssociation.visible) + self._active_visible_datasets_and_roles = object_session(self).scalars(stmt).unique().all() return self._active_visible_datasets_and_roles @property def active_visible_dataset_collections(self): if not hasattr(self, "_active_visible_dataset_collections"): - db_session = object_session(self) - query = ( - db_session.query(HistoryDatasetCollectionAssociation) - .filter(HistoryDatasetCollectionAssociation.table.c.history_id == self.id) - .filter(not_(HistoryDatasetCollectionAssociation.deleted)) - .filter(HistoryDatasetCollectionAssociation.visible) - .order_by(HistoryDatasetCollectionAssociation.table.c.hid.asc()) + stmt = ( + select(HistoryDatasetCollectionAssociation) + .where(HistoryDatasetCollectionAssociation.history_id == self.id) + .where(not_(HistoryDatasetCollectionAssociation.deleted)) + .where(HistoryDatasetCollectionAssociation.visible) + .order_by(HistoryDatasetCollectionAssociation.hid.asc()) .options( joinedload(HistoryDatasetCollectionAssociation.collection), joinedload(HistoryDatasetCollectionAssociation.tags), ) ) - self._active_visible_dataset_collections = query.all() + self._active_visible_dataset_collections = object_session(self).scalars(stmt).unique().all() return self._active_visible_dataset_collections @property @@ -3402,35 +3394,36 @@ def contents_iter(self, **kwds): def __dataset_contents_iter(self, **kwds): return self.__filter_contents(HistoryDatasetAssociation, **kwds) + def __collection_contents_iter(self, **kwds): + return self.__filter_contents(HistoryDatasetCollectionAssociation, **kwds) + def __filter_contents(self, content_class, **kwds): - db_session = object_session(self) - assert db_session is not None - query = db_session.query(content_class).filter(content_class.table.c.history_id == self.id) - query = query.order_by(content_class.table.c.hid.asc()) + session = object_session(self) + stmt = select(content_class).where(content_class.history_id == self.id).order_by(content_class.hid.asc()) + deleted = galaxy.util.string_as_bool_or_none(kwds.get("deleted", None)) if deleted is not None: - query = query.filter(content_class.deleted == deleted) + stmt = stmt.where(content_class.deleted == deleted) + visible = galaxy.util.string_as_bool_or_none(kwds.get("visible", None)) if visible is not None: - query = query.filter(content_class.visible == visible) + stmt = stmt.where(content_class.visible == visible) + if "object_store_ids" in kwds: if content_class == HistoryDatasetAssociation: - query = query.join(content_class.dataset).filter( - Dataset.table.c.object_store_id.in_(kwds.get("object_store_ids")) - ) + stmt = stmt.join(content_class.dataset).where(Dataset.object_store_id.in_(kwds.get("object_store_ids"))) # else ignoring object_store_ids on HDCAs... + if "ids" in kwds: assert "object_store_ids" not in kwds ids = kwds["ids"] max_in_filter_length = kwds.get("max_in_filter_length", MAX_IN_FILTER_LENGTH) if len(ids) < max_in_filter_length: - query = query.filter(content_class.id.in_(ids)) + stmt = stmt.where(content_class.id.in_(ids)) else: - query = (content for content in query if content.id in ids) - return query + return (content for content in session.scalars(stmt) if content.id in ids) - def __collection_contents_iter(self, **kwds): - return self.__filter_contents(HistoryDatasetCollectionAssociation, **kwds) + return session.scalars(stmt) class UserShareAssociation(RepresentById): @@ -4845,13 +4838,9 @@ def convert_dataset(self, trans, target_type): # Check dataset state and return any messages. msg = None if converted_dataset and converted_dataset.state == Dataset.states.ERROR: - job_id = ( - trans.sa_session.query(JobToOutputDatasetAssociation) - .filter_by(dataset_id=converted_dataset.id) - .first() - .job_id - ) - job = trans.sa_session.query(Job).get(job_id) + stmt = select(JobToOutputDatasetAssociation.job_id).filter_by(dataset_id=converted_dataset.id).limit(1) + job_id = trans.sa_session.scalars(stmt).first() + job = trans.sa_session.get(Job, job_id) msg = {"kind": self.conversion_messages.ERROR, "message": job.stderr} elif not converted_dataset or converted_dataset.state != Dataset.states.OK: msg = self.conversion_messages.PENDING @@ -7343,14 +7332,13 @@ def get_internal_version(self, version): def show_in_tool_panel(self, user_id): sa_session = object_session(self) - return bool( - sa_session.query(StoredWorkflowMenuEntry) - .filter( - StoredWorkflowMenuEntry.stored_workflow_id == self.id, - StoredWorkflowMenuEntry.user_id == user_id, - ) - .count() + stmt = ( + select(func.count()) + .select_from(StoredWorkflowMenuEntry) + .where(StoredWorkflowMenuEntry.stored_workflow_id == self.id) + .where(StoredWorkflowMenuEntry.user_id == user_id) ) + return bool(sa_session.scalar(stmt)) def copy_tags_from(self, target_user, source_workflow): # Override to only copy owner tags. @@ -8155,16 +8143,13 @@ def step_invocation_for_label(self, label): @staticmethod def poll_unhandled_workflow_ids(sa_session): - and_conditions = [ - WorkflowInvocation.state == WorkflowInvocation.states.NEW, - WorkflowInvocation.handler.is_(None), - ] - query = ( - sa_session.query(WorkflowInvocation.id) - .filter(and_(*and_conditions)) - .order_by(WorkflowInvocation.table.c.id.asc()) + stmt = ( + select(WorkflowInvocation.id) + .where(WorkflowInvocation.state == WorkflowInvocation.states.NEW) + .where(WorkflowInvocation.handler.is_(None)) + .order_by(WorkflowInvocation.id.asc()) ) - return [wid for wid in query.all()] + return [wid for wid in sa_session.scalars(stmt)] @staticmethod def poll_active_workflow_ids(engine, scheduler=None, handler=None): @@ -9302,7 +9287,8 @@ def save(self): @classmethod def store(cls, server_url, association): try: - assoc = cls.sa_session.query(cls).filter_by(server_url=server_url, handle=association.handle)[0] + stmt = select(PSAAssociation).filter_by(server_url=server_url, handle=association.handle).limit(1) + assoc = cls.sa_session.scalars(stmt).first() except IndexError: assoc = cls(server_url=server_url, handle=association.handle) assoc.secret = base64.encodebytes(association.secret).decode() @@ -9315,11 +9301,17 @@ def store(cls, server_url, association): @classmethod def get(cls, *args, **kwargs): - return cls.sa_session.query(cls).filter_by(*args, **kwargs) + stmt = select(PSAAssociation).filter_by(*args, **kwargs) + return cls.sa_session.scalars(stmt) @classmethod def remove(cls, ids_to_delete): - cls.sa_session.query(cls).filter(cls.id.in_(ids_to_delete)).delete(synchronize_session="fetch") + stmt = ( + delete(PSAAssociation) + .where(PSAAssociation.id.in_(ids_to_delete)) + .execution_options(synchronize_session="fetch") + ) + PSAAssociation.sa_session.execute(stmt) class PSACode(Base, CodeMixin, RepresentById): @@ -9344,7 +9336,8 @@ def save(self): @classmethod def get_code(cls, code): - return cls.sa_session.query(cls).filter(cls.code == code).first() + stmt = select(PSACode).where(PSACode.code == code).limit(1) + return cls.sa_session.scalars(stmt).first() class PSANonce(Base, NonceMixin, RepresentById): @@ -9371,7 +9364,8 @@ def save(self): @classmethod def use(cls, server_url, timestamp, salt): try: - return cls.sa_session.query(cls).filter_by(server_url=server_url, timestamp=timestamp, salt=salt)[0] + stmt = select(PSANonce).where(server_url=server_url, timestamp=timestamp, salt=salt).limit(1) + return cls.sa_session.scalars(stmt).first() except IndexError: instance = cls(server_url=server_url, timestamp=timestamp, salt=salt) cls.sa_session.add(instance) @@ -9405,13 +9399,17 @@ def save(self): @classmethod def load(cls, token): - return cls.sa_session.query(cls).filter(cls.token == token).first() + stmt = select(PSAPartial).where(PSAPartial.token == token).limit(1) + return cls.sa_session.scalars(stmt).first() @classmethod def destroy(cls, token): partial = cls.load(token) if partial: - cls.sa_session.delete(partial) + session = cls.sa_session + session.execute(delete(partial)) + with transaction(session): + session.commit() class UserAuthnzToken(Base, UserMixin, RepresentById): @@ -9462,24 +9460,12 @@ def username_max_length(cls): # A better alternative is to retrieve this number from the table, instead of this const value. return 255 - @classmethod - def user_model(cls): - return User - @classmethod def changed(cls, user): cls.sa_session.add(user) with transaction(cls.sa_session): cls.sa_session.commit() - @classmethod - def user_query(cls): - return cls.sa_session.query(cls.user_model()) - - @classmethod - def user_exists(cls, *args, **kwargs): - return cls.user_query().filter_by(*args, **kwargs).count() > 0 - @classmethod def get_username(cls, user): return getattr(user, "username", None) @@ -9490,9 +9476,8 @@ def create_user(cls, *args, **kwargs): This is used by PSA authnz, do not use directly. Prefer using the user manager. """ - model = cls.user_model() - instance = model(*args, **kwargs) - if cls.get_users_by_email(instance.email).first(): + instance = User(*args, **kwargs) + if cls.email_exists(instance.email): raise Exception(f"User with this email '{instance.email}' already exists.") instance.set_random_password() cls.sa_session.add(instance) @@ -9502,28 +9487,30 @@ def create_user(cls, *args, **kwargs): @classmethod def get_user(cls, pk): - return cls.user_query().get(pk) + return UserAuthnzToken.sa_session.get(User, pk) @classmethod - def get_users_by_email(cls, email): - return cls.user_query().filter(func.lower(User.email) == email.lower()) + def email_exists(cls, email): + stmt = select(User).where(func.lower(User.email) == email.lower()).limit(1) + return bool(cls.sa_session.scalars(stmt).first()) @classmethod def get_social_auth(cls, provider, uid): uid = str(uid) try: - return cls.sa_session.query(cls).filter_by(provider=provider, uid=uid)[0] + stmt = select(UserAuthnzToken).filter_by(provider=provider, uid=uid).limit(1) + return cls.sa_session.scalars(stmt).first() except IndexError: return None @classmethod def get_social_auth_for_user(cls, user, provider=None, id=None): - qs = cls.sa_session.query(cls).filter_by(user_id=user.id) + stmt = select(UserAuthnzToken).filter_by(user_id=user.id) if provider: - qs = qs.filter_by(provider=provider) + stmt = stmt.filter_by(provider=provider) if id: - qs = qs.filter_by(id=id) - return qs + stmt = stmt.filter_by(id=id) + return cls.sa_session.scalars(stmt) @classmethod def create_social_auth(cls, user, uid, provider): diff --git a/lib/galaxy/model/item_attrs.py b/lib/galaxy/model/item_attrs.py index 90e236ec7fa8..998536e07b7d 100644 --- a/lib/galaxy/model/item_attrs.py +++ b/lib/galaxy/model/item_attrs.py @@ -1,5 +1,6 @@ import logging +from sqlalchemy import select from sqlalchemy.sql.expression import func # Cannot import galaxy.model b/c it creates a circular import graph. @@ -27,13 +28,13 @@ def get_ave_item_rating_data(self, db_session, item, webapp_model=None): if not item_rating_assoc_class: raise Exception(f"Item does not have ratings: {item.__class__.__name__}") item_id_filter = self._get_item_id_filter_str(item, item_rating_assoc_class) - ave_rating = db_session.query(func.avg(item_rating_assoc_class.rating)).filter(item_id_filter).scalar() + ave_rating = db_session.scalar(select(func.avg(item_rating_assoc_class.rating)).where(item_id_filter)) # Convert ave_rating to float; note: if there are no item ratings, ave rating is None. if ave_rating: ave_rating = float(ave_rating) else: ave_rating = 0 - num_ratings = int(db_session.query(func.count(item_rating_assoc_class.rating)).filter(item_id_filter).scalar()) + num_ratings = db_session.scalar(select(func.count(item_rating_assoc_class.rating)).where(item_id_filter)) return (ave_rating, num_ratings) def rate_item(self, db_session, user, item, rating, webapp_model=None): @@ -65,7 +66,9 @@ def get_user_item_rating(self, db_session, user, item, webapp_model=None): # Query rating table by user and item id. item_id_filter = self._get_item_id_filter_str(item, item_rating_assoc_class) - return db_session.query(item_rating_assoc_class).filter_by(user=user).filter(item_id_filter).first() + return db_session.scalars( + select(item_rating_assoc_class).filter_by(user=user).where(item_id_filter).limit(1) + ).first() def _get_item_rating_assoc_class(self, item, webapp_model=None): """Returns an item's item-rating association class.""" @@ -118,7 +121,7 @@ def get_item_annotation_obj(db_session, user, item): return None # Get annotation association object. - annotation_assoc = db_session.query(annotation_assoc_class).filter_by(user=user) + annotation_assoc = select(annotation_assoc_class).filter_by(user=user) if item.__class__ == galaxy.model.History: annotation_assoc = annotation_assoc.filter_by(history=item) @@ -134,7 +137,7 @@ def get_item_annotation_obj(db_session, user, item): annotation_assoc = annotation_assoc.filter_by(page=item) elif item.__class__ == galaxy.model.Visualization: annotation_assoc = annotation_assoc.filter_by(visualization=item) - return annotation_assoc.first() + return db_session.scalars(annotation_assoc.limit(1)).first() def get_item_annotation_str(db_session, user, item): diff --git a/lib/galaxy/model/metadata.py b/lib/galaxy/model/metadata.py index ed757314edca..35e7f73762a1 100644 --- a/lib/galaxy/model/metadata.py +++ b/lib/galaxy/model/metadata.py @@ -20,6 +20,7 @@ Union, ) +from sqlalchemy import select from sqlalchemy.orm import object_session from sqlalchemy.orm.attributes import flag_modified @@ -602,9 +603,9 @@ def wrap(self, value, session): if isinstance(value, galaxy.model.MetadataFile) or isinstance(value, MetadataTempFile): return value if isinstance(value, int): - return session.query(galaxy.model.MetadataFile).get(value) + return session.get(galaxy.model.MetadataFile, value) else: - return session.query(galaxy.model.MetadataFile).filter_by(uuid=value).one() + return session.execute(select(galaxy.model.MetadataFile).filter_by(uuid=value)).scalar_one() def make_copy(self, value, target_context: MetadataCollection, source_context): session = target_context._object_session(target_context.parent) diff --git a/lib/galaxy/model/security.py b/lib/galaxy/model/security.py index 84c3eb5fb976..48f966e0a9f5 100644 --- a/lib/galaxy/model/security.py +++ b/lib/galaxy/model/security.py @@ -9,12 +9,30 @@ from sqlalchemy import ( and_, false, + func, not_, or_, + select, ) from sqlalchemy.orm import joinedload import galaxy.model +from galaxy.model import ( + Dataset, + DatasetPermissions, + Group, + GroupRoleAssociation, + HistoryDatasetAssociationDisplayAtAuthorization, + Library, + LibraryDataset, + LibraryDatasetDatasetAssociation, + LibraryDatasetPermissions, + LibraryPermissions, + Role, + User, + UserGroupAssociation, + UserRoleAssociation, +) from galaxy.model.base import transaction from galaxy.security import ( Action, @@ -65,17 +83,12 @@ def _get_npns_roles(self, trans): """ non-private, non-sharing roles """ - return ( - trans.sa_session.query(trans.app.model.Role) - .filter( - and_( - self.model.Role.deleted == false(), - self.model.Role.type != self.model.Role.types.PRIVATE, - self.model.Role.type != self.model.Role.types.SHARING, - ) - ) - .order_by(self.model.Role.name) + stmt = ( + select(Role) + .where(and_(Role.deleted == false(), Role.type != Role.types.PRIVATE, Role.type != Role.types.SHARING)) + .order_by(Role.name) ) + return trans.sa_session.scalars(stmt) def get_all_roles(self, trans, cntrller): admin_controller = cntrller in ["library_admin"] @@ -84,11 +97,8 @@ def get_all_roles(self, trans, cntrller): return self._get_npns_roles(trans) if admin_controller: # The library is public and the user is an admin, so all roles are legitimate - for role in ( - trans.sa_session.query(trans.app.model.Role) - .filter(self.model.Role.deleted == false()) - .order_by(self.model.Role.name) - ): + stmt = select(Role).where(Role.deleted == false()).order_by(Role.name) + for role in trans.sa_session.scalars(stmt): roles.add(role) else: # Add the current user's private role @@ -146,27 +156,30 @@ def get_valid_roles(self, trans, item, query=None, page=None, page_limit=None, i # Admins can always choose from all non-deleted roles if trans.user_is_admin or trans.app.config.expose_user_email: if trans.user_is_admin: - db_query = trans.sa_session.query(trans.app.model.Role).filter(self.model.Role.deleted == false()) + stmt = select(Role).where(Role.deleted == false()) else: # User is not an admin but the configuration exposes all private roles to all users. - db_query = trans.sa_session.query(trans.app.model.Role).filter( - and_(self.model.Role.deleted == false(), self.model.Role.type == self.model.Role.types.PRIVATE) - ) + stmt = select(Role).where(and_(Role.deleted == false(), Role.type == Role.types.PRIVATE)) if search_query: - db_query = db_query.filter(self.model.Role.name.like(search_query, escape="/")) - total_count = db_query.count() + stmt = stmt.where(Role.name.like(search_query, escape="/")) + + count_stmt = select(func.count()).select_from(stmt) + total_count = trans.sa_session.scalar(count_stmt) + if limit is not None: # Takes the least number of results from beginning that includes the requested page - roles = db_query.order_by(self.model.Role.name).limit(limit).all() + stmt = stmt.order_by(Role.name).limit(limit) page_start = (page * page_limit) - page_limit page_end = page_start + page_limit if total_count < page_start + 1: # Return empty list if there are less results than the requested position roles = [] else: + roles = trans.sa_session.scalars(stmt).all() roles = roles[page_start:page_end] else: - roles = db_query.order_by(self.model.Role.name) + stmt = stmt.order_by(Role.name) + roles = trans.sa_session.scalars(stmt).all() # Non-admin and public item elif is_public_item: # Add the current user's private role @@ -324,17 +337,13 @@ def get_actions_for_items(self, trans, action, permission_items): # SM: NB: LibraryDatasets became Datasets for some odd reason. if isinstance(permission_items[0], trans.model.LibraryDataset): ids = [item.library_dataset_id for item in permission_items] - permissions = ( - trans.sa_session.query(trans.model.LibraryDatasetPermissions) - .filter( - and_( - trans.model.LibraryDatasetPermissions.library_dataset_id.in_(ids), - trans.model.LibraryDatasetPermissions.action == action.action, - ) + stmt = select(LibraryDatasetPermissions).where( + and_( + LibraryDatasetPermissions.library_dataset_id.in_(ids), + LibraryDatasetPermissions.action == action.action, ) - .all() ) - + permissions = trans.sa_session.scalars(stmt) # Massage the return data. We will return a list of permissions # for each library dataset. So we initialize the return list to # have an empty list for each dataset. Then each permission is @@ -347,17 +356,11 @@ def get_actions_for_items(self, trans, action, permission_items): ret_permissions[permission.library_dataset_id].append(permission) elif isinstance(permission_items[0], trans.model.Dataset): ids = [item.id for item in permission_items] - permissions = ( - trans.sa_session.query(trans.model.DatasetPermissions) - .filter( - and_( - trans.model.DatasetPermissions.dataset_id.in_(ids), - trans.model.DatasetPermissions.action == action.action, - ) - ) - .all() - ) + stmt = select(DatasetPermissions).where( + and_(DatasetPermissions.dataset_id.in_(ids), DatasetPermissions.action == action.action) + ) + permissions = trans.sa_session.scalars(stmt) # Massage the return data. We will return a list of permissions # for each library dataset. So we initialize the return list to # have an empty list for each dataset. Then each permission is @@ -540,38 +543,34 @@ def get_accessible_libraries(self, trans, user): accessible_libraries = [] current_user_role_ids = [role.id for role in user.all_roles()] library_access_action = self.permitted_actions.LIBRARY_ACCESS.action - restricted_library_ids = [ - lp.library_id - for lp in trans.sa_session.query(trans.model.LibraryPermissions) - .filter(trans.model.LibraryPermissions.table.c.action == library_access_action) - .distinct() - ] - accessible_restricted_library_ids = [ - lp.library_id - for lp in trans.sa_session.query(trans.model.LibraryPermissions).filter( - and_( - trans.model.LibraryPermissions.table.c.action == library_access_action, - trans.model.LibraryPermissions.table.c.role_id.in_(current_user_role_ids), - ) + + stmt = select(LibraryPermissions).where(LibraryPermissions.action == library_access_action).distinct() + restricted_library_ids = [lp.library_id for lp in trans.sa_session.scalars(stmt)] + + stmt = select(LibraryPermissions).where( + and_( + LibraryPermissions.action == library_access_action, + LibraryPermissions.role_id.in_(current_user_role_ids), ) - ] + ) + accessible_restricted_library_ids = [lp.library_id for lp in trans.sa_session.scalars(stmt)] + # Filter to get libraries accessible by the current user. Get both # public libraries and restricted libraries accessible by the current user. - for library in ( - trans.sa_session.query(trans.model.Library) - .filter( + stmt = ( + select(Library) + .where( and_( - trans.model.Library.table.c.deleted == false(), - ( - or_( - not_(trans.model.Library.table.c.id.in_(restricted_library_ids)), - trans.model.Library.table.c.id.in_(accessible_restricted_library_ids), - ) + Library.deleted == false(), + or_( + not_(Library.id.in_(restricted_library_ids)), + Library.id.in_(accessible_restricted_library_ids), ), ) ) - .order_by(trans.app.model.Library.name) - ): + .order_by(Library.name) + ) + for library in trans.sa_session.scalars(stmt): accessible_libraries.append(library) return accessible_libraries @@ -589,12 +588,10 @@ def has_accessible_folders(self, trans, folder, user, roles, search_downward=Tru return False def has_accessible_library_datasets(self, trans, folder, user, roles, search_downward=True): - for library_dataset in trans.sa_session.query(trans.model.LibraryDataset).filter( - and_( - trans.model.LibraryDataset.table.c.deleted == false(), - trans.app.model.LibraryDataset.table.c.folder_id == folder.id, - ) - ): + stmt = select(LibraryDataset).where( + and_(LibraryDataset.deleted == false(), LibraryDataset.folder_id == folder.id) + ) + for library_dataset in trans.sa_session.scalars(stmt): if self.can_access_library_item(roles, library_dataset, user): return True if search_downward: @@ -749,17 +746,14 @@ def create_private_user_role(self, user): return self.get_private_user_role(user) def get_private_user_role(self, user, auto_create=False): - role = ( - self.sa_session.query(self.model.Role) - .filter( - and_( - self.model.UserRoleAssociation.table.c.user_id == user.id, - self.model.Role.id == self.model.UserRoleAssociation.table.c.role_id, - self.model.Role.type == self.model.Role.types.PRIVATE, - ) + stmt = select(Role).where( + and_( + UserRoleAssociation.user_id == user.id, + Role.id == UserRoleAssociation.role_id, + Role.type == Role.types.PRIVATE, ) - .one_or_none() ) + role = self.sa_session.execute(stmt).scalar_one_or_none() if not role: if auto_create: return self.create_private_user_role(user) @@ -770,21 +764,18 @@ def get_private_user_role(self, user, auto_create=False): def get_role(self, name, type=None): type = type or self.model.Role.types.SYSTEM # will raise exception if not found - return ( - self.sa_session.query(self.model.Role) - .filter(and_(self.model.Role.name == name, self.model.Role.type == type)) - .one() - ) + stmt = select(Role).where(and_(Role.name == name, Role.type == type)) + return self.sa_session.execute(stmt).scalar_one() def create_role(self, name, description, in_users, in_groups, create_group_for_role=False, type=None): type = type or self.model.Role.types.SYSTEM role = self.model.Role(name=name, description=description, type=type) self.sa_session.add(role) # Create the UserRoleAssociations - for user in [self.sa_session.query(self.model.User).get(x) for x in in_users]: + for user in [self.sa_session.get(User, x) for x in in_users]: self.associate_user_role(user, role) # Create the GroupRoleAssociations - for group in [self.sa_session.query(self.model.Group).get(x) for x in in_groups]: + for group in [self.sa_session.get(Group, x) for x in in_groups]: self.associate_group_role(group, role) if create_group_for_role: # Create the group @@ -800,12 +791,10 @@ def create_role(self, name, description, in_users, in_groups, create_group_for_r return role, num_in_groups def get_sharing_roles(self, user): - return self.sa_session.query(self.model.Role).filter( - and_( - (self.model.Role.name).like(f"Sharing role for: %{user.email}%"), - self.model.Role.type == self.model.Role.types.SHARING, - ) + stmt = select(Role).where( + and_((Role.name).like(f"Sharing role for: %{user.email}%"), Role.type == Role.types.SHARING) ) + return self.sa_session.scalars(stmt) def user_set_default_permissions( self, @@ -1217,16 +1206,13 @@ def datasets_are_public(self, trans, datasets): datasets_public[dataset_id] = True # Now get all datasets which have DATASET_ACCESS actions: - access_data_perms = ( - trans.sa_session.query(trans.app.model.DatasetPermissions) - .filter( - and_( - trans.app.model.DatasetPermissions.dataset_id.in_(dataset_ids), - trans.app.model.DatasetPermissions.action == self.permitted_actions.DATASET_ACCESS.action, - ) + stmt = select(DatasetPermissions).where( + and_( + DatasetPermissions.dataset_id.in_(dataset_ids), + DatasetPermissions.action == self.permitted_actions.DATASET_ACCESS.action, ) - .all() ) + access_data_perms = trans.sa_session.scalars(stmt) # Every dataset returned has "access" privileges associated with it, # so it's not public. for permission in access_data_perms: @@ -1264,14 +1250,14 @@ def derive_roles_from_access(self, trans, item_id, cntrller, library=False, **kw error = False for k, v in get_permitted_actions(filter="DATASET").items(): # Change for removing the prefix '_in' from the roles select box - in_roles = [self.sa_session.query(self.model.Role).get(x) for x in listify(kwd[k])] + in_roles = [self.sa_session.get(Role, x) for x in listify(kwd[k])] if not in_roles: - in_roles = [self.sa_session.query(self.model.Role).get(x) for x in listify(kwd.get(f"{k}_in", []))] + in_roles = [self.sa_session.get(Role, x) for x in listify(kwd.get(f"{k}_in", []))] if v == self.permitted_actions.DATASET_ACCESS and in_roles: if library: - item = self.sa_session.query(self.model.Library).get(item_id) + item = self.sa_session.get(Library, item_id) else: - item = self.sa_session.query(self.model.Dataset).get(item_id) + item = self.sa_session.get(Dataset, item_id) if (library and not self.library_is_public(item)) or (not library and not self.dataset_is_public(item)): # Ensure that roles being associated with DATASET_ACCESS are a subset of the legitimate roles # derived from the roles associated with the access permission on item if it's not public. This @@ -1387,11 +1373,8 @@ def get_permitted_libraries(self, trans, user, actions): libraries = trans.app.security_agent.get_permitted_libraries( trans, user, [ trans.app.security_agent.permitted_actions.LIBRARY_ADD ] ) """ - all_libraries = ( - trans.sa_session.query(trans.app.model.Library) - .filter(trans.app.model.Library.table.c.deleted == false()) - .order_by(trans.app.model.Library.name) - ) + stmt = select(Library).where(Library.deleted == false()).order_by(Library.name) + all_libraries = trans.sa_session.scalars(stmt) roles = user.all_roles() actions_to_check = actions # The libraries dictionary looks like: { library : '1,2' }, library : '3' } @@ -1520,31 +1503,23 @@ def get_component_associations(self, **kwd): assert len(kwd) == 2, "You must specify exactly 2 Galaxy security components to check for associations." if "dataset" in kwd: if "action" in kwd: - return ( - self.sa_session.query(self.model.DatasetPermissions) + stmt = ( + select(DatasetPermissions) .filter_by(action=kwd["action"].action, dataset_id=kwd["dataset"].id) - .first() + .limit(1) ) + return self.sa_session.scalars(stmt).first() elif "user" in kwd: if "group" in kwd: - return ( - self.sa_session.query(self.model.UserGroupAssociation) - .filter_by(group_id=kwd["group"].id, user_id=kwd["user"].id) - .first() - ) + stmt = select(UserGroupAssociation).filter_by(group_id=kwd["group"].id, user_id=kwd["user"].id).limit(1) + return self.sa_session.scalars(stmt).first() elif "role" in kwd: - return ( - self.sa_session.query(self.model.UserRoleAssociation) - .filter_by(role_id=kwd["role"].id, user_id=kwd["user"].id) - .first() - ) + stmt = select(UserRoleAssociation).filter_by(role_id=kwd["role"].id, user_id=kwd["user"].id).limit(1) + return self.sa_session.scalars(stmt).first() elif "group" in kwd: if "role" in kwd: - return ( - self.sa_session.query(self.model.GroupRoleAssociation) - .filter_by(role_id=kwd["role"].id, group_id=kwd["group"].id) - .first() - ) + stmt = select(GroupRoleAssociation).filter_by(role_id=kwd["role"].id, group_id=kwd["group"].id).limit(1) + return self.sa_session.scalars(stmt).first() raise Exception(f"No valid method of associating provided components: {kwd}") def check_folder_contents(self, user, roles, folder, hidden_folder_ids=""): @@ -1561,17 +1536,14 @@ def check_folder_contents(self, user, roles, folder, hidden_folder_ids=""): return True, "" action = self.permitted_actions.DATASET_ACCESS - lddas = ( - self.sa_session.query(self.model.LibraryDatasetDatasetAssociation) - .join("library_dataset") - .filter(self.model.LibraryDataset.folder == folder) - .join("dataset") - .options( - joinedload(self.model.LibraryDatasetDatasetAssociation.dataset).joinedload(self.model.Dataset.actions) - ) - .all() + stmt = ( + select(LibraryDatasetDatasetAssociation) + .join(LibraryDatasetDatasetAssociation.library_dataset) + .where(LibraryDataset.folder == folder) + .join(Dataset) + .options(joinedload(LibraryDatasetDatasetAssociation.dataset).joinedload(Dataset.actions)) ) - + lddas = self.sa_session.scalars(stmt).unique() for ldda in lddas: ldda_access_permissions = self.get_item_actions(action, ldda.dataset) if not ldda_access_permissions: @@ -1635,11 +1607,12 @@ def allow_action(self, addr, action, **kwd): ]: log.debug("Allowing access to public dataset with hda: %i." % hda.id) return True # dataset has no roles associated with the access permission, thus is already public - hdadaa = ( - self.sa_session.query(self.model.HistoryDatasetAssociationDisplayAtAuthorization) + stmt = ( + select(HistoryDatasetAssociationDisplayAtAuthorization) .filter_by(history_dataset_association_id=hda.id) - .first() + .limit(1) ) + hdadaa = self.sa_session.scalars(stmt).first() if not hdadaa: log.debug( "Denying access to private dataset with hda: %i. No hdadaa record for this dataset." % hda.id @@ -1677,11 +1650,12 @@ def allow_action(self, addr, action, **kwd): raise Exception("The dataset access permission is the only valid permission in the host security agent.") def set_dataset_permissions(self, hda, user, site): - hdadaa = ( - self.sa_session.query(self.model.HistoryDatasetAssociationDisplayAtAuthorization) + stmt = ( + select(HistoryDatasetAssociationDisplayAtAuthorization) .filter_by(history_dataset_association_id=hda.id) - .first() + .limit(1) ) + hdadaa = self.sa_session.scalars(stmt).first() if hdadaa: hdadaa.update_time = datetime.utcnow() else: diff --git a/lib/galaxy/model/store/__init__.py b/lib/galaxy/model/store/__init__.py index ea10a0bca4d7..59873c02fb7c 100644 --- a/lib/galaxy/model/store/__init__.py +++ b/lib/galaxy/model/store/__init__.py @@ -39,6 +39,7 @@ WorkflowDescription, ) from rocrate.rocrate import ROCrate +from sqlalchemy import select from sqlalchemy.orm import joinedload from sqlalchemy.orm.scoping import scoped_session from sqlalchemy.sql import expression @@ -466,9 +467,8 @@ def handle_dataset_object_edit(dataset_instance, dataset_attrs): self.dataset_state_serialized = False if "id" in dataset_attrs and self.import_options.allow_edit and not self.sessionless: - dataset_instance: model.DatasetInstance = self.sa_session.query( - getattr(model, dataset_attrs["model_class"]) - ).get(dataset_attrs["id"]) + model_class = getattr(model, dataset_attrs["model_class"]) + dataset_instance: model.DatasetInstance = self.sa_session.get(model_class, dataset_attrs["id"]) attributes = [ "name", "extension", @@ -775,7 +775,7 @@ def import_folder(folder_attrs, root_folder=None): and not self.sessionless and self.import_options.allow_edit ): - library_folder = self.sa_session.query(model.LibraryFolder).get(library_attrs["id"]) + library_folder = self.sa_session.get(model.LibraryFolder, library_attrs["id"]) import_folder(library_attrs, root_folder=library_folder) else: assert self.import_options.allow_library_creation @@ -838,7 +838,7 @@ def materialize_elements(dc): dc.element_count = len(elements_attrs) if "id" in collection_attrs and self.import_options.allow_edit and not self.sessionless: - dc = self.sa_session.query(model.DatasetCollection).get(collection_attrs["id"]) + dc = self.sa_session.get(model.DatasetCollection, collection_attrs["id"]) attributes = [ "collection_type", "populated_state", @@ -865,7 +865,7 @@ def materialize_elements(dc): if "collection" in collection_attrs: dc = import_collection(collection_attrs["collection"]) if "id" in collection_attrs and self.import_options.allow_edit and not self.sessionless: - hdca = self.sa_session.query(model.HistoryDatasetCollectionAssociation).get(collection_attrs["id"]) + hdca = self.sa_session.get(model.HistoryDatasetCollectionAssociation, collection_attrs["id"]) # TODO: edit attributes... else: hdca = model.HistoryDatasetCollectionAssociation( @@ -1217,7 +1217,7 @@ def _import_jobs(self, object_import_tracker: "ObjectImportTracker", history: Op if "id" in job_attrs and not self.sessionless: # only thing we allow editing currently is associations for incoming jobs. assert self.import_options.allow_edit - job = self.sa_session.query(model.Job).get(job_attrs["id"]) + job = self.sa_session.get(model.Job, job_attrs["id"]) self._connect_job_io(job, job_attrs, _find_hda, _find_hdca, _find_dce) # type: ignore[attr-defined] self._set_job_attributes(job, job_attrs, force_terminal=False) # type: ignore[attr-defined] # Don't edit job @@ -2110,12 +2110,12 @@ def export_history( sa_session = app.model.session # Write collections' attributes (including datasets list) to file. - query = ( - sa_session.query(model.HistoryDatasetCollectionAssociation) - .filter(model.HistoryDatasetCollectionAssociation.history == history) - .filter(model.HistoryDatasetCollectionAssociation.deleted == expression.false()) + stmt = ( + select(model.HistoryDatasetCollectionAssociation) + .where(model.HistoryDatasetCollectionAssociation.history == history) + .where(model.HistoryDatasetCollectionAssociation.deleted == expression.false()) ) - collections = query.all() + collections = sa_session.scalars(stmt) for collection in collections: # filter this ? @@ -2128,15 +2128,16 @@ def export_history( # Write datasets' attributes to file. actions_backref = model.Dataset.actions # type: ignore[attr-defined] - query = ( - sa_session.query(model.HistoryDatasetAssociation) - .filter(model.HistoryDatasetAssociation.history == history) + + stmt = ( + select(model.HistoryDatasetAssociation) + .where(model.HistoryDatasetAssociation.history == history) .join(model.Dataset) .options(joinedload(model.HistoryDatasetAssociation.dataset).joinedload(actions_backref)) .order_by(model.HistoryDatasetAssociation.hid) - .filter(model.Dataset.purged == expression.false()) + .where(model.Dataset.purged == expression.false()) ) - datasets = query.all() + datasets = sa_session.scalars(stmt).unique() for dataset in datasets: dataset.annotation = get_item_annotation_str(sa_session, history.user, dataset) should_include_file = (dataset.visible or include_hidden) and (not dataset.deleted or include_deleted) diff --git a/lib/galaxy/model/store/discover.py b/lib/galaxy/model/store/discover.py index 832357acb45b..5f1a9a9b0b31 100644 --- a/lib/galaxy/model/store/discover.py +++ b/lib/galaxy/model/store/discover.py @@ -828,7 +828,7 @@ def collect_elements_for_history(elements): sa_session = ( model_persistence_context.sa_session or model_persistence_context.import_store.sa_session ) - primary_dataset = sa_session.query(galaxy.model.HistoryDatasetAssociation).get(hda_id) + primary_dataset = sa_session.get(galaxy.model.HistoryDatasetAssociation, hda_id) sources = fields_match.sources hashes = fields_match.hashes diff --git a/lib/galaxy/model/tags.py b/lib/galaxy/model/tags.py index 1b378a534643..739d8528195d 100644 --- a/lib/galaxy/model/tags.py +++ b/lib/galaxy/model/tags.py @@ -287,12 +287,12 @@ def get_tags_str(self, tags): def get_tag_by_id(self, tag_id): """Get a Tag object from a tag id.""" - return self.sa_session.query(galaxy.model.Tag).filter_by(id=tag_id).first() + return self.sa_session.get(galaxy.model.Tag, tag_id) def get_tag_by_name(self, tag_name): """Get a Tag object from a tag name (string).""" if tag_name: - return self.sa_session.query(galaxy.model.Tag).filter_by(name=tag_name.lower()).first() + return self.sa_session.scalars(select(galaxy.model.Tag).filter_by(name=tag_name.lower()).limit(1)).first() return None def _create_tag(self, tag_str: str): @@ -317,7 +317,7 @@ def _create_tag(self, tag_str: str): return tag def _get_tag(self, tag_name): - return self.sa_session.query(galaxy.model.Tag).filter_by(name=tag_name).first() + return self.sa_session.scalars(select(galaxy.model.Tag).filter_by(name=tag_name).limit(1)).first() def _create_tag_instance(self, tag_name): # For good performance caller should first check if there's already an appropriate tag diff --git a/lib/galaxy/webapps/galaxy/controllers/page.py b/lib/galaxy/webapps/galaxy/controllers/page.py index 3df87e27b7b1..7fb4112070a8 100644 --- a/lib/galaxy/webapps/galaxy/controllers/page.py +++ b/lib/galaxy/webapps/galaxy/controllers/page.py @@ -1,6 +1,5 @@ from markupsafe import escape from sqlalchemy import ( - desc, false, true, ) @@ -19,8 +18,14 @@ HistoryManager, HistorySerializer, ) -from galaxy.managers.pages import PageManager +from galaxy.managers.pages import ( + get_page as get_page_, + get_shared_pages, + page_exists, + PageManager, +) from galaxy.managers.sharable import SlugBuilder +from galaxy.managers.users import get_user_by_username from galaxy.managers.workflows import WorkflowsManager from galaxy.model.base import transaction from galaxy.model.item_attrs import UsesItemRatings @@ -377,7 +382,7 @@ def list(self, trans, *args, **kwargs): ids = util.listify(kwargs["id"]) for id in ids: if operation == "delete": - item = session.query(model.Page).get(self.decode_id(id)) + item = session.get(model.Page, self.decode_id(id)) self.security_check(trans, item, check_ownership=True) item.deleted = True with transaction(session): @@ -397,14 +402,7 @@ def list_published(self, trans, *args, **kwargs): def _get_shared(self, trans): """Identify shared pages""" - shared_by_others = ( - trans.sa_session.query(model.PageUserShareAssociation) - .filter_by(user=trans.get_user()) - .join(model.Page.table) - .filter(model.Page.deleted == false()) - .order_by(desc(model.Page.update_time)) - .all() - ) + shared_by_others = get_shared_pages(trans.sa_session, trans.get_user()) return [ {"username": p.page.user.username, "slug": p.page.slug, "title": p.page.title} for p in shared_by_others ] @@ -480,7 +478,7 @@ def edit(self, trans, payload=None, **kwd): return self.message_exception(trans, "No page id received for editing.") decoded_id = self.decode_id(id) user = trans.get_user() - p = trans.sa_session.query(model.Page).get(decoded_id) + p = trans.sa_session.get(model.Page, decoded_id) p = self.security_check(trans, p, check_ownership=True) if trans.request.method == "GET": if p.slug is None: @@ -515,10 +513,7 @@ def edit(self, trans, payload=None, **kwd): return self.message_exception( trans, "Page identifier can only contain lowercase letters, numbers, and dashes (-)." ) - elif ( - p_slug != p.slug - and trans.sa_session.query(model.Page).filter_by(user=p.user, slug=p_slug, deleted=False).first() - ): + elif p_slug != p.slug and page_exists(trans.sa_session, p.user, p_slug): return self.message_exception(trans, "Page id must be unique.") else: p.title = p_title @@ -535,7 +530,7 @@ def edit(self, trans, payload=None, **kwd): @web.require_login() def display(self, trans, id, **kwargs): id = self.decode_id(id) - page = trans.sa_session.query(model.Page).get(id) + page = trans.sa_session.get(model.Page, id) if not page: raise web.httpexceptions.HTTPNotFound() return self.display_by_username_and_slug(trans, page.user.username, page.slug) @@ -545,9 +540,8 @@ def display_by_username_and_slug(self, trans, username, slug, **kwargs): """Display page based on a username and slug.""" # Get page. - session = trans.sa_session - user = session.query(model.User).filter_by(username=username).first() - page = trans.sa_session.query(model.Page).filter_by(user=user, slug=slug, deleted=False).first() + user = get_user_by_username(trans.sa_session, username) + page = get_page_(trans.sa_session, user, slug) if page is None: raise web.httpexceptions.HTTPNotFound() @@ -605,7 +599,7 @@ def get_page(self, trans, id, check_ownership=True, check_accessible=False): """Get a page from the database by id.""" # Load history from database id = self.decode_id(id) - page = trans.sa_session.query(model.Page).get(id) + page = trans.sa_session.get(model.Page, id) if not page: error("Page not found") else: