Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/device_lower/analysis/device_version.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ void MinimumDeviceVersion::dispatch(Val* val) {
}

void MinimumDeviceVersion::handle(MmaOp* mma_op) {
GpuLower::current()->setHasMma(true);
if (isTuring(mma_op->macro())) {
ensureVersion({7, 5}, "Fusion contains a Turing MMA macro");
} else if (isAmpere(mma_op->macro())) {
Expand Down
80 changes: 78 additions & 2 deletions csrc/device_lower/analysis/tma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
namespace nvfuser {

std::ostream& operator<<(std::ostream& os, const TMADim& d) {
os << "TMADim{"
<< "partitioned="
os << "TMADim{" << "partitioned="
<< (d.partitioned ? d.partitioned->toString() : "nullptr")
<< ", box=" << (d.box ? d.box->toString() : "nullptr")
<< ", tile=" << (d.tile ? d.tile->toString() : "nullptr")
Expand Down Expand Up @@ -1069,6 +1068,77 @@ std::vector<TMADim> run(

} // namespace collapse_tma_domain

// Validate broadcast usage with TMA loads
//
// TMA auto-fills out-of-bounds accesses with zeros. When broadcast dimensions
// are merged with non-broadcast and loaded with TMA, TMA treats
// the broadcast as a physical dimension and loads extra rows/cols as zeros,
// breaking broadcast semantics (which should replicate values, not fill zeros).
//
// See test BroadcastDownstreamOfTMALoad for detailed example.

void validateTMAConsumerBroadcasts(TensorView* smem_tv) {
// Check if Bulk-parallelized dimensions (TMA tile dimensions) depend on
// broadcast dimensions through transformations.
//
// Uses PERMISSIVE mode which maps broadcast to non-broadcast. We find
// ValGroups containing broadcasts and check if Bulk-parallelized ValGroups
// are reachable from them through transitive dependencies.

// Build PERMISSIVE IdModel graph - maps broadcast to non-broadcast through
// transformations
IdModel id_model(smem_tv->fusion(), /*build_graphs=*/false);
id_model.maybeBuildGraph(IdMappingMode::PERMISSIVE);
const ValGraph& permissive_graph =
id_model.idGraph(IdMappingMode::PERMISSIVE);

// Collect ValGroups containing broadcast IDs
ValGroups broadcast_groups;
for (const ValGroup& val_group :
permissive_graph.disjointValSets().disjointSets()) {
bool has_broadcast =
std::any_of(val_group->begin(), val_group->end(), [](Val* val) {
return val->isA<IterDomain>() && val->as<IterDomain>()->isBroadcast();
});
if (has_broadcast) {
broadcast_groups.pushBack(val_group);
}
}

// Collect ValGroups containing Bulk-parallelized IDs
ValGroups bulk_groups;
for (const ValGroup& val_group :
permissive_graph.disjointValSets().disjointSets()) {
bool has_bulk =
std::any_of(val_group->begin(), val_group->end(), [](Val* val) {
return val->isA<IterDomain>() &&
val->as<IterDomain>()->getParallelType() == ParallelType::Bulk;
});
if (has_bulk) {
bulk_groups.pushBack(val_group);
}
}

// Check if any Bulk ValGroup is reachable from broadcast ValGroups
// This captures both direct and transitive dependencies (broadcast ->
// intermediate -> Bulk)
auto reachable_bulk_groups = getReachableValsFrom<ValGraphBFS>(
broadcast_groups.vector(),
bulk_groups.vector(),
Direction::Forward,
permissive_graph);

// If any Bulk group is reachable from broadcasts, it's an error
if (!reachable_bulk_groups.empty()) {
NVF_ERROR(
false,
"Broadcast may interfere with TMA loading of ",
smem_tv->toString(),
". Bulk-parallelized dimensions are reachable from broadcast "
"dimensions through transformations.");
}
}

TMAInfo getTMAInfo(LoadStoreOp* ldst) {
auto* producer_tv = ldst->in()->as<TensorView>();
// In case the producer is aliased, use the alias instead
Expand Down Expand Up @@ -1121,6 +1191,12 @@ TMAInfo getTMAInfo(LoadStoreOp* ldst) {
"(this is always the case for nvFuser now)",
", the first element of elementStrides must be one.");

// Validate broadcast usage: TMA auto-fills out-of-bounds with zeros,
// breaking broadcast semantics when broadcast dims participate in tile shape.
if (!GpuLower::current()->hasMma()) {
validateTMAConsumerBroadcasts(smem_tv);
}

MmaInputSmemSwizzle swizzle = getSwizzle(smem_tv);

// Handle "defining box by compositing" by collapsing some dimensions in the
Expand Down
11 changes: 11 additions & 0 deletions csrc/device_lower/lower2device.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,14 @@ class GpuLower : public NonCopyable {
return mbarrier_map_;
}

bool hasMma() const {
return has_mma_;
}

void setHasMma(bool has_mma) {
has_mma_ = has_mma;
}

bool isNvFuserZeroEnabled() {
if (isOptionDisabled(DisableOption::MagicZero)) {
return false;
Expand Down Expand Up @@ -434,6 +442,9 @@ class GpuLower : public NonCopyable {
// The shared cluster reduction mbarrier tensor allocated during allocation
// pass
TensorView* cluster_reduction_mbarrier_tensor_ = nullptr;

// has mma op in fusion
bool has_mma_ = false;
};

#define NVFUSER_LOWER_VALIDATE(cond, ...) \
Expand Down
Loading