Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove max_text_bytes_per_part #385

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 3 additions & 76 deletions nemo_curator/modules/fuzzy_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,7 @@
filter_text_rows_by_bucket_batch,
merge_left_to_shuffled_right,
)
from nemo_curator.utils.fuzzy_dedup_utils.output_map_utils import (
build_partition,
get_agg_text_bytes_df,
)
from nemo_curator.utils.fuzzy_dedup_utils.output_map_utils import get_agg_text_bytes_df
sarahyurick marked this conversation as resolved.
Show resolved Hide resolved
from nemo_curator.utils.fuzzy_dedup_utils.shuffle_utils import write_partitioned_file


Expand Down Expand Up @@ -813,84 +810,15 @@ def __init__(
else:
self._logger = logger

@staticmethod
def _get_output_part_ids_with_approx_equal_sum(
bucket_text_bytes_df: cudf.DataFrame,
max_text_bytes_per_part: int,
buckets_column: str,
bytes_column: str,
output_partition_column: str,
) -> cudf.DataFrame:
"""
Create a output_series that maps the ser.index into `nparts`
so that the total sum of bucket_val_counts_df
for each output id are all most equal and
less than max_text_bytes_per_part
This is used downstream for creating equal output_ids
"""
sizes = bucket_text_bytes_df[bytes_column].values
bucket_output_ar = build_partition(
sizes=sizes.get(), max_size=max_text_bytes_per_part
)
df = cudf.DataFrame()
df[buckets_column] = bucket_text_bytes_df[buckets_column]
df[output_partition_column] = bucket_output_ar
return df

def _get_output_map_from_text_bytes_per_bucket(
sarahyurick marked this conversation as resolved.
Show resolved Hide resolved
self,
ddf_bk_text_bytes,
bytes_column,
output_partition_column="_output_partition_id",
):
# String bytes limit for cuDF
# https://github.com/rapidsai/cudf/issues/13733
max_text_bytes_per_part = int(np.iinfo(np.int32).max * 3)

self._logger.info(f"max_text_bytes_per_part = {max_text_bytes_per_part}")
# Increasing in an attempt to prevent hitting
# ulimits
output_map_df_meta = cudf.DataFrame(
{self.bucket_field: [0], output_partition_column: [1]}
)
output_map_df_meta = output_map_df_meta.astype(
{self.bucket_field: np.uint64, output_partition_column: np.int32}
)

output_map_df = ddf_bk_text_bytes.map_partitions(
_MapBuckets._get_output_part_ids_with_approx_equal_sum,
max_text_bytes_per_part=max_text_bytes_per_part,
buckets_column=self.bucket_field,
bytes_column=bytes_column,
output_partition_column=output_partition_column,
meta=output_map_df_meta,
)
output_map_df = ddf_bk_text_bytes.assign(**{output_partition_column: 0})
sarahyurick marked this conversation as resolved.
Show resolved Hide resolved
output_map_df = output_map_df.persist()
self._logger.info(
f"Step 1 of output_map_df of len: {len(output_map_df)} computed"
)
lower_bounds = (
output_map_df[output_partition_column]
.map_partitions(lambda s: (s.max() + 1))
.compute()
)
lower_bounds = np.cumsum(lower_bounds)

def update_id(df, lower_bound):
df[output_partition_column] += lower_bound
return df

updated_parts = [
output_map_df.get_partition(i).map_partitions(
update_id, lower_bounds[i - 1]
)
for i in range(1, len(lower_bounds))
]
updated_parts.append(output_map_df.get_partition(0))
output_map_df = dask_cudf.concat(updated_parts)
output_map_df = output_map_df.persist()
self._logger.info(
f"All steps of output_map_df of len: {len(output_map_df)} computed"
f"Output map computed with no max limit. Len: {len(output_map_df)}"
)
return output_map_df

Expand Down Expand Up @@ -923,7 +851,6 @@ def _get_output_map_based_on_str_bytes(
del buckets_df
output_map_df = self._get_output_map_from_text_bytes_per_bucket(
ddf_bk_text_bytes=ddf_bk_text_bytes,
bytes_column=bytes_column,
)
return output_map_df

Expand Down
26 changes: 0 additions & 26 deletions nemo_curator/utils/fuzzy_dedup_utils/output_map_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,29 +61,3 @@ def get_agg_text_bytes_df(
agg_df_len = len(agg_df)

return agg_df, agg_df_len
sarahyurick marked this conversation as resolved.
Show resolved Hide resolved


# next-fit-descending bin packing
# https://en.wikipedia.org/wiki/Next-fit-decreasing_bin_packing
@numba.jit(nopython=True)
def build_partition(sizes: np.ndarray, max_size: int) -> np.ndarray:
"""
Given an array of items and a max bin size this method
attempts to return a grouping of items such that no group exceeds
the max bin size using the Next-fit-decreasing bin packing approach.
"""
i: int = 0
count: int = 0
current: int = 0
size: int = 0
partition = np.empty(sizes.shape, dtype=np.int32)
for i in range(len(sizes)):
size = sizes[i]
if current + size < max_size:
partition[i] = count
current += size
else:
count += 1
current = size
partition[i] = count
return partition
28 changes: 0 additions & 28 deletions nemo_curator/utils/fuzzy_dedup_utils/shuffle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from packaging.version import Version

from nemo_curator._compat import query_planning_enabled
from nemo_curator.utils.fuzzy_dedup_utils.output_map_utils import build_partition

dask_cuda_version = Version(dask_cuda.__version__)
USE_EXCOMMS = (
Expand Down Expand Up @@ -95,30 +94,3 @@ def rearange_by_column_direct(
npartitions=npartitions,
ignore_index=ignore_index,
)


def get_shuffle_part_ids_df(
agg_df,
partition_on,
output_col,
size_col,
num_workers=0,
):
sizes = agg_df[size_col].values
max_text_bytes_per_part = int(np.iinfo(np.int32).max * 3)

# Adjust max_text_bytes_per_part if the number of output
# partitions is small compared to the number of workers.
# Sometimes we just have very few output partitions to
# deal with, and just need a larger batch
npartitions_min = max(1, int(num_workers * 0.8))
while True:
output_ar = build_partition(sizes.get(), max_text_bytes_per_part)
if output_ar.max() > npartitions_min or max_text_bytes_per_part < 2**24:
break
max_text_bytes_per_part = int(max_text_bytes_per_part // 2.0)

df = cudf.DataFrame()
df[partition_on] = agg_df[partition_on]
df[output_col] = output_ar
return df