Skip to content

Commit b3a6e22

Browse files
isururanawakafacebook-github-bot
authored andcommitted
Add random_seed for regular model parallel tests to ensure actual randomness in generating embeddings/inputs etc... (#3158)
Summary: Add random_seed as an optional parameter for gen_model_and_input method that can be used by any other testing methods. Differential Revision: D77742701
1 parent 538bfa4 commit b3a6e22

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

torchrec/distributed/test_utils/test_model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,14 @@ def generate(
100100
indices_dtype: torch.dtype = torch.int64,
101101
offsets_dtype: torch.dtype = torch.int64,
102102
lengths_dtype: torch.dtype = torch.int64,
103+
random_seed: Optional[int] = None,
103104
) -> Tuple["ModelInput", List["ModelInput"]]:
104105
"""
105106
Returns a global (single-rank training) batch
106107
and a list of local (multi-rank training) batches of world_size.
107108
"""
109+
if random_seed is not None:
110+
torch.manual_seed(random_seed)
108111
batch_size_by_rank = [batch_size] * world_size
109112
if variable_batch_size:
110113
batch_size_by_rank = [
@@ -751,9 +754,11 @@ def generate_variable_batch_input(
751754
indices_dtype: torch.dtype = torch.int64,
752755
offsets_dtype: torch.dtype = torch.int64,
753756
lengths_dtype: torch.dtype = torch.int64,
757+
random_seed: Optional[int] = None,
754758
) -> Tuple["ModelInput", List["ModelInput"]]:
755-
torch.manual_seed(100)
756-
random.seed(100)
759+
if random_seed is not None:
760+
torch.manual_seed(random_seed)
761+
random.seed(random_seed)
757762
dedup_factor = 2
758763

759764
global_kjt, local_kjts = ModelInput._generate_variable_batch_features(

torchrec/distributed/test_utils/test_sharding.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def __call__(
135135
indices_dtype: torch.dtype = torch.int64,
136136
offsets_dtype: torch.dtype = torch.int64,
137137
lengths_dtype: torch.dtype = torch.int64,
138+
random_seed: Optional[int] = None,
138139
) -> Tuple["ModelInput", List["ModelInput"]]: ...
139140

140141

@@ -152,6 +153,7 @@ def __call__(
152153
indices_dtype: torch.dtype = torch.int64,
153154
offsets_dtype: torch.dtype = torch.int64,
154155
lengths_dtype: torch.dtype = torch.int64,
156+
random_seed: Optional[int] = None,
155157
) -> Tuple["ModelInput", List["ModelInput"]]: ...
156158

157159

@@ -180,8 +182,10 @@ def gen_model_and_input(
180182
global_constant_batch: bool = False,
181183
num_inputs: int = 1,
182184
input_type: str = "kjt", # "kjt" or "td"
185+
random_seed: Optional[int] = None,
183186
) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]:
184-
torch.manual_seed(0)
187+
if random_seed is not None:
188+
torch.manual_seed(random_seed)
185189
if dedup_feature_names:
186190
model = model_class(
187191
tables=cast(
@@ -224,6 +228,7 @@ def gen_model_and_input(
224228
indices_dtype=indices_dtype,
225229
offsets_dtype=offsets_dtype,
226230
lengths_dtype=lengths_dtype,
231+
random_seed=random_seed,
227232
)
228233
)
229234
elif generate == ModelInput.generate:
@@ -242,6 +247,7 @@ def gen_model_and_input(
242247
indices_dtype=indices_dtype,
243248
offsets_dtype=offsets_dtype,
244249
lengths_dtype=lengths_dtype,
250+
random_seed=random_seed,
245251
)
246252
)
247253
else:
@@ -259,6 +265,7 @@ def gen_model_and_input(
259265
indices_dtype=indices_dtype,
260266
offsets_dtype=offsets_dtype,
261267
lengths_dtype=lengths_dtype,
268+
random_seed=random_seed,
262269
)
263270
)
264271
return (model, inputs)
@@ -718,6 +725,7 @@ def sharding_single_rank_test_single_process(
718725
indices_dtype: torch.dtype = torch.int64,
719726
offsets_dtype: torch.dtype = torch.int64,
720727
lengths_dtype: torch.dtype = torch.int64,
728+
random_seed: int = 0,
721729
) -> None:
722730
batch_size = random.randint(0, batch_size) if allow_zero_batch_size else batch_size
723731
# Generate model & inputs.
@@ -746,7 +754,9 @@ def sharding_single_rank_test_single_process(
746754
indices_dtype=indices_dtype,
747755
offsets_dtype=offsets_dtype,
748756
lengths_dtype=lengths_dtype,
757+
random_seed=random_seed,
749758
)
759+
750760
global_model = global_model.to(device)
751761
global_input = inputs[0][0].to(device)
752762
local_input = inputs[0][1][rank].to(device)
@@ -794,6 +804,7 @@ def sharding_single_rank_test_single_process(
794804
constraints=constraints,
795805
)
796806
plan: ShardingPlan = planner.collective_plan(local_model, sharders, pg)
807+
797808
"""
798809
Simulating multiple nodes on a single node. However, metadata information and
799810
tensor placement must still be consistent. Here we overwrite this to do so.
@@ -973,6 +984,7 @@ def sharding_single_rank_test(
973984
indices_dtype: torch.dtype = torch.int64,
974985
offsets_dtype: torch.dtype = torch.int64,
975986
lengths_dtype: torch.dtype = torch.int64,
987+
random_seed: int = 100,
976988
) -> None:
977989
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
978990
assert ctx.pg is not None
@@ -1006,6 +1018,7 @@ def sharding_single_rank_test(
10061018
indices_dtype=indices_dtype,
10071019
offsets_dtype=offsets_dtype,
10081020
lengths_dtype=lengths_dtype,
1021+
random_seed=random_seed,
10091022
)
10101023

10111024

torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,7 @@ def test_ssd_mixed_kernels_with_vbe(
608608
},
609609
constraints=constraints,
610610
variable_batch_per_feature=True,
611+
random_seed=100,
611612
)
612613

613614
@unittest.skipIf(

0 commit comments

Comments
 (0)