From 50487f2c0b98925bfd6ab0e6cbbc547dd96b9aa1 Mon Sep 17 00:00:00 2001 From: mpjuhasz Date: Mon, 16 Oct 2023 15:04:36 +0200 Subject: [PATCH 1/4] filter_by_corpus method added to Dataset --- src/cpr_data_access/models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/cpr_data_access/models.py b/src/cpr_data_access/models.py index 858071f..f0c31bd 100644 --- a/src/cpr_data_access/models.py +++ b/src/cpr_data_access/models.py @@ -1009,6 +1009,10 @@ def filter(self, attribute: str, value: Any) -> "Dataset": return Dataset(**instance_attributes, documents=documents) + def filter_by_corpus(self, corpus_name: str) -> "Dataset": + """"Returns documents that are source from the corpus provided as per their document-id""" + return self.filter("document_id", lambda x: x.lower().startswith(corpus_name.lower())) + def filter_by_language(self, language: str) -> "Dataset": """Return documents whose only language is the given language.""" return self.filter("languages", [language]) From 358a7123f0aa389c6f71c24a25f36f4bf6b49d12 Mon Sep 17 00:00:00 2001 From: mpjuhasz Date: Mon, 16 Oct 2023 15:10:34 +0200 Subject: [PATCH 2/4] test added --- tests/test_models.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_models.py b/tests/test_models.py index 61ce4d3..df2ed9b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -181,6 +181,16 @@ def test_dataset_filter_by_language(test_dataset): assert dataset.documents[1].languages == ["en"] +def test_dataset_filter_by_corpus(test_dataset): + """Test Dataset.filter_by_corpus""" + dataset = test_dataset.filter_by_corpus('UNFCCC') + + assert len(dataset) == 0 + + dataset = test_dataset.filter_by_corpus('CCLW') + + assert len(dataset) == 3 + def test_dataset_get_all_text_blocks(test_dataset): text_blocks = test_dataset.get_all_text_blocks() num_text_blocks = sum( From dbc504ecc43562eccf8adc06088ea5e023cc9779 Mon Sep 17 00:00:00 2001 From: mpjuhasz Date: Mon, 16 Oct 2023 16:34:03 +0200 Subject: [PATCH 3/4] minor docstring fix --- src/cpr_data_access/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cpr_data_access/models.py b/src/cpr_data_access/models.py index f0c31bd..e44b54d 100644 --- a/src/cpr_data_access/models.py +++ b/src/cpr_data_access/models.py @@ -1010,7 +1010,7 @@ def filter(self, attribute: str, value: Any) -> "Dataset": return Dataset(**instance_attributes, documents=documents) def filter_by_corpus(self, corpus_name: str) -> "Dataset": - """"Returns documents that are source from the corpus provided as per their document-id""" + """Returns documents that are source from the corpus provided as per their document-id""" return self.filter("document_id", lambda x: x.lower().startswith(corpus_name.lower())) def filter_by_language(self, language: str) -> "Dataset": From 24d53a98b9618eaee1413f643185b8916db54399 Mon Sep 17 00:00:00 2001 From: mpjuhasz Date: Mon, 16 Oct 2023 18:13:35 +0200 Subject: [PATCH 4/4] black reformatting --- src/cpr_data_access/models.py | 4 +++- tests/test_models.py | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/cpr_data_access/models.py b/src/cpr_data_access/models.py index e44b54d..1977202 100644 --- a/src/cpr_data_access/models.py +++ b/src/cpr_data_access/models.py @@ -1011,7 +1011,9 @@ def filter(self, attribute: str, value: Any) -> "Dataset": def filter_by_corpus(self, corpus_name: str) -> "Dataset": """Returns documents that are source from the corpus provided as per their document-id""" - return self.filter("document_id", lambda x: x.lower().startswith(corpus_name.lower())) + return self.filter( + "document_id", lambda x: x.lower().startswith(corpus_name.lower()) + ) def filter_by_language(self, language: str) -> "Dataset": """Return documents whose only language is the given language.""" diff --git a/tests/test_models.py b/tests/test_models.py index df2ed9b..9b2ca4e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -183,14 +183,15 @@ def test_dataset_filter_by_language(test_dataset): def test_dataset_filter_by_corpus(test_dataset): """Test Dataset.filter_by_corpus""" - dataset = test_dataset.filter_by_corpus('UNFCCC') + dataset = test_dataset.filter_by_corpus("UNFCCC") assert len(dataset) == 0 - dataset = test_dataset.filter_by_corpus('CCLW') + dataset = test_dataset.filter_by_corpus("CCLW") assert len(dataset) == 3 + def test_dataset_get_all_text_blocks(test_dataset): text_blocks = test_dataset.get_all_text_blocks() num_text_blocks = sum(