29
29
ModuleSharder ,
30
30
MultiPassPrefetchConfig ,
31
31
ParameterSharding ,
32
+ ShardingBucketMetadata ,
32
33
ShardMetadata ,
33
34
)
34
35
from torchrec .distributed .utils import (
35
36
add_params_from_parameter_sharding ,
36
37
convert_to_fbgemm_types ,
37
- get_bucket_offsets_from_shard_metadata ,
38
+ get_bucket_metadata_from_shard_metadata ,
38
39
get_unsharded_module_names ,
39
40
merge_fused_params ,
40
41
)
@@ -472,15 +473,15 @@ def test_convert_to_fbgemm_types(self) -> None:
472
473
self .assertFalse (isinstance (per_table_fused_params ["output_dtype" ], DataType ))
473
474
474
475
475
- class TestBucketOffsets (unittest .TestCase ):
476
- def test_bucket_offsets (self ) -> None :
476
+ class TestBucketMetadata (unittest .TestCase ):
477
+ def test_bucket_metadata (self ) -> None :
477
478
# Given no shards
478
- # When we get bucket offsets from get_bucket_offsets_from_shard_metadata
479
+ # When we get bucket metadata from get_bucket_metadata_from_shard_metadata
479
480
# Then an error should be raised
480
481
self .assertRaisesRegex (
481
482
AssertionError ,
482
483
"Shards cannot be empty" ,
483
- get_bucket_offsets_from_shard_metadata ,
484
+ get_bucket_metadata_from_shard_metadata ,
484
485
[],
485
486
num_buckets = 4 ,
486
487
)
@@ -490,11 +491,13 @@ def test_bucket_offsets(self) -> None:
490
491
ShardMetadata (shard_offsets = [0 ], shard_sizes = [5 ], placement = "rank:0/cuda:0" )
491
492
]
492
493
493
- # When we get bucket offsets from get_bucket_offsets_from_shard_metadata
494
- bucket_offsets = get_bucket_offsets_from_shard_metadata (shards , num_buckets = 5 )
494
+ # When we get bucket offsets from get_bucket_metadata_from_shard_metadata
495
+ bucket_metadata = get_bucket_metadata_from_shard_metadata (shards , num_buckets = 5 )
495
496
# Then we should get 1 offset with value 0
496
- expected_offsets = [0 ]
497
- self .assertEqual (bucket_offsets , expected_offsets )
497
+ expected_metadata = ShardingBucketMetadata (
498
+ num_buckets_per_shard = [5 ], bucket_offsets_per_shard = [0 ], bucket_size = 1
499
+ )
500
+ self .assertEqual (bucket_metadata , expected_metadata )
498
501
499
502
# Given 2 shards of size 5 and 4 buckets
500
503
shards = [
@@ -506,12 +509,12 @@ def test_bucket_offsets(self) -> None:
506
509
),
507
510
]
508
511
509
- # When we get bucket offsets from get_bucket_offsets_from_shard_metadata
512
+ # When we get bucket offsets from get_bucket_metadata_from_shard_metadata
510
513
# Then an error should be raised
511
514
self .assertRaisesRegex (
512
515
AssertionError ,
513
516
"Table size '10' must be divisible by num_buckets '4'" ,
514
- get_bucket_offsets_from_shard_metadata ,
517
+ get_bucket_metadata_from_shard_metadata ,
515
518
shards ,
516
519
num_buckets = 4 ,
517
520
)
@@ -526,12 +529,12 @@ def test_bucket_offsets(self) -> None:
526
529
),
527
530
]
528
531
529
- # When we get bucket offsets from get_bucket_offsets_from_shard_metadata
532
+ # When we get bucket offsets from get_bucket_metadata_from_shard_metadata
530
533
# Then an error should be raised
531
534
self .assertRaisesRegex (
532
535
AssertionError ,
533
536
"Table size '4' must be divisible by num_buckets '5'" ,
534
- get_bucket_offsets_from_shard_metadata ,
537
+ get_bucket_metadata_from_shard_metadata ,
535
538
shards ,
536
539
num_buckets = 5 ,
537
540
)
@@ -546,12 +549,12 @@ def test_bucket_offsets(self) -> None:
546
549
),
547
550
]
548
551
549
- # When we get bucket offsets from get_bucket_offsets_from_shard_metadata
552
+ # When we get bucket offsets from get_bucket_metadata_from_shard_metadata
550
553
# Then an error should be raised
551
554
self .assertRaisesRegex (
552
555
AssertionError ,
553
556
r"Shard shard_offsets\[1\] '5' is not 0. Table should be only row-wise sharded for bucketization" ,
554
- get_bucket_offsets_from_shard_metadata ,
557
+ get_bucket_metadata_from_shard_metadata ,
555
558
shards ,
556
559
num_buckets = 2 ,
557
560
)
@@ -566,17 +569,17 @@ def test_bucket_offsets(self) -> None:
566
569
),
567
570
]
568
571
569
- # When we get bucket offsets from get_bucket_offsets_from_shard_metadata
572
+ # When we get bucket offsets from get_bucket_metadata_from_shard_metadata
570
573
# Then an error should be raised
571
574
self .assertRaisesRegex (
572
575
AssertionError ,
573
576
r"Shard size\[0\] '10' is not divisible by bucket size '4'" ,
574
- get_bucket_offsets_from_shard_metadata ,
577
+ get_bucket_metadata_from_shard_metadata ,
575
578
shards ,
576
579
num_buckets = 5 ,
577
580
)
578
581
579
- # Given 2 shards of size 20 and 4 buckets
582
+ # Given 2 shards of size 20 and 10 buckets
580
583
shards = [
581
584
ShardMetadata (
582
585
shard_offsets = [0 ], shard_sizes = [20 ], placement = "rank:0/cuda:0"
@@ -585,13 +588,20 @@ def test_bucket_offsets(self) -> None:
585
588
shard_offsets = [20 ], shard_sizes = [20 ], placement = "rank:0/cuda:0"
586
589
),
587
590
]
588
- # When we get bucket offsets from get_bucket_offsets_from_shard_metadata
589
- bucket_offsets = get_bucket_offsets_from_shard_metadata (
591
+ # When we get bucket offsets from get_bucket_metadata_from_shard_metadata
592
+ bucket_metadata = get_bucket_metadata_from_shard_metadata (
590
593
shards ,
591
594
num_buckets = 10 ,
592
595
)
593
- # Then bucket offsets should be set to [0, 5]
594
- self .assertEqual (bucket_offsets , [0 , 5 ])
596
+ # Then num_buckets_per_shard should be set to [5, 5]
597
+ self .assertEqual (
598
+ bucket_metadata ,
599
+ ShardingBucketMetadata (
600
+ num_buckets_per_shard = [5 , 5 ],
601
+ bucket_offsets_per_shard = [0 , 5 ],
602
+ bucket_size = 4 ,
603
+ ),
604
+ )
595
605
596
606
# Given 3 uneven shards of sizes 12, 16 and 20 and 12 buckets
597
607
shards = [
@@ -606,10 +616,17 @@ def test_bucket_offsets(self) -> None:
606
616
),
607
617
]
608
618
609
- # When we get bucket offsets from get_bucket_offsets_from_shard_metadata
610
- bucket_offsets = get_bucket_offsets_from_shard_metadata (
619
+ # When we get bucket offsets from get_bucket_metadata_from_shard_metadata
620
+ bucket_metadata = get_bucket_metadata_from_shard_metadata (
611
621
shards ,
612
622
num_buckets = 12 ,
613
623
)
614
- # Then bucket offsets should be set to [0, 3, 7]
615
- self .assertEqual (bucket_offsets , [0 , 3 , 7 ])
624
+ # Then num_buckets_per_shard should be set to [3, 4, 5]
625
+ self .assertEqual (
626
+ bucket_metadata ,
627
+ ShardingBucketMetadata (
628
+ num_buckets_per_shard = [3 , 4 , 5 ],
629
+ bucket_offsets_per_shard = [0 , 3 , 7 ],
630
+ bucket_size = 4 ,
631
+ ),
632
+ )
0 commit comments