Skip to content

Commit

Permalink
Back to pg8000 and add type of id to the model functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
pwei1018 committed Sep 20, 2024
1 parent d19e944 commit 35a1574
Show file tree
Hide file tree
Showing 15 changed files with 56 additions and 52 deletions.
2 changes: 1 addition & 1 deletion auth-api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion auth-api/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "auth-api"
version = "3.0.3"
version = "3.0.4"
description = ""
authors = ["\"BC Registries and Online Services\""]
readme = "README.md"
Expand Down
10 changes: 6 additions & 4 deletions auth-api/src/auth_api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,13 @@ class _Config: # pylint: disable=too-few-public-methods
DB_PASSWORD = os.getenv("DATABASE_PASSWORD", "")
DB_NAME = os.getenv("DATABASE_NAME", "")
DB_HOST = os.getenv("DATABASE_HOST", "")
DB_PORT = os.getenv("DATABASE_PORT", "5432")
DB_PORT = int(os.getenv("DATABASE_PORT", "5432")) # POSTGRESQL
if DB_UNIX_SOCKET := os.getenv("DATABASE_UNIX_SOCKET", None):
SQLALCHEMY_DATABASE_URI = f"postgresql+psycopg2://{DB_USER}:{DB_PASSWORD}@{DB_UNIX_SOCKET}/{DB_NAME}"
SQLALCHEMY_DATABASE_URI = (
f"postgresql+pg8000://{DB_USER}:{DB_PASSWORD}@/{DB_NAME}?unix_sock={DB_UNIX_SOCKET}/.s.PGSQL.5432"
)
else:
SQLALCHEMY_DATABASE_URI = f"postgresql+psycopg2://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{int(DB_PORT)}/{DB_NAME}"
SQLALCHEMY_DATABASE_URI = f"postgresql+pg8000://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"

# JWT_OIDC Settings
JWT_OIDC_WELL_KNOWN_CONFIG = os.getenv("JWT_OIDC_WELL_KNOWN_CONFIG")
Expand Down Expand Up @@ -220,7 +222,7 @@ class TestConfig(_Config): # pylint: disable=too-few-public-methods
DB_HOST = os.getenv("DATABASE_TEST_HOST", "localhost")
DB_PORT = os.getenv("DATABASE_TEST_PORT", "5432")
SQLALCHEMY_DATABASE_URI = os.getenv(
"DATABASE_TEST_URL", f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{int(DB_PORT)}/{DB_NAME}"
"DATABASE_TEST_URL", f"postgresql+pg8000://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{int(DB_PORT)}/{DB_NAME}"
)

# JWT OIDC settings
Expand Down
2 changes: 1 addition & 1 deletion auth-api/src/auth_api/models/account_login_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ class AccountLoginOptions(Versioned, BaseModel): # pylint: disable=too-few-publ
org = relationship("Org", foreign_keys=[org_id], lazy="select")

@classmethod
def find_active_by_org_id(cls, account_id):
def find_active_by_org_id(cls, account_id: int):
"""Find an account setting instance that matches the provided org_id."""
return cls.query.filter_by(org_id=account_id).filter_by(is_active=True).first()
4 changes: 2 additions & 2 deletions auth-api/src/auth_api/models/affiliation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ def filter_environment(cls, environment: str):
return query

@classmethod
def find_affiliation_by_org_and_entity_ids(cls, org_id, entity_id, environment) -> Affiliation:
def find_affiliation_by_org_and_entity_ids(cls, org_id: int, entity_id: int, environment) -> Affiliation:
"""Return an affiliation for the provided org and entity ids."""
query = cls.filter_environment(environment).filter_by(org_id=org_id, entity_id=entity_id)
return query.one_or_none()

@classmethod
def find_affiliations_by_entity_id(cls, entity_id, environment) -> List[Affiliation]:
def find_affiliations_by_entity_id(cls, entity_id: int, environment) -> List[Affiliation]:
"""Return affiliations for the provided entity id."""
return cls.filter_environment(environment).filter_by(entity_id=entity_id).all()

Expand Down
14 changes: 7 additions & 7 deletions auth-api/src/auth_api/models/affiliation_invitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,30 +159,30 @@ def filter_by(cls, search_filter: AffiliationInvitationSearch, query=None) -> li
return results.all()

@classmethod
def find_invitation_by_id(cls, invitation_id):
def find_invitation_by_id(cls, invitation_id: int):
"""Find an affiliation invitation record that matches the id."""
return cls.query.filter_by(id=invitation_id).first()

@classmethod
def find_invitations_from_org(cls, org_id, status=None):
def find_invitations_from_org(cls, org_id: int, status=None):
"""Find all affiliation invitations sent from a specific org filtered by status."""
results = db.session.query(AffiliationInvitation).filter(AffiliationInvitation.from_org_id == org_id)
return results.filter(AffiliationInvitation.status == status.value).all() if status else results.all()

@classmethod
def find_invitations_to_org(cls, org_id, status=None):
def find_invitations_to_org(cls, org_id: int, status=None):
"""Find all affiliation invitations sent to a specific org filtered by status."""
results = db.session.query(AffiliationInvitation).filter(AffiliationInvitation.to_org_id == org_id)
return results.filter(AffiliationInvitation.status == status.value).all() if status else results.all()

@classmethod
def find_invitations_by_entity(cls, entity_id, status=None):
def find_invitations_by_entity(cls, entity_id: int, status=None):
"""Find all affiliation invitations sent for specific entity filtered by status."""
results = db.session.query(AffiliationInvitation).filter(AffiliationInvitation.entity_id == entity_id)
return results.filter(AffiliationInvitation.status == status.value).all() if status else results.all()

@classmethod
def find_invitations_by_affiliation(cls, affiliation_id):
def find_invitations_by_affiliation(cls, affiliation_id: int):
"""Find all affiliation invitations associated to an affiliation."""
return cls.query.filter_by(affiliation_id=affiliation_id).all()

Expand All @@ -202,7 +202,7 @@ def find_invitations_by_org_entity_ids(from_org_id: int, entity_id: int):
.all()
)

def update_invitation_as_retried(self, sender_id):
def update_invitation_as_retried(self, sender_id: int):
"""Update this affiliation invitation with the new data."""
self.sender_id = sender_id
self.sent_date = datetime.now()
Expand All @@ -217,7 +217,7 @@ def set_status(self, new_status_code: str):
return self

@classmethod
def find_all_related_to_org(cls, org_id, search_filter=AffiliationInvitationSearch()):
def find_all_related_to_org(cls, org_id: int, search_filter=AffiliationInvitationSearch()):
"""Return all affiliation invitations that are related to the org (from org or to org) filtered by statuses."""
query = db.session.query(AffiliationInvitation).filter(
or_(AffiliationInvitation.to_org_id == org_id, AffiliationInvitation.from_org_id == org_id)
Expand Down
8 changes: 4 additions & 4 deletions auth-api/src/auth_api/models/contact_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,22 @@ class ContactLink(Versioned, BaseModel): # pylint: disable=too-few-public-metho
affidavit = relationship("Affidavit", foreign_keys=[affidavit_id], lazy="select")

@classmethod
def find_by_entity_id(cls, entity_id):
def find_by_entity_id(cls, entity_id: int):
"""Return the first contact link with the provided entity id."""
return cls.query.filter_by(entity_id=entity_id).first()

@classmethod
def find_by_user_id(cls, user_id):
def find_by_user_id(cls, user_id: int):
"""Return the first contact link with the provided user id."""
return cls.query.filter_by(user_id=user_id).first()

@classmethod
def find_by_org_id(cls, org_id):
def find_by_org_id(cls, org_id: int):
"""Return the first contact link with the provided org id."""
return cls.query.filter_by(org_id=org_id).first()

@classmethod
def find_by_affidavit_id(cls, affidavit_id):
def find_by_affidavit_id(cls, affidavit_id: int):
"""Return the first contact link with the provided affidavit id."""
return cls.query.filter_by(affidavit_id=affidavit_id).one_or_none()

Expand Down
2 changes: 1 addition & 1 deletion auth-api/src/auth_api/models/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def create_from_dict(cls, entity_info: dict):
return None

@classmethod
def find_by_entity_id(cls, entity_id):
def find_by_entity_id(cls, entity_id: int):
"""Find an Entity instance that matches the provided id."""
return cls.query.filter_by(id=entity_id).first()

Expand Down
14 changes: 7 additions & 7 deletions auth-api/src/auth_api/models/invitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def status(self):
return self.invitation_status_code

@classmethod
def create_from_dict(cls, invitation_info: dict, user_id, invitation_type):
def create_from_dict(cls, invitation_info: dict, user_id: int, invitation_type):
"""Create a new Invitation from the provided dictionary."""
if invitation_info:
invitation = Invitation()
Expand All @@ -85,23 +85,23 @@ def create_from_dict(cls, invitation_info: dict, user_id, invitation_type):
return None

@classmethod
def find_invitations_by_user(cls, user_id):
def find_invitations_by_user(cls, user_id: int):
"""Find all invitation sent by the given user."""
return cls.query.filter_by(sender_id=user_id).all()

@classmethod
def find_invitation_by_id(cls, invitation_id):
def find_invitation_by_id(cls, invitation_id: int):
"""Find an invitation record that matches the id."""
return cls.query.filter_by(id=invitation_id).first()

@classmethod
def find_invitations_by_org(cls, org_id, status=None):
def find_invitations_by_org(cls, org_id: int, status=None):
"""Find all invitations sent for specific org filtered by status."""
results = cls.query.filter(Invitation.membership.any(InvitationMembership.org_id == org_id))
return results.filter(Invitation.status == status.value).all() if status else results.all()

@staticmethod
def find_pending_invitations_by_user(user_id):
def find_pending_invitations_by_user(user_id: int):
"""Find all invitations that are not in accepted state."""
return (
db.session.query(Invitation)
Expand All @@ -111,7 +111,7 @@ def find_pending_invitations_by_user(user_id):
)

@staticmethod
def find_pending_invitations_by_org(org_id):
def find_pending_invitations_by_org(org_id: int):
"""Find all invitations that are not in accepted state."""
return (
db.session.query(Invitation)
Expand All @@ -121,7 +121,7 @@ def find_pending_invitations_by_org(org_id):
)

@staticmethod
def find_invitations_by_status(user_id, status):
def find_invitations_by_status(user_id: int, status):
"""Find all invitations that are not in accepted state."""
return (
db.session.query(Invitation)
Expand Down
30 changes: 16 additions & 14 deletions auth-api/src/auth_api/models/membership.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,17 @@ def __init__(self, **kwargs): # pylint: disable=super-init-not-called
self.status = kwargs.get("membership_type_status")

@classmethod
def find_membership_by_id(cls, membership_id) -> Membership:
def find_membership_by_id(cls, membership_id: int) -> Membership:
"""Find the first membership with the given id and return it."""
return cls.query.filter_by(id=membership_id).first()

@classmethod
def find_members_by_org_id(cls, org_id) -> List[Membership]:
def find_members_by_org_id(cls, org_id: int) -> List[Membership]:
"""Return all members of the org with a status."""
return cls.query.filter_by(org_id=org_id).all()

@classmethod
def get_pending_members_count_by_org_id(cls, org_id) -> int:
def get_pending_members_count_by_org_id(cls, org_id: int) -> int:
"""Return the count of pending members."""
query = (
db.session.query(Membership)
Expand All @@ -92,7 +92,9 @@ def get_pending_members_count_by_org_id(cls, org_id) -> int:
return count

@classmethod
def find_members_by_org_id_by_status_by_roles(cls, org_id, roles, status=Status.ACTIVE.value) -> List[Membership]:
def find_members_by_org_id_by_status_by_roles(
cls, org_id: int, roles, status=Status.ACTIVE.value
) -> List[Membership]:
"""Return all members of the org with a status."""
return (
db.session.query(Membership)
Expand All @@ -103,7 +105,7 @@ def find_members_by_org_id_by_status_by_roles(cls, org_id, roles, status=Status.
)

@classmethod
def find_orgs_for_user(cls, user_id, valid_statuses=VALID_STATUSES) -> List[OrgModel]:
def find_orgs_for_user(cls, user_id: int, valid_statuses=VALID_STATUSES) -> List[OrgModel]:
"""Find the orgs for a user."""
records = (
cls.query.join(OrgModel)
Expand All @@ -116,7 +118,7 @@ def find_orgs_for_user(cls, user_id, valid_statuses=VALID_STATUSES) -> List[OrgM
return list(map(lambda x: x.org, records))

@classmethod
def find_active_staff_org_memberships_for_user(cls, user_id) -> List[Membership]:
def find_active_staff_org_memberships_for_user(cls, user_id: int) -> List[Membership]:
"""Find staff orgs memberships for a user."""
return (
cls.query.join(OrgModel)
Expand All @@ -128,7 +130,7 @@ def find_active_staff_org_memberships_for_user(cls, user_id) -> List[Membership]
)

@classmethod
def add_membership_for_staff(cls, user_id):
def add_membership_for_staff(cls, user_id: int):
"""Add staff membership."""
if staff_orgs := OrgModel.find_by_org_type(OrgType.STAFF.value):
membership = cls.find_membership_by_user_and_org(user_id, staff_orgs[0].id)
Expand All @@ -138,15 +140,15 @@ def add_membership_for_staff(cls, user_id):
membership.save()

@classmethod
def remove_membership_for_staff(cls, user_id):
def remove_membership_for_staff(cls, user_id: int):
"""Remove staff membership."""
staff_memberships = cls.find_active_staff_org_memberships_for_user(user_id)
for staff_membership in staff_memberships:
staff_membership.status = Status.INACTIVE.value
staff_membership.save()

@classmethod
def find_membership_by_user_and_org(cls, user_id, org_id) -> Membership:
def find_membership_by_user_and_org(cls, user_id: int, org_id: int) -> Membership:
"""Get the membership for the specified user and org."""
records = (
cls.query.filter(cls.user_id == user_id)
Expand All @@ -159,21 +161,21 @@ def find_membership_by_user_and_org(cls, user_id, org_id) -> Membership:
return records

@classmethod
def find_membership_by_userid(cls, user_id) -> Membership:
def find_membership_by_userid(cls, user_id: int) -> Membership:
"""Get the membership for the specified user."""
records = cls.query.filter(cls.user_id == user_id).order_by(desc(Membership.created)).first()

return records

@classmethod
def find_memberships_by_user_ids(cls, user_id) -> List[Membership]:
def find_memberships_by_user_ids(cls, user_id: int) -> List[Membership]:
"""Get the memberships for the specified user ids."""
records = cls.query.filter(cls.user_id == user_id).order_by(desc(Membership.created)).all()

return records

@classmethod
def find_membership_by_user_and_org_all_status(cls, user_id, org_id) -> Membership:
def find_membership_by_user_and_org_all_status(cls, user_id: int, org_id: int) -> Membership:
"""Get the membership for the specified user and org with all membership statuses."""
records = (
cls.query.filter(cls.user_id == user_id)
Expand All @@ -185,7 +187,7 @@ def find_membership_by_user_and_org_all_status(cls, user_id, org_id) -> Membersh
return records

@classmethod
def get_count_active_owner_org_id(cls, org_id) -> int:
def get_count_active_owner_org_id(cls, org_id: int) -> int:
"""Return the count of pending members."""
query = (
db.session.query(Membership)
Expand All @@ -206,7 +208,7 @@ def get_count_active_owner_org_id(cls, org_id) -> int:
return count

@classmethod
def check_if_active_admin_or_owner_org_id(cls, org_id, user_id) -> int:
def check_if_active_admin_or_owner_org_id(cls, org_id: int, user_id: int) -> int:
"""Return the count of pending members."""
query = (
db.session.query(Membership)
Expand Down
6 changes: 3 additions & 3 deletions auth-api/src/auth_api/models/org.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def find_by_org_uuid(cls, org_uuid):
return cls.query.filter_by(uuid=org_uuid).first()

@classmethod
def find_by_org_id(cls, org_id):
def find_by_org_id(cls, org_id: int):
"""Find an Org instance that matches the provided id."""
return cls.query.filter_by(id=org_id).first()

Expand Down Expand Up @@ -238,7 +238,7 @@ def _search_for_statuses(cls, query, statuses):
return query

@classmethod
def search_pending_activation_orgs(cls, name):
def search_pending_activation_orgs(cls, name: str):
"""Find all orgs with the given type."""
query = (
db.session.query(Org)
Expand Down Expand Up @@ -271,7 +271,7 @@ def find_by_org_access_type(cls, org_type):
return cls.query.filter_by(access_type=org_type).all()

@classmethod
def find_similar_org_by_name(cls, name, org_id=None, branch_name=None):
def find_similar_org_by_name(cls, name, org_id: int = None, branch_name=None):
"""Find an Org instance that matches the provided name."""
query = cls.query.filter(
and_(
Expand Down
4 changes: 2 additions & 2 deletions auth-api/src/auth_api/models/org_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ class OrgSettings(
org = relationship("Org")

@classmethod
def get_org_settings(cls, org_id):
def get_org_settings(cls, org_id: int):
"""Return the default status code for an Org."""
return cls.query.filter_by(org_id=org_id).all()

@classmethod
def is_admin_auto_approved_invitees(cls, org_id):
def is_admin_auto_approved_invitees(cls, org_id: int):
"""Return the default status code for an Org."""
org_model = cls.query.filter_by(org_id=org_id, setting="ADMIN_AUTO_APPROVAL_FOR_MEMBER_ACCEPTANCE").first()
if org_model is not None:
Expand Down
2 changes: 1 addition & 1 deletion auth-api/src/auth_api/models/product_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def find_by_org_ids(cls, org_ids, valid_statuses=VALID_SUBSCRIPTION_STATUSES):
).all()

@classmethod
def find_by_org_id_product_code(cls, org_id, product_code, valid_statuses=VALID_SUBSCRIPTION_STATUSES):
def find_by_org_id_product_code(cls, org_id: int, product_code, valid_statuses=VALID_SUBSCRIPTION_STATUSES):
"""Find an product subscription instance that matches the provided id."""
return cls.query.filter(
and_(
Expand Down
Loading

0 comments on commit 35a1574

Please sign in to comment.