diff --git a/app/repository/document.py b/app/repository/document.py index 94028fa8..a4b390ed 100644 --- a/app/repository/document.py +++ b/app/repository/document.py @@ -43,7 +43,11 @@ def _get_query(db: Session) -> Query: return ( db.query(FamilyDocument, PhysicalDocument, Slug) .filter(FamilyDocument.physical_document_id == PhysicalDocument.id) - .filter(Slug.family_document_import_id == FamilyDocument.import_id) + .join( + Slug, + Slug.family_document_import_id == FamilyDocument.import_id, + isouter=True, + ) ) @@ -56,7 +60,7 @@ def _document_to_dto(doc_tuple: DocumentTuple) -> DocumentReadDTO: status=cast(DocumentStatus, fd.document_status), role=cast(FamilyDocumentRole, fd.document_role), type=cast(FamilyDocumentType, fd.document_type), - slug=cast(str, slug.name), + slug=cast(str, slug.name) if slug is not None else "", physical_id=cast(int, pd.id), title=cast(str, pd.title), md5_sum=cast(str, pd.md5_sum), @@ -146,7 +150,7 @@ def search(db: Session, search_term: str) -> list[DocumentReadDTO]: search = or_(PhysicalDocument.title.ilike(term)) found = _get_query(db).filter(search).all() - return [_document_to_dto(*d) for d in found] + return [_document_to_dto(d) for d in found] def update(db: Session, document: DocumentReadDTO) -> bool: @@ -160,30 +164,41 @@ def update(db: Session, document: DocumentReadDTO) -> bool: """ # TODO: Implement this: - # new_values = document.model_dump() - - # original_document = ( - # db.query(FamilyDocument) - # .filter(FamilyDocument.import_id == document.import_id) - # .one_or_none() - # ) - - # if original_document is None: # Not found the document to update - # _LOGGER.error(f"Unable to find document for update {document}") - # return False - - # result = db.execute( - # db_update(FamilyDocument) - # .where(FamilyDocument.import_id == document.import_id) - # .values( - # title=new_values["title"], - # description=new_values["description"], - # ) - # ) - # if result.rowcount == 0: # type: ignore - # msg = f"Could not update document fields: {document}" - # _LOGGER.error(msg) - # raise RepositoryError(msg) + new_values = document.model_dump() + + original_fd = ( + db.query(FamilyDocument) + .filter(FamilyDocument.import_id == document.import_id) + .one_or_none() + ) + + if original_fd is None: # Not found the document to update + _LOGGER.error(f"Unable to find document for update {document.import_id}") + return False + + original_pd = ( + db.query(PhysicalDocument) + .filter(PhysicalDocument.id == original_fd.physical_document_id) + .one_or_none() + ) + + if original_pd is None: # Not found the document to update + _LOGGER.error( + f"Unable to find document for update {original_fd.physical_document_id}" + ) + return False + + result = db.execute( + db_update(PhysicalDocument) + .where(PhysicalDocument.id == original_pd.id) + .values( + title=new_values["title"], + ) + ) + if result.rowcount == 0: # type: ignore + msg = f"Could not update document fields: {document}" + _LOGGER.error(msg) + raise RepositoryError(msg) return True diff --git a/integration_tests/collection/test_delete.py b/integration_tests/collection/test_delete.py index 55150bd3..67dec255 100644 --- a/integration_tests/collection/test_delete.py +++ b/integration_tests/collection/test_delete.py @@ -2,7 +2,7 @@ from fastapi import status from sqlalchemy.orm import Session from app.clients.db.models.law_policy.collection import Collection -from integration_tests.setup_db import setup_db +from integration_tests.setup_db import EXPECTED_NUM_COLLECTIONS, setup_db def test_delete_collection( @@ -14,7 +14,7 @@ def test_delete_collection( ) assert response.status_code == status.HTTP_200_OK n = test_db.query(Collection).count() - assert n == 1 + assert n == EXPECTED_NUM_COLLECTIONS - 1 def test_delete_collection_when_not_authenticated(client: TestClient, test_db: Session): @@ -37,7 +37,7 @@ def test_delete_collection_rollback( ) assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE n = test_db.query(Collection).count() - assert n == 2 + assert n == EXPECTED_NUM_COLLECTIONS assert rollback_collection_repo.delete.call_count == 1 @@ -52,7 +52,7 @@ def test_delete_collection_when_not_found( data = response.json() assert data["detail"] == "Collection not deleted: C.0.0.22" n = test_db.query(Collection).count() - assert n == 2 + assert n == EXPECTED_NUM_COLLECTIONS def test_delete_collection_when_db_error( diff --git a/integration_tests/document/test_get.py b/integration_tests/document/test_get.py index 99389339..354bdf2d 100644 --- a/integration_tests/document/test_get.py +++ b/integration_tests/document/test_get.py @@ -2,16 +2,16 @@ from fastapi import status from sqlalchemy.orm import Session -from integration_tests.setup_db import EXPECTED_COLLECTIONS, setup_db +from integration_tests.setup_db import EXPECTED_DOCUMENTS, setup_db # --- GET ALL -def test_get_all_collections(client: TestClient, test_db: Session, user_header_token): +def test_get_all_documents(client: TestClient, test_db: Session, user_header_token): setup_db(test_db) response = client.get( - "/api/v1/collections", + "/api/v1/documents", headers=user_header_token, ) assert response.status_code == status.HTTP_200_OK @@ -19,21 +19,19 @@ def test_get_all_collections(client: TestClient, test_db: Session, user_header_t assert type(data) is list assert len(data) == 2 ids_found = set([f["import_id"] for f in data]) - expected_ids = set(["C.0.0.1", "C.0.0.2"]) + expected_ids = set(["D.0.0.1", "D.0.0.2"]) assert ids_found.symmetric_difference(expected_ids) == set([]) sdata = sorted(data, key=lambda d: d["import_id"]) - assert sdata[0] == EXPECTED_COLLECTIONS[0] - assert sdata[1] == EXPECTED_COLLECTIONS[1] + assert sdata[0] == EXPECTED_DOCUMENTS[0] + assert sdata[1] == EXPECTED_DOCUMENTS[1] -def test_get_all_collections_when_not_authenticated( - client: TestClient, test_db: Session -): +def test_get_all_documents_when_not_authenticated(client: TestClient, test_db: Session): setup_db(test_db) response = client.get( - "/api/v1/collections", + "/api/v1/documents", ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -41,45 +39,45 @@ def test_get_all_collections_when_not_authenticated( # --- GET -def test_get_collection(client: TestClient, test_db: Session, user_header_token): +def test_get_document(client: TestClient, test_db: Session, user_header_token): setup_db(test_db) response = client.get( - "/api/v1/collections/C.0.0.1", + "/api/v1/documents/D.0.0.1", headers=user_header_token, ) assert response.status_code == status.HTTP_200_OK data = response.json() - assert data["import_id"] == "C.0.0.1" - assert data == EXPECTED_COLLECTIONS[0] + assert data["import_id"] == "D.0.0.1" + assert data == EXPECTED_DOCUMENTS[0] -def test_get_collection_when_not_authenticated(client: TestClient, test_db: Session): +def test_get_document_when_not_authenticated(client: TestClient, test_db: Session): setup_db(test_db) response = client.get( - "/api/v1/collections/C.0.0.1", + "/api/v1/documents/D.0.0.1", ) assert response.status_code == status.HTTP_401_UNAUTHORIZED -def test_get_collection_when_not_found( +def test_get_document_when_not_found( client: TestClient, test_db: Session, user_header_token ): setup_db(test_db) response = client.get( - "/api/v1/collections/C.0.0.8", + "/api/v1/documents/D.0.0.8", headers=user_header_token, ) assert response.status_code == status.HTTP_404_NOT_FOUND data = response.json() - assert data["detail"] == "Collection not found: C.0.0.8" + assert data["detail"] == "Document not found: D.0.0.8" -def test_get_collection_when_id_invalid( +def test_get_document_when_id_invalid( client: TestClient, test_db: Session, user_header_token ): setup_db(test_db) response = client.get( - "/api/v1/collections/A008", + "/api/v1/documents/A008", headers=user_header_token, ) assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -88,12 +86,12 @@ def test_get_collection_when_id_invalid( assert data["detail"] == expected_msg -def test_get_collection_when_db_error( - client: TestClient, test_db: Session, bad_collection_repo, user_header_token +def test_get_document_when_db_error( + client: TestClient, test_db: Session, bad_document_repo, user_header_token ): setup_db(test_db) response = client.get( - "/api/v1/collections/A.0.0.8", + "/api/v1/documents/A.0.0.8", headers=user_header_token, ) assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE diff --git a/integration_tests/document/test_search.py b/integration_tests/document/test_search.py index c1130655..01cc2cd4 100644 --- a/integration_tests/document/test_search.py +++ b/integration_tests/document/test_search.py @@ -4,10 +4,10 @@ from integration_tests.setup_db import setup_db -def test_search_collection(client: TestClient, test_db: Session, user_header_token): +def test_search_document(client: TestClient, test_db: Session, user_header_token): setup_db(test_db) response = client.get( - "/api/v1/collections/?q=big", + "/api/v1/documents/?q=big", headers=user_header_token, ) assert response.status_code == status.HTTP_200_OK @@ -17,38 +17,38 @@ def test_search_collection(client: TestClient, test_db: Session, user_header_tok ids_found = set([f["import_id"] for f in data]) assert len(ids_found) == 1 - expected_ids = set(["C.0.0.1"]) + expected_ids = set(["D.0.0.1"]) assert ids_found.symmetric_difference(expected_ids) == set([]) -def test_search_collection_when_not_authorised(client: TestClient, test_db: Session): +def test_search_document_when_not_authorised(client: TestClient, test_db: Session): setup_db(test_db) response = client.get( - "/api/v1/collections/?q=orange", + "/api/v1/documents/?q=orange", ) assert response.status_code == status.HTTP_401_UNAUTHORIZED -def test_search_collection_when_nothing_found( +def test_search_document_when_nothing_found( client: TestClient, test_db: Session, user_header_token ): setup_db(test_db) response = client.get( - "/api/v1/collections/?q=chicken", + "/api/v1/documents/?q=chicken", headers=user_header_token, ) assert response.status_code == status.HTTP_404_NOT_FOUND data = response.json() - assert data["detail"] == "Collections not found for term: chicken" + assert data["detail"] == "Documents not found for term: chicken" -def test_search_collection_when_db_error( - client: TestClient, test_db: Session, bad_collection_repo, user_header_token +def test_search_document_when_db_error( + client: TestClient, test_db: Session, bad_document_repo, user_header_token ): setup_db(test_db) response = client.get( - "/api/v1/collections/?q=chicken", + "/api/v1/documents/?q=chicken", headers=user_header_token, ) assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE - assert bad_collection_repo.search.call_count == 1 + assert bad_document_repo.search.call_count == 1 diff --git a/integration_tests/document/test_update.py b/integration_tests/document/test_update.py index 2f817500..de4647aa 100644 --- a/integration_tests/document/test_update.py +++ b/integration_tests/document/test_update.py @@ -1,150 +1,137 @@ +from typing import Tuple from fastapi.testclient import TestClient from fastapi import status from sqlalchemy.orm import Session -from app.clients.db.models.law_policy.collection import ( - Collection, - CollectionFamily, - CollectionOrganisation, -) -from integration_tests.setup_db import EXPECTED_COLLECTIONS, setup_db -from unit_tests.helpers.collection import create_collection_dto +from app.clients.db.models.document.physical_document import PhysicalDocument +from app.clients.db.models.law_policy.family import FamilyDocument +from integration_tests.setup_db import EXPECTED_DOCUMENTS, setup_db +from unit_tests.helpers.document import create_document_dto -def test_update_collection(client: TestClient, test_db: Session, user_header_token): + +def _get_doc_tuple( + test_db: Session, import_id: str +) -> Tuple[FamilyDocument, PhysicalDocument]: + fd: FamilyDocument = ( + test_db.query(FamilyDocument) + .filter(FamilyDocument.import_id == import_id) + .one() + ) + assert fd is not None + + pd: PhysicalDocument = ( + test_db.query(PhysicalDocument) + .filter(PhysicalDocument.id == fd.physical_document_id) + .one_or_none() + ) + assert pd is not None + + return fd, pd + + +def test_update_document(client: TestClient, test_db: Session, user_header_token): setup_db(test_db) - new_collection = create_collection_dto( - import_id="C.0.0.2", + new_document = create_document_dto( + import_id="D.0.0.2", + family_import_id="A.0.0.3", title="Updated Title", - description="just a test", ) response = client.put( - "/api/v1/collections", - json=new_collection.model_dump(), + "/api/v1/documents", + json=new_document.model_dump(), headers=user_header_token, ) assert response.status_code == status.HTTP_200_OK data = response.json() assert data["title"] == "Updated Title" - assert data["description"] == "just a test" - db_collection: Collection = ( - test_db.query(Collection).filter(Collection.import_id == "C.0.0.2").one() - ) - assert db_collection.title == "Updated Title" - assert db_collection.description == "just a test" - families = test_db.query(CollectionFamily).filter( - CollectionFamily.collection_import_id == "C.0.0.2" - ) - assert families.count() == 2 - org: CollectionOrganisation = ( - test_db.query(CollectionOrganisation) - .filter(CollectionOrganisation.collection_import_id == "C.0.0.2") - .one() - ) - assert org is not None + fd, pd = _get_doc_tuple(test_db, "D.0.0.2") + assert pd.title == "Updated Title" -def test_update_collection_when_not_authorised(client: TestClient, test_db: Session): +def test_update_document_when_not_authorised(client: TestClient, test_db: Session): setup_db(test_db) - new_collection = create_collection_dto( - import_id="A.0.0.2", + new_document = create_document_dto( + import_id="D.0.0.2", + family_import_id="A.0.0.3", title="Updated Title", - description="just a test", ) - response = client.put("/api/v1/collections", json=new_collection.model_dump()) + response = client.put("/api/v1/documents", json=new_document.model_dump()) assert response.status_code == status.HTTP_401_UNAUTHORIZED -def test_update_collection_idempotent( +def test_update_document_idempotent( client: TestClient, test_db: Session, user_header_token ): setup_db(test_db) - collection = EXPECTED_COLLECTIONS[1] - response = client.put( - "/api/v1/collections", json=collection, headers=user_header_token - ) + document = EXPECTED_DOCUMENTS[1] + response = client.put("/api/v1/documents", json=document, headers=user_header_token) assert response.status_code == status.HTTP_200_OK + data = response.json() - assert data["title"] == EXPECTED_COLLECTIONS[1]["title"] - assert data["description"] == EXPECTED_COLLECTIONS[1]["description"] - db_collection: Collection = ( - test_db.query(Collection) - .filter(Collection.import_id == EXPECTED_COLLECTIONS[1]["import_id"]) - .one() - ) - assert db_collection.title == EXPECTED_COLLECTIONS[1]["title"] - assert db_collection.description == EXPECTED_COLLECTIONS[1]["description"] + assert data["title"] == EXPECTED_DOCUMENTS[1]["title"] + + _, pd = _get_doc_tuple(test_db, EXPECTED_DOCUMENTS[1]["import_id"]) + assert pd.title == EXPECTED_DOCUMENTS[1]["title"] -def test_update_collection_rollback( - client: TestClient, test_db: Session, rollback_collection_repo, user_header_token +def test_update_document_rollback( + client: TestClient, test_db: Session, rollback_document_repo, user_header_token ): setup_db(test_db) - new_collection = create_collection_dto( - import_id="C.0.0.2", + new_document = create_document_dto( + import_id="D.0.0.2", + family_import_id="A.0.0.3", title="Updated Title", - description="just a test", ) response = client.put( - "/api/v1/collections", - json=new_collection.model_dump(), + "/api/v1/documents", + json=new_document.model_dump(), headers=user_header_token, ) assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE - db_collection: Collection = ( - test_db.query(Collection).filter(Collection.import_id == "C.0.0.2").one() - ) - assert db_collection.title != "Updated Title" - assert db_collection.description != "just a test" + _, pd = _get_doc_tuple(test_db, "D.0.0.2") + assert pd.title != "Updated Title" - families = test_db.query(CollectionFamily).filter( - CollectionFamily.collection_import_id == "C.0.0.2" - ) - assert families.count() == 2 - org: CollectionOrganisation = ( - test_db.query(CollectionOrganisation) - .filter(CollectionOrganisation.collection_import_id == "C.0.0.2") - .one() - ) - assert org is not None - assert rollback_collection_repo.update.call_count == 1 + assert rollback_document_repo.update.call_count == 1 -def test_update_collection_when_not_found( +def test_update_document_when_not_found( client: TestClient, test_db: Session, user_header_token ): setup_db(test_db) - new_collection = create_collection_dto( - import_id="C.0.0.22", + new_document = create_document_dto( + import_id="D.0.0.22", + family_import_id="A.0.0.3", title="Updated Title", - description="just a test", ) response = client.put( - "/api/v1/collections", - json=new_collection.model_dump(), + "/api/v1/documents", + json=new_document.model_dump(), headers=user_header_token, ) assert response.status_code == status.HTTP_404_NOT_FOUND data = response.json() - assert data["detail"] == "Collection not updated: C.0.0.22" + assert data["detail"] == "Document not updated: D.0.0.22" -def test_update_collection_when_db_error( - client: TestClient, test_db: Session, bad_collection_repo, user_header_token +def test_update_document_when_db_error( + client: TestClient, test_db: Session, bad_document_repo, user_header_token ): setup_db(test_db) - new_collection = create_collection_dto( - import_id="C.0.0.2", + + new_document = create_document_dto( + import_id="D.0.0.2", + family_import_id="A.0.0.3", title="Updated Title", - description="just a test", ) response = client.put( - "/api/v1/collections", - json=new_collection.model_dump(), + "/api/v1/documents", + json=new_document.model_dump(), headers=user_header_token, ) assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE data = response.json() assert data["detail"] == "Bad Repo" - assert bad_collection_repo.update.call_count == 1 + assert bad_document_repo.update.call_count == 1 diff --git a/integration_tests/setup_db.py b/integration_tests/setup_db.py index d73cb93d..e4ce917e 100644 --- a/integration_tests/setup_db.py +++ b/integration_tests/setup_db.py @@ -9,9 +9,7 @@ CollectionOrganisation, ) from app.clients.db.models.law_policy.family import ( - DocumentStatus, Family, - FamilyCategory, FamilyDocument, FamilyOrganisation, Slug, @@ -23,7 +21,7 @@ ) # TODO: Change this to use the service.family.create - so we don't miss anything here - +EXPECTED_NUM_FAMILIES = 3 EXPECTED_FAMILIES = [ { "import_id": "A.0.0.1", @@ -76,6 +74,7 @@ ] +EXPECTED_NUM_COLLECTIONS = 2 EXPECTED_COLLECTIONS = [ { "import_id": "C.0.0.1", @@ -93,6 +92,42 @@ }, ] +EXPECTED_NUM_DOCUMENTS = 2 +EXPECTED_DOCUMENTS = [ + { + "import_id": "D.0.0.1", + "family_import_id": "A.0.0.3", + "variant_name": "Original Language", + "status": "Created", + "role": "MAIN", + "type": "Law", + "slug": "", + "physical_id": 1, + "title": "big title1", + "md5_sum": "sum1", + "cdn_object": "obj1", + "source_url": "source1", + "content_type": "application/pdf", + "user_language_name": "TODO", + }, + { + "import_id": "D.0.0.2", + "family_import_id": "A.0.0.3", + "variant_name": "Original Language", + "status": "Created", + "role": "MAIN", + "type": "Law", + "slug": "", + "physical_id": 2, + "title": "title2", + "md5_sum": "sum2", + "cdn_object": "obj2", + "source_url": "source2", + "content_type": "application/pdf", + "user_language_name": "TODO", + }, +] + def setup_db(test_db: Session): with open("integration_tests/default-data.sql") as file: @@ -124,27 +159,21 @@ def _setup_organisation(test_db: Session) -> int: def _setup_collection_data(test_db: Session, org_id: int): - test_db.add( - Collection( - import_id="C.0.0.1", - title="Collection 1 a very big collection", - description="description one", + for index in range(EXPECTED_NUM_COLLECTIONS): + data = EXPECTED_COLLECTIONS[index] + test_db.add( + Collection( + import_id=data["import_id"], + title=data["title"], + description=data["description"], + ) ) - ) - test_db.add( - Collection( - import_id="C.0.0.2", title="Collection 2", description="description two" + test_db.add( + CollectionOrganisation( + collection_import_id=data["import_id"], organisation_id=org_id + ) ) - ) - - test_db.add( - CollectionOrganisation(collection_import_id="C.0.0.1", organisation_id=org_id) - ) - - test_db.add( - CollectionOrganisation(collection_import_id="C.0.0.2", organisation_id=org_id) - ) test_db.add( CollectionFamily(collection_import_id="C.0.0.2", family_import_id="A.0.0.1") @@ -156,38 +185,24 @@ def _setup_collection_data(test_db: Session, org_id: int): def _setup_family_data(test_db: Session, org_id: int): - FAMILY_1 = Family( - import_id="A.0.0.1", - title="apple", - description="", - geography_id=1, - family_category=FamilyCategory.UNFCCC, - ) - - FAMILY_2 = Family( - import_id="A.0.0.2", - title="apple orange banana", - description="", - geography_id=1, - family_category=FamilyCategory.UNFCCC, - ) - - FAMILY_3 = Family( - import_id="A.0.0.3", - title="title", - description="orange peas", - geography_id=1, - family_category=FamilyCategory.UNFCCC, - ) - - test_db.add(FAMILY_1) - test_db.add(FAMILY_2) - test_db.add(FAMILY_3) + for index in range(EXPECTED_NUM_FAMILIES): + data = EXPECTED_FAMILIES[index] + test_db.add( + Family( + import_id=data["import_id"], + title=data["title"], + description=data["summary"], + geography_id=1, + family_category=data["category"], + ) + ) - # Link the families to the org - test_db.add(FamilyOrganisation(family_import_id="A.0.0.1", organisation_id=org_id)) - test_db.add(FamilyOrganisation(family_import_id="A.0.0.2", organisation_id=org_id)) - test_db.add(FamilyOrganisation(family_import_id="A.0.0.3", organisation_id=org_id)) + # Link the families to the org + test_db.add( + FamilyOrganisation( + family_import_id=data["import_id"], organisation_id=org_id + ) + ) # Now a Taxonomy tax = MetadataTaxonomy( @@ -212,60 +227,45 @@ def _setup_family_data(test_db: Session, org_id: int): test_db.flush() # Now add the metadata onto the families - for index in range(3): + for index in range(EXPECTED_NUM_FAMILIES): + data = EXPECTED_FAMILIES[index] test_db.add( FamilyMetadata( - family_import_id=EXPECTED_FAMILIES[index]["import_id"], + family_import_id=data["import_id"], taxonomy_id=tax.id, - value=EXPECTED_FAMILIES[index]["metadata"], + value=data["metadata"], ) ) test_db.add( Slug( name=f"Slug{index+1}", - family_import_id=EXPECTED_FAMILIES[index]["import_id"], + family_import_id=data["import_id"], ) ) def _setup_document_data(test_db: Session, family_id: str) -> None: - pd1 = PhysicalDocument( - id=None, - title="title1", - md5_sum="sum1", - cdn_object="obj1", - source_url="url1", - content_type="type1", - ) - pd2 = PhysicalDocument( - id=None, - title="title2", - md5_sum="sum2", - cdn_object="obj2", - source_url="url2", - content_type="type2", - ) - test_db.add(pd1) - test_db.add(pd2) - test_db.flush() + for index in range(EXPECTED_NUM_DOCUMENTS): + data = EXPECTED_DOCUMENTS[index] + pd = PhysicalDocument( + id=None, + title=data["title"], + md5_sum=data["md5_sum"], + cdn_object=data["cdn_object"], + source_url=data["source_url"], + content_type=data["content_type"], + ) + test_db.add(pd) + test_db.flush() - fd1 = FamilyDocument( - family_import_id=family_id, - physical_document_id=pd1.id, - import_id="D.0.0.1", - variant_name="Original Language", - document_status=DocumentStatus.CREATED, - document_type="Law", - document_role="MAIN", - ) - fd2 = FamilyDocument( - family_import_id=family_id, - physical_document_id=pd2.id, - import_id="D.0.0.2", - variant_name="Original Language", - document_status=DocumentStatus.CREATED, - document_type="Law", - document_role="MAIN", - ) - test_db.add(fd1) - test_db.add(fd2) + fd = FamilyDocument( + family_import_id=family_id, + physical_document_id=pd.id, + import_id=data["import_id"], + variant_name=data["variant_name"], + document_status=data["status"], + document_type=data["type"], + document_role=data["role"], + ) + test_db.add(fd) + test_db.flush()