Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix improper touched buffer assignment of Pass MergeSharedMemoryAllocations #17438

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 8 additions & 19 deletions src/tir/transforms/merge_shared_memory_allocations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ namespace tir {
using runtime::StorageRank;
using runtime::StorageScope;

bool IsDynamicSharedMemory(Var buffer_var) {
static bool IsDynamicSharedMemory(Var buffer_var) {
StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
}

bool IsStaticSharedMemory(Var buffer_var) {
static bool IsStaticSharedMemory(Var buffer_var) {
StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == "";
}
Expand Down Expand Up @@ -125,7 +125,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size());
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
scope_[it->second.level].touched.push_back(buf);
scope_[scope_.size() - 1].touched.push_back(buf);
}
}
StmtEntry e = scope_.back();
Expand Down Expand Up @@ -156,29 +156,18 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store.";
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
scope_[it->second.level].touched.push_back(buf);
scope_[scope_.size() - 1].touched.push_back(buf);
}
}
}

void VisitExpr_(const CallNode* op) final {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this visit function, as it only allow visit indices, which will lead to some buffer load statement not be traced.

if (op->op.same_as(builtin::address_of())) {
const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
for (const auto& index : load->indices) {
this->VisitExpr(index);
}
} else {
StmtExprVisitor::VisitExpr_(op);
}
}

void VisitExpr_(const VarNode* buf) final {
// Directly reference to the variable count as a read.
auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size());
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
scope_[it->second.level].touched.push_back(buf);
scope_[scope_.size() - 1].touched.push_back(buf);
}
}
}
Expand Down Expand Up @@ -291,7 +280,7 @@ class SharedMemoryRewriter : public StmtExprMutator {
for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
for (const VarNode* buffer : e->allocs[i]) {
const AllocateNode* alloc = shmem_allocs_[buffer];
align[i] = std::max(align[i], alloc->dtype.bytes());
align[i] = std::max(align[i], alloc->dtype.bytes() * alloc->dtype.lanes());
}
}
}
Expand All @@ -303,7 +292,7 @@ class SharedMemoryRewriter : public StmtExprMutator {
for (const VarNode* buffer : e->allocs[i]) {
const AllocateNode* alloc = shmem_allocs_[buffer];
buffer_byte_offsets_[buffer] = merged_alloc_size_ + inner_offset;
inner_offset += alloc->extents[0] * alloc->dtype.bytes();
inner_offset += alloc->extents[0] * alloc->dtype.bytes() * alloc->dtype.lanes();
inner_offset += indexmod(align[i] - indexmod(inner_offset, align[i]), align[i]);
}
max_inner_offset = max(max_inner_offset, inner_offset);
Expand Down Expand Up @@ -426,7 +415,7 @@ class SharedMemoryRewriter : public StmtExprMutator {
PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) {
auto it = buffer_byte_offsets_.find(buffer_var.get());
ICHECK(it != buffer_byte_offsets_.end());
return indexdiv(it->second, dtype.bytes());
return indexdiv(it->second, dtype.bytes() * dtype.lanes());
}

// Wrapper function to determine if the shared memory allocation for a variable is appropriate.
Expand Down
42 changes: 40 additions & 2 deletions src/tir/transforms/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,36 @@ namespace tir {
using runtime::StorageRank;
using runtime::StorageScope;

/*!
* \brief collect the mapping from the buffer var to its allocate
*/
class AllocateCollector : public StmtExprVisitor {
private:
bool IsDynamicSharedMemory(Var buffer_var) {
StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
}

bool IsStaticSharedMemory(Var buffer_var) {
StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == "";
}

public:
void VisitStmt_(const AllocateNode* op) final {
if (IsDynamicSharedMemory(op->buffer_var)) {
dyn_shmem_allocs_[op->buffer_var.get()] = op;
} else if (IsStaticSharedMemory(op->buffer_var)) {
static_shmem_allocs_[op->buffer_var.get()] = op;
}
StmtExprVisitor::VisitStmt_(op);
}
// The dynamic mapping from the original buffer var to its allocate
std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
// The static mapping from the original buffer var to its allocate
std::unordered_map<const VarNode*, const AllocateNode*> static_shmem_allocs_;
};

// Find a linear pattern of storage access
// Used for liveness analysis.
// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
Expand Down Expand Up @@ -1733,7 +1763,15 @@ Pass StorageRewrite() {
bool enable_reuse = true;
bool reuse_require_exact_matched_dtype = false;
bool merge_static_smem = ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
if (merge_static_smem) {

AllocateCollector collector;
collector(f->body);
bool has_dynamic = collector.dyn_shmem_allocs_.size() > 1;
if (has_dynamic || merge_static_smem) {
// For IRModule utilizing dynamic shared memory, reuse is not enabled
// Because dynamic doesn't require maintaining the readability and
// it benefits from a more optimized allocation strategy through the
// Pass `MergeSharedMemoryAllocations`.
// When `merge_static_smem` is true, we will reuse and merge shared
// memory in a dedicated pass `MergeSharedMemoryAllocations`.
// And so we don't enable reuse in this pass.
Expand All @@ -1755,7 +1793,7 @@ Pass StorageRewrite() {
// padded out to 32 bits) would require either rewriting
// AllocateConst::data, or would require the code generators to
// handle vectorized constants.
return PointerValueTypeRewrite(std::move(f), true, false, false, true, true, true, false,
return PointerValueTypeRewrite(std::move(f), true, false, false, false, true, true, false,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fourth condition must be false, as the vectorized buffer merge (for example, merge half B_shared[1024] into halfx8 B_shared[128]) will occasional lead to a unhandled behavior during async copy lowering phase.

false);
};
return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {});
Expand Down
Loading