diff --git a/backend/src/app/routers/editions/editions.py b/backend/src/app/routers/editions/editions.py index dff2ef30b..e000cf13d 100644 --- a/backend/src/app/routers/editions/editions.py +++ b/backend/src/app/routers/editions/editions.py @@ -16,6 +16,8 @@ # Don't add the "Editions" tag here, because then it gets applied # to all child routes as well +from ...utils.dependencies import require_admin, require_auth, require_coach + editions_router = APIRouter(prefix="/editions") # Register all child routers @@ -31,20 +33,20 @@ editions_router.include_router(router, prefix="/{edition_id}") -@editions_router.get("/",response_model=EditionList, tags=[Tags.EDITIONS]) +@editions_router.get("/", response_model=EditionList, tags=[Tags.EDITIONS], dependencies=[Depends(require_auth)]) async def get_editions(db: Session = Depends(get_session)): """Get a list of all editions. - Args: db (Session, optional): connection with the database. Defaults to Depends(get_session). Returns: EditionList: an object with a list of all the editions. """ + # TODO only return editions the user can see return logic_editions.get_editions(db) -@editions_router.get("/{edition_id}", response_model=Edition, tags=[Tags.EDITIONS]) +@editions_router.get("/{edition_id}", response_model=Edition, tags=[Tags.EDITIONS], dependencies=[Depends(require_coach)]) async def get_edition_by_id(edition_id: int, db: Session = Depends(get_session)): """Get a specific edition. @@ -58,7 +60,7 @@ async def get_edition_by_id(edition_id: int, db: Session = Depends(get_session)) return logic_editions.get_edition_by_id(db, edition_id) -@editions_router.post("/", status_code=status.HTTP_201_CREATED, response_model=Edition, tags=[Tags.EDITIONS]) +@editions_router.post("/", status_code=status.HTTP_201_CREATED, response_model=Edition, tags=[Tags.EDITIONS], dependencies=[Depends(require_admin)]) async def post_edition(edition: EditionBase, db: Session = Depends(get_session)): """ Create a new edition. @@ -71,7 +73,7 @@ async def post_edition(edition: EditionBase, db: Session = Depends(get_session)) return logic_editions.create_edition(db, edition) -@editions_router.delete("/{edition_id}", status_code=status.HTTP_204_NO_CONTENT, tags=[Tags.EDITIONS]) +@editions_router.delete("/{edition_id}", status_code=status.HTTP_204_NO_CONTENT, tags=[Tags.EDITIONS], dependencies=[Depends(require_admin)]) async def delete_edition(edition_id: int, db: Session = Depends(get_session)): """Delete an existing edition. diff --git a/backend/src/app/routers/editions/invites/invites.py b/backend/src/app/routers/editions/invites/invites.py index 465968b09..dd2480542 100644 --- a/backend/src/app/routers/editions/invites/invites.py +++ b/backend/src/app/routers/editions/invites/invites.py @@ -6,14 +6,14 @@ from src.app.logic.invites import create_mailto_link, delete_invite_link, get_pending_invites_list from src.app.routers.tags import Tags from src.app.schemas.invites import InvitesListResponse, EmailAddress, MailtoLink, InviteLink as InviteLinkModel -from src.app.utils.dependencies import get_edition, get_invite_link +from src.app.utils.dependencies import get_edition, get_invite_link, require_admin from src.database.database import get_session from src.database.models import Edition, InviteLink as InviteLinkDB invites_router = APIRouter(prefix="/invites", tags=[Tags.INVITES]) -@invites_router.get("/", response_model=InvitesListResponse) +@invites_router.get("/", response_model=InvitesListResponse, dependencies=[Depends(require_admin)]) async def get_invites(db: Session = Depends(get_session), edition: Edition = Depends(get_edition)): """ Get a list of all pending invitation links. @@ -21,15 +21,18 @@ async def get_invites(db: Session = Depends(get_session), edition: Edition = Dep return get_pending_invites_list(db, edition) -@invites_router.post("/", status_code=status.HTTP_201_CREATED, response_model=MailtoLink) -async def create_invite(email: EmailAddress, db: Session = Depends(get_session), edition: Edition = Depends(get_edition)): +@invites_router.post("/", status_code=status.HTTP_201_CREATED, response_model=MailtoLink, + dependencies=[Depends(require_admin)]) +async def create_invite(email: EmailAddress, db: Session = Depends(get_session), + edition: Edition = Depends(get_edition)): """ Create a new invitation link for the current edition. """ return create_mailto_link(db, edition, email) -@invites_router.delete("/{invite_uuid}", status_code=status.HTTP_204_NO_CONTENT, response_class=Response) +@invites_router.delete("/{invite_uuid}", status_code=status.HTTP_204_NO_CONTENT, response_class=Response, + dependencies=[Depends(require_admin)]) async def delete_invite(invite_link: InviteLinkDB = Depends(get_invite_link), db: Session = Depends(get_session)): """ Delete an existing invitation link manually so that it can't be used anymore. diff --git a/backend/src/app/routers/editions/webhooks/webhooks.py b/backend/src/app/routers/editions/webhooks/webhooks.py index 2fec59219..1026780fa 100644 --- a/backend/src/app/routers/editions/webhooks/webhooks.py +++ b/backend/src/app/routers/editions/webhooks/webhooks.py @@ -5,7 +5,7 @@ from src.database.crud.webhooks import get_webhook, create_webhook from src.app.schemas.webhooks import WebhookEvent, WebhookUrlResponse from src.database.models import Edition -from src.app.utils.dependencies import get_edition +from src.app.utils.dependencies import get_edition, require_admin from src.app.routers.tags import Tags from src.app.logic.webhooks import process_webhook from starlette import status @@ -18,10 +18,10 @@ def valid_uuid(uuid: str, database: Session = Depends(get_session)): get_webhook(database, uuid) -# TODO: check admin permission -@webhooks_router.post("/", response_model=WebhookUrlResponse, status_code=status.HTTP_201_CREATED) +@webhooks_router.post("/", response_model=WebhookUrlResponse, status_code=status.HTTP_201_CREATED, + dependencies=[Depends(require_admin)]) def new(edition: Edition = Depends(get_edition), database: Session = Depends(get_session)): - """Create e new webhook for an edition""" + """Create a new webhook for an edition""" return create_webhook(database, edition) diff --git a/backend/src/app/routers/skills/skills.py b/backend/src/app/routers/skills/skills.py index 64af3d57e..42b60d72d 100644 --- a/backend/src/app/routers/skills/skills.py +++ b/backend/src/app/routers/skills/skills.py @@ -2,18 +2,16 @@ from sqlalchemy.orm import Session from starlette import status -from src.database.database import get_session -from src.app.schemas.skills import SkillBase, Skill, SkillList from src.app.logic import skills as logic_skills - -from src.app.schemas.skills import SkillBase from src.app.routers.tags import Tags - +from src.app.schemas.skills import SkillBase, Skill, SkillList +from src.app.utils.dependencies import require_auth +from src.database.database import get_session skills_router = APIRouter(prefix="/skills", tags=[Tags.SKILLS]) -@skills_router.get("/", response_model=SkillList, tags=[Tags.SKILLS]) +@skills_router.get("/", response_model=SkillList, tags=[Tags.SKILLS], dependencies=[Depends(require_auth)]) async def get_skills(db: Session = Depends(get_session)): """Get a list of all the base skills that can be added to a student or project. @@ -26,7 +24,7 @@ async def get_skills(db: Session = Depends(get_session)): return logic_skills.get_skills(db) -@skills_router.post("/",status_code=status.HTTP_201_CREATED, response_model=Skill, tags=[Tags.SKILLS]) +@skills_router.post("/",status_code=status.HTTP_201_CREATED, response_model=Skill, tags=[Tags.SKILLS], dependencies=[Depends(require_auth)]) async def create_skill(skill: SkillBase, db: Session = Depends(get_session)): """Add a new skill into the database. @@ -40,7 +38,7 @@ async def create_skill(skill: SkillBase, db: Session = Depends(get_session)): return logic_skills.create_skill(db, skill) -@skills_router.delete("/{skill_id}", status_code=status.HTTP_204_NO_CONTENT, tags=[Tags.SKILLS]) +@skills_router.delete("/{skill_id}", status_code=status.HTTP_204_NO_CONTENT, tags=[Tags.SKILLS], dependencies=[Depends(require_auth)]) async def delete_skill(skill_id: int, db: Session = Depends(get_session)): """Delete an existing skill. diff --git a/backend/src/app/routers/users/users.py b/backend/src/app/routers/users/users.py index d2191b98f..652d6e725 100644 --- a/backend/src/app/routers/users/users.py +++ b/backend/src/app/routers/users/users.py @@ -4,12 +4,13 @@ from src.app.routers.tags import Tags import src.app.logic.users as logic from src.app.schemas.users import UsersListResponse, AdminPatch, UserRequestsResponse +from src.app.utils.dependencies import require_admin from src.database.database import get_session users_router = APIRouter(prefix="/users", tags=[Tags.USERS]) -@users_router.get("/", response_model=UsersListResponse) +@users_router.get("/", response_model=UsersListResponse, dependencies=[Depends(require_admin)]) async def get_users(admin: bool = Query(False), edition: int | None = Query(None), db: Session = Depends(get_session)): """ Get users @@ -18,7 +19,7 @@ async def get_users(admin: bool = Query(False), edition: int | None = Query(None return logic.get_users_list(db, admin, edition) -@users_router.patch("/{user_id}", status_code=204) +@users_router.patch("/{user_id}", status_code=204, dependencies=[Depends(require_admin)]) async def patch_admin_status(user_id: int, admin: AdminPatch, db: Session = Depends(get_session)): """ Set admin-status of user @@ -27,7 +28,7 @@ async def patch_admin_status(user_id: int, admin: AdminPatch, db: Session = Depe logic.edit_admin_status(db, user_id, admin) -@users_router.post("/{user_id}/editions/{edition_id}", status_code=204) +@users_router.post("/{user_id}/editions/{edition_id}", status_code=204, dependencies=[Depends(require_admin)]) async def add_to_edition(user_id: int, edition_id: int, db: Session = Depends(get_session)): """ Add user as coach of the given edition @@ -36,7 +37,7 @@ async def add_to_edition(user_id: int, edition_id: int, db: Session = Depends(ge logic.add_coach(db, user_id, edition_id) -@users_router.delete("/{user_id}/editions/{edition_id}", status_code=204) +@users_router.delete("/{user_id}/editions/{edition_id}", status_code=204, dependencies=[Depends(require_admin)]) async def remove_from_edition(user_id: int, edition_id: int, db: Session = Depends(get_session)): """ Remove user as coach of the given edition @@ -45,7 +46,7 @@ async def remove_from_edition(user_id: int, edition_id: int, db: Session = Depen logic.remove_coach(db, user_id, edition_id) -@users_router.get("/requests", response_model=UserRequestsResponse) +@users_router.get("/requests", response_model=UserRequestsResponse, dependencies=[Depends(require_admin)]) async def get_requests(edition: int | None = Query(None), db: Session = Depends(get_session)): """ Get pending userrequests @@ -54,7 +55,7 @@ async def get_requests(edition: int | None = Query(None), db: Session = Depends( return logic.get_request_list(db, edition) -@users_router.post("/requests/{request_id}/accept", status_code=204) +@users_router.post("/requests/{request_id}/accept", status_code=204, dependencies=[Depends(require_admin)]) async def accept_request(request_id: int, db: Session = Depends(get_session)): """ Accept a coach request @@ -63,7 +64,7 @@ async def accept_request(request_id: int, db: Session = Depends(get_session)): logic.accept_request(db, request_id) -@users_router.post("/requests/{request_id}/reject", status_code=204) +@users_router.post("/requests/{request_id}/reject", status_code=204, dependencies=[Depends(require_admin)]) async def reject_request(request_id: int, db: Session = Depends(get_session)): """ Reject a coach request diff --git a/backend/src/app/utils/dependencies.py b/backend/src/app/utils/dependencies.py index 84f1929c8..8d7800f2b 100644 --- a/backend/src/app/utils/dependencies.py +++ b/backend/src/app/utils/dependencies.py @@ -25,8 +25,6 @@ def get_edition(edition_id: int, database: Session = Depends(get_session)) -> Ed async def get_current_active_user(db: Session = Depends(get_session), token: str = Depends(oauth2_scheme)) -> User: """Check which user is making a request by decoding its token This function is used as a dependency for other functions - TODO check if user has any pending coach requests - requires coach request logic to be done """ try: payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) @@ -47,11 +45,23 @@ async def get_current_active_user(db: Session = Depends(get_session), token: str raise InvalidCredentialsException() from jwt_err -# Alias that is easier to read in the dependency list when -# the return value isn't required -# Require the user to be authorized, coach or admin doesn't matter -require_authorization = get_current_active_user -require_auth = get_current_active_user +async def require_auth(user: User = Depends(get_current_active_user)) -> User: + """Dependency to check if a user is at least a coach + This dependency should be used to check for resources that aren't linked to + editions + + The function checks if the user is either an admin, or a coach with at least + one UserRole (meaning they have been accepted for at least one edition) + """ + # Admins can see everything + if user.admin: + return user + + # Coach is not in any editions (yet) + if len(user.editions) == 0: + raise MissingPermissionsException() + + return user async def require_admin(user: User = Depends(get_current_active_user)) -> User: @@ -62,6 +72,21 @@ async def require_admin(user: User = Depends(get_current_active_user)) -> User: return user +async def require_coach(edition: Edition = Depends(get_edition), user: User = Depends(get_current_active_user)) -> User: + """Dependency to check if a user can see a given resource + This comes down to checking if a coach is linked to an edition or not + """ + # Admins can see everything in any edition + if user.admin: + return user + + # Coach is not part of this edition + if edition not in user.editions: + raise MissingPermissionsException() + + return user + + def get_invite_link(invite_uuid: str, db: Session = Depends(get_session)) -> InviteLink: """Get an invite link from the database, given the id in the path""" return get_invite_link_by_uuid(db, invite_uuid) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index fb1797b33..cda2243b8 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -11,6 +11,8 @@ from src.database.database import get_session from src.database.engine import engine +from tests.utils.authorization import AuthClient + @pytest.fixture(scope="session") def tables(): @@ -55,3 +57,18 @@ def override_get_session() -> Generator[Session, None, None]: # Replace get_session with a call to this method instead app.dependency_overrides[get_session] = override_get_session return TestClient(app) + + +@pytest.fixture +def auth_client(database_session: Session) -> AuthClient: + """Fixture to get a TestClient that handles authentication""" + + def override_get_session() -> Generator[Session, None, None]: + """Inner function to override the Session used in the app + A session provided by a fixture will be used instead + """ + yield database_session + + # Replace get_session with a call to this method instead + app.dependency_overrides[get_session] = override_get_session + return AuthClient(database_session, app) diff --git a/backend/tests/test_routers/test_editions/test_editions/test_editions.py b/backend/tests/test_routers/test_editions/test_editions/test_editions.py index 6f20b3ac3..86643afa8 100644 --- a/backend/tests/test_routers/test_editions/test_editions/test_editions.py +++ b/backend/tests/test_routers/test_editions/test_editions/test_editions.py @@ -1,24 +1,25 @@ -from json import dumps from sqlalchemy.orm import Session from starlette import status -from starlette.testclient import TestClient from src.database.models import Edition +from tests.utils.authorization import AuthClient -def test_get_editions(database_session: Session, test_client: TestClient): - """Performe tests on getting editions +def test_get_editions(database_session: Session, auth_client: AuthClient): + """Perform tests on getting editions Args: database_session (Session): a connection with the database - test_client (TestClient): a client used to do rest calls + auth_client (AuthClient): a client used to do rest calls """ - edition = Edition(year = 2022) + edition = Edition(year=2022) database_session.add(edition) database_session.commit() + auth_client.coach(edition) + # Make the get request - response = test_client.get("/editions/") + response = auth_client.get("/editions/") assert response.status_code == status.HTTP_200_OK response = response.json() @@ -26,63 +27,149 @@ def test_get_editions(database_session: Session, test_client: TestClient): assert response["editions"][0]["editionId"] == 1 -def test_get_edition_by_id(database_session: Session, test_client: TestClient): - """Performe tests on getting editions by ids +def test_get_edition_by_id_admin(database_session: Session, auth_client: AuthClient): + """Test getting an edition as an admin""" + auth_client.admin() + + edition = Edition(year=2022) + database_session.add(edition) + database_session.commit() + + response = auth_client.get(f"/editions/{edition.edition_id}") + assert response.status_code == status.HTTP_200_OK + + +def test_get_edition_by_id_coach(database_session: Session, auth_client: AuthClient): + """Perform tests on getting editions by ids Args: database_session (Session): a connection with the database - test_client (TestClient): a client used to do rest calls + auth_client (AuthClient): a client used to do rest calls """ - edition = Edition(year = 2022) + edition = Edition(year=2022) database_session.add(edition) database_session.commit() - database_session.refresh(edition) + + auth_client.coach(edition) # Make the get request - response = test_client.get(f"/editions/{edition.edition_id}") + response = auth_client.get(f"/editions/{edition.edition_id}") assert response.status_code == status.HTTP_200_OK assert response.json()["year"] == 2022 assert response.json()["editionId"] == edition.edition_id -def test_create_edition(database_session: Session, test_client: TestClient): - """Performe tests on creating editions +def test_get_edition_by_id_unauthorized(database_session: Session, auth_client: AuthClient): + """Test getting an edition without access token""" + edition = Edition(year=2022) + database_session.add(edition) + database_session.commit() + + assert auth_client.get(f"/editions/1").status_code == status.HTTP_401_UNAUTHORIZED + + +def test_get_edition_by_id_not_coach(database_session: Session, auth_client: AuthClient): + """Test getting an edition without being a coach in it""" + edition = Edition(year=2022) + database_session.add(edition) + + coach_edition = Edition(year=2021) + database_session.add(coach_edition) + + database_session.commit() + + # Sign in as a coach in a different edition + auth_client.coach(coach_edition) + + assert auth_client.get(f"/editions/{edition.edition_id}").status_code == status.HTTP_403_FORBIDDEN + + +def test_create_edition_admin(database_session: Session, auth_client: AuthClient): + """Test creating an edition as an admin""" + auth_client.admin() - Args: - database_session (Session): a connection with the database - test_client (TestClient): a client used to do rest calls - """ # Verify that editions doesn't exist yet - assert test_client.get("/editions/1/").status_code == status.HTTP_404_NOT_FOUND + assert auth_client.get("/editions/1/").status_code == status.HTTP_404_NOT_FOUND # Make the post request - response = test_client.post("/editions/", json={"year": 2022}) + response = auth_client.post("/editions/", json={"year": 2022}) + assert response.status_code == status.HTTP_201_CREATED + assert auth_client.get("/editions/").json()["editions"][0]["year"] == 2022 + assert auth_client.get("/editions/").json()["editions"][0]["editionId"] == 1 + assert auth_client.get("/editions/1/").status_code == status.HTTP_200_OK + + +def test_create_edition_unauthorized(database_session: Session, auth_client: AuthClient): + """Test creating an edition without any credentials""" + assert auth_client.post("/editions/", json={"year": 2022}).status_code == status.HTTP_401_UNAUTHORIZED + + +def test_create_edition_coach(database_session: Session, auth_client: AuthClient): + """Test creating an edition as a coach""" + edition = Edition(year=2022) + database_session.add(edition) + database_session.commit() + + auth_client.coach(edition) + + assert auth_client.post("/editions/", json={"year": 2022}).status_code == status.HTTP_403_FORBIDDEN + + +def test_create_edition_existing_year(database_session: Session, auth_client: AuthClient): + """Test that creating an edition for a year that already exists throws an error""" + auth_client.admin() + + response = auth_client.post("/editions/", json={"year": 2022}) assert response.status_code == status.HTTP_201_CREATED - assert test_client.get("/editions/").json()["editions"][0]["year"] == 2022 - assert test_client.get("/editions/").json()["editions"][0]["editionId"] == 1 - assert test_client.get("/editions/1/").status_code == status.HTTP_200_OK # Try to make an edition in the same year - #response = test_client.post("/editions/", json={"year": 2022}) - #assert response.status_code == status.HTTP_409_CONFLICT + response = auth_client.post("/editions/", json={"year": 2022}) + assert response.status_code == status.HTTP_409_CONFLICT -def test_delete_edition(database_session: Session, test_client: TestClient): - """Performe tests on deleting editions +def test_delete_edition_admin(database_session: Session, auth_client: AuthClient): + """Perform tests on deleting editions Args: database_session (Session): a connection with the database - test_client (TestClient): a client used to do rest calls + auth_client (AuthClient): a client used to do rest calls """ - edition = Edition(year = 2022) + auth_client.admin() + + edition = Edition(year=2022) database_session.add(edition) database_session.commit() - database_session.refresh(edition) # Make the delete request - response = test_client.delete(f"/editions/{edition.edition_id}") + response = auth_client.delete(f"/editions/{edition.edition_id}") assert response.status_code == status.HTTP_204_NO_CONTENT - # Try to make a delete on an editions that doesn't exist - response = test_client.delete("/edition/1") - assert response.status_code == status.HTTP_404_NOT_FOUND \ No newline at end of file + +def test_delete_edition_unauthorized(database_session: Session, auth_client: AuthClient): + """Test deleting an edition without any credentials""" + edition = Edition(year=2022) + database_session.add(edition) + database_session.commit() + + # Make the delete request + assert auth_client.delete(f"/editions/{edition.edition_id}").status_code == status.HTTP_401_UNAUTHORIZED + + +def test_delete_edition_coach(database_session: Session, auth_client: AuthClient): + """Test deleting an edition as a coach""" + edition = Edition(year=2022) + database_session.add(edition) + database_session.commit() + + auth_client.coach(edition) + + # Make the delete request + assert auth_client.delete(f"/editions/{edition.edition_id}").status_code == status.HTTP_403_FORBIDDEN + + +def test_delete_edition_non_existing(database_session: Session, auth_client: AuthClient): + """Delete an edition that doesn't exist""" + auth_client.admin() + + response = auth_client.delete("/edition/1") + assert response.status_code == status.HTTP_404_NOT_FOUND diff --git a/backend/tests/test_routers/test_editions/test_invites/test_invites.py b/backend/tests/test_routers/test_editions/test_invites/test_invites.py index 429600bd1..d82b481fa 100644 --- a/backend/tests/test_routers/test_editions/test_invites/test_invites.py +++ b/backend/tests/test_routers/test_editions/test_invites/test_invites.py @@ -3,31 +3,33 @@ from sqlalchemy.orm import Session from starlette import status -from starlette.testclient import TestClient from src.database.models import Edition, InviteLink +from tests.utils.authorization import AuthClient -def test_get_empty_invites(database_session: Session, test_client: TestClient): +def test_get_empty_invites(database_session: Session, auth_client: AuthClient): """Test endpoint for getting invites when db is empty""" + auth_client.admin() database_session.add(Edition(year=2022)) database_session.commit() - response = test_client.get("/editions/1/invites") + response = auth_client.get("/editions/1/invites") assert response.status_code == status.HTTP_200_OK assert response.json() == {"inviteLinks": []} -def test_get_invites(database_session: Session, test_client: TestClient): +def test_get_invites(database_session: Session, auth_client: AuthClient): """Test endpoint for getting invites when db is not empty""" + auth_client.admin() edition = Edition(year=2022) database_session.add(edition) database_session.commit() database_session.add(InviteLink(target_email="test@ema.il", edition=edition)) database_session.commit() - response = test_client.get("/editions/1/invites") + response = auth_client.get("/editions/1/invites") assert response.status_code == status.HTTP_200_OK json = response.json() @@ -38,47 +40,51 @@ def test_get_invites(database_session: Session, test_client: TestClient): assert link["editionId"] == 1 -def test_create_invite_valid(database_session: Session, test_client: TestClient): +def test_create_invite_valid(database_session: Session, auth_client: AuthClient): """Test endpoint for creating invites when data is valid""" + auth_client.admin() edition = Edition(year=2022) database_session.add(edition) database_session.commit() # Create POST request - response = test_client.post("/editions/1/invites/", data=dumps({"email": "test@ema.il"})) + response = auth_client.post("/editions/1/invites/", data=dumps({"email": "test@ema.il"})) assert response.status_code == status.HTTP_201_CREATED json = response.json() assert "mailTo" in json assert json["mailTo"].startswith("mailto:test@ema.il") # New entry made in database - json = test_client.get("/editions/1/invites/").json() + json = auth_client.get("/editions/1/invites/").json() assert len(json["inviteLinks"]) == 1 new_uuid = json["inviteLinks"][0]["uuid"] - assert test_client.get(f"/editions/1/invites/{new_uuid}/").status_code == status.HTTP_200_OK + assert auth_client.get(f"/editions/1/invites/{new_uuid}/").status_code == status.HTTP_200_OK -def test_create_invite_invalid(database_session: Session, test_client: TestClient): +def test_create_invite_invalid(database_session: Session, auth_client: AuthClient): """Test endpoint for creating invites when data is invalid""" + auth_client.admin() edition = Edition(year=2022) database_session.add(edition) database_session.commit() # Invalid POST will send invalid status code - response = test_client.post("/editions/1/invites/", data=dumps({"email": "invalid field"})) + response = auth_client.post("/editions/1/invites/", data=dumps({"email": "invalid field"})) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY # Verify that no new entry was made after the error - assert len(test_client.get("/editions/1/invites/").json()["inviteLinks"]) == 0 + assert len(auth_client.get("/editions/1/invites/").json()["inviteLinks"]) == 0 -def test_delete_invite_invalid(database_session: Session, test_client: TestClient): +def test_delete_invite_invalid(database_session: Session, auth_client: AuthClient): """Test endpoint for deleting invites when uuid is malformed""" - assert test_client.delete("/editions/1/invites/1").status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + auth_client.admin() + assert auth_client.delete("/editions/1/invites/1").status_code == status.HTTP_422_UNPROCESSABLE_ENTITY -def test_delete_invite_valid(database_session: Session, test_client: TestClient): +def test_delete_invite_valid(database_session: Session, auth_client: AuthClient): """Test endpoint for deleting invites when uuid is valid""" + auth_client.admin() edition = Edition(year=2022) database_session.add(edition) database_session.commit() @@ -86,7 +92,7 @@ def test_delete_invite_valid(database_session: Session, test_client: TestClient) debug_uuid = "123e4567-e89b-12d3-a456-426614174000" # Not present yet - assert test_client.delete(f"/editions/1/invites/{debug_uuid}").status_code == status.HTTP_404_NOT_FOUND + assert auth_client.delete(f"/editions/1/invites/{debug_uuid}").status_code == status.HTTP_404_NOT_FOUND # Create new entry in db invite_link = InviteLink(target_email="test@ema.il", edition=edition, uuid=UUID(debug_uuid)) @@ -94,34 +100,37 @@ def test_delete_invite_valid(database_session: Session, test_client: TestClient) database_session.commit() # Remove - assert test_client.delete(f"/editions/1/invites/{invite_link.uuid}").status_code == status.HTTP_204_NO_CONTENT + assert auth_client.delete(f"/editions/1/invites/{invite_link.uuid}").status_code == status.HTTP_204_NO_CONTENT # Not found anymore - assert test_client.get(f"/editions/1/invites/{invite_link.uuid}/").status_code == status.HTTP_404_NOT_FOUND + assert auth_client.get(f"/editions/1/invites/{invite_link.uuid}/").status_code == status.HTTP_404_NOT_FOUND -def test_get_invite_malformed_uuid(database_session: Session, test_client: TestClient): +def test_get_invite_malformed_uuid(database_session: Session, auth_client: AuthClient): """Test endpoint for fetching invites when uuid is malformed""" + auth_client.admin() edition = Edition(year=2022) database_session.add(edition) database_session.commit() # Verify malformed uuid (1) - assert test_client.get("/editions/1/invites/1/").status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert auth_client.get("/editions/1/invites/1/").status_code == status.HTTP_422_UNPROCESSABLE_ENTITY -def test_get_invite_non_existing(database_session: Session, test_client: TestClient): +def test_get_invite_non_existing(database_session: Session, auth_client: AuthClient): """Test endpoint for fetching invites when uuid is valid but doesn't exist""" + auth_client.admin() edition = Edition(year=2022) database_session.add(edition) database_session.commit() - assert test_client.get( + assert auth_client.get( "/editions/1/invites/123e4567-e89b-12d3-a456-426614174000").status_code == status.HTTP_404_NOT_FOUND -def test_get_invite_present(database_session: Session, test_client: TestClient): +def test_get_invite_present(database_session: Session, auth_client: AuthClient): """Test endpoint to fetch an invite when one is present""" + auth_client.admin() edition = Edition(year=2022) database_session.add(edition) database_session.commit() @@ -134,7 +143,7 @@ def test_get_invite_present(database_session: Session, test_client: TestClient): database_session.commit() # Found the correct result now - response = test_client.get(f"/editions/1/invites/{debug_uuid}") + response = auth_client.get(f"/editions/1/invites/{debug_uuid}") json = response.json() assert response.status_code == status.HTTP_200_OK assert json["uuid"] == debug_uuid diff --git a/backend/tests/test_routers/test_editions/test_webhooks/test_webhooks.py b/backend/tests/test_routers/test_editions/test_webhooks/test_webhooks.py index c64693267..30cc5d3d9 100644 --- a/backend/tests/test_routers/test_editions/test_webhooks/test_webhooks.py +++ b/backend/tests/test_routers/test_editions/test_webhooks/test_webhooks.py @@ -7,6 +7,7 @@ from starlette import status from src.database.models import Edition, WebhookURL, Student +from tests.utils.authorization import AuthClient from .data import create_webhook_event, WEBHOOK_EVENT_BAD_FORMAT, WEBHOOK_MISSING_QUESTION @@ -26,15 +27,17 @@ def webhook(edition: Edition, database_session: Session) -> WebhookURL: return webhook -def test_new_webhook(test_client: TestClient, edition: Edition): - response = test_client.post(f"/editions/{edition.edition_id}/webhooks/") +def test_new_webhook(auth_client: AuthClient, edition: Edition): + auth_client.admin() + response = auth_client.post(f"/editions/{edition.edition_id}/webhooks/") assert response.status_code == status.HTTP_201_CREATED assert 'uuid' in response.json() assert UUID(response.json()['uuid']) -def test_new_webhook_invalid_edition(test_client: TestClient, edition: Edition): - response = test_client.post(f"/editions/0/webhooks/") +def test_new_webhook_invalid_edition(auth_client: AuthClient, edition: Edition): + auth_client.admin() + response = auth_client.post(f"/editions/0/webhooks/") assert response.status_code == status.HTTP_404_NOT_FOUND diff --git a/backend/tests/test_routers/test_skills/test_skills.py b/backend/tests/test_routers/test_skills/test_skills.py index a8cc4360f..bc9b51b06 100644 --- a/backend/tests/test_routers/test_skills/test_skills.py +++ b/backend/tests/test_routers/test_skills/test_skills.py @@ -1,24 +1,25 @@ from json import dumps from sqlalchemy.orm import Session from starlette import status -from starlette.testclient import TestClient from src.database.models import Skill +from tests.utils.authorization import AuthClient -def test_get_skills(database_session: Session, test_client: TestClient): +def test_get_skills(database_session: Session, auth_client: AuthClient): """Performe tests on getting skills Args: database_session (Session): a connection with the database - test_client (TestClient): a client used to do rest calls + auth_client (AuthClient): a client used to do rest calls """ + auth_client.admin() skill = Skill(name="Backend", description = "Must know react") database_session.add(skill) database_session.commit() # Make the get request - response = test_client.get("/skills/") + response = auth_client.get("/skills/") assert response.status_code == status.HTTP_200_OK response = response.json() @@ -26,35 +27,43 @@ def test_get_skills(database_session: Session, test_client: TestClient): assert response["skills"][0]["description"] == "Must know react" -def test_create_skill(database_session: Session, test_client: TestClient): +def test_create_skill(database_session: Session, auth_client: AuthClient): """Performe tests on creating skills Args: database_session (Session): a connection with the database - test_client (TestClient): a client used to do rest calls + auth_client (AuthClient): a client used to do rest calls """ + auth_client.admin() + # Make the post request - response = test_client.post("/skills/", data=dumps({"name": "Backend", "description": "must know react"})) + response = auth_client.post("/skills/", data=dumps({"name": "Backend", "description": "must know react"})) assert response.status_code == status.HTTP_201_CREATED - assert test_client.get("/skills/").json()["skills"][0]["name"] == "Backend" - assert test_client.get("/skills/").json()["skills"][0]["description"] == "must know react" + assert auth_client.get("/skills/").json()["skills"][0]["name"] == "Backend" + assert auth_client.get("/skills/").json()["skills"][0]["description"] == "must know react" -def test_delete_skill(database_session: Session, test_client: TestClient): +def test_delete_skill(database_session: Session, auth_client: AuthClient): """Performe tests on deleting skills Args: database_session (Session): a connection with the database - test_client (TestClient): a client used to do rest calls + auth_client (AuthClient): a client used to do rest calls """ + auth_client.admin() + skill = Skill(name="Backend", description = "Must know react") database_session.add(skill) database_session.commit() database_session.refresh(skill) - response = test_client.delete(f"/skills/{skill.skill_id}") + response = auth_client.delete(f"/skills/{skill.skill_id}") assert response.status_code == status.HTTP_204_NO_CONTENT - # Try to make a delete on an editions that doesn't exist - response = test_client.delete("/skills/1") - assert response.status_code == status.HTTP_404_NOT_FOUND \ No newline at end of file + +def test_delete_skill_non_existing(database_session: Session, auth_client: AuthClient): + """Delete a skill that doesn't exist""" + auth_client.admin() + + response = auth_client.delete("/skills/1") + assert response.status_code == status.HTTP_404_NOT_FOUND diff --git a/backend/tests/test_routers/test_users/test_users.py b/backend/tests/test_routers/test_users/test_users.py index 502d1a026..7191e968c 100644 --- a/backend/tests/test_routers/test_users/test_users.py +++ b/backend/tests/test_routers/test_users/test_users.py @@ -8,12 +8,12 @@ from src.database import models from src.database.models import user_editions, CoachRequest +from tests.utils.authorization import AuthClient @pytest.fixture -def data(database_session: Session) -> dict[str, str]: +def data(database_session: Session) -> dict[str, str | int]: """Fill database with dummy data""" - # Create users user1 = models.User(name="user1", email="user1@mail.com", admin=True) @@ -43,85 +43,87 @@ def data(database_session: Session) -> dict[str, str]: } -def test_get_all_users(database_session: Session, test_client: TestClient, data: dict[str, str]): +def test_get_all_users(database_session: Session, auth_client: AuthClient, data: dict[str, str | int]): """Test endpoint for getting a list of users""" - + auth_client.admin() # All users - response = test_client.get("/users") + response = auth_client.get("/users") assert response.status_code == status.HTTP_200_OK user_ids = [user["userId"] for user in response.json()['users']] + user_ids.remove(auth_client.user.user_id) assert len(user_ids) == 2 assert data["user1"] in user_ids assert data["user2"] in user_ids -def test_get_all_admins(database_session: Session, test_client: TestClient, data: dict[str, str]): +def test_get_all_admins(database_session: Session, auth_client: AuthClient, data: dict[str, str | int]): """Test endpoint for getting a list of admins""" - + auth_client.admin() # All admins - response = test_client.get("/users?admin=true") + response = auth_client.get("/users?admin=true") assert response.status_code == status.HTTP_200_OK user_ids = [user["userId"] for user in response.json()['users']] + user_ids.remove(auth_client.user.user_id) assert [data["user1"]] == user_ids -def test_get_users_from_edition(database_session: Session, test_client: TestClient, data: dict[str, str]): +def test_get_users_from_edition(database_session: Session, auth_client: AuthClient, data: dict[str, str | int]): """Test endpoint for getting a list of users from a given edition""" - + auth_client.admin() # All users from edition - response = test_client.get(f"/users?edition={data['edition2']}") + response = auth_client.get(f"/users?edition={data['edition2']}") assert response.status_code == status.HTTP_200_OK user_ids = [user["userId"] for user in response.json()['users']] assert [data["user2"]] == user_ids -def test_get_admins_from_edition(database_session: Session, test_client: TestClient, data: dict[str, str]): +def test_get_admins_from_edition(database_session: Session, auth_client: AuthClient, data: dict[str, str | int]): """Test endpoint for getting a list of admins from a given edition""" - + auth_client.admin() # All admins from edition - response = test_client.get(f"/users?admin=true&edition={data['edition1']}") + response = auth_client.get(f"/users?admin=true&edition={data['edition1']}") assert response.status_code == status.HTTP_200_OK user_ids = [user["userId"] for user in response.json()['users']] assert [data["user1"]] == user_ids - response = test_client.get(f"/users?admin=true&edition={data['edition2']}") + response = auth_client.get(f"/users?admin=true&edition={data['edition2']}") assert response.status_code == status.HTTP_200_OK assert len(response.json()['users']) == 0 -def test_get_users_invalid(database_session: Session, test_client: TestClient, data: dict[str, str]): +def test_get_users_invalid(database_session: Session, auth_client: AuthClient, data: dict[str, str | int]): """Test endpoint for unvalid input""" - + auth_client.admin() # Invalid input - response = test_client.get("/users?admin=INVALID") + response = auth_client.get("/users?admin=INVALID") assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - response = test_client.get("/users?edition=INVALID") + response = auth_client.get("/users?edition=INVALID") assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY -def test_edit_admin_status(database_session: Session, test_client: TestClient): +def test_edit_admin_status(database_session: Session, auth_client: AuthClient): """Test endpoint for editing the admin status of a user""" - + auth_client.admin() # Create user user = models.User(name="user1", email="user1@mail.com", admin=False) database_session.add(user) database_session.commit() - response = test_client.patch(f"/users/{user.user_id}", + response = auth_client.patch(f"/users/{user.user_id}", data=dumps({"admin": True})) assert response.status_code == status.HTTP_204_NO_CONTENT assert user.admin - response = test_client.patch(f"/users/{user.user_id}", + response = auth_client.patch(f"/users/{user.user_id}", data=dumps({"admin": False})) assert response.status_code == status.HTTP_204_NO_CONTENT assert not user.admin -def test_coach(database_session: Session, test_client: TestClient): +def test_add_coach(database_session: Session, auth_client: AuthClient): """Test endpoint for adding coaches""" - + auth_client.admin() # Create user user = models.User(name="user1", email="user1@mail.com", admin=False) database_session.add(user) @@ -133,16 +135,16 @@ def test_coach(database_session: Session, test_client: TestClient): database_session.commit() # Add coach - response = test_client.post(f"/users/{user.user_id}/editions/{edition.edition_id}") + response = auth_client.post(f"/users/{user.user_id}/editions/{edition.edition_id}") assert response.status_code == status.HTTP_204_NO_CONTENT coach = database_session.query(user_editions).one() assert coach.user_id == user.user_id assert coach.edition_id == edition.edition_id -def test_remove_coach(database_session: Session, test_client: TestClient): +def test_remove_coach(database_session: Session, auth_client: AuthClient): """Test endpoint for removing coaches""" - + auth_client.admin() # Create user user = models.User(name="user1", email="user1@mail.com") database_session.add(user) @@ -160,14 +162,15 @@ def test_remove_coach(database_session: Session, test_client: TestClient): database_session.commit() # Remove coach - response = test_client.delete(f"/users/{user.user_id}/editions/{edition.edition_id}") + response = auth_client.delete(f"/users/{user.user_id}/editions/{edition.edition_id}") assert response.status_code == status.HTTP_204_NO_CONTENT coach = database_session.query(user_editions).all() assert len(coach) == 0 -def test_get_all_requests(database_session: Session, test_client: TestClient): +def test_get_all_requests(database_session: Session, auth_client: AuthClient): """Test endpoint for getting all userrequests""" + auth_client.admin() # Create user user1 = models.User(name="user1", email="user1@mail.com") @@ -191,7 +194,7 @@ def test_get_all_requests(database_session: Session, test_client: TestClient): database_session.commit() - response = test_client.get(f"/users/requests") + response = auth_client.get(f"/users/requests") assert response.status_code == status.HTTP_200_OK user_ids = [request["user"]["userId"] for request in response.json()['requests']] assert len(user_ids) == 2 @@ -199,8 +202,9 @@ def test_get_all_requests(database_session: Session, test_client: TestClient): assert user2.user_id in user_ids -def test_get_all_requests_from_edition(database_session: Session, test_client: TestClient): +def test_get_all_requests_from_edition(database_session: Session, auth_client: AuthClient): """Test endpoint for getting all userrequests of a given edition""" + auth_client.admin() # Create user user1 = models.User(name="user1", email="user1@mail.com") @@ -224,22 +228,22 @@ def test_get_all_requests_from_edition(database_session: Session, test_client: T database_session.commit() - response = test_client.get(f"/users/requests?edition={edition1.edition_id}") + response = auth_client.get(f"/users/requests?edition={edition1.edition_id}") assert response.status_code == status.HTTP_200_OK requests = response.json()['requests'] assert len(requests) == 1 assert user1.user_id == requests[0]["user"]["userId"] - response = test_client.get(f"/users/requests?edition={edition2.edition_id}") + response = auth_client.get(f"/users/requests?edition={edition2.edition_id}") assert response.status_code == status.HTTP_200_OK requests = response.json()['requests'] assert len(requests) == 1 assert user2.user_id == requests[0]["user"]["userId"] -def test_accept_request(database_session, test_client: TestClient): +def test_accept_request(database_session, auth_client: AuthClient): """Test endpoint for accepting a coach request""" - + auth_client.admin() # Create user user1 = models.User(name="user1", email="user1@mail.com") database_session.add(user1) @@ -256,16 +260,16 @@ def test_accept_request(database_session, test_client: TestClient): database_session.commit() - response = test_client.post(f"users/requests/{request1.request_id}/accept") + response = auth_client.post(f"users/requests/{request1.request_id}/accept") assert response.status_code == status.HTTP_204_NO_CONTENT assert len(user1.editions) == 1 assert user1.editions[0].edition_id == edition1.edition_id -def test_reject_request(database_session, test_client: TestClient): +def test_reject_request(database_session, auth_client: AuthClient): """Test endpoint for rejecting a coach request""" - + auth_client.admin() # Create user user1 = models.User(name="user1", email="user1@mail.com") database_session.add(user1) @@ -282,11 +286,11 @@ def test_reject_request(database_session, test_client: TestClient): database_session.commit() - response = test_client.post(f"users/requests/{request1.request_id}/reject") + response = auth_client.post(f"users/requests/{request1.request_id}/reject") assert response.status_code == status.HTTP_204_NO_CONTENT requests = database_session.query(CoachRequest).all() assert len(requests) == 0 - response = test_client.post("users/requests/INVALID/reject") + response = auth_client.post("users/requests/INVALID/reject") assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY diff --git a/backend/tests/test_routers/test_editions/test_users/__init__.py b/backend/tests/utils/__init__.py similarity index 100% rename from backend/tests/test_routers/test_editions/test_users/__init__.py rename to backend/tests/utils/__init__.py diff --git a/backend/tests/utils/authorization/__init__.py b/backend/tests/utils/authorization/__init__.py new file mode 100644 index 000000000..c1f4796ca --- /dev/null +++ b/backend/tests/utils/authorization/__init__.py @@ -0,0 +1 @@ +from .auth_client import AuthClient diff --git a/backend/tests/utils/authorization/auth_client.py b/backend/tests/utils/authorization/auth_client.py new file mode 100644 index 000000000..2b120c772 --- /dev/null +++ b/backend/tests/utils/authorization/auth_client.py @@ -0,0 +1,92 @@ +from datetime import timedelta +from typing import Text + +from requests import Response +from sqlalchemy.orm import Session +from starlette.testclient import TestClient + +from src.app.logic.security import create_access_token +from src.database.models import User, Edition + + +class AuthClient(TestClient): + """Custom TestClient that handles authentication to make tests more compact""" + user: User | None = None + headers: dict[str, str] | None = None + session: Session + + def __init__(self, session: Session, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.session = session + + def invalid(self): + """Sign in with an invalid access token""" + self.headers = { + "Authorization": "Bearer If I can't scuba, then what has this all been about? What am I working towards?" + } + + def admin(self): + """Sign in as an admin for all future requests""" + # Create a new user in the db + admin = User(name="Pytest Admin", email="admin@pytest.email", admin=True) + self.session.add(admin) + self.session.commit() + + self.login(admin) + + def coach(self, edition: Edition): + """Sign in as a coach for all future requests + Assigns the coach to the edition + """ + # Create a new user in the db + coach = User(name="Pytest Coach", email="coach@pytest.email", admin=False) + + # Link the coach to the edition + coach.editions.append(edition) + self.session.add(coach) + self.session.commit() + + self.login(coach) + + def login(self, user: User): + """Sign in as a user for all future requests""" + self.user = user + + access_token_expires = timedelta(hours=24) + access_token = create_access_token( + data={"sub": str(user.user_id)}, expires_delta=access_token_expires + ) + + # Add auth headers into dict + self.headers = {"Authorization": f"Bearer {access_token}"} + + def delete(self, url: Text | None, **kwargs) -> Response: + if self.headers is not None: + kwargs["headers"] = self.headers + + return super().delete(url, **kwargs) + + def get(self, url: Text | None, **kwargs) -> Response: + if self.headers is not None: + kwargs["headers"] = self.headers + + return super().get(url, **kwargs) + + def patch(self, url: Text | None, **kwargs) -> Response: + if self.headers is not None: + kwargs["headers"] = self.headers + + return super().patch(url, **kwargs) + + def post(self, url: Text | None, **kwargs) -> Response: + if self.headers is not None: + kwargs["headers"] = self.headers + + return super().post(url, **kwargs) + + def put(self, url: Text | None, **kwargs) -> Response: + if self.headers is not None: + kwargs["headers"] = self.headers + + return super().put(url, **kwargs)