diff --git a/test/unit/flash_attention/flash_attention_prefill/flash_prefill_testbed_3x.hpp b/test/unit/flash_attention/flash_attention_prefill/flash_prefill_testbed_3x.hpp index ece31f6f7a..496ce1fc66 100644 --- a/test/unit/flash_attention/flash_attention_prefill/flash_prefill_testbed_3x.hpp +++ b/test/unit/flash_attention/flash_attention_prefill/flash_prefill_testbed_3x.hpp @@ -283,6 +283,10 @@ struct TestbedImpl { block_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); block_ref_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); + // Zero-initialize output memory to prevent uninitialized values from USM reuse + compat::memset(block_O.get(), 0, block_O.size() * sizeof(ElementOutput)); + compat::memset(block_ref_O.get(), 0, block_ref_O.size() * sizeof(ElementOutput)); + initialize_block(block_Q, seed + 2023); initialize_block(block_K, seed + 2022); initialize_block(block_V, seed + 2021); @@ -595,6 +599,11 @@ struct TestbedImpl { CUTLASS_TRACE_HOST("TestbedImpl::run: Allocating workspace of size " << workspace_size); #endif cutlass::device_memory::allocation workspace(workspace_size); + + // Zero-initialize workspace to prevent uninitialized memory issues + if (workspace_size > 0) { + compat::memset(workspace.get(), 0, workspace_size); + } #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) CUTLASS_TRACE_HOST("TestbedImpl::run: Calling FlashAttention::can_implement"); diff --git a/test/unit/flash_attention/flash_attention_prefill/xe_flash_prefill.cpp b/test/unit/flash_attention/flash_attention_prefill/xe_flash_prefill.cpp index 095261ffbb..4a8005a948 100644 --- a/test/unit/flash_attention/flash_attention_prefill/xe_flash_prefill.cpp +++ b/test/unit/flash_attention/flash_attention_prefill/xe_flash_prefill.cpp @@ -52,7 +52,7 @@ TEST(TEST_NAME, noncausal) { EXPECT_TRUE(test::flash_attention::TestFlashPrefillAll(HEAD_DIM)); } -TEST(GTEST_CONCAT_TOKEN_(DISABLED_, TEST_NAME), varlen_causal) { +TEST(TEST_NAME, varlen_causal) { using Kernel = test::flash_attention::XE_Flash_Attention_Prefill::Kernel; EXPECT_TRUE(test::flash_attention::TestFlashPrefillAll(HEAD_DIM)); diff --git a/test/unit/flash_attention/flash_attention_prefill_cachedkv/flash_prefill_cachedkv_testbed_3x.hpp b/test/unit/flash_attention/flash_attention_prefill_cachedkv/flash_prefill_cachedkv_testbed_3x.hpp index b758d1b8fd..0b6fe6e4d6 100644 --- a/test/unit/flash_attention/flash_attention_prefill_cachedkv/flash_prefill_cachedkv_testbed_3x.hpp +++ b/test/unit/flash_attention/flash_attention_prefill_cachedkv/flash_prefill_cachedkv_testbed_3x.hpp @@ -220,6 +220,11 @@ struct TestbedImpl { block_V_cache.reset(batch * num_heads_kv * seq_len_kv_cache * head_size_vo); block_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); block_ref_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); + + // Zero-initialize output buffers to prevent uninitialized memory issues + if (block_O.size() > 0) { + compat::memset(block_O.get(), 0, block_O.size() * sizeof(ElementOutput)); + } if constexpr (UsePagedKV) { std::vector num_pages_per_seq{0}; @@ -632,6 +637,11 @@ struct TestbedImpl { CUTLASS_TRACE_HOST("TestbedImpl::run: Allocating workspace of size " << workspace_size); #endif cutlass::device_memory::allocation workspace(workspace_size); + + // Zero-initialize workspace to prevent uninitialized memory issues + if (workspace_size > 0) { + compat::memset(workspace.get(), 0, workspace_size); + } #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) CUTLASS_TRACE_HOST("TestbedImpl::run: Calling FlashPrefillCachedKV::can_implement"); diff --git a/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_bf16_fp32_fp32_128.cpp b/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_bf16_fp32_fp32_128.cpp index 09fc3344ac..8dca2d5be3 100644 --- a/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_bf16_fp32_fp32_128.cpp +++ b/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_bf16_fp32_fp32_128.cpp @@ -61,7 +61,7 @@ TEST(XE_Flash_Attention_Prefill_bf16_128, noncausal) { EXPECT_TRUE(test::flash_attention::TestFlashPrefillCachedKVAll(128)); } -TEST(DISABLED_XE_Flash_Attention_Prefill_bf16_128, varlen_causal) { +TEST(XE_Flash_Attention_Prefill_bf16_128, varlen_causal) { constexpr int PipelineStages = 2; using ShapeQK = Shape<_128, _64, _64>; using ShapePV = Shape<_128, _32, _64>; diff --git a/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_bf16_fp32_fp32_192.cpp b/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_bf16_fp32_fp32_192.cpp index 4d037a91ba..8a23ab23be 100644 --- a/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_bf16_fp32_fp32_192.cpp +++ b/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_bf16_fp32_fp32_192.cpp @@ -61,7 +61,7 @@ TEST(XE_Flash_Attention_Prefill_bf16_192, noncausal) { EXPECT_TRUE(test::flash_attention::TestFlashPrefillCachedKVAll(192)); } -TEST(DISABLED_XE_Flash_Attention_Prefill_bf16_192, varlen_causal) { +TEST(XE_Flash_Attention_Prefill_bf16_192, varlen_causal) { constexpr int PipelineStages = 2; using ShapeQK = Shape<_256, _64, _64>; using ShapePV = Shape<_256, _32, _64>; diff --git a/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_bf16_fp32_fp32_64.cpp b/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_bf16_fp32_fp32_64.cpp index c938462e88..c47f0cc5e4 100644 --- a/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_bf16_fp32_fp32_64.cpp +++ b/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_bf16_fp32_fp32_64.cpp @@ -61,7 +61,7 @@ TEST(XE_Flash_Attention_Prefill_bf16_64, noncausal) { EXPECT_TRUE(test::flash_attention::TestFlashPrefillCachedKVAll(64)); } -TEST(DISABLED_XE_Flash_Attention_Prefill_bf16_64, varlen_causal) { +TEST(XE_Flash_Attention_Prefill_bf16_64, varlen_causal) { constexpr int PipelineStages = 2; using ShapeQK = Shape<_128, _64, _64>; using ShapePV = Shape<_128, _32, _64>; diff --git a/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_bf16_fp32_fp32_96.cpp b/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_bf16_fp32_fp32_96.cpp index d74225e3a3..8bc4b078e1 100644 --- a/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_bf16_fp32_fp32_96.cpp +++ b/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_bf16_fp32_fp32_96.cpp @@ -61,7 +61,7 @@ TEST(XE_Flash_Attention_Prefill_bf16_96, noncausal) { EXPECT_TRUE(test::flash_attention::TestFlashPrefillCachedKVAll(96)); } -TEST(DISABLED_XE_Flash_Attention_Prefill_bf16_96, varlen_causal) { +TEST(XE_Flash_Attention_Prefill_bf16_96, varlen_causal) { constexpr int PipelineStages = 2; using ShapeQK = Shape<_128, _64, _32>; using ShapePV = Shape<_128, _32, _64>; diff --git a/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_fp16_fp32_fp32_128.cpp b/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_fp16_fp32_fp32_128.cpp index 4ddba02387..eeb1fc1757 100644 --- a/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_fp16_fp32_fp32_128.cpp +++ b/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_fp16_fp32_fp32_128.cpp @@ -61,7 +61,7 @@ TEST(XE_Flash_Attention_Prefill_fp16_128, noncausal) { EXPECT_TRUE(test::flash_attention::TestFlashPrefillCachedKVAll(128)); } -TEST(DISABLED_XE_Flash_Attention_Prefill_fp16_128, varlen_causal) { +TEST(XE_Flash_Attention_Prefill_fp16_128, varlen_causal) { constexpr int PipelineStages = 2; using ShapeQK = Shape<_128, _64, _64>; using ShapePV = Shape<_128, _32, _64>; diff --git a/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_fp16_fp32_fp32_192.cpp b/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_fp16_fp32_fp32_192.cpp index 3b1a950ed6..778c378012 100644 --- a/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_fp16_fp32_fp32_192.cpp +++ b/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_fp16_fp32_fp32_192.cpp @@ -61,7 +61,7 @@ TEST(XE_Flash_Attention_Prefill_fp16_192, noncausal) { EXPECT_TRUE(test::flash_attention::TestFlashPrefillCachedKVAll(192)); } -TEST(DISABLED_XE_Flash_Attention_Prefill_fp16_192, varlen_causal) { +TEST(XE_Flash_Attention_Prefill_fp16_192, varlen_causal) { constexpr int PipelineStages = 2; using ShapeQK = Shape<_256, _64, _64>; using ShapePV = Shape<_256, _32, _64>; diff --git a/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_fp16_fp32_fp32_64.cpp b/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_fp16_fp32_fp32_64.cpp index 1ae9c8f453..643412db9f 100644 --- a/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_fp16_fp32_fp32_64.cpp +++ b/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_fp16_fp32_fp32_64.cpp @@ -61,7 +61,7 @@ TEST(XE_Flash_Attention_Prefill_fp16_64, noncausal) { EXPECT_TRUE(test::flash_attention::TestFlashPrefillCachedKVAll(64)); } -TEST(DISABLED_XE_Flash_Attention_Prefill_fp16_64, varlen_causal) { +TEST(XE_Flash_Attention_Prefill_fp16_64, varlen_causal) { constexpr int PipelineStages = 2; using ShapeQK = Shape<_128, _64, _64>; using ShapePV = Shape<_128, _32, _64>; diff --git a/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_fp16_fp32_fp32_96.cpp b/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_fp16_fp32_fp32_96.cpp index 4b3f50c1eb..7783767588 100644 --- a/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_fp16_fp32_fp32_96.cpp +++ b/test/unit/flash_attention/flash_attention_prefill_cachedkv/xe_flash_prefill_cachedkv_fp16_fp32_fp32_96.cpp @@ -61,7 +61,7 @@ TEST(XE_Flash_Attention_Prefill_fp16_96, noncausal) { EXPECT_TRUE(test::flash_attention::TestFlashPrefillCachedKVAll(96)); } -TEST(DISABLED_XE_Flash_Attention_Prefill_fp16_96, varlen_causal) { +TEST(XE_Flash_Attention_Prefill_fp16_96, varlen_causal) { constexpr int PipelineStages = 2; using ShapeQK = Shape<_128, _64, _32>; using ShapePV = Shape<_128, _32, _64>;