@@ -135,6 +135,7 @@ def __call__(
135
135
indices_dtype : torch .dtype = torch .int64 ,
136
136
offsets_dtype : torch .dtype = torch .int64 ,
137
137
lengths_dtype : torch .dtype = torch .int64 ,
138
+ random_seed : Optional [int ] = None ,
138
139
) -> Tuple ["ModelInput" , List ["ModelInput" ]]: ...
139
140
140
141
@@ -152,6 +153,7 @@ def __call__(
152
153
indices_dtype : torch .dtype = torch .int64 ,
153
154
offsets_dtype : torch .dtype = torch .int64 ,
154
155
lengths_dtype : torch .dtype = torch .int64 ,
156
+ random_seed : Optional [int ] = None ,
155
157
) -> Tuple ["ModelInput" , List ["ModelInput" ]]: ...
156
158
157
159
@@ -180,8 +182,10 @@ def gen_model_and_input(
180
182
global_constant_batch : bool = False ,
181
183
num_inputs : int = 1 ,
182
184
input_type : str = "kjt" , # "kjt" or "td"
185
+ random_seed : Optional [int ] = None ,
183
186
) -> Tuple [nn .Module , List [Tuple [ModelInput , List [ModelInput ]]]]:
184
- torch .manual_seed (0 )
187
+ if random_seed is not None :
188
+ torch .manual_seed (random_seed )
185
189
if dedup_feature_names :
186
190
model = model_class (
187
191
tables = cast (
@@ -224,6 +228,7 @@ def gen_model_and_input(
224
228
indices_dtype = indices_dtype ,
225
229
offsets_dtype = offsets_dtype ,
226
230
lengths_dtype = lengths_dtype ,
231
+ random_seed = random_seed ,
227
232
)
228
233
)
229
234
elif generate == ModelInput .generate :
@@ -242,6 +247,7 @@ def gen_model_and_input(
242
247
indices_dtype = indices_dtype ,
243
248
offsets_dtype = offsets_dtype ,
244
249
lengths_dtype = lengths_dtype ,
250
+ random_seed = random_seed ,
245
251
)
246
252
)
247
253
else :
@@ -259,6 +265,7 @@ def gen_model_and_input(
259
265
indices_dtype = indices_dtype ,
260
266
offsets_dtype = offsets_dtype ,
261
267
lengths_dtype = lengths_dtype ,
268
+ random_seed = random_seed ,
262
269
)
263
270
)
264
271
return (model , inputs )
@@ -718,6 +725,7 @@ def sharding_single_rank_test_single_process(
718
725
indices_dtype : torch .dtype = torch .int64 ,
719
726
offsets_dtype : torch .dtype = torch .int64 ,
720
727
lengths_dtype : torch .dtype = torch .int64 ,
728
+ random_seed : int = 0 ,
721
729
) -> None :
722
730
batch_size = random .randint (0 , batch_size ) if allow_zero_batch_size else batch_size
723
731
# Generate model & inputs.
@@ -746,7 +754,9 @@ def sharding_single_rank_test_single_process(
746
754
indices_dtype = indices_dtype ,
747
755
offsets_dtype = offsets_dtype ,
748
756
lengths_dtype = lengths_dtype ,
757
+ random_seed = random_seed ,
749
758
)
759
+
750
760
global_model = global_model .to (device )
751
761
global_input = inputs [0 ][0 ].to (device )
752
762
local_input = inputs [0 ][1 ][rank ].to (device )
@@ -794,6 +804,7 @@ def sharding_single_rank_test_single_process(
794
804
constraints = constraints ,
795
805
)
796
806
plan : ShardingPlan = planner .collective_plan (local_model , sharders , pg )
807
+
797
808
"""
798
809
Simulating multiple nodes on a single node. However, metadata information and
799
810
tensor placement must still be consistent. Here we overwrite this to do so.
@@ -973,6 +984,7 @@ def sharding_single_rank_test(
973
984
indices_dtype : torch .dtype = torch .int64 ,
974
985
offsets_dtype : torch .dtype = torch .int64 ,
975
986
lengths_dtype : torch .dtype = torch .int64 ,
987
+ random_seed : int = 100 ,
976
988
) -> None :
977
989
with MultiProcessContext (rank , world_size , backend , local_size ) as ctx :
978
990
assert ctx .pg is not None
@@ -1006,6 +1018,7 @@ def sharding_single_rank_test(
1006
1018
indices_dtype = indices_dtype ,
1007
1019
offsets_dtype = offsets_dtype ,
1008
1020
lengths_dtype = lengths_dtype ,
1021
+ random_seed = random_seed ,
1009
1022
)
1010
1023
1011
1024
0 commit comments