Skip to content
This repository has been archived by the owner on Sep 11, 2024. It is now read-only.

Commit

Permalink
Merge pull request #60 from climatepolicyradar/fix/rnd-458-languages-…
Browse files Browse the repository at this point in the history
…not-set-properly-for-cpr-huggingface-dump

fix issue with document languages not being present & incorrect number of documents
  • Loading branch information
kdutia authored Oct 17, 2023
2 parents c6c9b88 + d3e16d5 commit 5694c63
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 4 deletions.
26 changes: 22 additions & 4 deletions src/cpr_data_access/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,13 @@ def to_huggingface(
info=dataset_info,
)

# Rename column to avoid confusion with the 'language' field, which is text block language
rename_map = {
"languages": "document_languages",
}

huggingface_dataset = huggingface_dataset.rename_columns(rename_map)

return huggingface_dataset

def _from_huggingface_parquet(
Expand All @@ -1127,15 +1134,26 @@ def _from_huggingface_parquet(

hf_dataframe: pd.DataFrame = huggingface_dataset.to_pandas()

# This undoes the renaming of columns done in to_huggingface()
hf_dataframe = hf_dataframe.rename(columns={"document_languages": "languages"})

# Create a dummy variable to group on combining document_id and translated.
# This way we get an accurate count in the progress bar.
hf_dataframe["_document_id_translated"] = hf_dataframe[
"document_id"
] + hf_dataframe["translated"].astype(str)

if limit is not None:
doc_ids = hf_dataframe["document_id"].unique()[:limit]
hf_dataframe = hf_dataframe[hf_dataframe["document_id"].isin(doc_ids)]
doc_ids = hf_dataframe["_document_id_translated"].unique()[:limit]
hf_dataframe = hf_dataframe[
hf_dataframe["_document_id_translated"].isin(doc_ids)
]

documents = []

for _, doc_df in tqdm(
hf_dataframe.groupby("document_id"),
total=hf_dataframe["document_id"].nunique(),
hf_dataframe.groupby("_document_id_translated"),
total=hf_dataframe["_document_id_translated"].nunique(),
unit="docs",
):
document_text_blocks = [
Expand Down
Binary file modified tests/test_data/GST_huggingface_data_sample.parquet
Binary file not shown.
2 changes: 2 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,8 @@ def test_dataset_from_huggingface_gst(test_huggingface_dataset_gst):
assert isinstance(dataset, Dataset)
assert all(isinstance(doc, GSTDocument) for doc in dataset.documents)

assert any(doc.languages is not None for doc in dataset.documents)

# Check hugingface dataset has the same number of documents as the dataset
assert len(dataset) == len({d["document_id"] for d in test_huggingface_dataset_gst})

Expand Down

0 comments on commit 5694c63

Please sign in to comment.