-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[WIP] addressed comments in #193 apart from resolving .iloc pattern, …
…test currently failing Signed-off-by: Shuoyang Ding <[email protected]>
- Loading branch information
Showing
15 changed files
with
416 additions
and
274 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
import csv | ||
from typing import List, Optional, Tuple, Union | ||
|
||
import dask.dataframe as dd | ||
import pandas as pd | ||
|
||
from nemo_curator.datasets.doc_dataset import DocumentDataset | ||
from nemo_curator.utils.distributed_utils import write_to_disk | ||
from nemo_curator.utils.file_utils import remove_path_extension | ||
from nemo_curator.utils.import_utils import gpu_only_import | ||
|
||
cudf = gpu_only_import("cudf") | ||
dask_cudf = gpu_only_import("dask_cudf") | ||
|
||
|
||
class ParallelDataset(DocumentDataset): | ||
""" | ||
An extension of the standard `DocumentDataset` with a special method that loads simple bitext. | ||
For data with more complicated metadata, please convert your data into jsonl/parquet/pickle format | ||
and use interfaces defined in `DocumentDataset`. | ||
""" | ||
|
||
def persist(self): | ||
return ParallelDataset(self.df.persist()) | ||
|
||
@classmethod | ||
def read_simple_bitext( | ||
cls, | ||
src_input_files: Union[str, List[str]], | ||
tgt_input_files: Union[str, List[str]], | ||
src_lang: str, | ||
tgt_lang: str, | ||
backend: str = "pandas", | ||
add_filename: bool = False, | ||
partition_size: Optional[Union[int, str]] = "100MB", | ||
): | ||
"""See `read_single_simple_bitext_file_pair` docstring for what "simple_bitext" means and usage of other parameters. | ||
Args: | ||
src_input_files (Union[str, List[str]]): one or several input files, in source language | ||
tgt_input_files (Union[str, List[str]]): one or several input files, in target language | ||
Raises: | ||
TypeError: If types of `src_input_files` and `tgt_input_files` doesn't agree. | ||
Returns: | ||
ParallelDataset: A `ParallelDataset` object with `self.df` holding the ingested simple bitext. | ||
""" | ||
|
||
if isinstance(src_input_files, str) and isinstance(tgt_input_files, str): | ||
src_input_files = [src_input_files] | ||
tgt_input_files = [tgt_input_files] | ||
elif not isinstance(src_input_files, list) or not isinstance( | ||
tgt_input_files, list | ||
): | ||
raise TypeError("Both file inputs must be strings or lists.") | ||
|
||
# TODO: use default doc id for now | ||
# but it might be useful to allow customizing doc id by passing a prefix | ||
df = dd.from_map( | ||
ParallelDataset.read_single_simple_bitext_file_pair, | ||
list(zip(src_input_files, tgt_input_files)), | ||
src_lang=src_lang, | ||
tgt_lang=tgt_lang, | ||
backend=backend, | ||
add_filename=add_filename, | ||
) | ||
|
||
# if partition_size: | ||
# df = df.repartition(partition_size=partition_size) | ||
return cls(df) | ||
|
||
def to_bitext( | ||
self, | ||
output_file_dir, | ||
write_to_filename=False, | ||
): | ||
"""See `nemo_curator.utils.distributed_utils.write_to_disk` docstring for parameter usage.""" | ||
write_to_disk( | ||
df=self.df, | ||
output_file_dir=output_file_dir, | ||
write_to_filename=write_to_filename, | ||
output_type="bitext", | ||
) | ||
|
||
@staticmethod | ||
def read_single_simple_bitext_file_pair( | ||
input_file_pair: Tuple[str], | ||
src_lang: str, | ||
tgt_lang: str, | ||
doc_id: str = None, | ||
backend: str = "cudf", | ||
add_filename: bool = False, | ||
) -> Union[dd.DataFrame, dask_cudf.DataFrame]: | ||
"""This function reads a pair of "simple bitext" files into a pandas DataFrame. | ||
A simple bitext is a commonly data format in machine translation. | ||
It consists of two plain text files with the same number of lines, each line pair being translations of each other. For example: | ||
data.de: | ||
``` | ||
Wir besitzen keine Reisetaschen aus Leder. | ||
Die Firma produziert Computer für den deutschen Markt. | ||
... | ||
``` | ||
data.en: | ||
``` | ||
We don't own duffel bags made of leather. | ||
The company produces computers for the German market. | ||
... | ||
``` | ||
For simplicity, we also assume that the names of the two text files have the same prefix, except for different language code at the end as file extensions. | ||
Args: | ||
input_file_pair (Tuple[str]): A pair of file paths pointing to the input files | ||
src_lang (str): Source language, in ISO-639-1 (two character) format (e.g. 'en') | ||
tgt_lang (str): Target language, in ISO-639-1 (two character) format (e.g. 'en') | ||
doc_id (str, optional): A string document id to assign to every segment in the file. Defaults to None. | ||
backend (str, optional): Backend of the data frame. Defaults to "cudf". | ||
add_filename (bool, optional): Add filename as an extra field to every segment in the file. Defaults to False. | ||
Returns: | ||
Union[dd.DataFrame, dask_cudf.DataFrame] | ||
""" | ||
src_input_file, tgt_input_file = input_file_pair | ||
assert remove_path_extension(src_input_file) == remove_path_extension( | ||
tgt_input_file | ||
), f"Assuming source and target filenames would have common prefix before language code, but got {src_input_file} and {tgt_input_file}." | ||
|
||
if not doc_id: | ||
doc_id = "▁".join([src_input_file, tgt_input_file]) | ||
df_combined["doc_id"] = doc_id | ||
|
||
# TODO: it seems like cudf.read_table can only take one file max | ||
# so maybe we shouldn't pass more than one | ||
if backend == "cudf": | ||
df = cudf | ||
else: | ||
df = pd | ||
|
||
df_src = df.read_table(src_input_file, names=["src"], quoting=csv.QUOTE_NONE) | ||
df_tgt = df.read_table(tgt_input_file, names=["tgt"], quoting=csv.QUOTE_NONE) | ||
assert len(df_src) == len( | ||
df_tgt | ||
), f"We assume the source and target file would have the same number of lines, but got {len(df_src)} and {len(df_tgt)}." | ||
df_combined = df.concat([df_src, df_tgt], axis=1) | ||
df_combined["src_lang"] = src_lang | ||
df_combined["tgt_lang"] = tgt_lang | ||
|
||
if add_filename: | ||
df_combined["filename"] = remove_path_extension(src_input_file) | ||
|
||
return df_combined |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.