Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions csrc/scheduler/normalization_inner_outer_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ PersistentBufferStorageParams getPersistentBufferStorageParams(
std::unordered_map<TensorView*, std::pair<int64_t, int64_t>>
required_size_bit_regs_smem_map;
int64_t total_smem_buffer_size_bit = 0;
int64_t total_regs_buffer_size_bit = 0;
for (auto buffer : buffers) {
int64_t buffer_size_regs_bit =
scheduler_utils::getPersistentBufferSizeBitOfTensor(
Expand All @@ -227,16 +228,33 @@ PersistentBufferStorageParams getPersistentBufferStorageParams(
required_size_bit_regs_smem_map[buffer] =
std::make_pair(buffer_size_regs_bit, buffer_size_smem_bit);
total_smem_buffer_size_bit += buffer_size_smem_bit;
total_regs_buffer_size_bit += buffer_size_regs_bit;
}
// Prefer shared memory persistent buffers when cached inputs are persistent,
// enabling efficient CpAsync or TMA direct copies from global memory to
// shared memory. Otherwise, register buffers are preferred to avoid the
// inefficient data path: gmem (input) → regs (cached input) → smem
// (persistent buffer) → regs (computation), where the regs → smem → regs path
// is redundant.
bool cached_inputs_are_persistent = buffer_params.project_to_input ||
std::all_of(buffers.begin(), buffers.end(), [](TensorView* tv) {
return tv->isFusionInput();
});
if (cached_inputs_are_persistent) {
buffer_params.smem_buffer_size_bit += total_smem_buffer_size_bit;
} else {
buffer_params.regs_buffer_size_bit += total_regs_buffer_size_bit;
}
buffer_params.smem_buffer_size_bit = total_smem_buffer_size_bit;
buffer_params.regs_buffer_size_bit +=
partialOuterReductionBufferSizeBit(reduction_tvs, runtime_info);
buffer_params.circular_buffered_smem_size_bit =
buffer_params.smem_buffer_size_bit -
buffer_params.non_circular_buffered_smem_size_bit;
if (buffer_params.regs_buffer_size_bit <= available_regs_bit &&
buffer_params.smem_buffer_size_bit <= available_smem_bit) {
buffer_params.smem_persistent_buffers = buffers;
if (cached_inputs_are_persistent) {
buffer_params.smem_persistent_buffers = buffers;
}
buffer_params.has_enough_regs_and_smem = true;
return buffer_params;
}
Expand Down
31 changes: 31 additions & 0 deletions tests/cpp/test_combined_inner_outer_reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1820,4 +1820,35 @@ TEST_F(CombinedSchedulerTest, IllegalSizeToUseTMA) {
EXPECT_FALSE(heuristic_params->as<ReductionParams>()->tma_warp_specialized);
testValidate(&fusion_copy, cg_outputs, {t0, t1}, __LINE__, __FILE__);
}

// Avoid using shared memory persistent buffers when cached inputs are not
// persistent, to avoid the inefficient data path: gmem (input) → regs (cached
// input) → smem (persistent buffer) → regs (computation), where the regs → smem
// → regs path is redundant. The performance of this test is increased
// from 50.5% SOL to 61.6% SOL on GB200.
TEST_F(CombinedSchedulerTest, CachedInputsAreNotPersistentFusedReshape) {
auto dtype = DataType::Float;
constexpr auto dim0 = 2;
constexpr auto dim1 = 1024;
constexpr auto dim2 = 8192;
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
auto tv0 = makeContigTensor(3, dtype);
fusion->addInput(tv0);
auto tv1 = reshape(tv0, {dim0, dim1, dim2}, {dim0 * dim1, dim2});
auto tv2 = sum(tv1, {1});
auto tv3 = broadcast(tv2, {false, true});
auto tv4 = add(tv1, tv3);
auto tv5 = sum(tv1, {0});
fusion->addOutput(tv4);
fusion->addOutput(tv5);
auto fusion_copy = *fusion;

auto options =
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({dim0, dim1, dim2}, options);
FusionExecutorCache executor_cache(std::move(fusion));
auto cg_outputs = executor_cache.runFusionWithInputs({t0});
testValidate(&fusion_copy, cg_outputs, {t0}, __LINE__, __FILE__);
}
} // namespace nvfuser