@@ -158,7 +158,9 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
158158 loss .backward ()
159159
160160
161- def test_fp8_mlp_tensor_parallelism (mesh : DeviceMesh , size = 16 ):
161+ def test_fp8_mlp_tensor_parallelism_base (
162+ mesh : DeviceMesh , size = 16 , compile : bool = False
163+ ):
162164 device = mesh .device_type
163165
164166 toy_model = ToyModel ().to (device )
@@ -197,6 +199,9 @@ def test_fp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16):
197199 },
198200 )
199201
202+ if compile :
203+ tp_model = torch .compile (tp_model )
204+
200205 x_fp32 = torch .rand (size * 2 , size , device = device , requires_grad = False )
201206 x_fp32_tp_input = x_fp32 .clone ()
202207 x_fp32_sp_input = distribute_tensor (x_fp32 .clone (), mesh , [Shard (0 )])
@@ -217,6 +222,10 @@ def test_fp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16):
217222 )
218223
219224
225+ def test_fp8_mlp_tensor_parallelism_compile (mesh : DeviceMesh , size = 16 ):
226+ test_fp8_mlp_tensor_parallelism_base (mesh , size , compile = True )
227+
228+
220229if __name__ == "__main__" :
221230 # float8 only works on CUDA H100 so we only test cuda and we follow
222231 # other test files to not use TestCase but instead just add the test
@@ -227,7 +236,8 @@ def test_fp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16):
227236 test_fp8_redistribute ,
228237 test_dtensor_cast_to_fp8 ,
229238 test_dtensor_fp8_autograd ,
230- test_fp8_mlp_tensor_parallelism ,
239+ test_fp8_mlp_tensor_parallelism_base ,
240+ test_fp8_mlp_tensor_parallelism_compile ,
231241 ]
232242
233243 for test in tqdm (tests , desc = "Running tests" ):
0 commit comments