From 35d26881f2c2c85fabdb48cc4de33a37f30d2671 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Dec 2024 06:23:35 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/fused_attn/test_fused_attn.py | 65 ++++++++++++++++----- 1 file changed, 51 insertions(+), 14 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index b04624fd89..588e6e4ecd 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -427,8 +427,12 @@ def test_dpa_mla(dtype, model_configs, model): "mask_8_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding", "no_bias"), "mask_9_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding_causal", "no_bias"), "mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding_causal", "no_bias"), - "mask_10_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias"), - "mask_10_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias"), + "mask_10_0": ModelConfig( + 2, 24, 24, 128, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "mask_10_1": ModelConfig( + 2, 16, 16, 256, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), } @@ -643,18 +647,51 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): "layout_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), "layout_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), "layout_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), - "layout_2_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"), - "layout_2_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"), - "layout_2_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"), - "layout_3_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4,4)), - "layout_3_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4,4)), - "layout_3_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias", window_size=(4,4)), - "layout_4_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4,0)), - "layout_4_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4,0)), - "layout_4_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias", window_size=(4,0)), - "layout_5_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4,0)), - "layout_5_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4,0)), - "layout_5_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4,0)), + "layout_2_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "layout_2_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "layout_2_2": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "layout_3_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4) + ), + "layout_3_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4) + ), + "layout_3_2": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias", window_size=(4, 4) + ), + "layout_4_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + ), + "layout_4_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + ), + "layout_4_2": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + ), + "layout_5_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0) + ), + "layout_5_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0) + ), + "layout_5_2": ModelConfig( + 2, + 24, + 24, + 128, + 2048, + 4096, + 0.0, + "padding_causal_bottom_right", + "no_bias", + window_size=(4, 0), + ), }