Skip to content

Rocm jaxlib v0.5.0 warpsize #169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: rocm-jaxlib-v0.5.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions xla/backends/gpu/codegen/emitters/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -596,9 +596,9 @@ ColumnReductionFusion::ColumnReductionFusion(const HloFusionAnalysis& analysis)
reduction_dimensions_.dimensions[1],
reduction_dimensions_.dimensions[2]};
vector_size_ = GetVectorSizeForMlir(
analysis, /*minor_dim=*/input_shape_.back(), WarpSize());
int64_t num_warps_per_column = WarpSize();
num_threads_ = {num_warps_per_column, WarpSize()};
analysis, /*minor_dim=*/input_shape_.back(), kTileSize);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we need to change as tile size 32 instead of WarpSize(device_info) here? may I ask why?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is temporary, I believe that reduction algorithm needs modifications in order to work with warp_size==64.
Without this some tests fail.

Copy link

@pemeliya pemeliya Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I also did not find a good solution here. This only applies for column-wise reductions.
They work as follows: one block of 1024 threads (32x32) performs column reduction for 1 vertical stripe of N rows and 32 columns. Basically each warp loads and reduces N/32 rows (each having 32 elements) and writes its resulting reduced row to a shared memory. As a result, we have 32 rows of 32 elements written to shared memory.

After that, we do syncthreads and each warp reads 1 vertical column from shared memory and performs warp-level reduction on it. So, finally each warp just writes its 1 reduced element back to global mem. As a result we have Nx32 stripe reduced to 1x32 row.

To make it working for warp_size=64, we could have 16 warps (16*64 = 1024) processing 1 vertical stripe of N rows and 64 columns. But each warp shall then process N/16 rows and perform 4 writes to shared memory (instead of 1). As a result. we would then have 1 large shared mem array of size 64x64 to be transposed. But I don't have a clear idea how to express this in terms of Indexing maps they is in the reduction emitter.

int64_t num_warps_per_column = kTileSize;
num_threads_ = {num_warps_per_column, kTileSize};
int64_t num_col_elements_per_thread =
CeilOfRatio(reduction_dimensions_
.dimensions[ReductionDimensions::kColReducedDimension],
Expand All @@ -612,7 +612,7 @@ ColumnReductionFusion::ColumnReductionFusion(const HloFusionAnalysis& analysis)
reduction_dimensions_
.dimensions[ReductionDimensions::kColMinorKeptDimension];
int64_t num_blocks_per_row =
CeilOfRatio(minor_kept_dim, WarpSize() * vector_size_);
CeilOfRatio(minor_kept_dim, kTileSize * vector_size_);
num_blocks_ = {major_kept_dim, num_blocks_per_row};
}

Expand All @@ -625,7 +625,7 @@ IndexingMap ColumnReductionFusion::ComputeReductionOutputIndexing(
auto vector_index = getAffineSymbolExpr(0, ctx);
SmallVector<AffineExpr, 2> results{
block_id[0],
(block_id[1] * WarpSize() + thread_id[0]) * vector_size_ + vector_index};
(block_id[1] * kTileSize + thread_id[0]) * vector_size_ + vector_index};
IndexingMap projected_index =
GetIndexingMap(results, /*symbol_sizes=*/{vector_size_});
projected_index.AddConstraint(thread_id[1], {0, 0});
Expand All @@ -643,7 +643,7 @@ IndexingMap ColumnReductionFusion::ComputeReductionInputIndexing(

SmallVector<AffineExpr, 3> results{
block_id[0], thread_id[0] + element_index * num_threads_[1],
(block_id[1] * WarpSize() + thread_id[1]) * vector_size_ + vector_index};
(block_id[1] * kTileSize + thread_id[1]) * vector_size_ + vector_index};
IndexingMap map = GetIndexingMap(results, tile_sizes_per_thread_);
for (auto [result, dim_size] :
llvm::zip(results, reduction_dimensions_.dimensions)) {
Expand Down Expand Up @@ -691,20 +691,20 @@ SmallColumnReductionFusion::SmallColumnReductionFusion(
// We emit a single loop over the dimensions 1 and 2, so we use their total
// size when computing the vector size.
vector_size_ = GetVectorSizeForMlir(
analysis, /*minor_dim=*/input_shape_[1] * input_shape_[2], WarpSize());
analysis, /*minor_dim=*/input_shape_[1] * input_shape_[2], kTileSize);
num_threads_ = {128};
shared_rows_ = vector_size_ * num_threads_[0] / input_shape_[kColMinorKept];

// If we have more than 32 shared rows, we'd have to go through shared
// memory one extra time. We don't currently support that, and it's not been
// tried, so we have to reduce the vector size/number of threads.
while (shared_rows_ > WarpSize() && vector_size_ > 1) {
while (shared_rows_ > kTileSize && vector_size_ > 1) {
vector_size_ /= 2;
shared_rows_ /= 2;
}
if (shared_rows_ > WarpSize()) {
num_threads_[0] /= (shared_rows_ / WarpSize());
shared_rows_ = WarpSize();
if (shared_rows_ > kTileSize) {
num_threads_[0] /= (shared_rows_ / kTileSize);
shared_rows_ = kTileSize;
}

num_blocks_ = {input_shape_[kColMajorKept]};
Expand Down
4 changes: 4 additions & 0 deletions xla/backends/gpu/codegen/emitters/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ class ColumnReductionFusion : public ReductionFusion {
IndexingMap GetSharedMemoryReductionReadMap(
mlir::MLIRContext* ctx) const override;
IndexingMap GetSharedMemoryWriteMap(mlir::MLIRContext* ctx) const override;

const int64_t kTileSize = 32;
};

// Special emitter for column reductions whose minor reduced dimension divides
Expand All @@ -227,6 +229,8 @@ class SmallColumnReductionFusion : public ReductionFusion {
mlir::MLIRContext* ctx) const override;
IndexingMap GetSharedMemoryWriteMap(mlir::MLIRContext* ctx) const override;

const int64_t kTileSize = 32;

int64_t shared_rows_;
int64_t loop_size_;
};
Expand Down
9 changes: 5 additions & 4 deletions xla/backends/gpu/codegen/emitters/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ using mlir::ValueRange;
using mlir::func::FuncOp;
using mlir::func::ReturnOp;

constexpr int kNumRows = 4;
constexpr int kNumThreadsPerBlock = 128;
constexpr int kMaxVectorizedBytes = 4;
constexpr int kTileSize = 32;
constexpr int kNumRows = 8;
constexpr int kNumThreadsPerBlock = kNumRows * kTileSize;
constexpr int kMaxVectorizedBytes = 16;

} // namespace

Expand All @@ -87,7 +88,7 @@ TransposeFusion::TransposeFusion(const HloFusionAnalysis& analysis)
permutation_(transpose_.permutation),
input_shape_(
Permute(transpose_.dimensions, InversePermutation(permutation_))),
base_block_size_(WarpSize(analysis_.device_info())) {
base_block_size_(kTileSize) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

ConstHloInstructionSet transposes_to_tile;
int index = 0;
int64_t shmem_usage = 0;
Expand Down
3 changes: 1 addition & 2 deletions xla/stream_executor/rocm/rocm_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,7 @@ absl::StatusOr<int64_t> GetMaxRegistersPerBlock(hipDevice_t device) {
}

absl::StatusOr<int64_t> GetThreadsPerWarp(hipDevice_t device) {
// TODO(ROCm): This is almost certainly wrong but tests seem to rely on it.
return 32;
return 64;
}

absl::Status GetGridLimits(int* x, int* y, int* z, hipDevice_t device) {
Expand Down