Skip to content

Commit 7412aff

Browse files
kausvfacebook-github-bot
authored andcommitted
Change util to return a set of bucket metadata per shard (#2920)
Summary: Pull Request resolved: #2920 Based on discussions, we are modifying the util to give a master set of metadata about buckets for the given shards. We will stack a follow up diff for ShardedEC to return this data. Reviewed By: emlin Differential Revision: D73522460 fbshipit-source-id: 89648ef3531756c9a6d43422877c12454bd1ffcb
1 parent 07947d4 commit 7412aff

File tree

3 files changed

+72
-37
lines changed

3 files changed

+72
-37
lines changed

torchrec/distributed/tests/test_utils.py

+43-26
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,13 @@
2929
ModuleSharder,
3030
MultiPassPrefetchConfig,
3131
ParameterSharding,
32+
ShardingBucketMetadata,
3233
ShardMetadata,
3334
)
3435
from torchrec.distributed.utils import (
3536
add_params_from_parameter_sharding,
3637
convert_to_fbgemm_types,
37-
get_bucket_offsets_from_shard_metadata,
38+
get_bucket_metadata_from_shard_metadata,
3839
get_unsharded_module_names,
3940
merge_fused_params,
4041
)
@@ -472,15 +473,15 @@ def test_convert_to_fbgemm_types(self) -> None:
472473
self.assertFalse(isinstance(per_table_fused_params["output_dtype"], DataType))
473474

474475

475-
class TestBucketOffsets(unittest.TestCase):
476-
def test_bucket_offsets(self) -> None:
476+
class TestBucketMetadata(unittest.TestCase):
477+
def test_bucket_metadata(self) -> None:
477478
# Given no shards
478-
# When we get bucket offsets from get_bucket_offsets_from_shard_metadata
479+
# When we get bucket metadata from get_bucket_metadata_from_shard_metadata
479480
# Then an error should be raised
480481
self.assertRaisesRegex(
481482
AssertionError,
482483
"Shards cannot be empty",
483-
get_bucket_offsets_from_shard_metadata,
484+
get_bucket_metadata_from_shard_metadata,
484485
[],
485486
num_buckets=4,
486487
)
@@ -490,11 +491,13 @@ def test_bucket_offsets(self) -> None:
490491
ShardMetadata(shard_offsets=[0], shard_sizes=[5], placement="rank:0/cuda:0")
491492
]
492493

493-
# When we get bucket offsets from get_bucket_offsets_from_shard_metadata
494-
bucket_offsets = get_bucket_offsets_from_shard_metadata(shards, num_buckets=5)
494+
# When we get bucket offsets from get_bucket_metadata_from_shard_metadata
495+
bucket_metadata = get_bucket_metadata_from_shard_metadata(shards, num_buckets=5)
495496
# Then we should get 1 offset with value 0
496-
expected_offsets = [0]
497-
self.assertEqual(bucket_offsets, expected_offsets)
497+
expected_metadata = ShardingBucketMetadata(
498+
num_buckets_per_shard=[5], bucket_offsets_per_shard=[0], bucket_size=1
499+
)
500+
self.assertEqual(bucket_metadata, expected_metadata)
498501

499502
# Given 2 shards of size 5 and 4 buckets
500503
shards = [
@@ -506,12 +509,12 @@ def test_bucket_offsets(self) -> None:
506509
),
507510
]
508511

509-
# When we get bucket offsets from get_bucket_offsets_from_shard_metadata
512+
# When we get bucket offsets from get_bucket_metadata_from_shard_metadata
510513
# Then an error should be raised
511514
self.assertRaisesRegex(
512515
AssertionError,
513516
"Table size '10' must be divisible by num_buckets '4'",
514-
get_bucket_offsets_from_shard_metadata,
517+
get_bucket_metadata_from_shard_metadata,
515518
shards,
516519
num_buckets=4,
517520
)
@@ -526,12 +529,12 @@ def test_bucket_offsets(self) -> None:
526529
),
527530
]
528531

529-
# When we get bucket offsets from get_bucket_offsets_from_shard_metadata
532+
# When we get bucket offsets from get_bucket_metadata_from_shard_metadata
530533
# Then an error should be raised
531534
self.assertRaisesRegex(
532535
AssertionError,
533536
"Table size '4' must be divisible by num_buckets '5'",
534-
get_bucket_offsets_from_shard_metadata,
537+
get_bucket_metadata_from_shard_metadata,
535538
shards,
536539
num_buckets=5,
537540
)
@@ -546,12 +549,12 @@ def test_bucket_offsets(self) -> None:
546549
),
547550
]
548551

549-
# When we get bucket offsets from get_bucket_offsets_from_shard_metadata
552+
# When we get bucket offsets from get_bucket_metadata_from_shard_metadata
550553
# Then an error should be raised
551554
self.assertRaisesRegex(
552555
AssertionError,
553556
r"Shard shard_offsets\[1\] '5' is not 0. Table should be only row-wise sharded for bucketization",
554-
get_bucket_offsets_from_shard_metadata,
557+
get_bucket_metadata_from_shard_metadata,
555558
shards,
556559
num_buckets=2,
557560
)
@@ -566,17 +569,17 @@ def test_bucket_offsets(self) -> None:
566569
),
567570
]
568571

569-
# When we get bucket offsets from get_bucket_offsets_from_shard_metadata
572+
# When we get bucket offsets from get_bucket_metadata_from_shard_metadata
570573
# Then an error should be raised
571574
self.assertRaisesRegex(
572575
AssertionError,
573576
r"Shard size\[0\] '10' is not divisible by bucket size '4'",
574-
get_bucket_offsets_from_shard_metadata,
577+
get_bucket_metadata_from_shard_metadata,
575578
shards,
576579
num_buckets=5,
577580
)
578581

579-
# Given 2 shards of size 20 and 4 buckets
582+
# Given 2 shards of size 20 and 10 buckets
580583
shards = [
581584
ShardMetadata(
582585
shard_offsets=[0], shard_sizes=[20], placement="rank:0/cuda:0"
@@ -585,13 +588,20 @@ def test_bucket_offsets(self) -> None:
585588
shard_offsets=[20], shard_sizes=[20], placement="rank:0/cuda:0"
586589
),
587590
]
588-
# When we get bucket offsets from get_bucket_offsets_from_shard_metadata
589-
bucket_offsets = get_bucket_offsets_from_shard_metadata(
591+
# When we get bucket offsets from get_bucket_metadata_from_shard_metadata
592+
bucket_metadata = get_bucket_metadata_from_shard_metadata(
590593
shards,
591594
num_buckets=10,
592595
)
593-
# Then bucket offsets should be set to [0, 5]
594-
self.assertEqual(bucket_offsets, [0, 5])
596+
# Then num_buckets_per_shard should be set to [5, 5]
597+
self.assertEqual(
598+
bucket_metadata,
599+
ShardingBucketMetadata(
600+
num_buckets_per_shard=[5, 5],
601+
bucket_offsets_per_shard=[0, 5],
602+
bucket_size=4,
603+
),
604+
)
595605

596606
# Given 3 uneven shards of sizes 12, 16 and 20 and 12 buckets
597607
shards = [
@@ -606,10 +616,17 @@ def test_bucket_offsets(self) -> None:
606616
),
607617
]
608618

609-
# When we get bucket offsets from get_bucket_offsets_from_shard_metadata
610-
bucket_offsets = get_bucket_offsets_from_shard_metadata(
619+
# When we get bucket offsets from get_bucket_metadata_from_shard_metadata
620+
bucket_metadata = get_bucket_metadata_from_shard_metadata(
611621
shards,
612622
num_buckets=12,
613623
)
614-
# Then bucket offsets should be set to [0, 3, 7]
615-
self.assertEqual(bucket_offsets, [0, 3, 7])
624+
# Then num_buckets_per_shard should be set to [3, 4, 5]
625+
self.assertEqual(
626+
bucket_metadata,
627+
ShardingBucketMetadata(
628+
num_buckets_per_shard=[3, 4, 5],
629+
bucket_offsets_per_shard=[0, 3, 7],
630+
bucket_size=4,
631+
),
632+
)

torchrec/distributed/types.py

+16
Original file line numberDiff line numberDiff line change
@@ -1223,3 +1223,19 @@ class ObjectPoolShardingType(Enum):
12231223
class ObjectPoolShardingPlan(ModuleShardingPlan):
12241224
sharding_type: ObjectPoolShardingType
12251225
inference: bool = False
1226+
1227+
1228+
@dataclass
1229+
class ShardingBucketMetadata:
1230+
"""
1231+
If a table is row-wise sharded with bucketization, this class contains the bucket information for the table.
1232+
1233+
Attributes:
1234+
num_buckets_per_shard (List[int]): Number of buckets in each shard of the table.
1235+
bucket_offsets_per_shard (List[int]): Index of the first bucket in each shard.
1236+
bucket_size (int): No. of rows in one bucket.
1237+
"""
1238+
1239+
num_buckets_per_shard: List[int]
1240+
bucket_offsets_per_shard: List[int]
1241+
bucket_size: int

torchrec/distributed/utils.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
EmbeddingEvent,
2929
ParameterSharding,
3030
ShardedModule,
31+
ShardingBucketMetadata,
3132
ShardingType,
3233
ShardMetadata,
3334
)
@@ -567,32 +568,32 @@ def create_global_tensor_shape_stride_from_metadata(
567568
return size, (size[1], 1) if size else (torch.Size([0, 0]), (0, 1))
568569

569570

570-
def get_bucket_offsets_from_shard_metadata(
571+
def get_bucket_metadata_from_shard_metadata(
571572
shards: List[ShardMetadata],
572573
num_buckets: int,
573-
) -> List[int]:
574+
) -> ShardingBucketMetadata:
574575
"""
575-
Calculate the bucket offsets from shard metadata.
576+
Calculate the bucket metadata from shard metadata.
576577
577578
This function assumes the table is to be row-wise sharded in equal sized buckets across bucket boundaries.
578-
It computes the sequential bucket offsets for each shard. It ensures that the table size
579-
is divisible by the number of buckets and that each shard size is divisible
580-
by the bucket size.
579+
It computes the number of buckets per shard and the bucket size.
581580
582581
Args:
583-
shards (List[ShardMetadata]): A list of shard metadata objects.
582+
shards (List[ShardMetadata]): Shard metadata for all shards of a table.
584583
num_buckets (int): The number of buckets to divide the table into.
585584
586585
Returns:
587-
List[int]: A list of bucket offsets.
586+
ShardingBucketMetadata: An object containing the number of buckets per shard and the bucket size.
588587
"""
589588
assert len(shards) > 0, "Shards cannot be empty"
590589
table_size = shards[-1].shard_offsets[0] + shards[-1].shard_sizes[0]
591590
assert (
592591
table_size % num_buckets == 0
593592
), f"Table size '{table_size}' must be divisible by num_buckets '{num_buckets}'"
594-
bucket_offsets: List[int] = []
595593
bucket_size = table_size // num_buckets
594+
bucket_metadata: ShardingBucketMetadata = ShardingBucketMetadata(
595+
num_buckets_per_shard=[], bucket_offsets_per_shard=[], bucket_size=bucket_size
596+
)
596597
current_bucket_offset = 0
597598
for shard in shards:
598599
assert (
@@ -602,7 +603,8 @@ def get_bucket_offsets_from_shard_metadata(
602603
shard.shard_sizes[0] % bucket_size == 0
603604
), f"Shard size[0] '{shard.shard_sizes[0]}' is not divisible by bucket size '{bucket_size}'"
604605
num_buckets_in_shard = shard.shard_sizes[0] // bucket_size
605-
bucket_offsets.append(current_bucket_offset)
606+
bucket_metadata.num_buckets_per_shard.append(num_buckets_in_shard)
607+
bucket_metadata.bucket_offsets_per_shard.append(current_bucket_offset)
606608
current_bucket_offset += num_buckets_in_shard
607609

608-
return bucket_offsets
610+
return bucket_metadata

0 commit comments

Comments
 (0)