diff --git a/test/test_low_bit_optim.py b/test/test_low_bit_optim.py index b0edfc7fc5..db2d2e6197 100644 --- a/test/test_low_bit_optim.py +++ b/test/test_low_bit_optim.py @@ -419,8 +419,8 @@ def test_optim_cpu_offload_save_load(self): for p1, p2 in zip(model1.parameters(), model2.parameters()): torch.testing.assert_close(p2, p1) + @parametrize("device", _DEVICES) def test_optim_bf16_stochastic_round_correctness(self): - device = "cuda" if torch.cuda.is_available() else "cpu" torch.manual_seed(2024) model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)) model1.to(device)