-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add codepath for computing buckets without int conversion #326
base: main
Are you sure you want to change the base?
Changes from 1 commit
ccb1e31
f2b1888
816940b
30f383c
d7a2617
954a043
3b51aad
d119740
8dbc48a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,6 +47,7 @@ | |
from nemo_curator.utils.fuzzy_dedup_utils.id_mapping import int_ids_to_str | ||
from nemo_curator.utils.fuzzy_dedup_utils.io_utils import ( | ||
aggregated_anchor_docs_with_bk_read, | ||
check_empty_buckets, | ||
get_restart_offsets, | ||
update_restart_offsets, | ||
) | ||
|
@@ -198,6 +199,7 @@ def __init__( | |
num_hashes: int, | ||
num_buckets: int, | ||
buckets_per_shuffle: int = 1, | ||
buckets_as_int: bool = False, | ||
logger: Union[logging.LoggerAdapter, str] = "./", | ||
id_fields: Union[str, list] = "id", | ||
minhash_field: str = "_minhash_signature", | ||
|
@@ -228,6 +230,7 @@ def __init__( | |
self.bucket_ranges = self._generate_bucket_ranges( | ||
self.num_buckets, self.num_hashes | ||
) | ||
self.buckets_as_int = buckets_as_int | ||
|
||
if cache_dir is None: | ||
raise ValueError( | ||
|
@@ -320,6 +323,8 @@ def lsh( | |
""" | ||
Computes buckets and writes them as parquet files to the write_path | ||
""" | ||
buckets_isempty = True | ||
|
||
meta = self._minhash_to_bucket_meta(df) | ||
df = df.map_partitions( | ||
self.minhash_to_buckets, | ||
|
@@ -343,17 +348,19 @@ def lsh( | |
).map_partitions(lambda x: x[x["_bucket_id"].duplicated(keep=False)]) | ||
|
||
df2 = df2.reset_index(drop=True) | ||
df2, end_id = self.bucket_id_to_int( | ||
df2, bucket_col_name="_bucket_id", start_id=bucket_start_id | ||
) | ||
# If bucketing return empty dataframe | ||
if end_id < bucket_start_id: | ||
continue | ||
bucket_start_id = end_id + 1 | ||
if self.buckets_as_int: | ||
df2, end_id = self.bucket_id_to_int( | ||
df2, bucket_col_name="_bucket_id", start_id=bucket_start_id | ||
) | ||
# If bucketing return empty dataframe | ||
if end_id < bucket_start_id: | ||
continue | ||
bucket_start_id = end_id + 1 | ||
buckets_isempty = False | ||
|
||
# Workaround for dtype mismatches with empty partitions | ||
dtypes = df2.dtypes.to_dict() | ||
df2 = df2.map_partitions(lambda x: x.astype(dtypes)) | ||
# dtypes = df2.dtypes.to_dict() | ||
# df2 = df2.map_partitions(lambda x: x.astype(dtypes)) | ||
|
||
if i == 0: | ||
if os.path.exists(write_path): | ||
|
@@ -362,21 +369,42 @@ def lsh( | |
) | ||
df2.to_parquet(write_path, write_index=False, overwrite=True) | ||
else: | ||
df2.to_parquet(write_path, write_index=False, append=True) | ||
df2.to_parquet( | ||
write_path, | ||
write_index=False, | ||
overwrite=buckets_isempty, | ||
append=not buckets_isempty, | ||
) | ||
|
||
self._logger.info(f"Wrote data for buckets: {value_vars}") | ||
if os.path.exists(write_path) and buckets_isempty: | ||
buckets_isempty = check_empty_buckets(write_path) | ||
|
||
if buckets_isempty: | ||
self._logger.info( | ||
f"No duplicate documents found for buckets: {value_vars}" | ||
) | ||
else: | ||
self._logger.info(f"Wrote data for buckets: {value_vars}") | ||
|
||
if buckets_isempty: | ||
self._logger.info("No duplicate documents found during LSH") | ||
import shutil | ||
|
||
shutil.rmtree(write_path) | ||
return buckets_isempty | ||
|
||
def __call__(self, dataset: DocumentDataset) -> DocumentDataset: | ||
df = dataset.df | ||
|
||
write_path = os.path.join(self.cache_dir, "_buckets.parquet") | ||
t0 = time.time() | ||
with performance_report_if_with_ts_suffix(self.profile_dir, f"lsh-profile"): | ||
self.lsh(write_path=write_path, df=df) | ||
empty_result = self.lsh(write_path=write_path, df=df) | ||
self._logger.info( | ||
f"Time taken for LSH = {time.time() - t0}s and output written at {write_path}" | ||
) | ||
|
||
if empty_result: | ||
return None | ||
buckets_df = dask_cudf.read_parquet(write_path, split_row_groups=False) | ||
return DocumentDataset(buckets_df) | ||
|
||
|
@@ -425,6 +453,8 @@ def __init__( | |
num_hashes=self.config.num_hashes, | ||
num_buckets=self.config.num_buckets, | ||
buckets_per_shuffle=self.config.buckets_per_shuffle, | ||
# Only convert buckets to int if we are running false positive check | ||
buckets_as_int=self.config.false_positive_check, | ||
logger=self._logger, | ||
id_fields=[self.config.id_field], | ||
profile_dir=self.config.profile_dir, | ||
|
@@ -494,6 +524,11 @@ def __call__(self, dataset: DocumentDataset): | |
minhashLSH = Sequential([self.minhash, self.lsh]) | ||
buckets_df = minhashLSH(dataset) | ||
print(f"Stage{stage_num}: Minhash + LSH complete!") | ||
if buckets_df is None: | ||
print( | ||
f"Stage{stage_num}: No potential duplicate documents found during LSH" | ||
) | ||
return None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this return None or an empty There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer returning There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, but then for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't seen |
||
stage_num += 1 | ||
|
||
if self.config.false_positive_check: | ||
|
@@ -677,6 +712,7 @@ def buckets_to_edges( | |
|
||
def __call__(self, dataset: DocumentDataset) -> DocumentDataset: | ||
buckets_df = dataset.df | ||
self._logger.info(f"Starting conversion of LSH Buckets to Graph Edgelist") | ||
if len(self.id_fields) > 1: | ||
buckets_df = buckets_df.map_partitions( | ||
BucketsToEdges._combine_multiple_ids, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -201,3 +201,16 @@ def strip_trailing_sep(path: str): | |
Strips a path string of trailing path seperators like `/` if any. | ||
""" | ||
return path.rstrip(os.path.sep) | ||
|
||
|
||
def check_empty_buckets(bucket_path): | ||
""" | ||
Inspects parquet metadata of the buckets dataset to check if it's an empty dataset. | ||
""" | ||
from pyarrow.dataset import dataset | ||
|
||
ds = dataset(bucket_path, format="parquet") | ||
for fragment in ds.get_fragments(): | ||
if fragment.metadata.num_rows > 0: | ||
return False | ||
Comment on lines
+212
to
+215
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic can probably be simplified by using a global metadata file when writing out the parquet dataset |
||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not fir this PR, but just a highlight from our google docs convo, good place to leverage
fsspec
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. Decided to go via this route for now (since other places also use shutil). Aligned that the refactor to be more remote friendly should leverage fsspec utilities where possible.