From 6a6672f7d701f7ae7113c52e3ae9cf41c6c05dc7 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Mon, 27 Nov 2023 17:19:15 +0000 Subject: [PATCH 01/67] intial factorization changes --- src/portfft/descriptor.hpp | 35 +++++++-------- src/portfft/utils.hpp | 87 +++++++++++++++++++++++++++++++++----- 2 files changed, 91 insertions(+), 31 deletions(-) diff --git a/src/portfft/descriptor.hpp b/src/portfft/descriptor.hpp index 873bc4c5..c145da87 100644 --- a/src/portfft/descriptor.hpp +++ b/src/portfft/descriptor.hpp @@ -337,8 +337,9 @@ class committed_descriptor { IdxGlobal n_idx_global = detail::factorize(fft_size); if (detail::can_cast_safely(n_idx_global) && detail::can_cast_safely(fft_size / n_idx_global)) { + bool is_prime = false; if (n_idx_global == 1) { - throw unsupported_configuration("FFT size ", fft_size, " : Large Prime sized FFT currently is unsupported"); + is_prime = true; } Idx n = static_cast(n_idx_global); Idx m = static_cast(fft_size / n_idx_global); @@ -354,7 +355,7 @@ class committed_descriptor { // Checks for PACKED layout only at the moment, as the other layout will not be supported // by the global implementation. For such sizes, only PACKED layout will be supported if (detail::fits_in_wi(factor_wi_n) && detail::fits_in_wi(factor_wi_m) && - (local_memory_usage <= static_cast(local_memory_size))) { + (local_memory_usage <= static_cast(local_memory_size)) && !is_prime) { factors.push_back(factor_wi_n); factors.push_back(factor_sg_n); factors.push_back(factor_wi_m); @@ -366,7 +367,8 @@ class committed_descriptor { } } std::vector, std::vector>> param_vec; - auto check_and_select_target_level = [&](IdxGlobal factor_size, bool batch_interleaved_layout = true) -> bool { + auto check_and_select_target_level = [&](IdxGlobal factor_size, bool batch_interleaved_layout = true, + bool load_modifier_required = false) -> bool { if (detail::fits_in_wi(factor_size)) { param_vec.emplace_back(detail::level::WORKITEM, detail::get_ids(), @@ -375,25 +377,13 @@ class committed_descriptor { return true; } bool fits_in_local_memory_subgroup = [&]() { - Idx temp_num_sgs_in_wg; IdxGlobal factor_sg = detail::factorize_sg(factor_size, SubgroupSize); IdxGlobal factor_wi = factor_size / factor_sg; if (detail::can_cast_safely(factor_sg) && detail::can_cast_safely(factor_wi)) { - if (batch_interleaved_layout) { - return (2 * - num_scalars_in_local_mem( - detail::level::SUBGROUP, static_cast(factor_size), SubgroupSize, - {static_cast(factor_sg), static_cast(factor_wi)}, temp_num_sgs_in_wg) * - sizeof(Scalar) + - 2 * static_cast(factor_size) * sizeof(Scalar)) < - static_cast(local_memory_size); - } - return (num_scalars_in_local_mem( - detail::level::SUBGROUP, static_cast(factor_size), SubgroupSize, - {static_cast(factor_sg), static_cast(factor_wi)}, temp_num_sgs_in_wg) * - sizeof(Scalar) + - 2 * static_cast(factor_size) * sizeof(Scalar)) < - static_cast(local_memory_size); + return (detail::get_local_memory_usage(detail::level::SUBGROUP, static_cast(factor_size), + batch_interleaved_layout, load_modifier_required, true, SubgroupSize, + PORTFFT_SGS_IN_WG * SubgroupSize) * + static_cast(sizeof(Scalar))) <= local_memory_size; } return false; }(); @@ -408,7 +398,12 @@ class committed_descriptor { } return false; }; - detail::factorize_input(fft_size, check_and_select_target_level); + if (detail::factorize_input(fft_size, check_and_select_target_level)) { + param_vec.clear(); + IdxGlobal padded_fft_size = + static_cast(std::pow(2, ceil(log(static_cast(fft_size)) / log(2.0)))); + detail::factorize_input(padded_fft_size, check_and_select_target_level, true); + } return {detail::level::GLOBAL, param_vec}; } diff --git a/src/portfft/utils.hpp b/src/portfft/utils.hpp index 6e4e5a4e..16e5388a 100644 --- a/src/portfft/utils.hpp +++ b/src/portfft/utils.hpp @@ -26,6 +26,8 @@ #include #include +#include "common/memory_views.hpp" +#include "common/workgroup.hpp" #include "defines.hpp" #include "enums.hpp" @@ -96,26 +98,32 @@ constexpr bool can_cast_safely(const InputType& x) { * @param factor_size Length of the factor * @param check_and_select_target_level Function which checks whether the factor can fit in one of the existing * implementations - * The function should accept factor size and whether it would be have a BATCH_INTERLEAVED layout or not as an input, - * and should return a boolean indicating whether or not the factor size can fit in any of the implementation. + * The function should accept factor size and whether the data layout in the local memory would be in a batch + * interleaved format, and whether or not load modifers be used. and should return a boolean indicating whether or not + * the factor size can fit in any of the implementation. * @param transposed whether or not the factor will be computed in a BATCH_INTERLEAVED format - * @return + * @param encountered_prime whether or not a large prime was encountered during factorization + * @param requires_load_modifier whether or not load modifier will be required. + * @return Largest factor that was possible to fit in either or workitem/subgroup level FFTs */ template -IdxGlobal factorize_input_impl(IdxGlobal factor_size, F&& check_and_select_target_level, bool transposed) { +IdxGlobal factorize_input_impl(IdxGlobal factor_size, F&& check_and_select_target_level, bool transposed, + bool& encountered_prime, bool requires_load_modifier) { IdxGlobal fact_1 = factor_size; - if (check_and_select_target_level(fact_1, transposed)) { + if (check_and_select_target_level(fact_1, transposed, requires_load_modifier)) { return fact_1; } if ((detail::factorize(fact_1) == 1)) { - throw unsupported_configuration("Large prime sized factors are not supported at the moment"); + encountered_prime = true; + return factor_size; } do { fact_1 = detail::factorize(fact_1); if (fact_1 == 1) { - throw internal_error("Factorization Failed !"); + encountered_prime = true; + return factor_size; } - } while (!check_and_select_target_level(fact_1)); + } while (!check_and_select_target_level(fact_1, transposed, requires_load_modifier)); return fact_1; } @@ -127,16 +135,25 @@ IdxGlobal factorize_input_impl(IdxGlobal factor_size, F&& check_and_select_targe * implementations. The function should accept factor size and whether it would be have a BATCH_INTERLEAVED layout or * not as an input, and should return a boolean indicating whether or not the factor size can fit in any of the * implementation. + * @param requires_load_modifier whether or not load modifier will be required. + * @return whether or not a large prime was encounterd during factorization. */ template -void factorize_input(IdxGlobal input_size, F&& check_and_select_target_level) { +bool factorize_input(IdxGlobal input_size, F&& check_and_select_target_level, bool requires_load_modifier = false) { + bool encountered_prime = false; if (detail::factorize(input_size) == 1) { - throw unsupported_configuration("Large Prime sized FFTs are currently not supported"); + encountered_prime = true; + return encountered_prime; } IdxGlobal temp = 1; while (input_size / temp != 1) { - temp *= factorize_input_impl(input_size / temp, check_and_select_target_level, true); + if (encountered_prime) { + return encountered_prime; + } + temp *= factorize_input_impl(input_size / temp, check_and_select_target_level, true, encountered_prime, + requires_load_modifier); } + return encountered_prime; } /** @@ -175,6 +192,54 @@ inline std::shared_ptr make_shared(std::size_t size, sycl::queue& queue) { }); } +/** + * @brief Gets the cumulative local memory usage for a particular level + * @tparam Scalar Scalar type + * @param level level to get the cumulative local memory usage for + * @param factor_size Factor size + * @param is_batch_interleaved Will the data be in a batch interleaved format in local memory + * @param is_load_modifier_applied Is load modifier applied + * @param is_store_modifier_applied Is store modifier applied + * @param workgroup_size workgroup size with which the kernel will be launched + * @return cumulative local memory usage in terms on number of scalars in local memory. + */ +inline Idx get_local_memory_usage(detail::level level, Idx factor_size, bool is_batch_interleaved, + bool is_load_modifier_applied, bool is_store_modifier_applied, Idx subgroup_size, + Idx workgroup_size) { + Idx local_memory_usage = 0; + switch (level) { + case detail::level::WORKITEM: { + // This will use local memory for load / store modifiers in the future. + if (!is_batch_interleaved) { + local_memory_usage += detail::pad_local(2 * factor_size * workgroup_size, 1); + } + } break; + case detail::level::SUBGROUP: { + local_memory_usage += 2 * factor_size; + Idx fact_sg = factorize_sg(factor_size, subgroup_size); + Idx num_ffts_in_sg = subgroup_size / fact_sg; + Idx num_ffts_in_local_mem = + is_batch_interleaved ? workgroup_size / 2 : num_ffts_in_sg * (workgroup_size / subgroup_size); + local_memory_usage += detail::pad_local(2 * num_ffts_in_local_mem * factor_size, 1); + if (is_load_modifier_applied) { + local_memory_usage += detail::pad_local(2 * num_ffts_in_local_mem * factor_size, 1); + } + if (is_store_modifier_applied) { + local_memory_usage += detail::pad_local(2 * num_ffts_in_local_mem * factor_size, 1); + } + } break; + case detail::level::WORKGROUP: { + Idx n = detail::factorize(factor_size); + Idx m = factor_size / n; + Idx num_ffts_in_local_mem = is_batch_interleaved ? workgroup_size / 2 : 1; + local_memory_usage += detail::pad_local(2 * factor_size * num_ffts_in_local_mem, bank_lines_per_pad_wg(m)); + } break; + default: + break; + } + return local_memory_usage; +} + } // namespace detail } // namespace portfft #endif From a0b6218c06d6bbff16341647f7adda62d254643e Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Tue, 28 Nov 2023 11:16:32 +0000 Subject: [PATCH 02/67] passing vector to contain l2 events in global level FFTs --- src/portfft/common/global.hpp | 30 ++++++------ src/portfft/descriptor.hpp | 18 +++---- src/portfft/dispatcher/global_dispatcher.hpp | 50 +++++++++++--------- 3 files changed, 52 insertions(+), 46 deletions(-) diff --git a/src/portfft/common/global.hpp b/src/portfft/common/global.hpp index 71ba7395..f79fca1e 100644 --- a/src/portfft/common/global.hpp +++ b/src/portfft/common/global.hpp @@ -452,18 +452,20 @@ sycl::event transpose_level(const typename committed_descriptor: * @param batch_start start of the current global batch being processed * @param factor_id current factor being proccessed * @param total_factors total number of factors - * @param dependencies even dependencies + * @param in_dependencies input dependencies + * @param out_events std::vector to store event per batch in l2 * @param queue queue - * @return vector events, one for each batch in l2 + * @return void */ template -std::vector compute_level( - const typename committed_descriptor::kernel_data_struct& kd_struct, const TIn input, Scalar* output, - const Scalar* twiddles_ptr, const IdxGlobal* factors_triple, Scalar scale_factor, - IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, IdxGlobal input_global_offset, - IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, IdxGlobal batch_start, Idx factor_id, - Idx total_factors, const std::vector& dependencies, sycl::queue& queue) { +void compute_level(const typename committed_descriptor::kernel_data_struct& kd_struct, const TIn input, + Scalar* output, const Scalar* twiddles_ptr, const IdxGlobal* factors_triple, Scalar scale_factor, + IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, + IdxGlobal input_global_offset, IdxGlobal committed_size, Idx num_batches_in_l2, + IdxGlobal n_transforms, IdxGlobal batch_start, Idx factor_id, Idx total_factors, + const std::vector& in_dependencies, std::vector& out_events, + sycl::queue& queue) { IdxGlobal local_range = kd_struct.local_range; IdxGlobal global_range = kd_struct.global_range; IdxGlobal batch_size = kd_struct.batch_size; @@ -493,21 +495,20 @@ std::vector compute_level( }(); const IdxGlobal* inner_batches = factors_triple + total_factors; const IdxGlobal* inclusive_scan = factors_triple + 2 * total_factors; - std::vector events; for (Idx batch_in_l2 = 0; batch_in_l2 < num_batches_in_l2 && batch_in_l2 + batch_start < n_transforms; batch_in_l2++) { - events.push_back(queue.submit([&](sycl::handler& cgh) { + out_events[static_cast(batch_in_l2)] = queue.submit([&](sycl::handler& cgh) { sycl::local_accessor loc_for_input(local_memory_for_input, cgh); sycl::local_accessor loc_for_twiddles(loc_mem_for_twiddles, cgh); sycl::local_accessor loc_for_modifier(local_mem_for_store_modifier, cgh); auto in_acc_or_usm = detail::get_access(input, cgh); cgh.use_kernel_bundle(kd_struct.exec_bundle); - if (static_cast(dependencies.size()) < num_batches_in_l2) { - cgh.depends_on(dependencies); + if (static_cast(in_dependencies.size()) < num_batches_in_l2) { + cgh.depends_on(in_dependencies); } else { // If events is a vector, the order of events is assumed to correspond to the order batches present in last // level cache. - cgh.depends_on(dependencies.at(static_cast(batch_in_l2))); + cgh.depends_on(in_dependencies.at(static_cast(batch_in_l2))); } detail::launch_kernel( in_acc_or_usm, output + 2 * batch_in_l2 * committed_size, loc_for_input, loc_for_twiddles, loc_for_modifier, @@ -517,9 +518,8 @@ std::vector compute_level( {sycl::range<1>(static_cast(global_range)), sycl::range<1>(static_cast(local_range))}, cgh); - })); + }); } - return events; } } // namespace detail } // namespace portfft diff --git a/src/portfft/descriptor.hpp b/src/portfft/descriptor.hpp index c145da87..2805e9a8 100644 --- a/src/portfft/descriptor.hpp +++ b/src/portfft/descriptor.hpp @@ -42,12 +42,13 @@ class committed_descriptor; namespace detail { template -std::vector compute_level( - const typename committed_descriptor::kernel_data_struct& kd_struct, TIn input, Scalar* output, - const Scalar* twiddles_ptr, const IdxGlobal* factors_triple, Scalar scale_factor, - IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, IdxGlobal input_global_offset, - IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, IdxGlobal batch_start, Idx factor_id, - Idx total_factors, const std::vector& dependencies, sycl::queue& queue); +void compute_level(const typename committed_descriptor::kernel_data_struct& kd_struct, TIn input, + Scalar* output, const Scalar* twiddles_ptr, const IdxGlobal* factors_triple, Scalar scale_factor, + IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, + IdxGlobal input_global_offset, IdxGlobal committed_size, Idx num_batches_in_l2, + IdxGlobal n_transforms, IdxGlobal batch_start, Idx factor_id, Idx total_factors, + const std::vector& dependencies, std::vector& out_events, + sycl::queue& queue); template sycl::event transpose_level(const typename committed_descriptor::kernel_data_struct& kd_struct, @@ -178,12 +179,13 @@ class committed_descriptor { friend struct descriptor; template - friend std::vector detail::compute_level( + friend void detail::compute_level( const typename committed_descriptor::kernel_data_struct& kd_struct, TIn input, Scalar1* output, const Scalar1* twiddles_ptr, const IdxGlobal* factors_triple, Scalar1 scale_factor, IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, IdxGlobal input_global_offset, IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, IdxGlobal batch_start, Idx factor_id, - Idx total_factors, const std::vector& dependencies, sycl::queue& queue); + Idx total_factors, const std::vector& dependencies, std::vector& out_events, + sycl::queue& queue); template friend sycl::event detail::transpose_level( diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index 772c35cb..a4288044 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -293,8 +293,11 @@ struct committed_descriptor::run_kernel_struct(desc.params.lengths[0]); Idx num_transposes = num_factors - 1; - std::vector l2_events; - sycl::event event = desc.queue.submit([&](sycl::handler& cgh) { + std::vector current_events; + std::vector previous_events; + current_events.resize(static_cast(desc.dimensions.at(0).num_batches_in_l2)); + previous_events.resize(static_cast(desc.dimensions.at(0).num_batches_in_l2)); + current_events[0] = desc.queue.submit([&](sycl::handler& cgh) { cgh.depends_on(dependencies); cgh.host_task([&]() {}); }); @@ -305,63 +308,64 @@ struct committed_descriptor::run_kernel_struct( + detail::compute_level( desc.dimensions.at(0).kernels.at(0), in, desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, scale_factor, intermediate_twiddles_offset, impl_twiddle_offset, 2 * static_cast(i) * committed_size + input_offset, committed_size, static_cast(max_batches_in_l2), static_cast(num_batches), static_cast(i), 0, - desc.dimensions.at(0).num_factors, {event}, desc.queue); + desc.dimensions.at(0).num_factors, current_events, previous_events, desc.queue); intermediate_twiddles_offset += 2 * desc.dimensions.at(0).kernels.at(0).batch_size * static_cast(desc.dimensions.at(0).kernels.at(0).length); impl_twiddle_offset += detail::increment_twiddle_offset( desc.dimensions.at(0).kernels.at(0).level, static_cast(desc.dimensions.at(0).kernels.at(0).length)); + current_events.swap(previous_events); for (std::size_t factor_num = 1; factor_num < static_cast(desc.dimensions.at(0).num_factors); factor_num++) { if (static_cast(factor_num) == desc.dimensions.at(0).num_factors - 1) { - l2_events = - detail::compute_level( - desc.dimensions.at(0).kernels.at(factor_num), static_cast(desc.scratch_ptr_1.get()), - desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, scale_factor, intermediate_twiddles_offset, - impl_twiddle_offset, 0, committed_size, static_cast(max_batches_in_l2), - static_cast(num_batches), static_cast(i), static_cast(factor_num), - desc.dimensions.at(0).num_factors, l2_events, desc.queue); + detail::compute_level( + desc.dimensions.at(0).kernels.at(factor_num), static_cast(desc.scratch_ptr_1.get()), + desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, scale_factor, intermediate_twiddles_offset, + impl_twiddle_offset, 0, committed_size, static_cast(max_batches_in_l2), + static_cast(num_batches), static_cast(i), static_cast(factor_num), + desc.dimensions.at(0).num_factors, current_events, previous_events, desc.queue); } else { - l2_events = detail::compute_level( + detail::compute_level( desc.dimensions.at(0).kernels.at(factor_num), static_cast(desc.scratch_ptr_1.get()), desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, scale_factor, intermediate_twiddles_offset, impl_twiddle_offset, 0, committed_size, static_cast(max_batches_in_l2), static_cast(num_batches), static_cast(i), static_cast(factor_num), - desc.dimensions.at(0).num_factors, l2_events, desc.queue); + desc.dimensions.at(0).num_factors, current_events, previous_events, desc.queue); intermediate_twiddles_offset += 2 * desc.dimensions.at(0).kernels.at(factor_num).batch_size * static_cast(desc.dimensions.at(0).kernels.at(factor_num).length); impl_twiddle_offset += detail::increment_twiddle_offset(desc.dimensions.at(0).kernels.at(factor_num).level, static_cast(desc.dimensions.at(0).kernels.at(factor_num).length)); + current_events.swap(previous_events); } } - event = desc.queue.submit([&](sycl::handler& cgh) { - cgh.depends_on(l2_events); + current_events[0] = desc.queue.submit([&](sycl::handler& cgh) { + cgh.depends_on(previous_events); cgh.host_task([&]() {}); }); for (Idx num_transpose = num_transposes - 1; num_transpose > 0; num_transpose--) { - event = detail::transpose_level( + current_events[0] = detail::transpose_level( desc.dimensions.at(0).kernels.at(static_cast(num_transpose) + static_cast(num_factors)), static_cast(desc.scratch_ptr_1.get()), desc.scratch_ptr_2.get(), factors_and_scan, committed_size, static_cast(max_batches_in_l2), n_transforms, static_cast(i), num_transpose, - num_factors, 0, desc.queue, desc.scratch_ptr_1, desc.scratch_ptr_2, {event}); - event.wait(); + num_factors, 0, desc.queue, desc.scratch_ptr_1, desc.scratch_ptr_2, current_events); + current_events[0].wait(); } - event = detail::transpose_level( + current_events[0] = detail::transpose_level( desc.dimensions.at(0).kernels.at(static_cast(num_factors)), static_cast(desc.scratch_ptr_1.get()), out, factors_and_scan, committed_size, static_cast(max_batches_in_l2), n_transforms, static_cast(i), 0, num_factors, 2 * static_cast(i) * committed_size + output_offset, desc.queue, desc.scratch_ptr_1, - desc.scratch_ptr_2, {event}); + desc.scratch_ptr_2, current_events); } - return event; + return current_events[0]; } }; From d6dff45d72cffdb373d7fcf4b4a407b4a48653eb Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Tue, 28 Nov 2023 11:31:03 +0000 Subject: [PATCH 03/67] passing pre-allocated vector to transpose_level to collect llc events --- src/portfft/common/global.hpp | 12 ++++++------ src/portfft/descriptor.hpp | 4 ++-- src/portfft/dispatcher/global_dispatcher.hpp | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/portfft/common/global.hpp b/src/portfft/common/global.hpp index f79fca1e..068f48bc 100644 --- a/src/portfft/common/global.hpp +++ b/src/portfft/common/global.hpp @@ -381,6 +381,7 @@ static void dispatch_transpose_kernel_impl(const Scalar* input, Scalar* output, * @param ptr1 shared_ptr for the first scratch pointer * @param ptr2 shared_ptr for the second scratch pointer * @param events event dependencies + * @param generated_events std::vector to collect all generated events during transpositions. * @return sycl::event */ template @@ -389,8 +390,7 @@ sycl::event transpose_level(const typename committed_descriptor: Idx num_batches_in_l2, IdxGlobal n_transforms, IdxGlobal batch_start, Idx factor_num, Idx total_factors, IdxGlobal output_offset, sycl::queue& queue, std::shared_ptr& ptr1, std::shared_ptr& ptr2, - const std::vector& events) { - std::vector transpose_events; + const std::vector& events, std::vector& generated_events) { IdxGlobal ld_input = kd_struct.factors.at(1); IdxGlobal ld_output = kd_struct.factors.at(0); const IdxGlobal* inner_batches = factors_triple + total_factors; @@ -398,7 +398,7 @@ sycl::event transpose_level(const typename committed_descriptor: for (Idx batch_in_l2 = 0; batch_in_l2 < num_batches_in_l2 && (static_cast(batch_in_l2) + batch_start) < n_transforms; batch_in_l2++) { - transpose_events.push_back(queue.submit([&](sycl::handler& cgh) { + generated_events[static_cast(batch_in_l2)] = queue.submit([&](sycl::handler& cgh) { auto out_acc_or_usm = detail::get_access(output, cgh); sycl::local_accessor loc({16, 32}, cgh); if (static_cast(events.size()) < num_batches_in_l2) { @@ -412,16 +412,16 @@ sycl::event transpose_level(const typename committed_descriptor: detail::dispatch_transpose_kernel_impl( input + 2 * committed_size * batch_in_l2, out_acc_or_usm, loc, factors_triple, inner_batches, inclusive_scan, output_offset + 2 * committed_size * batch_in_l2, ld_output, ld_input, cgh); - })); + }); } if (factor_num != 0) { return queue.submit([&](sycl::handler& cgh) { - cgh.depends_on(transpose_events); + cgh.depends_on(generated_events); cgh.host_task([&]() { ptr1.swap(ptr2); }); }); } return queue.submit([&](sycl::handler& cgh) { - cgh.depends_on(transpose_events); + cgh.depends_on(generated_events); cgh.host_task([&]() {}); }); } diff --git a/src/portfft/descriptor.hpp b/src/portfft/descriptor.hpp index 2805e9a8..8f87d306 100644 --- a/src/portfft/descriptor.hpp +++ b/src/portfft/descriptor.hpp @@ -56,7 +56,7 @@ sycl::event transpose_level(const typename committed_descriptor: Idx num_batches_in_l2, IdxGlobal n_transforms, IdxGlobal batch_start, Idx factor_num, Idx total_factors, IdxGlobal output_offset, sycl::queue& queue, std::shared_ptr& ptr1, std::shared_ptr& ptr2, - const std::vector& events); + const std::vector& events, std::vector& generated_events); // kernel names // TODO: Remove all templates except Scalar, Domain and Memory and SubgroupSize @@ -193,7 +193,7 @@ class committed_descriptor { TOut output, const IdxGlobal* factors_triple, IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, IdxGlobal batch_start, Idx factor_num, Idx total_factors, IdxGlobal output_offset, sycl::queue& queue, std::shared_ptr& ptr1, std::shared_ptr& ptr2, - const std::vector& events); + const std::vector& events, std::vector& generated_events); descriptor params; sycl::queue queue; diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index a4288044..bccf7d07 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -355,7 +355,7 @@ struct committed_descriptor::run_kernel_struct(num_factors)), static_cast(desc.scratch_ptr_1.get()), desc.scratch_ptr_2.get(), factors_and_scan, committed_size, static_cast(max_batches_in_l2), n_transforms, static_cast(i), num_transpose, - num_factors, 0, desc.queue, desc.scratch_ptr_1, desc.scratch_ptr_2, current_events); + num_factors, 0, desc.queue, desc.scratch_ptr_1, desc.scratch_ptr_2, current_events, previous_events); current_events[0].wait(); } current_events[0] = detail::transpose_level( @@ -363,7 +363,7 @@ struct committed_descriptor::run_kernel_struct(desc.scratch_ptr_1.get()), out, factors_and_scan, committed_size, static_cast(max_batches_in_l2), n_transforms, static_cast(i), 0, num_factors, 2 * static_cast(i) * committed_size + output_offset, desc.queue, desc.scratch_ptr_1, - desc.scratch_ptr_2, current_events); + desc.scratch_ptr_2, current_events, previous_events); } return current_events[0]; } From 74c8c80e64a34ad38da47f2f414cbdc71f9586dc Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Wed, 29 Nov 2023 17:37:05 +0000 Subject: [PATCH 04/67] passing dimension_struct instead of kernel_data_struct --- src/portfft/descriptor.hpp | 28 +++---- src/portfft/dispatcher/global_dispatcher.hpp | 81 +++++++++---------- .../dispatcher/subgroup_dispatcher.hpp | 21 ++--- .../dispatcher/workgroup_dispatcher.hpp | 23 +++--- .../dispatcher/workitem_dispatcher.hpp | 14 ++-- 5 files changed, 78 insertions(+), 89 deletions(-) diff --git a/src/portfft/descriptor.hpp b/src/portfft/descriptor.hpp index 873bc4c5..61b9b62d 100644 --- a/src/portfft/descriptor.hpp +++ b/src/portfft/descriptor.hpp @@ -490,18 +490,18 @@ class committed_descriptor { // Dummy parameter is needed as only partial specializations are allowed without specializing the containing class template struct inner { - static Scalar* execute(committed_descriptor& desc, kernel_data_struct& kernel_data); + static Scalar* execute(committed_descriptor& desc, dimension_struct& dimension_data); }; }; /** * Calculates twiddle factors for the implementation in use. * - * @param kernel_data data about the kernel the twiddles are needed for + * @param dimension_data data about the dimension for which twiddles are needed * @return Scalar* USM pointer to the twiddle factors */ - Scalar* calculate_twiddles(kernel_data_struct& kernel_data) { - return dispatch(kernel_data.level, kernel_data); + Scalar* calculate_twiddles(dimension_struct& dimension_data) { + return dispatch(dimension_data.level, dimension_data); } /** @@ -740,24 +740,14 @@ class committed_descriptor { std::size_t n_kernels = params.lengths.size(); for (std::size_t i = 0; i < n_kernels; i++) { dimensions.push_back(build_w_spec_const(i)); - if (dimensions.at(i).level == detail::level::GLOBAL) { - dimensions.back().kernels.at(0).twiddles_forward = std::shared_ptr( - dispatch(detail::level::GLOBAL, dimensions.back().kernels.at(0)), - [queue](Scalar* ptr) { - if (ptr != nullptr) { - sycl::free(ptr, queue); - } - }); - } else { - for (kernel_data_struct& kernel : dimensions.back().kernels) { - kernel.twiddles_forward = std::shared_ptr(calculate_twiddles(kernel), [queue](Scalar* ptr) { + dimensions.back().kernels.at(0).twiddles_forward = + std::shared_ptr(calculate_twiddles(dimensions.back()), [queue](Scalar* ptr) { if (ptr != nullptr) { sycl::free(ptr, queue); } }); - } - } } + bool is_scratch_required = false; Idx num_global_level_dimensions = 0; for (std::size_t i = 0; i < n_kernels; i++) { @@ -1354,7 +1344,7 @@ class committed_descriptor { static sycl::event execute(committed_descriptor& desc, const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, const std::vector& dependencies, std::size_t n_transforms, std::size_t forward_offset, std::size_t backward_offset, Scalar scale_factor, - std::vector& kernels); + dimension_struct& dimension_data); }; }; @@ -1405,7 +1395,7 @@ class committed_descriptor { dimension_data.level, detail::reinterpret(in), detail::reinterpret(out), detail::reinterpret(in_imag), detail::reinterpret(out_imag), dependencies, static_cast(n_transforms), static_cast(vec_multiplier * input_offset), - static_cast(vec_multiplier * output_offset), scale_factor, dimension_data.kernels); + static_cast(vec_multiplier * output_offset), scale_factor, dimension_data); } }; diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index 772c35cb..fa43ff71 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -105,10 +105,11 @@ inline IdxGlobal increment_twiddle_offset(detail::level level, Idx factor_size) template template struct committed_descriptor::calculate_twiddles_struct::inner { - static Scalar* execute(committed_descriptor& desc, kernel_data_struct& /*kernel_data*/) { + static Scalar* execute(committed_descriptor& desc, dimension_struct& dimension_data) { + auto& kernels = dimension_data.kernels; std::vector factors_idx_global; // Get factor sizes per level; - for (const auto& kernel_data : desc.dimensions.back().kernels) { + for (const auto& kernel_data : kernels) { factors_idx_global.push_back(static_cast( std::accumulate(kernel_data.factors.begin(), kernel_data.factors.end(), 1, std::multiplies()))); } @@ -128,7 +129,7 @@ struct committed_descriptor::calculate_twiddles_struct::inner::calculate_twiddles_struct::inner::calculate_twiddles_struct::inner(factors_idx_global.at(counter)); if (kernel_data.level == detail::level::WORKITEM) { @@ -211,7 +212,7 @@ struct committed_descriptor::calculate_twiddles_struct::inner(1); } else { kernel_data.local_mem_required = 2 * static_cast(local_range * factors_idx_global.at(counter)); @@ -226,7 +227,7 @@ struct committed_descriptor::calculate_twiddles_struct::inner( detail::level::SUBGROUP, static_cast(factors_idx_global.at(counter)), kernel_data.used_sg_size, {static_cast(factor_sg), static_cast(factor_wi)}, tmp); @@ -282,15 +283,16 @@ struct committed_descriptor::run_kernel_struct& dependencies, IdxGlobal n_transforms, IdxGlobal input_offset, IdxGlobal output_offset, Scalar scale_factor, - std::vector& kernel_data) { + dimension_struct& dimension_data) { (void)in_imag; (void)out_imag; - const Scalar* twiddles_ptr = static_cast(kernel_data.at(0).twiddles_forward.get()); - const IdxGlobal* factors_and_scan = static_cast(desc.dimensions.at(0).factors_and_scan.get()); + const auto& kernels = dimension_data.kernels; + const Scalar* twiddles_ptr = static_cast(kernels.at(0).twiddles_forward.get()); + const IdxGlobal* factors_and_scan = static_cast(dimension_data.factors_and_scan.get()); std::size_t num_batches = desc.params.number_of_transforms; - std::size_t max_batches_in_l2 = static_cast(desc.dimensions.at(0).num_batches_in_l2); + std::size_t max_batches_in_l2 = static_cast(dimension_data.num_batches_in_l2); IdxGlobal initial_impl_twiddle_offset = 0; - Idx num_factors = desc.dimensions.at(0).num_factors; + Idx num_factors = dimension_data.num_factors; IdxGlobal committed_size = static_cast(desc.params.lengths[0]); Idx num_transposes = num_factors - 1; std::vector l2_events; @@ -299,46 +301,43 @@ struct committed_descriptor::run_kernel_struct(num_factors - 1); i++) { - initial_impl_twiddle_offset += 2 * desc.dimensions.at(0).kernels.at(i).batch_size * - static_cast(desc.dimensions.at(0).kernels.at(i).length); + initial_impl_twiddle_offset += 2 * kernels.at(i).batch_size * static_cast(kernels.at(i).length); } for (std::size_t i = 0; i < num_batches; i += max_batches_in_l2) { IdxGlobal intermediate_twiddles_offset = 0; IdxGlobal impl_twiddle_offset = initial_impl_twiddle_offset; l2_events = detail::compute_level( - desc.dimensions.at(0).kernels.at(0), in, desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, - scale_factor, intermediate_twiddles_offset, impl_twiddle_offset, + kernels.at(0), in, desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, scale_factor, + intermediate_twiddles_offset, impl_twiddle_offset, 2 * static_cast(i) * committed_size + input_offset, committed_size, static_cast(max_batches_in_l2), static_cast(num_batches), static_cast(i), 0, - desc.dimensions.at(0).num_factors, {event}, desc.queue); - intermediate_twiddles_offset += 2 * desc.dimensions.at(0).kernels.at(0).batch_size * - static_cast(desc.dimensions.at(0).kernels.at(0).length); - impl_twiddle_offset += detail::increment_twiddle_offset( - desc.dimensions.at(0).kernels.at(0).level, static_cast(desc.dimensions.at(0).kernels.at(0).length)); - for (std::size_t factor_num = 1; factor_num < static_cast(desc.dimensions.at(0).num_factors); + dimension_data.num_factors, {event}, desc.queue); + intermediate_twiddles_offset += 2 * kernels.at(0).batch_size * static_cast(kernels.at(0).length); + impl_twiddle_offset += + detail::increment_twiddle_offset(kernels.at(0).level, static_cast(kernels.at(0).length)); + for (std::size_t factor_num = 1; factor_num < static_cast(dimension_data.num_factors); factor_num++) { - if (static_cast(factor_num) == desc.dimensions.at(0).num_factors - 1) { + if (static_cast(factor_num) == dimension_data.num_factors - 1) { l2_events = detail::compute_level( - desc.dimensions.at(0).kernels.at(factor_num), static_cast(desc.scratch_ptr_1.get()), + kernels.at(factor_num), static_cast(desc.scratch_ptr_1.get()), desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, scale_factor, intermediate_twiddles_offset, impl_twiddle_offset, 0, committed_size, static_cast(max_batches_in_l2), static_cast(num_batches), static_cast(i), static_cast(factor_num), - desc.dimensions.at(0).num_factors, l2_events, desc.queue); + dimension_data.num_factors, l2_events, desc.queue); } else { l2_events = detail::compute_level( - desc.dimensions.at(0).kernels.at(factor_num), static_cast(desc.scratch_ptr_1.get()), - desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, scale_factor, intermediate_twiddles_offset, - impl_twiddle_offset, 0, committed_size, static_cast(max_batches_in_l2), - static_cast(num_batches), static_cast(i), static_cast(factor_num), - desc.dimensions.at(0).num_factors, l2_events, desc.queue); - intermediate_twiddles_offset += 2 * desc.dimensions.at(0).kernels.at(factor_num).batch_size * - static_cast(desc.dimensions.at(0).kernels.at(factor_num).length); - impl_twiddle_offset += - detail::increment_twiddle_offset(desc.dimensions.at(0).kernels.at(factor_num).level, - static_cast(desc.dimensions.at(0).kernels.at(factor_num).length)); + kernels.at(factor_num), static_cast(desc.scratch_ptr_1.get()), desc.scratch_ptr_1.get(), + twiddles_ptr, factors_and_scan, scale_factor, intermediate_twiddles_offset, impl_twiddle_offset, 0, + committed_size, static_cast(max_batches_in_l2), static_cast(num_batches), + static_cast(i), static_cast(factor_num), dimension_data.num_factors, l2_events, + desc.queue); + intermediate_twiddles_offset += + 2 * kernels.at(factor_num).batch_size * static_cast(kernels.at(factor_num).length); + impl_twiddle_offset += detail::increment_twiddle_offset(kernels.at(factor_num).level, + static_cast(kernels.at(factor_num).length)); } } event = desc.queue.submit([&](sycl::handler& cgh) { @@ -347,19 +346,17 @@ struct committed_descriptor::run_kernel_struct 0; num_transpose--) { event = detail::transpose_level( - desc.dimensions.at(0).kernels.at(static_cast(num_transpose) + - static_cast(num_factors)), + kernels.at(static_cast(num_transpose) + static_cast(num_factors)), static_cast(desc.scratch_ptr_1.get()), desc.scratch_ptr_2.get(), factors_and_scan, committed_size, static_cast(max_batches_in_l2), n_transforms, static_cast(i), num_transpose, num_factors, 0, desc.queue, desc.scratch_ptr_1, desc.scratch_ptr_2, {event}); event.wait(); } event = detail::transpose_level( - desc.dimensions.at(0).kernels.at(static_cast(num_factors)), - static_cast(desc.scratch_ptr_1.get()), out, factors_and_scan, committed_size, - static_cast(max_batches_in_l2), n_transforms, static_cast(i), 0, num_factors, - 2 * static_cast(i) * committed_size + output_offset, desc.queue, desc.scratch_ptr_1, - desc.scratch_ptr_2, {event}); + kernels.at(static_cast(num_factors)), static_cast(desc.scratch_ptr_1.get()), out, + factors_and_scan, committed_size, static_cast(max_batches_in_l2), n_transforms, + static_cast(i), 0, num_factors, 2 * static_cast(i) * committed_size + output_offset, + desc.queue, desc.scratch_ptr_1, desc.scratch_ptr_2, {event}); } return event; } diff --git a/src/portfft/dispatcher/subgroup_dispatcher.hpp b/src/portfft/dispatcher/subgroup_dispatcher.hpp index 08495863..e875fb3e 100644 --- a/src/portfft/dispatcher/subgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/subgroup_dispatcher.hpp @@ -579,7 +579,8 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag template template struct committed_descriptor::calculate_twiddles_struct::inner { - static Scalar* execute(committed_descriptor& desc, kernel_data_struct& kernel_data) { + static Scalar* execute(committed_descriptor& desc, dimension_struct& dimension_data) { + const auto& kernel_data = dimension_data.kernels.at(0); Idx factor_wi = kernel_data.factors[0]; Idx factor_sg = kernel_data.factors[1]; Scalar* res = sycl::aligned_alloc_device( @@ -607,20 +608,20 @@ struct committed_descriptor::run_kernel_struct& dependencies, IdxGlobal n_transforms, IdxGlobal input_offset, IdxGlobal output_offset, Scalar scale_factor, - std::vector& kernel_data) { + dimension_struct& dimension_data) { constexpr detail::memory Mem = std::is_pointer::value ? detail::memory::USM : detail::memory::BUFFER; - Scalar* twiddles = kernel_data[0].twiddles_forward.get(); - Idx factor_sg = kernel_data[0].factors[1]; + auto& kernel_data = dimension_data.kernels.at(0); + Scalar* twiddles = kernel_data.twiddles_forward.get(); + Idx factor_sg = kernel_data.factors[1]; std::size_t local_elements = num_scalars_in_local_mem_struct::template inner::execute( - desc, kernel_data[0].length, kernel_data[0].used_sg_size, kernel_data[0].factors, - kernel_data[0].num_sgs_per_wg); + desc, kernel_data.length, kernel_data.used_sg_size, kernel_data.factors, kernel_data.num_sgs_per_wg); std::size_t global_size = static_cast(detail::get_global_size_subgroup( - n_transforms, factor_sg, SubgroupSize, kernel_data[0].num_sgs_per_wg, desc.n_compute_units)); - std::size_t twiddle_elements = 2 * kernel_data[0].length; + n_transforms, factor_sg, SubgroupSize, kernel_data.num_sgs_per_wg, desc.n_compute_units)); + std::size_t twiddle_elements = 2 * kernel_data.length; return desc.queue.submit([&](sycl::handler& cgh) { cgh.depends_on(dependencies); - cgh.use_kernel_bundle(kernel_data[0].exec_bundle); + cgh.use_kernel_bundle(kernel_data.exec_bundle); auto in_acc_or_usm = detail::get_access(in, cgh); auto out_acc_or_usm = detail::get_access(out, cgh); auto in_imag_acc_or_usm = detail::get_access(in_imag, cgh); @@ -631,7 +632,7 @@ struct committed_descriptor::run_kernel_struct>( - sycl::nd_range<1>{{global_size}, {static_cast(SubgroupSize * kernel_data[0].num_sgs_per_wg)}}, + sycl::nd_range<1>{{global_size}, {static_cast(SubgroupSize * kernel_data.num_sgs_per_wg)}}, [=](sycl::nd_item<1> it, sycl::kernel_handler kh) [[sycl::reqd_sub_group_size(SubgroupSize)]] { detail::global_data_struct global_data{ #ifdef PORTFFT_LOG diff --git a/src/portfft/dispatcher/workgroup_dispatcher.hpp b/src/portfft/dispatcher/workgroup_dispatcher.hpp index 106c664b..8b845e19 100644 --- a/src/portfft/dispatcher/workgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/workgroup_dispatcher.hpp @@ -211,29 +211,29 @@ struct committed_descriptor::run_kernel_struct& dependencies, IdxGlobal n_transforms, IdxGlobal input_offset, IdxGlobal output_offset, Scalar scale_factor, - std::vector& kernel_data) { + dimension_struct& dimension_data) { + auto& kernel_data = dimension_data.kernels.at(0); Idx num_batches_in_local_mem = [=]() { if constexpr (LayoutIn == detail::layout::BATCH_INTERLEAVED) { - return kernel_data[0].used_sg_size * PORTFFT_SGS_IN_WG / 2; + return kernel_data.used_sg_size * PORTFFT_SGS_IN_WG / 2; } else { return 1; } }(); constexpr detail::memory Mem = std::is_pointer::value ? detail::memory::USM : detail::memory::BUFFER; - Scalar* twiddles = kernel_data[0].twiddles_forward.get(); + Scalar* twiddles = kernel_data.twiddles_forward.get(); std::size_t local_elements = num_scalars_in_local_mem_struct::template inner::execute( - desc, kernel_data[0].length, kernel_data[0].used_sg_size, kernel_data[0].factors, - kernel_data[0].num_sgs_per_wg); + desc, kernel_data.length, kernel_data.used_sg_size, kernel_data.factors, kernel_data.num_sgs_per_wg); std::size_t global_size = static_cast(detail::get_global_size_workgroup( - n_transforms, SubgroupSize, kernel_data[0].num_sgs_per_wg, desc.n_compute_units)); - const Idx bank_lines_per_pad = bank_lines_per_pad_wg(2 * static_cast(sizeof(Scalar)) * - kernel_data[0].factors[2] * kernel_data[0].factors[3]); + n_transforms, SubgroupSize, kernel_data.num_sgs_per_wg, desc.n_compute_units)); + const Idx bank_lines_per_pad = + bank_lines_per_pad_wg(2 * static_cast(sizeof(Scalar)) * kernel_data.factors[2] * kernel_data.factors[3]); std::size_t sg_twiddles_offset = static_cast( - detail::pad_local(2 * static_cast(kernel_data[0].length) * num_batches_in_local_mem, bank_lines_per_pad)); + detail::pad_local(2 * static_cast(kernel_data.length) * num_batches_in_local_mem, bank_lines_per_pad)); return desc.queue.submit([&](sycl::handler& cgh) { cgh.depends_on(dependencies); - cgh.use_kernel_bundle(kernel_data[0].exec_bundle); + cgh.use_kernel_bundle(kernel_data.exec_bundle); auto in_acc_or_usm = detail::get_access(in, cgh); auto out_acc_or_usm = detail::get_access(out, cgh); auto in_imag_acc_or_usm = detail::get_access(in_imag, cgh); @@ -292,7 +292,8 @@ struct committed_descriptor::num_scalars_in_local_mem_struct::in template template struct committed_descriptor::calculate_twiddles_struct::inner { - static Scalar* execute(committed_descriptor& desc, kernel_data_struct& kernel_data) { + static Scalar* execute(committed_descriptor& desc, dimension_struct& dimension_data) { + const auto& kernel_data = dimension_data.kernels.at(0); Idx factor_wi_n = kernel_data.factors[0]; Idx factor_sg_n = kernel_data.factors[1]; Idx factor_wi_m = kernel_data.factors[2]; diff --git a/src/portfft/dispatcher/workitem_dispatcher.hpp b/src/portfft/dispatcher/workitem_dispatcher.hpp index 48d85e6c..d7d7e229 100644 --- a/src/portfft/dispatcher/workitem_dispatcher.hpp +++ b/src/portfft/dispatcher/workitem_dispatcher.hpp @@ -280,17 +280,17 @@ struct committed_descriptor::run_kernel_struct& dependencies, IdxGlobal n_transforms, IdxGlobal input_offset, IdxGlobal output_offset, Scalar scale_factor, - std::vector& kernel_data) { + dimension_struct& dimension_data) { constexpr detail::memory Mem = std::is_pointer::value ? detail::memory::USM : detail::memory::BUFFER; + auto& kernel_data = dimension_data.kernels.at(0); std::size_t local_elements = num_scalars_in_local_mem_struct::template inner::execute( - desc, kernel_data[0].length, kernel_data[0].used_sg_size, kernel_data[0].factors, - kernel_data[0].num_sgs_per_wg); + desc, kernel_data.length, kernel_data.used_sg_size, kernel_data.factors, kernel_data.num_sgs_per_wg); std::size_t global_size = static_cast(detail::get_global_size_workitem( - n_transforms, SubgroupSize, kernel_data[0].num_sgs_per_wg, desc.n_compute_units)); + n_transforms, SubgroupSize, kernel_data.num_sgs_per_wg, desc.n_compute_units)); return desc.queue.submit([&](sycl::handler& cgh) { cgh.depends_on(dependencies); - cgh.use_kernel_bundle(kernel_data[0].exec_bundle); + cgh.use_kernel_bundle(kernel_data.exec_bundle); auto in_acc_or_usm = detail::get_access(in, cgh); auto out_acc_or_usm = detail::get_access(out, cgh); auto in_imag_acc_or_usm = detail::get_access(in_imag, cgh); @@ -300,7 +300,7 @@ struct committed_descriptor::run_kernel_struct>( - sycl::nd_range<1>{{global_size}, {static_cast(SubgroupSize * kernel_data[0].num_sgs_per_wg)}}, + sycl::nd_range<1>{{global_size}, {static_cast(SubgroupSize * kernel_data.num_sgs_per_wg)}}, [=](sycl::nd_item<1> it, sycl::kernel_handler kh) [[sycl::reqd_sub_group_size(SubgroupSize)]] { detail::global_data_struct global_data{ #ifdef PORTFFT_LOG @@ -346,7 +346,7 @@ struct committed_descriptor::num_scalars_in_local_mem_struct::in template template struct committed_descriptor::calculate_twiddles_struct::inner { - static Scalar* execute(committed_descriptor& /*desc*/, kernel_data_struct& /*kernel_data*/) { return nullptr; } + static Scalar* execute(committed_descriptor& /*desc*/, dimension_struct& /*dimension_data*/) { return nullptr; } }; } // namespace portfft From 386c98f788f95b6d4509eecd2b0e7eaf28ef3497 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Thu, 30 Nov 2023 13:44:50 +0000 Subject: [PATCH 05/67] further changes to descriptor --- src/portfft/descriptor.hpp | 38 ++++++++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/src/portfft/descriptor.hpp b/src/portfft/descriptor.hpp index 8f87d306..74eaf296 100644 --- a/src/portfft/descriptor.hpp +++ b/src/portfft/descriptor.hpp @@ -240,6 +240,7 @@ class committed_descriptor { Idx used_sg_size; Idx num_batches_in_l2; Idx num_factors; + bool is_prime; dimension_struct(std::vector kernels, detail::level level, std::size_t length, Idx used_sg_size) : kernels(kernels), level(level), length(length), used_sg_size(used_sg_size) {} @@ -405,6 +406,7 @@ class committed_descriptor { IdxGlobal padded_fft_size = static_cast(std::pow(2, ceil(log(static_cast(fft_size)) / log(2.0)))); detail::factorize_input(padded_fft_size, check_and_select_target_level, true); + detail::factorize_input(padded_fft_size, check_and_select_target_level, false); } return {detail::level::GLOBAL, param_vec}; } @@ -524,6 +526,7 @@ class committed_descriptor { std::vector result; if (is_compatible) { std::size_t counter = 0; + std::size_t dimension_size = 1; for (auto [level, ids, factors] : prepared_vec) { auto in_bundle = sycl::get_kernel_bundle(queue.get_context(), ids); if (top_level == detail::level::GLOBAL) { @@ -556,10 +559,12 @@ class committed_descriptor { is_compatible = false; break; } + dimension_size *= + static_cast(std::accumulate(factors.begin(), factors.end(), 1, std::multiplies())); counter++; } if (is_compatible) { - return {result, top_level, params.lengths[kernel_num], SubgroupSize}; + return {result, top_level, dimension_size, SubgroupSize}; } } } @@ -584,13 +589,30 @@ class committed_descriptor { break; } } + IdxGlobal dimension_size = static_cast(dimensions.at(global_dimension).length); + Idx num_forward_factors = 0; + Idx num_backward_factors = 0; + IdxGlobal temp_acc = 1; std::vector factors; std::vector sub_batches; std::vector inclusive_scan; std::size_t cache_required_for_twiddles = 0; for (const auto& kernel_data : dimensions.at(global_dimension).kernels) { - IdxGlobal factor_size = static_cast( + if (temp_acc == dimension_size) { + break; + } + temp_acc *= static_cast( std::accumulate(kernel_data.factors.begin(), kernel_data.factors.end(), 1, std::multiplies())); + num_forward_factors++; + } + if (dimensions.at(i).is_prime) { + num_backward_factors = static_cast(dimensions.at(global_dimension).kernels.length()) - num_forward_factors; + } + const auto& kernels_for_dimension = dimensions.at(global_dimension).kernels; + for (Idx i = 0; i < num_forward_factors; i++) { + IdxGlobal factor_size = static_cast(std::accumulate(kernels_for_dimension.at(i).factors.begin(), + kernels_for_dimension.at(i).factors.end(), 1, + std::multiplies())); cache_required_for_twiddles += static_cast(2 * factor_size * kernel_data.batch_size) * sizeof(Scalar); factors.push_back(factor_size); @@ -618,8 +640,8 @@ class committed_descriptor { for (std::size_t i = 1; i < factors.size(); i++) { inclusive_scan.push_back(inclusive_scan.at(i - 1) * factors.at(i)); } - dimensions.at(global_dimension).factors_and_scan = - detail::make_shared(factors.size() + sub_batches.size() + inclusive_scan.size(), queue); + Idx mem_for_inclusive_scans = 3 * (num_forward_factors + num_backward_factors); + dimensions.at(global_dimension).factors_and_scan = detail::make_shared(mem_for_inclusive_scans, queue); queue.copy(factors.data(), dimensions.at(global_dimension).factors_and_scan.get(), factors.size()); queue.copy(sub_batches.data(), dimensions.at(global_dimension).factors_and_scan.get() + factors.size(), sub_batches.size()); @@ -642,6 +664,11 @@ class committed_descriptor { std::vector{static_cast(factors.at(i)), static_cast(sub_batches.at(i))}, 1, 1, 1, std::shared_ptr(), detail::level::GLOBAL); } + if (num_backward_factors != 0) { + factors.clear(); + sub_batches.clear(); + inclusive_scan.clear(); + } } else { std::size_t max_encountered_global_size = 0; for (std::size_t i = 0; i < n_kernels; i++) { @@ -760,6 +787,9 @@ class committed_descriptor { for (std::size_t i = 0; i < n_kernels; i++) { if (dimensions.at(i).level == detail::level::GLOBAL) { is_scratch_required = true; + if (detail::factorize(static_cast(params.lengths.at(0))) == 1) { + dimensions.at(i).is_prime = true; + } num_global_level_dimensions++; } } From 6236bbaf08947ff582cbeb735a46d6ad11b6989c Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Thu, 7 Dec 2023 12:12:57 +0000 Subject: [PATCH 06/67] WIP --- src/portfft/common/workgroup.hpp | 22 -------- src/portfft/descriptor.hpp | 45 +++++++++++------ src/portfft/dispatcher/global_dispatcher.hpp | 50 +++++++++---------- .../dispatcher/workgroup_dispatcher.hpp | 9 ++-- src/portfft/utils.hpp | 23 ++++++++- 5 files changed, 79 insertions(+), 70 deletions(-) diff --git a/src/portfft/common/workgroup.hpp b/src/portfft/common/workgroup.hpp index 3190ed88..d786e1bd 100644 --- a/src/portfft/common/workgroup.hpp +++ b/src/portfft/common/workgroup.hpp @@ -30,28 +30,6 @@ namespace portfft { -/** - * Calculate the number of groups or bank lines of PORTFFT_N_LOCAL_BANKS between each padding in local memory, - * specifically for reducing bank conflicts when reading values from the columns of a 2D data layout. e.g. If there are - * 64 complex elements in a row, then the consecutive values in the same column are 128 floats apart. There are 32 - * banks, each the size of a float, so we only want a padding float every 128/32=4 bank lines to read along the column - * without bank conflicts. - * - * @tparam T Input type to the function - * @param row_size the size in bytes of the row. 32 std::complex values would probably have a size of 256 bytes. - * @return the number of groups of PORTFFT_N_LOCAL_BANKS between each padding in local memory. - */ -template -constexpr T bank_lines_per_pad_wg(T row_size) { - constexpr T BankLineSize = sizeof(float) * PORTFFT_N_LOCAL_BANKS; - if (row_size % BankLineSize == 0) { - return row_size / BankLineSize; - } - // There is room for improvement here. E.G if row_size was half of BankLineSize then maybe you would still want 1 - // pad every bank group. - return 1; -} - namespace detail { /** * Calculate all dfts in one dimension of the data stored in local memory. diff --git a/src/portfft/descriptor.hpp b/src/portfft/descriptor.hpp index 0f596fe9..71f44d5d 100644 --- a/src/portfft/descriptor.hpp +++ b/src/portfft/descriptor.hpp @@ -237,13 +237,20 @@ class committed_descriptor { std::shared_ptr factors_and_scan; detail::level level; std::size_t length; + std::size_t committed_length; + bool is_prime; Idx used_sg_size; Idx num_batches_in_l2; Idx num_factors; - bool is_prime; - dimension_struct(std::vector kernels, detail::level level, std::size_t length, Idx used_sg_size) - : kernels(kernels), level(level), length(length), used_sg_size(used_sg_size) {} + dimension_struct(std::vector kernels, detail::level level, std::size_t length, + std::size_t committed_length, bool is_prime, Idx used_sg_size) + : kernels(kernels), + level(level), + length(length), + committed_length(committed_length), + is_prime(is_prime), + used_sg_size(used_sg_size) {} }; std::vector dimensions; @@ -564,7 +571,11 @@ class committed_descriptor { counter++; } if (is_compatible) { - return {result, top_level, dimension_size, SubgroupSize}; + bool is_prime = false; + if (dimension_size != params.lengths[kernel_num]) { + is_prime = true; + } + return {result, top_level, dimension_size, params.lengths[kernel_num], is_prime, SubgroupSize}; } } } @@ -605,20 +616,20 @@ class committed_descriptor { std::accumulate(kernel_data.factors.begin(), kernel_data.factors.end(), 1, std::multiplies())); num_forward_factors++; } - if (dimensions.at(i).is_prime) { - num_backward_factors = static_cast(dimensions.at(global_dimension).kernels.length()) - num_forward_factors; + if (dimensions.at(global_dimension).is_prime) { + num_backward_factors = static_cast(dimensions.at(global_dimension).kernels.size()) - num_forward_factors; } const auto& kernels_for_dimension = dimensions.at(global_dimension).kernels; - for (Idx i = 0; i < num_forward_factors; i++) { + for (std::size_t i = 0; i < static_cast(num_forward_factors); i++) { IdxGlobal factor_size = static_cast(std::accumulate(kernels_for_dimension.at(i).factors.begin(), kernels_for_dimension.at(i).factors.end(), 1, std::multiplies())); cache_required_for_twiddles += - static_cast(2 * factor_size * kernel_data.batch_size) * sizeof(Scalar); + static_cast(2 * factor_size * kernels_for_dimension.at(i).batch_size) * sizeof(Scalar); factors.push_back(factor_size); - sub_batches.push_back(kernel_data.batch_size); + sub_batches.push_back(kernels_for_dimension.at(i).batch_size); } - dimensions.at(global_dimension).num_factors = static_cast(factors.size()); + dimensions.at(global_dimension).num_factors = num_forward_factors; std::size_t cache_space_left_for_batches = static_cast(llc_size) - cache_required_for_twiddles; // TODO: In case of mutli-dim (single dim global sized), this should be batches corresposding to that dim dimensions.at(global_dimension).num_batches_in_l2 = static_cast(std::min( @@ -637,20 +648,22 @@ class committed_descriptor { static_cast(dimensions.at(global_dimension).num_batches_in_l2), queue); inclusive_scan.push_back(factors.at(0)); - for (std::size_t i = 1; i < factors.size(); i++) { + for (std::size_t i = 1; i < static_cast(num_forward_factors); i++) { inclusive_scan.push_back(inclusive_scan.at(i - 1) * factors.at(i)); } Idx mem_for_inclusive_scans = 3 * (num_forward_factors + num_backward_factors); - dimensions.at(global_dimension).factors_and_scan = detail::make_shared(mem_for_inclusive_scans, queue); - queue.copy(factors.data(), dimensions.at(global_dimension).factors_and_scan.get(), factors.size()); - queue.copy(sub_batches.data(), dimensions.at(global_dimension).factors_and_scan.get() + factors.size(), + dimensions.at(global_dimension).factors_and_scan = + detail::make_shared(static_cast(mem_for_inclusive_scans), queue); + queue.copy(factors.data(), dimensions.at(global_dimension).factors_and_scan.get(), + static_cast(num_forward_factors)); + queue.copy(sub_batches.data(), dimensions.at(global_dimension).factors_and_scan.get() + num_forward_factors, sub_batches.size()); queue.copy(inclusive_scan.data(), - dimensions.at(global_dimension).factors_and_scan.get() + factors.size() + sub_batches.size(), + dimensions.at(global_dimension).factors_and_scan.get() + 2 * num_forward_factors, inclusive_scan.size()); queue.wait(); // build transpose kernels - std::size_t num_transposes_required = factors.size() - 1; + std::size_t num_transposes_required = static_cast(num_forward_factors - 1); for (std::size_t i = 0; i < num_transposes_required; i++) { std::vector ids; auto in_bundle = sycl::get_kernel_bundle(queue.get_context(), diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index b9aaf2cb..127665b3 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -310,16 +310,15 @@ struct committed_descriptor::run_kernel_struct( - kernels.at(0), in, desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, - scale_factor, intermediate_twiddles_offset, impl_twiddle_offset, - 2 * static_cast(i) * committed_size + input_offset, committed_size, - static_cast(max_batches_in_l2), static_cast(num_batches), static_cast(i), 0, - dimension_data.num_factors, current_events, previous_events, desc.queue); - intermediate_twiddles_offset += 2 * kernels.at(0).batch_size * - static_cast(kernels.at(0).length); - impl_twiddle_offset += detail::increment_twiddle_offset( - kernels.at(0).level, static_cast(kernels.at(0).length)); + SubgroupSize>(kernels.at(0), in, desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, + scale_factor, intermediate_twiddles_offset, impl_twiddle_offset, + 2 * static_cast(i) * committed_size + input_offset, committed_size, + static_cast(max_batches_in_l2), static_cast(num_batches), + static_cast(i), 0, dimension_data.num_factors, current_events, + previous_events, desc.queue); + intermediate_twiddles_offset += 2 * kernels.at(0).batch_size * static_cast(kernels.at(0).length); + impl_twiddle_offset += + detail::increment_twiddle_offset(kernels.at(0).level, static_cast(kernels.at(0).length)); current_events.swap(previous_events); for (std::size_t factor_num = 1; factor_num < static_cast(dimension_data.num_factors); factor_num++) { @@ -333,16 +332,15 @@ struct committed_descriptor::run_kernel_struct( - kernels.at(factor_num), static_cast(desc.scratch_ptr_1.get()), - desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, scale_factor, intermediate_twiddles_offset, - impl_twiddle_offset, 0, committed_size, static_cast(max_batches_in_l2), - static_cast(num_batches), static_cast(i), static_cast(factor_num), - dimension_data.num_factors, current_events, previous_events, desc.queue); - intermediate_twiddles_offset += 2 * kernels.at(factor_num).batch_size * - static_cast(kernels.at(factor_num).length); - impl_twiddle_offset += - detail::increment_twiddle_offset(kernels.at(factor_num).level, - static_cast(kernels.at(factor_num).length)); + kernels.at(factor_num), static_cast(desc.scratch_ptr_1.get()), desc.scratch_ptr_1.get(), + twiddles_ptr, factors_and_scan, scale_factor, intermediate_twiddles_offset, impl_twiddle_offset, 0, + committed_size, static_cast(max_batches_in_l2), static_cast(num_batches), + static_cast(i), static_cast(factor_num), dimension_data.num_factors, current_events, + previous_events, desc.queue); + intermediate_twiddles_offset += + 2 * kernels.at(factor_num).batch_size * static_cast(kernels.at(factor_num).length); + impl_twiddle_offset += detail::increment_twiddle_offset(kernels.at(factor_num).level, + static_cast(kernels.at(factor_num).length)); current_events.swap(previous_events); } } @@ -352,19 +350,17 @@ struct committed_descriptor::run_kernel_struct 0; num_transpose--) { current_events[0] = detail::transpose_level( - kernels.at(static_cast(num_transpose) + - static_cast(num_factors)), + kernels.at(static_cast(num_transpose) + static_cast(num_factors)), static_cast(desc.scratch_ptr_1.get()), desc.scratch_ptr_2.get(), factors_and_scan, committed_size, static_cast(max_batches_in_l2), n_transforms, static_cast(i), num_transpose, num_factors, 0, desc.queue, desc.scratch_ptr_1, desc.scratch_ptr_2, current_events, previous_events); current_events[0].wait(); } current_events[0] = detail::transpose_level( - kernels.at(static_cast(num_factors)), - static_cast(desc.scratch_ptr_1.get()), out, factors_and_scan, committed_size, - static_cast(max_batches_in_l2), n_transforms, static_cast(i), 0, num_factors, - 2 * static_cast(i) * committed_size + output_offset, desc.queue, desc.scratch_ptr_1, - desc.scratch_ptr_2, current_events, previous_events); + kernels.at(static_cast(num_factors)), static_cast(desc.scratch_ptr_1.get()), out, + factors_and_scan, committed_size, static_cast(max_batches_in_l2), n_transforms, + static_cast(i), 0, num_factors, 2 * static_cast(i) * committed_size + output_offset, + desc.queue, desc.scratch_ptr_1, desc.scratch_ptr_2, current_events, previous_events); } return current_events[0]; } diff --git a/src/portfft/dispatcher/workgroup_dispatcher.hpp b/src/portfft/dispatcher/workgroup_dispatcher.hpp index 8b845e19..f15baa57 100644 --- a/src/portfft/dispatcher/workgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/workgroup_dispatcher.hpp @@ -30,6 +30,7 @@ #include "portfft/descriptor.hpp" #include "portfft/enums.hpp" #include "portfft/specialization_constant.hpp" +#include "portfft/utils.hpp" namespace portfft { namespace detail { @@ -119,7 +120,7 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* /*input_i Idx factor_n = detail::factorize(fft_size); Idx factor_m = fft_size / factor_n; const T* wg_twiddles = twiddles + 2 * (factor_m + factor_n); - const Idx bank_lines_per_pad = bank_lines_per_pad_wg(2 * static_cast(sizeof(T)) * factor_m); + const Idx bank_lines_per_pad = detail::bank_lines_per_pad_wg(2 * static_cast(sizeof(T)) * factor_m); auto loc_view = padded_view(loc, bank_lines_per_pad); global_data.log_message_global(__func__, "loading sg twiddles from global to local memory"); @@ -227,8 +228,8 @@ struct committed_descriptor::run_kernel_struct(detail::get_global_size_workgroup( n_transforms, SubgroupSize, kernel_data.num_sgs_per_wg, desc.n_compute_units)); - const Idx bank_lines_per_pad = - bank_lines_per_pad_wg(2 * static_cast(sizeof(Scalar)) * kernel_data.factors[2] * kernel_data.factors[3]); + const Idx bank_lines_per_pad = detail::bank_lines_per_pad_wg(2 * static_cast(sizeof(Scalar)) * + kernel_data.factors[2] * kernel_data.factors[3]); std::size_t sg_twiddles_offset = static_cast( detail::pad_local(2 * static_cast(kernel_data.length) * num_batches_in_local_mem, bank_lines_per_pad)); return desc.queue.submit([&](sycl::handler& cgh) { @@ -284,7 +285,7 @@ struct committed_descriptor::num_scalars_in_local_mem_struct::in Idx num_batches_in_local_mem = detail::get_num_batches_in_local_mem_workgroup(used_sg_size * PORTFFT_SGS_IN_WG); return detail::pad_local(static_cast(2 * num_batches_in_local_mem) * length, - bank_lines_per_pad_wg(2 * static_cast(sizeof(Scalar)) * m)) + + detail::bank_lines_per_pad_wg(2 * static_cast(sizeof(Scalar)) * m)) + 2 * (m + n); } }; diff --git a/src/portfft/utils.hpp b/src/portfft/utils.hpp index 16e5388a..fa669cac 100644 --- a/src/portfft/utils.hpp +++ b/src/portfft/utils.hpp @@ -27,7 +27,6 @@ #include #include "common/memory_views.hpp" -#include "common/workgroup.hpp" #include "defines.hpp" #include "enums.hpp" @@ -192,6 +191,28 @@ inline std::shared_ptr make_shared(std::size_t size, sycl::queue& queue) { }); } +/** + * Calculate the number of groups or bank lines of PORTFFT_N_LOCAL_BANKS between each padding in local memory, + * specifically for reducing bank conflicts when reading values from the columns of a 2D data layout. e.g. If there are + * 64 complex elements in a row, then the consecutive values in the same column are 128 floats apart. There are 32 + * banks, each the size of a float, so we only want a padding float every 128/32=4 bank lines to read along the column + * without bank conflicts. + * + * @tparam T Input type to the function + * @param row_size the size in bytes of the row. 32 std::complex values would probably have a size of 256 bytes. + * @return the number of groups of PORTFFT_N_LOCAL_BANKS between each padding in local memory. + */ +template +constexpr T bank_lines_per_pad_wg(T row_size) { + constexpr T BankLineSize = sizeof(float) * PORTFFT_N_LOCAL_BANKS; + if (row_size % BankLineSize == 0) { + return row_size / BankLineSize; + } + // There is room for improvement here. E.G if row_size was half of BankLineSize then maybe you would still want 1 + // pad every bank group. + return 1; +} + /** * @brief Gets the cumulative local memory usage for a particular level * @tparam Scalar Scalar type From dbc09198bd1b05d7223f2739c3b8f6ef29e04b6d Mon Sep 17 00:00:00 2001 From: Atharva Dubey Date: Mon, 11 Dec 2023 12:56:35 +0000 Subject: [PATCH 07/67] correctly populating factors and sub batches vectors for backward factors, add transposes for backward factors --- src/portfft/descriptor.hpp | 60 ++++++++++++++------ src/portfft/dispatcher/global_dispatcher.hpp | 57 ++++++++++++++----- 2 files changed, 85 insertions(+), 32 deletions(-) diff --git a/src/portfft/descriptor.hpp b/src/portfft/descriptor.hpp index 71f44d5d..5101c7b0 100644 --- a/src/portfft/descriptor.hpp +++ b/src/portfft/descriptor.hpp @@ -241,7 +241,8 @@ class committed_descriptor { bool is_prime; Idx used_sg_size; Idx num_batches_in_l2; - Idx num_factors; + Idx forward_factors; + Idx backward_factors; dimension_struct(std::vector kernels, detail::level level, std::size_t length, std::size_t committed_length, bool is_prime, Idx used_sg_size) @@ -600,25 +601,14 @@ class committed_descriptor { break; } } - IdxGlobal dimension_size = static_cast(dimensions.at(global_dimension).length); - Idx num_forward_factors = 0; - Idx num_backward_factors = 0; - IdxGlobal temp_acc = 1; + std::vector factors; std::vector sub_batches; std::vector inclusive_scan; std::size_t cache_required_for_twiddles = 0; - for (const auto& kernel_data : dimensions.at(global_dimension).kernels) { - if (temp_acc == dimension_size) { - break; - } - temp_acc *= static_cast( - std::accumulate(kernel_data.factors.begin(), kernel_data.factors.end(), 1, std::multiplies())); - num_forward_factors++; - } - if (dimensions.at(global_dimension).is_prime) { - num_backward_factors = static_cast(dimensions.at(global_dimension).kernels.size()) - num_forward_factors; - } + Idx num_forward_factors = dimensions.at(global_dimension).forward_factors; + Idx num_backward_factors = dimensions.at(global_dimension).backward_factors; + const auto& kernels_for_dimension = dimensions.at(global_dimension).kernels; for (std::size_t i = 0; i < static_cast(num_forward_factors); i++) { IdxGlobal factor_size = static_cast(std::accumulate(kernels_for_dimension.at(i).factors.begin(), @@ -629,7 +619,7 @@ class committed_descriptor { factors.push_back(factor_size); sub_batches.push_back(kernels_for_dimension.at(i).batch_size); } - dimensions.at(global_dimension).num_factors = num_forward_factors; + std::size_t cache_space_left_for_batches = static_cast(llc_size) - cache_required_for_twiddles; // TODO: In case of mutli-dim (single dim global sized), this should be batches corresposding to that dim dimensions.at(global_dimension).num_batches_in_l2 = static_cast(std::min( @@ -681,6 +671,37 @@ class committed_descriptor { factors.clear(); sub_batches.clear(); inclusive_scan.clear(); + for (std::size_t i = 0; i < static_cast(num_backward_factors); i++) { + const auto& backward_factor_kernel = kernels_for_dimension.at(num_forward_factors + i); + factors.push_back( + static_cast(std::accumulate(backward_factor_kernel.factors.begin(), + backward_factor_kernel.factors.end(), 1, std::multiplies()))); + sub_batches.push_back(backward_factor_kernel.batch_size); + } + inclusive_scan.push_back(factors.at(0)); + for (std::size_t i = 1; i < static_cast(num_backward_factors); i++) { + inclusive_scan.push_back(inclusive_scan.at(i - 1) * factors.at(i)); + } + IdxGlobal* offset_ptr = dimensions.at(global_dimension).factors_and_scan.get() + 3 * num_forward_factors; + queue.copy(factors.data(), offset_ptr, static_cast(num_backward_factors)); + queue.copy(sub_batches.data(), offset_ptr + num_backward_factors, sub_batches.size()); + queue.copy(inclusive_scan.data(), offset_ptr + 2 * num_backward_factors, inclusive_scan.size()); + queue.wait(); + + Idx num_backward_transposes = dimensions.at(global_dimension).backward_factors - 1; + for (std::size_t i = 0; i < num_backward_transposes; i++) { + std::vector ids; + auto in_bundle = sycl::get_kernel_bundle( + queue.get_context(), detail::get_transpose_kernel_ids()); + in_bundle.template set_specialization_constant(static_cast(i)); + in_bundle.template set_specialization_constant( + static_cast(factors.size())); + dimensions.at(global_dimension) + .kernels.emplace_back( + sycl::build(in_bundle), + std::vector{static_cast(factors.at(i)), static_cast(sub_batches.at(i))}, 1, 1, 1, + std::shared_ptr(), detail::level::GLOBAL); + } } } else { std::size_t max_encountered_global_size = 0; @@ -711,7 +732,7 @@ class committed_descriptor { for (std::size_t j = 1; j < factors.size(); j++) { inclusive_scan.push_back(inclusive_scan.at(j - 1) * factors.at(j)); } - dimensions.at(i).num_factors = static_cast(factors.size()); + dimensions.at(i).forward_factors = static_cast(factors.size()); dimensions.at(i).factors_and_scan = detail::make_shared(factors.size() + sub_batches.size() + inclusive_scan.size(), queue); queue.copy(factors.data(), dimensions.at(i).factors_and_scan.get(), factors.size()); @@ -777,6 +798,9 @@ class committed_descriptor { std::size_t n_kernels = params.lengths.size(); for (std::size_t i = 0; i < n_kernels; i++) { dimensions.push_back(build_w_spec_const(i)); + if (dimensions.back().committed_length != dimensions.back().length) { + dimensions.back().is_prime = true; + } dimensions.back().kernels.at(0).twiddles_forward = std::shared_ptr(calculate_twiddles(dimensions.back()), [queue](Scalar* ptr) { if (ptr != nullptr) { diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index 127665b3..fe05bf84 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -108,12 +108,18 @@ struct committed_descriptor::calculate_twiddles_struct::inner factors_idx_global; + IdxGlobal temp_acc = 1; // Get factor sizes per level; for (const auto& kernel_data : kernels) { factors_idx_global.push_back(static_cast( std::accumulate(kernel_data.factors.begin(), kernel_data.factors.end(), 1, std::multiplies()))); + temp_acc *= factors_idx_global.back(); + if (temp_acc == dimension_data.committed_length) { + break; + } } - + dimension_data.forward_factors = static_cast(factors_idx_global.size()); + dimension_data.backward_factors = static_cast(kernels.size()) - dimension_data.forward_factors; std::vector sub_batches; // Get sub batches for (std::size_t i = 0; i < factors_idx_global.size() - 1; i++) { @@ -121,12 +127,31 @@ struct committed_descriptor::calculate_twiddles_struct::inner())); } sub_batches.push_back(factors_idx_global.at(factors_idx_global.size() - 2)); + // factors and inner batches for the backward factors; + if (dimension_data.backward_factors > 0) { + for (Idx i = 0; i < dimension_data.backward_factors; i++) { + const auto& kd_struct = kernels.at(static_cast(dimension_data.forward_factors + i)); + factors_idx_global.push_back(static_cast( + std::accumulate(kd_struct.factors.begin(), kd_struct.factors.end(), 1, std::multiplies()))); + } + for (Idx back_factor = 0; back_factor < dimension_data.backward_factors - 1; back_factor++) { + sub_batches.push_back(std::accumulate( + factors_idx_global.begin() + static_cast(dimension_data.forward_factors + back_factor + 1), + factors_idx_global.end(), IdxGlobal(1), std::multiplies())); + } + } // calculate total memory required for twiddles; IdxGlobal mem_required_for_twiddles = 0; // First calculate mem required for twiddles between factors; - for (std::size_t i = 0; i < factors_idx_global.size() - 1; i++) { + for (std::size_t i = 0; i < static_cast(dimension_data.forward_factors - 1); i++) { mem_required_for_twiddles += 2 * factors_idx_global.at(i) * sub_batches.at(i); } + if (dimension_data.backward_factors > 0) { + for (std::size_t i = 0; i < static_cast(dimension_data.backward_factors - 1); i++) { + mem_required_for_twiddles += 2 * factors_idx_global.at(i) * sub_batches.at(i); + } + } + // Now calculate mem required for twiddles per implementation std::size_t counter = 0; for (const auto& kernel_data : kernels) { @@ -157,9 +182,14 @@ struct committed_descriptor::calculate_twiddles_struct::inner(dimension_data.forward_factors) - 1; i++) { calculate_twiddles(sub_batches.at(i), factors_idx_global.at(i), offset, host_memory.data()); } + if (dimension_data.backward_factors > 0) { + for (std::size_t i = 0; i < static_cast(dimension_data.backward_factors) - 1; i++) { + calculate_twiddles(sub_batches.at(i), factors_idx_global.at(i), offset, host_memory.data()); + } + } // Now calculate per twiddles. counter = 0; for (const auto& kernel_data : kernels) { @@ -292,8 +322,8 @@ struct committed_descriptor::run_kernel_struct(dimension_data.num_batches_in_l2); IdxGlobal initial_impl_twiddle_offset = 0; - Idx num_factors = dimension_data.num_factors; - IdxGlobal committed_size = static_cast(desc.params.lengths[0]); + Idx num_factors = dimension_data.forward_factors; + IdxGlobal committed_size = static_cast(dimension_data.length); Idx num_transposes = num_factors - 1; std::vector current_events; std::vector previous_events; @@ -314,29 +344,28 @@ struct committed_descriptor::run_kernel_struct(i) * committed_size + input_offset, committed_size, static_cast(max_batches_in_l2), static_cast(num_batches), - static_cast(i), 0, dimension_data.num_factors, current_events, - previous_events, desc.queue); + static_cast(i), 0, num_factors, current_events, previous_events, + desc.queue); intermediate_twiddles_offset += 2 * kernels.at(0).batch_size * static_cast(kernels.at(0).length); impl_twiddle_offset += detail::increment_twiddle_offset(kernels.at(0).level, static_cast(kernels.at(0).length)); current_events.swap(previous_events); - for (std::size_t factor_num = 1; factor_num < static_cast(dimension_data.num_factors); - factor_num++) { - if (static_cast(factor_num) == dimension_data.num_factors - 1) { + for (std::size_t factor_num = 1; factor_num < static_cast(num_factors); factor_num++) { + if (static_cast(factor_num) == num_factors - 1) { detail::compute_level( dimension_data.kernels.at(factor_num), static_cast(desc.scratch_ptr_1.get()), desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, scale_factor, intermediate_twiddles_offset, impl_twiddle_offset, 0, committed_size, static_cast(max_batches_in_l2), - static_cast(num_batches), static_cast(i), static_cast(factor_num), - dimension_data.num_factors, current_events, previous_events, desc.queue); + static_cast(num_batches), static_cast(i), static_cast(factor_num), num_factors, + current_events, previous_events, desc.queue); } else { detail::compute_level( kernels.at(factor_num), static_cast(desc.scratch_ptr_1.get()), desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, scale_factor, intermediate_twiddles_offset, impl_twiddle_offset, 0, committed_size, static_cast(max_batches_in_l2), static_cast(num_batches), - static_cast(i), static_cast(factor_num), dimension_data.num_factors, current_events, - previous_events, desc.queue); + static_cast(i), static_cast(factor_num), num_factors, current_events, previous_events, + desc.queue); intermediate_twiddles_offset += 2 * kernels.at(factor_num).batch_size * static_cast(kernels.at(factor_num).length); impl_twiddle_offset += detail::increment_twiddle_offset(kernels.at(factor_num).level, From 8de1139ebd0e59a55b77e932b1087e98c65c34f6 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Tue, 12 Dec 2023 07:40:35 +0000 Subject: [PATCH 08/67] update setting of specialization constants --- src/portfft/descriptor.hpp | 116 ++++++++++++++++++++++++++----------- src/portfft/utils.hpp | 1 + 2 files changed, 83 insertions(+), 34 deletions(-) diff --git a/src/portfft/descriptor.hpp b/src/portfft/descriptor.hpp index 5101c7b0..cae19d21 100644 --- a/src/portfft/descriptor.hpp +++ b/src/portfft/descriptor.hpp @@ -510,6 +510,45 @@ class committed_descriptor { Scalar* calculate_twiddles(dimension_struct& dimension_data) { return dispatch(dimension_data.level, dimension_data); } + using in_bundle_and_metadata = + std::tuple, std::vector>; + /** + * Sets the spec constants for the global implementation. + * @param prepared_vec Vector returned by prepare_implementations + * @param first_uses_load_modifiers whether or not first kernel multiplies the modifier before dft compute + * @param last_uses_load_modifier whether or not first kernel multiplies the modifier after dft compute + * @param num_kernels number of factors + */ + void set_global_impl_spec_consts(std::vector& prepared_vec, + detail::elementwise_multiply first_uses_load_modifiers, + detail::elementwise_multiply last_uses_load_modifier, Idx num_kernels) { + Idx counter = 0; + for (auto& [level, in_bundle, factors] : prepared_vec) { + if (counter > num_kernels) { + break; + } + if (counter == num_kernels - 1) { + set_spec_constants( + detail::level::GLOBAL, in_bundle, + static_cast(std::accumulate(factors.begin(), factors.end(), Idx(1), std::multiplies())), + factors, detail::elementwise_multiply::NOT_APPLIED, last_uses_load_modifier, + detail::apply_scale_factor::APPLIED, level, static_cast(counter), static_cast(num_kernels)); + } else if (counter == 0) { + set_spec_constants( + detail::level::GLOBAL, in_bundle, + static_cast(std::accumulate(factors.begin(), factors.end(), Idx(1), std::multiplies())), + factors, first_uses_load_modifiers, detail::elementwise_multiply::APPLIED, + detail::apply_scale_factor::NOT_APPLIED, level, static_cast(counter), static_cast(num_kernels)); + } else { + set_spec_constants( + detail::level::GLOBAL, in_bundle, + static_cast(std::accumulate(factors.begin(), factors.end(), Idx(1), std::multiplies())), + factors, detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::APPLIED, + detail::apply_scale_factor::NOT_APPLIED, level, static_cast(counter), static_cast(num_kernels)); + } + counter++; + } + } /** * Builds the kernel bundles with appropriate values of specialization constants for the first supported subgroup @@ -532,34 +571,48 @@ class committed_descriptor { } } std::vector result; + std::size_t counter = 0; + std::size_t dimension_size = 1; + for (auto [level, ids, factors] : prepared_vec) { + dimension_size *= + static_cast(std::accumulate(factors.begin(), factors.end(), 1, std::multiplies())); + counter++; + if (dimension_size != params.lengths[kernel_num]) { + break; + } + } + std::size_t backward_factors = prepared_vec.size() - counter; + std::vector in_bundles; if (is_compatible) { - std::size_t counter = 0; - std::size_t dimension_size = 1; for (auto [level, ids, factors] : prepared_vec) { - auto in_bundle = sycl::get_kernel_bundle(queue.get_context(), ids); - if (top_level == detail::level::GLOBAL) { - if (counter == prepared_vec.size() - 1) { - set_spec_constants(detail::level::GLOBAL, in_bundle, - static_cast( - std::accumulate(factors.begin(), factors.end(), Idx(1), std::multiplies())), - factors, detail::elementwise_multiply::NOT_APPLIED, - detail::elementwise_multiply::NOT_APPLIED, detail::apply_scale_factor::APPLIED, level, - static_cast(counter), static_cast(prepared_vec.size())); - } else { - set_spec_constants(detail::level::GLOBAL, in_bundle, - static_cast( - std::accumulate(factors.begin(), factors.end(), Idx(1), std::multiplies())), - factors, detail::elementwise_multiply::NOT_APPLIED, - detail::elementwise_multiply::APPLIED, detail::apply_scale_factor::NOT_APPLIED, level, - static_cast(counter), static_cast(prepared_vec.size())); - } - } else { - set_spec_constants(level, in_bundle, params.lengths[kernel_num], factors, - detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::NOT_APPLIED, - detail::apply_scale_factor::APPLIED, level); + in_bundles.emplace_back(level, sycl::get_kernel_bundle(queue.get_context(), ids), + factors); + } + if (top_level == detail::level::GLOBAL) { + detail::elementwise_multiply first_uses_load_modifiers = + backward_factors > 0 ? detail::elementwise_multiply::APPLIED : detail::elementwise_multiply::NOT_APPLIED; + detail::elementwise_multiply last_uses_store_modifiers = + backward_factors > 0 ? detail::elementwise_multiply::APPLIED : detail::elementwise_multiply::NOT_APPLIED; + set_global_impl_spec_consts(in_bundles, first_uses_load_modifiers, detail::elementwise_multiply::NOT_APPLIED, + counter); + if (backward_factors > 0) { + std::vector backward_kernels_slice(in_bundles.begin() + static_cast(counter), + in_bundles.end()); + set_global_impl_spec_consts(backward_kernels_slice, detail::elementwise_multiply::NOT_APPLIED, + last_uses_store_modifiers, backward_factors); + std::copy(backward_kernels_slice.begin(), backward_kernels_slice.end(), + in_bundles.begin() + static_cast(counter)); + } + } else { + for (auto& [level, in_bundle, factors] : in_bundles) { + set_spec_constants(top_level, in_bundle, dimension_size, factors, detail::elementwise_multiply::NOT_APPLIED, + detail::elementwise_multiply::NOT_APPLIED, detail::apply_scale_factor::APPLIED, level); } + } + + for (const auto& [level, in_bundle_w_spec_const, factors] : in_bundles) { try { - result.emplace_back(sycl::build(in_bundle), factors, params.lengths[kernel_num], SubgroupSize, + result.emplace_back(sycl::build(in_bundle_w_spec_const), factors, params.lengths[kernel_num], SubgroupSize, PORTFFT_SGS_IN_WG, std::shared_ptr(), level); } catch (std::exception& e) { std::cerr << "Build for subgroup size " << SubgroupSize << " failed with message:\n" @@ -567,17 +620,12 @@ class committed_descriptor { is_compatible = false; break; } - dimension_size *= - static_cast(std::accumulate(factors.begin(), factors.end(), 1, std::multiplies())); - counter++; } - if (is_compatible) { - bool is_prime = false; - if (dimension_size != params.lengths[kernel_num]) { - is_prime = true; - } - return {result, top_level, dimension_size, params.lengths[kernel_num], is_prime, SubgroupSize}; + bool is_prime = false; + if (backward_factors > 0) { + is_prime = true; } + return {result, top_level, dimension_size, params.lengths[kernel_num], is_prime, SubgroupSize}; } } if constexpr (sizeof...(OtherSGSizes) == 0) { @@ -689,7 +737,7 @@ class committed_descriptor { queue.wait(); Idx num_backward_transposes = dimensions.at(global_dimension).backward_factors - 1; - for (std::size_t i = 0; i < num_backward_transposes; i++) { + for (std::size_t i = 0; i < static_cast(num_backward_transposes); i++) { std::vector ids; auto in_bundle = sycl::get_kernel_bundle( queue.get_context(), detail::get_transpose_kernel_ids()); diff --git a/src/portfft/utils.hpp b/src/portfft/utils.hpp index fa669cac..b5e7fced 100644 --- a/src/portfft/utils.hpp +++ b/src/portfft/utils.hpp @@ -151,6 +151,7 @@ bool factorize_input(IdxGlobal input_size, F&& check_and_select_target_level, bo } temp *= factorize_input_impl(input_size / temp, check_and_select_target_level, true, encountered_prime, requires_load_modifier); + requires_load_modifier = false; } return encountered_prime; } From 69f674ea5e087e73dbde89aa9a9fd873fc5e2d07 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Tue, 12 Dec 2023 10:11:34 +0000 Subject: [PATCH 09/67] added host naive_dft and generation of chirp signal --- src/portfft/common/bluestein.hpp | 60 ++++++++++++++++++++++++++++++++ src/portfft/common/host_fft.hpp | 43 +++++++++++++++++++++++ 2 files changed, 103 insertions(+) create mode 100644 src/portfft/common/bluestein.hpp create mode 100644 src/portfft/common/host_fft.hpp diff --git a/src/portfft/common/bluestein.hpp b/src/portfft/common/bluestein.hpp new file mode 100644 index 00000000..55267dae --- /dev/null +++ b/src/portfft/common/bluestein.hpp @@ -0,0 +1,60 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Codeplay's portFFT + * + **************************************************************************/ + +#ifndef PORTFFT_COMMON_BLUESTEIN_HPP +#define PORTFFT_COMMON_BLUESTEIN_HPP + +#include "portfft/common/host_fft.hpp" +#include "portfft/defines.hpp" + +#include +#include +#include + +namespace portfft { +namespace detail { +/** + * Utility function to get chirp signal and fft + * @tparam T Scalar Type + * @param ptr Device Pointer containing the load/store modifiers. + * @param committed_size original problem size + * @param dimension_size padded size + * @param queue queue with the committed descriptor + */ +template +void get_fft_chirp_signal(T* ptr, IdxGlobal committed_size, IdxGlobal dimension_size, sycl::queue& queue) { + using ctype = std::complex; + ctype* chirp_signal = (ctype*)calloc(dimension_size, sizeof(ctype)); + ctype* chirp_fft = (ctype*)malloc(dimension_size * sizeof(ctype)); + for (IdxGlobal i = 0; i < committed_size; i++) { + double theta = M_PI * static_cast(i * i) / static_cast(committed_size); + chirp_signal[i] = ctype(std::cos(theta), std::sin(theta)); + } + IdxGlobal num_zeros = dimension_size - 2 * committed_size + 1; + for (IdxGlobal i = 0; i < committed_size; i++) { + chirp_signal[committed_size + num_zeros + i - 1] = chirp_signal[committed_size - i]; + } + naive_dft(chirp_signal, chirp_fft, dimension_size); + queue.copy(reinterpret_cast(&chirp_fft[0]), ptr, 2 * dimension_size).wait(); +} +} // namespace detail +} // namespace portfft + +#endif \ No newline at end of file diff --git a/src/portfft/common/host_fft.hpp b/src/portfft/common/host_fft.hpp new file mode 100644 index 00000000..16af0831 --- /dev/null +++ b/src/portfft/common/host_fft.hpp @@ -0,0 +1,43 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Codeplay's portFFT + * + **************************************************************************/ + +#ifndef PORTFFT_COMMON_HOST_FFT_HPP +#define PORTFFT_COMMON_HOST_FFT_HPP + +#include "portfft/defines.hpp" +#include + +namespace portfft { +namespace detail { +template +void naive_dft(T* input, T* output, IdxGlobal fft_size) { + for (int i = 0; i < fft_size; i++) { + ftype temp = 0; + for (int j = 0; j < fft_size; j++) { + ftype multiplier = ftype(std::cos((-2 * M_PI * i * j) / fft_size), std::sin((-2 * M_PI * i * j) / fft_size)); + temp += input[j] * multiplier; + } + output[i] = temp; + } +} +} // namespace detail +} // namespace portfft + +#endif \ No newline at end of file From 0f7f7cb8402ef632a14a894ea3713f4c96a7de99 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Tue, 12 Dec 2023 16:39:12 +0000 Subject: [PATCH 10/67] further changes and resolved all warnings --- src/portfft/common/bluestein.hpp | 12 ++++++++++++ src/portfft/common/host_fft.hpp | 6 ++++-- src/portfft/descriptor.hpp | 7 ++++--- src/portfft/dispatcher/global_dispatcher.hpp | 3 ++- 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/portfft/common/bluestein.hpp b/src/portfft/common/bluestein.hpp index 55267dae..17a82bbd 100644 --- a/src/portfft/common/bluestein.hpp +++ b/src/portfft/common/bluestein.hpp @@ -54,6 +54,18 @@ void get_fft_chirp_signal(T* ptr, IdxGlobal committed_size, IdxGlobal dimension_ naive_dft(chirp_signal, chirp_fft, dimension_size); queue.copy(reinterpret_cast(&chirp_fft[0]), ptr, 2 * dimension_size).wait(); } + +template +void populate_input_and_output_modifiers(T* ptr, IdxGlobal committed_size, IdxGlobal dimension_size, + sycl::queue& queue) { + using ctype = std::complex; + ctype* scratch = (ctype*)calloc(dimension_size, sizeof(ctype)); + for (IdxGlobal i = 0; i < committed_size; i++) { + double theta = -M_PI * static_cast(i * i) / static_cast(committed_size); + scratch[i] = ctype(std::cos(theta), std::sin(theta)); + } + queue.copy(reinterpret_cast(&scratch[0]), ptr, 2 * dimension_size); +} } // namespace detail } // namespace portfft diff --git a/src/portfft/common/host_fft.hpp b/src/portfft/common/host_fft.hpp index 16af0831..89f9774d 100644 --- a/src/portfft/common/host_fft.hpp +++ b/src/portfft/common/host_fft.hpp @@ -28,10 +28,12 @@ namespace portfft { namespace detail { template void naive_dft(T* input, T* output, IdxGlobal fft_size) { + using ftype = std::complex; for (int i = 0; i < fft_size; i++) { - ftype temp = 0; + ftype temp = ftype(0, 0); for (int j = 0; j < fft_size; j++) { - ftype multiplier = ftype(std::cos((-2 * M_PI * i * j) / fft_size), std::sin((-2 * M_PI * i * j) / fft_size)); + ftype multiplier = ftype(std::cos((-2 * M_PI * i * j) / static_cast(fft_size)), + std::sin((-2 * M_PI * i * j) / static_cast(fft_size))); temp += input[j] * multiplier; } output[i] = temp; diff --git a/src/portfft/descriptor.hpp b/src/portfft/descriptor.hpp index cae19d21..8bd8ecb6 100644 --- a/src/portfft/descriptor.hpp +++ b/src/portfft/descriptor.hpp @@ -594,12 +594,12 @@ class committed_descriptor { detail::elementwise_multiply last_uses_store_modifiers = backward_factors > 0 ? detail::elementwise_multiply::APPLIED : detail::elementwise_multiply::NOT_APPLIED; set_global_impl_spec_consts(in_bundles, first_uses_load_modifiers, detail::elementwise_multiply::NOT_APPLIED, - counter); + static_cast(counter)); if (backward_factors > 0) { std::vector backward_kernels_slice(in_bundles.begin() + static_cast(counter), in_bundles.end()); set_global_impl_spec_consts(backward_kernels_slice, detail::elementwise_multiply::NOT_APPLIED, - last_uses_store_modifiers, backward_factors); + last_uses_store_modifiers, static_cast(backward_factors)); std::copy(backward_kernels_slice.begin(), backward_kernels_slice.end(), in_bundles.begin() + static_cast(counter)); } @@ -720,7 +720,8 @@ class committed_descriptor { sub_batches.clear(); inclusive_scan.clear(); for (std::size_t i = 0; i < static_cast(num_backward_factors); i++) { - const auto& backward_factor_kernel = kernels_for_dimension.at(num_forward_factors + i); + const auto& backward_factor_kernel = + kernels_for_dimension.at(static_cast(num_forward_factors) + i); factors.push_back( static_cast(std::accumulate(backward_factor_kernel.factors.begin(), backward_factor_kernel.factors.end(), 1, std::multiplies()))); diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index fe05bf84..a539cb20 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -25,6 +25,7 @@ #include +#include "portfft/common/bluestein.hpp" #include "portfft/common/global.hpp" #include "portfft/common/subgroup.hpp" #include "portfft/defines.hpp" @@ -114,7 +115,7 @@ struct committed_descriptor::calculate_twiddles_struct::inner( std::accumulate(kernel_data.factors.begin(), kernel_data.factors.end(), 1, std::multiplies()))); temp_acc *= factors_idx_global.back(); - if (temp_acc == dimension_data.committed_length) { + if (temp_acc == static_cast(dimension_data.committed_length)) { break; } } From 308351d46ebc26b3571a77ef27d0caf1e69c3f1e Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Tue, 12 Dec 2023 17:10:07 +0000 Subject: [PATCH 11/67] populate bluestein specific twiddles in device pointer --- src/portfft/common/bluestein.hpp | 18 +++++++++--------- src/portfft/common/host_fft.hpp | 10 +++++----- src/portfft/dispatcher/global_dispatcher.hpp | 10 ++++++++++ 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/src/portfft/common/bluestein.hpp b/src/portfft/common/bluestein.hpp index 17a82bbd..5619a82c 100644 --- a/src/portfft/common/bluestein.hpp +++ b/src/portfft/common/bluestein.hpp @@ -41,30 +41,30 @@ namespace detail { template void get_fft_chirp_signal(T* ptr, IdxGlobal committed_size, IdxGlobal dimension_size, sycl::queue& queue) { using ctype = std::complex; - ctype* chirp_signal = (ctype*)calloc(dimension_size, sizeof(ctype)); - ctype* chirp_fft = (ctype*)malloc(dimension_size * sizeof(ctype)); + ctype* chirp_signal = (ctype*)calloc(static_cast(dimension_size), sizeof(ctype)); + ctype* chirp_fft = (ctype*)malloc(static_cast(dimension_size) * sizeof(ctype)); for (IdxGlobal i = 0; i < committed_size; i++) { double theta = M_PI * static_cast(i * i) / static_cast(committed_size); - chirp_signal[i] = ctype(std::cos(theta), std::sin(theta)); + chirp_signal[i] = ctype(static_cast(std::cos(theta)), static_cast(std::sin(theta))); } IdxGlobal num_zeros = dimension_size - 2 * committed_size + 1; for (IdxGlobal i = 0; i < committed_size; i++) { chirp_signal[committed_size + num_zeros + i - 1] = chirp_signal[committed_size - i]; } naive_dft(chirp_signal, chirp_fft, dimension_size); - queue.copy(reinterpret_cast(&chirp_fft[0]), ptr, 2 * dimension_size).wait(); + queue.copy(reinterpret_cast(&chirp_fft[0]), ptr, static_cast(2 * dimension_size)).wait(); } template -void populate_input_and_output_modifiers(T* ptr, IdxGlobal committed_size, IdxGlobal dimension_size, - sycl::queue& queue) { +void populate_bluestein_input_modifiers(T* ptr, IdxGlobal committed_size, IdxGlobal dimension_size, + sycl::queue& queue) { using ctype = std::complex; - ctype* scratch = (ctype*)calloc(dimension_size, sizeof(ctype)); + ctype* scratch = (ctype*)calloc(static_cast(dimension_size), sizeof(ctype)); for (IdxGlobal i = 0; i < committed_size; i++) { double theta = -M_PI * static_cast(i * i) / static_cast(committed_size); - scratch[i] = ctype(std::cos(theta), std::sin(theta)); + scratch[i] = ctype(static_cast(std::cos(theta)), static_cast(std::sin(theta))); } - queue.copy(reinterpret_cast(&scratch[0]), ptr, 2 * dimension_size); + queue.copy(reinterpret_cast(&scratch[0]), ptr, static_cast(2 * dimension_size)); } } // namespace detail } // namespace portfft diff --git a/src/portfft/common/host_fft.hpp b/src/portfft/common/host_fft.hpp index 89f9774d..b7ebe70e 100644 --- a/src/portfft/common/host_fft.hpp +++ b/src/portfft/common/host_fft.hpp @@ -27,13 +27,13 @@ namespace portfft { namespace detail { template -void naive_dft(T* input, T* output, IdxGlobal fft_size) { - using ftype = std::complex; +void naive_dft(std::complex* input, std::complex* output, IdxGlobal fft_size) { + using ctype = std::complex; for (int i = 0; i < fft_size; i++) { - ftype temp = ftype(0, 0); + ctype temp = ctype(0, 0); for (int j = 0; j < fft_size; j++) { - ftype multiplier = ftype(std::cos((-2 * M_PI * i * j) / static_cast(fft_size)), - std::sin((-2 * M_PI * i * j) / static_cast(fft_size))); + ctype multiplier = ctype(static_cast(std::cos((-2 * M_PI * i * j) / static_cast(fft_size))), + static_cast(std::sin((-2 * M_PI * i * j) / static_cast(fft_size)))); temp += input[j] * multiplier; } output[i] = temp; diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index a539cb20..3e6c927b 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -151,6 +151,7 @@ struct committed_descriptor::calculate_twiddles_struct::inner(dimension_data.backward_factors - 1); i++) { mem_required_for_twiddles += 2 * factors_idx_global.at(i) * sub_batches.at(i); } + mem_required_for_twiddles += static_cast(4 * dimension_data.length); } // Now calculate mem required for twiddles per implementation @@ -182,6 +183,15 @@ struct committed_descriptor::calculate_twiddles_struct::inner(dimension_data.committed_length), + static_cast(dimension_data.length), desc.queue); + offset += static_cast(2 * dimension_data.length); + detail::populate_bluestein_input_modifiers(device_twiddles + counter, + static_cast(dimension_data.committed_length), + static_cast(dimension_data.length), desc.queue); + } // calculate twiddles to be multiplied between factors for (std::size_t i = 0; i < static_cast(dimension_data.forward_factors) - 1; i++) { calculate_twiddles(sub_batches.at(i), factors_idx_global.at(i), offset, host_memory.data()); From 8d506bdceb0ac3e05b75a2d6e5ca890c748a9c3c Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Thu, 14 Dec 2023 15:15:20 +0000 Subject: [PATCH 12/67] fixed issue is setting dimension size --- src/portfft/descriptor.hpp | 9 +- src/portfft/dispatcher/global_dispatcher.hpp | 118 ++++++++++--------- 2 files changed, 66 insertions(+), 61 deletions(-) diff --git a/src/portfft/descriptor.hpp b/src/portfft/descriptor.hpp index 8bd8ecb6..191a963c 100644 --- a/src/portfft/descriptor.hpp +++ b/src/portfft/descriptor.hpp @@ -332,6 +332,7 @@ class committed_descriptor { std::vector factors; IdxGlobal fft_size = static_cast(params.lengths[kernel_num]); if (detail::fits_in_wi(fft_size)) { + factors.push_back(static_cast(fft_size)); ids = detail::get_ids(); return {detail::level::WORKITEM, {{detail::level::WORKITEM, ids, factors}}}; } @@ -573,13 +574,10 @@ class committed_descriptor { std::vector result; std::size_t counter = 0; std::size_t dimension_size = 1; - for (auto [level, ids, factors] : prepared_vec) { + for (const auto& [level, ids, factors] : prepared_vec) { dimension_size *= static_cast(std::accumulate(factors.begin(), factors.end(), 1, std::multiplies())); counter++; - if (dimension_size != params.lengths[kernel_num]) { - break; - } } std::size_t backward_factors = prepared_vec.size() - counter; std::vector in_bundles; @@ -847,9 +845,6 @@ class committed_descriptor { std::size_t n_kernels = params.lengths.size(); for (std::size_t i = 0; i < n_kernels; i++) { dimensions.push_back(build_w_spec_const(i)); - if (dimensions.back().committed_length != dimensions.back().length) { - dimensions.back().is_prime = true; - } dimensions.back().kernels.at(0).twiddles_forward = std::shared_ptr(calculate_twiddles(dimensions.back()), [queue](Scalar* ptr) { if (ptr != nullptr) { diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index 3e6c927b..3c36d0e2 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -184,11 +184,12 @@ struct committed_descriptor::calculate_twiddles_struct::inner(dimension_data.committed_length), + detail::get_fft_chirp_signal(device_twiddles + offset, static_cast(dimension_data.committed_length), static_cast(dimension_data.length), desc.queue); offset += static_cast(2 * dimension_data.length); - detail::populate_bluestein_input_modifiers(device_twiddles + counter, + detail::populate_bluestein_input_modifiers(device_twiddles + offset, static_cast(dimension_data.committed_length), static_cast(dimension_data.length), desc.queue); } @@ -347,61 +348,70 @@ struct committed_descriptor::run_kernel_struct(num_factors - 1); i++) { initial_impl_twiddle_offset += 2 * kernels.at(i).batch_size * static_cast(kernels.at(i).length); } - for (std::size_t i = 0; i < num_batches; i += max_batches_in_l2) { - IdxGlobal intermediate_twiddles_offset = 0; - IdxGlobal impl_twiddle_offset = initial_impl_twiddle_offset; - detail::compute_level(kernels.at(0), in, desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, - scale_factor, intermediate_twiddles_offset, impl_twiddle_offset, - 2 * static_cast(i) * committed_size + input_offset, committed_size, - static_cast(max_batches_in_l2), static_cast(num_batches), - static_cast(i), 0, num_factors, current_events, previous_events, - desc.queue); - intermediate_twiddles_offset += 2 * kernels.at(0).batch_size * static_cast(kernels.at(0).length); - impl_twiddle_offset += - detail::increment_twiddle_offset(kernels.at(0).level, static_cast(kernels.at(0).length)); - current_events.swap(previous_events); - for (std::size_t factor_num = 1; factor_num < static_cast(num_factors); factor_num++) { - if (static_cast(factor_num) == num_factors - 1) { - detail::compute_level( - dimension_data.kernels.at(factor_num), static_cast(desc.scratch_ptr_1.get()), - desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, scale_factor, intermediate_twiddles_offset, - impl_twiddle_offset, 0, committed_size, static_cast(max_batches_in_l2), - static_cast(num_batches), static_cast(i), static_cast(factor_num), num_factors, - current_events, previous_events, desc.queue); - } else { - detail::compute_level( - kernels.at(factor_num), static_cast(desc.scratch_ptr_1.get()), desc.scratch_ptr_1.get(), - twiddles_ptr, factors_and_scan, scale_factor, intermediate_twiddles_offset, impl_twiddle_offset, 0, - committed_size, static_cast(max_batches_in_l2), static_cast(num_batches), - static_cast(i), static_cast(factor_num), num_factors, current_events, previous_events, - desc.queue); - intermediate_twiddles_offset += - 2 * kernels.at(factor_num).batch_size * static_cast(kernels.at(factor_num).length); - impl_twiddle_offset += detail::increment_twiddle_offset(kernels.at(factor_num).level, - static_cast(kernels.at(factor_num).length)); - current_events.swap(previous_events); + + auto run_global = [&](const std::vector& kernels) { + for (std::size_t i = 0; i < num_batches; i += max_batches_in_l2) { + IdxGlobal intermediate_twiddles_offset = 0; + IdxGlobal impl_twiddle_offset = initial_impl_twiddle_offset; + if (dimension_data.is_prime) { + impl_twiddle_offset += static_cast(4 * dimension_data.length); + } + detail::compute_level( + kernels.at(0), in, desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, scale_factor, + intermediate_twiddles_offset, impl_twiddle_offset, + 2 * static_cast(i) * committed_size + input_offset, committed_size, + static_cast(max_batches_in_l2), static_cast(num_batches), static_cast(i), 0, + num_factors, current_events, previous_events, desc.queue); + intermediate_twiddles_offset += 2 * kernels.at(0).batch_size * static_cast(kernels.at(0).length); + impl_twiddle_offset += + detail::increment_twiddle_offset(kernels.at(0).level, static_cast(kernels.at(0).length)); + current_events.swap(previous_events); + for (std::size_t factor_num = 1; factor_num < static_cast(num_factors); factor_num++) { + if (static_cast(factor_num) == num_factors - 1) { + detail::compute_level( + dimension_data.kernels.at(factor_num), static_cast(desc.scratch_ptr_1.get()), + desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, scale_factor, intermediate_twiddles_offset, + impl_twiddle_offset, 0, committed_size, static_cast(max_batches_in_l2), + static_cast(num_batches), static_cast(i), static_cast(factor_num), + num_factors, current_events, previous_events, desc.queue); + } else { + detail::compute_level( + kernels.at(factor_num), static_cast(desc.scratch_ptr_1.get()), desc.scratch_ptr_1.get(), + twiddles_ptr, factors_and_scan, scale_factor, intermediate_twiddles_offset, impl_twiddle_offset, 0, + committed_size, static_cast(max_batches_in_l2), static_cast(num_batches), + static_cast(i), static_cast(factor_num), num_factors, current_events, previous_events, + desc.queue); + intermediate_twiddles_offset += + 2 * kernels.at(factor_num).batch_size * static_cast(kernels.at(factor_num).length); + impl_twiddle_offset += detail::increment_twiddle_offset(kernels.at(factor_num).level, + static_cast(kernels.at(factor_num).length)); + current_events.swap(previous_events); + } + } + current_events[0] = desc.queue.submit([&](sycl::handler& cgh) { + cgh.depends_on(previous_events); + cgh.host_task([&]() {}); + }); + for (Idx num_transpose = num_transposes - 1; num_transpose > 0; num_transpose--) { + current_events[0] = detail::transpose_level( + kernels.at(static_cast(num_transpose) + static_cast(num_factors)), + static_cast(desc.scratch_ptr_1.get()), desc.scratch_ptr_2.get(), factors_and_scan, + committed_size, static_cast(max_batches_in_l2), n_transforms, static_cast(i), + num_transpose, num_factors, 0, desc.queue, desc.scratch_ptr_1, desc.scratch_ptr_2, current_events, + previous_events); + current_events[0].wait(); } - } - current_events[0] = desc.queue.submit([&](sycl::handler& cgh) { - cgh.depends_on(previous_events); - cgh.host_task([&]() {}); - }); - for (Idx num_transpose = num_transposes - 1; num_transpose > 0; num_transpose--) { current_events[0] = detail::transpose_level( - kernels.at(static_cast(num_transpose) + static_cast(num_factors)), - static_cast(desc.scratch_ptr_1.get()), desc.scratch_ptr_2.get(), factors_and_scan, - committed_size, static_cast(max_batches_in_l2), n_transforms, static_cast(i), num_transpose, - num_factors, 0, desc.queue, desc.scratch_ptr_1, desc.scratch_ptr_2, current_events, previous_events); - current_events[0].wait(); + kernels.at(static_cast(num_factors)), static_cast(desc.scratch_ptr_1.get()), + out, factors_and_scan, committed_size, static_cast(max_batches_in_l2), n_transforms, + static_cast(i), 0, num_factors, 2 * static_cast(i) * committed_size + output_offset, + desc.queue, desc.scratch_ptr_1, desc.scratch_ptr_2, current_events, previous_events); } - current_events[0] = detail::transpose_level( - kernels.at(static_cast(num_factors)), static_cast(desc.scratch_ptr_1.get()), out, - factors_and_scan, committed_size, static_cast(max_batches_in_l2), n_transforms, - static_cast(i), 0, num_factors, 2 * static_cast(i) * committed_size + output_offset, - desc.queue, desc.scratch_ptr_1, desc.scratch_ptr_2, current_events, previous_events); - } + }; + run_global.template operator()(kernels); return current_events[0]; } }; From b5306240a928d12de4f5c0b330b126cf56129213 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Thu, 14 Dec 2023 16:28:45 +0000 Subject: [PATCH 13/67] added backward pass required for bluestein --- src/portfft/descriptor.hpp | 2 +- src/portfft/dispatcher/global_dispatcher.hpp | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/portfft/descriptor.hpp b/src/portfft/descriptor.hpp index 191a963c..0c4e9e36 100644 --- a/src/portfft/descriptor.hpp +++ b/src/portfft/descriptor.hpp @@ -591,7 +591,7 @@ class committed_descriptor { backward_factors > 0 ? detail::elementwise_multiply::APPLIED : detail::elementwise_multiply::NOT_APPLIED; detail::elementwise_multiply last_uses_store_modifiers = backward_factors > 0 ? detail::elementwise_multiply::APPLIED : detail::elementwise_multiply::NOT_APPLIED; - set_global_impl_spec_consts(in_bundles, first_uses_load_modifiers, detail::elementwise_multiply::NOT_APPLIED, + set_global_impl_spec_consts(in_bundles, first_uses_load_modifiers, last_uses_store_modifiers, static_cast(counter)); if (backward_factors > 0) { std::vector backward_kernels_slice(in_bundles.begin() + static_cast(counter), diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index 3c36d0e2..5d599608 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -184,7 +184,6 @@ struct committed_descriptor::calculate_twiddles_struct::inner(dimension_data.committed_length), static_cast(dimension_data.length), desc.queue); @@ -411,7 +410,11 @@ struct committed_descriptor::run_kernel_struct(kernels); + run_global.template operator()(kernels); + if (dimension_data.is_prime) { + run_global.template operator()( + std::vector(kernels.begin() + static_cast(dimension_data.forward_factors), kernels.end())); + } return current_events[0]; } }; From 2d3c88afdea14c3118687de5e9ef3cfbf2655233 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Thu, 14 Dec 2023 19:49:33 +0000 Subject: [PATCH 14/67] added prime sized tests --- test/unit_test/instantiate_fft_tests.hpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/unit_test/instantiate_fft_tests.hpp b/test/unit_test/instantiate_fft_tests.hpp index d8cbe2d8..88f696ec 100644 --- a/test/unit_test/instantiate_fft_tests.hpp +++ b/test/unit_test/instantiate_fft_tests.hpp @@ -116,6 +116,12 @@ INSTANTIATE_TEST_SUITE_P(GlobalTest, FFTTest, ::testing::Values(sizes_t{32768}, sizes_t{65536}, sizes_t{131072}))), test_params_print()); +INSTANTIATE_TEST_SUITE_P(PrimeSizedTest, FFTTest, + ::testing::ConvertGenerator(::testing::Combine( + all_valid_global_placement_layouts, fwd_only, interleaved_storage, ::testing::Values(1, 3), + ::testing::Values(sizes_t{211}, sizes_t{523}, sizes_t{65537}))), + test_params_print()); + // Backward FFT test suite INSTANTIATE_TEST_SUITE_P(BackwardTest, FFTTest, ::testing::ConvertGenerator(::testing::Combine( From 67c3c7fc25a4d015930ecdcdf9e998cc871e972a Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Wed, 20 Dec 2023 09:28:11 +0000 Subject: [PATCH 15/67] fix compilation issues --- src/portfft/dispatcher/global_dispatcher.hpp | 123 +++++++++---------- 1 file changed, 61 insertions(+), 62 deletions(-) diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index 5d599608..9bd7762b 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -348,72 +348,71 @@ struct committed_descriptor::run_kernel_struct(kernels.at(i).length); } - auto run_global = [&](const std::vector& kernels) { - for (std::size_t i = 0; i < num_batches; i += max_batches_in_l2) { - IdxGlobal intermediate_twiddles_offset = 0; - IdxGlobal impl_twiddle_offset = initial_impl_twiddle_offset; - if (dimension_data.is_prime) { - impl_twiddle_offset += static_cast(4 * dimension_data.length); - } - detail::compute_level( - kernels.at(0), in, desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, scale_factor, - intermediate_twiddles_offset, impl_twiddle_offset, - 2 * static_cast(i) * committed_size + input_offset, committed_size, - static_cast(max_batches_in_l2), static_cast(num_batches), static_cast(i), 0, - num_factors, current_events, previous_events, desc.queue); - intermediate_twiddles_offset += 2 * kernels.at(0).batch_size * static_cast(kernels.at(0).length); - impl_twiddle_offset += - detail::increment_twiddle_offset(kernels.at(0).level, static_cast(kernels.at(0).length)); - current_events.swap(previous_events); - for (std::size_t factor_num = 1; factor_num < static_cast(num_factors); factor_num++) { - if (static_cast(factor_num) == num_factors - 1) { - detail::compute_level( - dimension_data.kernels.at(factor_num), static_cast(desc.scratch_ptr_1.get()), - desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, scale_factor, intermediate_twiddles_offset, - impl_twiddle_offset, 0, committed_size, static_cast(max_batches_in_l2), - static_cast(num_batches), static_cast(i), static_cast(factor_num), - num_factors, current_events, previous_events, desc.queue); - } else { - detail::compute_level( - kernels.at(factor_num), static_cast(desc.scratch_ptr_1.get()), desc.scratch_ptr_1.get(), - twiddles_ptr, factors_and_scan, scale_factor, intermediate_twiddles_offset, impl_twiddle_offset, 0, - committed_size, static_cast(max_batches_in_l2), static_cast(num_batches), - static_cast(i), static_cast(factor_num), num_factors, current_events, previous_events, - desc.queue); - intermediate_twiddles_offset += - 2 * kernels.at(factor_num).batch_size * static_cast(kernels.at(factor_num).length); - impl_twiddle_offset += detail::increment_twiddle_offset(kernels.at(factor_num).level, - static_cast(kernels.at(factor_num).length)); - current_events.swap(previous_events); - } - } - current_events[0] = desc.queue.submit([&](sycl::handler& cgh) { - cgh.depends_on(previous_events); - cgh.host_task([&]() {}); - }); - for (Idx num_transpose = num_transposes - 1; num_transpose > 0; num_transpose--) { - current_events[0] = detail::transpose_level( - kernels.at(static_cast(num_transpose) + static_cast(num_factors)), - static_cast(desc.scratch_ptr_1.get()), desc.scratch_ptr_2.get(), factors_and_scan, - committed_size, static_cast(max_batches_in_l2), n_transforms, static_cast(i), - num_transpose, num_factors, 0, desc.queue, desc.scratch_ptr_1, desc.scratch_ptr_2, current_events, - previous_events); - current_events[0].wait(); + auto run_global = [&](const std::vector& kernels, const std::size_t& i) { + IdxGlobal intermediate_twiddles_offset = 0; + IdxGlobal impl_twiddle_offset = initial_impl_twiddle_offset; + if (dimension_data.is_prime) { + impl_twiddle_offset += static_cast(4 * dimension_data.length); + } + detail::compute_level( + kernels.at(0), in, desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, scale_factor, + intermediate_twiddles_offset, impl_twiddle_offset, + 2 * static_cast(i) * committed_size + input_offset, committed_size, + static_cast(max_batches_in_l2), static_cast(num_batches), static_cast(i), 0, + num_factors, current_events, previous_events, desc.queue); + intermediate_twiddles_offset += 2 * kernels.at(0).batch_size * static_cast(kernels.at(0).length); + impl_twiddle_offset += + detail::increment_twiddle_offset(kernels.at(0).level, static_cast(kernels.at(0).length)); + current_events.swap(previous_events); + for (std::size_t factor_num = 1; factor_num < static_cast(num_factors); factor_num++) { + if (static_cast(factor_num) == num_factors - 1) { + detail::compute_level( + dimension_data.kernels.at(factor_num), static_cast(desc.scratch_ptr_1.get()), + desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, scale_factor, intermediate_twiddles_offset, + impl_twiddle_offset, 0, committed_size, static_cast(max_batches_in_l2), + static_cast(num_batches), static_cast(i), static_cast(factor_num), num_factors, + current_events, previous_events, desc.queue); + } else { + detail::compute_level( + kernels.at(factor_num), static_cast(desc.scratch_ptr_1.get()), desc.scratch_ptr_1.get(), + twiddles_ptr, factors_and_scan, scale_factor, intermediate_twiddles_offset, impl_twiddle_offset, 0, + committed_size, static_cast(max_batches_in_l2), static_cast(num_batches), + static_cast(i), static_cast(factor_num), num_factors, current_events, previous_events, + desc.queue); + intermediate_twiddles_offset += + 2 * kernels.at(factor_num).batch_size * static_cast(kernels.at(factor_num).length); + impl_twiddle_offset += detail::increment_twiddle_offset(kernels.at(factor_num).level, + static_cast(kernels.at(factor_num).length)); + current_events.swap(previous_events); } + } + current_events[0] = desc.queue.submit([&](sycl::handler& cgh) { + cgh.depends_on(previous_events); + cgh.host_task([&]() {}); + }); + for (Idx num_transpose = num_transposes - 1; num_transpose > 0; num_transpose--) { current_events[0] = detail::transpose_level( - kernels.at(static_cast(num_factors)), static_cast(desc.scratch_ptr_1.get()), - out, factors_and_scan, committed_size, static_cast(max_batches_in_l2), n_transforms, - static_cast(i), 0, num_factors, 2 * static_cast(i) * committed_size + output_offset, - desc.queue, desc.scratch_ptr_1, desc.scratch_ptr_2, current_events, previous_events); + kernels.at(static_cast(num_transpose) + static_cast(num_factors)), + static_cast(desc.scratch_ptr_1.get()), desc.scratch_ptr_2.get(), factors_and_scan, + committed_size, static_cast(max_batches_in_l2), n_transforms, static_cast(i), num_transpose, + num_factors, 0, desc.queue, desc.scratch_ptr_1, desc.scratch_ptr_2, current_events, previous_events); + current_events[0].wait(); } + current_events[0] = detail::transpose_level( + kernels.at(static_cast(num_factors)), static_cast(desc.scratch_ptr_1.get()), out, + factors_and_scan, committed_size, static_cast(max_batches_in_l2), n_transforms, + static_cast(i), 0, num_factors, 2 * static_cast(i) * committed_size + output_offset, + desc.queue, desc.scratch_ptr_1, desc.scratch_ptr_2, current_events, previous_events); }; - run_global.template operator()(kernels); - if (dimension_data.is_prime) { - run_global.template operator()( - std::vector(kernels.begin() + static_cast(dimension_data.forward_factors), kernels.end())); + for (std::size_t i = 0; i < num_batches; i += max_batches_in_l2) { + run_global.template operator()(kernels, i); + if (dimension_data.is_prime) { + run_global.template operator()( + std::vector(kernels.begin() + static_cast(dimension_data.forward_factors), kernels.end()), i); + } } return current_events[0]; } From b23be660a99a4608c4b14f56ed39a90c114d2065 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Wed, 20 Dec 2023 21:11:31 +0000 Subject: [PATCH 16/67] add copy function in global dispatcher --- src/portfft/dispatcher/global_dispatcher.hpp | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index 9bd7762b..a2a8a4b3 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -101,6 +101,15 @@ inline IdxGlobal increment_twiddle_offset(detail::level level, Idx factor_size) return 0; } +template +void trigger_device_copy(const T* src, T* dst, IdxGlobal num_elements_to_copy, IdxGlobal src_stride, + IdxGlobal dst_stride, Idx num_copies, std::vector& event_vector, + sycl::queue& queue) { + for (Idx i = 0; i < num_copies; i++) { + event_vector.at(i) = queue.copy(src + i * src_stride, dst + i * dst_stride, num_elements_to_copy); + } +} + } // namespace detail template @@ -408,10 +417,13 @@ struct committed_descriptor::run_kernel_struct(kernels, i); if (dimension_data.is_prime) { + run_global.template operator()( + std::vector(kernels.begin() + static_cast(dimension_data.forward_factors), kernels.end()), i); run_global.template operator()( std::vector(kernels.begin() + static_cast(dimension_data.forward_factors), kernels.end()), i); + } else { + run_global.template operator()(kernels, i); } } return current_events[0]; From 4a92d8cf41dd1fa3f764d6a5420f080fccb799c6 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Thu, 21 Dec 2023 10:00:45 +0000 Subject: [PATCH 17/67] enable taking conjugate before and after fft compute --- src/portfft/common/workgroup.hpp | 24 ++++- .../dispatcher/subgroup_dispatcher.hpp | 26 ++++++ .../dispatcher/workgroup_dispatcher.hpp | 6 +- .../dispatcher/workitem_dispatcher.hpp | 90 +++++++++++-------- src/portfft/specialization_constant.hpp | 3 + 5 files changed, 105 insertions(+), 44 deletions(-) diff --git a/src/portfft/common/workgroup.hpp b/src/portfft/common/workgroup.hpp index d786e1bd..1b5fa036 100644 --- a/src/portfft/common/workgroup.hpp +++ b/src/portfft/common/workgroup.hpp @@ -55,6 +55,8 @@ namespace detail { * @param multiply_on_load Whether the input data is multiplied with some data array before fft computation. * @param MultiplyOnStore Whether the input data is multiplied with some data array after fft computation. * @param ApplyScaleFactor Whether or not the scale factor is applied + * @param take_conjugate_on_load Whether or not take the conjugate of the input before computing the fft. + * @param take_conjugate_on_store Whether or not take the conjugate of the result of the fft * @param global_data global data for the kernel */ template @@ -63,7 +65,8 @@ __attribute__((always_inline)) inline void dimension_dft( Idx batch_num_in_local, const T* load_modifier_data, const T* store_modifier_data, IdxGlobal batch_num_in_kernel, Idx dft_size, Idx stride_within_dft, Idx ndfts_in_outer_dimension, detail::layout layout_in, detail::elementwise_multiply multiply_on_load, detail::elementwise_multiply multiply_on_store, - detail::apply_scale_factor apply_scale_factor, global_data_struct<1> global_data) { + detail::apply_scale_factor apply_scale_factor, bool take_conjugate_on_load, bool take_conjugate_on_store, + global_data_struct<1> global_data) { static_assert(std::is_same_v, T>, "Real type mismatch"); global_data.log_message_global(__func__, "entered", "DFTSize", dft_size, "stride_within_dft", stride_within_dft, "ndfts_in_outer_dimension", ndfts_in_outer_dimension, "max_num_batches_in_local_mem", @@ -175,9 +178,19 @@ __attribute__((always_inline)) inline void dimension_dft( } } } - + if (take_conjugate_on_load) { + PORTFFT_UNROLL + for (Idx k = 0; k < fact_wi; k++) { + priv[2 * k + 1] *= -1; + } + } sg_dft(priv, global_data.sg, fact_wi, fact_sg, loc_twiddles, wi_private_scratch); - + if (take_conjugate_on_store) { + PORTFFT_UNROLL + for (Idx k = 0; k < fact_wi; k++) { + priv[2 * k + 1] *= -1; + } + } if (working) { if (multiply_on_store == detail::elementwise_multiply::APPLIED) { // Store modifier data layout in global memory - n_transforms x N x FactorSG x FactorWI @@ -237,6 +250,8 @@ __attribute__((always_inline)) inline void dimension_dft( * @param multiply_on_load Whether the input data is multiplied with some data array before fft computation. * @param multiply_on_store Whether the input data is multiplied with some data array after fft computation. * @param apply_scale_factor Whether or not the scale factor is applied + * @param take_conjugate_on_load Whether or not take the conjugate of the input before computing the fft. + * @param take_conjugate_on_store Whether or not take the conjugate of the result of the fft * @param global_data global data for the kernel */ template @@ -245,7 +260,8 @@ PORTFFT_INLINE void wg_dft(LocalT loc, T* loc_twiddles, const T* wg_twiddles, T const T* load_modifier_data, const T* store_modifier_data, Idx fft_size, Idx N, Idx M, detail::layout layout_in, detail::elementwise_multiply multiply_on_load, detail::elementwise_multiply multiply_on_store, - detail::apply_scale_factor apply_scale_factor, detail::global_data_struct<1> global_data) { + detail::apply_scale_factor apply_scale_factor, bool take_conjugate_on_load, + bool take_conjugate_on_store, detail::global_data_struct<1> global_data) { global_data.log_message_global(__func__, "entered", "FFTSize", fft_size, "N", N, "M", M, "max_num_batches_in_local_mem", max_num_batches_in_local_mem, "batch_num_in_local", batch_num_in_local); diff --git a/src/portfft/dispatcher/subgroup_dispatcher.hpp b/src/portfft/dispatcher/subgroup_dispatcher.hpp index e875fb3e..2b4a99a4 100644 --- a/src/portfft/dispatcher/subgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/subgroup_dispatcher.hpp @@ -96,6 +96,8 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag detail::elementwise_multiply multiply_on_load = kh.get_specialization_constant(); detail::elementwise_multiply multiply_on_store = kh.get_specialization_constant(); detail::apply_scale_factor apply_scale_factor = kh.get_specialization_constant(); + bool take_conjugate_on_load = kh.get_specialization_constant(); + bool take_conjugate_on_store = kh.get_specialization_constant(); const Idx factor_wi = kh.get_specialization_constant(); const Idx factor_sg = kh.get_specialization_constant(); @@ -255,7 +257,19 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag } } } + if (take_conjugate_on_load) { + PORTFFT_UNROLL + for (Idx k = 0; k < factor_wi; k++) { + priv[2 * k + 1] *= -1; + } + } sg_dft(priv, global_data.sg, factor_wi, factor_sg, loc_twiddles, wi_private_scratch); + if (take_conjugate_on_store) { + PORTFFT_UNROLL + for (Idx k = 0; k < factor_wi; k++) { + priv[2 * k + 1] *= -1; + } + } if (working_inner) { global_data.log_dump_private("data in registers after computation:", priv, n_reals_per_wi); } @@ -449,7 +463,19 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag } } } + if (take_conjugate_on_load) { + PORTFFT_UNROLL + for (Idx k = 0; k < factor_wi; k++) { + priv[2 * k + 1] *= -1; + } + } sg_dft(priv, global_data.sg, factor_wi, factor_sg, loc_twiddles, wi_private_scratch); + if (take_conjugate_on_store) { + PORTFFT_UNROLL + for (Idx k = 0; k < factor_wi; k++) { + priv[2 * k + 1] *= -1; + } + } if (working) { global_data.log_dump_private("data in registers after computation:", priv, n_reals_per_wi); } diff --git a/src/portfft/dispatcher/workgroup_dispatcher.hpp b/src/portfft/dispatcher/workgroup_dispatcher.hpp index f15baa57..7549ddc1 100644 --- a/src/portfft/dispatcher/workgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/workgroup_dispatcher.hpp @@ -109,6 +109,8 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* /*input_i detail::elementwise_multiply multiply_on_load = kh.get_specialization_constant(); detail::elementwise_multiply multiply_on_store = kh.get_specialization_constant(); detail::apply_scale_factor apply_scale_factor = kh.get_specialization_constant(); + bool take_conjugate_on_load = kh.get_specialization_constant(); + bool take_conjugate_on_store = kh.get_specialization_constant(); const Idx fft_size = kh.get_specialization_constant(); @@ -153,7 +155,7 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* /*input_i wg_dft(loc_view, loc_twiddles, wg_twiddles, scaling_factor, max_num_batches_in_local_mem, sub_batch, offset / (2 * fft_size), load_modifier_data, store_modifier_data, fft_size, factor_n, factor_m, LayoutIn, multiply_on_load, multiply_on_store, apply_scale_factor, - global_data); + take_conjugate_on_load, take_conjugate_on_store, global_data); sycl::group_barrier(global_data.it.get_group()); } if constexpr (LayoutOut == detail::layout::PACKED) { @@ -180,7 +182,7 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* /*input_i wg_dft(loc_view, loc_twiddles, wg_twiddles, scaling_factor, max_num_batches_in_local_mem, 0, offset / static_cast(2 * fft_size), load_modifier_data, store_modifier_data, fft_size, factor_n, factor_m, LayoutIn, multiply_on_load, multiply_on_store, - apply_scale_factor, global_data); + apply_scale_factor, take_ take_conjugate_on_load, take_conjugate_on_store, global_data); sycl::group_barrier(global_data.it.get_group()); global_data.log_message_global(__func__, "storing non-transposed data from local to global memory"); // transposition for WG CT diff --git a/src/portfft/dispatcher/workitem_dispatcher.hpp b/src/portfft/dispatcher/workitem_dispatcher.hpp index d7d7e229..36ea1b66 100644 --- a/src/portfft/dispatcher/workitem_dispatcher.hpp +++ b/src/portfft/dispatcher/workitem_dispatcher.hpp @@ -116,6 +116,9 @@ PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag detail::apply_scale_factor apply_scale_factor = kh.get_specialization_constant(); const Idx fft_size = kh.get_specialization_constant(); + bool take_conjugate_on_load = kh.get_specialization_constant(); + bool take_conjugate_on_store = kh.get_specialization_constant(); + global_data.log_message_global(__func__, "entered", "fft_size", fft_size, "n_transforms", n_transforms); bool interleaved_storage = storage == complex_storage::INTERLEAVED_COMPLEX; @@ -207,50 +210,61 @@ PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag global_data.log_message_global(__func__, "applying load modifier"); detail::apply_modifier(fft_size, priv, load_modifier_data, i * n_reals); } - wi_dft(priv, priv, fft_size, 1, 1, wi_private_scratch); - global_data.log_dump_private("data in registers after computation:", priv, n_reals); - if (multiply_on_store == detail::elementwise_multiply::APPLIED) { - // Assumes store modifier data is stored in a transposed fashion (fft_size x num_batches_local_mem) - // to ensure much lesser bank conflicts - global_data.log_message_global(__func__, "applying store modifier"); - detail::apply_modifier(fft_size, priv, store_modifier_data, i * n_reals); - } - if (apply_scale_factor == detail::apply_scale_factor::APPLIED) { + if (take_conjugate_on_load) { PORTFFT_UNROLL - for (Idx idx = 0; idx < n_reals; idx += 2) { - priv[idx] *= scaling_factor; - priv[idx + 1] *= scaling_factor; + for (Idx j = 0; j < fft_size; j++) { + priv[2 * j + 1] *= -1; } - } - global_data.log_dump_private("data in registers after scaling:", priv, n_reals); - global_data.log_message_global(__func__, "loading data from private to local memory"); - if (LayoutOut == detail::layout::PACKED) { - if (storage == complex_storage::INTERLEAVED_COMPLEX) { - detail::offset_view offset_local_view{loc_view, local_offset + subgroup_local_id * n_reals}; - copy_wi(global_data, priv, offset_local_view, n_reals); - } else { - detail::strided_view priv_real_view{priv, 2}; - detail::strided_view priv_imag_view{priv, 2, 1}; - detail::offset_view local_real_view{loc_view, local_offset + subgroup_local_id * fft_size}; - detail::offset_view local_imag_view{loc_view, - local_offset + subgroup_local_id * fft_size + local_imag_offset}; - copy_wi(global_data, priv_real_view, local_real_view, fft_size); - copy_wi(global_data, priv_imag_view, local_imag_view, fft_size); + wi_dft(priv, priv, fft_size, 1, 1, wi_private_scratch); + if (take_conjugate_on_store) { + PORTFFT_UNROLL + for (Idx j = 0; j < fft_size; j++) { + priv[2 * j + 1] *= -1; + } } - } else { - if (storage == complex_storage::INTERLEAVED_COMPLEX) { - detail::strided_view output_view{output, n_transforms, i * 2}; - copy_wi<2>(global_data, priv, output_view, fft_size); + global_data.log_dump_private("data in registers after computation:", priv, n_reals); + if (multiply_on_store == detail::elementwise_multiply::APPLIED) { + // Assumes store modifier data is stored in a transposed fashion (fft_size x num_batches_local_mem) + // to ensure much lesser bank conflicts + global_data.log_message_global(__func__, "applying store modifier"); + detail::apply_modifier(fft_size, priv, store_modifier_data, i * n_reals); + } + if (apply_scale_factor == detail::apply_scale_factor::APPLIED) { + PORTFFT_UNROLL + for (Idx idx = 0; idx < n_reals; idx += 2) { + priv[idx] *= scaling_factor; + priv[idx + 1] *= scaling_factor; + } + } + global_data.log_dump_private("data in registers after scaling:", priv, n_reals); + global_data.log_message_global(__func__, "loading data from private to local memory"); + if (LayoutOut == detail::layout::PACKED) { + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + detail::offset_view offset_local_view{loc_view, local_offset + subgroup_local_id * n_reals}; + copy_wi(global_data, priv, offset_local_view, n_reals); + } else { + detail::strided_view priv_real_view{priv, 2}; + detail::strided_view priv_imag_view{priv, 2, 1}; + detail::offset_view local_real_view{loc_view, local_offset + subgroup_local_id * fft_size}; + detail::offset_view local_imag_view{loc_view, + local_offset + subgroup_local_id * fft_size + local_imag_offset}; + copy_wi(global_data, priv_real_view, local_real_view, fft_size); + copy_wi(global_data, priv_imag_view, local_imag_view, fft_size); + } } else { - detail::strided_view priv_real_view{priv, 2}; - detail::strided_view priv_imag_view{priv, 2, 1}; - detail::strided_view output_real_view{output, n_transforms, i}; - detail::strided_view output_imag_view{output_imag, n_transforms, i}; - copy_wi(global_data, priv_real_view, output_real_view, fft_size); - copy_wi(global_data, priv_imag_view, output_imag_view, fft_size); + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + detail::strided_view output_view{output, n_transforms, i * 2}; + copy_wi<2>(global_data, priv, output_view, fft_size); + } else { + detail::strided_view priv_real_view{priv, 2}; + detail::strided_view priv_imag_view{priv, 2, 1}; + detail::strided_view output_real_view{output, n_transforms, i}; + detail::strided_view output_imag_view{output_imag, n_transforms, i}; + copy_wi(global_data, priv_real_view, output_real_view, fft_size); + copy_wi(global_data, priv_imag_view, output_imag_view, fft_size); + } } } - } if (LayoutOut == detail::layout::PACKED) { sycl::group_barrier(global_data.sg); global_data.log_dump_local("computed data local memory:", loc, n_reals * n_working); diff --git a/src/portfft/specialization_constant.hpp b/src/portfft/specialization_constant.hpp index cefd5b37..5c18773b 100644 --- a/src/portfft/specialization_constant.hpp +++ b/src/portfft/specialization_constant.hpp @@ -44,6 +44,9 @@ constexpr static sycl::specialization_id GlobalSubImplSpecConst{}; constexpr static sycl::specialization_id GlobalSpecConstLevelNum{}; constexpr static sycl::specialization_id GlobalSpecConstNumFactors{}; +constexpr static sycl::specialization_id SpecConstTakeConjugateOnLoad{}; +constexpr static sycl::specialization_id SpecConstTakeConjugateOnStore{}; + } // namespace detail } // namespace portfft #endif From 651edd1010a4f35835fcadf00074d66b2f1654e1 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Thu, 21 Dec 2023 10:08:20 +0000 Subject: [PATCH 18/67] fix compilation issues --- src/portfft/common/workgroup.hpp | 6 +- .../dispatcher/workgroup_dispatcher.hpp | 2 +- .../dispatcher/workitem_dispatcher.hpp | 89 ++++++++++--------- 3 files changed, 50 insertions(+), 47 deletions(-) diff --git a/src/portfft/common/workgroup.hpp b/src/portfft/common/workgroup.hpp index 1b5fa036..8418b041 100644 --- a/src/portfft/common/workgroup.hpp +++ b/src/portfft/common/workgroup.hpp @@ -269,13 +269,15 @@ PORTFFT_INLINE void wg_dft(LocalT loc, T* loc_twiddles, const T* wg_twiddles, T detail::dimension_dft( loc, loc_twiddles + (2 * M), nullptr, 1, max_num_batches_in_local_mem, batch_num_in_local, load_modifier_data, store_modifier_data, batch_num_in_kernel, N, M, 1, layout_in, multiply_on_load, - detail::elementwise_multiply::NOT_APPLIED, detail::apply_scale_factor::NOT_APPLIED, global_data); + detail::elementwise_multiply::NOT_APPLIED, detail::apply_scale_factor::NOT_APPLIED, take_conjugate_on_load, + take_conjugate_on_store, global_data); sycl::group_barrier(global_data.it.get_group()); // row-wise DFTs, including twiddle multiplications and scaling detail::dimension_dft( loc, loc_twiddles, wg_twiddles, scaling_factor, max_num_batches_in_local_mem, batch_num_in_local, load_modifier_data, store_modifier_data, batch_num_in_kernel, M, 1, N, layout_in, - detail::elementwise_multiply::NOT_APPLIED, multiply_on_store, apply_scale_factor, global_data); + detail::elementwise_multiply::NOT_APPLIED, multiply_on_store, apply_scale_factor, take_conjugate_on_load, + take_conjugate_on_store, global_data); global_data.log_message_global(__func__, "exited"); } diff --git a/src/portfft/dispatcher/workgroup_dispatcher.hpp b/src/portfft/dispatcher/workgroup_dispatcher.hpp index 7549ddc1..92129a98 100644 --- a/src/portfft/dispatcher/workgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/workgroup_dispatcher.hpp @@ -182,7 +182,7 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* /*input_i wg_dft(loc_view, loc_twiddles, wg_twiddles, scaling_factor, max_num_batches_in_local_mem, 0, offset / static_cast(2 * fft_size), load_modifier_data, store_modifier_data, fft_size, factor_n, factor_m, LayoutIn, multiply_on_load, multiply_on_store, - apply_scale_factor, take_ take_conjugate_on_load, take_conjugate_on_store, global_data); + apply_scale_factor, take_conjugate_on_load, take_conjugate_on_store, global_data); sycl::group_barrier(global_data.it.get_group()); global_data.log_message_global(__func__, "storing non-transposed data from local to global memory"); // transposition for WG CT diff --git a/src/portfft/dispatcher/workitem_dispatcher.hpp b/src/portfft/dispatcher/workitem_dispatcher.hpp index 36ea1b66..b0a3db90 100644 --- a/src/portfft/dispatcher/workitem_dispatcher.hpp +++ b/src/portfft/dispatcher/workitem_dispatcher.hpp @@ -215,56 +215,57 @@ PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag for (Idx j = 0; j < fft_size; j++) { priv[2 * j + 1] *= -1; } - wi_dft(priv, priv, fft_size, 1, 1, wi_private_scratch); - if (take_conjugate_on_store) { - PORTFFT_UNROLL - for (Idx j = 0; j < fft_size; j++) { - priv[2 * j + 1] *= -1; - } + } + wi_dft(priv, priv, fft_size, 1, 1, wi_private_scratch); + if (take_conjugate_on_store) { + PORTFFT_UNROLL + for (Idx j = 0; j < fft_size; j++) { + priv[2 * j + 1] *= -1; } - global_data.log_dump_private("data in registers after computation:", priv, n_reals); - if (multiply_on_store == detail::elementwise_multiply::APPLIED) { - // Assumes store modifier data is stored in a transposed fashion (fft_size x num_batches_local_mem) - // to ensure much lesser bank conflicts - global_data.log_message_global(__func__, "applying store modifier"); - detail::apply_modifier(fft_size, priv, store_modifier_data, i * n_reals); + } + global_data.log_dump_private("data in registers after computation:", priv, n_reals); + if (multiply_on_store == detail::elementwise_multiply::APPLIED) { + // Assumes store modifier data is stored in a transposed fashion (fft_size x num_batches_local_mem) + // to ensure much lesser bank conflicts + global_data.log_message_global(__func__, "applying store modifier"); + detail::apply_modifier(fft_size, priv, store_modifier_data, i * n_reals); + } + if (apply_scale_factor == detail::apply_scale_factor::APPLIED) { + PORTFFT_UNROLL + for (Idx idx = 0; idx < n_reals; idx += 2) { + priv[idx] *= scaling_factor; + priv[idx + 1] *= scaling_factor; } - if (apply_scale_factor == detail::apply_scale_factor::APPLIED) { - PORTFFT_UNROLL - for (Idx idx = 0; idx < n_reals; idx += 2) { - priv[idx] *= scaling_factor; - priv[idx + 1] *= scaling_factor; - } + } + global_data.log_dump_private("data in registers after scaling:", priv, n_reals); + global_data.log_message_global(__func__, "loading data from private to local memory"); + if (LayoutOut == detail::layout::PACKED) { + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + detail::offset_view offset_local_view{loc_view, local_offset + subgroup_local_id * n_reals}; + copy_wi(global_data, priv, offset_local_view, n_reals); + } else { + detail::strided_view priv_real_view{priv, 2}; + detail::strided_view priv_imag_view{priv, 2, 1}; + detail::offset_view local_real_view{loc_view, local_offset + subgroup_local_id * fft_size}; + detail::offset_view local_imag_view{loc_view, + local_offset + subgroup_local_id * fft_size + local_imag_offset}; + copy_wi(global_data, priv_real_view, local_real_view, fft_size); + copy_wi(global_data, priv_imag_view, local_imag_view, fft_size); } - global_data.log_dump_private("data in registers after scaling:", priv, n_reals); - global_data.log_message_global(__func__, "loading data from private to local memory"); - if (LayoutOut == detail::layout::PACKED) { - if (storage == complex_storage::INTERLEAVED_COMPLEX) { - detail::offset_view offset_local_view{loc_view, local_offset + subgroup_local_id * n_reals}; - copy_wi(global_data, priv, offset_local_view, n_reals); - } else { - detail::strided_view priv_real_view{priv, 2}; - detail::strided_view priv_imag_view{priv, 2, 1}; - detail::offset_view local_real_view{loc_view, local_offset + subgroup_local_id * fft_size}; - detail::offset_view local_imag_view{loc_view, - local_offset + subgroup_local_id * fft_size + local_imag_offset}; - copy_wi(global_data, priv_real_view, local_real_view, fft_size); - copy_wi(global_data, priv_imag_view, local_imag_view, fft_size); - } + } else { + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + detail::strided_view output_view{output, n_transforms, i * 2}; + copy_wi<2>(global_data, priv, output_view, fft_size); } else { - if (storage == complex_storage::INTERLEAVED_COMPLEX) { - detail::strided_view output_view{output, n_transforms, i * 2}; - copy_wi<2>(global_data, priv, output_view, fft_size); - } else { - detail::strided_view priv_real_view{priv, 2}; - detail::strided_view priv_imag_view{priv, 2, 1}; - detail::strided_view output_real_view{output, n_transforms, i}; - detail::strided_view output_imag_view{output_imag, n_transforms, i}; - copy_wi(global_data, priv_real_view, output_real_view, fft_size); - copy_wi(global_data, priv_imag_view, output_imag_view, fft_size); - } + detail::strided_view priv_real_view{priv, 2}; + detail::strided_view priv_imag_view{priv, 2, 1}; + detail::strided_view output_real_view{output, n_transforms, i}; + detail::strided_view output_imag_view{output_imag, n_transforms, i}; + copy_wi(global_data, priv_real_view, output_real_view, fft_size); + copy_wi(global_data, priv_imag_view, output_imag_view, fft_size); } } + } if (LayoutOut == detail::layout::PACKED) { sycl::group_barrier(global_data.sg); global_data.log_dump_local("computed data local memory:", loc, n_reals * n_working); From 9fd1566dd4051f16b69da45a9ce32aff8c6632f7 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Wed, 3 Jan 2024 11:44:07 +0000 Subject: [PATCH 19/67] refactor repetative conjugate snippet into a utility function --- src/portfft/common/helpers.hpp | 14 +++++++++++++ src/portfft/common/workgroup.hpp | 10 ++-------- .../dispatcher/subgroup_dispatcher.hpp | 20 ++++--------------- .../dispatcher/workitem_dispatcher.hpp | 10 ++-------- 4 files changed, 22 insertions(+), 32 deletions(-) diff --git a/src/portfft/common/helpers.hpp b/src/portfft/common/helpers.hpp index 4f8017ae..3716e9d6 100644 --- a/src/portfft/common/helpers.hpp +++ b/src/portfft/common/helpers.hpp @@ -184,6 +184,20 @@ PORTFFT_INLINE constexpr Idx int_log2(Idx x) { } return y; } + +/** + * Takes the conjugate of the complex data in private array + * @tparam T Scalar type + * @param priv pointer to the data in registers + * @param num_elements number of complex numbers in the private memory + */ +template +PORTFFT_INLINE void take_conjugate(T* priv, Idx num_elements) { + PORTFFT_UNROLL + for (Idx i = 0; i < num_elements; i++) { + priv[2 * i + 1] *= -1; + } +} } // namespace portfft::detail #endif diff --git a/src/portfft/common/workgroup.hpp b/src/portfft/common/workgroup.hpp index 8418b041..a6452c80 100644 --- a/src/portfft/common/workgroup.hpp +++ b/src/portfft/common/workgroup.hpp @@ -179,17 +179,11 @@ __attribute__((always_inline)) inline void dimension_dft( } } if (take_conjugate_on_load) { - PORTFFT_UNROLL - for (Idx k = 0; k < fact_wi; k++) { - priv[2 * k + 1] *= -1; - } + take_conjugate(priv, fact_wi); } sg_dft(priv, global_data.sg, fact_wi, fact_sg, loc_twiddles, wi_private_scratch); if (take_conjugate_on_store) { - PORTFFT_UNROLL - for (Idx k = 0; k < fact_wi; k++) { - priv[2 * k + 1] *= -1; - } + take_conjugate(priv, fact_wi); } if (working) { if (multiply_on_store == detail::elementwise_multiply::APPLIED) { diff --git a/src/portfft/dispatcher/subgroup_dispatcher.hpp b/src/portfft/dispatcher/subgroup_dispatcher.hpp index 2b4a99a4..c8c31f3e 100644 --- a/src/portfft/dispatcher/subgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/subgroup_dispatcher.hpp @@ -258,17 +258,11 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag } } if (take_conjugate_on_load) { - PORTFFT_UNROLL - for (Idx k = 0; k < factor_wi; k++) { - priv[2 * k + 1] *= -1; - } + take_conjugate(priv, factor_wi); } sg_dft(priv, global_data.sg, factor_wi, factor_sg, loc_twiddles, wi_private_scratch); if (take_conjugate_on_store) { - PORTFFT_UNROLL - for (Idx k = 0; k < factor_wi; k++) { - priv[2 * k + 1] *= -1; - } + take_conjugate(priv, factor_wi); } if (working_inner) { global_data.log_dump_private("data in registers after computation:", priv, n_reals_per_wi); @@ -464,17 +458,11 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag } } if (take_conjugate_on_load) { - PORTFFT_UNROLL - for (Idx k = 0; k < factor_wi; k++) { - priv[2 * k + 1] *= -1; - } + take_conjugate(priv, factor_wi); } sg_dft(priv, global_data.sg, factor_wi, factor_sg, loc_twiddles, wi_private_scratch); if (take_conjugate_on_store) { - PORTFFT_UNROLL - for (Idx k = 0; k < factor_wi; k++) { - priv[2 * k + 1] *= -1; - } + take_conjugate(priv, factor_wi); } if (working) { global_data.log_dump_private("data in registers after computation:", priv, n_reals_per_wi); diff --git a/src/portfft/dispatcher/workitem_dispatcher.hpp b/src/portfft/dispatcher/workitem_dispatcher.hpp index b0a3db90..7cda0f3a 100644 --- a/src/portfft/dispatcher/workitem_dispatcher.hpp +++ b/src/portfft/dispatcher/workitem_dispatcher.hpp @@ -211,17 +211,11 @@ PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag detail::apply_modifier(fft_size, priv, load_modifier_data, i * n_reals); } if (take_conjugate_on_load) { - PORTFFT_UNROLL - for (Idx j = 0; j < fft_size; j++) { - priv[2 * j + 1] *= -1; - } + take_conjugate(priv, fft_size); } wi_dft(priv, priv, fft_size, 1, 1, wi_private_scratch); if (take_conjugate_on_store) { - PORTFFT_UNROLL - for (Idx j = 0; j < fft_size; j++) { - priv[2 * j + 1] *= -1; - } + take_conjugate(priv, fft_size); } global_data.log_dump_private("data in registers after computation:", priv, n_reals); if (multiply_on_store == detail::elementwise_multiply::APPLIED) { From c4e6549b0421086c95a07a78bea997040bdd1048 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Wed, 3 Jan 2024 11:47:35 +0000 Subject: [PATCH 20/67] ignore templated lambda C++20 warning (for now) --- src/portfft/dispatcher/global_dispatcher.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index a2a8a4b3..bb650596 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -334,6 +334,8 @@ struct committed_descriptor::run_kernel_struct& dependencies, IdxGlobal n_transforms, IdxGlobal input_offset, IdxGlobal output_offset, Scalar scale_factor, dimension_struct& dimension_data) { +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wc++20-extensions" (void)in_imag; (void)out_imag; const auto& kernels = dimension_data.kernels; @@ -427,6 +429,7 @@ struct committed_descriptor::run_kernel_struct Date: Tue, 9 Jan 2024 11:34:07 +0000 Subject: [PATCH 21/67] added event dependencies for copy before and after compute for prime sized cases --- src/portfft/dispatcher/global_dispatcher.hpp | 38 +++++++++++++++++--- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index 2223b198..0b561059 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -103,10 +103,10 @@ inline IdxGlobal increment_twiddle_offset(detail::level level, Idx factor_size) } template -void trigger_device_copy(const T* src, T* dst, IdxGlobal num_elements_to_copy, IdxGlobal src_stride, - IdxGlobal dst_stride, Idx num_copies, std::vector& event_vector, +void trigger_device_copy(const T* src, T* dst, std::size_t num_elements_to_copy, std::size_t src_stride, + std::size_t dst_stride, std::size_t num_copies, std::vector& event_vector, sycl::queue& queue) { - for (Idx i = 0; i < num_copies; i++) { + for (std::size_t i = 0; i < num_copies; i++) { event_vector.at(i) = queue.copy(src + i * src_stride, dst + i * dst_stride, num_elements_to_copy); } } @@ -422,17 +422,45 @@ struct committed_descriptor::run_kernel_struct(i), 0, num_factors, 2 * static_cast(i) * committed_size + output_offset, desc.queue, desc.scratch_ptr_1, desc.scratch_ptr_2, current_events, previous_events); }; + for (std::size_t i = 0; i < num_batches; i += max_batches_in_l2) { if (dimension_data.is_prime) { + desc.queue + .submit([&](sycl::handler& cgh) { + cgh.depends_on(current_events[0]); + auto in_acc_or_usm = detail::get_access(in, cgh); + cgh.host_task([&]() { + detail::trigger_device_copy(&in_acc_or_usm[0] + 2 * i * dimension_data.committed_length, + desc.scratch_ptr_1.get(), 2 * dimension_data.committed_length, + 2 * dimension_data.committed_length, 2 * dimension_data.length, + max_batches_in_l2, current_events, desc.queue); + }); + }) + .wait(); run_global.template operator()( - std::vector(kernels.begin() + static_cast(dimension_data.forward_factors), kernels.end()), i); + std::vector(kernels.begin() + static_cast(dimension_data.forward_factors), + kernels.begin() + static_cast(dimension_data.forward_factors)), + i); run_global.template operator()( std::vector(kernels.begin() + static_cast(dimension_data.forward_factors), kernels.end()), i); + desc.queue.submit([&](sycl::handler& cgh) { + cgh.depends_on(current_events[0]); + auto out_acc_or_usm = detail::get_access(out, cgh); + cgh.host_task([&]() { + detail::trigger_device_copy( + desc.scratch_ptr_2.get(), &out_acc_or_usm[0] + 2 * i * dimension_data.committed_length, + 2 * dimension_data.committed_length, 2 * dimension_data.length, 2 * dimension_data.committed_length, + max_batches_in_l2, current_events, desc.queue); + }); + }); } else { run_global.template operator()(kernels, i); } } - return current_events[0]; + return desc.queue.submit([&](sycl::handler& cgh) { + cgh.depends_on(current_events); + cgh.host_task([]() {}); + }); #pragma clang diagnostic pop } }; From c56475f40f5089c12a795176cb72111fe6693004 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Tue, 9 Jan 2024 13:39:24 +0000 Subject: [PATCH 22/67] fix bugs in prepare_implementation and build_w_spec_const --- src/portfft/descriptor.hpp | 71 +++++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 32 deletions(-) diff --git a/src/portfft/descriptor.hpp b/src/portfft/descriptor.hpp index 1d833062..78b1c602 100644 --- a/src/portfft/descriptor.hpp +++ b/src/portfft/descriptor.hpp @@ -317,11 +317,12 @@ class committed_descriptor { * * @tparam SubgroupSize size of the subgroup * @param kernel_num the consecutive number of the kernel to prepare - * @return implementation to use for the dimension and a vector of tuples of: implementation to use for a kernel, - * vector of kernel ids, factors + * @return implementation to use for the dimension, the size for that dimension and a vector of tuples of: + * implementation to use for a kernel, vector of kernel ids, factors */ template - std::tuple, std::vector>>> + std::tuple, std::vector>>> prepare_implementation(std::size_t kernel_num) { // TODO: check and support all the parameter values if constexpr (Domain != domain::COMPLEX) { @@ -334,7 +335,7 @@ class committed_descriptor { if (detail::fits_in_wi(fft_size)) { factors.push_back(static_cast(fft_size)); ids = detail::get_ids(); - return {detail::level::WORKITEM, {{detail::level::WORKITEM, ids, factors}}}; + return {detail::level::WORKITEM, fft_size, {{detail::level::WORKITEM, ids, factors}}}; } if (detail::fits_in_sg(fft_size, SubgroupSize)) { Idx factor_sg = detail::factorize_sg(static_cast(fft_size), SubgroupSize); @@ -344,7 +345,7 @@ class committed_descriptor { factors.push_back(factor_wi); factors.push_back(factor_sg); ids = detail::get_ids(); - return {detail::level::SUBGROUP, {{detail::level::SUBGROUP, ids, factors}}}; + return {detail::level::SUBGROUP, fft_size, {{detail::level::SUBGROUP, ids, factors}}}; } IdxGlobal n_idx_global = detail::factorize(fft_size); if (detail::can_cast_safely(n_idx_global) && @@ -375,7 +376,7 @@ class committed_descriptor { // This factorization of N and M is duplicated in the dispatch logic on the device. // The CT and spec constant factors should match. ids = detail::get_ids(); - return {detail::level::WORKGROUP, {{detail::level::WORKGROUP, ids, factors}}}; + return {detail::level::WORKGROUP, fft_size, {{detail::level::WORKGROUP, ids, factors}}}; } } std::vector, std::vector>> param_vec; @@ -417,8 +418,9 @@ class committed_descriptor { static_cast(std::pow(2, ceil(log(static_cast(fft_size)) / log(2.0)))); detail::factorize_input(padded_fft_size, check_and_select_target_level, true); detail::factorize_input(padded_fft_size, check_and_select_target_level, false); + fft_size = padded_fft_size; } - return {detail::level::GLOBAL, param_vec}; + return {detail::level::GLOBAL, fft_size, param_vec}; } /** @@ -526,7 +528,7 @@ class committed_descriptor { detail::elementwise_multiply last_uses_load_modifier, Idx num_kernels) { Idx counter = 0; for (auto& [level, in_bundle, factors] : prepared_vec) { - if (counter > num_kernels) { + if (counter >= num_kernels) { break; } if (counter == num_kernels - 1) { @@ -564,7 +566,7 @@ class committed_descriptor { template dimension_struct build_w_spec_const(std::size_t kernel_num) { if (std::count(supported_sg_sizes.begin(), supported_sg_sizes.end(), SubgroupSize)) { - auto [top_level, prepared_vec] = prepare_implementation(kernel_num); + auto [top_level, dimension_size, prepared_vec] = prepare_implementation(kernel_num); bool is_compatible = true; for (auto [level, ids, factors] : prepared_vec) { is_compatible = is_compatible && sycl::is_compatible(ids, dev); @@ -572,40 +574,48 @@ class committed_descriptor { break; } } - std::vector result; - std::size_t counter = 0; - std::size_t dimension_size = 1; + bool is_prime = false; + if (dimension_size != static_cast(params.lengths[kernel_num])) { + is_prime = true; + } + IdxGlobal num_forward_factors = 0; + IdxGlobal num_backward_factors = 0; + IdxGlobal temp = 1; for (const auto& [level, ids, factors] : prepared_vec) { - dimension_size *= - static_cast(std::accumulate(factors.begin(), factors.end(), 1, std::multiplies())); - counter++; + if (temp == dimension_size) { + break; + } + temp *= static_cast(std::accumulate(factors.begin(), factors.end(), 1, std::multiplies())); + num_forward_factors++; } - std::size_t backward_factors = prepared_vec.size() - counter; + num_backward_factors = static_cast(prepared_vec.size()) - num_forward_factors; + std::vector result; std::vector in_bundles; if (is_compatible) { - for (auto [level, ids, factors] : prepared_vec) { + for (const auto& [level, ids, factors] : prepared_vec) { in_bundles.emplace_back(level, sycl::get_kernel_bundle(queue.get_context(), ids), factors); } if (top_level == detail::level::GLOBAL) { detail::elementwise_multiply first_uses_load_modifiers = - backward_factors > 0 ? detail::elementwise_multiply::APPLIED : detail::elementwise_multiply::NOT_APPLIED; + is_prime ? detail::elementwise_multiply::APPLIED : detail::elementwise_multiply::NOT_APPLIED; detail::elementwise_multiply last_uses_store_modifiers = - backward_factors > 0 ? detail::elementwise_multiply::APPLIED : detail::elementwise_multiply::NOT_APPLIED; + is_prime ? detail::elementwise_multiply::APPLIED : detail::elementwise_multiply::NOT_APPLIED; set_global_impl_spec_consts(in_bundles, first_uses_load_modifiers, last_uses_store_modifiers, - static_cast(counter)); - if (backward_factors > 0) { - std::vector backward_kernels_slice(in_bundles.begin() + static_cast(counter), - in_bundles.end()); + static_cast(num_forward_factors)); + if (is_prime) { + std::vector backward_kernels_slice( + in_bundles.begin() + static_cast(num_forward_factors), in_bundles.end()); set_global_impl_spec_consts(backward_kernels_slice, detail::elementwise_multiply::NOT_APPLIED, - last_uses_store_modifiers, static_cast(backward_factors)); + last_uses_store_modifiers, static_cast(num_backward_factors)); std::copy(backward_kernels_slice.begin(), backward_kernels_slice.end(), - in_bundles.begin() + static_cast(counter)); + in_bundles.begin() + static_cast(num_forward_factors)); } } else { for (auto& [level, in_bundle, factors] : in_bundles) { - set_spec_constants(top_level, in_bundle, dimension_size, factors, detail::elementwise_multiply::NOT_APPLIED, - detail::elementwise_multiply::NOT_APPLIED, detail::apply_scale_factor::APPLIED, level); + set_spec_constants(top_level, in_bundle, static_cast(dimension_size), factors, + detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::NOT_APPLIED, + detail::apply_scale_factor::APPLIED, level); } } @@ -620,11 +630,8 @@ class committed_descriptor { break; } } - bool is_prime = false; - if (backward_factors > 0) { - is_prime = true; - } - return {result, top_level, dimension_size, params.lengths[kernel_num], is_prime, SubgroupSize}; + return {result, top_level, static_cast(dimension_size), params.lengths[kernel_num], + is_prime, SubgroupSize}; } } if constexpr (sizeof...(OtherSGSizes) == 0) { From 7e6dd9e8404d8675defa9156d01b81d9d025ad8e Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Tue, 9 Jan 2024 15:55:57 +0000 Subject: [PATCH 23/67] add option to increment store modifier pointer --- src/portfft/common/global.hpp | 11 +++++++---- src/portfft/specialization_constant.hpp | 2 ++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/portfft/common/global.hpp b/src/portfft/common/global.hpp index 6d05ab6c..6a6e4d89 100644 --- a/src/portfft/common/global.hpp +++ b/src/portfft/common/global.hpp @@ -140,26 +140,29 @@ PORTFFT_INLINE void dispatch_level(const Scalar* input, Scalar* output, const Sc auto level = kh.get_specialization_constant(); Idx level_num = kh.get_specialization_constant(); Idx num_factors = kh.get_specialization_constant(); + bool increment_modifier_pointer = + kh.get_specialization_constant(); // Should it be a spec constant ? global_data.log_message_global(__func__, "dispatching sub implementation for factor num = ", level_num); IdxGlobal outer_batch_product = get_outer_batch_product(inclusive_scan, num_factors, level_num); for (IdxGlobal iter_value = 0; iter_value < outer_batch_product; iter_value++) { IdxGlobal outer_batch_offset = get_outer_batch_offset(factors, inner_batches, inclusive_scan, num_factors, level_num, iter_value, outer_batch_product); + auto store_modifier_offset = increment_modifier_pointer ? outer_batch_offset : 0; if (level == detail::level::WORKITEM) { workitem_impl( input + outer_batch_offset, output + outer_batch_offset, nullptr, nullptr, input_loc, batch_size, - scale_factor, global_data, kh, static_cast(nullptr), store_modifier_data, - static_cast(nullptr), store_modifier_loc); + scale_factor, global_data, kh, static_cast(nullptr), + store_modifier_data + store_modifier_offset, static_cast(nullptr), store_modifier_loc); } else if (level == detail::level::SUBGROUP) { subgroup_impl( input + outer_batch_offset, output + outer_batch_offset, nullptr, nullptr, input_loc, twiddles_loc, batch_size, implementation_twiddles, scale_factor, global_data, kh, static_cast(nullptr), - store_modifier_data, static_cast(nullptr), store_modifier_loc); + store_modifier_data + store_modifier_offset, static_cast(nullptr), store_modifier_loc); } else if (level == detail::level::WORKGROUP) { workgroup_impl( input + outer_batch_offset, output + outer_batch_offset, nullptr, nullptr, input_loc, twiddles_loc, batch_size, implementation_twiddles, scale_factor, global_data, kh, static_cast(nullptr), - store_modifier_data); + store_modifier_data + store_modifier_offset); } sycl::group_barrier(global_data.it.get_group()); } diff --git a/src/portfft/specialization_constant.hpp b/src/portfft/specialization_constant.hpp index 5c18773b..7dbaea99 100644 --- a/src/portfft/specialization_constant.hpp +++ b/src/portfft/specialization_constant.hpp @@ -47,6 +47,8 @@ constexpr static sycl::specialization_id GlobalSpecConstNumFactors{}; constexpr static sycl::specialization_id SpecConstTakeConjugateOnLoad{}; constexpr static sycl::specialization_id SpecConstTakeConjugateOnStore{}; +constexpr static sycl::specialization_id SpecConstIncrementModifierPointer{}; + } // namespace detail } // namespace portfft #endif From 78e1372ab6427b3f5350ce01abfb798ad336c5e2 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Tue, 9 Jan 2024 17:07:23 +0000 Subject: [PATCH 24/67] remove readability-magic-numbers from clang-tidy --- .clang-tidy | 1 + 1 file changed, 1 insertion(+) diff --git a/.clang-tidy b/.clang-tidy index 0b3225d5..1fe9e338 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -14,6 +14,7 @@ Checks: > performance-*, -performance-avoid-endl, readability-*, + -readability-magic-numbers, -readability-function-cognitive-complexity, -readability-identifier-length, -readability-named-parameter, From 730fecc26279ca11067873f6a0f8467c62585bfa Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Thu, 11 Jan 2024 17:29:54 +0000 Subject: [PATCH 25/67] further changes in global's calculate_twiddles_struct to accomodate bluestein --- src/portfft/common/bluestein.hpp | 19 +- src/portfft/dispatcher/global_dispatcher.hpp | 321 +++++++++++-------- 2 files changed, 191 insertions(+), 149 deletions(-) diff --git a/src/portfft/common/bluestein.hpp b/src/portfft/common/bluestein.hpp index 5619a82c..acc0c3d9 100644 --- a/src/portfft/common/bluestein.hpp +++ b/src/portfft/common/bluestein.hpp @@ -33,13 +33,12 @@ namespace detail { /** * Utility function to get chirp signal and fft * @tparam T Scalar Type - * @param ptr Device Pointer containing the load/store modifiers. + * @param ptr Host Pointer containing the load/store modifiers. * @param committed_size original problem size * @param dimension_size padded size - * @param queue queue with the committed descriptor */ template -void get_fft_chirp_signal(T* ptr, IdxGlobal committed_size, IdxGlobal dimension_size, sycl::queue& queue) { +void get_fft_chirp_signal(T* ptr, IdxGlobal committed_size, IdxGlobal dimension_size) { using ctype = std::complex; ctype* chirp_signal = (ctype*)calloc(static_cast(dimension_size), sizeof(ctype)); ctype* chirp_fft = (ctype*)malloc(static_cast(dimension_size) * sizeof(ctype)); @@ -52,19 +51,25 @@ void get_fft_chirp_signal(T* ptr, IdxGlobal committed_size, IdxGlobal dimension_ chirp_signal[committed_size + num_zeros + i - 1] = chirp_signal[committed_size - i]; } naive_dft(chirp_signal, chirp_fft, dimension_size); - queue.copy(reinterpret_cast(&chirp_fft[0]), ptr, static_cast(2 * dimension_size)).wait(); + std::memcpy(ptr, reinterpret_cast(&chirp_fft[0]), static_cast(2 * dimension_size) * sizeof(T)); } +/** + * Populates input modifiers required for bluestein + * @tparam T Scalar Type + * @param ptr Host Pointer containing the load/store modifiers. + * @param committed_size original problem size + * @param dimension_size padded size + */ template -void populate_bluestein_input_modifiers(T* ptr, IdxGlobal committed_size, IdxGlobal dimension_size, - sycl::queue& queue) { +void populate_bluestein_input_modifiers(T* ptr, IdxGlobal committed_size, IdxGlobal dimension_size) { using ctype = std::complex; ctype* scratch = (ctype*)calloc(static_cast(dimension_size), sizeof(ctype)); for (IdxGlobal i = 0; i < committed_size; i++) { double theta = -M_PI * static_cast(i * i) / static_cast(committed_size); scratch[i] = ctype(static_cast(std::cos(theta)), static_cast(std::sin(theta))); } - queue.copy(reinterpret_cast(&scratch[0]), ptr, static_cast(2 * dimension_size)); + std::memcpy(ptr, reinterpret_cast(&scratch[0]), static_cast(2 * dimension_size) * sizeof(T)); } } // namespace detail } // namespace portfft diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index 0b561059..bd121c15 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -118,6 +118,151 @@ template struct committed_descriptor::calculate_twiddles_struct::inner { static Scalar* execute(committed_descriptor& desc, dimension_struct& dimension_data) { auto& kernels = dimension_data.kernels; + + /** + * Helper Lambda to calculate twiddles + */ + auto calculate_twiddles = [](IdxGlobal N, IdxGlobal M, IdxGlobal& offset, Scalar* ptr) { + for (IdxGlobal i = 0; i < N; i++) { + for (IdxGlobal j = 0; j < M; j++) { + double theta = -2 * M_PI * static_cast(i * j) / static_cast(N * M); + ptr[offset++] = static_cast(std::cos(theta)); + ptr[offset++] = static_cast(std::sin(theta)); + } + } + }; + + /** + * Gets cumulative global memory requirements for provided set of factors and sub batches. + */ + auto get_cumulative_memory_requirememts = [&](std::vector& factors, std::vector& sub_batches, + direction dir) -> IdxGlobal { + // calculate sizes for modifiers + std::size_t num_factors = static_cast(dir == direction::FORWARD ? dimension_data.forward_factors + : dimension_data.backward_factors); + std::size_t offset = static_cast(dir == direction::FORWARD ? 0 : dimension_data.backward_factors); + IdxGlobal total_memory = 0; + + // get memory for modifiers + for (std::size_t i = 0; i < num_factors - 1; i++) { + total_memory += 2 * factors.at(offset + i) * sub_batches.at(offset + i); + } + + // Get memory required for twiddles per sub-level + for (std::size_t i = 0; i < num_factors; i++) { + auto level = kernels.at(offset + i).level; + if (level == detail::level::SUBGROUP) { + total_memory += 2 * factors.at(offset + i); + } else if (level == detail::level::WORKGROUP) { + IdxGlobal factor_1 = detail::factorize(factors.at(offset + i)); + IdxGlobal factor_2 = factors.at(offset + i) / factor_1; + total_memory += 2 * (factor_1 * factor_2) + 2 * (factor_1 + factor_2); + } + } + return total_memory; + }; + + /** + * populates and rearranges twiddles on host pointer and populates + * the kernel specific metadata (launch params and local mem requirements for twiddles only) + */ + auto populate_twiddles_and_metadata = [&](Scalar* ptr, std::vector& factors, + std::vector& sub_batches, IdxGlobal& ptr_offset, + direction dir) -> void { + std::size_t num_factors = static_cast(dir == direction::FORWARD ? dimension_data.forward_factors + : dimension_data.backward_factors); + std::size_t offset = static_cast(dir == direction::FORWARD ? 0 : dimension_data.backward_factors); + Scalar* scratch_ptr = (Scalar*)malloc(2 * dimension_data.length * sizeof(Scalar)); + + // generate and rearrange store modifiers + for (std::size_t i = 0; i < num_factors - 1; i++) { + calculate_twiddles(sub_batches.at(offset + i), factors.at(offset + i), ptr_offset, ptr); + if (kernels.at(offset + i).level == detail::level::WORKITEM) { + // For the WI implementation, utilize coalesced loads from global as they are not being reused. + // shift them to local memory only for devices which do not have coalesced accesses. + detail::complex_transpose(ptr + ptr_offset, scratch_ptr, factors.at(offset + i), sub_batches.at(offset + i), + factors.at(offset + i) * sub_batches.at(offset + i)); + std::memcpy(ptr + ptr_offset, scratch_ptr, + 2 * factors.at(offset + i) * sub_batches.at(offset + i) * sizeof(Scalar)); + } + } + + // Calculate twiddles for the implementation corresponding to per factor; + for (Idx i = 0; i < num_factors; i++) { + const auto& kernel_data = kernels.at(offset + i); + if (kernels.at(offset + i).level == detail::level::SUBGROUP) { + for (Idx j = 0; j < kernel_data.factors.at(0); j++) { + for (Idx k = 0; k < kernel_data.factors.at(1); k++) { + double theta = -2 * M_PI * static_cast(j * k) / + static_cast(kernel_data.factors.at(0) * kernel_data.factors.at(1)); + auto twiddle = + std::complex(static_cast(std::cos(theta)), static_cast(std::sin(theta))); + ptr[offset + static_cast(j * kernel_data.factors.at(0) + i)] = twiddle.real(); + ptr[offset + static_cast((j + kernel_data.factors.at(1)) * kernel_data.factors.at(0) + i)] = + twiddle.imag(); + } + } + ptr_offset += 2 * kernel_data.factors.at(0) * kernel_data.factors.at(1); + } else if (kernels.at(offset + i).level == detail::level::WORKGROUP) { + Idx factor_n = kernel_data.factors.at(0) * kernel_data.factors.at(1); + Idx factor_m = kernel_data.factors.at(2) * kernel_data.factors.at(3); + calculate_twiddles(static_cast(kernel_data.factors.at(0)), + static_cast(kernel_data.factors.at(1)), ptr_offset, ptr); + calculate_twiddles(static_cast(kernel_data.factors.at(2)), + static_cast(kernel_data.factors.at(3)), ptr_offset, ptr); + // Calculate wg twiddles and transpose them + calculate_twiddles(static_cast(factor_n), static_cast(factor_m), ptr_offset, ptr); + for (Idx j = 0; j < factor_n; j++) { + detail::complex_transpose(ptr + offset + 2 * j * factor_n, scratch_ptr, factor_m, factor_n, + factor_n * factor_m); + std::memcpy(ptr + offset + 2 * j * factor_n, scratch_ptr, + static_cast(2 * factor_n * factor_m) * sizeof(float)); + } + } + } + + // Populate Metadata + for (std::size_t i = 0; i < num_factors; i++) { + auto& kernel_data = kernels.at(offset + i); + kernel_data.batch_size = sub_batches.at(offset + i); + kernel_data.length = static_cast(factors.at(offset + i)); + if (kernel_data.level == detail::level::WORKITEM) { + Idx num_sgs_in_wg = PORTFFT_SGS_IN_WG; + if (i < kernels.size() - 1) { + kernel_data.local_mem_required = static_cast(1); + } else { + kernel_data.local_mem_required = desc.num_scalars_in_local_mem( + detail::level::WORKITEM, static_cast(factors.at(offset + i)), kernel_data.used_sg_size, + {static_cast(factors.at(offset + i))}, num_sgs_in_wg); + } + auto [global_range, local_range] = + detail::get_launch_params(factors.at(offset + i), sub_batches.at(offset + i), detail::level::WORKITEM, + desc.n_compute_units, kernel_data.used_sg_size, num_sgs_in_wg); + kernel_data.global_range = global_range; + kernel_data.local_range = local_range; + } else if (kernel_data.level == detail::level::SUBGROUP) { + Idx num_sgs_in_wg = PORTFFT_SGS_IN_WG; + IdxGlobal factor_sg = detail::factorize_sg(factors.at(offset + i), kernel_data.used_sg_size); + IdxGlobal factor_wi = factors.at(offset + i) / factor_sg; + if (i < kernels.size() - 1) { + kernel_data.local_mem_required = desc.num_scalars_in_local_mem( + detail::level::SUBGROUP, static_cast(factors.at(offset + i)), kernel_data.used_sg_size, + {static_cast(factor_sg), static_cast(factor_wi)}, num_sgs_in_wg); + } else { + kernel_data.local_mem_required = desc.num_scalars_in_local_mem( + detail::level::SUBGROUP, static_cast(factors.at(offset + i)), kernel_data.used_sg_size, + {static_cast(factor_sg), static_cast(factor_wi)}, num_sgs_in_wg); + } + auto [global_range, local_range] = + detail::get_launch_params(factors.at(offset + i), sub_batches.at(offset + i), detail::level::SUBGROUP, + desc.n_compute_units, kernel_data.used_sg_size, num_sgs_in_wg); + kernel_data.global_range = global_range; + kernel_data.local_range = local_range; + } + } + free(scratch_ptr); + }; + std::vector factors_idx_global; IdxGlobal temp_acc = 1; // Get factor sizes per level; @@ -131,167 +276,59 @@ struct committed_descriptor::calculate_twiddles_struct::inner(factors_idx_global.size()); dimension_data.backward_factors = static_cast(kernels.size()) - dimension_data.forward_factors; + for (const auto& kernel_data : kernels.begin() + static_cast(factors_idx_global.size())) { + factors_idx_global.push_back(static_cast( + std::accumulate(kernel_data.factors.begin(), kernel_data.factors.end(), 1, std::multiplies()))); + } + + // Get sub batches per direction std::vector sub_batches; - // Get sub batches - for (std::size_t i = 0; i < factors_idx_global.size() - 1; i++) { - sub_batches.push_back(std::accumulate(factors_idx_global.begin() + static_cast(i + 1), - factors_idx_global.end(), IdxGlobal(1), std::multiplies())); + for (Idx i = 0; i < dimension_data.forward_factors - 1; i++) { + sub_batches.push_back( + std::accumulate(factors_idx_global.begin() + static_cast(i + 1), + factors_idx_global.begin() + static_cast(dimension_data.forward_factors), IdxGlobal(1), + std::multiplies())); } - sub_batches.push_back(factors_idx_global.at(factors_idx_global.size() - 2)); - // factors and inner batches for the backward factors; + sub_batches.push_back(factors_idx_global.at(static_cast(dimension_data.forward_factors - 2))); if (dimension_data.backward_factors > 0) { - for (Idx i = 0; i < dimension_data.backward_factors; i++) { - const auto& kd_struct = kernels.at(static_cast(dimension_data.forward_factors + i)); - factors_idx_global.push_back(static_cast( - std::accumulate(kd_struct.factors.begin(), kd_struct.factors.end(), 1, std::multiplies()))); - } - for (Idx back_factor = 0; back_factor < dimension_data.backward_factors - 1; back_factor++) { - sub_batches.push_back(std::accumulate( - factors_idx_global.begin() + static_cast(dimension_data.forward_factors + back_factor + 1), - factors_idx_global.end(), IdxGlobal(1), std::multiplies())); + for (Idx i = 0; i < dimension_data.backward_factors - 1; i++) { + sub_batches.push_back( + std::accumulate(factors_idx_global.begin() + static_cast(dimension_data.forward_factors + i + 1), + factors_idx_global.end(), IdxGlobal(1), std::multiplies())); } + sub_batches.push_back(factors_idx_global.at(factors_idx_global.size() - 2)); } - // calculate total memory required for twiddles; - IdxGlobal mem_required_for_twiddles = 0; - // First calculate mem required for twiddles between factors; - for (std::size_t i = 0; i < static_cast(dimension_data.forward_factors - 1); i++) { - mem_required_for_twiddles += 2 * factors_idx_global.at(i) * sub_batches.at(i); - } + + // Get total Global memory required to store all the twiddles and multipliers. + IdxGlobal mem_required_for_twiddles = + get_cumulative_memory_requirememts(factors_idx_global, sub_batches, direction::FORWARD); if (dimension_data.backward_factors > 0) { - for (std::size_t i = 0; i < static_cast(dimension_data.backward_factors - 1); i++) { - mem_required_for_twiddles += 2 * factors_idx_global.at(i) * sub_batches.at(i); - } + mem_required_for_twiddles += + get_cumulative_memory_requirememts(factors_idx_global, sub_batches, direction::BACKWARD); + // Presence of backward factors signifies that Bluestein will be used. + // Thus take into account memory required for load modifiers as well. mem_required_for_twiddles += static_cast(4 * dimension_data.length); } - // Now calculate mem required for twiddles per implementation - std::size_t counter = 0; - for (const auto& kernel_data : kernels) { - if (kernel_data.level == detail::level::SUBGROUP) { - mem_required_for_twiddles += 2 * factors_idx_global.at(counter); - } else if (kernel_data.level == detail::level::WORKGROUP) { - IdxGlobal factor_1 = detail::factorize(factors_idx_global.at(counter)); - IdxGlobal factor_2 = factors_idx_global.at(counter) / factor_1; - mem_required_for_twiddles += 2 * (factor_1 * factor_2) + 2 * (factor_1 + factor_2); - } - counter++; - } std::vector host_memory(static_cast(mem_required_for_twiddles)); - std::vector scratch_space(static_cast(mem_required_for_twiddles)); Scalar* device_twiddles = sycl::malloc_device(static_cast(mem_required_for_twiddles), desc.queue); - // Helper Lambda to calculate twiddles - auto calculate_twiddles = [](IdxGlobal N, IdxGlobal M, IdxGlobal& offset, Scalar* ptr) { - for (IdxGlobal i = 0; i < N; i++) { - for (IdxGlobal j = 0; j < M; j++) { - double theta = -2 * M_PI * static_cast(i * j) / static_cast(N * M); - ptr[offset++] = static_cast(std::cos(theta)); - ptr[offset++] = static_cast(std::sin(theta)); - } - } - }; - IdxGlobal offset = 0; if (dimension_data.is_prime) { - // get bluestein specific modifiers. - detail::get_fft_chirp_signal(device_twiddles + offset, static_cast(dimension_data.committed_length), - static_cast(dimension_data.length), desc.queue); + // first populate load modifiers for bluestein. + detail::get_fft_chirp_signal(host_memory.data() + offset, static_cast(dimension_data.committed_length), + static_cast(dimension_data.length)); offset += static_cast(2 * dimension_data.length); - detail::populate_bluestein_input_modifiers(device_twiddles + offset, + detail::populate_bluestein_input_modifiers(host_memory.data() + offset, static_cast(dimension_data.committed_length), - static_cast(dimension_data.length), desc.queue); - } - // calculate twiddles to be multiplied between factors - for (std::size_t i = 0; i < static_cast(dimension_data.forward_factors) - 1; i++) { - calculate_twiddles(sub_batches.at(i), factors_idx_global.at(i), offset, host_memory.data()); - } - if (dimension_data.backward_factors > 0) { - for (std::size_t i = 0; i < static_cast(dimension_data.backward_factors) - 1; i++) { - calculate_twiddles(sub_batches.at(i), factors_idx_global.at(i), offset, host_memory.data()); - } - } - // Now calculate per twiddles. - counter = 0; - for (const auto& kernel_data : kernels) { - if (kernel_data.level == detail::level::SUBGROUP) { - for (Idx i = 0; i < kernel_data.factors.at(0); i++) { - for (Idx j = 0; j < kernel_data.factors.at(1); j++) { - double theta = -2 * M_PI * static_cast(i * j) / - static_cast(kernel_data.factors.at(0) * kernel_data.factors.at(1)); - auto twiddle = - std::complex(static_cast(std::cos(theta)), static_cast(std::sin(theta))); - host_memory[static_cast(offset + static_cast(j * kernel_data.factors.at(0) + i))] = - twiddle.real(); - host_memory[static_cast( - offset + static_cast((j + kernel_data.factors.at(1)) * kernel_data.factors.at(0) + i))] = - twiddle.imag(); - } - } - offset += 2 * kernel_data.factors.at(0) * kernel_data.factors.at(1); - } else if (kernel_data.level == detail::level::WORKGROUP) { - Idx factor_n = kernel_data.factors.at(0) * kernel_data.factors.at(1); - Idx factor_m = kernel_data.factors.at(2) * kernel_data.factors.at(3); - calculate_twiddles(static_cast(kernel_data.factors.at(0)), - static_cast(kernel_data.factors.at(1)), offset, host_memory.data()); - calculate_twiddles(static_cast(kernel_data.factors.at(2)), - static_cast(kernel_data.factors.at(3)), offset, host_memory.data()); - // Calculate wg twiddles and transpose them - calculate_twiddles(static_cast(factor_n), static_cast(factor_m), offset, - host_memory.data()); - for (Idx j = 0; j < factor_n; j++) { - detail::complex_transpose(host_memory.data() + offset + 2 * j * factor_n, scratch_space.data(), factor_m, - factor_n, factor_n * factor_m); - } - } - counter++; + static_cast(dimension_data.length)); + offset += static_cast(2 * dimension_data.length); } - // Rearrage the twiddles between factors for optimal access patters in shared memory - // Also take this opportunity to populate local memory size, and batch size, and launch params and local memory - // usage Note, global impl only uses store modifiers - // TODO: there is a heap corruption in workitem's access of loaded modifiers, hence loading from global directly for - // now. - counter = 0; - for (auto& kernel_data : kernels) { - kernel_data.batch_size = sub_batches.at(counter); - kernel_data.length = static_cast(factors_idx_global.at(counter)); - if (kernel_data.level == detail::level::WORKITEM) { - // See comments in workitem_dispatcher for layout requirments. - Idx num_sgs_in_wg = PORTFFT_SGS_IN_WG; - if (counter < kernels.size() - 1) { - kernel_data.local_mem_required = static_cast(1); - } else { - kernel_data.local_mem_required = desc.num_scalars_in_local_mem( - detail::level::WORKITEM, static_cast(factors_idx_global.at(counter)), - kernel_data.used_sg_size, {static_cast(factors_idx_global.at(counter))}, num_sgs_in_wg); - } - auto [global_range, local_range] = - detail::get_launch_params(factors_idx_global.at(counter), sub_batches.at(counter), detail::level::WORKITEM, - desc.n_compute_units, kernel_data.used_sg_size, num_sgs_in_wg); - kernel_data.global_range = global_range; - kernel_data.local_range = local_range; - } else if (kernel_data.level == detail::level::SUBGROUP) { - Idx num_sgs_in_wg = PORTFFT_SGS_IN_WG; - // See comments in subgroup_dispatcher for layout requirements. - IdxGlobal factor_sg = detail::factorize_sg(factors_idx_global.at(counter), kernel_data.used_sg_size); - IdxGlobal factor_wi = factors_idx_global.at(counter) / factor_sg; - if (counter < kernels.size() - 1) { - kernel_data.local_mem_required = desc.num_scalars_in_local_mem( - detail::level::SUBGROUP, static_cast(factors_idx_global.at(counter)), - kernel_data.used_sg_size, {static_cast(factor_sg), static_cast(factor_wi)}, num_sgs_in_wg); - } else { - kernel_data.local_mem_required = desc.num_scalars_in_local_mem( - detail::level::SUBGROUP, static_cast(factors_idx_global.at(counter)), - kernel_data.used_sg_size, {static_cast(factor_sg), static_cast(factor_wi)}, num_sgs_in_wg); - } - auto [global_range, local_range] = - detail::get_launch_params(factors_idx_global.at(counter), sub_batches.at(counter), detail::level::SUBGROUP, - desc.n_compute_units, kernel_data.used_sg_size, num_sgs_in_wg); - kernel_data.global_range = global_range; - kernel_data.local_range = local_range; - } - counter++; + populate_twiddles_and_metadata(host_memory.data(), factors_idx_global, sub_batches, offset, direction::FORWARD); + if (dimension_data.backward_factors) { + populate_twiddles_and_metadata(host_memory.data(), factors_idx_global, sub_batches, offset, direction::BACKWARD); } desc.queue.copy(host_memory.data(), device_twiddles, static_cast(mem_required_for_twiddles)).wait(); return device_twiddles; From 1be4afc28876a384316651fb22c87502cd92914e Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Fri, 12 Jan 2024 10:54:27 +0000 Subject: [PATCH 26/67] fix compilation and warning --- src/portfft/dispatcher/global_dispatcher.hpp | 27 ++++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index bd121c15..469e4665 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -182,24 +182,27 @@ struct committed_descriptor::calculate_twiddles_struct::inner(2 * factors.at(offset + i) * sub_batches.at(offset + i)) * sizeof(Scalar)); } } // Calculate twiddles for the implementation corresponding to per factor; - for (Idx i = 0; i < num_factors; i++) { + for (std::size_t i = 0; i < num_factors; i++) { const auto& kernel_data = kernels.at(offset + i); if (kernels.at(offset + i).level == detail::level::SUBGROUP) { - for (Idx j = 0; j < kernel_data.factors.at(0); j++) { - for (Idx k = 0; k < kernel_data.factors.at(1); k++) { + for (std::size_t j = 0; j < std::size_t(kernel_data.factors.at(0)); j++) { + for (std::size_t k = 0; k < std::size_t(kernel_data.factors.at(1)); k++) { double theta = -2 * M_PI * static_cast(j * k) / static_cast(kernel_data.factors.at(0) * kernel_data.factors.at(1)); auto twiddle = std::complex(static_cast(std::cos(theta)), static_cast(std::sin(theta))); - ptr[offset + static_cast(j * kernel_data.factors.at(0) + i)] = twiddle.real(); - ptr[offset + static_cast((j + kernel_data.factors.at(1)) * kernel_data.factors.at(0) + i)] = - twiddle.imag(); + ptr[offset + j * static_cast(kernel_data.factors.at(0)) + i] = twiddle.real(); + ptr[offset + + (j + static_cast(kernel_data.factors.at(1))) * + static_cast(kernel_data.factors.at(0)) + + i] = twiddle.imag(); } } ptr_offset += 2 * kernel_data.factors.at(0) * kernel_data.factors.at(1); @@ -276,16 +279,18 @@ struct committed_descriptor::calculate_twiddles_struct::inner(factors_idx_global.size()); dimension_data.backward_factors = static_cast(kernels.size()) - dimension_data.forward_factors; - for (const auto& kernel_data : kernels.begin() + static_cast(factors_idx_global.size())) { + for (std::size_t i = 0; i < std::size_t(dimension_data.backward_factors); i++) { factors_idx_global.push_back(static_cast( - std::accumulate(kernel_data.factors.begin(), kernel_data.factors.end(), 1, std::multiplies()))); + std::accumulate(kernels.at(i + static_cast(dimension_data.forward_factors)).factors.begin(), + kernels.at(i + static_cast(dimension_data.forward_factors)).factors.end(), 1, + std::multiplies()))); } // Get sub batches per direction std::vector sub_batches; for (Idx i = 0; i < dimension_data.forward_factors - 1; i++) { sub_batches.push_back( - std::accumulate(factors_idx_global.begin() + static_cast(i + 1), + std::accumulate(factors_idx_global.begin() + i + 1, factors_idx_global.begin() + static_cast(dimension_data.forward_factors), IdxGlobal(1), std::multiplies())); } From 4cf6ca739843ce1b3a963a20bbcabd850d904811 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Fri, 12 Jan 2024 11:34:13 +0000 Subject: [PATCH 27/67] updated apply_modifier in workitem impl after layout changes --- src/portfft/dispatcher/workitem_dispatcher.hpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/portfft/dispatcher/workitem_dispatcher.hpp b/src/portfft/dispatcher/workitem_dispatcher.hpp index a852e759..37cbd4f8 100644 --- a/src/portfft/dispatcher/workitem_dispatcher.hpp +++ b/src/portfft/dispatcher/workitem_dispatcher.hpp @@ -65,11 +65,13 @@ IdxGlobal get_global_size_workitem(IdxGlobal n_transforms, Idx subgroup_size, Id * @param offset offset for the global modifier data pointer */ template -PORTFFT_INLINE void apply_modifier(Idx num_elements, PrivT priv, const T* modifier_data, IdxGlobal offset) { +PORTFFT_INLINE void apply_modifier(Idx num_elements, PrivT priv, const T* modifier_data, IdxGlobal modifier_stride, + IdxGlobal modifier_offset) { + using vec = sycl::vec; PORTFFT_UNROLL for (Idx j = 0; j < num_elements; j++) { - sycl::vec modifier_vec; - modifier_vec.load(0, detail::get_global_multi_ptr(&modifier_data[offset + 2 * j])); + vec modifier_vec; + modifier_vec = *reinterpret_cast(&modifier_data[j * modifier_stride + modifier_offset]); if (Dir == direction::BACKWARD) { modifier_vec[1] *= -1; } @@ -208,7 +210,7 @@ PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag // Assumes load modifier data is stored in a transposed fashion (fft_size x num_batches_local_mem) // to ensure much lesser bank conflicts global_data.log_message_global(__func__, "applying load modifier"); - detail::apply_modifier(fft_size, priv, load_modifier_data, i * n_reals); + detail::apply_modifier(fft_size, priv, load_modifier_data, 2 * n_transforms, 2 * i); } if (take_conjugate_on_load) { take_conjugate(priv, fft_size); @@ -222,7 +224,7 @@ PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag // Assumes store modifier data is stored in a transposed fashion (fft_size x num_batches_local_mem) // to ensure much lesser bank conflicts global_data.log_message_global(__func__, "applying store modifier"); - detail::apply_modifier(fft_size, priv, store_modifier_data, i * n_reals); + detail::apply_modifier(fft_size, priv, store_modifier_data, 2 * n_transforms, 2 * i); } if (apply_scale_factor == detail::apply_scale_factor::APPLIED) { PORTFFT_UNROLL From 935477d189cd68ae213903f3194fe760f128bc3c Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Fri, 12 Jan 2024 14:07:56 +0000 Subject: [PATCH 28/67] Revert "updated apply_modifier in workitem impl after layout changes" This reverts commit 4cf6ca739843ce1b3a963a20bbcabd850d904811. --- src/portfft/dispatcher/workitem_dispatcher.hpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/portfft/dispatcher/workitem_dispatcher.hpp b/src/portfft/dispatcher/workitem_dispatcher.hpp index 37cbd4f8..a852e759 100644 --- a/src/portfft/dispatcher/workitem_dispatcher.hpp +++ b/src/portfft/dispatcher/workitem_dispatcher.hpp @@ -65,13 +65,11 @@ IdxGlobal get_global_size_workitem(IdxGlobal n_transforms, Idx subgroup_size, Id * @param offset offset for the global modifier data pointer */ template -PORTFFT_INLINE void apply_modifier(Idx num_elements, PrivT priv, const T* modifier_data, IdxGlobal modifier_stride, - IdxGlobal modifier_offset) { - using vec = sycl::vec; +PORTFFT_INLINE void apply_modifier(Idx num_elements, PrivT priv, const T* modifier_data, IdxGlobal offset) { PORTFFT_UNROLL for (Idx j = 0; j < num_elements; j++) { - vec modifier_vec; - modifier_vec = *reinterpret_cast(&modifier_data[j * modifier_stride + modifier_offset]); + sycl::vec modifier_vec; + modifier_vec.load(0, detail::get_global_multi_ptr(&modifier_data[offset + 2 * j])); if (Dir == direction::BACKWARD) { modifier_vec[1] *= -1; } @@ -210,7 +208,7 @@ PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag // Assumes load modifier data is stored in a transposed fashion (fft_size x num_batches_local_mem) // to ensure much lesser bank conflicts global_data.log_message_global(__func__, "applying load modifier"); - detail::apply_modifier(fft_size, priv, load_modifier_data, 2 * n_transforms, 2 * i); + detail::apply_modifier(fft_size, priv, load_modifier_data, i * n_reals); } if (take_conjugate_on_load) { take_conjugate(priv, fft_size); @@ -224,7 +222,7 @@ PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag // Assumes store modifier data is stored in a transposed fashion (fft_size x num_batches_local_mem) // to ensure much lesser bank conflicts global_data.log_message_global(__func__, "applying store modifier"); - detail::apply_modifier(fft_size, priv, store_modifier_data, 2 * n_transforms, 2 * i); + detail::apply_modifier(fft_size, priv, store_modifier_data, i * n_reals); } if (apply_scale_factor == detail::apply_scale_factor::APPLIED) { PORTFFT_UNROLL From eae33414d46895c700e92110093cd5d23371c17c Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Mon, 15 Jan 2024 16:24:58 +0000 Subject: [PATCH 29/67] bugfixes in calculate_twiddles and ensure coalesced accesses in workitem's apply_modifier --- src/portfft/dispatcher/global_dispatcher.hpp | 35 +++++++++---------- .../dispatcher/workitem_dispatcher.hpp | 10 +++--- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index 469e4665..65b759c8 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -172,37 +172,34 @@ struct committed_descriptor::calculate_twiddles_struct::inner(dir == direction::FORWARD ? dimension_data.forward_factors : dimension_data.backward_factors); std::size_t offset = static_cast(dir == direction::FORWARD ? 0 : dimension_data.backward_factors); - Scalar* scratch_ptr = (Scalar*)malloc(2 * dimension_data.length * sizeof(Scalar)); + Scalar* scratch_ptr = (Scalar*)malloc(8 * dimension_data.length * sizeof(Scalar)); - // generate and rearrange store modifiers for (std::size_t i = 0; i < num_factors - 1; i++) { - calculate_twiddles(sub_batches.at(offset + i), factors.at(offset + i), ptr_offset, ptr); if (kernels.at(offset + i).level == detail::level::WORKITEM) { - // For the WI implementation, utilize coalesced loads from global as they are not being reused. - // shift them to local memory only for devices which do not have coalesced accesses. - detail::complex_transpose(ptr + ptr_offset, scratch_ptr, factors.at(offset + i), sub_batches.at(offset + i), - factors.at(offset + i) * sub_batches.at(offset + i)); - std::memcpy( - ptr + ptr_offset, scratch_ptr, - static_cast(2 * factors.at(offset + i) * sub_batches.at(offset + i)) * sizeof(Scalar)); + // Use coalesced loads from global memory and not local memory to ensure optimal accesses. + // Local memory option provided for devices which do not support coalesced accesses. + calculate_twiddles(factors.at(offset + i), sub_batches.at(offset + i), ptr_offset, ptr); + } else { + calculate_twiddles(sub_batches.at(offset + i), factors.at(offset + i), ptr_offset, ptr); } } // Calculate twiddles for the implementation corresponding to per factor; for (std::size_t i = 0; i < num_factors; i++) { const auto& kernel_data = kernels.at(offset + i); - if (kernels.at(offset + i).level == detail::level::SUBGROUP) { + if (kernel_data.level == detail::level::SUBGROUP) { for (std::size_t j = 0; j < std::size_t(kernel_data.factors.at(0)); j++) { for (std::size_t k = 0; k < std::size_t(kernel_data.factors.at(1)); k++) { double theta = -2 * M_PI * static_cast(j * k) / static_cast(kernel_data.factors.at(0) * kernel_data.factors.at(1)); auto twiddle = std::complex(static_cast(std::cos(theta)), static_cast(std::sin(theta))); - ptr[offset + j * static_cast(kernel_data.factors.at(0)) + i] = twiddle.real(); - ptr[offset + - (j + static_cast(kernel_data.factors.at(1))) * + ptr[static_cast(ptr_offset) + k * static_cast(kernel_data.factors.at(0)) + j] = + twiddle.real(); + ptr[static_cast(ptr_offset) + + (k + static_cast(kernel_data.factors.at(1))) * static_cast(kernel_data.factors.at(0)) + - i] = twiddle.imag(); + j] = twiddle.imag(); } } ptr_offset += 2 * kernel_data.factors.at(0) * kernel_data.factors.at(1); @@ -216,9 +213,9 @@ struct committed_descriptor::calculate_twiddles_struct::inner(factor_n), static_cast(factor_m), ptr_offset, ptr); for (Idx j = 0; j < factor_n; j++) { - detail::complex_transpose(ptr + offset + 2 * j * factor_n, scratch_ptr, factor_m, factor_n, + detail::complex_transpose(ptr + ptr_offset + 2 * j * factor_n, scratch_ptr, factor_m, factor_n, factor_n * factor_m); - std::memcpy(ptr + offset + 2 * j * factor_n, scratch_ptr, + std::memcpy(ptr + ptr_offset + 2 * j * factor_n, scratch_ptr, static_cast(2 * factor_n * factor_m) * sizeof(float)); } } @@ -250,11 +247,11 @@ struct committed_descriptor::calculate_twiddles_struct::inner( detail::level::SUBGROUP, static_cast(factors.at(offset + i)), kernel_data.used_sg_size, - {static_cast(factor_sg), static_cast(factor_wi)}, num_sgs_in_wg); + {static_cast(factor_wi), static_cast(factor_sg)}, num_sgs_in_wg); } else { kernel_data.local_mem_required = desc.num_scalars_in_local_mem( detail::level::SUBGROUP, static_cast(factors.at(offset + i)), kernel_data.used_sg_size, - {static_cast(factor_sg), static_cast(factor_wi)}, num_sgs_in_wg); + {static_cast(factor_wi), static_cast(factor_sg)}, num_sgs_in_wg); } auto [global_range, local_range] = detail::get_launch_params(factors.at(offset + i), sub_batches.at(offset + i), detail::level::SUBGROUP, diff --git a/src/portfft/dispatcher/workitem_dispatcher.hpp b/src/portfft/dispatcher/workitem_dispatcher.hpp index a852e759..7d299020 100644 --- a/src/portfft/dispatcher/workitem_dispatcher.hpp +++ b/src/portfft/dispatcher/workitem_dispatcher.hpp @@ -65,11 +65,13 @@ IdxGlobal get_global_size_workitem(IdxGlobal n_transforms, Idx subgroup_size, Id * @param offset offset for the global modifier data pointer */ template -PORTFFT_INLINE void apply_modifier(Idx num_elements, PrivT priv, const T* modifier_data, IdxGlobal offset) { +PORTFFT_INLINE void apply_modifier(Idx num_elements, PrivT priv, const T* modifier_data, IdxGlobal modifier_stride, + IdxGlobal modifier_offset) { + using vec_t = sycl::vec; PORTFFT_UNROLL for (Idx j = 0; j < num_elements; j++) { sycl::vec modifier_vec; - modifier_vec.load(0, detail::get_global_multi_ptr(&modifier_data[offset + 2 * j])); + modifier_vec = *reinterpret_cast(modifier_data + j * modifier_stride + modifier_offset); if (Dir == direction::BACKWARD) { modifier_vec[1] *= -1; } @@ -208,7 +210,7 @@ PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag // Assumes load modifier data is stored in a transposed fashion (fft_size x num_batches_local_mem) // to ensure much lesser bank conflicts global_data.log_message_global(__func__, "applying load modifier"); - detail::apply_modifier(fft_size, priv, load_modifier_data, i * n_reals); + detail::apply_modifier(fft_size, priv, load_modifier_data, 2 * n_transforms, 2 * i); } if (take_conjugate_on_load) { take_conjugate(priv, fft_size); @@ -222,7 +224,7 @@ PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag // Assumes store modifier data is stored in a transposed fashion (fft_size x num_batches_local_mem) // to ensure much lesser bank conflicts global_data.log_message_global(__func__, "applying store modifier"); - detail::apply_modifier(fft_size, priv, store_modifier_data, i * n_reals); + detail::apply_modifier(fft_size, priv, store_modifier_data, 2 * n_transforms, 2 * i); } if (apply_scale_factor == detail::apply_scale_factor::APPLIED) { PORTFFT_UNROLL From 1861125662c53658bed7ce34463899fd6efd745a Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Tue, 16 Jan 2024 12:33:59 +0000 Subject: [PATCH 30/67] pass load modifier pointer and loc memory to launch_kernel and dispatch level --- src/portfft/common/global.hpp | 40 +++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/src/portfft/common/global.hpp b/src/portfft/common/global.hpp index 6a6e4d89..46bf025d 100644 --- a/src/portfft/common/global.hpp +++ b/src/portfft/common/global.hpp @@ -133,15 +133,15 @@ PORTFFT_INLINE inline IdxGlobal get_outer_batch_offset(const IdxGlobal* factors, */ template PORTFFT_INLINE void dispatch_level(const Scalar* input, Scalar* output, const Scalar* implementation_twiddles, - const Scalar* store_modifier_data, Scalar* input_loc, Scalar* twiddles_loc, + const Scalar* load_modifier_data, const Scalar* store_modifier_data, + Scalar* input_loc, Scalar* twiddles_loc, Scalar* load_modifier_loc, Scalar* store_modifier_loc, const IdxGlobal* factors, const IdxGlobal* inner_batches, const IdxGlobal* inclusive_scan, IdxGlobal batch_size, Scalar scale_factor, detail::global_data_struct<1> global_data, sycl::kernel_handler& kh) { auto level = kh.get_specialization_constant(); Idx level_num = kh.get_specialization_constant(); Idx num_factors = kh.get_specialization_constant(); - bool increment_modifier_pointer = - kh.get_specialization_constant(); // Should it be a spec constant ? + bool increment_modifier_pointer = kh.get_specialization_constant(); global_data.log_message_global(__func__, "dispatching sub implementation for factor num = ", level_num); IdxGlobal outer_batch_product = get_outer_batch_product(inclusive_scan, num_factors, level_num); for (IdxGlobal iter_value = 0; iter_value < outer_batch_product; iter_value++) { @@ -151,17 +151,17 @@ PORTFFT_INLINE void dispatch_level(const Scalar* input, Scalar* output, const Sc if (level == detail::level::WORKITEM) { workitem_impl( input + outer_batch_offset, output + outer_batch_offset, nullptr, nullptr, input_loc, batch_size, - scale_factor, global_data, kh, static_cast(nullptr), - store_modifier_data + store_modifier_offset, static_cast(nullptr), store_modifier_loc); + scale_factor, global_data, kh, load_modifier_data, store_modifier_data + store_modifier_offset, + load_modifier_loc, store_modifier_loc); } else if (level == detail::level::SUBGROUP) { subgroup_impl( input + outer_batch_offset, output + outer_batch_offset, nullptr, nullptr, input_loc, twiddles_loc, - batch_size, implementation_twiddles, scale_factor, global_data, kh, static_cast(nullptr), - store_modifier_data + store_modifier_offset, static_cast(nullptr), store_modifier_loc); + batch_size, implementation_twiddles, scale_factor, global_data, kh, load_modifier_data, + store_modifier_data + store_modifier_offset, load_modifier_loc, store_modifier_loc); } else if (level == detail::level::WORKGROUP) { workgroup_impl( input + outer_batch_offset, output + outer_batch_offset, nullptr, nullptr, input_loc, twiddles_loc, - batch_size, implementation_twiddles, scale_factor, global_data, kh, static_cast(nullptr), + batch_size, implementation_twiddles, scale_factor, global_data, kh, load_modifier_data, store_modifier_data + store_modifier_offset); } sycl::group_barrier(global_data.it.get_group()); @@ -182,6 +182,8 @@ PORTFFT_INLINE void dispatch_level(const Scalar* input, Scalar* output, const Sc * @param loc_for_twiddles local memory for twiddles * @param loc_for_store_modifier local memory for store modifier data * @param multipliers_between_factors twiddles to be multiplied between factors + * @param loc_for_load_modifier local memory for load modifiers + * @param load_modifier pointer to global memory containing the load modifier data * @param impl_twiddles twiddles required for sub implementation * @param factors global memory pointer containing factors of the input * @param inner_batches global memory pointer containing the inner batch for each factor @@ -197,6 +199,7 @@ template & input, Scalar* output, sycl::local_accessor& loc_for_input, sycl::local_accessor& loc_for_twiddles, sycl::local_accessor& loc_for_store_modifier, const Scalar* multipliers_between_factors, + sycl::local_accessor& loc_for_load_modifier, const Scalar* load_modifier, const Scalar* impl_twiddles, const IdxGlobal* factors, const IdxGlobal* inner_batches, const IdxGlobal* inclusive_scan, IdxGlobal n_transforms, Scalar scale_factor, IdxGlobal input_batch_offset, std::pair, sycl::range<1>> launch_params, @@ -214,9 +217,9 @@ void launch_kernel(sycl::accessor& in #endif it}; dispatch_level( - &input[0] + input_batch_offset, output, impl_twiddles, multipliers_between_factors, &loc_for_input[0], - &loc_for_twiddles[0], &loc_for_store_modifier[0], factors, inner_batches, inclusive_scan, n_transforms, - scale_factor, global_data, kh); + &input[0] + input_batch_offset, output, impl_twiddles, load_modifier, multipliers_between_factors, + &loc_for_input[0], &loc_for_twiddles[0], &loc_for_load_modifier[0], &loc_for_store_modifier[0], factors, + inner_batches, inclusive_scan, n_transforms, scale_factor, global_data, kh); }); } @@ -235,6 +238,8 @@ void launch_kernel(sycl::accessor& in * @param loc_for_twiddles local memory for twiddles * @param loc_for_store_modifier local memory for store modifier data * @param multipliers_between_factors twiddles to be multiplied between factors + * @param loc_for_load_modifier local memory for load modifiers + * @param load_modifier pointer to global memory containing the load modifier data * @param impl_twiddles twiddles required for sub implementation * @param factors global memory pointer containing factors of the input * @param inner_batches global memory pointer containing the inner batch for each factor @@ -250,6 +255,7 @@ template & loc_for_input, sycl::local_accessor& loc_for_twiddles, sycl::local_accessor& loc_for_store_modifier, const Scalar* multipliers_between_factors, + sycl::local_accessor& loc_for_load_modifier, const Scalar* load_modifier, const Scalar* impl_twiddles, const IdxGlobal* factors, const IdxGlobal* inner_batches, const IdxGlobal* inclusive_scan, IdxGlobal n_transforms, Scalar scale_factor, IdxGlobal input_batch_offset, std::pair, sycl::range<1>> launch_params, @@ -267,9 +273,9 @@ void launch_kernel(const Scalar* input, Scalar* output, sycl::local_accessor( - &input[0] + input_batch_offset, output, impl_twiddles, multipliers_between_factors, &loc_for_input[0], - &loc_for_twiddles[0], &loc_for_store_modifier[0], factors, inner_batches, inclusive_scan, n_transforms, - scale_factor, global_data, kh); + &input[0] + input_batch_offset, output, impl_twiddles, load_modifier, multipliers_between_factors, + &loc_for_input[0], &loc_for_twiddles[0], &loc_for_load_modifier[0], &loc_for_store_modifier[0], factors, + inner_batches, inclusive_scan, n_transforms, scale_factor, global_data, kh); }); } @@ -504,6 +510,7 @@ void compute_level(const typename committed_descriptor::kernel_d sycl::local_accessor loc_for_input(local_memory_for_input, cgh); sycl::local_accessor loc_for_twiddles(loc_mem_for_twiddles, cgh); sycl::local_accessor loc_for_modifier(local_mem_for_store_modifier, cgh); + sycl::local_accessor loc_for_load_modifier(1, cgh); auto in_acc_or_usm = detail::get_access(input, cgh); cgh.use_kernel_bundle(kd_struct.exec_bundle); if (static_cast(in_dependencies.size()) < num_batches_in_l2) { @@ -519,8 +526,9 @@ void compute_level(const typename committed_descriptor::kernel_d const Scalar* subimpl_twiddles = using_wi_level ? nullptr : twiddles_ptr + subimpl_twiddle_offset; detail::launch_kernel( in_acc_or_usm, output + 2 * batch_in_l2 * committed_size, loc_for_input, loc_for_twiddles, loc_for_modifier, - twiddles_ptr + intermediate_twiddle_offset, subimpl_twiddles, factors_triple, inner_batches, inclusive_scan, - batch_size, scale_factor, 2 * committed_size * batch_in_l2 + input_global_offset, + twiddles_ptr + intermediate_twiddle_offset, loc_for_load_modifier, static_cast(nullptr), + subimpl_twiddles, factors_triple, inner_batches, inclusive_scan, batch_size, scale_factor, + 2 * committed_size * batch_in_l2 + input_global_offset, {sycl::range<1>(static_cast(global_range)), sycl::range<1>(static_cast(local_range))}, cgh); From 63dd4733b84474777d05a08aefcc9637ac26ac07 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Tue, 16 Jan 2024 14:31:01 +0000 Subject: [PATCH 31/67] add option for subgroup loads in apply_modifier --- src/portfft/dispatcher/workitem_dispatcher.hpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/portfft/dispatcher/workitem_dispatcher.hpp b/src/portfft/dispatcher/workitem_dispatcher.hpp index 7d299020..458bdad5 100644 --- a/src/portfft/dispatcher/workitem_dispatcher.hpp +++ b/src/portfft/dispatcher/workitem_dispatcher.hpp @@ -66,12 +66,16 @@ IdxGlobal get_global_size_workitem(IdxGlobal n_transforms, Idx subgroup_size, Id */ template PORTFFT_INLINE void apply_modifier(Idx num_elements, PrivT priv, const T* modifier_data, IdxGlobal modifier_stride, - IdxGlobal modifier_offset) { + IdxGlobal modifier_offset, [[maybe_unused]] sycl::sub_group& sg) { using vec_t = sycl::vec; PORTFFT_UNROLL for (Idx j = 0; j < num_elements; j++) { sycl::vec modifier_vec; +#ifdef PORTFFT_USE_SG_TRANSFERS + modifier_vec = sg.load<2>(detail::get_global_multi_ptr(modifier_data + j * modifier_stride + modifier_offset)); +#else modifier_vec = *reinterpret_cast(modifier_data + j * modifier_stride + modifier_offset); +#endif if (Dir == direction::BACKWARD) { modifier_vec[1] *= -1; } @@ -210,7 +214,7 @@ PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag // Assumes load modifier data is stored in a transposed fashion (fft_size x num_batches_local_mem) // to ensure much lesser bank conflicts global_data.log_message_global(__func__, "applying load modifier"); - detail::apply_modifier(fft_size, priv, load_modifier_data, 2 * n_transforms, 2 * i); + detail::apply_modifier(fft_size, priv, load_modifier_data, 2 * n_transforms, 2 * i, global_data.sg); } if (take_conjugate_on_load) { take_conjugate(priv, fft_size); @@ -224,7 +228,7 @@ PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag // Assumes store modifier data is stored in a transposed fashion (fft_size x num_batches_local_mem) // to ensure much lesser bank conflicts global_data.log_message_global(__func__, "applying store modifier"); - detail::apply_modifier(fft_size, priv, store_modifier_data, 2 * n_transforms, 2 * i); + detail::apply_modifier(fft_size, priv, store_modifier_data, 2 * n_transforms, 2 * i, global_data.sg); } if (apply_scale_factor == detail::apply_scale_factor::APPLIED) { PORTFFT_UNROLL From 9ea81fe589da98896b62c7b898cea1f901f46844 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Tue, 16 Jan 2024 16:01:51 +0000 Subject: [PATCH 32/67] transpose load modifiers if required --- src/portfft/dispatcher/global_dispatcher.hpp | 21 ++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index 65b759c8..140e1c4a 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -168,11 +168,10 @@ struct committed_descriptor::calculate_twiddles_struct::inner& factors, std::vector& sub_batches, IdxGlobal& ptr_offset, - direction dir) -> void { + Scalar* scratch_ptr, direction dir) -> void { std::size_t num_factors = static_cast(dir == direction::FORWARD ? dimension_data.forward_factors : dimension_data.backward_factors); std::size_t offset = static_cast(dir == direction::FORWARD ? 0 : dimension_data.backward_factors); - Scalar* scratch_ptr = (Scalar*)malloc(8 * dimension_data.length * sizeof(Scalar)); for (std::size_t i = 0; i < num_factors - 1; i++) { if (kernels.at(offset + i).level == detail::level::WORKITEM) { @@ -315,6 +314,7 @@ struct committed_descriptor::calculate_twiddles_struct::inner host_memory(static_cast(mem_required_for_twiddles)); Scalar* device_twiddles = sycl::malloc_device(static_cast(mem_required_for_twiddles), desc.queue); + Scalar* scratch_ptr = (Scalar*)malloc(8 * dimension_data.length * sizeof(Scalar)); IdxGlobal offset = 0; if (dimension_data.is_prime) { @@ -326,11 +326,24 @@ struct committed_descriptor::calculate_twiddles_struct::inner(dimension_data.committed_length), static_cast(dimension_data.length)); offset += static_cast(2 * dimension_data.length); + // set the layout of the load modifiers according the requirement of the sub-impl. + if (kernels.at(0).level == detail::level::SUBGROUP) { + IdxGlobal base_offset = static_cast(2 * dimension_data.length); + for (IdxGlobal i = 0; i < kernels.at(0).batch_size; i++) { + detail::complex_transpose(host_memory.data() + base_offset, scratch_ptr, kernels.at(0).factors[0], + kernels.at(0).factors[1], kernels.at(0).factors[0] * kernels.at(0).factors[1]); + std::memcpy(host_memory.data() + base_offset, scratch_ptr, + 2 * kernels.at(0).factors[0] * kernels.at(0).factors[1] * sizeof(float)); + base_offset += 2 * kernels.at(0).factors[0] * kernels.at(0).factors[1]; + } + } } - populate_twiddles_and_metadata(host_memory.data(), factors_idx_global, sub_batches, offset, direction::FORWARD); + populate_twiddles_and_metadata(host_memory.data(), factors_idx_global, sub_batches, offset, scratch_ptr, + direction::FORWARD); if (dimension_data.backward_factors) { - populate_twiddles_and_metadata(host_memory.data(), factors_idx_global, sub_batches, offset, direction::BACKWARD); + populate_twiddles_and_metadata(host_memory.data(), factors_idx_global, sub_batches, offset, scratch_ptr, + direction::BACKWARD); } desc.queue.copy(host_memory.data(), device_twiddles, static_cast(mem_required_for_twiddles)).wait(); return device_twiddles; From 2d7017099c7236331c578b21992a965820544795 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Tue, 16 Jan 2024 16:37:22 +0000 Subject: [PATCH 33/67] setting of increment store pointer spec const and using less memory in scratch --- src/portfft/descriptor.hpp | 9 ++++++--- src/portfft/dispatcher/global_dispatcher.hpp | 7 ++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/portfft/descriptor.hpp b/src/portfft/descriptor.hpp index 78b1c602..3c56d53a 100644 --- a/src/portfft/descriptor.hpp +++ b/src/portfft/descriptor.hpp @@ -520,12 +520,12 @@ class committed_descriptor { * Sets the spec constants for the global implementation. * @param prepared_vec Vector returned by prepare_implementations * @param first_uses_load_modifiers whether or not first kernel multiplies the modifier before dft compute - * @param last_uses_load_modifier whether or not first kernel multiplies the modifier after dft compute + * @param last_uses_store_modifier whether or not last kernel multiplies the modifier after dft compute * @param num_kernels number of factors */ void set_global_impl_spec_consts(std::vector& prepared_vec, detail::elementwise_multiply first_uses_load_modifiers, - detail::elementwise_multiply last_uses_load_modifier, Idx num_kernels) { + detail::elementwise_multiply last_uses_store_modifier, Idx num_kernels) { Idx counter = 0; for (auto& [level, in_bundle, factors] : prepared_vec) { if (counter >= num_kernels) { @@ -535,8 +535,11 @@ class committed_descriptor { set_spec_constants( detail::level::GLOBAL, in_bundle, static_cast(std::accumulate(factors.begin(), factors.end(), Idx(1), std::multiplies())), - factors, detail::elementwise_multiply::NOT_APPLIED, last_uses_load_modifier, + factors, detail::elementwise_multiply::NOT_APPLIED, last_uses_store_modifier, detail::apply_scale_factor::APPLIED, level, static_cast(counter), static_cast(num_kernels)); + if (last_uses_store_modifier == detail::elementwise_multiply::APPLIED) { + in_bundle.template set_specialization_constant(true); + } } else if (counter == 0) { set_spec_constants( detail::level::GLOBAL, in_bundle, diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index 140e1c4a..9c73b98d 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -314,7 +314,7 @@ struct committed_descriptor::calculate_twiddles_struct::inner host_memory(static_cast(mem_required_for_twiddles)); Scalar* device_twiddles = sycl::malloc_device(static_cast(mem_required_for_twiddles), desc.queue); - Scalar* scratch_ptr = (Scalar*)malloc(8 * dimension_data.length * sizeof(Scalar)); + Scalar* scratch_ptr = (Scalar*)malloc(2 * dimension_data.length * sizeof(Scalar)); IdxGlobal offset = 0; if (dimension_data.is_prime) { @@ -332,8 +332,9 @@ struct committed_descriptor::calculate_twiddles_struct::inner(2 * kernels.at(0).factors[0] * kernels.at(0).factors[1]) * sizeof(Scalar)); base_offset += 2 * kernels.at(0).factors[0] * kernels.at(0).factors[1]; } } From a07ec9cca05acb9db9cda781ccb971cfba42181f Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Tue, 16 Jan 2024 17:18:52 +0000 Subject: [PATCH 34/67] updated doxygens and descriptions --- src/portfft/common/host_fft.hpp | 8 +++++++ src/portfft/dispatcher/global_dispatcher.hpp | 23 +++++++++++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/portfft/common/host_fft.hpp b/src/portfft/common/host_fft.hpp index b7ebe70e..2ad6affc 100644 --- a/src/portfft/common/host_fft.hpp +++ b/src/portfft/common/host_fft.hpp @@ -26,6 +26,14 @@ namespace portfft { namespace detail { + +/** + * Host Naive DFT. Works OOP only + * @tparam T Scalar Type + * @param input input pointer + * @param output output pointer + * @param fft_size fft size + */ template void naive_dft(std::complex* input, std::complex* output, IdxGlobal fft_size) { using ctype = std::complex; diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index 9c73b98d..853a8d30 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -102,6 +102,18 @@ inline IdxGlobal increment_twiddle_offset(detail::level level, Idx factor_size) return 0; } +/** + * Utility function to copy data between pointers with different distances between each batch. + * @tparam T scalar type + * @param src source pointer + * @param dst destination pointer + * @param num_elements_to_copy number of elements to copy + * @param src_stride stride of the source pointer + * @param dst_stride stride of the destination pointer + * @param num_copies number of batches to copy + * @param event_vector vector to store the generated events + * @param queue queue + */ template void trigger_device_copy(const T* src, T* dst, std::size_t num_elements_to_copy, std::size_t src_stride, std::size_t dst_stride, std::size_t num_copies, std::vector& event_vector, @@ -175,7 +187,7 @@ struct committed_descriptor::calculate_twiddles_struct::inner::run_kernel_struct(kernels.at(i).length); } - auto run_global = [&](const std::vector& kernels, const std::size_t& i) { + auto global_impl_driver = [&](const std::vector& kernels, + const std::size_t& i) { IdxGlobal intermediate_twiddles_offset = 0; IdxGlobal impl_twiddle_offset = initial_impl_twiddle_offset; if (dimension_data.is_prime) { @@ -490,11 +503,11 @@ struct committed_descriptor::run_kernel_struct( + global_impl_driver.template operator()( std::vector(kernels.begin() + static_cast(dimension_data.forward_factors), kernels.begin() + static_cast(dimension_data.forward_factors)), i); - run_global.template operator()( + global_impl_driver.template operator()( std::vector(kernels.begin() + static_cast(dimension_data.forward_factors), kernels.end()), i); desc.queue.submit([&](sycl::handler& cgh) { cgh.depends_on(current_events[0]); @@ -507,7 +520,7 @@ struct committed_descriptor::run_kernel_struct(kernels, i); + global_impl_driver.template operator()(kernels, i); } } return desc.queue.submit([&](sycl::handler& cgh) { From 8a8371d6685afe724b037191266eb70f4c8a24ca Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Wed, 17 Jan 2024 17:55:27 +0000 Subject: [PATCH 35/67] remove direction --- src/portfft/common/global.hpp | 22 +- src/portfft/common/helpers.hpp | 14 ++ src/portfft/common/subgroup.hpp | 31 +-- src/portfft/common/workgroup.hpp | 35 +-- src/portfft/common/workitem.hpp | 30 +-- src/portfft/descriptor.hpp | 228 ++++++++++-------- src/portfft/dispatcher/global_dispatcher.hpp | 14 +- .../dispatcher/subgroup_dispatcher.hpp | 33 ++- .../dispatcher/workgroup_dispatcher.hpp | 33 +-- .../dispatcher/workitem_dispatcher.hpp | 26 +- src/portfft/specialization_constant.hpp | 3 + src/portfft/utils.hpp | 33 ++- 12 files changed, 272 insertions(+), 230 deletions(-) diff --git a/src/portfft/common/global.hpp b/src/portfft/common/global.hpp index fdd0c36b..70923a36 100644 --- a/src/portfft/common/global.hpp +++ b/src/portfft/common/global.hpp @@ -111,7 +111,6 @@ PORTFFT_INLINE inline IdxGlobal get_outer_batch_offset(const IdxGlobal* factors, /** * Device function responsible for calling the corresponding sub-implementation * - * @tparam Dir Direction of the FFT * @tparam Scalar Scalar type * @tparam LayoutIn Input layout * @tparam LayoutOut Output layout @@ -131,7 +130,7 @@ PORTFFT_INLINE inline IdxGlobal get_outer_batch_offset(const IdxGlobal* factors, * @param global_data global data * @param kh kernel handler */ -template +template PORTFFT_INLINE void dispatch_level(const Scalar* input, Scalar* output, const Scalar* implementation_twiddles, const Scalar* store_modifier_data, Scalar* input_loc, Scalar* twiddles_loc, Scalar* store_modifier_loc, const IdxGlobal* factors, const IdxGlobal* inner_batches, @@ -168,7 +167,6 @@ PORTFFT_INLINE void dispatch_level(const Scalar* input, Scalar* output, const Sc /** * Utility function to launch the kernel when the input is a buffer * @tparam Scalar Scalar type - * @tparam Dir Direction of the FFT * @tparam Domain Domain of the compute * @tparam LayoutIn Input layout * @tparam LayoutOut Output layout @@ -189,8 +187,7 @@ PORTFFT_INLINE void dispatch_level(const Scalar* input, Scalar* output, const Sc * @param launch_params launch configuration, the global and local range with which the kernel will get launched * @param cgh associated command group handler */ -template +template void launch_kernel(sycl::accessor& input, Scalar* output, sycl::local_accessor& loc_for_input, sycl::local_accessor& loc_for_twiddles, sycl::local_accessor& loc_for_store_modifier, const Scalar* multipliers_between_factors, @@ -210,7 +207,7 @@ void launch_kernel(sycl::accessor& in s, #endif it}; - dispatch_level( + dispatch_level( &input[0] + input_batch_offset, output, impl_twiddles, multipliers_between_factors, &loc_for_input[0], &loc_for_twiddles[0], &loc_for_store_modifier[0], factors, inner_batches, inclusive_scan, n_transforms, scale_factor, global_data, kh); @@ -221,7 +218,6 @@ void launch_kernel(sycl::accessor& in * TODO: Launch the kernel directly from compute_level and remove the duplicated launch_kernel * Utility function to launch the kernel when the input is an USM * @tparam Scalar Scalar type - * @tparam Dir Direction of the FFT * @tparam Domain Domain of the compute * @tparam LayoutIn Input layout * @tparam LayoutOut Output layout @@ -242,8 +238,7 @@ void launch_kernel(sycl::accessor& in * @param launch_params launch configuration, the global and local range with which the kernel will get launched * @param cgh associated command group handler */ -template +template void launch_kernel(const Scalar* input, Scalar* output, sycl::local_accessor& loc_for_input, sycl::local_accessor& loc_for_twiddles, sycl::local_accessor& loc_for_store_modifier, const Scalar* multipliers_between_factors, @@ -263,7 +258,7 @@ void launch_kernel(const Scalar* input, Scalar* output, sycl::local_accessor( + dispatch_level( &input[0] + input_batch_offset, output, impl_twiddles, multipliers_between_factors, &loc_for_input[0], &loc_for_twiddles[0], &loc_for_store_modifier[0], factors, inner_batches, inclusive_scan, n_transforms, scale_factor, global_data, kh); @@ -430,7 +425,6 @@ sycl::event transpose_level(const typename committed_descriptor: * Prepares the launch of fft compute at a particular level * @tparam Scalar Scalar type * @tparam Domain Domain of FFT - * @tparam Dir Direction of the FFT * @tparam LayoutIn Input layout * @tparam LayoutOut output layout * @tparam SubgroupSize subgroup size @@ -456,8 +450,8 @@ sycl::event transpose_level(const typename committed_descriptor: * @param queue queue * @return vector events, one for each batch in l2 */ -template +template std::vector compute_level( const typename committed_descriptor::kernel_data_struct& kd_struct, const TIn input, Scalar* output, const Scalar* twiddles_ptr, const IdxGlobal* factors_triple, Scalar scale_factor, @@ -513,7 +507,7 @@ std::vector compute_level( // the subimpl_twiddles + subimpl_twiddle_offset may point to the end of the allocation and therefore be invalid. const bool using_wi_level = kd_struct.level == detail::level::WORKITEM; const Scalar* subimpl_twiddles = using_wi_level ? nullptr : twiddles_ptr + subimpl_twiddle_offset; - detail::launch_kernel( + detail::launch_kernel( in_acc_or_usm, output + 2 * batch_in_l2 * committed_size, loc_for_input, loc_for_twiddles, loc_for_modifier, twiddles_ptr + intermediate_twiddle_offset, subimpl_twiddles, factors_triple, inner_batches, inclusive_scan, batch_size, scale_factor, 2 * committed_size * batch_in_l2 + input_global_offset, diff --git a/src/portfft/common/helpers.hpp b/src/portfft/common/helpers.hpp index 476dc820..f164c4ff 100644 --- a/src/portfft/common/helpers.hpp +++ b/src/portfft/common/helpers.hpp @@ -182,6 +182,20 @@ PORTFFT_INLINE constexpr Idx int_log2(Idx x) { } return y; } + +/** + * Takes the conjugate of the complex data in private array + * @tparam T Scalar type + * @param priv pointer to the data in registers + * @param num_elements number of complex numbers in the private memory + */ +template +PORTFFT_INLINE void take_conjugate(T* priv, Idx num_elements) { + PORTFFT_UNROLL + for (Idx i = 0; i < num_elements; i++) { + priv[2 * i + 1] *= -1; + } +} } // namespace portfft::detail #endif diff --git a/src/portfft/common/subgroup.hpp b/src/portfft/common/subgroup.hpp index b396ced9..c20a8009 100644 --- a/src/portfft/common/subgroup.hpp +++ b/src/portfft/common/subgroup.hpp @@ -60,14 +60,13 @@ factors and does transposition and twiddle multiplication inbetween. */ // forward declaration -template +template PORTFFT_INLINE void cross_sg_dft(T& real, T& imag, Idx fft_size, Idx stride, sycl::sub_group& sg); /** * Calculates DFT using naive algorithm by using workitems of one subgroup. * Each workitem holds one input and one output complex value. * - * @tparam Dir direction of the FFT * @tparam T type of the scalar to work on * @param[in,out] real real component of the input/output complex value for one * workitem @@ -78,7 +77,7 @@ PORTFFT_INLINE void cross_sg_dft(T& real, T& imag, Idx fft_size, Idx stride, syc * DFT * @param sg subgroup */ -template +template PORTFFT_INLINE void cross_sg_naive_dft(T& real, T& imag, Idx fft_size, Idx stride, sycl::sub_group& sg) { if (fft_size == 2 && (stride & (stride - 1)) == 0) { Idx local_id = static_cast(sg.get_local_linear_id()); @@ -106,9 +105,7 @@ PORTFFT_INLINE void cross_sg_naive_dft(T& real, T& imag, Idx fft_size, Idx strid for (Idx idx_in = 0; idx_in < fft_size; idx_in++) { T multi_re = twiddle::Re[fft_size][idx_in * idx_out % fft_size]; T multi_im = twiddle::Im[fft_size][idx_in * idx_out % fft_size]; - if constexpr (Dir == direction::BACKWARD) { - multi_im = -multi_im; - } + Idx source_wi_id = fft_start + idx_in * stride; T cur_real = sycl::select_from_group(sg, real, static_cast(source_wi_id)); @@ -157,7 +154,6 @@ PORTFFT_INLINE void cross_sg_transpose(T& real, T& imag, Idx factor_n, Idx facto * Calculates DFT using Cooley-Tukey FFT algorithm. Size of the problem is N*M. * Each workitem holds one input and one output complex value. * - * @tparam Dir FFT direction, takes either direction::FORWARD or direction::BACKWARD * @tparam SubgroupSize Size of subgroup in kernel * @tparam RecursionLevel level of recursion in SG dft * @tparam T type of the scalar to work on @@ -171,7 +167,7 @@ PORTFFT_INLINE void cross_sg_transpose(T& real, T& imag, Idx factor_n, Idx facto * DFT * @param sg subgroup */ -template +template PORTFFT_INLINE void cross_sg_cooley_tukey_dft(T& real, T& imag, Idx factor_n, Idx factor_m, Idx stride, sycl::sub_group& sg) { Idx local_id = static_cast(sg.get_local_linear_id()); @@ -180,24 +176,20 @@ PORTFFT_INLINE void cross_sg_cooley_tukey_dft(T& real, T& imag, Idx factor_n, Id Idx n = index_in_outer_dft / factor_n; // index of the contiguous factor/fft // factor N - cross_sg_dft(real, imag, factor_n, factor_m * stride, sg); + cross_sg_dft(real, imag, factor_n, factor_m * stride, sg); // transpose cross_sg_transpose(real, imag, factor_n, factor_m, stride, sg); T multi_re = twiddle::Re[factor_n * factor_m][k * n]; T multi_im = twiddle::Im[factor_n * factor_m][k * n]; - if constexpr (Dir == direction::BACKWARD) { - multi_im = -multi_im; - } detail::multiply_complex(real, imag, multi_re, multi_im, real, imag); // factor M - cross_sg_dft(real, imag, factor_m, factor_n * stride, sg); + cross_sg_dft(real, imag, factor_m, factor_n * stride, sg); } /** * Calculates DFT using FFT algorithm. Each workitem holds one input and one * output complex value. * - * @tparam Dir FFT direction, takes either direction::FORWARD or direction::BACKWARD * @tparam SubgroupSize Size of subgroup in kernel * @tparam RecursionLevel level of recursion in SG dft * @tparam T type of the scalar to work on @@ -210,15 +202,15 @@ PORTFFT_INLINE void cross_sg_cooley_tukey_dft(T& real, T& imag, Idx factor_n, Id * DFT * @param sg subgroup */ -template +template PORTFFT_INLINE void cross_sg_dft(T& real, T& imag, Idx fft_size, Idx stride, sycl::sub_group& sg) { constexpr Idx MaxRecursionLevel = detail::int_log2(SubgroupSize); if constexpr (RecursionLevel < MaxRecursionLevel) { const Idx f0 = detail::factorize(fft_size); if (f0 >= 2 && fft_size / f0 >= 2) { - cross_sg_cooley_tukey_dft(real, imag, fft_size / f0, f0, stride, sg); + cross_sg_cooley_tukey_dft(real, imag, fft_size / f0, f0, stride, sg); } else { - cross_sg_naive_dft(real, imag, fft_size, stride, sg); + cross_sg_naive_dft(real, imag, fft_size, stride, sg); } } } @@ -266,7 +258,6 @@ constexpr bool fits_in_sg(IdxGlobal N, Idx sg_size) { * Calculates FFT of size N*M using workitems in a subgroup. Works in place. The * end result needs to be transposed when storing it to the local memory! * - * @tparam Dir direction of the FFT * @tparam SubgroupSize Size of subgroup in kernel * @tparam T type of the scalar used for computations * @param inout pointer to private memory where the input/output data is @@ -277,7 +268,7 @@ constexpr bool fits_in_sg(IdxGlobal N, Idx sg_size) { * commit * @param private_scratch Scratch memory for wi implementation */ -template +template PORTFFT_INLINE void sg_dft(T* inout, sycl::sub_group& sg, Idx factor_wi, Idx factor_sg, const T* sg_twiddles, T* private_scratch) { Idx idx_of_wi_in_fft = static_cast(sg.get_local_linear_id()) % factor_sg; @@ -299,7 +290,7 @@ PORTFFT_INLINE void sg_dft(T* inout, sycl::sub_group& sg, Idx factor_wi, Idx fac } } }; - wi_dft(inout, inout, factor_wi, 1, 1, private_scratch); + wi_dft<0>(inout, inout, factor_wi, 1, 1, private_scratch); } /** diff --git a/src/portfft/common/workgroup.hpp b/src/portfft/common/workgroup.hpp index 3190ed88..38972834 100644 --- a/src/portfft/common/workgroup.hpp +++ b/src/portfft/common/workgroup.hpp @@ -56,7 +56,6 @@ namespace detail { /** * Calculate all dfts in one dimension of the data stored in local memory. * - * @tparam Dir Direction of the FFT * @tparam LayoutIn Input Layout * @tparam SubgroupSize Size of the subgroup * @tparam LocalT The type of the local view @@ -77,15 +76,18 @@ namespace detail { * @param multiply_on_load Whether the input data is multiplied with some data array before fft computation. * @param MultiplyOnStore Whether the input data is multiplied with some data array after fft computation. * @param ApplyScaleFactor Whether or not the scale factor is applied + * @param take_conjugate_on_load whether or not to take conjugate of the input + * @param take_conjugate_on_store whether or not to take conjugate of the output * @param global_data global data for the kernel */ -template +template __attribute__((always_inline)) inline void dimension_dft( LocalT loc, T* loc_twiddles, const T* wg_twiddles, T scaling_factor, Idx max_num_batches_in_local_mem, Idx batch_num_in_local, const T* load_modifier_data, const T* store_modifier_data, IdxGlobal batch_num_in_kernel, Idx dft_size, Idx stride_within_dft, Idx ndfts_in_outer_dimension, detail::layout layout_in, detail::elementwise_multiply multiply_on_load, detail::elementwise_multiply multiply_on_store, - detail::apply_scale_factor apply_scale_factor, global_data_struct<1> global_data) { + detail::apply_scale_factor apply_scale_factor, bool take_conjugate_on_load, bool take_conjugate_on_store, + global_data_struct<1> global_data) { static_assert(std::is_same_v, T>, "Real type mismatch"); global_data.log_message_global(__func__, "entered", "DFTSize", dft_size, "stride_within_dft", stride_within_dft, "ndfts_in_outer_dimension", ndfts_in_outer_dimension, "max_num_batches_in_local_mem", @@ -168,9 +170,6 @@ __attribute__((always_inline)) inline void dimension_dft( sycl::vec twiddles = reinterpret_cast*>(wg_twiddles)[twiddle_index]; T twiddle_real = twiddles[0]; T twiddle_imag = twiddles[1]; - if constexpr (Dir == direction::BACKWARD) { - twiddle_imag = -twiddle_imag; - } multiply_complex(priv[2 * i], priv[2 * i + 1], twiddle_real, twiddle_imag, priv[2 * i], priv[2 * i + 1]); } global_data.log_dump_private("data in registers after twiddle multiplication:", priv, 2 * fact_wi); @@ -197,9 +196,13 @@ __attribute__((always_inline)) inline void dimension_dft( } } } - - sg_dft(priv, global_data.sg, fact_wi, fact_sg, loc_twiddles, wi_private_scratch); - + if (take_conjugate_on_load) { + take_conjugate(priv, fact_wi); + } + sg_dft(priv, global_data.sg, fact_wi, fact_sg, loc_twiddles, wi_private_scratch); + if (take_conjugate_on_store) { + take_conjugate(priv, fact_wi); + } if (working) { if (multiply_on_store == detail::elementwise_multiply::APPLIED) { // Store modifier data layout in global memory - n_transforms x N x FactorSG x FactorWI @@ -238,7 +241,6 @@ __attribute__((always_inline)) inline void dimension_dft( /** * Calculates FFT using Bailey 4 step algorithm. * - * @tparam Dir Direction of the FFT * @tparam SubgroupSize Size of the subgroup * @tparam LocalT Local memory view type * @tparam T Scalar type @@ -259,15 +261,18 @@ __attribute__((always_inline)) inline void dimension_dft( * @param multiply_on_load Whether the input data is multiplied with some data array before fft computation. * @param multiply_on_store Whether the input data is multiplied with some data array after fft computation. * @param apply_scale_factor Whether or not the scale factor is applied + * @param take_conjugate_on_load whether or not to take conjugate of the input + * @param take_conjugate_on_store whether or not to take conjugate of the output * @param global_data global data for the kernel */ -template +template PORTFFT_INLINE void wg_dft(LocalT loc, T* loc_twiddles, const T* wg_twiddles, T scaling_factor, Idx max_num_batches_in_local_mem, Idx batch_num_in_local, IdxGlobal batch_num_in_kernel, const T* load_modifier_data, const T* store_modifier_data, Idx fft_size, Idx N, Idx M, detail::layout layout_in, detail::elementwise_multiply multiply_on_load, detail::elementwise_multiply multiply_on_store, - detail::apply_scale_factor apply_scale_factor, detail::global_data_struct<1> global_data) { + detail::apply_scale_factor apply_scale_factor, bool take_conjugate_on_load, + bool take_conjugate_on_store, detail::global_data_struct<1> global_data) { global_data.log_message_global(__func__, "entered", "FFTSize", fft_size, "N", N, "M", M, "max_num_batches_in_local_mem", max_num_batches_in_local_mem, "batch_num_in_local", batch_num_in_local); @@ -275,13 +280,15 @@ PORTFFT_INLINE void wg_dft(LocalT loc, T* loc_twiddles, const T* wg_twiddles, T detail::dimension_dft( loc, loc_twiddles + (2 * M), nullptr, 1, max_num_batches_in_local_mem, batch_num_in_local, load_modifier_data, store_modifier_data, batch_num_in_kernel, N, M, 1, layout_in, multiply_on_load, - detail::elementwise_multiply::NOT_APPLIED, detail::apply_scale_factor::NOT_APPLIED, global_data); + detail::elementwise_multiply::NOT_APPLIED, detail::apply_scale_factor::NOT_APPLIED, take_conjugate_on_load, + take_conjugate_on_store, global_data); sycl::group_barrier(global_data.it.get_group()); // row-wise DFTs, including twiddle multiplications and scaling detail::dimension_dft( loc, loc_twiddles, wg_twiddles, scaling_factor, max_num_batches_in_local_mem, batch_num_in_local, load_modifier_data, store_modifier_data, batch_num_in_kernel, M, 1, N, layout_in, - detail::elementwise_multiply::NOT_APPLIED, multiply_on_store, apply_scale_factor, global_data); + detail::elementwise_multiply::NOT_APPLIED, multiply_on_store, apply_scale_factor, take_conjugate_on_load, + take_conjugate_on_store, global_data); global_data.log_message_global(__func__, "exited"); } diff --git a/src/portfft/common/workitem.hpp b/src/portfft/common/workitem.hpp index b20e3045..619455fd 100644 --- a/src/portfft/common/workitem.hpp +++ b/src/portfft/common/workitem.hpp @@ -53,7 +53,6 @@ strides. /** * Calculates DFT using naive algorithm. Can work in or out of place. * - * @tparam Dir direction of the FFT * @tparam T type of the scalar used for computations * @param in pointer to input * @param out pointer to output @@ -62,7 +61,7 @@ strides. * @param stride_out stride (in complex values) between complex values in `out` * @param privateScratch Scratch memory for this WI. Expects 2 * dftSize size. */ -template +template PORTFFT_INLINE void naive_dft(const T* in, T* out, Idx fft_size, Idx stride_in, Idx stride_out, T* privateScratch) { PORTFFT_UNROLL for (Idx idx_out = 0; idx_out < fft_size; idx_out++) { @@ -71,12 +70,8 @@ PORTFFT_INLINE void naive_dft(const T* in, T* out, Idx fft_size, Idx stride_in, PORTFFT_UNROLL for (Idx idx_in = 0; idx_in < fft_size; idx_in++) { auto re_multiplier = twiddle::Re[fft_size][idx_in * idx_out % fft_size]; - auto im_multiplier = [&]() { - if constexpr (Dir == direction::FORWARD) { - return twiddle::Im[fft_size][idx_in * idx_out % fft_size]; - } - return -twiddle::Im[fft_size][idx_in * idx_out % fft_size]; - }(); + auto im_multiplier = twiddle::Im[fft_size][idx_in * idx_out % fft_size]; + // multiply in and multi T tmp_real; T tmp_complex; @@ -98,7 +93,6 @@ PORTFFT_INLINE void naive_dft(const T* in, T* out, Idx fft_size, Idx stride_in, /** * Calculates DFT using Cooley-Tukey FFT algorithm. Can work in or out of place. Size of the problem is N*M * - * @tparam Dir direction of the FFT * @tparam T type of the scalar used for computations * @param in pointer to input * @param out pointer to output @@ -108,7 +102,7 @@ PORTFFT_INLINE void naive_dft(const T* in, T* out, Idx fft_size, Idx stride_in, * @param stride_in stride (in complex values) between complex values in `out` * @param privateScratch Scratch memory for this WI. Expects 2 * dftSize size. */ -template +template PORTFFT_INLINE void cooley_tukey_dft(const T* in, T* out, Idx factor_n, Idx factor_m, Idx stride_in, Idx stride_out, T* privateScratch) { PORTFFT_UNROLL @@ -118,12 +112,8 @@ PORTFFT_INLINE void cooley_tukey_dft(const T* in, T* out, Idx factor_n, Idx fact PORTFFT_UNROLL for (Idx j = 0; j < factor_n; j++) { auto re_multiplier = twiddle::Re[factor_n * factor_m][i * j]; - auto im_multiplier = [&]() { - if constexpr (Dir == direction::FORWARD) { - return twiddle::Im[factor_n * factor_m][i * j]; - } - return -twiddle::Im[factor_n * factor_m][i * j]; - }(); + auto im_multiplier = twiddle::Im[factor_n * factor_m][i * j]; + detail::multiply_complex(privateScratch[2 * i * factor_n + 2 * j], privateScratch[2 * i * factor_n + 2 * j + 1], re_multiplier, im_multiplier, privateScratch[2 * i * factor_n + 2 * j], privateScratch[2 * i * factor_n + 2 * j + 1]); @@ -199,7 +189,6 @@ PORTFFT_INLINE constexpr bool fits_in_wi(TIdx N) { /** * Calculates DFT using FFT algorithm. Can work in or out of place. * - * @tparam Dir direction of the FFT * @tparam T type of the scalar used for computations * @param in pointer to input * @param out pointer to output @@ -208,7 +197,7 @@ PORTFFT_INLINE constexpr bool fits_in_wi(TIdx N) { * @param stride_out stride (in complex values) between complex values in `out` * @param privateScratch Scratch memory for this WI. */ -template +template PORTFFT_INLINE void wi_dft(const T* in, T* out, Idx fft_size, Idx stride_in, Idx stride_out, T* privateScratch) { const Idx f0 = detail::factorize(fft_size); constexpr Idx MaxRecursionLevel = detail::int_log2(detail::MaxComplexPerWI) - 1; @@ -222,10 +211,9 @@ PORTFFT_INLINE void wi_dft(const T* in, T* out, Idx fft_size, Idx stride_in, Idx out[0 * stride_out + 1] = b; out[2 * stride_out + 0] = c; } else if (f0 >= 2 && fft_size / f0 >= 2) { - detail::cooley_tukey_dft(in, out, fft_size / f0, f0, stride_in, stride_out, - privateScratch); + detail::cooley_tukey_dft(in, out, fft_size / f0, f0, stride_in, stride_out, privateScratch); } else { - detail::naive_dft(in, out, fft_size, stride_in, stride_out, privateScratch); + detail::naive_dft(in, out, fft_size, stride_in, stride_out, privateScratch); } } } diff --git a/src/portfft/descriptor.hpp b/src/portfft/descriptor.hpp index d563fa96..22191724 100644 --- a/src/portfft/descriptor.hpp +++ b/src/portfft/descriptor.hpp @@ -40,8 +40,8 @@ namespace portfft { template class committed_descriptor; namespace detail { -template +template std::vector compute_level( const typename committed_descriptor::kernel_data_struct& kd_struct, TIn input, Scalar* output, const Scalar* twiddles_ptr, const IdxGlobal* factors_triple, Scalar scale_factor, @@ -59,13 +59,13 @@ sycl::event transpose_level(const typename committed_descriptor: // kernel names // TODO: Remove all templates except Scalar, Domain and Memory and SubgroupSize -template +template class workitem_kernel; -template +template class subgroup_kernel; -template +template class workgroup_kernel; -template +template class global_kernel; template class transpose_kernel; @@ -231,7 +231,8 @@ class committed_descriptor { }; struct dimension_struct { - std::vector kernels; + std::vector forward_kernels; + std::vector backward_kernels; std::shared_ptr factors_and_scan; detail::level level; std::size_t length; @@ -239,8 +240,13 @@ class committed_descriptor { Idx num_batches_in_l2; Idx num_factors; - dimension_struct(std::vector kernels, detail::level level, std::size_t length, Idx used_sg_size) - : kernels(kernels), level(level), length(length), used_sg_size(used_sg_size) {} + dimension_struct(std::vector forward_kernels, std::vector backward_kernels, + detail::level level, std::size_t length, Idx used_sg_size) + : forward_kernels(forward_kernels), + backward_kernels(backward_kernels), + level(level), + length(length), + used_sg_size(used_sg_size) {} }; std::vector dimensions; @@ -432,17 +438,20 @@ class committed_descriptor { * @param in_bundle input kernel bundle to set spec constants for * @param length length of the fft * @param factors factors of the corresponsing length - * @param multiply_on_load Whether the input data is multiplied with some data array before fft computation. - * @param multiply_on_store Whether the input data is multiplied with some data array after fft computation. + * @param multiply_on_load Whether the input data is multiplied with some data array before fft computation + * @param multiply_on_store Whether the input data is multiplied with some data array after fft computation * @param scale_factor_applied whether or not to multiply scale factor - * @param level sub implementation to run which will be set as a spec constant. + * @param level sub implementation to run which will be set as a spec constant + * @param take_conjugate_on_load whether or not to take conjugate of the input + * @param take_conjugate_on_store whether or not to take conjugate of the output * @param factor_num factor number which is set as a spec constant - * @param num_factors total number of factors of the committed size, set as a spec constant. + * @param num_factors total number of factors of the committed size, set as a spec constant */ void set_spec_constants(detail::level top_level, sycl::kernel_bundle& in_bundle, std::size_t length, const std::vector& factors, detail::elementwise_multiply multiply_on_load, detail::elementwise_multiply multiply_on_store, - detail::apply_scale_factor scale_factor_applied, detail::level level, Idx factor_num = 0, + detail::apply_scale_factor scale_factor_applied, detail::level level, + bool take_conjugate_on_load, bool take_conjugate_on_store, Idx factor_num = 0, Idx num_factors = 0) { const Idx length_idx = static_cast(length); // These spec constants are used in all implementations, so we set them here @@ -452,6 +461,8 @@ class committed_descriptor { in_bundle.template set_specialization_constant(multiply_on_load); in_bundle.template set_specialization_constant(multiply_on_store); in_bundle.template set_specialization_constant(scale_factor_applied); + in_bundle.template set_specialization_constant(take_conjugate_on_load); + in_bundle.template set_specialization_constant(take_conjugate_on_store); dispatch(top_level, in_bundle, length, factors, level, factor_num, num_factors); } @@ -525,45 +536,62 @@ class committed_descriptor { break; } } - std::vector result; + if (is_compatible) { - std::size_t counter = 0; - for (auto [level, ids, factors] : prepared_vec) { - auto in_bundle = sycl::get_kernel_bundle(queue.get_context(), ids); - if (top_level == detail::level::GLOBAL) { - if (counter == prepared_vec.size() - 1) { - set_spec_constants(detail::level::GLOBAL, in_bundle, - static_cast( - std::accumulate(factors.begin(), factors.end(), Idx(1), std::multiplies())), - factors, detail::elementwise_multiply::NOT_APPLIED, - detail::elementwise_multiply::NOT_APPLIED, detail::apply_scale_factor::APPLIED, level, - static_cast(counter), static_cast(prepared_vec.size())); + auto set_specialization_constants = [&](direction compute_direction) -> std::vector { + std::size_t counter = 0; + bool take_conjugate_on_load = compute_direction == direction::FORWARD ? false : true; + bool take_conjugate_on_store = compute_direction == direction::FORWARD ? false : true; + std::vector result; + for (auto [level, ids, factors] : prepared_vec) { + auto in_bundle = sycl::get_kernel_bundle(queue.get_context(), ids); + if (top_level == detail::level::GLOBAL) { + if (counter == prepared_vec.size() - 1) { + if (compute_direction == direction::BACKWARD) { + take_conjugate_on_load = false; + take_conjugate_on_store = true; + } + set_spec_constants(detail::level::GLOBAL, in_bundle, + static_cast( + std::accumulate(factors.begin(), factors.end(), Idx(1), std::multiplies())), + factors, detail::elementwise_multiply::NOT_APPLIED, + detail::elementwise_multiply::NOT_APPLIED, detail::apply_scale_factor::APPLIED, + level, static_cast(counter), static_cast(prepared_vec.size())); + } else { + if (counter == 0) { + take_conjugate_on_load = true; + take_conjugate_on_store = false; + } + set_spec_constants(detail::level::GLOBAL, in_bundle, + static_cast( + std::accumulate(factors.begin(), factors.end(), Idx(1), std::multiplies())), + factors, detail::elementwise_multiply::NOT_APPLIED, + detail::elementwise_multiply::APPLIED, detail::apply_scale_factor::NOT_APPLIED, + level, static_cast(counter), static_cast(prepared_vec.size())); + } } else { - set_spec_constants(detail::level::GLOBAL, in_bundle, - static_cast( - std::accumulate(factors.begin(), factors.end(), Idx(1), std::multiplies())), - factors, detail::elementwise_multiply::NOT_APPLIED, - detail::elementwise_multiply::APPLIED, detail::apply_scale_factor::NOT_APPLIED, level, - static_cast(counter), static_cast(prepared_vec.size())); + set_spec_constants(level, in_bundle, params.lengths[kernel_num], factors, + detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::NOT_APPLIED, + detail::apply_scale_factor::APPLIED, level, take_conjugate_on_load, + take_conjugate_on_store); } - } else { - set_spec_constants(level, in_bundle, params.lengths[kernel_num], factors, - detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::NOT_APPLIED, - detail::apply_scale_factor::APPLIED, level); - } - try { - result.emplace_back(sycl::build(in_bundle), factors, params.lengths[kernel_num], SubgroupSize, - PORTFFT_SGS_IN_WG, std::shared_ptr(), level); - } catch (std::exception& e) { - std::cerr << "Build for subgroup size " << SubgroupSize << " failed with message:\n" - << e.what() << std::endl; - is_compatible = false; - break; + try { + result.emplace_back(sycl::build(in_bundle), factors, params.lengths[kernel_num], SubgroupSize, + PORTFFT_SGS_IN_WG, std::shared_ptr(), level); + } catch (std::exception& e) { + std::cerr << "Build for subgroup size " << SubgroupSize << " failed with message:\n" + << e.what() << std::endl; + is_compatible = false; + break; + } + counter++; } - counter++; - } + return result; + }; + std::vector forward_kernels = set_specialization_constants(direction::FORWARD); + std::vector backward_kernels = set_specialization_constants(direction::BACKWARD); if (is_compatible) { - return {result, top_level, params.lengths[kernel_num], SubgroupSize}; + return {forward_kernels, backward_kernels, top_level, params.lengths[kernel_num], SubgroupSize}; } } } @@ -885,7 +913,7 @@ class committed_descriptor { * @param out buffer containing output data */ void compute_forward(const sycl::buffer& in, sycl::buffer& out) { - dispatch_direction(in, out, in, out, complex_storage::INTERLEAVED_COMPLEX); + dispatch_direction(in, out, in, out, complex_storage::INTERLEAVED_COMPLEX, direction::FORWARD); } /** @@ -898,7 +926,7 @@ class committed_descriptor { */ void compute_forward(const sycl::buffer& in_real, const sycl::buffer& in_imag, sycl::buffer& out_real, sycl::buffer& out_imag) { - dispatch_direction(in_real, out_real, in_imag, out_imag, complex_storage::SPLIT_COMPLEX); + dispatch_direction(in_real, out_real, in_imag, out_imag, complex_storage::SPLIT_COMPLEX, direction::FORWARD); } /** @@ -918,7 +946,7 @@ class committed_descriptor { * @param out buffer containing output data */ void compute_backward(const sycl::buffer& in, sycl::buffer& out) { - dispatch_direction(in, out, in, out, complex_storage::INTERLEAVED_COMPLEX); + dispatch_direction(in, out, in, out, complex_storage::INTERLEAVED_COMPLEX, direction::BACKWARD); } /** @@ -931,7 +959,7 @@ class committed_descriptor { */ void compute_backward(const sycl::buffer& in_real, const sycl::buffer& in_imag, sycl::buffer& out_real, sycl::buffer& out_imag) { - dispatch_direction(in_real, out_real, in_imag, out_imag, complex_storage::SPLIT_COMPLEX); + dispatch_direction(in_real, out_real, in_imag, out_imag, complex_storage::SPLIT_COMPLEX, direction::BACKWARD); } /** @@ -1009,7 +1037,7 @@ class committed_descriptor { */ sycl::event compute_forward(const complex_type* in, complex_type* out, const std::vector& dependencies = {}) { - return dispatch_direction(in, out, in, out, complex_storage::INTERLEAVED_COMPLEX, dependencies); + return dispatch_direction(in, out, in, out, complex_storage::INTERLEAVED_COMPLEX, direction::FORWARD, dependencies); } /** @@ -1024,8 +1052,8 @@ class committed_descriptor { */ sycl::event compute_forward(const scalar_type* in_real, const scalar_type* in_imag, scalar_type* out_real, scalar_type* out_imag, const std::vector& dependencies = {}) { - return dispatch_direction(in_real, out_real, in_imag, out_imag, complex_storage::SPLIT_COMPLEX, - dependencies); + return dispatch_direction(in_real, out_real, in_imag, out_imag, complex_storage::SPLIT_COMPLEX, direction::FORWARD, + dependencies); } /** @@ -1052,8 +1080,8 @@ class committed_descriptor { */ sycl::event compute_backward(const complex_type* in, complex_type* out, const std::vector& dependencies = {}) { - return dispatch_direction(in, out, in, out, complex_storage::INTERLEAVED_COMPLEX, - dependencies); + return dispatch_direction(in, out, in, out, complex_storage::INTERLEAVED_COMPLEX, direction::BACKWARD, + dependencies); } /** @@ -1068,15 +1096,14 @@ class committed_descriptor { */ sycl::event compute_backward(const scalar_type* in_real, const scalar_type* in_imag, scalar_type* out_real, scalar_type* out_imag, const std::vector& dependencies = {}) { - return dispatch_direction(in_real, out_real, in_imag, out_imag, complex_storage::SPLIT_COMPLEX, - dependencies); + return dispatch_direction(in_real, out_real, in_imag, out_imag, complex_storage::SPLIT_COMPLEX, direction::BACKWARD, + dependencies); } private: /** * Dispatches to the implementation for the appropriate direction. * - * @tparam Dir FFT direction, takes either direction::FORWARD or direction::BACKWARD * @tparam TIn Type of the input buffer or USM pointer * @tparam TOut Type of the output buffer or USM pointer * @param in buffer or USM pointer to memory containing input data. Real part of input data if @@ -1088,12 +1115,14 @@ class committed_descriptor { * @param out_imag buffer or USM pointer to memory containing imaginary part of the output data. Ignored if * `descriptor.complex_storage` is interleaved. * @param used_storage how components of a complex value are stored - either split or interleaved + * @param compute_direction direction of compute, forward / backward * @param dependencies events that must complete before the computation * @return sycl::event */ - template + template sycl::event dispatch_direction(const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, - complex_storage used_storage, const std::vector& dependencies = {}) { + complex_storage used_storage, direction compute_direction, + const std::vector& dependencies = {}) { #ifndef PORTFFT_ENABLE_BUFFER_BUILDS if constexpr (!std::is_pointer_v || !std::is_pointer_v) { throw invalid_configuration("Buffer interface can not be called when buffer builds are disabled."); @@ -1110,20 +1139,21 @@ class committed_descriptor { "INTERLEAVED_COMPLEX."); } if constexpr (Dir == direction::FORWARD) { - return dispatch_dimensions(in, out, in_imag, out_imag, dependencies, params.forward_strides, - params.backward_strides, params.forward_distance, params.backward_distance, - params.forward_offset, params.backward_offset, params.forward_scale); + return dispatch_dimensions(in, out, in_imag, out_imag, dependencies, params.forward_strides, + params.backward_strides, params.forward_distance, params.backward_distance, + params.forward_offset, params.backward_offset, params.forward_scale, + compute_direction); } else { - return dispatch_dimensions(in, out, in_imag, out_imag, dependencies, params.backward_strides, - params.forward_strides, params.backward_distance, params.forward_distance, - params.backward_offset, params.forward_offset, params.backward_scale); + return dispatch_dimensions(in, out, in_imag, out_imag, dependencies, params.backward_strides, + params.forward_strides, params.backward_distance, params.forward_distance, + params.backward_offset, params.forward_offset, params.backward_scale, + compute_direction); } } /** * Dispatches to the implementation for the appropriate number of dimensions. * - * @tparam Dir FFT direction, takes either direction::FORWARD or direction::BACKWARD * @tparam TIn Type of the input buffer or USM pointer * @tparam TOut Type of the output buffer or USM pointer * @param in buffer or USM pointer to memory containing input data. Real part of input data if @@ -1142,7 +1172,7 @@ class committed_descriptor { * @param input_offset offset into input allocation where the data for FFTs start * @param output_offset offset into output allocation where the data for FFTs start * @param scale_factor scaling factor applied to the result - * @param dimension_data data for the dimension this call will work on + * @param compute_direction direction of compute, forward / backward * @return sycl::event */ template @@ -1151,7 +1181,7 @@ class committed_descriptor { const std::vector& input_strides, const std::vector& output_strides, std::size_t input_distance, std::size_t output_distance, std::size_t input_offset, std::size_t output_offset, - Scalar scale_factor) { + Scalar scale_factor, direction compute_direction) { using TOutConst = std::conditional_t, const std::remove_pointer_t*, const TOut>; std::size_t n_dimensions = params.lengths.size(); std::size_t total_size = params.get_flattened_length(); @@ -1181,9 +1211,10 @@ class committed_descriptor { output_distance = params.lengths.back(); } - sycl::event previous_event = dispatch_kernel_1d( - in, out, in_imag, out_imag, dependencies, params.number_of_transforms * outer_size, input_stride_0, - output_stride_0, input_distance, output_distance, input_offset, output_offset, scale_factor, dimensions.back()); + sycl::event previous_event = + dispatch_kernel_1d(in, out, in_imag, out_imag, dependencies, params.number_of_transforms * outer_size, + input_stride_0, output_stride_0, input_distance, output_distance, input_offset, + output_offset, scale_factor, dimensions.back(), compute_direction); if (n_dimensions == 1) { return previous_event; } @@ -1196,10 +1227,10 @@ class committed_descriptor { // kernels. std::size_t stride_between_kernels = inner_size * params.lengths[i]; for (std::size_t j = 0; j < params.number_of_transforms * outer_size; j++) { - sycl::event e = dispatch_kernel_1d( + sycl::event e = dispatch_kernel_1d( out, out, out_imag, out_imag, previous_events, inner_size, inner_size, inner_size, 1, 1, output_offset + j * stride_between_kernels, output_offset + j * stride_between_kernels, - static_cast(1.0), dimensions[i]); + static_cast(1.0), dimensions[i], compute_direction); next_events.push_back(e); } inner_size *= params.lengths[i]; @@ -1212,7 +1243,6 @@ class committed_descriptor { /** * Dispatches the kernel with the first subgroup size that is supported by the device. * - * @tparam Dir FFT direction, takes either direction::FORWARD or direction::BACKWARD * @tparam TIn Type of the input buffer or USM pointer * @tparam TOut Type of the output buffer or USM pointer * @param in buffer or USM pointer to memory containing input data. Real part of input data if @@ -1233,6 +1263,7 @@ class committed_descriptor { * @param output_offset offset into output allocation where the data for FFTs start * @param scale_factor scaling factor applied to the result * @param dimension_data data for the dimension this call will work on + * @param compute_direction direction of compute, forward / backward * @return sycl::event */ template @@ -1240,16 +1271,15 @@ class committed_descriptor { const std::vector& dependencies, std::size_t n_transforms, std::size_t input_stride, std::size_t output_stride, std::size_t input_distance, std::size_t output_distance, std::size_t input_offset, std::size_t output_offset, - Scalar scale_factor, dimension_struct& dimension_data) { + Scalar scale_factor, dimension_struct& dimension_data, direction compute_direction) { return dispatch_kernel_1d_helper( in, out, in_imag, out_imag, dependencies, n_transforms, input_stride, output_stride, input_distance, - output_distance, input_offset, output_offset, scale_factor, dimension_data); + output_distance, input_offset, output_offset, scale_factor, dimension_data, compute_direction); } /** * Helper for dispatching the kernel with the first subgroup size that is supported by the device. * - * @tparam Dir FFT direction, takes either direction::FORWARD or direction::BACKWARD * @tparam TIn Type of the input buffer or USM pointer * @tparam TOut Type of the output buffer or USM pointer * @tparam SubgroupSize first subgroup size @@ -1272,15 +1302,16 @@ class committed_descriptor { * @param output_offset offset into output allocation where the data for FFTs start * @param scale_factor scaling factor applied to the result * @param dimension_data data for the dimension this call will work on + * @param compute_direction direction of compute, forward / backward * @return sycl::event */ - template + template sycl::event dispatch_kernel_1d_helper(const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, const std::vector& dependencies, std::size_t n_transforms, std::size_t input_stride, std::size_t output_stride, std::size_t input_distance, std::size_t output_distance, std::size_t input_offset, std::size_t output_offset, Scalar scale_factor, - dimension_struct& dimension_data) { + dimension_struct& dimension_data, direction compute_direction) { if (SubgroupSize == dimension_data.used_sg_size) { const bool input_packed = input_distance == dimension_data.length && input_stride == 1; const bool output_packed = output_distance == dimension_data.length && output_stride == 1; @@ -1301,24 +1332,24 @@ class committed_descriptor { } } if (input_packed && output_packed) { - return run_kernel( + return run_kernel( in, out, in_imag, out_imag, dependencies, n_transforms, input_offset, output_offset, scale_factor, - dimension_data); + dimension_data, compute_direction); } if (input_batch_interleaved && output_packed && in != out) { - return run_kernel( + return run_kernel( in, out, in_imag, out_imag, dependencies, n_transforms, input_offset, output_offset, scale_factor, - dimension_data); + dimension_data, compute_direction); } if (input_packed && output_batch_interleaved && in != out) { - return run_kernel( + return run_kernel( in, out, in_imag, out_imag, dependencies, n_transforms, input_offset, output_offset, scale_factor, - dimension_data); + dimension_data, compute_direction); } if (input_batch_interleaved && output_batch_interleaved) { - return run_kernel( + return run_kernel( in, out, in_imag, out_imag, dependencies, n_transforms, input_offset, output_offset, scale_factor, - dimension_data); + dimension_data, compute_direction); } throw unsupported_configuration("Only PACKED or BATCH_INTERLEAVED transforms are supported"); } @@ -1334,15 +1365,13 @@ class committed_descriptor { /** * Struct for dispatching `run_kernel()` call. * - * @tparam Dir FFT direction, takes either direction::FORWARD or direction::BACKWARD * @tparam LayoutIn Input Layout * @tparam LayoutOut Output Layout * @tparam SubgroupSize size of the subgroup * @tparam TIn Type of the input USM pointer or buffer * @tparam TOut Type of the output USM pointer or buffer */ - template + template struct run_kernel_struct { // Dummy parameter is needed as only partial specializations are allowed without specializing the containing class template @@ -1350,14 +1379,13 @@ class committed_descriptor { static sycl::event execute(committed_descriptor& desc, const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, const std::vector& dependencies, std::size_t n_transforms, std::size_t forward_offset, std::size_t backward_offset, Scalar scale_factor, - dimension_struct& dimension_data); + dimension_struct& dimension_data, direction compute_direction); }; }; /** * Common interface to run the kernel called by compute_forward and compute_backward * - * @tparam Dir FFT direction, takes either direction::FORWARD or direction::BACKWARD * @tparam LayoutIn Input Layout * @tparam LayoutOut Output Layout * @tparam SubgroupSize size of the subgroup @@ -1377,14 +1405,14 @@ class committed_descriptor { * @param output_offset offset into output allocation where the data for FFTs start * @param scale_factor scaling factor applied to the result * @param dimension_data data for the dimension this call will work on + * @param compute_direction direction of fft, forward / backward * @return sycl::event */ - template + template sycl::event run_kernel(const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, const std::vector& dependencies, std::size_t n_transforms, std::size_t input_offset, std::size_t output_offset, Scalar scale_factor, - dimension_struct& dimension_data) { + dimension_struct& dimension_data, direction compute_direction) { // mixing const and non-const inputs leads to hard-to-debug linking errors, as both use the same kernel name, but // are called from different template instantiations. static_assert(!std::is_pointer_v || std::is_const_v>, @@ -1397,11 +1425,11 @@ class committed_descriptor { using TInReinterpret = decltype(detail::reinterpret(in)); using TOutReinterpret = decltype(detail::reinterpret(out)); std::size_t vec_multiplier = params.complex_storage == complex_storage::INTERLEAVED_COMPLEX ? 2 : 1; - return dispatch>( + return dispatch>( dimension_data.level, detail::reinterpret(in), detail::reinterpret(out), detail::reinterpret(in_imag), detail::reinterpret(out_imag), dependencies, static_cast(n_transforms), static_cast(vec_multiplier * input_offset), - static_cast(vec_multiplier * output_offset), scale_factor, dimension_data); + static_cast(vec_multiplier * output_offset), scale_factor, dimension_data, compute_direction); } }; diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index 8776d051..84cd184b 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -279,18 +279,18 @@ struct committed_descriptor::num_scalars_in_local_mem_struct::in }; template -template +template template struct committed_descriptor::run_kernel_struct::inner { static sycl::event execute(committed_descriptor& desc, const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, const std::vector& dependencies, IdxGlobal n_transforms, IdxGlobal input_offset, IdxGlobal output_offset, Scalar scale_factor, - dimension_struct& dimension_data) { + dimension_struct& dimension_data, direction compute_direction) { (void)in_imag; (void)out_imag; - const auto& kernels = dimension_data.kernels; + const auto& kernels = + compute_direction == direction::FORWARD ? dimension_data.forward_kernels : dimension_data.backward_kernels; const Scalar* twiddles_ptr = static_cast(kernels.at(0).twiddles_forward.get()); const IdxGlobal* factors_and_scan = static_cast(dimension_data.factors_and_scan.get()); std::size_t num_batches = desc.params.number_of_transforms; @@ -310,7 +310,7 @@ struct committed_descriptor::run_kernel_struct( kernels.at(0), in, desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, scale_factor, intermediate_twiddles_offset, impl_twiddle_offset, @@ -324,14 +324,14 @@ struct committed_descriptor::run_kernel_struct(factor_num) == dimension_data.num_factors - 1) { l2_events = - detail::compute_level( + detail::compute_level( kernels.at(factor_num), static_cast(desc.scratch_ptr_1.get()), desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, scale_factor, intermediate_twiddles_offset, impl_twiddle_offset, 0, committed_size, static_cast(max_batches_in_l2), static_cast(num_batches), static_cast(i), static_cast(factor_num), dimension_data.num_factors, l2_events, desc.queue); } else { - l2_events = detail::compute_level( kernels.at(factor_num), static_cast(desc.scratch_ptr_1.get()), desc.scratch_ptr_1.get(), twiddles_ptr, factors_and_scan, scale_factor, intermediate_twiddles_offset, impl_twiddle_offset, 0, diff --git a/src/portfft/dispatcher/subgroup_dispatcher.hpp b/src/portfft/dispatcher/subgroup_dispatcher.hpp index e511d280..d2ee2b49 100644 --- a/src/portfft/dispatcher/subgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/subgroup_dispatcher.hpp @@ -59,7 +59,6 @@ IdxGlobal get_global_size_subgroup(IdxGlobal n_transforms, Idx factor_sg, Idx su /** * Implementation of FFT for sizes that can be done by a subgroup. * - * @tparam Dir FFT direction, takes either direction::FORWARD or direction::BACKWARD * @tparam LayoutIn Input Layout * @tparam LayoutOut Output Layout * @tparam SubgroupSize size of the subgroup @@ -86,7 +85,7 @@ IdxGlobal get_global_size_subgroup(IdxGlobal n_transforms, Idx factor_sg, Idx su * @param loc_load_modifier Pointer to load modifier data in local memory * @param loc_store_modifier Pointer to store modifier data in local memory */ -template +template PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag, T* output_imag, T* loc, T* loc_twiddles, IdxGlobal n_transforms, const T* twiddles, T scaling_factor, global_data_struct<1> global_data, sycl::kernel_handler& kh, @@ -96,6 +95,8 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag detail::elementwise_multiply multiply_on_load = kh.get_specialization_constant(); detail::elementwise_multiply multiply_on_store = kh.get_specialization_constant(); detail::apply_scale_factor apply_scale_factor = kh.get_specialization_constant(); + bool take_conjugate_on_load = kh.get_specialization_constant(); + bool take_conjugate_on_store = kh.get_specialization_constant(); const Idx factor_wi = kh.get_specialization_constant(); const Idx factor_sg = kh.get_specialization_constant(); @@ -256,7 +257,13 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag } } } - sg_dft(priv, global_data.sg, factor_wi, factor_sg, loc_twiddles, wi_private_scratch); + if (take_conjugate_on_load) { + take_conjugate(priv, factor_wi); + } + sg_dft(priv, global_data.sg, factor_wi, factor_sg, loc_twiddles, wi_private_scratch); + if (take_conjugate_on_store) { + take_conjugate(priv, factor_wi); + } if (working_inner) { global_data.log_dump_private("data in registers after computation:", priv, n_reals_per_wi); } @@ -450,7 +457,13 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag } } } - sg_dft(priv, global_data.sg, factor_wi, factor_sg, loc_twiddles, wi_private_scratch); + if (take_conjugate_on_load) { + take_conjugate(priv, factor_wi); + } + sg_dft(priv, global_data.sg, factor_wi, factor_sg, loc_twiddles, wi_private_scratch); + if (take_conjugate_on_store) { + take_conjugate(priv, factor_wi); + } if (working) { global_data.log_dump_private("data in registers after computation:", priv, n_reals_per_wi); } @@ -601,17 +614,17 @@ struct committed_descriptor::calculate_twiddles_struct::inner -template +template template -struct committed_descriptor::run_kernel_struct::run_kernel_struct::inner { static sycl::event execute(committed_descriptor& desc, const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, const std::vector& dependencies, IdxGlobal n_transforms, IdxGlobal input_offset, IdxGlobal output_offset, Scalar scale_factor, - dimension_struct& dimension_data) { + dimension_struct& dimension_data, direction compute_direction) { constexpr detail::memory Mem = std::is_pointer_v ? detail::memory::USM : detail::memory::BUFFER; - auto& kernel_data = dimension_data.kernels.at(0); + auto& kernel_data = compute_direction == direction::FORWARD ? dimension_data.forward_kernels.at(0) + : dimension_data.backward_kernels.at(0); Scalar* twiddles = kernel_data.twiddles_forward.get(); Idx factor_sg = kernel_data.factors[1]; std::size_t local_elements = @@ -641,7 +654,7 @@ struct committed_descriptor::run_kernel_struct( + detail::subgroup_impl( &in_acc_or_usm[0] + input_offset, &out_acc_or_usm[0] + output_offset, &in_imag_acc_or_usm[0] + input_offset, &out_imag_acc_or_usm[0] + output_offset, &loc[0], &loc_twiddles[0], n_transforms, twiddles, scale_factor, global_data, kh); diff --git a/src/portfft/dispatcher/workgroup_dispatcher.hpp b/src/portfft/dispatcher/workgroup_dispatcher.hpp index 445cd5a6..ddce0762 100644 --- a/src/portfft/dispatcher/workgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/workgroup_dispatcher.hpp @@ -75,7 +75,6 @@ IdxGlobal get_global_size_workgroup(IdxGlobal n_transforms, Idx subgroup_size, I /** * Implementation of FFT for sizes that can be done by a workgroup. * - * @tparam Dir Direction of the FFT * @tparam LayoutIn Input Layout * @tparam LayoutOut Output Layout * @tparam SubgroupSize size of the subgroup @@ -99,7 +98,7 @@ IdxGlobal get_global_size_workgroup(IdxGlobal n_transforms, Idx subgroup_size, I * @param load_modifier_data Pointer to the load modifier data in global Memory * @param store_modifier_data Pointer to the store modifier data in global Memory */ -template +template PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* /*input_imag*/, T* /*output_imag*/, T* loc, T* loc_twiddles, IdxGlobal n_transforms, const T* twiddles, T scaling_factor, global_data_struct<1> global_data, sycl::kernel_handler& kh, @@ -108,6 +107,8 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* /*input_i detail::elementwise_multiply multiply_on_load = kh.get_specialization_constant(); detail::elementwise_multiply multiply_on_store = kh.get_specialization_constant(); detail::apply_scale_factor apply_scale_factor = kh.get_specialization_constant(); + bool take_conjugate_on_load = kh.get_specialization_constant(); + bool take_conjugate_on_store = kh.get_specialization_constant(); const Idx fft_size = kh.get_specialization_constant(); @@ -149,10 +150,10 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* /*input_i std::array{fft_size, 2 * num_batches_in_local_mem}); sycl::group_barrier(global_data.it.get_group()); for (Idx sub_batch = 0; sub_batch < num_batches_in_local_mem; sub_batch++) { - wg_dft(loc_view, loc_twiddles, wg_twiddles, scaling_factor, max_num_batches_in_local_mem, - sub_batch, offset / (2 * fft_size), load_modifier_data, store_modifier_data, fft_size, - factor_n, factor_m, LayoutIn, multiply_on_load, multiply_on_store, apply_scale_factor, - global_data); + wg_dft(loc_view, loc_twiddles, wg_twiddles, scaling_factor, max_num_batches_in_local_mem, + sub_batch, offset / (2 * fft_size), load_modifier_data, store_modifier_data, fft_size, + factor_n, factor_m, LayoutIn, multiply_on_load, multiply_on_store, apply_scale_factor, + take_conjugate_on_load, take_conjugate_on_store, global_data); sycl::group_barrier(global_data.it.get_group()); } if constexpr (LayoutOut == detail::layout::PACKED) { @@ -176,10 +177,10 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* /*input_i global_data.log_message_global(__func__, "loading non-transposed data from global to local memory"); global2local(global_data, input, loc_view, 2 * fft_size, offset); sycl::group_barrier(global_data.it.get_group()); - wg_dft(loc_view, loc_twiddles, wg_twiddles, scaling_factor, max_num_batches_in_local_mem, 0, - offset / static_cast(2 * fft_size), load_modifier_data, store_modifier_data, - fft_size, factor_n, factor_m, LayoutIn, multiply_on_load, multiply_on_store, - apply_scale_factor, global_data); + wg_dft(loc_view, loc_twiddles, wg_twiddles, scaling_factor, max_num_batches_in_local_mem, 0, + offset / static_cast(2 * fft_size), load_modifier_data, store_modifier_data, + fft_size, factor_n, factor_m, LayoutIn, multiply_on_load, multiply_on_store, + take_conjugate_on_load, take_conjugate_on_store, apply_scale_factor, global_data); sycl::group_barrier(global_data.it.get_group()); global_data.log_message_global(__func__, "storing non-transposed data from local to global memory"); // transposition for WG CT @@ -203,16 +204,16 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* /*input_i } // namespace detail template -template +template template -struct committed_descriptor::run_kernel_struct::run_kernel_struct::inner { static sycl::event execute(committed_descriptor& desc, const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, const std::vector& dependencies, IdxGlobal n_transforms, IdxGlobal input_offset, IdxGlobal output_offset, Scalar scale_factor, - dimension_struct& dimension_data) { - auto& kernel_data = dimension_data.kernels.at(0); + dimension_struct& dimension_data, direction compute_direction) { + auto& kernel_data = compute_direction == direction::FORWARD ? dimension_data.forward_kernels.at(0) + : dimension_data.backward_kernels.at(0); Idx num_batches_in_local_mem = [=]() { if constexpr (LayoutIn == detail::layout::BATCH_INTERLEAVED) { return kernel_data.used_sg_size * PORTFFT_SGS_IN_WG / 2; @@ -251,7 +252,7 @@ struct committed_descriptor::run_kernel_struct( + detail::workgroup_impl( &in_acc_or_usm[0] + input_offset, &out_acc_or_usm[0] + output_offset, &in_imag_acc_or_usm[0] + input_offset, &out_imag_acc_or_usm[0] + output_offset, &loc[0], &loc[0] + sg_twiddles_offset, n_transforms, twiddles, scale_factor, global_data, kh); diff --git a/src/portfft/dispatcher/workitem_dispatcher.hpp b/src/portfft/dispatcher/workitem_dispatcher.hpp index 533fac4d..19e87d24 100644 --- a/src/portfft/dispatcher/workitem_dispatcher.hpp +++ b/src/portfft/dispatcher/workitem_dispatcher.hpp @@ -80,7 +80,6 @@ PORTFFT_INLINE void apply_modifier(Idx num_elements, PrivT priv, const T* modifi /** * Implementation of FFT for sizes that can be done by independent work items. * - * @tparam Dir FFT direction, takes either direction::FORWARD or direction::BACKWARD * @tparam LayoutIn Input Layout * @tparam LayoutOut Output Layout * @tparam SubgroupSize size of the subgroup @@ -104,7 +103,7 @@ PORTFFT_INLINE void apply_modifier(Idx num_elements, PrivT priv, const T* modifi * @param loc_load_modifier Pointer to load modifier data in local memory * @param loc_store_modifier Pointer to store modifier data in local memory */ -template +template PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag, T* output_imag, T* loc, IdxGlobal n_transforms, T scaling_factor, global_data_struct<1> global_data, sycl::kernel_handler& kh, const T* load_modifier_data = nullptr, @@ -114,6 +113,9 @@ PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag detail::elementwise_multiply multiply_on_load = kh.get_specialization_constant(); detail::elementwise_multiply multiply_on_store = kh.get_specialization_constant(); detail::apply_scale_factor apply_scale_factor = kh.get_specialization_constant(); + bool take_conjugate_on_load = kh.get_specialization_constant(); + bool take_conjugate_on_store = kh.get_specialization_constant(); + const Idx fft_size = kh.get_specialization_constant(); global_data.log_message_global(__func__, "entered", "fft_size", fft_size, "n_transforms", n_transforms); @@ -207,7 +209,13 @@ PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag global_data.log_message_global(__func__, "applying load modifier"); detail::apply_modifier(fft_size, priv, load_modifier_data, i * n_reals); } - wi_dft(priv, priv, fft_size, 1, 1, wi_private_scratch); + if (take_conjugate_on_load) { + take_conjugate(priv, fft_size); + } + wi_dft<0>(priv, priv, fft_size, 1, 1, wi_private_scratch); + if (take_conjugate_on_store) { + take_conjugate(priv, fft_size); + } global_data.log_dump_private("data in registers after computation:", priv, n_reals); if (multiply_on_store == detail::elementwise_multiply::APPLIED) { // Assumes store modifier data is stored in a transposed fashion (fft_size x num_batches_local_mem) @@ -272,17 +280,17 @@ PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag } // namespace detail template -template +template template -struct committed_descriptor::run_kernel_struct::run_kernel_struct::inner { static sycl::event execute(committed_descriptor& desc, const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, const std::vector& dependencies, IdxGlobal n_transforms, IdxGlobal input_offset, IdxGlobal output_offset, Scalar scale_factor, - dimension_struct& dimension_data) { + dimension_struct& dimension_data, direction compute_direction) { constexpr detail::memory Mem = std::is_pointer_v ? detail::memory::USM : detail::memory::BUFFER; - auto& kernel_data = dimension_data.kernels.at(0); + auto& kernel_data = compute_direction == direction::FORWARD ? dimension_data.forward_kernels.at(0) + : dimension_data.backward_kernels.at(0); std::size_t local_elements = num_scalars_in_local_mem_struct::template inner::execute( desc, kernel_data.length, kernel_data.used_sg_size, kernel_data.factors, kernel_data.num_sgs_per_wg); @@ -308,7 +316,7 @@ struct committed_descriptor::run_kernel_struct( + detail::workitem_impl( &in_acc_or_usm[0] + input_offset, &out_acc_or_usm[0] + output_offset, &in_imag_acc_or_usm[0] + input_offset, &out_imag_acc_or_usm[0] + output_offset, &loc[0], n_transforms, scale_factor, global_data, kh); diff --git a/src/portfft/specialization_constant.hpp b/src/portfft/specialization_constant.hpp index cefd5b37..5c18773b 100644 --- a/src/portfft/specialization_constant.hpp +++ b/src/portfft/specialization_constant.hpp @@ -44,6 +44,9 @@ constexpr static sycl::specialization_id GlobalSubImplSpecConst{}; constexpr static sycl::specialization_id GlobalSpecConstLevelNum{}; constexpr static sycl::specialization_id GlobalSpecConstNumFactors{}; +constexpr static sycl::specialization_id SpecConstTakeConjugateOnLoad{}; +constexpr static sycl::specialization_id SpecConstTakeConjugateOnStore{}; + } // namespace detail } // namespace portfft #endif diff --git a/src/portfft/utils.hpp b/src/portfft/utils.hpp index f9b086c9..8b01b7d4 100644 --- a/src/portfft/utils.hpp +++ b/src/portfft/utils.hpp @@ -41,39 +41,34 @@ class transpose_kernel; * @tparam SubgroupSize size of the subgroup * @return vector of kernel ids */ -template