diff --git a/src/portfft/committed_descriptor_impl.hpp b/src/portfft/committed_descriptor_impl.hpp index 1b799d1d..0a6470ea 100644 --- a/src/portfft/committed_descriptor_impl.hpp +++ b/src/portfft/committed_descriptor_impl.hpp @@ -49,11 +49,12 @@ class committed_descriptor_impl; template std::vector compute_level( - const typename committed_descriptor_impl::kernel_data_struct& kd_struct, TIn input, Scalar* output, - TIn input_imag, Scalar* output_imag, const Scalar* twiddles_ptr, const IdxGlobal* factors_triple, - IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, IdxGlobal input_global_offset, - IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, IdxGlobal batch_start, Idx factor_id, - Idx total_factors, complex_storage storage, const std::vector& dependencies, sycl::queue& queue); + const typename committed_descriptor_impl::kernel_data_struct& kd_struct, const TIn& input, + Scalar* output, const TIn& input_imag, Scalar* output_imag, const Scalar* twiddles_ptr, + const IdxGlobal* factors_triple, IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, + IdxGlobal input_global_offset, IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, + IdxGlobal batch_start, Idx factor_id, Idx total_factors, complex_storage storage, + const std::vector& dependencies, sycl::queue& queue); template sycl::event transpose_level(const typename committed_descriptor_impl::kernel_data_struct& kd_struct, @@ -150,8 +151,8 @@ class committed_descriptor_impl { template friend std::vector detail::compute_level( - const typename committed_descriptor_impl::kernel_data_struct& kd_struct, TIn input, - Scalar1* output, TIn input_imag, Scalar1* output_imag, const Scalar1* twiddles_ptr, + const typename committed_descriptor_impl::kernel_data_struct& kd_struct, const TIn& input, + Scalar1* output, const TIn& input_imag, Scalar1* output_imag, const Scalar1* twiddles_ptr, const IdxGlobal* factors_triple, IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, IdxGlobal input_global_offset, IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, IdxGlobal batch_start, Idx factor_id, Idx total_factors, complex_storage storage, diff --git a/src/portfft/common/global.hpp b/src/portfft/common/global.hpp index d49ebe2c..5c4574df 100644 --- a/src/portfft/common/global.hpp +++ b/src/portfft/common/global.hpp @@ -171,245 +171,6 @@ PORTFFT_INLINE void dispatch_level(const Scalar* input, Scalar* output, const Sc } } -/** - * Utility function to launch the kernel when the input is a buffer - * @tparam Scalar Scalar type - * @tparam Domain Domain of the compute - * @tparam LayoutIn Input layout - * @tparam LayoutOut Output layout - * @tparam SubgroupSize Subgroup size - * @param input input accessor - * @param output output USM pointer - * @param input_imag input accessor for imaginary data - * @param output_imag output USM pointer for imaginary data - * @param loc_for_input local memory for input - * @param loc_for_twiddles local memory for twiddles - * @param loc_for_store_modifier local memory for store modifier data - * @param multipliers_between_factors twiddles to be multiplied between factors - * @param impl_twiddles twiddles required for sub implementation - * @param factors pointer to global memory containing factors of the input - * @param inner_batches pointer to global memory containing the inner batch for each factor - * @param inclusive_scan pointer to global memory containing the inclusive scan of the factors - * @param n_transforms batch size corresponding to the factor - * @param input_batch_offset offset for the input pointer - * @param launch_params launch configuration, the global and local range with which the kernel will get launched - * @param cgh associated command group handler - */ -template -void launch_kernel(sycl::accessor& input, Scalar* output, - sycl::accessor& input_imag, Scalar* output_imag, - sycl::local_accessor& loc_for_input, sycl::local_accessor& loc_for_twiddles, - sycl::local_accessor& loc_for_store_modifier, const Scalar* multipliers_between_factors, - const Scalar* impl_twiddles, const IdxGlobal* factors, const IdxGlobal* inner_batches, - const IdxGlobal* inclusive_scan, IdxGlobal n_transforms, IdxGlobal input_batch_offset, - std::pair, sycl::range<1>> launch_params, sycl::handler& cgh) { - PORTFFT_LOG_FUNCTION_ENTRY(); - auto [global_range, local_range] = launch_params; -#ifdef PORTFFT_KERNEL_LOG - sycl::stream s{1024 * 16, 1024, cgh}; -#endif - PORTFFT_LOG_TRACE("Launching kernel for global implementation with global_size", global_range[0], "local_size", - local_range[0]); - cgh.parallel_for>( - sycl::nd_range<1>(global_range, local_range), [= -#ifdef PORTFFT_KERNEL_LOG - , - global_logging_config = detail::global_logging_config -#endif - ](sycl::nd_item<1> it, sycl::kernel_handler kh) PORTFFT_REQD_SUBGROUP_SIZE(SubgroupSize) { - detail::global_data_struct global_data{ -#ifdef PORTFFT_KERNEL_LOG - s, global_logging_config, -#endif - it}; - dispatch_level( - &input[0] + input_batch_offset, output, &input_imag[0] + input_batch_offset, output_imag, impl_twiddles, - multipliers_between_factors, &loc_for_input[0], &loc_for_twiddles[0], &loc_for_store_modifier[0], factors, - inner_batches, inclusive_scan, n_transforms, global_data, kh); - }); -} - -/** - * TODO: Launch the kernel directly from compute_level and remove the duplicated launch_kernel - * Utility function to launch the kernel when the input is an USM - * @tparam Scalar Scalar type - * @tparam Domain Domain of the compute - * @tparam LayoutIn Input layout - * @tparam LayoutOut Output layout - * @tparam SubgroupSize Subgroup size - * @param input input pointer - * @param output output pointer - * @param input_imag input pointer for imaginary data - * @param output_imag output pointer for imaginary data - * @param loc_for_input local memory for input - * @param loc_for_twiddles local memory for twiddles - * @param loc_for_store_modifier local memory for store modifier data - * @param multipliers_between_factors twiddles to be multiplied between factors - * @param impl_twiddles twiddles required for sub implementation - * @param factors pointer to global memory containing factors of the input - * @param inner_batches pointer to global memory containing the inner batch for each factor - * @param inclusive_scan pointer to global memory containing the inclusive scan of the factors - * @param n_transforms batch size corresponding to the factor - * @param input_batch_offset offset for the input pointer - * @param launch_params launch configuration, the global and local range with which the kernel will get launched - * @param cgh associated command group handler - */ -template -void launch_kernel(const Scalar* input, Scalar* output, const Scalar* input_imag, Scalar* output_imag, - sycl::local_accessor& loc_for_input, sycl::local_accessor& loc_for_twiddles, - sycl::local_accessor& loc_for_store_modifier, const Scalar* multipliers_between_factors, - const Scalar* impl_twiddles, const IdxGlobal* factors, const IdxGlobal* inner_batches, - const IdxGlobal* inclusive_scan, IdxGlobal n_transforms, IdxGlobal input_batch_offset, - std::pair, sycl::range<1>> launch_params, sycl::handler& cgh) { - PORTFFT_LOG_FUNCTION_ENTRY(); -#ifdef PORTFFT_LOG - sycl::stream s{1024 * 16 * 16, 1024, cgh}; -#endif - auto [global_range, local_range] = launch_params; -#ifdef PORTFFT_KERNEL_LOG - sycl::stream s{1024 * 16, 1024, cgh}; -#endif - PORTFFT_LOG_TRACE("Launching kernel for global implementation with global_size", global_range[0], "local_size", - local_range[0]); - cgh.parallel_for>( - sycl::nd_range<1>(global_range, local_range), [= -#ifdef PORTFFT_KERNEL_LOG - , - global_logging_config = detail::global_logging_config -#endif - ](sycl::nd_item<1> it, sycl::kernel_handler kh) PORTFFT_REQD_SUBGROUP_SIZE(SubgroupSize) { - detail::global_data_struct global_data{ -#ifdef PORTFFT_KERNEL_LOG - s, global_logging_config, -#endif - it}; - dispatch_level( - &input[0] + input_batch_offset, output, &input_imag[0] + input_batch_offset, output_imag, impl_twiddles, - multipliers_between_factors, &loc_for_input[0], &loc_for_twiddles[0], &loc_for_store_modifier[0], factors, - inner_batches, inclusive_scan, n_transforms, global_data, kh); - }); -} - -/** - * TODO: Launch the kernel directly from transpose_level and remove the duplicated dispatch_transpose_kernel_impl - * Utility function to launch the transpose kernel, when the output is a buffer - * @tparam Scalar Scalar type - * @param input input pointer - * @param output output accessor - * @param loc 2D local memory - * @param factors pointer to global memory containing factors of the input - * @param inner_batches pointer to global memory containing the inner batch for each factor - * @param inclusive_scan pointer to global memory containing the inclusive scan of the factors - * @param output_offset offset to output pointer - * @param ldb leading dimension of the output - * @param lda leading dimension of the input - * @param cgh associated command group handler - */ -template -static void dispatch_transpose_kernel_impl(const Scalar* input, - sycl::accessor& output, - sycl::local_accessor& loc, const IdxGlobal* factors, - const IdxGlobal* inner_batches, const IdxGlobal* inclusive_scan, - IdxGlobal output_offset, IdxGlobal lda, IdxGlobal ldb, sycl::handler& cgh) { - PORTFFT_LOG_FUNCTION_ENTRY(); -#ifdef PORTFFT_KERNEL_LOG - sycl::stream s{1024 * 16, 1024, cgh}; -#endif - std::size_t lda_rounded = detail::round_up_to_multiple(static_cast(lda), static_cast(16)); - std::size_t ldb_rounded = detail::round_up_to_multiple(static_cast(ldb), static_cast(16)); - PORTFFT_LOG_TRACE("Launching transpose kernel with global_size", lda_rounded, ldb_rounded, "local_size", 16, 16); - cgh.parallel_for>( - sycl::nd_range<2>({lda_rounded, ldb_rounded}, {16, 16}), [= -#ifdef PORTFFT_KERNEL_LOG - , - global_logging_config = detail::global_logging_config -#endif - ](sycl::nd_item<2> it, sycl::kernel_handler kh) { - detail::global_data_struct global_data{ -#ifdef PORTFFT_KERNEL_LOG - s, global_logging_config, -#endif - it}; - global_data.log_message_global("entering transpose kernel - buffer impl"); - complex_storage storage = kh.get_specialization_constant(); - Idx level_num = kh.get_specialization_constant(); - Idx num_factors = kh.get_specialization_constant(); - IdxGlobal outer_batch_product = get_outer_batch_product(inclusive_scan, num_factors, level_num); - for (IdxGlobal iter_value = 0; iter_value < outer_batch_product; iter_value++) { - global_data.log_message_subgroup("iter_value: ", iter_value); - IdxGlobal outer_batch_offset = get_outer_batch_offset(factors, inner_batches, inclusive_scan, num_factors, - level_num, iter_value, outer_batch_product, storage); - if (storage == complex_storage::INTERLEAVED_COMPLEX) { - detail::generic_transpose<2>(lda, ldb, 16, input + outer_batch_offset, - &output[0] + outer_batch_offset + output_offset, loc, global_data); - } else { - detail::generic_transpose<1>(lda, ldb, 16, input + outer_batch_offset, - &output[0] + outer_batch_offset + output_offset, loc, global_data); - } - } - global_data.log_message_global("exiting transpose kernel - buffer impl"); - }); -} - -/** - * Utility function to launch the transpose kernel, when the output is a buffer - * @tparam Scalar Scalar type - * @param input input pointer - * @param output output pointer - * @param loc 2D local memory - * @param factors pointer to global memory containing factors of the input - * @param inner_batches pointer to global memory containing the inner batch for each factor - * @param inclusive_scan pointer to global memory containing the inclusive scan of the factors - * @param output_offset offset to output pointer - * @param ldb leading dimension of the output - * @param lda leading dimension of the input - * @param cgh associated command group handler - */ -template -static void dispatch_transpose_kernel_impl(const Scalar* input, Scalar* output, sycl::local_accessor& loc, - const IdxGlobal* factors, const IdxGlobal* inner_batches, - const IdxGlobal* inclusive_scan, IdxGlobal output_offset, IdxGlobal lda, - IdxGlobal ldb, sycl::handler& cgh) { - PORTFFT_LOG_FUNCTION_ENTRY(); -#ifdef PORTFFT_KERNEL_LOG - sycl::stream s{1024 * 16 * 16, 1024, cgh}; -#endif - std::size_t lda_rounded = detail::round_up_to_multiple(static_cast(lda), static_cast(16)); - std::size_t ldb_rounded = detail::round_up_to_multiple(static_cast(ldb), static_cast(16)); - PORTFFT_LOG_TRACE("Launching transpose kernel with global_size", lda_rounded, ldb_rounded, "local_size", 16, 16); - cgh.parallel_for>( - sycl::nd_range<2>({lda_rounded, ldb_rounded}, {16, 16}), [= -#ifdef PORTFFT_KERNEL_LOG - , - global_logging_config = detail::global_logging_config -#endif - ](sycl::nd_item<2> it, sycl::kernel_handler kh) { - detail::global_data_struct global_data{ -#ifdef PORTFFT_KERNEL_LOG - s, global_logging_config, -#endif - it}; - global_data.log_message_global("entering transpose kernel - USM impl"); - complex_storage storage = kh.get_specialization_constant(); - Idx level_num = kh.get_specialization_constant(); - Idx num_factors = kh.get_specialization_constant(); - IdxGlobal outer_batch_product = get_outer_batch_product(inclusive_scan, num_factors, level_num); - for (IdxGlobal iter_value = 0; iter_value < outer_batch_product; iter_value++) { - global_data.log_message_subgroup("iter_value: ", iter_value); - IdxGlobal outer_batch_offset = get_outer_batch_offset(factors, inner_batches, inclusive_scan, num_factors, - level_num, iter_value, outer_batch_product, storage); - if (storage == complex_storage::INTERLEAVED_COMPLEX) { - detail::generic_transpose<2>(lda, ldb, 16, input + outer_batch_offset, - &output[0] + outer_batch_offset + output_offset, loc, global_data); - } else { - detail::generic_transpose<1>(lda, ldb, 16, input + outer_batch_offset, - &output[0] + outer_batch_offset + output_offset, loc, global_data); - } - } - global_data.log_message_global("exiting transpose kernel - USM impl"); - }); -} - /** * Prepares the launch of transposition at a particular level * @tparam Scalar Scalar type @@ -437,6 +198,7 @@ sycl::event transpose_level(const typename committed_descriptor_impl& events, complex_storage storage) { PORTFFT_LOG_FUNCTION_ENTRY(); + constexpr detail::memory Mem = std::is_pointer_v ? detail::memory::USM : detail::memory::BUFFER; const IdxGlobal vec_size = storage == complex_storage::INTERLEAVED_COMPLEX ? 2 : 1; std::vector transpose_events; IdxGlobal ld_input = kd_struct.factors.at(1); @@ -456,10 +218,53 @@ sycl::event transpose_level(const typename committed_descriptor_impl(batch_in_l2))); } + const Scalar* offset_input = input + vec_size * committed_size * batch_in_l2; + IdxGlobal output_offset_inner = output_offset + vec_size * committed_size * batch_in_l2; cgh.use_kernel_bundle(kd_struct.exec_bundle); - detail::dispatch_transpose_kernel_impl( - input + vec_size * committed_size * batch_in_l2, out_acc_or_usm, loc, factors_triple, inner_batches, - inclusive_scan, output_offset + vec_size * committed_size * batch_in_l2, ld_output, ld_input, cgh); +#ifdef PORTFFT_KERNEL_LOG + sycl::stream s{1024 * 16, 1024, cgh}; +#endif + std::size_t ld_output_rounded = + detail::round_up_to_multiple(static_cast(ld_output), static_cast(16)); + std::size_t ld_input_rounded = + detail::round_up_to_multiple(static_cast(ld_input), static_cast(16)); + PORTFFT_LOG_TRACE("Launching transpose kernel with global_size", ld_output_rounded, ld_input_rounded, + "local_size", 16, 16); + cgh.parallel_for>( + sycl::nd_range<2>({ld_output_rounded, ld_input_rounded}, {16, 16}), + [= +#ifdef PORTFFT_KERNEL_LOG + , + global_logging_config = detail::global_logging_config +#endif + ](sycl::nd_item<2> it, sycl::kernel_handler kh) { + detail::global_data_struct global_data{ +#ifdef PORTFFT_KERNEL_LOG + s, global_logging_config, +#endif + it}; + global_data.log_message_global("entering transpose kernel - buffer impl"); + complex_storage storage = kh.get_specialization_constant(); + Idx level_num = kh.get_specialization_constant(); + Idx num_factors = kh.get_specialization_constant(); + IdxGlobal outer_batch_product = get_outer_batch_product(inclusive_scan, num_factors, level_num); + for (IdxGlobal iter_value = 0; iter_value < outer_batch_product; iter_value++) { + global_data.log_message_subgroup("iter_value: ", iter_value); + IdxGlobal outer_batch_offset = + get_outer_batch_offset(factors_triple, inner_batches, inclusive_scan, num_factors, level_num, + iter_value, outer_batch_product, storage); + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + detail::generic_transpose<2>(ld_output, ld_input, 16, offset_input + outer_batch_offset, + &out_acc_or_usm[0] + outer_batch_offset + output_offset_inner, loc, + global_data); + } else { + detail::generic_transpose<1>(ld_output, ld_input, 16, offset_input + outer_batch_offset, + &out_acc_or_usm[0] + outer_batch_offset + output_offset_inner, loc, + global_data); + } + } + global_data.log_message_global("exiting transpose kernel - buffer impl"); + }); })); } return queue.submit([&](sycl::handler& cgh) { @@ -502,13 +307,14 @@ sycl::event transpose_level(const typename committed_descriptor_impl std::vector compute_level( - const typename committed_descriptor_impl::kernel_data_struct& kd_struct, const TIn input, - Scalar* output, const TIn input_imag, Scalar* output_imag, const Scalar* twiddles_ptr, + const typename committed_descriptor_impl::kernel_data_struct& kd_struct, const TIn& input, + Scalar* output, const TIn& input_imag, Scalar* output_imag, const Scalar* twiddles_ptr, const IdxGlobal* factors_triple, IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, IdxGlobal input_global_offset, IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, IdxGlobal batch_start, Idx factor_id, Idx total_factors, complex_storage storage, const std::vector& dependencies, sycl::queue& queue) { PORTFFT_LOG_FUNCTION_ENTRY(); + constexpr detail::memory Mem = std::is_pointer_v ? detail::memory::USM : detail::memory::BUFFER; IdxGlobal local_range = kd_struct.local_range; IdxGlobal global_range = kd_struct.global_range; IdxGlobal batch_size = kd_struct.batch_size; @@ -566,14 +372,34 @@ std::vector compute_level( Scalar* offset_output_imag = storage == complex_storage::INTERLEAVED_COMPLEX ? nullptr : output_imag + vec_size * batch_in_l2 * committed_size; - detail::launch_kernel( - in_acc_or_usm, output + vec_size * batch_in_l2 * committed_size, in_imag_acc_or_usm, offset_output_imag, - loc_for_input, loc_for_twiddles, loc_for_modifier, twiddles_ptr + intermediate_twiddle_offset, - subimpl_twiddles, factors_triple, inner_batches, inclusive_scan, batch_size, - vec_size * committed_size * batch_in_l2 + input_global_offset, - {sycl::range<1>(static_cast(global_range)), - sycl::range<1>(static_cast(local_range))}, - cgh); + Scalar* offset_output = output + vec_size * batch_in_l2 * committed_size; + const Scalar* multipliers_between_factors = twiddles_ptr + intermediate_twiddle_offset; + IdxGlobal input_batch_offset = vec_size * committed_size * batch_in_l2 + input_global_offset; +#ifdef PORTFFT_KERNEL_LOG + sycl::stream s{1024 * 16, 1024, cgh}; +#endif + PORTFFT_LOG_TRACE("Launching kernel for global implementation with global_size", global_range, "local_size", + local_range); + cgh.parallel_for>( + sycl::nd_range<1>(sycl::range<1>(static_cast(global_range)), + sycl::range<1>(static_cast(local_range))), + [= +#ifdef PORTFFT_KERNEL_LOG + , + global_logging_config = detail::global_logging_config +#endif + ](sycl::nd_item<1> it, sycl::kernel_handler kh) PORTFFT_REQD_SUBGROUP_SIZE(SubgroupSize) { + detail::global_data_struct global_data{ +#ifdef PORTFFT_KERNEL_LOG + s, global_logging_config, +#endif + it}; + dispatch_level( + &in_acc_or_usm[0] + input_batch_offset, offset_output, &in_imag_acc_or_usm[0] + input_batch_offset, + offset_output_imag, subimpl_twiddles, multipliers_between_factors, &loc_for_input[0], + &loc_for_twiddles[0], &loc_for_modifier[0], factors_triple, inner_batches, inclusive_scan, batch_size, + global_data, kh); + }); })); } return events; diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index 093fba53..3c96a36f 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -345,16 +345,16 @@ struct committed_descriptor_impl::run_kernel_struct(factor_num) == dimension_data.num_factors - 1) { PORTFFT_LOG_TRACE("This is the last kernel"); - l2_events = - detail::compute_level( - current_kernel, desc.scratch_ptr_1.get(), desc.scratch_ptr_1.get(), - desc.scratch_ptr_1.get() + imag_offset, desc.scratch_ptr_1.get() + imag_offset, twiddles_ptr, - factors_and_scan, intermediate_twiddles_offset, impl_twiddle_offset, 0, committed_size, - static_cast(max_batches_in_l2), static_cast(num_batches), static_cast(i), - static_cast(factor_num), dimension_data.num_factors, storage, l2_events, desc.queue); + l2_events = detail::compute_level( + current_kernel, desc.scratch_ptr_1.get(), desc.scratch_ptr_1.get(), + desc.scratch_ptr_1.get() + imag_offset, desc.scratch_ptr_1.get() + imag_offset, twiddles_ptr, + factors_and_scan, intermediate_twiddles_offset, impl_twiddle_offset, 0, committed_size, + static_cast(max_batches_in_l2), static_cast(num_batches), static_cast(i), + static_cast(factor_num), dimension_data.num_factors, storage, l2_events, desc.queue); } else { l2_events = detail::compute_level( + detail::layout::BATCH_INTERLEAVED, SubgroupSize, const Scalar*>( current_kernel, desc.scratch_ptr_1.get(), desc.scratch_ptr_1.get(), desc.scratch_ptr_1.get() + imag_offset, desc.scratch_ptr_1.get() + imag_offset, twiddles_ptr, factors_and_scan, intermediate_twiddles_offset, impl_twiddle_offset, 0, committed_size,