Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for parallel data curation #193

Open
wants to merge 50 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
c7a6423
add data interface to read simple bitext
shuoyangd Jul 30, 2024
4b3dc97
adding ParallelScoreFilter
nverma1 Jul 30, 2024
114716e
add test for ParallelScoreFilter, small style change for ParallelData…
shuoyangd Jul 31, 2024
cbab143
allow ParallelScoreFilter to take different filters for source and ta…
shuoyangd Jul 31, 2024
82f5486
add JointScoreFilter and LengthRatioFilter
nverma1 Jul 31, 2024
f9a0535
[WIP] add heuristic filter w/o test
shuoyangd Jul 31, 2024
8f25988
merge with main
shuoyangd Jul 31, 2024
612249c
add test for histogram filter, fix a few bugs
shuoyangd Jul 31, 2024
2fe4973
length ratio, joint score filter testing
nverma1 Jul 31, 2024
b61d7f1
fix typing in joint test
nverma1 Jul 31, 2024
f63a1f9
add a fake comet qe filter as an initial step
shuoyangd Aug 1, 2024
76bced7
[WIP] adding bitext cleaning tutorial
nverma1 Aug 1, 2024
1a2bb1e
[WIP] fixing example
nverma1 Aug 2, 2024
74698d5
fix slow histogram filter, fix faulty bitext loading
shuoyangd Aug 2, 2024
bf2e6ac
tutorial running
nverma1 Aug 2, 2024
62d1242
[WIP] documentation of bitext tutorial
nverma1 Aug 2, 2024
c413ea2
add tested version of comet-qe filter
shuoyangd Aug 2, 2024
5a90038
fix ParallelDataset bug where single file name is not accepted, and d…
shuoyangd Aug 5, 2024
f8046dd
add docstring to explain simple bitext format, fix a bug where file e…
shuoyangd Aug 5, 2024
6c7aea4
remove print line for debug
shuoyangd Aug 5, 2024
a457995
add comet filter to tutorial
shuoyangd Aug 5, 2024
c5a6f1c
refactor COMET QE filter to decouple model from filter, make sure Joi…
shuoyangd Aug 5, 2024
61713e4
use refactored qe filter
shuoyangd Aug 5, 2024
a4d2bb3
wrap_qe_input should be a static method
shuoyangd Aug 5, 2024
0674400
use conditional import for comet, formatting changes
shuoyangd Aug 6, 2024
6936f9a
[WIP] add cometoid
shuoyangd Aug 6, 2024
da96d29
[WIP] attempt to resolve device conflict but is failing
shuoyangd Aug 7, 2024
14b7d70
[WIP] playing with cometoid arguments
shuoyangd Aug 7, 2024
b02b56d
[WIP] -d 0 doesn't look necessary
shuoyangd Aug 7, 2024
6c1e719
tested arguments for Cometoid
shuoyangd Aug 8, 2024
70a7fe8
use proper safe import, make sure test doesn't crash sans comet/pymarian
shuoyangd Aug 8, 2024
c66d7f9
falling back to comet for tutorial since that's easier to set up, upp…
shuoyangd Aug 8, 2024
861bd4d
give credit to original fairseq implementation of histogram filtering…
shuoyangd Aug 8, 2024
52ba08e
fix pre-commit complaint
shuoyangd Aug 8, 2024
62c254b
fix small bug
shuoyangd Aug 11, 2024
91ea9fa
fix another occurrence of the same bug
shuoyangd Aug 13, 2024
12783ec
introduce shard limit to a single PyMarian API call to avoid memory l…
shuoyangd Aug 13, 2024
a65588a
repartition after reading simple bitext data
shuoyangd Aug 16, 2024
3f1d09b
-d 0 is actually needed for pymarian
shuoyangd Aug 16, 2024
102429a
remove duplicate LengthRatioFilter definition
shuoyangd Sep 5, 2024
8a367dd
refactor repeated code segment in file writing, change classifier to …
shuoyangd Sep 20, 2024
396d7ba
[WIP] addressed comments in #193 apart from resolving .iloc pattern, …
shuoyangd Sep 20, 2024
eb4f4df
refactor to resolve .loc pattern, test passing
shuoyangd Oct 1, 2024
3addf44
add missing file
shuoyangd Oct 1, 2024
a14a78a
revert changes in setup.py
shuoyangd Oct 1, 2024
6b8dfa0
fix a small bug in parallel dataset, explain why repartition is disab…
shuoyangd Oct 1, 2024
bb4f148
add api guide, small change on bitext/parallel score filter docstring
shuoyangd Oct 1, 2024
d309744
fix read_simple_bitext test issues
shuoyangd Oct 1, 2024
21676bd
Merge branch 'main' into main
shuoyangd Oct 1, 2024
7797925
reinstate dependencies lost during merging
shuoyangd Oct 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemo_curator/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .doc_dataset import DocumentDataset
from .doc_dataset import DocumentDataset, ParallelDataset

__all__ = ["DocumentDataset"]
shuoyangd marked this conversation as resolved.
Show resolved Hide resolved
70 changes: 69 additions & 1 deletion nemo_curator/datasets/doc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@

import dask.dataframe as dd

from nemo_curator.utils.distributed_utils import read_data, write_to_disk
from nemo_curator.utils.distributed_utils import (
read_data,
read_simple_bitext_data,
write_to_disk,
)
from nemo_curator.utils.file_utils import get_all_files_paths_under


Expand Down Expand Up @@ -252,3 +256,67 @@ def _read_json_or_parquet(
raise TypeError("File input must be a string or list.")

return raw_data


class ParallelDataset(DocumentDataset):
shuoyangd marked this conversation as resolved.
Show resolved Hide resolved
"""
An extension of the standard `DocumentDataset` with a special method that loads simple bitext.

For data with more complicated metadata, please convert your data into jsonl/parquet/pickle format
and use interfaces defined in `DocumentDataset`.
"""

def persist(self):
return ParallelDataset(self.df.persist())

@classmethod
def read_simple_bitext(
cls,
src_input_files: Union[str, List[str]],
tgt_input_files: Union[str, List[str]],
src_lang: str,
tgt_lang: str,
backend: str = "pandas",
add_filename: bool = False,
partition_size: Optional[Union[int, str]] = "100MB",
):
shuoyangd marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(src_input_files, list) and isinstance(tgt_input_files, list):
df = read_simple_bitext_data(
src_input_files,
tgt_input_files,
src_lang,
tgt_lang,
backend,
add_filename,
)
elif isinstance(src_input_files, str) and isinstance(tgt_input_files, str):
shuoyangd marked this conversation as resolved.
Show resolved Hide resolved
df = read_simple_bitext_data(
[src_input_files],
[tgt_input_files],
src_lang,
tgt_lang,
backend,
add_filename,
)
else:
raise TypeError("Both file inputs must be strings or lists.")

if partition_size:
df = df.repartition(partition_size=partition_size)
return cls(df)

def to_bitext(
self,
output_file_dir,
write_to_filename=False,
):
"""
See nemo_curator.utils.distributed_utils.write_to_disk docstring for other parameters.

"""
write_to_disk(
df=self.df,
output_file_dir=output_file_dir,
write_to_filename=write_to_filename,
output_type="bitext",
)
10 changes: 0 additions & 10 deletions nemo_curator/download/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,7 @@
"import_downloader",
"import_extractor",
"import_iterator",
"download_common_crawl",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this was modified by mistake, so probably revert it unless there's something I'm missing.

"CommonCrawlWARCDownloader",
"CommonCrawlWARCExtractor",
"CommonCrawlWARCIterator",
"CommonCrawlWARCDownloaderExtractOnly",
"JusTextExtractor",
"ResiliparseExtractor",
"download_wikipedia",
"WikipediaDownloader",
"WikipediaIterator",
"WikipediaExtractor",
"batch_download",
"download_arxiv",
"ArxivDownloader",
Expand Down
9 changes: 8 additions & 1 deletion nemo_curator/filters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .classifier_filter import FastTextLangId, FastTextQualityFilter
from .classifier_filter import (
FastTextLangId,
FastTextQualityFilter,
QualityEstimationFilter,
)
from .code import (
AlphaFilter,
GeneralCommentToCodeFilter,
Expand All @@ -29,6 +33,8 @@
BulletsFilter,
CommonEnglishWordsFilter,
EllipsisFilter,
HistogramFilter,
LengthRatioFilter,
LongWordFilter,
MeanWordLengthFilter,
NonAlphaNumericFilter,
Expand Down Expand Up @@ -84,4 +90,5 @@
"AlphaFilter",
"HTMLBoilerplateFilter",
"PerExtensionFilter",
"LengthRatioFilter",
shuoyangd marked this conversation as resolved.
Show resolved Hide resolved
]
99 changes: 99 additions & 0 deletions nemo_curator/filters/classifier_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

import dask
import fasttext
import numpy as np
import pandas as pd

from nemo_curator.filters.doc_filter import DocumentFilter
from nemo_curator.filters.models.qe_models import COMETQEModel, PyMarianQEModel
from nemo_curator.utils.decorators import batched
from nemo_curator.utils.distributed_utils import NoWorkerError, load_object_on_worker

Expand Down Expand Up @@ -99,3 +102,99 @@ def keep_document(self, score):

def _load_model(self):
return fasttext.load_model(self._model_path)


class QualityEstimationFilter(DocumentFilter):
shuoyangd marked this conversation as resolved.
Show resolved Hide resolved

# a mapping from supported model names to their corresponding model class
SUPPORTED_MODELS = {
"comet-qe": COMETQEModel,
"cometoid-wmt23": PyMarianQEModel,
"cometoid-wmt23-mqm": PyMarianQEModel,
}

def __init__(self, model_name, cutoff, mode="always_en_x", gpu=False):
if model_name in self.SUPPORTED_MODELS:
self._name = model_name
else:
raise NotImplementedError(
f"Only the following models are currently supported: {str(self.SUPPORTED_MODELS.keys())}"
)

self._model_path = None
self._mode = mode
self._cutoff = cutoff
self._gpu = gpu

def _score_document_with_qe(
self, model, df: pd.Series, mode="always_en_x"
) -> List[float]:

def _is_en_x(src_lang: str, tgt_lang: str):
return src_lang == "en" and tgt_lang != "en"

def _has_en(src_lang: str, tgt_lang: str):
return src_lang == "en" and tgt_lang == "en"

model_class = self.SUPPORTED_MODELS[self._name]

if mode == "simple":
input = [
model_class.wrap_qe_input(src, tgt)
for src, tgt in zip(df["src"], df["tgt"])
]
return model.predict(input)
elif mode == "always_en_x":
# if English is included but it's on the target side, flip to make sure we are scoring with en-x
# this strategy was proposed in: https://aclanthology.org/2023.wmt-1.50.pdf
input = [
model_class.wrap_qe_input(
src,
tgt,
reverse=(
_has_en(src_lang, tgt_lang) and not _is_en_x(src_lang, tgt_lang)
),
)
for src, tgt, src_lang, tgt_lang in zip(
df["src"], df["tgt"], df["src_lang"], df["tgt_lang"]
)
]
return model.predict(input)
elif mode == "bidi":
# score twice -- once forward and once backward
fwd_input = [
model_class.wrap_qe_input(src, tgt)
for src, tgt in zip(df["src"], df["tgt"])
]
rev_input = [
model_class.wrap_qe_input(src, tgt, reverse=True)
for src, tgt in zip(df["src"], df["tgt"])
]
scores = model.predict(
fwd_input + rev_input
) # making one call to take advantage of batching
# first half is forward score, second half is reverse score -- now we unpack and average
fwd_scores = scores[: len(df)]
rev_scores = scores[len(df) :]
return [(fs + rs) / 2 for fs, rs in zip(fwd_scores, rev_scores)]
else:
raise NotImplementedError

@batched
def score_document(self, df: pd.Series):
model_attr = f"{self._name}_{self._model_path}"
try:
model = load_object_on_worker(
model_attr,
self.SUPPORTED_MODELS[self._name].load_model,
{"model_name": self._name, "gpu": self._gpu},
)
except NoWorkerError:
return pd.Series([-1.0 for _ in range(len(df))])

scores = self._score_document_with_qe(model, df, self._mode)

return pd.Series(scores, index=df.index)

def keep_document(self, score):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just an idea, you can probably decorate this method with @batched too to get a slight perf bump.

return score >= self._cutoff
121 changes: 121 additions & 0 deletions nemo_curator/filters/heuristic_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os.path
import tarfile

import regex
import requests
from platformdirs import user_cache_dir

from nemo_curator.filters.doc_filter import DocumentFilter, import_filter
from nemo_curator.utils.constants import (
Expand Down Expand Up @@ -633,3 +638,119 @@ def score_document(self, text):

def keep_document(self, score):
return score != 1


class LengthRatioFilter(DocumentFilter):
shuoyangd marked this conversation as resolved.
Show resolved Hide resolved
"""
For bitext cleaning.
If the ratio between source and target tokens is not within a specified range then discard. Either direction (src/tgt, tgt/src) is considered. See mosesdecoder/scripts/training/clean-corpus-n.perl for details
"""

def __init__(self, max_ratio=3, src_lang="en", tgt_lang="en"):
super().__init__()
self._max_ratio = max_ratio
self._src_word_splitter = get_word_splitter(src_lang)
self._tgt_word_splitter = get_word_splitter(tgt_lang)
self._name = "length_ratio"

def score_document(self, bitext_tuple):
src_len = len(self._src_word_splitter(bitext_tuple.iloc[0].strip()))
shuoyangd marked this conversation as resolved.
Show resolved Hide resolved
tgt_len = len(self._tgt_word_splitter(bitext_tuple.iloc[1].strip()))
return max(src_len / tgt_len, tgt_len / src_len)

def keep_document(self, score):
return score < self._max_ratio


class HistogramFilter(DocumentFilter):
shuoyangd marked this conversation as resolved.
Show resolved Hide resolved
"""
Histogram filter used by the NLLB paper (https://arxiv.org/pdf/2207.04672). See p30 for details.

Written with reference to the original fairseq implementation at:
https://github.com/facebookresearch/fairseq/blob/main/examples/m2m_100/process_data/clean_histogram.py.
"""

def __init__(self, lang="en", threshold=0.8, cache_dir="", threshold_char="]"):
super().__init__()
self._lang = lang
self._threshold = threshold
self._cache_dir = cache_dir if cache_dir else user_cache_dir()
self._threshold_char = threshold_char
self._name = "histogram"

if not os.path.isdir(os.path.join(self._cache_dir, "histograms")):
self._download_histograms()

self._read_hist()

def _download_histograms(self):
# Send a GET request to the URL
response = requests.get(
"https://dl.fbaipublicfiles.com/m2m_100/histograms.tar.gz"
)

# Check if the request was successful
if response.status_code != 200:
raise requests.exceptions.RequestException(
f"Failed to download histogram file. Status code: {response.status_code}"
)

# Open a file to write the content
os.makedirs(self._cache_dir, exist_ok=True)
download_dest_path = os.path.join(self._cache_dir, "histograms.tar.gz")
with open(download_dest_path, "wb") as file:
file.write(response.content)

extract_path = os.path.join(self._cache_dir, "histograms")
with tarfile.open(download_dest_path, "r:gz") as tar:
# Extract all the contents into the specified directory
tar.extractall(path=extract_path)

def _read_hist(self):
self._histogram = []
with open(
os.path.join(
self._cache_dir,
"histograms",
"checkpoint",
"edunov",
"cc60_multilingual",
"clean_hists",
self._lang,
)
) as f:
for line in f:
c = line[0]
if c == self._threshold_char:
break
self._histogram.append(c)
self._histogram = set(self._histogram)

def score_document(self, text):
cnt = len([c for c in text.strip() if c in self._histogram])
return 1 if cnt / len(text) > self._threshold else 0

def keep_document(self, score):
return score == 1


class LengthRatioFilter(DocumentFilter):
shuoyangd marked this conversation as resolved.
Show resolved Hide resolved
"""
For bitext cleaning.
If the ratio between source and target tokens is not within a specified range then discard. Either direction (src/tgt, tgt/src) is considered. See mosesdecoder/scripts/training/clean-corpus-n.perl for details
"""

def __init__(self, max_ratio=3.0, src_lang="en", tgt_lang="en"):
super().__init__()
self._max_ratio = float(max_ratio)
self._src_word_splitter = get_word_splitter(src_lang)
self._tgt_word_splitter = get_word_splitter(tgt_lang)
self._name = "length_ratio"

def score_document(self, bitext_tuple):
src_len = len(self._src_word_splitter(bitext_tuple.iloc[0].strip()))
tgt_len = len(self._tgt_word_splitter(bitext_tuple.iloc[1].strip()))
return max(src_len / tgt_len, tgt_len / src_len)

def keep_document(self, score):
return score < self._max_ratio
Empty file.
Loading
Loading