Skip to content
This repository has been archived by the owner on Sep 11, 2024. It is now read-only.

Commit

Permalink
Merge pull request #94 from climatepolicyradar/feature/dataset-sample…
Browse files Browse the repository at this point in the history
…-text-blocks

add dataset.sample_text_blocks method
  • Loading branch information
kdutia authored Nov 16, 2023
2 parents ee83671 + b67ac0b commit 8feafcf
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 5 deletions.
30 changes: 25 additions & 5 deletions src/cpr_data_access/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,11 +1097,31 @@ def sample(self, n: Union[float, int], random_state: int = 42) -> "Dataset":

return Dataset(**instance_attributes, documents=documents)

def sample_text(
self, n: int, document_ids: Optional[Sequence[str]], replace: bool = False
):
"""Randomly sample a number of text blocks. Used for e.g. negative sampling for text classification."""
raise NotImplementedError
def sample_text_blocks(
self, n: int, with_document_context: bool = False
) -> Union[List[TextBlock], Tuple[List[TextBlock], dict]]: #  type: ignore
"""
Randomly sample a number of text blocks. Used for e.g. negative sampling for text classification.
For reproducibility you may want to set `random.seed` before calling this function.
:param n: number of text blocks to sample
:param with_document_context: If True, include document context in the output. Defaults to False
:return: list of text blocks or (text block, document context) tuples.
"""

all_blocks = self.get_all_text_blocks(
with_document_context=with_document_context
)

if n >= len(all_blocks):
LOGGER.warning(
"Requested number of text blocks is >= the number of text blocks in the dataset. Returning all text blocks."
)
return all_blocks

else:
return random.sample(all_blocks, n) # type: ignore

def get_all_text_blocks(
self, with_document_context: bool = False
Expand Down
13 changes: 13 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,19 @@ def test_dataset_get_all_text_blocks(test_dataset):
assert all(["text_blocks" not in i[1] for i in text_blocks_with_document_context])


def test_dataset_sample_text_blocks(test_dataset):
text_blocks = test_dataset.sample_text_blocks(2)
num_text_blocks = sum(
[
len(doc.text_blocks) if doc.text_blocks is not None else 0
for doc in test_dataset.documents
]
)

assert len(text_blocks) == 2
assert len(text_blocks) < num_text_blocks


def test_text_block_add_valid_spans(test_document, test_spans_valid):
block_1 = test_document.text_blocks[0]
block_2 = test_document.text_blocks[1]
Expand Down

0 comments on commit 8feafcf

Please sign in to comment.