diff --git a/src/cpr_sdk/parser_models.py b/src/cpr_sdk/parser_models.py index 9f48c3f..d5fcfb9 100644 --- a/src/cpr_sdk/parser_models.py +++ b/src/cpr_sdk/parser_models.py @@ -378,7 +378,7 @@ def from_flat_json(data: dict): return ParserOutput.model_validate(unflattened) - def to_passage_level_json(self) -> list[dict[str, Any]]: + def to_passage_level_json(self, include_empty: bool = True) -> list[dict[str, Any]]: """ Convert the parser output to a passage-level JSON format. @@ -394,8 +394,13 @@ def to_passage_level_json(self) -> list[dict[str, Any]]: model_dump_json method and then reloading with json.load is as objects like Enums and child pydantic objects persist when using the model_dump method. We don't want these when we push to huggingface. + + :param include_empty: Whether to output the document metadata if there are no + text blocks in the ParserOutput. If True, outputs a single dict with None values + for each text block related field. If False, returns an empty list. """ - if self.text_blocks is None: + + if not self.text_blocks and not include_empty: return [] fixed_fields_dict = json.loads( @@ -407,6 +412,19 @@ def to_passage_level_json(self) -> list[dict[str, Any]]: ) ) + empty_html_text_block_keys: list[str] = list(HTMLTextBlock.model_fields.keys()) + empty_pdf_text_block_keys: list[str] = list(PDFTextBlock.model_fields.keys()) + + if not self.text_blocks: + passages_array_filled = [ + {key: None for key in empty_html_text_block_keys} + | {key: None for key in empty_pdf_text_block_keys} + | fixed_fields_dict + | {"block_index": 0, PDF_PAGE_METADATA_KEY: None} + ] + + return passages_array_filled + passages_array = [ fixed_fields_dict | json.loads(block.model_dump_json(exclude={"text"})) @@ -422,9 +440,6 @@ def to_passage_level_json(self) -> list[dict[str, Any]]: else None ) - empty_html_text_block_keys: list[str] = list(HTMLTextBlock.model_fields.keys()) - empty_pdf_text_block_keys: list[str] = list(PDFTextBlock.model_fields.keys()) - passages_array_filled = [] for passage in passages_array: for key in empty_html_text_block_keys: diff --git a/tests/test_parser_models.py b/tests/test_parser_models.py index 0fbc148..1da4384 100644 --- a/tests/test_parser_models.py +++ b/tests/test_parser_models.py @@ -160,11 +160,21 @@ def test_parser_output_object( parser_output = ParserOutput.from_flat_json(parser_output_json_flat) +@pytest.mark.parametrize("null_text_blocks", [True, False]) +@pytest.mark.parametrize("include_empty_text_blocks", [True, False]) def test_to_passage_level_json_method( parser_output_json_pdf: dict, parser_output_json_html: dict, + null_text_blocks: bool, + include_empty_text_blocks: bool, ) -> None: - """Test that we can successfully create a passage level array from the text blocks.""" + """ + Test that we can successfully create a passage level array from the text blocks. + + :param null_text_blocks: Whether to set the text blocks to None in the parser output + :param include_empty_text_blocks: The setting for the `include_empty_text_blocks` + kwarg in the `to_passage_level_json` method + """ expected_top_level_fields = set( list(TextBlock.model_fields.keys()) + list(HTMLTextBlock.model_fields.keys()) @@ -184,13 +194,34 @@ def test_to_passage_level_json_method( expected_pdf_data_fields.remove(field) parser_output_pdf = ParserOutput.model_validate(parser_output_json_pdf) - passage_level_array_pdf = parser_output_pdf.to_passage_level_json() - + assert parser_output_pdf.pdf_data is not None parser_output_html = ParserOutput.model_validate(parser_output_json_html) - passage_level_array_html = parser_output_html.to_passage_level_json() + assert parser_output_html.html_data is not None + + if null_text_blocks: + parser_output_pdf.pdf_data.text_blocks = [] + parser_output_html.html_data.text_blocks = [] + + passage_level_array_pdf = parser_output_pdf.to_passage_level_json( + include_empty=include_empty_text_blocks + ) + passage_level_array_html = parser_output_html.to_passage_level_json( + include_empty=include_empty_text_blocks + ) + + if null_text_blocks: + if include_empty_text_blocks: + assert len(passage_level_array_pdf) == 1 + assert len(passage_level_array_html) == 1 + else: + assert len(passage_level_array_pdf) == 0 + assert len(passage_level_array_html) == 0 - assert len(passage_level_array_pdf) == len(parser_output_pdf.text_blocks) - assert len(passage_level_array_html) == len(parser_output_html.text_blocks) + # the resulting output is [] so we can stop the test here + return + else: + assert len(passage_level_array_pdf) == len(parser_output_pdf.text_blocks) + assert len(passage_level_array_html) == len(parser_output_html.text_blocks) for passage_level_array in [passage_level_array_pdf, passage_level_array_html]: first_doc_keys = set(passage_level_array[0].keys()) @@ -204,7 +235,8 @@ def test_to_passage_level_json_method( ) if passage["document_content_type"] == CONTENT_TYPE_PDF: - assert passage[PDF_PAGE_METADATA_KEY] is not None + if not (null_text_blocks and include_empty_text_blocks): + assert passage[PDF_PAGE_METADATA_KEY] is not None assert set(passage["pdf_data"].keys()) == expected_pdf_data_fields elif passage["document_content_type"] == CONTENT_TYPE_HTML: assert set(passage["html_data"].keys()) == expected_html_data_fields