Skip to content

Commit

Permalink
Creating a family or collection now used the organisation of the user (
Browse files Browse the repository at this point in the history
…#32)

* Create a family using the organisation of the user

* Create collections with the org of the logged in user

* fixed after review

* fixed after review
  • Loading branch information
diversemix authored Oct 24, 2023
1 parent 2fb8513 commit 8c38808
Show file tree
Hide file tree
Showing 21 changed files with 146 additions and 85 deletions.
5 changes: 3 additions & 2 deletions app/api/api_v1/routers/collection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Endpoints for managing the Collection entity."""
import logging
from fastapi import APIRouter, HTTPException, status
from fastapi import APIRouter, HTTPException, Request, status
from app.errors import RepositoryError, ValidationError

from app.model.collection import (
Expand Down Expand Up @@ -130,6 +130,7 @@ async def update_collection(
status_code=status.HTTP_201_CREATED,
)
async def create_collection(
request: Request,
new_collection: CollectionCreateDTO,
) -> str:
"""
Expand All @@ -139,7 +140,7 @@ async def create_collection(
:return str: returns the import_id of the new collection.
"""
try:
return collection_service.create(new_collection)
return collection_service.create(new_collection, request.state.user.email)
except ValidationError as e:
_LOGGER.error(e.message)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.message)
Expand Down
5 changes: 3 additions & 2 deletions app/api/api_v1/routers/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
implemented directly accesses the "repository" layer.
"""
import logging
from fastapi import APIRouter, HTTPException, status
from fastapi import APIRouter, HTTPException, Request, status
from app.errors import RepositoryError, ValidationError

from app.model.family import FamilyCreateDTO, FamilyReadDTO, FamilyWriteDTO
Expand Down Expand Up @@ -119,6 +119,7 @@ async def update_family(

@r.post("/families", response_model=str, status_code=status.HTTP_201_CREATED)
async def create_family(
request: Request,
new_family: FamilyCreateDTO,
) -> str:
"""
Expand All @@ -128,7 +129,7 @@ async def create_family(
:return FamilyDTO: returns a FamilyDTO of the new family.
"""
try:
family = family_service.create(new_family)
family = family_service.create(new_family, request.state.user.email)
except ValidationError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.message)
except RepositoryError as e:
Expand Down
2 changes: 0 additions & 2 deletions app/model/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,3 @@ class CollectionCreateDTO(BaseModel):

title: str
description: str
# families: list[str] TODO: Ask Patrick if we want this as an option?
organisation: str
22 changes: 15 additions & 7 deletions app/model/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,33 @@ class FamilyReadDTO(BaseModel):


class FamilyWriteDTO(BaseModel):
"""A JSON representation of a family for writing."""
"""
A JSON representation of a family for writing.
Note:
- import_id is given from the request
- organisation is immutable
"""

# import_id: not included as this is in the request path
title: str
summary: str
geography: str
category: str
metadata: Json
# organisation: not included as once created is immutable


class FamilyCreateDTO(BaseModel):
"""A JSON representation of a family for creating."""
"""
A JSON representation of a family for creating.
Note:
- import_id is auto generated
- slug is auto generated
- organisation comes from the user's organisation
"""

# import_id: not included as generated
title: str
summary: str
geography: str
category: str
metadata: Json
# slug: not included as this is generated from title
organisation: str # FIXME: https://linear.app/climate-policy-radar/issue/PDCT-494
16 changes: 15 additions & 1 deletion app/repository/app_user.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple
from typing import Optional, Tuple, cast
from sqlalchemy.orm import Session

from app.clients.db.models.app.users import AppUser, Organisation, OrganisationUser
Expand Down Expand Up @@ -29,3 +29,17 @@ def get_app_user_authorisation(
.join(Organisation, Organisation.id == OrganisationUser.organisation_id)
)
return [(r[0], r[1]) for r in query.all()]


def get_org_id(db: Session, user_email: str) -> Optional[int]:
"""Gets the organisation id given the user's email"""
result = (
db.query(Organisation.id)
.select_from(Organisation)
.join(OrganisationUser, Organisation.id == OrganisationUser.organisation_id)
.join(AppUser, AppUser.email == user_email)
.filter(AppUser.email == OrganisationUser.appuser_email)
.scalar()
)
if result is not None:
return cast(int, result)
4 changes: 3 additions & 1 deletion app/repository/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def generate_import_id(
if type(org) == str:
org_name = org
else:
org_name = db.query(Organisation.name).filter(Organisation.id == org)
org_name = (
db.query(Organisation.name).filter(Organisation.id == org).scalar_subquery()
)

counter: EntityCounter = (
db.query(EntityCounter).filter(EntityCounter.prefix == org_name).one()
Expand Down
11 changes: 11 additions & 0 deletions app/service/app_user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from sqlalchemy.orm import Session
from app.errors import ValidationError
from app.repository import app_user_repo


def get_organisation(db: Session, user_email: str) -> int:
"""Gets a user's organisation"""
org_id = app_user_repo.get_org_id(db, user_email)
if org_id is None:
raise ValidationError(f"Could not get the organisation for user {user_email}")
return org_id
9 changes: 6 additions & 3 deletions app/service/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from sqlalchemy.orm import Session

from app.service import id
from app.service import organisation
from app.service import app_user


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -120,7 +120,9 @@ def update(

@db_session.with_transaction(__name__)
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def create(collection: CollectionCreateDTO, db: Session = db_session.get_db()) -> str:
def create(
collection: CollectionCreateDTO, user_email: str, db: Session = db_session.get_db()
) -> str:
"""
Creates a new collection with the values passed.
Expand All @@ -130,7 +132,8 @@ def create(collection: CollectionCreateDTO, db: Session = db_session.get_db()) -
:return str: The new import_id for the collection.
"""
try:
org_id = organisation.get_id(db, collection.organisation)
# Get the organisation from the user's email
org_id = app_user.get_organisation(db, user_email)

return collection_repo.create(db, collection, org_id)

Expand Down
12 changes: 10 additions & 2 deletions app/service/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from app.service import category
from app.service import organisation
from app.service import metadata
from app.service import app_user

from app.repository import family_repo

Expand Down Expand Up @@ -116,7 +117,9 @@ def update(

@db_session.with_transaction(__name__)
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def create(family: FamilyCreateDTO, db: Session = db_session.get_db()) -> str:
def create(
family: FamilyCreateDTO, user_email: str, db: Session = db_session.get_db()
) -> str:
"""
Creates a new Family with the values passed.
Expand All @@ -125,12 +128,17 @@ def create(family: FamilyCreateDTO, db: Session = db_session.get_db()) -> str:
:raises ValidationError: raised should the import_id be invalid.
:return Optional[FamilyDTO]: The new created Family or None if unsuccessful.
"""

# Get the organisation from the user's email
org_id = app_user.get_organisation(db, user_email)

# Validate geography
geo_id = geography.get_id(db, family.geography)

# Validate category
category.validate(family.category)

# Validate organisation
org_id = organisation.get_id(db, family.organisation)
metadata.validate(db, org_id, family.metadata)

return family_repo.create(db, family, geo_id, org_id)
Expand Down
19 changes: 0 additions & 19 deletions integration_tests/collection/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,22 +80,3 @@ def test_create_collection_when_db_error(
data = response.json()
assert data["detail"] == "Bad Repo"
assert bad_collection_repo.create.call_count == 1


def test_create_collection_when_org_invalid(
client: TestClient, test_db: Session, user_header_token
):
setup_db(test_db)
new_collection = create_collection_create_dto(
title="Title",
description="test test test",
)
new_collection.organisation = "chicken"
response = client.post(
"/api/v1/collections",
json=new_collection.model_dump(),
headers=user_header_token,
)
assert response.status_code == 400
data = response.json()
assert data["detail"] == "The organisation name chicken is invalid!"
17 changes: 0 additions & 17 deletions integration_tests/family/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,20 +122,3 @@ def test_create_family_when_invalid_category(
assert response.status_code == 400
data = response.json()
assert data["detail"] == "Invalid is not a valid FamilyCategory"


def test_create_family_when_invalid_org(
client: TestClient, test_db: Session, user_header_token
):
setup_db(test_db)
new_family = create_family_create_dto(
title="Title",
summary="test test test",
)
new_family.organisation = "chicken"
response = client.post(
"/api/v1/families", json=new_family.model_dump(), headers=user_header_token
)
assert response.status_code == 400
data = response.json()
assert data["detail"] == "The organisation name chicken is invalid!"
20 changes: 19 additions & 1 deletion integration_tests/setup_db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import cast
from sqlalchemy.orm import Session
from sqlalchemy import text
from app.clients.db.models.app.users import Organisation
from app.clients.db.models.app.users import AppUser, Organisation, OrganisationUser
from app.clients.db.models.document.physical_document import (
LanguageSource,
PhysicalDocument,
Expand Down Expand Up @@ -221,6 +221,24 @@ def _setup_organisation(test_db: Session) -> int:
)
)
test_db.flush()

# Also link to the test user
test_db.add(
AppUser(
email="[email protected]", name="Test", hashed_password="", is_superuser=False
)
)
test_db.flush()
test_db.add(
OrganisationUser(
appuser_email="[email protected]",
organisation_id=org.id,
job_title="",
is_active=True,
is_admin=False,
)
)
test_db.commit()
return cast(int, org.id)


Expand Down
10 changes: 10 additions & 0 deletions unit_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import app.service.token as token_service
import app.service.analytics as analytics_service
import app.service.event as event_service
import app.service.app_user as app_user_service

from app.repository import (
family_repo,
geography_repo,
Expand All @@ -41,6 +43,7 @@
from unit_tests.mocks.repos.config_repo import mock_config_repo
from unit_tests.mocks.repos.event_repo import mock_event_repo

from unit_tests.mocks.services.app_user_service import mock_app_user_service
from unit_tests.mocks.services.family_service import mock_family_service
from unit_tests.mocks.services.collection_service import mock_collection_service
from unit_tests.mocks.services.document_service import mock_document_service
Expand Down Expand Up @@ -125,6 +128,13 @@ def event_repo_mock(monkeypatch, mocker):
# ----- Mock services


@pytest.fixture
def app_user_service_mock(monkeypatch, mocker):
"""Mocks the service for a single test."""
mock_app_user_service(app_user_service, monkeypatch, mocker)
yield app_user_service


@pytest.fixture
def family_service_mock(monkeypatch, mocker):
"""Mocks the service for a single test."""
Expand Down
1 change: 0 additions & 1 deletion unit_tests/helpers/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,4 @@ def create_collection_create_dto(
return CollectionCreateDTO(
title=title,
description=description,
organisation="CCLW",
)
1 change: 0 additions & 1 deletion unit_tests/helpers/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def create_family_create_dto(
geography=geography,
category=category,
metadata=metadata,
organisation="CCLW",
)


Expand Down
11 changes: 9 additions & 2 deletions unit_tests/mocks/repos/app_user_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
HASH_PASSWORD = auth_service.get_password_hash(PLAIN_PASSWORD)
VALID_USERNAME = "[email protected]"

ORG_ID = 1234


def mock_app_user_repo(app_user_repo, monkeypatch: MonkeyPatch, mocker):
app_user_repo.user_active = True

app_user_repo.error = False

def mock_get_app_user_authorisation(
_, __
) -> list[Tuple[OrganisationUser, Organisation]]:
Expand All @@ -28,15 +31,19 @@ def mock_get_user_by_email(_, __) -> MaybeAppUser:
is_superuser=True,
)

def mock_get_org_id(_, user_email: str) -> int:
return ORG_ID

def mock_is_active(_, email: str) -> bool:
return app_user_repo.user_active

app_user_repo.error = False
monkeypatch.setattr(app_user_repo, "get_user_by_email", mock_get_user_by_email)
monkeypatch.setattr(app_user_repo, "get_org_id", mock_get_org_id)
monkeypatch.setattr(app_user_repo, "is_active", mock_is_active)
monkeypatch.setattr(
app_user_repo, "get_app_user_authorisation", mock_get_app_user_authorisation
)
mocker.spy(app_user_repo, "get_user_by_email")
mocker.spy(app_user_repo, "get_org_id")
mocker.spy(app_user_repo, "is_active")
mocker.spy(app_user_repo, "get_app_user_authorisation")
11 changes: 11 additions & 0 deletions unit_tests/mocks/services/app_user_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from pytest import MonkeyPatch

ORG_ID = 1234


def mock_app_user_service(app_user_service, monkeypatch: MonkeyPatch, mocker):
def mock_get_organisation(_, user_email: str) -> int:
return ORG_ID

monkeypatch.setattr(app_user_service, "get_organisation", mock_get_organisation)
mocker.spy(app_user_service, "get_organisation")
4 changes: 3 additions & 1 deletion unit_tests/mocks/services/collection_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def mock_update_collection(
return create_collection_read_dto(import_id, data.title, data.description)

# TODO: Think: Is this return value correct?
def mock_create_collection(data: CollectionWriteDTO) -> Optional[str]:
def mock_create_collection(
data: CollectionWriteDTO, user_email: str
) -> Optional[str]:
maybe_throw()
if not collection_service.missing:
return "test.new.collection.0"
Expand Down
Loading

0 comments on commit 8c38808

Please sign in to comment.