Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LCFS-1654: Protect routes for compliance report #1686

Merged
merged 6 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions backend/lcfs/tests/fuel_export/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from datetime import datetime
import pytest
from unittest.mock import AsyncMock, MagicMock
from lcfs.web.api.compliance_report.schema import CompliancePeriodSchema, ComplianceReportHistorySchema, ComplianceReportOrganizationSchema, ComplianceReportStatusSchema, ComplianceReportUserSchema, SummarySchema
from lcfs.web.api.fuel_export.repo import FuelExportRepository
from lcfs.web.api.fuel_code.repo import FuelCodeRepository
from lcfs.web.api.fuel_export.services import FuelExportServices
Expand Down Expand Up @@ -45,6 +47,49 @@ def mock_compliance_report_repo():
repo = AsyncMock(spec=ComplianceReportRepository)
return repo

@pytest.fixture
def compliance_period_schema():
return CompliancePeriodSchema(
compliance_period_id=1,
description="2024",
effective_date=datetime(2024, 1, 1),
expiration_date=datetime(2024, 3, 31),
display_order=1,
)

@pytest.fixture
def compliance_report_organization_schema():
return ComplianceReportOrganizationSchema(
organization_id=1, name="Acme Corporation"
)

@pytest.fixture
def summary_schema():
return SummarySchema(summary_id=1, is_locked=False)

@pytest.fixture
def compliance_report_status_schema():
return ComplianceReportStatusSchema(compliance_report_status_id=1, status="Draft")

@pytest.fixture
def compliance_report_user_schema(compliance_report_organization_schema):
return ComplianceReportUserSchema(
first_name="John",
last_name="Doe",
organization=compliance_report_organization_schema,
)

@pytest.fixture
def compliance_report_history_schema(
compliance_report_status_schema, compliance_report_user_schema
):
return ComplianceReportHistorySchema(
compliance_report_history_id=1,
compliance_report_id=1,
status=compliance_report_status_schema,
user_profile=compliance_report_user_schema,
create_date=datetime(2024, 4, 1, 12, 0, 0),
)

@pytest.fixture
def mock_fuel_code_repo():
Expand Down
23 changes: 19 additions & 4 deletions backend/lcfs/tests/fuel_export/test_fuel_exports_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from fastapi import FastAPI
from fastapi.encoders import jsonable_encoder

from lcfs.tests.compliance_report.conftest import compliance_report_base_schema
from lcfs.web.api.compliance_report.schema import ChainedComplianceReportSchema
from lcfs.web.api.fuel_export.schema import (
FuelExportSchema,
FuelExportCreateUpdateSchema,
Expand Down Expand Up @@ -68,18 +70,24 @@ async def test_get_fuel_exports_invalid_payload(

@pytest.mark.anyio
async def test_get_fuel_exports_paginated_success(
client: AsyncClient, fastapi_app: FastAPI, set_mock_user
client: AsyncClient, fastapi_app: FastAPI, set_mock_user, compliance_report_base_schema
):
with patch(
"lcfs.web.api.fuel_export.views.FuelExportServices.get_fuel_exports_paginated"
) as mock_get_fuel_exports_paginated, patch(
"lcfs.web.api.fuel_export.views.ComplianceReportValidation.validate_organization_access"
) as mock_validate_organization_access:
) as mock_validate_organization_access, patch(
"lcfs.web.api.fuel_export.views.FuelExportServices.get_compliance_report_by_id"
) as mock_get_compliance_report_by_id:

mock_get_fuel_exports_paginated.return_value = FuelExportsSchema(
fuel_exports=[]
)
mock_validate_organization_access.return_value = True

mock_compliance_report = compliance_report_base_schema()

mock_get_compliance_report_by_id.return_value = mock_compliance_report
set_mock_user(fastapi_app, [RoleEnum.ANALYST])

url = fastapi_app.url_path_for("get_fuel_exports")
Expand All @@ -98,16 +106,23 @@ async def test_get_fuel_exports_paginated_success(

@pytest.mark.anyio
async def test_get_fuel_exports_list_success(
client: AsyncClient, fastapi_app: FastAPI, set_mock_user
client: AsyncClient, fastapi_app: FastAPI, set_mock_user, compliance_report_base_schema
):
with patch(
"lcfs.web.api.fuel_export.views.FuelExportServices.get_fuel_export_list"
) as mock_get_fuel_export_list, patch(
"lcfs.web.api.fuel_export.views.ComplianceReportValidation.validate_organization_access"
) as mock_validate_organization_access:
) as mock_validate_organization_access, patch(
"lcfs.web.api.fuel_export.views.FuelExportServices.get_compliance_report_by_id"
) as mock_get_compliance_report_by_id:

mock_get_fuel_export_list.return_value = FuelExportsSchema(fuel_exports=[])
mock_validate_organization_access.return_value = True

mock_compliance_report = compliance_report_base_schema()

mock_get_compliance_report_by_id.return_value = mock_compliance_report

set_mock_user(fastapi_app, [RoleEnum.ANALYST])

url = fastapi_app.url_path_for("get_fuel_exports")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from lcfs.db.base import UserTypeEnum, ActionTypeEnum
from lcfs.db.models.user.Role import RoleEnum
from lcfs.tests.compliance_report.conftest import compliance_report_base_schema
from lcfs.web.api.base import ComplianceReportRequestSchema
from lcfs.web.api.notional_transfer.schema import (
PaginatedNotionalTransferRequestSchema,
Expand Down Expand Up @@ -70,12 +71,20 @@ async def test_get_notional_transfers(
):
with patch(
"lcfs.web.api.notional_transfer.views.ComplianceReportValidation.validate_organization_access"
) as mock_validate_organization_access:
) as mock_validate_organization_access,patch(
"lcfs.web.api.notional_transfer.views.ComplianceReportValidation.validate_compliance_report_access"
) as mock_validate_compliance_report_access, patch(
"lcfs.web.api.notional_transfer.views.NotionalTransferServices.get_compliance_report_by_id"
) as mock_get_compliance_report_by_id:
set_mock_user(fastapi_app, [RoleEnum.SUPPLIER])
url = fastapi_app.url_path_for("get_notional_transfers")
payload = ComplianceReportRequestSchema(compliance_report_id=1).model_dump()

mock_validate_organization_access.return_value = True

mock_get_compliance_report_by_id.return_value = compliance_report_base_schema
mock_validate_compliance_report_access.return_value = True

mock_notional_transfer_service.get_notional_transfers.return_value = {
"notionalTransfers": []
}
Expand Down
12 changes: 11 additions & 1 deletion backend/lcfs/tests/other_uses/test_other_uses_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from lcfs.db.base import UserTypeEnum, ActionTypeEnum
from lcfs.db.models.user.Role import RoleEnum
from lcfs.tests.compliance_report.conftest import compliance_report_base_schema
from lcfs.web.api.base import ComplianceReportRequestSchema
from lcfs.web.api.other_uses.schema import (
PaginatedOtherUsesRequestSchema,
Expand Down Expand Up @@ -69,12 +70,21 @@ async def test_get_other_uses(
):
with patch(
"lcfs.web.api.other_uses.views.ComplianceReportValidation.validate_organization_access"
) as mock_validate_organization_access:
) as mock_validate_organization_access, patch(
"lcfs.web.api.notional_transfer.views.ComplianceReportValidation.validate_compliance_report_access"
) as mock_validate_compliance_report_access, patch(
"lcfs.web.api.notional_transfer.views.NotionalTransferServices.get_compliance_report_by_id"
) as mock_get_compliance_report_by_id:

set_mock_user(fastapi_app, [RoleEnum.SUPPLIER])
url = fastapi_app.url_path_for("get_other_uses")
payload = ComplianceReportRequestSchema(compliance_report_id=1).model_dump()

mock_validate_organization_access.return_value = True

mock_get_compliance_report_by_id.return_value = compliance_report_base_schema
mock_validate_compliance_report_access.return_value = True

mock_other_uses_service.get_other_uses.return_value = {"otherUses": []}

fastapi_app.dependency_overrides[OtherUsesServices] = (
Expand Down
20 changes: 19 additions & 1 deletion backend/lcfs/web/api/allocation_agreement/services.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import math
import structlog
from typing import List
from fastapi import Depends
from fastapi import Depends, HTTPException, status
from datetime import datetime

from lcfs.web.api.allocation_agreement.repo import AllocationAgreementRepository
from lcfs.web.api.compliance_report.repo import ComplianceReportRepository
from lcfs.web.core.decorators import service_handler
from lcfs.db.models.compliance.AllocationAgreement import AllocationAgreement
from lcfs.web.api.base import PaginationRequestSchema, PaginationResponseSchema
Expand Down Expand Up @@ -34,9 +35,11 @@ def __init__(
self,
repo: AllocationAgreementRepository = Depends(AllocationAgreementRepository),
fuel_repo: FuelCodeRepository = Depends(),
compliance_report_repo: ComplianceReportRepository = Depends(),
) -> None:
self.repo = repo
self.fuel_repo = fuel_repo
self.compliance_report_repo = compliance_report_repo

async def convert_to_model(
self, allocation_agreement: AllocationAgreementCreateSchema
Expand Down Expand Up @@ -350,3 +353,18 @@ async def create_allocation_agreement(
async def delete_allocation_agreement(self, allocation_agreement_id: int) -> str:
"""Delete an Allocation agreement"""
return await self.repo.delete_allocation_agreement(allocation_agreement_id)

@service_handler
async def get_compliance_report_by_id(self, compliance_report_id: int):
"""Get compliance report by period with status"""
compliance_report = await self.compliance_report_repo.get_compliance_report_by_id(
compliance_report_id,
)

if not compliance_report:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Compliance report not found for this period"
)

return compliance_report
30 changes: 26 additions & 4 deletions backend/lcfs/web/api/allocation_agreement/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi import (
APIRouter,
Body,
HTTPException,
status,
Request,
Response,
Expand Down Expand Up @@ -69,10 +70,31 @@ async def get_allocation_agreements(
report_validate: ComplianceReportValidation = Depends(),
):
"""Endpoint to get list of allocation agreements for a compliance report"""
await report_validate.validate_organization_access(
request_data.compliance_report_id
)
return await service.get_allocation_agreements(request_data.compliance_report_id)
try:
compliance_report_id = request_data.compliance_report_id

compliance_report = await service.get_compliance_report_by_id(compliance_report_id)
if not compliance_report:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Compliance report not found"
)

await report_validate.validate_compliance_report_access(compliance_report)
await report_validate.validate_organization_access(
request_data.compliance_report_id
)
return await service.get_allocation_agreements(request_data.compliance_report_id)
except HTTPException as http_ex:
# Re-raise HTTP exceptions to preserve status code and message
raise http_ex
except Exception as e:
# Log and handle unexpected errors
logger.exception("Error occurred", error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred while processing your request"
)


@router.post(
Expand Down
15 changes: 15 additions & 0 deletions backend/lcfs/web/api/compliance_report/validation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fastapi import Depends, HTTPException, Request
from lcfs.db.models.user.Role import RoleEnum
from lcfs.db.models.compliance.ComplianceReportStatus import ComplianceReportStatusEnum
from lcfs.web.api.compliance_report.repo import ComplianceReportRepository
from fastapi import status
from lcfs.web.api.role.schema import user_has_roles
Expand Down Expand Up @@ -41,3 +42,17 @@ async def validate_organization_access(self, compliance_report_id: int):
)

return compliance_report

async def validate_compliance_report_access(self, compliance_report):
"""Validates government user access to draft reports"""
is_government = user_has_roles(self.request.user, [RoleEnum.GOVERNMENT])

if compliance_report:
status_enum = ComplianceReportStatusEnum(compliance_report.current_status.status)
is_draft = status_enum == ComplianceReportStatusEnum.Draft

if is_government and is_draft:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Government users cannot access draft compliance reports"
)
3 changes: 2 additions & 1 deletion backend/lcfs/web/api/compliance_report/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ async def get_compliance_report_by_id(
service: ComplianceReportServices = Depends(),
validate: ComplianceReportValidation = Depends(),
) -> ChainedComplianceReportSchema:
await validate.validate_organization_access(report_id)
compliance_report = await validate.validate_organization_access(report_id)
await validate.validate_compliance_report_access(compliance_report)

mask_statuses = not user_has_roles(request.user, [RoleEnum.GOVERNMENT])

Expand Down
26 changes: 16 additions & 10 deletions backend/lcfs/web/api/final_supply_equipment/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,23 @@ async def get_organization_names(self, organization) -> List[str]:
Returns:
List[str]: A list of unique organization names.
"""
organization_names = (
await self.db.execute(
select(distinct(FinalSupplyEquipment.organization_name))
.join(ComplianceReport, FinalSupplyEquipment.compliance_report_id == ComplianceReport.compliance_report_id)
.filter(ComplianceReport.organization_id == organization.organization_id)
.filter(FinalSupplyEquipment.organization_name.isnot(None))
)
).all()
try:
if not organization or not organization.organization_id:
return []

organization_names = (
await self.db.execute(
select(distinct(FinalSupplyEquipment.organization_name))
.join(ComplianceReport, FinalSupplyEquipment.compliance_report_id == ComplianceReport.compliance_report_id)
.filter(ComplianceReport.organization_id == organization.organization_id)
.filter(FinalSupplyEquipment.organization_name.isnot(None))
)
).all()

# Extract strings from the list of tuples
return [name[0] for name in organization_names]
return [name[0] for name in organization_names]
except Exception as e:
logger.error("Error getting organization names", error=str(e))
return []

@repo_handler
async def get_intended_user_by_name(self, intended_user: str) -> EndUseType:
Expand Down
Loading
Loading