diff --git a/src/cpr_data_access/models.py b/src/cpr_data_access/models.py index 858071f..1977202 100644 --- a/src/cpr_data_access/models.py +++ b/src/cpr_data_access/models.py @@ -1009,6 +1009,12 @@ def filter(self, attribute: str, value: Any) -> "Dataset": return Dataset(**instance_attributes, documents=documents) + def filter_by_corpus(self, corpus_name: str) -> "Dataset": + """Returns documents that are source from the corpus provided as per their document-id""" + return self.filter( + "document_id", lambda x: x.lower().startswith(corpus_name.lower()) + ) + def filter_by_language(self, language: str) -> "Dataset": """Return documents whose only language is the given language.""" return self.filter("languages", [language]) diff --git a/tests/test_models.py b/tests/test_models.py index 61ce4d3..9b2ca4e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -181,6 +181,17 @@ def test_dataset_filter_by_language(test_dataset): assert dataset.documents[1].languages == ["en"] +def test_dataset_filter_by_corpus(test_dataset): + """Test Dataset.filter_by_corpus""" + dataset = test_dataset.filter_by_corpus("UNFCCC") + + assert len(dataset) == 0 + + dataset = test_dataset.filter_by_corpus("CCLW") + + assert len(dataset) == 3 + + def test_dataset_get_all_text_blocks(test_dataset): text_blocks = test_dataset.get_all_text_blocks() num_text_blocks = sum(