From dbea1dcc4979c88353d661f898a6e2f0b047ba4e Mon Sep 17 00:00:00 2001 From: pengpeng-yu <2649535611@qq.com> Date: Tue, 24 Mar 2026 14:18:42 +0800 Subject: [PATCH] Fix the ScatterD issue in predicated_tile_iterator --- .../threadblock/predicated_tile_iterator.h | 48 +++++++++++++++---- 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h index 2502c2e03f..1be5bff9cd 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h @@ -364,12 +364,16 @@ class PredicatedTileIterator { } if (group + 1 < ThreadMap::Iterations::kGroup) { - byte_pointer += params_.increment_group; + if (!ScatterD) { + byte_pointer += params_.increment_group; + } } } if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; + if (!ScatterD) { + byte_pointer += params_.increment_cluster; + } } } } @@ -727,8 +731,14 @@ class PredicatedTileIterator { int increment_row = state_[0] / ThreadMap::Count::kRow; state_[0] = state_[0] % ThreadMap::Count::kRow; - byte_pointer_ += (params_.advance_row * increment); - store_byte_pointer_ += (params_.advance_row * increment); + if (!ScatterD) { + byte_pointer_ += (params_.advance_row * increment); + } + + if (!ScatterD && !PermuteD) { + store_byte_pointer_ += (params_.advance_row * increment); + } + thread_start_row_ += (ThreadMap::Shape::kRow * increment); // Group @@ -736,8 +746,14 @@ class PredicatedTileIterator { int increment_group = state_[1] / ThreadMap::Count::kGroup; state_[1] = state_[1] % ThreadMap::Count::kGroup; - byte_pointer_ += (params_.advance_group * increment_row); - store_byte_pointer_ += (params_.advance_group * increment_row); + if (!ScatterD) { + byte_pointer_ += (params_.advance_group * increment_row); + } + + if (!ScatterD && !PermuteD) { + store_byte_pointer_ += (params_.advance_group * increment_row); + } + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * @@ -750,8 +766,14 @@ class PredicatedTileIterator { int increment_cluster = state_[2] / ThreadMap::Count::kCluster; state_[2] = state_[2] % ThreadMap::Count::kCluster; - byte_pointer_ += (params_.advance_cluster * increment_group); - store_byte_pointer_ += (params_.advance_cluster * increment_group); + if (!ScatterD) { + byte_pointer_ += (params_.advance_cluster * increment_group); + } + + if (!ScatterD && !PermuteD) { + store_byte_pointer_ += (params_.advance_cluster * increment_group); + } + thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * @@ -760,8 +782,14 @@ class PredicatedTileIterator { increment_group; // Tile - byte_pointer_ += (params_.advance_tile * increment_cluster); - store_byte_pointer_ += (params_.advance_tile * increment_cluster); + if (!ScatterD) { + byte_pointer_ += (params_.advance_tile * increment_cluster); + } + + if (!ScatterD && !PermuteD) { + store_byte_pointer_ += (params_.advance_tile * increment_cluster); + } + thread_start_row_ += ThreadMap::Shape::kGroup * ThreadMap::Shape::kRow *