Skip to content

Commit

Permalink
return document metadata in to_passage_level_json when text is empty (#…
Browse files Browse the repository at this point in the history
…120)

* return document metadata in to_passage_level_json when text is empty

* bump sdk version
  • Loading branch information
kdutia authored Sep 30, 2024
1 parent 2f8cdef commit 972cb02
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 12 deletions.
25 changes: 20 additions & 5 deletions src/cpr_sdk/parser_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def from_flat_json(data: dict):

return ParserOutput.model_validate(unflattened)

def to_passage_level_json(self) -> list[dict[str, Any]]:
def to_passage_level_json(self, include_empty: bool = True) -> list[dict[str, Any]]:
"""
Convert the parser output to a passage-level JSON format.
Expand All @@ -394,8 +394,13 @@ def to_passage_level_json(self) -> list[dict[str, Any]]:
model_dump_json method and then reloading with json.load is as objects like
Enums and child pydantic objects persist when using the model_dump method.
We don't want these when we push to huggingface.
:param include_empty: Whether to output the document metadata if there are no
text blocks in the ParserOutput. If True, outputs a single dict with None values
for each text block related field. If False, returns an empty list.
"""
if self.text_blocks is None:

if not self.text_blocks and not include_empty:
return []

fixed_fields_dict = json.loads(
Expand All @@ -407,6 +412,19 @@ def to_passage_level_json(self) -> list[dict[str, Any]]:
)
)

empty_html_text_block_keys: list[str] = list(HTMLTextBlock.model_fields.keys())
empty_pdf_text_block_keys: list[str] = list(PDFTextBlock.model_fields.keys())

if not self.text_blocks:
passages_array_filled = [
{key: None for key in empty_html_text_block_keys}
| {key: None for key in empty_pdf_text_block_keys}
| fixed_fields_dict
| {"block_index": 0, PDF_PAGE_METADATA_KEY: None}
]

return passages_array_filled

passages_array = [
fixed_fields_dict
| json.loads(block.model_dump_json(exclude={"text"}))
Expand All @@ -422,9 +440,6 @@ def to_passage_level_json(self) -> list[dict[str, Any]]:
else None
)

empty_html_text_block_keys: list[str] = list(HTMLTextBlock.model_fields.keys())
empty_pdf_text_block_keys: list[str] = list(PDFTextBlock.model_fields.keys())

passages_array_filled = []
for passage in passages_array:
for key in empty_html_text_block_keys:
Expand Down
46 changes: 39 additions & 7 deletions tests/test_parser_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,21 @@ def test_parser_output_object(
parser_output = ParserOutput.from_flat_json(parser_output_json_flat)


@pytest.mark.parametrize("null_text_blocks", [True, False])
@pytest.mark.parametrize("include_empty_text_blocks", [True, False])
def test_to_passage_level_json_method(
parser_output_json_pdf: dict,
parser_output_json_html: dict,
null_text_blocks: bool,
include_empty_text_blocks: bool,
) -> None:
"""Test that we can successfully create a passage level array from the text blocks."""
"""
Test that we can successfully create a passage level array from the text blocks.
:param null_text_blocks: Whether to set the text blocks to None in the parser output
:param include_empty_text_blocks: The setting for the `include_empty_text_blocks`
kwarg in the `to_passage_level_json` method
"""
expected_top_level_fields = set(
list(TextBlock.model_fields.keys())
+ list(HTMLTextBlock.model_fields.keys())
Expand All @@ -184,13 +194,34 @@ def test_to_passage_level_json_method(
expected_pdf_data_fields.remove(field)

parser_output_pdf = ParserOutput.model_validate(parser_output_json_pdf)
passage_level_array_pdf = parser_output_pdf.to_passage_level_json()

assert parser_output_pdf.pdf_data is not None
parser_output_html = ParserOutput.model_validate(parser_output_json_html)
passage_level_array_html = parser_output_html.to_passage_level_json()
assert parser_output_html.html_data is not None

if null_text_blocks:
parser_output_pdf.pdf_data.text_blocks = []
parser_output_html.html_data.text_blocks = []

passage_level_array_pdf = parser_output_pdf.to_passage_level_json(
include_empty=include_empty_text_blocks
)
passage_level_array_html = parser_output_html.to_passage_level_json(
include_empty=include_empty_text_blocks
)

if null_text_blocks:
if include_empty_text_blocks:
assert len(passage_level_array_pdf) == 1
assert len(passage_level_array_html) == 1
else:
assert len(passage_level_array_pdf) == 0
assert len(passage_level_array_html) == 0

assert len(passage_level_array_pdf) == len(parser_output_pdf.text_blocks)
assert len(passage_level_array_html) == len(parser_output_html.text_blocks)
# the resulting output is [] so we can stop the test here
return
else:
assert len(passage_level_array_pdf) == len(parser_output_pdf.text_blocks)
assert len(passage_level_array_html) == len(parser_output_html.text_blocks)

for passage_level_array in [passage_level_array_pdf, passage_level_array_html]:
first_doc_keys = set(passage_level_array[0].keys())
Expand All @@ -204,7 +235,8 @@ def test_to_passage_level_json_method(
)

if passage["document_content_type"] == CONTENT_TYPE_PDF:
assert passage[PDF_PAGE_METADATA_KEY] is not None
if not (null_text_blocks and include_empty_text_blocks):
assert passage[PDF_PAGE_METADATA_KEY] is not None
assert set(passage["pdf_data"].keys()) == expected_pdf_data_fields
elif passage["document_content_type"] == CONTENT_TYPE_HTML:
assert set(passage["html_data"].keys()) == expected_html_data_fields
Expand Down

0 comments on commit 972cb02

Please sign in to comment.