Skip to content

Commit

Permalink
Adding failing test to verify theory. (#80)
Browse files Browse the repository at this point in the history
* Adding failing test to verify theory.

* Adding failing test.:

* Fixing the test.

* Refactoring.

* Adding test for document passage matches.

* Refactoring test.

---------

Co-authored-by: Mark <[email protected]>
  • Loading branch information
THOR300 and Mark authored Dec 6, 2023
1 parent 191b39e commit a0bf482
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 3 deletions.
55 changes: 54 additions & 1 deletion src/index/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,16 @@
from typing import Any
from pathlib import Path
import numpy as np
from datetime import datetime

from cpr_data_access.parser_models import ParserOutput
from cpr_data_access.parser_models import (
ParserOutput,
BackendDocument,
PDFData,
PDFTextBlock,
BlockType,
PDFPageMetadata,
)


def read_local_json_file(file_path: str) -> dict:
Expand All @@ -22,6 +30,51 @@ def read_local_npy_file(file_path: str) -> Any:
return np.load(file_path)


def get_parser_output(document_id: int, family_id: int) -> ParserOutput:
"""Create a ParserOutput with specific family and document ids."""
return ParserOutput(
document_id=f"CCLW.executive.{document_id}.0",
document_name="Example name",
document_description="Example description.",
document_slug="",
document_content_type="application/pdf",
pdf_data=PDFData(
page_metadata=[PDFPageMetadata(page_number=1, dimensions=(612.0, 792.0))],
md5sum="123",
text_blocks=[
PDFTextBlock(
text=[f"Example text for CCLW.executive.{document_id}.0"],
text_block_id="p_1_b_0",
type=BlockType.TEXT,
type_confidence=1.0,
coords=[
(89.58967590332031, 243.0702667236328),
(519.2817077636719, 243.0702667236328),
(519.2817077636719, 303.5213928222656),
(89.58967590332031, 303.5213928222656),
],
page_number=1,
)
],
),
document_metadata=BackendDocument(
name="Example name",
description="Example description.",
import_id=f"CCLW.executive.{document_id}.0",
slug="",
family_import_id=f"CCLW.family.{family_id}.0",
family_slug="",
publication_ts=datetime.now(),
type="",
source="",
category="",
geography="",
languages=[],
metadata={},
),
)


@pytest.fixture
def s3_bucket_and_region() -> dict:
return {
Expand Down
50 changes: 50 additions & 0 deletions src/index/test/test_vespa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
import pytest
from unittest.mock import patch
from cloudpathlib import S3Path
from pathlib import Path

from src.index.vespa_ import _get_vespa_instance, VespaConfigError
from src import config
from src.index.vespa_ import (
get_document_generator,
FAMILY_DOCUMENT_SCHEMA,
DOCUMENT_PASSAGE_SCHEMA,
)
from src.utils import read_npy_file
from conftest import get_parser_output


def test_get_vespa_instance() -> None:
Expand All @@ -20,3 +30,43 @@ def test_get_vespa_instance() -> None:
with pytest.raises(VespaConfigError) as context:
_get_vespa_instance()
assert expected_error_string not in str(context.value)


@patch("src.index.vespa_.read_npy_file")
def test_get_document_generator(mock_read_npy_file):
"""Assert that the vespa document generator works as expected."""
mock_read_npy_file.return_value = read_npy_file(
Path("src/index/test/data/CCLW.executive.10002.4495.npy")
)

embedding_dir_as_path = S3Path("s3://path/to/embeddings")

# An array of ParserOutputs, some belonging to the same family.
tasks = [
get_parser_output(document_id=0, family_id=0),
get_parser_output(document_id=1, family_id=0),
get_parser_output(document_id=2, family_id=1),
]

generator = get_document_generator(tasks, embedding_dir_as_path)

vespa_family_document_ids = []
vespa_document_passage_fam_refs = []
for schema, id, data in generator:
if schema == FAMILY_DOCUMENT_SCHEMA:
vespa_family_document_ids.append(id)
if schema == DOCUMENT_PASSAGE_SCHEMA:
vespa_document_passage_fam_refs.append(data["family_document_ref"])

# Check every family document id is unique and that there's one for each task
assert len(set(vespa_family_document_ids)) == len(vespa_family_document_ids)
assert len(vespa_family_document_ids) == len(tasks)

# Check that every family document is referenced by one passage
# (this is as we had one text block for each family document)
assert len(vespa_family_document_ids) == len(vespa_document_passage_fam_refs)
for ref in vespa_document_passage_fam_refs:
# A document passage id CCLW.executive.0.0.0 would take the form
# 'id:doc_search:family_document::CCLW.executive.0.0'
ref_id_format = ref.split(":")[-1]
assert ref_id_format in vespa_family_document_ids
2 changes: 1 addition & 1 deletion src/index/test/test_vespa_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_get_document_generator(

schema, document_id, data = document
assert schema == FAMILY_DOCUMENT_SCHEMA
assert document_id == parser_output.document_metadata.family_import_id
assert document_id == parser_output.document_metadata.import_id
assert isinstance(data, dict)
VespaFamilyDocument.model_validate(data)

Expand Down
2 changes: 1 addition & 1 deletion src/index/vespa_.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def get_document_generator(
)
embeddings = read_npy_file(task_array_file_path)

family_document_id = DocumentID(task.document_metadata.family_import_id)
family_document_id = DocumentID(task.document_metadata.import_id)
family_document = VespaFamilyDocument(
search_weights_ref=f"id:{_NAMESPACE}:search_weights::{search_weights_id}",
family_name=task.document_name,
Expand Down

0 comments on commit a0bf482

Please sign in to comment.