Skip to content
This repository has been archived by the owner on Sep 11, 2024. It is now read-only.

Commit

Permalink
Merge pull request #59 from climatepolicyradar/RND-479-dataset-filter…
Browse files Browse the repository at this point in the history
…-by-corpus

Dataset - filter_by_corpus
  • Loading branch information
mpjuhasz authored Oct 16, 2023
2 parents 0875e03 + 24d53a9 commit c6c9b88
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/cpr_data_access/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,12 @@ def filter(self, attribute: str, value: Any) -> "Dataset":

return Dataset(**instance_attributes, documents=documents)

def filter_by_corpus(self, corpus_name: str) -> "Dataset":
"""Returns documents that are source from the corpus provided as per their document-id"""
return self.filter(
"document_id", lambda x: x.lower().startswith(corpus_name.lower())
)

def filter_by_language(self, language: str) -> "Dataset":
"""Return documents whose only language is the given language."""
return self.filter("languages", [language])
Expand Down
11 changes: 11 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,17 @@ def test_dataset_filter_by_language(test_dataset):
assert dataset.documents[1].languages == ["en"]


def test_dataset_filter_by_corpus(test_dataset):
"""Test Dataset.filter_by_corpus"""
dataset = test_dataset.filter_by_corpus("UNFCCC")

assert len(dataset) == 0

dataset = test_dataset.filter_by_corpus("CCLW")

assert len(dataset) == 3


def test_dataset_get_all_text_blocks(test_dataset):
text_blocks = test_dataset.get_all_text_blocks()
num_text_blocks = sum(
Expand Down

0 comments on commit c6c9b88

Please sign in to comment.