Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GPU error messages for fuzzy deduplication #387

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions nemo_curator/datasets/doc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def __len__(self) -> int:
def persist(self) -> "DocumentDataset":
return DocumentDataset(self.df.persist())

def to_backend(self, backend: Optional[str] = None) -> "DocumentDataset":
return DocumentDataset(self.df.to_backend(backend))

@wraps(dd.DataFrame.repartition)
def repartition(self, *args, **kwargs) -> "DocumentDataset":
return self.__class__(self.df.repartition(*args, **kwargs))
Expand Down
75 changes: 54 additions & 21 deletions nemo_curator/modules/fuzzy_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def minhash64(
"""
if not isinstance(ser, cudf.Series):
raise TypeError("Expected data of type cudf.Series")

if not MINHASH_PERMUTED_AVAILABLE:
warnings.warn(
"Using an outdated minhash implementation, please update to cuDF version 24.12 "
Expand Down Expand Up @@ -224,6 +225,12 @@ def __call__(self, dataset: DocumentDataset) -> Union[str, DocumentDataset]:
-------
DocumentDataset containing IDs of all documents and the corresponding MinHash Signature
"""
if "cudf" not in str(type(dataset.df)):
raise TypeError(
"Dask-cuDF DataFrame is required to run minhashes. "
'Please convert your DocumentDataset by using .to_backend("gpu").'
)

result = dataset.df[[self.id_field]]
result["_minhash_signature"] = dataset.df[self.text_field].map_partitions(
self.minhash_method,
Expand Down Expand Up @@ -492,6 +499,12 @@ def _write_bucket_parquet(
return wrote_buckets, are_buckets_empty

def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
if "cudf" not in str(type(dataset.df)):
raise TypeError(
"Dask-cuDF DataFrame is required to run locality-sensitive hashing. "
'Please convert your DocumentDataset by using .to_backend("gpu").'
)

df = dataset.df

write_path = os.path.join(self.cache_dir, "_buckets.parquet")
Expand Down Expand Up @@ -613,23 +626,28 @@ def __call__(self, dataset: DocumentDataset):
DocumentDataset containing IDs of all documents and the corresponding duplicate group
they belong to. Documents in the same group are near duplicates.
"""
if "cudf" not in str(type(dataset.df)):
raise TypeError(
"Dask-cuDF DataFrame is required to run fuzzy deduplication. "
'Please convert your DocumentDataset by using .to_backend("gpu").'
)

# Minhash + LSH
stage_num = 1
print(f"Stage{stage_num}: Starting Minhash + LSH computation")
print(f"Stage {stage_num}: Starting Minhash + LSH computation")
minhashLSH = Sequential([self.minhash, self.lsh])
buckets_df = minhashLSH(dataset)
print(f"Stage{stage_num}: Minhash + LSH complete!")
print(f"Stage {stage_num}: Minhash + LSH complete!")
if buckets_df is None:
print(
f"Stage{stage_num}: No potential duplicate documents found during LSH"
f"Stage {stage_num}: No potential duplicate documents found during LSH"
)
return None
stage_num += 1

if self.config.false_positive_check:
# Map buckets to lower cardinality distribution
print(f"Stage{stage_num} (False Positive Check): Starting Map_Buckets")
print(f"Stage {stage_num} (False Positive Check): Starting Map_Buckets")
t0 = time.time()
mapped_buckets_w_anchors_path = os.path.join(
self.config.cache_dir, "anchor_docs_with_bk.parquet"
Expand All @@ -647,14 +665,14 @@ def __call__(self, dataset: DocumentDataset):
mapped_buckets_w_anchors_path, write_index=False, overwrite=True
)
self._logger.info(
f"Time taken for Map_buckets : {time.time() - t0}s and output written at {mapped_buckets_w_anchors_path}"
f"Time taken for Map_Buckets: {time.time() - t0}s and output written at {mapped_buckets_w_anchors_path}"
)

print(f"Stage{stage_num} (False Postive Check): Map_Buckets Complete!")
print(f"Stage {stage_num} (False Positive Check): Map_Buckets complete!")
stage_num += 1

# Shuffle documents based on mapped buckets
print(f"Stage{stage_num} (False Postive Check): Shuffle docs")
print(f"Stage {stage_num} (False Positive Check): Shuffle Documents")
shuffled_docs_path = os.path.join(
self.config.cache_dir, "shuffled_docs.parquet"
)
Expand All @@ -666,12 +684,14 @@ def __call__(self, dataset: DocumentDataset):
parts_per_worker=self.config.parts_per_worker,
bucket_parts_per_worker=self.config.bucket_parts_per_worker,
)
print(f"Stage{stage_num} (False Postive Check): Shuffle docs complete!")
print(
f"Stage {stage_num} (False Positive Check): Shuffling Documents complete!"
)
stage_num += 1

# jaccard comparision within buckets
print(
f"Stage{stage_num} (False Postive Check): Jaccard Similarity in Buckets"
f"Stage {stage_num} (False Positive Check): Jaccard Similarity in Buckets"
)
jaccard_pairs_path = os.path.join(
self.config.cache_dir, "jaccard_similarity_results.parquet"
Expand All @@ -691,26 +711,28 @@ def __call__(self, dataset: DocumentDataset):
overwrite=True,
)
self._logger.info(
f"Time taken for Jaccard Similarity = {time.time()-t0}s and output written at {jaccard_pairs_path}"
f"Time taken for Jaccard Similarity: {time.time()-t0}s and output written at {jaccard_pairs_path}"
)

print(
f"Stage{stage_num} (False Postive Check): Jaccard Similarity in Buckets Complete!"
f"Stage {stage_num} (False Positive Check): Jaccard Similarity in Buckets complete!"
)
stage_num += 1

else:
# Map buckets to lower cardinality distribution
print(f"Stage{stage_num}: Starting LSH Buckets to Graph edgelist")
print(f"Stage {stage_num}: Starting LSH Buckets to Graph Edgelist")
self.buckets_to_edges(buckets_df)
print(f"Stage{stage_num}: Starting LSH Buckets to Graph edgelist Complete!")
print(
f"Stage {stage_num}: Starting LSH Buckets to Graph Edgelist complete!"
)
stage_num += 1

# Connected components across buckets
print(f"Stage{stage_num}: Connected Components across buckets")
print(f"Stage {stage_num}: Connected Components Across Buckets")
cc_path = os.path.join(self.config.cache_dir, "connected_components.parquet")
self.connected_components.cc_workflow(cc_path)
print(f"Stage{stage_num}: Connected Components across buckets complete!")
self.connected_components(cc_path)
print(f"Stage{stage_num}: Connected Components Across Buckets complete!")
stage_num += 1

return DocumentDataset(dask_cudf.read_parquet(cc_path, split_row_groups=False))
Expand Down Expand Up @@ -808,6 +830,12 @@ def buckets_to_edges(
return result_df

def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
if "cudf" not in str(type(dataset.df)):
raise TypeError(
"Dask-cuDF DataFrame is required to run buckets to edges. "
'Please convert your DocumentDataset by using .to_backend("gpu").'
)

buckets_df = dataset.df
self._logger.info(f"Starting conversion of LSH Buckets to Graph Edgelist")
if len(self.id_fields) > 1:
Expand Down Expand Up @@ -1351,9 +1379,6 @@ def __init__(
self.right_id = f"{self.id_field}_y"
self.ngram_width = ngram_width

def __call__(DocumentDataset):
raise NotImplementedError

def jaccard_compute(self, shuffled_docs_path):
paths = [
entry.path
Expand Down Expand Up @@ -1539,7 +1564,7 @@ def __init__(
else:
self._logger = logger

def cc_workflow(self, output_path):
def __call__(self, output_path):
deduped_parsed_id_path = self._write_dedup_parsed_id()
encoded_jaccard_pair_path = self._write_encoded_jaccard_pair(
deduped_parsed_id_path
Expand All @@ -1563,7 +1588,15 @@ def _run_connected_components(
self.profile_dir, "connected-components-run"
):

Comms.initialize(p2p=False)
try:
Comms.initialize(p2p=False)
except ValueError:
raise TypeError(
"A GPU-based Dask client is required to run connected components. "
'Please initialize your client with get_client(cluster_type="gpu") '
"or with a LocalCUDACluster."
)

df = dask_cudf.read_parquet(
deduped_encoded_jaccard_path, blocksize="1GB", aggregate_files=True
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def main(args):
logger=args.log_dir,
profile_dir=args.profile_path,
)
components_stage.cc_workflow(output_path=output_path)
components_stage(output_path=output_path)
print(f"All done in {time.time()-st:.1f} seconds")
print(f"Results written to {output_path}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2751,7 +2751,7 @@
" id_column=id_field,\n",
" jaccard_threshold=jaccard_threshold,\n",
")\n",
"components_stage.cc_workflow(output_path=output_path)\n",
"components_stage(output_path=output_path)\n",
"print(f\"Connected Component took {time.time()-t0} seconds\")"
]
},
Expand Down Expand Up @@ -4455,7 +4455,7 @@
" id_column=id_field,\n",
" jaccard_threshold=jaccard_threshold,\n",
")\n",
"components_stage.cc_workflow(output_path=output_path)\n",
"components_stage(output_path=output_path)\n",
"print(f\"Connected Component took {time.time()-t0} seconds\")"
]
},
Expand Down
2 changes: 1 addition & 1 deletion tutorials/single_node_tutorial/single_gpu_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1753,7 +1753,7 @@
")\n",
"\n",
"#Load and run connected component\n",
"components_stage.cc_workflow(output_path=connected_component_output_path)\n",
"components_stage(output_path=connected_component_output_path)\n",
"print(f\"Time taken for Connected Component: {time.time()-t0} s\")"
]
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,5 @@
)

# Load and run connected components
components_stage.cc_workflow(output_path=connected_component_output_path)
components_stage(output_path=connected_component_output_path)
logging.info(f"Time taken for Connected Components: {time.time() - t0:.2f} s")
Loading