diff --git a/csrc/scheduler/normalization_inner_outer_utils.cpp b/csrc/scheduler/normalization_inner_outer_utils.cpp index 511b40dec63..3e1adaf0e06 100644 --- a/csrc/scheduler/normalization_inner_outer_utils.cpp +++ b/csrc/scheduler/normalization_inner_outer_utils.cpp @@ -208,6 +208,7 @@ PersistentBufferStorageParams getPersistentBufferStorageParams( std::unordered_map> required_size_bit_regs_smem_map; int64_t total_smem_buffer_size_bit = 0; + int64_t total_regs_buffer_size_bit = 0; for (auto buffer : buffers) { int64_t buffer_size_regs_bit = scheduler_utils::getPersistentBufferSizeBitOfTensor( @@ -227,8 +228,23 @@ PersistentBufferStorageParams getPersistentBufferStorageParams( required_size_bit_regs_smem_map[buffer] = std::make_pair(buffer_size_regs_bit, buffer_size_smem_bit); total_smem_buffer_size_bit += buffer_size_smem_bit; + total_regs_buffer_size_bit += buffer_size_regs_bit; + } + // Prefer shared memory persistent buffers when cached inputs are persistent, + // enabling efficient CpAsync or TMA direct copies from global memory to + // shared memory. Otherwise, register buffers are preferred to avoid the + // inefficient data path: gmem (input) → regs (cached input) → smem + // (persistent buffer) → regs (computation), where the regs → smem → regs path + // is redundant. + bool cached_inputs_are_persistent = buffer_params.project_to_input || + std::all_of(buffers.begin(), buffers.end(), [](TensorView* tv) { + return tv->isFusionInput(); + }); + if (cached_inputs_are_persistent) { + buffer_params.smem_buffer_size_bit += total_smem_buffer_size_bit; + } else { + buffer_params.regs_buffer_size_bit += total_regs_buffer_size_bit; } - buffer_params.smem_buffer_size_bit = total_smem_buffer_size_bit; buffer_params.regs_buffer_size_bit += partialOuterReductionBufferSizeBit(reduction_tvs, runtime_info); buffer_params.circular_buffered_smem_size_bit = @@ -236,7 +252,9 @@ PersistentBufferStorageParams getPersistentBufferStorageParams( buffer_params.non_circular_buffered_smem_size_bit; if (buffer_params.regs_buffer_size_bit <= available_regs_bit && buffer_params.smem_buffer_size_bit <= available_smem_bit) { - buffer_params.smem_persistent_buffers = buffers; + if (cached_inputs_are_persistent) { + buffer_params.smem_persistent_buffers = buffers; + } buffer_params.has_enough_regs_and_smem = true; return buffer_params; } diff --git a/tests/cpp/test_combined_inner_outer_reduction.cpp b/tests/cpp/test_combined_inner_outer_reduction.cpp index 31849df0b8a..0144505046b 100644 --- a/tests/cpp/test_combined_inner_outer_reduction.cpp +++ b/tests/cpp/test_combined_inner_outer_reduction.cpp @@ -1820,4 +1820,35 @@ TEST_F(CombinedSchedulerTest, IllegalSizeToUseTMA) { EXPECT_FALSE(heuristic_params->as()->tma_warp_specialized); testValidate(&fusion_copy, cg_outputs, {t0, t1}, __LINE__, __FILE__); } + +// Avoid using shared memory persistent buffers when cached inputs are not +// persistent, to avoid the inefficient data path: gmem (input) → regs (cached +// input) → smem (persistent buffer) → regs (computation), where the regs → smem +// → regs path is redundant. The performance of this test is increased +// from 50.5% SOL to 61.6% SOL on GB200. +TEST_F(CombinedSchedulerTest, CachedInputsAreNotPersistentFusedReshape) { + auto dtype = DataType::Float; + constexpr auto dim0 = 2; + constexpr auto dim1 = 1024; + constexpr auto dim2 = 8192; + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + auto tv0 = makeContigTensor(3, dtype); + fusion->addInput(tv0); + auto tv1 = reshape(tv0, {dim0, dim1, dim2}, {dim0 * dim1, dim2}); + auto tv2 = sum(tv1, {1}); + auto tv3 = broadcast(tv2, {false, true}); + auto tv4 = add(tv1, tv3); + auto tv5 = sum(tv1, {0}); + fusion->addOutput(tv4); + fusion->addOutput(tv5); + auto fusion_copy = *fusion; + + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({dim0, dim1, dim2}, options); + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs({t0}); + testValidate(&fusion_copy, cg_outputs, {t0}, __LINE__, __FILE__); +} } // namespace nvfuser