diff --git a/.clang-tidy b/.clang-tidy index 0b3225d5..1fe9e338 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -14,6 +14,7 @@ Checks: > performance-*, -performance-avoid-endl, readability-*, + -readability-magic-numbers, -readability-function-cognitive-complexity, -readability-identifier-length, -readability-named-parameter, diff --git a/README.md b/README.md index fa0b3cb4..57012ffe 100644 --- a/README.md +++ b/README.md @@ -98,10 +98,8 @@ portFFT is still in early development. The supported configurations are: * Arbitrary forward and backward offsets * Arbitrary strides and distance where the problem size + auxilary data fits in the registers of a single work-item. -Any 1D arbitrarily large input size that fits in global memory is supported, with a restriction that large input sizes should not have large prime factors. -The largest prime factor depend on the device and the values set by `PORTFFT_REGISTERS_PER_WI` and `PORTFFT_SUBGROUP_SIZES`. -For instance with `PORTFFT_REGISTERS_PER_WI` set to `128` (resp. `256`) each work-item can hold a maximum of 27 (resp. 56) complex values, thus with `PORTFFT_SUBGROUP_SIZES` set to `32` the largest prime factor cannot exceed `27*32=864` (resp. `56*32=1792`). -portFFT may allocate up to `2 * PORTFFT_MAX_CONCURRENT_KERNELS * input_size` scratch memory, depending on the configuration passed. +Any 1D arbitrarily large input size that fits in global memory is supported. +portFFT may allocate up to `2 * PORTFFT_MAX_CONCURRENT_KERNELS * input_size` scratch memory, in addition to memory allocated to hold precomputed values to be used during compute, depending on the configuration passed. Any batch size is supported as long as the input and output data fits in global memory. diff --git a/src/portfft/committed_descriptor_impl.hpp b/src/portfft/committed_descriptor_impl.hpp index d985e816..19f89996 100644 --- a/src/portfft/committed_descriptor_impl.hpp +++ b/src/portfft/committed_descriptor_impl.hpp @@ -45,22 +45,26 @@ namespace detail { template class committed_descriptor_impl; +template +using kernels_vec = std::vector::kernel_data_struct>; template -std::vector compute_level( - const typename committed_descriptor_impl::kernel_data_struct& kd_struct, const TIn& input, - Scalar* output, const TIn& input_imag, Scalar* output_imag, const Scalar* twiddles_ptr, - const IdxGlobal* factors_triple, IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, - IdxGlobal input_global_offset, IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, - IdxGlobal batch_start, Idx factor_id, Idx total_factors, complex_storage storage, - const std::vector& dependencies, sycl::queue& queue); +std::vector compute_level(const typename committed_descriptor_impl::kernel_data_struct&, + const TIn&, Scalar*, const TIn&, Scalar*, const Scalar*, const Scalar*, + const Scalar*, const IdxGlobal*, IdxGlobal, IdxGlobal, Idx, IdxGlobal, IdxGlobal, + Idx, complex_storage, const std::vector&, sycl::queue&); template -sycl::event transpose_level(const typename committed_descriptor_impl::kernel_data_struct& kd_struct, - const Scalar* input, TOut output, const IdxGlobal* factors_triple, IdxGlobal committed_size, - Idx num_batches_in_l2, IdxGlobal n_transforms, IdxGlobal batch_start, Idx total_factors, - IdxGlobal output_offset, sycl::queue& queue, const std::vector& events, - complex_storage storage); +sycl::event transpose_level(const typename committed_descriptor_impl::kernel_data_struct&, + const Scalar*, TOut, const IdxGlobal*, IdxGlobal, Idx, IdxGlobal, IdxGlobal, Idx, IdxGlobal, + sycl::queue&, const std::vector&, complex_storage); + +template +sycl::event global_impl_driver(const TIn&, const TIn&, TOut, TOut, committed_descriptor_impl&, + typename committed_descriptor_impl::dimension_struct&, + const kernels_vec&, const kernels_vec&, Idx, IdxGlobal, + IdxGlobal, std::size_t, std::size_t, IdxGlobal, IdxGlobal, IdxGlobal, complex_storage, + detail::elementwise_multiply, const Scalar*); // kernel names template @@ -84,20 +88,24 @@ template class committed_descriptor_impl { friend struct descriptor; template - friend std::vector detail::compute_level( - const typename committed_descriptor_impl::kernel_data_struct& kd_struct, const TIn& input, - Scalar1* output, const TIn& input_imag, Scalar1* output_imag, const Scalar1* twiddles_ptr, - const IdxGlobal* factors_triple, IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, - IdxGlobal input_global_offset, IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, - IdxGlobal batch_start, Idx factor_id, Idx total_factors, complex_storage storage, - const std::vector& dependencies, sycl::queue& queue); + friend std::vector compute_level( + const typename committed_descriptor_impl::kernel_data_struct&, const TIn&, Scalar1*, const TIn&, + Scalar1*, const Scalar1*, const Scalar1*, const Scalar1*, const IdxGlobal*, IdxGlobal, IdxGlobal, Idx, IdxGlobal, + IdxGlobal, Idx, complex_storage, const std::vector&, sycl::queue&); template friend sycl::event detail::transpose_level( - const typename committed_descriptor_impl::kernel_data_struct& kd_struct, const Scalar1* input, - TOut output, const IdxGlobal* factors_triple, IdxGlobal committed_size, Idx num_batches_in_l2, - IdxGlobal n_transforms, IdxGlobal batch_start, Idx total_factors, IdxGlobal output_offset, sycl::queue& queue, - const std::vector& events, complex_storage storage); + const typename committed_descriptor_impl::kernel_data_struct&, const Scalar1*, TOut, + const IdxGlobal*, IdxGlobal, Idx, IdxGlobal, IdxGlobal, Idx, IdxGlobal, sycl::queue&, + const std::vector&, complex_storage); + + template + friend sycl::event global_impl_driver(const TIn&, const TIn&, TOut, TOut, + committed_descriptor_impl&, + typename committed_descriptor_impl::dimension_struct&, + const kernels_vec&, const kernels_vec&, Idx, + IdxGlobal, IdxGlobal, std::size_t, std::size_t, IdxGlobal, IdxGlobal, IdxGlobal, + complex_storage, detail::elementwise_multiply, const Scalar1*); /** * Vector containing the sub-implementation level, kernel_ids and factors for each factor that requires a separate @@ -105,6 +113,12 @@ class committed_descriptor_impl { */ using kernel_ids_and_metadata_t = std::vector, std::vector>>; + /** + * Tuple of the level, an input kernel_bundle, and factors pertaining to each factor of the committed size + */ + using input_bundles_and_metadata_t = + std::tuple, std::vector>; + descriptor params; sycl::queue queue; sycl::device dev; @@ -148,18 +162,32 @@ class committed_descriptor_impl { std::vector transpose_kernels; std::shared_ptr factors_and_scan; detail::level level; + // The problem size for which DFT will be computed std::size_t length; + // The committed size corresponding to the dimension + std::size_t committed_length; Idx used_sg_size; Idx num_batches_in_l2; - Idx num_factors; + Idx num_forward_factors; + Idx num_backward_factors; + bool is_prime; + IdxGlobal backward_twiddles_offset; + IdxGlobal bluestein_modifiers_offset; + IdxGlobal forward_impl_twiddle_offset; + IdxGlobal backward_impl_twiddle_offset; dimension_struct(std::vector forward_kernels, std::vector backward_kernels, - detail::level level, std::size_t length, Idx used_sg_size) + detail::level level, std::size_t length, std::size_t committed_length, Idx used_sg_size, + Idx num_forward_factors, Idx num_backward_factors, bool is_prime) : forward_kernels(std::move(forward_kernels)), backward_kernels(std::move(backward_kernels)), level(level), length(length), - used_sg_size(used_sg_size) {} + committed_length(committed_length), + used_sg_size(used_sg_size), + num_forward_factors(num_forward_factors), + num_backward_factors(num_backward_factors), + is_prime(is_prime) {} }; std::vector dimensions; @@ -204,11 +232,11 @@ class committed_descriptor_impl { * * @tparam SubgroupSize size of the subgroup * @param kernel_num the consecutive number of the kernel to prepare - * @return implementation to use for the dimension and a vector of tuples of: implementation to use for a kernel, - * vector of kernel ids, factors + * @return implementation to use for the dimension and a vector of tuples of: implementation to use for a kernel, the + * size of the transform and vector of kernel ids, factors */ template - std::tuple prepare_implementation(std::size_t kernel_num) { + std::tuple prepare_implementation(std::size_t kernel_num) { PORTFFT_LOG_FUNCTION_ENTRY(); // TODO: check and support all the parameter values if constexpr (Domain != domain::COMPLEX) { @@ -221,7 +249,9 @@ class committed_descriptor_impl { if (detail::fits_in_wi(fft_size)) { ids = detail::get_ids(); PORTFFT_LOG_TRACE("Prepared workitem impl for size: ", fft_size); - return {detail::level::WORKITEM, {{detail::level::WORKITEM, ids, factors}}}; + return {detail::level::WORKITEM, + static_cast(fft_size), + {{detail::level::WORKITEM, ids, {static_cast(fft_size)}}}}; } if (detail::fits_in_sg(fft_size, SubgroupSize)) { Idx factor_sg = detail::factorize_sg(static_cast(fft_size), SubgroupSize); @@ -232,7 +262,7 @@ class committed_descriptor_impl { factors.push_back(factor_sg); ids = detail::get_ids(); PORTFFT_LOG_TRACE("Prepared subgroup impl with factor_wi:", factor_wi, "and factor_sg:", factor_sg); - return {detail::level::SUBGROUP, {{detail::level::SUBGROUP, ids, factors}}}; + return {detail::level::SUBGROUP, static_cast(fft_size), {{detail::level::SUBGROUP, ids, factors}}}; } IdxGlobal n_idx_global = detail::factorize(fft_size); if (detail::can_cast_safely(n_idx_global) && @@ -265,11 +295,12 @@ class committed_descriptor_impl { ids = detail::get_ids(); PORTFFT_LOG_TRACE("Prepared workgroup impl with factor_wi_n:", factor_wi_n, " factor_sg_n:", factor_sg_n, " factor_wi_m:", factor_wi_m, " factor_sg_m:", factor_sg_m); - return {detail::level::WORKGROUP, {{detail::level::WORKGROUP, ids, factors}}}; + return { + detail::level::WORKGROUP, static_cast(fft_size), {{detail::level::WORKGROUP, ids, factors}}}; } } PORTFFT_LOG_TRACE("Preparing global impl"); - std::vector, std::vector>> param_vec; + kernel_ids_and_metadata_t param_vec; auto check_and_select_target_level = [&](IdxGlobal factor_size, bool batch_interleaved_layout = true) -> bool { if (detail::fits_in_wi(factor_size)) { // Throughout we have assumed there would always be enough local memory for the WI implementation. @@ -286,12 +317,10 @@ class committed_descriptor_impl { if (detail::can_cast_safely(factor_sg) && detail::can_cast_safely(factor_wi)) { std::size_t input_scalars = num_scalars_in_local_mem(detail::level::SUBGROUP, static_cast(factor_size), SubgroupSize, - {static_cast(factor_sg), static_cast(factor_wi)}, temp_num_sgs_in_wg, + {static_cast(factor_wi), static_cast(factor_sg)}, temp_num_sgs_in_wg, batch_interleaved_layout ? layout::BATCH_INTERLEAVED : layout::PACKED); - std::size_t store_modifiers = batch_interleaved_layout ? input_scalars : 0; std::size_t twiddle_scalars = 2 * static_cast(factor_size); - return (sizeof(Scalar) * (input_scalars + store_modifiers + twiddle_scalars)) < - static_cast(local_memory_size); + return (sizeof(Scalar) * (input_scalars + twiddle_scalars)) <= static_cast(local_memory_size); } return false; }(); @@ -303,13 +332,22 @@ class committed_descriptor_impl { "and factor_sg:", factor_sg); param_vec.emplace_back(detail::level::SUBGROUP, detail::get_ids(), - std::vector{factor_sg, factor_wi}); + std::vector{factor_wi, factor_sg}); return true; } return false; }; - detail::factorize_input(fft_size, check_and_select_target_level); - return {detail::level::GLOBAL, param_vec}; + bool encountered_large_prime = detail::factorize_input(fft_size, check_and_select_target_level); + if (encountered_large_prime) { + param_vec.clear(); + auto padded_fft_size = detail::get_padded_length(static_cast(fft_size)); + + // Forward DFT within Bluestein implementation (pre convolution). + detail::factorize_input(padded_fft_size, check_and_select_target_level); + // Backward DFT within Bluestein implementation (post convolution). + detail::factorize_input(padded_fft_size, check_and_select_target_level); + } + return {detail::level::GLOBAL, static_cast(fft_size), param_vec}; } /** @@ -437,6 +475,119 @@ class committed_descriptor_impl { return dispatch(level, dimension_data, kernels); } + /** + * Sets the specialization constants for the global implementation + * @param input_kernels_and_metadata vector of input_bundles_and_metadata_t + * @param num_forward_factors Number of forward factors + * @param num_backward_factors Number of backward factors + * @param compute_direction direction of compute: forward / backward + * @param is_prime whether or not the dimension size is a prime number + * @param scale_factor Scaling factor to be applied to the result + */ + void set_global_impl_spec_constants(std::vector& input_kernels_and_metadata, + std::size_t num_forward_factors, std::size_t num_backward_factors, + direction compute_direction, bool is_prime, Scalar scale_factor) { + std::vector global_impl_factors; + std::vector inner_batches; + for (std::size_t i = 0; i < num_forward_factors; i++) { + const auto& [level, input_bundle, factors] = input_kernels_and_metadata.at(i); + global_impl_factors.push_back( + static_cast(std::accumulate(factors.begin(), factors.end(), 1, std::multiplies()))); + } + for (std::size_t i = 0; i < num_forward_factors; i++) { + inner_batches.push_back(std::accumulate(global_impl_factors.begin() + static_cast(i + 1), + global_impl_factors.end(), IdxGlobal(1), std::multiplies())); + } + + for (std::size_t i = 0; i < num_backward_factors; i++) { + const auto& [level, input_bundle, factors] = input_kernels_and_metadata.at(num_forward_factors + i); + global_impl_factors.push_back( + static_cast(std::accumulate(factors.begin(), factors.end(), 1, std::multiplies()))); + } + for (std::size_t i = 0; i < num_backward_factors; i++) { + inner_batches.push_back( + std::accumulate(global_impl_factors.begin() + static_cast(num_forward_factors + i + 1), + global_impl_factors.end(), IdxGlobal(1), std::multiplies())); + } + + for (std::size_t i = 0; i < num_forward_factors; i++) { + auto conjugate_on_load = detail::complex_conjugate::NOT_APPLIED; + auto conjugate_on_store = detail::complex_conjugate::NOT_APPLIED; + auto multiply_on_load = detail::elementwise_multiply::NOT_APPLIED; + auto multiply_on_store = detail::elementwise_multiply::APPLIED; + auto scale_factor_applied = detail::apply_scale_factor::NOT_APPLIED; + + IdxGlobal input_stride = global_impl_factors.at(i); + IdxGlobal output_stride = global_impl_factors.at(i); + IdxGlobal input_distance = 1; + IdxGlobal output_distance = 1; + + if (i == num_forward_factors - 1) { + input_stride = 1; + output_stride = 1; + input_distance = global_impl_factors.at(i); + output_distance = global_impl_factors.at(i); + } + + if (i == 0 && compute_direction == direction::BACKWARD) { + conjugate_on_load = detail::complex_conjugate::APPLIED; + } + if (i == 0 && is_prime) { + multiply_on_load = detail::elementwise_multiply::APPLIED; + } + if (i == num_forward_factors - 1) { + if (compute_direction == direction::BACKWARD && !is_prime) { + conjugate_on_store = detail::complex_conjugate::APPLIED; + } + if (!is_prime) { + multiply_on_store = detail::elementwise_multiply::NOT_APPLIED; + } + if (!is_prime) { + scale_factor_applied = detail::apply_scale_factor::APPLIED; + } + } + auto& [level, input_bundle, factors] = input_kernels_and_metadata.at(i); + set_spec_constants(detail::level::GLOBAL, input_bundle, static_cast(global_impl_factors.at(i)), factors, + multiply_on_load, multiply_on_store, scale_factor_applied, level, conjugate_on_load, + conjugate_on_store, scale_factor, input_stride, output_stride, input_distance, output_distance, + Idx(i), static_cast(num_forward_factors)); + } + + for (std::size_t i = 0; i < num_backward_factors; i++) { + auto conjugate_on_load = detail::complex_conjugate::NOT_APPLIED; + auto conjugate_on_store = detail::complex_conjugate::NOT_APPLIED; + auto multiply_on_load = detail::elementwise_multiply::NOT_APPLIED; + auto multiply_on_store = detail::elementwise_multiply::APPLIED; + auto scale_factor_applied = detail::apply_scale_factor::NOT_APPLIED; + + IdxGlobal input_stride = global_impl_factors.at(num_forward_factors + i); + IdxGlobal output_stride = global_impl_factors.at(num_forward_factors + i); + IdxGlobal input_distance = 1; + IdxGlobal output_distance = 1; + + if (i == num_forward_factors - 1) { + input_stride = 1; + output_stride = 1; + input_distance = global_impl_factors.at(num_forward_factors + i); + output_distance = global_impl_factors.at(num_forward_factors + i); + } + + if (i == 0) { + conjugate_on_load = detail::complex_conjugate::APPLIED; + } + if (i == num_forward_factors - 1) { + multiply_on_store = detail::elementwise_multiply::APPLIED; + scale_factor_applied = detail::apply_scale_factor::APPLIED; + } + auto& [level, input_bundle, factors] = input_kernels_and_metadata.at(num_forward_factors + i); + set_spec_constants(detail::level::GLOBAL, input_bundle, + static_cast(global_impl_factors.at(num_forward_factors + i)), factors, multiply_on_load, + multiply_on_store, scale_factor_applied, level, conjugate_on_load, conjugate_on_store, + scale_factor, input_stride, output_stride, input_distance, output_distance, Idx(i), + static_cast(num_backward_factors)); + } + } + /** * Sets the specialization constants for all the kernel_ids contained in the vector * returned from prepare_implementation @@ -446,91 +597,68 @@ class committed_descriptor_impl { * vector of kernel ids, factors * @param compute_direction direction of compute: forward or backward * @param dimension_num which dimension are the kernels being built for - * @param skip_scaling whether or not to skip scaling * @return vector of kernel_data_struct if all kernel builds are successful, std::nullopt otherwise */ template - std::optional> set_spec_constants_driver(detail::level top_level, - kernel_ids_and_metadata_t& prepared_vec, - direction compute_direction, - std::size_t dimension_num) { + std::optional> set_spec_constants_driver( + detail::level top_level, kernel_ids_and_metadata_t& prepared_vec, direction compute_direction, + std::size_t dimension_num, Idx num_forward_factors, Idx num_backward_factors) { Scalar scale_factor = compute_direction == direction::FORWARD ? params.forward_scale : params.backward_scale; - std::size_t counter = 0; - IdxGlobal remaining_factors_prod = static_cast(params.get_flattened_length()); std::vector result; - for (auto [level, ids, factors] : prepared_vec) { - const bool is_multi_dim = params.lengths.size() > 1; - const bool is_global = top_level == detail::level::GLOBAL; - const bool is_final_factor = counter == (prepared_vec.size() - 1); - const bool is_final_dim = dimension_num == (params.lengths.size() - 1); - const bool is_backward = compute_direction == direction::BACKWARD; - if (is_multi_dim && is_global) { - throw unsupported_configuration("multidimensional global transforms are not supported."); - } + std::vector input_kernels_and_metadata; + bool skip_scaling = dimension_num != params.lengths.size() - 1; + for (const auto& [level, kernel_ids, factors] : prepared_vec) { + input_kernels_and_metadata.emplace_back( + level, sycl::get_kernel_bundle(queue.get_context(), kernel_ids), factors); + } - const auto multiply_on_store = is_global && !is_final_factor ? detail::elementwise_multiply::APPLIED - : detail::elementwise_multiply::NOT_APPLIED; - const auto conjugate_on_load = - is_backward && counter == 0 ? detail::complex_conjugate::APPLIED : detail::complex_conjugate::NOT_APPLIED; - const auto conjugate_on_store = - is_backward && is_final_factor ? detail::complex_conjugate::APPLIED : detail::complex_conjugate::NOT_APPLIED; - const auto apply_scale = is_final_factor && is_final_dim ? detail::apply_scale_factor::APPLIED - : detail::apply_scale_factor::NOT_APPLIED; - - Idx length{}; - IdxGlobal forward_stride{}; - IdxGlobal backward_stride{}; - IdxGlobal forward_distance{}; - IdxGlobal backward_distance{}; - - if (is_global) { - length = std::accumulate(factors.begin(), factors.end(), Idx(1), std::multiplies()); - - remaining_factors_prod /= length; - forward_stride = remaining_factors_prod; - backward_stride = remaining_factors_prod; - forward_distance = is_final_factor ? length : 1; - backward_distance = is_final_factor ? length : 1; - - } else { - length = static_cast(params.lengths[dimension_num]); - forward_stride = static_cast(params.forward_strides[dimension_num]); - backward_stride = static_cast(params.backward_strides[dimension_num]); - if (is_multi_dim) { - if (is_final_dim) { - forward_distance = length; - backward_distance = length; - } else { - forward_distance = 1; - backward_distance = 1; - } - } else { - forward_distance = static_cast(params.forward_distance); - backward_distance = static_cast(params.backward_distance); + if (top_level == detail::level::GLOBAL) { + set_global_impl_spec_constants(input_kernels_and_metadata, static_cast(num_forward_factors), + static_cast(num_backward_factors), compute_direction, + num_backward_factors > 0, scale_factor); + } else { + detail::complex_conjugate conjugate_on_load = detail::complex_conjugate::NOT_APPLIED; + detail::complex_conjugate conjugate_on_store = detail::complex_conjugate::NOT_APPLIED; + detail::apply_scale_factor scale_factor_applied = detail::apply_scale_factor::APPLIED; + const auto input_stride = compute_direction == direction::FORWARD ? params.forward_strides[dimension_num] + : params.backward_strides[dimension_num]; + const auto output_stride = compute_direction == direction::FORWARD ? params.backward_strides[dimension_num] + : params.forward_strides[dimension_num]; + const auto input_distance = + compute_direction == direction::FORWARD ? params.forward_distance : params.backward_distance; + const auto output_distance = + compute_direction == direction::FORWARD ? params.backward_distance : params.forward_distance; + + if (compute_direction == direction::BACKWARD) { + if (dimension_num == 0) { + conjugate_on_load = detail::complex_conjugate::APPLIED; + } + if (dimension_num == params.lengths.size() - 1) { + conjugate_on_store = detail::complex_conjugate::APPLIED; } } + if (skip_scaling) { + scale_factor_applied = detail::apply_scale_factor::NOT_APPLIED; + } + for (auto& [level, input_bundle, factors] : input_kernels_and_metadata) { + set_spec_constants(level, input_bundle, static_cast(params.lengths[dimension_num]), factors, + detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::NOT_APPLIED, + scale_factor_applied, level, conjugate_on_load, conjugate_on_store, scale_factor, + static_cast(input_stride), static_cast(output_stride), + static_cast(input_distance), static_cast(output_distance)); + } + } - const IdxGlobal input_stride = compute_direction == direction::FORWARD ? forward_stride : backward_stride; - const IdxGlobal output_stride = compute_direction == direction::FORWARD ? backward_stride : forward_stride; - const IdxGlobal input_distance = compute_direction == direction::FORWARD ? forward_distance : backward_distance; - const IdxGlobal output_distance = compute_direction == direction::FORWARD ? backward_distance : forward_distance; - - auto in_bundle = sycl::get_kernel_bundle(queue.get_context(), ids); - - set_spec_constants(top_level, in_bundle, length, factors, detail::elementwise_multiply::NOT_APPLIED, - multiply_on_store, apply_scale, level, conjugate_on_load, conjugate_on_store, scale_factor, - input_stride, output_stride, input_distance, output_distance, static_cast(counter), - static_cast(prepared_vec.size())); + for (const auto& [level, input_bundle, factors] : input_kernels_and_metadata) { try { - PORTFFT_LOG_TRACE("Building kernel bundle with subgroup size", SubgroupSize); - result.emplace_back(sycl::build(in_bundle), factors, params.lengths[dimension_num], SubgroupSize, - PORTFFT_SGS_IN_WG, std::shared_ptr(), level); - PORTFFT_LOG_TRACE("Kernel bundle build complete."); - } catch (std::exception& e) { + result.emplace_back( + sycl::build(input_bundle), factors, + static_cast(std::accumulate(factors.begin(), factors.end(), 1, std::multiplies())), + SubgroupSize, PORTFFT_SGS_IN_WG, std::shared_ptr(), level); + } catch (const std::exception& e) { PORTFFT_LOG_WARNING("Build for subgroup size", SubgroupSize, "failed with message:\n", e.what()); return std::nullopt; } - counter++; } return result; } @@ -549,23 +677,33 @@ class committed_descriptor_impl { dimension_struct build_w_spec_const(std::size_t dimension_num) { PORTFFT_LOG_FUNCTION_ENTRY(); if (std::count(supported_sg_sizes.begin(), supported_sg_sizes.end(), SubgroupSize)) { - auto [top_level, prepared_vec] = prepare_implementation(dimension_num); + auto [top_level, dimension_size, prepared_vec] = prepare_implementation(dimension_num); bool is_compatible = true; - for (auto [level, ids, factors] : prepared_vec) { + std::size_t accumulated_size = 1; + Idx num_forward_factors = 0; + for (const auto& [level, ids, factors] : prepared_vec) { is_compatible = is_compatible && sycl::is_compatible(ids, dev); if (!is_compatible) { break; } + if (accumulated_size == dimension_size) { + break; + } + accumulated_size *= + static_cast(std::accumulate(factors.begin(), factors.end(), 1, std::multiplies())); + num_forward_factors++; } - + Idx num_backward_factors = static_cast(prepared_vec.size()) - num_forward_factors; + bool is_prime = static_cast(dimension_size != params.lengths[dimension_num]); if (is_compatible) { - auto forward_kernels = - set_spec_constants_driver(top_level, prepared_vec, direction::FORWARD, dimension_num); - auto backward_kernels = - set_spec_constants_driver(top_level, prepared_vec, direction::BACKWARD, dimension_num); + auto forward_kernels = set_spec_constants_driver( + top_level, prepared_vec, direction::FORWARD, dimension_num, num_forward_factors, num_backward_factors); + auto backward_kernels = set_spec_constants_driver( + top_level, prepared_vec, direction::BACKWARD, dimension_num, num_forward_factors, num_backward_factors); if (forward_kernels.has_value() && backward_kernels.has_value()) { - return {forward_kernels.value(), backward_kernels.value(), top_level, params.lengths[dimension_num], - SubgroupSize}; + return {forward_kernels.value(), backward_kernels.value(), top_level, + dimension_size, params.lengths[dimension_num], SubgroupSize, + num_forward_factors, num_backward_factors, is_prime}; } } } @@ -576,138 +714,137 @@ class committed_descriptor_impl { } } + /** + * Builds transpose kernels required for global implementation + * @param dimension_data dimension_struct associated with the dimension + * @param num_transpositions Number of transpose kernels to build + * @param ld_input vector containing leading dimensions of the inputs for transpositions at each level + * @param ld_output vector containing leading dimensions of the outputs for transpositions at each level + */ + void build_transpose_kernels(dimension_struct& dimension_data, std::size_t num_transpositions, + std::vector& ld_input, std::vector& ld_output) { + for (std::size_t i = 0; i < num_transpositions; i++) { + auto in_bundle = sycl::get_kernel_bundle(queue.get_context(), + detail::get_transpose_kernel_ids()); + in_bundle.template set_specialization_constant(static_cast(i)); + in_bundle.template set_specialization_constant( + static_cast(num_transpositions + 1)); + try { + dimension_data.transpose_kernels.emplace_back( + sycl::build(in_bundle), + std::vector{static_cast(ld_input.at(i)), static_cast(ld_output.at(i))}, 1, 1, 1, + std::shared_ptr(), detail::level::GLOBAL); + } catch (const std::exception& e) { + throw internal_error("Error building transpose kernel: ", e.what()); + } + } + } + + /** + * Precomputes the inclusive scan required for the global implementation, and populates the device pointer containing + * the same. Also calculates the ideal amount of llc cache size occupied by the load/store modifiers and returns the + * same. + * @param dimension_data Dimension struct for which the inclusive scan is being precomputed + * @param num_factors Number of factors + * @param kernel_offset Index from which the kernel_data_struct are to be considered + * @param ptr Pointer to the global memory for the precomputed data. + * @return cache space in number of bytes required for the load/store modifiers + */ + IdxGlobal precompute_scan_impl(dimension_struct& dimension_data, std::size_t num_factors, std::size_t kernel_offset, + IdxGlobal* ptr) { + std::vector factors; + std::vector inner_batches; + std::vector inclusive_scan; + + for (std::size_t i = 0; i < num_factors; i++) { + const auto& kernel_data = dimension_data.forward_kernels.at(kernel_offset + i); + factors.push_back(static_cast(kernel_data.length)); + inner_batches.push_back(kernel_data.batch_size); + } + + inclusive_scan.push_back(factors.at(0)); + for (std::size_t i = 1; i < static_cast(num_factors); i++) { + inclusive_scan.push_back(inclusive_scan.at(i - 1) * factors.at(i)); + } + queue.copy(factors.data(), ptr, factors.size()); + queue.copy(inner_batches.data(), ptr + factors.size(), inner_batches.size()); + queue.copy(inclusive_scan.data(), ptr + factors.size() + inner_batches.size(), inclusive_scan.size()); + + build_transpose_kernels(dimension_data, num_factors - 1, inner_batches, factors); + + // calculate Ideal amount of llc cache required for load/store + std::size_t llc_cache_space_for_twiddles = 0; + for (std::size_t i = 0; i < num_factors - 1; i++) { + llc_cache_space_for_twiddles += + static_cast(2 * factors.at(i) * inner_batches.at(i)) * sizeof(Scalar); + } + + if (dimension_data.is_prime) { + llc_cache_space_for_twiddles += 4 * dimension_data.length * sizeof(Scalar); + } + queue.wait(); + return static_cast(llc_cache_space_for_twiddles); + } + + /** + * Gets the number of transforms to accomodate in the last level cache + * @param llc_cache_space_for_twiddles Amount of cache space in bytes required for load/store modifiers + * @param n_transforms The Batch size correspoding to the dimension size + * @param length length of the transform + * @return + */ + Idx get_num_batches_in_llc(IdxGlobal llc_cache_space_for_twiddles, IdxGlobal n_transforms, std::size_t length) { + IdxGlobal cache_space_remaining = + std::max(IdxGlobal(0), static_cast(llc_size) - llc_cache_space_for_twiddles); + IdxGlobal sizeof_one_transform = static_cast(2 * length * sizeof(Scalar)); + + return static_cast( + std::min(IdxGlobal(PORTFFT_MAX_CONCURRENT_KERNELS), + std::min(n_transforms, std::max(IdxGlobal(1), cache_space_remaining / sizeof_one_transform)))); + } + /** * Function which calculates the amount of scratch space required, and also pre computes the necessary scans required. + * Builds the transpose kernels required for the global implementation * @param num_global_level_dimensions number of global level dimensions in the committed size */ void allocate_scratch_and_precompute_scan(Idx num_global_level_dimensions) { PORTFFT_LOG_FUNCTION_ENTRY(); - std::size_t n_kernels = params.lengths.size(); + std::size_t n_dimensions = params.lengths.size(); if (num_global_level_dimensions == 1) { std::size_t global_dimension = 0; - for (std::size_t i = 0; i < n_kernels; i++) { + for (std::size_t i = 0; i < n_dimensions; i++) { if (dimensions.at(i).level == detail::level::GLOBAL) { global_dimension = i; break; } } - std::vector factors; - std::vector sub_batches; - std::vector inclusive_scan; - std::size_t cache_required_for_twiddles = 0; - for (const auto& kernel_data : dimensions.at(global_dimension).forward_kernels) { - IdxGlobal factor_size = static_cast( - std::accumulate(kernel_data.factors.begin(), kernel_data.factors.end(), 1, std::multiplies())); - cache_required_for_twiddles += - static_cast(2 * factor_size * kernel_data.batch_size) * sizeof(Scalar); - factors.push_back(factor_size); - sub_batches.push_back(kernel_data.batch_size); - } - dimensions.at(global_dimension).num_factors = static_cast(factors.size()); - std::size_t cache_space_left_for_batches = static_cast(llc_size) - cache_required_for_twiddles; - // TODO: In case of multi-dim (single dim global sized), this should be batches corresponding to that dim - dimensions.at(global_dimension).num_batches_in_l2 = static_cast(std::min( - static_cast(PORTFFT_MAX_CONCURRENT_KERNELS), - std::min(params.number_of_transforms, - std::max(std::size_t(1), cache_space_left_for_batches / - (2 * dimensions.at(global_dimension).length * sizeof(Scalar)))))); - scratch_space_required = 2 * dimensions.at(global_dimension).length * - static_cast(dimensions.at(global_dimension).num_batches_in_l2); - PORTFFT_LOG_TRACE("Allocating 2 scratch arrays of size", scratch_space_required, "scalars in global memory"); - scratch_ptr_1 = detail::make_shared(scratch_space_required, queue); - scratch_ptr_2 = detail::make_shared(scratch_space_required, queue); - inclusive_scan.push_back(factors.at(0)); - for (std::size_t i = 1; i < factors.size(); i++) { - inclusive_scan.push_back(inclusive_scan.at(i - 1) * factors.at(i)); - } - PORTFFT_LOG_TRACE("Dimension:", global_dimension, - "num_batches_in_l2:", dimensions.at(global_dimension).num_batches_in_l2, - "scan:", inclusive_scan); - dimensions.at(global_dimension).factors_and_scan = - detail::make_shared(factors.size() + sub_batches.size() + inclusive_scan.size(), queue); - queue.copy(factors.data(), dimensions.at(global_dimension).factors_and_scan.get(), factors.size()); - queue.copy(sub_batches.data(), dimensions.at(global_dimension).factors_and_scan.get() + factors.size(), - sub_batches.size()); - queue.copy(inclusive_scan.data(), - dimensions.at(global_dimension).factors_and_scan.get() + factors.size() + sub_batches.size(), - inclusive_scan.size()); - queue.wait(); - // build transpose kernels - std::size_t num_transposes_required = factors.size() - 1; - for (std::size_t i = 0; i < num_transposes_required; i++) { - std::vector ids; - auto in_bundle = sycl::get_kernel_bundle(queue.get_context(), - detail::get_transpose_kernel_ids()); - PORTFFT_LOG_TRACE("Setting specialization constants for transpose kernel", i); - PORTFFT_LOG_TRACE("SpecConstComplexStorage:", params.complex_storage); - in_bundle.template set_specialization_constant(params.complex_storage); - PORTFFT_LOG_TRACE("GlobalSpecConstLevelNum:", i); - in_bundle.template set_specialization_constant(static_cast(i)); - PORTFFT_LOG_TRACE("GlobalSpecConstNumFactors:", factors.size()); - in_bundle.template set_specialization_constant( - static_cast(factors.size())); - dimensions.at(global_dimension) - .transpose_kernels.emplace_back( - sycl::build(in_bundle), - std::vector{static_cast(factors.at(i)), static_cast(sub_batches.at(i))}, 1, 1, 1, - std::shared_ptr(), detail::level::GLOBAL); + auto& dimension_data = dimensions.at(global_dimension); + std::size_t space_for_scans = + static_cast(3 * (dimension_data.num_forward_factors + + (dimension_data.is_prime ? dimension_data.num_backward_factors : 0))); + dimension_data.factors_and_scan = detail::make_shared(space_for_scans, queue); + IdxGlobal cache_req_for_modifiers = static_cast( + precompute_scan_impl(dimension_data, static_cast(dimension_data.num_forward_factors), 0, + dimension_data.factors_and_scan.get())); + Idx num_batches_in_llc = get_num_batches_in_llc(cache_req_for_modifiers, IdxGlobal(params.number_of_transforms), + dimension_data.length); + scratch_ptr_1 = + detail::make_shared(2 * dimension_data.length * static_cast(num_batches_in_llc), queue); + scratch_ptr_2 = + detail::make_shared(2 * dimension_data.length * static_cast(num_batches_in_llc), queue); + dimension_data.num_batches_in_l2 = num_batches_in_llc; + + if (dimension_data.is_prime) { + // only need populate the scans and build transpose kernels + precompute_scan_impl(dimension_data, static_cast(dimension_data.num_backward_factors), + static_cast(dimension_data.num_forward_factors), + dimension_data.factors_and_scan.get() + 3 * dimension_data.num_forward_factors); } } else { - std::size_t max_encountered_global_size = 0; - for (std::size_t i = 0; i < n_kernels; i++) { - if (dimensions.at(i).level == detail::level::GLOBAL) { - max_encountered_global_size = max_encountered_global_size > dimensions.at(i).length - ? max_encountered_global_size - : dimensions.at(i).length; - } - } - // TODO: max_scratch_size should be max(global_size_1 * corresponding_batches_in_l2, global_size_1 * - // corresponding_batches_in_l2), in the case of multi-dim global FFTs. - scratch_space_required = 2 * max_encountered_global_size * params.number_of_transforms; - scratch_ptr_1 = detail::make_shared(scratch_space_required, queue); - scratch_ptr_2 = detail::make_shared(scratch_space_required, queue); - for (std::size_t i = 0; i < n_kernels; i++) { - if (dimensions.at(i).level == detail::level::GLOBAL) { - std::vector factors; - std::vector sub_batches; - std::vector inclusive_scan; - for (const auto& kernel_data : dimensions.at(i).forward_kernels) { - IdxGlobal factor_size = static_cast( - std::accumulate(kernel_data.factors.begin(), kernel_data.factors.end(), 1, std::multiplies())); - factors.push_back(factor_size); - sub_batches.push_back(kernel_data.batch_size); - } - inclusive_scan.push_back(factors.at(0)); - for (std::size_t j = 1; j < factors.size(); j++) { - inclusive_scan.push_back(inclusive_scan.at(j - 1) * factors.at(j)); - } - dimensions.at(i).num_factors = static_cast(factors.size()); - dimensions.at(i).factors_and_scan = - detail::make_shared(factors.size() + sub_batches.size() + inclusive_scan.size(), queue); - queue.copy(factors.data(), dimensions.at(i).factors_and_scan.get(), factors.size()); - queue.copy(sub_batches.data(), dimensions.at(i).factors_and_scan.get() + factors.size(), sub_batches.size()); - queue.copy(inclusive_scan.data(), - dimensions.at(i).factors_and_scan.get() + factors.size() + sub_batches.size(), - inclusive_scan.size()); - queue.wait(); - // build transpose kernels - std::size_t num_transposes_required = factors.size() - 1; - for (std::size_t j = 0; j < num_transposes_required; j++) { - auto in_bundle = sycl::get_kernel_bundle( - queue.get_context(), detail::get_transpose_kernel_ids()); - PORTFFT_LOG_TRACE("Setting specilization constants for transpose kernel", j); - PORTFFT_LOG_TRACE("GlobalSpecConstLevelNum:", i); - in_bundle.template set_specialization_constant(static_cast(i)); - PORTFFT_LOG_TRACE("GlobalSpecConstNumFactors:", factors.size()); - in_bundle.template set_specialization_constant( - static_cast(factors.size())); - dimensions.at(i).transpose_kernels.emplace_back( - sycl::build(in_bundle), - std::vector{static_cast(factors.at(j)), static_cast(sub_batches.at(j))}, 1, 1, 1, - std::shared_ptr(), detail::level::GLOBAL); - } - } - } + // TODO: accuractely calculate the scratch space required when there are more than one global level sizes to + // ensure least amount of evictions + throw internal_error("Scratch space calculation for more than one global level dimensions is not handled"); } } diff --git a/src/portfft/common/bluestein.hpp b/src/portfft/common/bluestein.hpp new file mode 100644 index 00000000..f264c804 --- /dev/null +++ b/src/portfft/common/bluestein.hpp @@ -0,0 +1,76 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Codeplay's portFFT + * + **************************************************************************/ + +#ifndef PORTFFT_COMMON_BLUESTEIN_HPP +#define PORTFFT_COMMON_BLUESTEIN_HPP + +#include "portfft/common/host_fft.hpp" +#include "portfft/defines.hpp" + +#include +#include + +namespace portfft { +namespace detail { +/** + * Utility function to get the dft transform of the chirp signal + * @tparam T Scalar Type + * @param ptr Host Pointer containing the load/store modifiers. + * @param committed_size original problem size + * @param dimension_size padded size + */ +template +void get_fft_chirp_signal(T* ptr, std::size_t committed_size, std::size_t dimension_size) { + using complex_t = std::complex; + std::vector chirp_signal(dimension_size, 0); + std::vector chirp_fft(dimension_size, 0); + for (std::size_t i = 0; i < committed_size; i++) { + double theta = M_PI * static_cast(i * i) / static_cast(committed_size); + chirp_signal[i] = complex_t(static_cast(std::cos(theta)), static_cast(std::sin(theta))); + } + std::size_t num_zeros = dimension_size - 2 * committed_size + 1; + for (std::size_t i = 0; i < committed_size; i++) { + chirp_signal[committed_size + num_zeros + i - 1] = chirp_signal[committed_size - i]; + } + host_naive_dft(chirp_signal.data(), chirp_fft.data(), dimension_size); + std::memcpy(ptr, reinterpret_cast(chirp_fft.data()), 2 * dimension_size * sizeof(T)); +} + +/** + * Populates input modifiers required for bluestein + * @tparam T Scalar Type + * @param ptr Host Pointer containing the load/store modifiers. + * @param committed_size committed problem length + * @param dimension_size padded dft length + */ +template +void populate_bluestein_input_modifiers(T* ptr, std::size_t committed_size, std::size_t dimension_size) { + using complex_t = std::complex; + std::vector scratch(dimension_size, 0); + for (std::size_t i = 0; i < committed_size; i++) { + double theta = -M_PI * static_cast(i * i) / static_cast(committed_size); + scratch[i] = complex_t(static_cast(std::cos(theta)), static_cast(std::sin(theta))); + } + std::memcpy(ptr, reinterpret_cast(scratch.data()), 2 * dimension_size * sizeof(T)); +} +} // namespace detail +} // namespace portfft + +#endif diff --git a/src/portfft/common/global.hpp b/src/portfft/common/global.hpp index 727d0bed..da4d13c5 100644 --- a/src/portfft/common/global.hpp +++ b/src/portfft/common/global.hpp @@ -34,6 +34,25 @@ namespace portfft { namespace detail { +/** + * Helper function to determine the increment of twiddle pointer between factors + * @param level Corresponding implementation for the previous factor + * @param factor_size length of the factor + * @return value to increment the pointer by + */ +inline IdxGlobal increment_twiddle_offset(detail::level level, Idx factor_size) { + PORTFFT_LOG_FUNCTION_ENTRY(); + if (level == detail::level::SUBGROUP) { + return 2 * factor_size; + } + if (level == detail::level::WORKGROUP) { + Idx n = detail::factorize(factor_size); + Idx m = factor_size / n; + return 2 * (factor_size + m + n); + } + return 0; +} + /** * inner batches refers to the batches associated per factor which will be computed in a single implementation call * corresponding to that factor. Optimization note: currently the factors_triple pointer is in global memory, however @@ -134,8 +153,8 @@ PORTFFT_INLINE inline IdxGlobal get_outer_batch_offset(const IdxGlobal* factors, */ template PORTFFT_INLINE void dispatch_level(const Scalar* input, Scalar* output, const Scalar* input_imag, Scalar* output_imag, - const Scalar* implementation_twiddles, const Scalar* store_modifier_data, - Scalar* input_loc, Scalar* twiddles_loc, Scalar* store_modifier_loc, + const Scalar* implementation_twiddles, const Scalar* load_modifier_data, + const Scalar* store_modifier_data, Scalar* input_loc, Scalar* twiddles_loc, const IdxGlobal* factors, const IdxGlobal* inner_batches, const IdxGlobal* inclusive_scan, IdxGlobal batch_size, detail::global_data_struct<1> global_data, sycl::kernel_handler& kh) { @@ -143,27 +162,33 @@ PORTFFT_INLINE void dispatch_level(const Scalar* input, Scalar* output, const Sc auto level = kh.get_specialization_constant(); Idx level_num = kh.get_specialization_constant(); Idx num_factors = kh.get_specialization_constant(); + detail::elementwise_multiply multiply_on_store = kh.get_specialization_constant(); global_data.log_message_global(__func__, "dispatching sub implementation for factor num = ", level_num); IdxGlobal outer_batch_product = get_outer_batch_product(inclusive_scan, num_factors, level_num); for (IdxGlobal iter_value = 0; iter_value < outer_batch_product; iter_value++) { IdxGlobal outer_batch_offset = get_outer_batch_offset(factors, inner_batches, inclusive_scan, num_factors, level_num, iter_value, outer_batch_product, storage); + IdxGlobal store_modifier_offset = [&]() { + if (level_num == num_factors - 1 && multiply_on_store == detail::elementwise_multiply::APPLIED) { + return outer_batch_offset; + } + return static_cast(0); + }(); if (level == detail::level::WORKITEM) { workitem_impl(input + outer_batch_offset, output + outer_batch_offset, input_imag + outer_batch_offset, output_imag + outer_batch_offset, input_loc, - batch_size, global_data, kh, static_cast(nullptr), - store_modifier_data, static_cast(nullptr), store_modifier_loc); + batch_size, global_data, kh, load_modifier_data, + store_modifier_data + store_modifier_offset); } else if (level == detail::level::SUBGROUP) { subgroup_impl(input + outer_batch_offset, output + outer_batch_offset, input_imag + outer_batch_offset, output_imag + outer_batch_offset, input_loc, twiddles_loc, batch_size, implementation_twiddles, global_data, kh, - static_cast(nullptr), store_modifier_data, - static_cast(nullptr), store_modifier_loc); + load_modifier_data, store_modifier_data + store_modifier_offset); } else if (level == detail::level::WORKGROUP) { workgroup_impl(input + outer_batch_offset, output + outer_batch_offset, input_imag + outer_batch_offset, output_imag + outer_batch_offset, input_loc, twiddles_loc, batch_size, implementation_twiddles, global_data, kh, - static_cast(nullptr), store_modifier_data); + load_modifier_data, store_modifier_data + store_modifier_offset); } sycl::group_barrier(global_data.it.get_group()); } @@ -199,8 +224,8 @@ sycl::event transpose_level(const typename committed_descriptor_impl ? detail::memory::USM : detail::memory::BUFFER; const IdxGlobal vec_size = storage == complex_storage::INTERLEAVED_COMPLEX ? 2 : 1; std::vector transpose_events; - IdxGlobal ld_input = kd_struct.factors.at(1); - IdxGlobal ld_output = kd_struct.factors.at(0); + IdxGlobal ld_input = kd_struct.factors.at(0); + IdxGlobal ld_output = kd_struct.factors.at(1); const IdxGlobal* inner_batches = factors_triple + total_factors; const IdxGlobal* inclusive_scan = factors_triple + 2 * total_factors; for (Idx batch_in_l2 = 0; @@ -293,7 +318,6 @@ sycl::event transpose_level(const typename committed_descriptor_impl std::vector compute_level( const typename committed_descriptor_impl::kernel_data_struct& kd_struct, const TIn& input, - Scalar* output, const TIn& input_imag, Scalar* output_imag, const Scalar* twiddles_ptr, - const IdxGlobal* factors_triple, IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, + Scalar* output, const TIn& input_imag, Scalar* output_imag, const Scalar* load_modifier_data, + const Scalar* store_modifier_data, const Scalar* subimpl_twiddles, const IdxGlobal* factors_triple, IdxGlobal input_global_offset, IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, - IdxGlobal batch_start, Idx factor_id, Idx total_factors, complex_storage storage, - const std::vector& dependencies, sycl::queue& queue) { + IdxGlobal batch_start, Idx total_factors, complex_storage storage, const std::vector& dependencies, + sycl::queue& queue) { PORTFFT_LOG_FUNCTION_ENTRY(); constexpr detail::memory Mem = std::is_pointer_v ? detail::memory::USM : detail::memory::BUFFER; IdxGlobal local_range = kd_struct.local_range; IdxGlobal global_range = kd_struct.global_range; IdxGlobal batch_size = kd_struct.batch_size; std::size_t local_memory_for_input = kd_struct.local_mem_required; - std::size_t local_mem_for_store_modifier = [&]() -> std::size_t { - if (factor_id < total_factors - 1) { - if (kd_struct.level == detail::level::WORKITEM || kd_struct.level == detail::level::WORKGROUP) { - return 1; - } - if (kd_struct.level == detail::level::SUBGROUP) { - return kd_struct.local_mem_required; - } - } - return std::size_t(1); - }(); + // Backends may check pointer validity. For the WI implementation, where no subimpl_twiddles alloc is used, + // the subimpl_twiddles + subimpl_twiddle_offset may point to the end of the allocation and therefore be invalid. + // The same check is performed on the pointer containing the imag data, which would fall OOB in the + // INTERLEAVED_COMPLEX case + const Scalar* subimpl_twiddles_ptr = + kd_struct.level == detail::level::WORKITEM ? static_cast(nullptr) : subimpl_twiddles; + std::size_t loc_mem_for_twiddles = [&]() { if (kd_struct.level == detail::level::WORKITEM) { - return std::size_t(1); + return std::size_t(0); } if (kd_struct.level == detail::level::SUBGROUP) { return 2 * kd_struct.length; } if (kd_struct.level == detail::level::WORKGROUP) { - return std::size_t(1); + return std::size_t(0); } throw internal_error("illegal level encountered"); }(); + const IdxGlobal* inner_batches = factors_triple + total_factors; const IdxGlobal* inclusive_scan = factors_triple + 2 * total_factors; const Idx vec_size = storage == complex_storage::INTERLEAVED_COMPLEX ? 2 : 1; std::vector events; - PORTFFT_LOG_TRACE("Local mem requirement - input:", local_memory_for_input, "store modifiers", - local_mem_for_store_modifier, "twiddles", loc_mem_for_twiddles, "total", - local_memory_for_input + local_mem_for_store_modifier + loc_mem_for_twiddles); + PORTFFT_LOG_TRACE("Local mem requirement - input:", local_memory_for_input, "twiddles", loc_mem_for_twiddles, "total", + local_memory_for_input + loc_mem_for_twiddles); for (Idx batch_in_l2 = 0; batch_in_l2 < num_batches_in_l2 && batch_in_l2 + batch_start < n_transforms; batch_in_l2++) { events.push_back(queue.submit([&](sycl::handler& cgh) { sycl::local_accessor loc_for_input(local_memory_for_input, cgh); sycl::local_accessor loc_for_twiddles(loc_mem_for_twiddles, cgh); - sycl::local_accessor loc_for_modifier(local_mem_for_store_modifier, cgh); auto in_acc_or_usm = detail::get_access(input, cgh); auto in_imag_acc_or_usm = detail::get_access(input_imag, cgh); cgh.use_kernel_bundle(kd_struct.exec_bundle); @@ -360,15 +379,11 @@ std::vector compute_level( // level cache. cgh.depends_on(dependencies.at(static_cast(batch_in_l2))); } - // Backends may check pointer validity. For the WI implementation, where no subimpl_twiddles alloc is used, - // the subimpl_twiddles + subimpl_twiddle_offset may point to the end of the allocation and therefore be invalid. - const bool using_wi_level = kd_struct.level == detail::level::WORKITEM; - const Scalar* subimpl_twiddles = using_wi_level ? nullptr : twiddles_ptr + subimpl_twiddle_offset; + Scalar* offset_output_imag = storage == complex_storage::INTERLEAVED_COMPLEX ? nullptr : output_imag + vec_size * batch_in_l2 * committed_size; Scalar* offset_output = output + vec_size * batch_in_l2 * committed_size; - const Scalar* multipliers_between_factors = twiddles_ptr + intermediate_twiddle_offset; IdxGlobal input_batch_offset = vec_size * committed_size * batch_in_l2 + input_global_offset; #ifdef PORTFFT_KERNEL_LOG sycl::stream s{1024 * 16, 1024, cgh}; @@ -389,16 +404,145 @@ std::vector compute_level( s, global_logging_config, #endif it}; - dispatch_level(&in_acc_or_usm[0] + input_batch_offset, offset_output, - &in_imag_acc_or_usm[0] + input_batch_offset, offset_output_imag, - subimpl_twiddles, multipliers_between_factors, &loc_for_input[0], - &loc_for_twiddles[0], &loc_for_modifier[0], factors_triple, - inner_batches, inclusive_scan, batch_size, global_data, kh); + dispatch_level( + &in_acc_or_usm[0] + input_batch_offset, offset_output, &in_imag_acc_or_usm[0] + input_batch_offset, + offset_output_imag, subimpl_twiddles_ptr, load_modifier_data, store_modifier_data, &loc_for_input[0], + &loc_for_twiddles[0], factors_triple, inner_batches, inclusive_scan, batch_size, global_data, kh); }); })); } return events; } + +/** + * Run the global implementation + * @tparam Scalar Scalar type of committed descriptor + * @tparam TIn Input type + * @tparam TOut Output Type + * @tparam SubgroupSize Subgroup Size + * @tparam Domain Domain of the committed descriptor + * @param input sycl::buffer / pointer containing the input data. In the case SPLIT_COMPLEX storage, it contains only + * the real part + * @param input_imag sycl::buffer / pointer containing the imaginary part of the input in the case where storage is + * SPLIT_COMPLEX + * @param output sycl::buffer / pointer containing the output data. In the case SPLIT_COMPLEX storage, it contains only + * the real part + * @param output_imag sycl::buffer / pointer containing the imaginary part of the output in the case where storage is + * SPLIT_COMPLEX + * @param desc committed descriptor + * @param dimension_data Dimension struct pertaining to the dimension being dispatched + * @param kernels vector containing the kernels for the computation + * @param transpose_kernels vector containing transpose kernels as required by the global implementation + * @param num_factors Number of factors + * @param ptr_offset Offset applied to the twiddles pointer to obtain the start of twiddles applied between factors. + * @param subimpl_twiddles_offset Offset applied to the twiddles pointer to obtain the start of twiddles required by the + * level specific implementation. + * @param kd_struct_offset offset applied to vector of kernels + * @param transforms_begin ID of the transform from which the next transforms which fit in LLC will be processed + * @param n_transforms number of transforms + * @param batch_offset_input offset applied to the input + * @param batch_offset_output offset applied to the output + * @param storage complex storage scheme: split_complex / complex_interleaved + * @param first_uses_load_modifier whether or not the very first kernel modifies data before computation + * @param last_kernel_store_modifier_data whether or not the very last kernel modifies the data after computation + * @return sycl::event waiting on the last transposes + */ +template +sycl::event global_impl_driver(const TIn& input, const TIn& input_imag, TOut output, TOut output_imag, + committed_descriptor_impl& desc, + typename committed_descriptor_impl::dimension_struct& dimension_data, + const kernels_vec& kernels, + const kernels_vec& transpose_kernels, Idx num_factors, + IdxGlobal ptr_offset, IdxGlobal subimpl_twiddles_offset, std::size_t kd_struct_offset, + std::size_t transforms_begin, IdxGlobal n_transforms, IdxGlobal batch_offset_input, + IdxGlobal batch_offset_output, complex_storage storage, + detail::elementwise_multiply first_uses_load_modifier, + const Scalar* last_kernel_store_modifier_data) { + std::vector l2_events; + sycl::event event; + + IdxGlobal intermediate_twiddles_offset = ptr_offset; + IdxGlobal impl_twiddle_offset = subimpl_twiddles_offset; + const IdxGlobal vec_size = storage == complex_storage::INTERLEAVED_COMPLEX ? 2 : 1; + auto imag_offset = static_cast(dimension_data.length) * vec_size; + const Scalar* twiddles_ptr = static_cast(kernels.at(0).twiddles_forward.get()); + const IdxGlobal* factors_and_scan = + static_cast(dimension_data.factors_and_scan.get()) + 3 * kd_struct_offset; + IdxGlobal dimension_size = static_cast(dimension_data.length); + Idx max_batches_in_l2 = dimension_data.num_batches_in_l2; + + auto& kernel0 = kernels.at(kd_struct_offset + 0); + const Scalar* load_modifier_data = first_uses_load_modifier == detail::elementwise_multiply::APPLIED + ? twiddles_ptr + dimension_data.bluestein_modifiers_offset + : static_cast(nullptr); + l2_events = detail::compute_level( + kernel0, input, desc.scratch_ptr_1.get(), input_imag, desc.scratch_ptr_1.get() + imag_offset, load_modifier_data, + twiddles_ptr + intermediate_twiddles_offset, twiddles_ptr + impl_twiddle_offset, factors_and_scan, + batch_offset_input, dimension_size, max_batches_in_l2, n_transforms, static_cast(transforms_begin), + num_factors, storage, {event}, desc.queue); + intermediate_twiddles_offset += 2 * kernel0.batch_size * static_cast(kernel0.length); + impl_twiddle_offset += increment_twiddle_offset(kernel0.level, static_cast(kernel0.length)); + + for (std::size_t factor_num = 1; factor_num < static_cast(num_factors); factor_num++) { + auto& current_kernel = kernels.at(kd_struct_offset + factor_num); + if (static_cast(factor_num) == num_factors - 1) { + l2_events = detail::compute_level( + current_kernel, static_cast(desc.scratch_ptr_1.get()), desc.scratch_ptr_1.get(), + static_cast(desc.scratch_ptr_1.get() + imag_offset), desc.scratch_ptr_1.get() + imag_offset, + static_cast(nullptr), last_kernel_store_modifier_data, twiddles_ptr + impl_twiddle_offset, + factors_and_scan, 0, dimension_size, max_batches_in_l2, n_transforms, + static_cast(transforms_begin), num_factors, storage, l2_events, desc.queue); + } else { + l2_events = detail::compute_level( + current_kernel, static_cast(desc.scratch_ptr_1.get()), desc.scratch_ptr_1.get(), + static_cast(desc.scratch_ptr_1.get() + imag_offset), desc.scratch_ptr_1.get() + imag_offset, + static_cast(nullptr), twiddles_ptr + intermediate_twiddles_offset, + twiddles_ptr + impl_twiddle_offset, factors_and_scan, 0, dimension_size, max_batches_in_l2, n_transforms, + static_cast(transforms_begin), num_factors, storage, l2_events, desc.queue); + intermediate_twiddles_offset += 2 * current_kernel.batch_size * static_cast(current_kernel.length); + impl_twiddle_offset += increment_twiddle_offset(current_kernel.level, static_cast(current_kernel.length)); + } + } + + event = desc.queue.submit([&](sycl::handler& cgh) { + cgh.depends_on(l2_events); + cgh.host_task([&]() {}); + }); + + for (Idx num_transpose = num_factors - 2; num_transpose > 0; num_transpose--) { + event = detail::transpose_level( + transpose_kernels.at(static_cast(num_transpose) + kd_struct_offset - 1), desc.scratch_ptr_1.get(), + desc.scratch_ptr_2.get(), factors_and_scan, dimension_size, static_cast(max_batches_in_l2), n_transforms, + static_cast(transforms_begin), num_factors, 0, desc.queue, {event}, storage); + if (storage == complex_storage::SPLIT_COMPLEX) { + event = detail::transpose_level( + transpose_kernels.at(static_cast(num_transpose) + kd_struct_offset - 1), + desc.scratch_ptr_1.get() + imag_offset, desc.scratch_ptr_2.get() + imag_offset, factors_and_scan, + dimension_size, static_cast(max_batches_in_l2), n_transforms, static_cast(transforms_begin), + num_factors, 0, desc.queue, {event}, storage); + } + desc.queue + .submit([&](sycl::handler& cgh) { + cgh.depends_on(event); + cgh.host_task([&]() { desc.scratch_ptr_1.swap(desc.scratch_ptr_2); }); + }) + .wait(); + } + + std::size_t transpose_kernel_pos = kd_struct_offset == 0 ? 0 : kd_struct_offset - 1; + event = detail::transpose_level( + transpose_kernels.at(transpose_kernel_pos), desc.scratch_ptr_1.get(), output, factors_and_scan, dimension_size, + static_cast(max_batches_in_l2), n_transforms, static_cast(transforms_begin), num_factors, + batch_offset_output, desc.queue, {event}, storage); + if (storage == complex_storage::SPLIT_COMPLEX) { + event = detail::transpose_level( + transpose_kernels.at(transpose_kernel_pos), desc.scratch_ptr_1.get() + imag_offset, output_imag, + factors_and_scan, dimension_size, static_cast(max_batches_in_l2), n_transforms, + static_cast(transforms_begin), num_factors, batch_offset_output, desc.queue, {event}, storage); + } + return event; +} + } // namespace detail } // namespace portfft diff --git a/src/portfft/common/host_fft.hpp b/src/portfft/common/host_fft.hpp new file mode 100644 index 00000000..aac96e32 --- /dev/null +++ b/src/portfft/common/host_fft.hpp @@ -0,0 +1,54 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Codeplay's portFFT + * + **************************************************************************/ + +#ifndef PORTFFT_COMMON_HOST_FFT_HPP +#define PORTFFT_COMMON_HOST_FFT_HPP + +#include "portfft/defines.hpp" +#include + +namespace portfft { +namespace detail { + +/** + * Host Naive DFT. Works OOP only + * @tparam T Scalar Type + * @param input input pointer + * @param output output pointer + * @param fft_size fft size + */ +template +void host_naive_dft(std::complex* input, std::complex* output, std::size_t fft_size) { + using complex_t = std::complex; + for (std::size_t i = 0; i < fft_size; i++) { + complex_t temp = complex_t(0, 0); + for (std::size_t j = 0; j < fft_size; j++) { + complex_t multiplier = + complex_t(static_cast(std::cos((-2 * M_PI * static_cast(i * j)) / static_cast(fft_size))), + static_cast(std::sin((-2 * M_PI * static_cast(i * j)) / static_cast(fft_size)))); + temp += input[j] * multiplier; + } + output[i] = temp; + } +} +} // namespace detail +} // namespace portfft + +#endif diff --git a/src/portfft/common/transfers.hpp b/src/portfft/common/transfers.hpp index 99ce928d..7105ce44 100644 --- a/src/portfft/common/transfers.hpp +++ b/src/portfft/common/transfers.hpp @@ -488,6 +488,28 @@ PORTFFT_INLINE void local2global(detail::global_data_struct<1> global_data, Loca global_data, global, local, total_num_elems, global_offset, local_offset); } +/** + * Device Level copy function to copy between data between two arrays with different + * @tparam T Scalar type + * @param src source pointer of type T + * @param dst pointer of type T + * @param src_distance distance between 2 transforms in the source array + * @param dst_distance distance between 2 transforms in the destination array + * @param + */ +template +PORTFFT_INLINE void copy(const T* src, T* dst, IdxGlobal src_distance, IdxGlobal dst_distance, IdxGlobal num_elements, + IdxGlobal num_copies, sycl::nd_item<1> it) { + for (IdxGlobal i = IdxGlobal(it.get_global_id(0)); i < num_elements * num_copies; + i += IdxGlobal(it.get_global_range(0))) { + IdxGlobal transform_id = (i / num_elements) % num_copies; + IdxGlobal element_in_transform = i % num_elements; + IdxGlobal src_index = transform_id * src_distance + element_in_transform; + IdxGlobal dst_index = transform_id * dst_distance + element_in_transform; + dst[dst_index] = src[src_index]; + } +} + } // namespace portfft #endif diff --git a/src/portfft/common/workgroup.hpp b/src/portfft/common/workgroup.hpp index 25038527..e26977c6 100644 --- a/src/portfft/common/workgroup.hpp +++ b/src/portfft/common/workgroup.hpp @@ -188,7 +188,9 @@ __attribute__((always_inline)) inline void dimension_dft( } } global_data.log_dump_private("data loaded in registers:", priv, 2 * fact_wi); - + if (conjugate_on_load == detail::complex_conjugate::APPLIED) { + conjugate_inplace(priv, fact_wi); + } if (wg_twiddles) { PORTFFT_UNROLL for (Idx i = 0; i < fact_wi; i++) { @@ -225,13 +227,7 @@ __attribute__((always_inline)) inline void dimension_dft( } } } - if (conjugate_on_load == detail::complex_conjugate::APPLIED) { - conjugate_inplace(priv, fact_wi); - } sg_dft(priv, global_data.sg, fact_wi, fact_sg, loc_twiddles, wi_private_scratch); - if (conjugate_on_store == detail::complex_conjugate::APPLIED) { - conjugate_inplace(priv, fact_wi); - } if (working) { if (multiply_on_store == detail::elementwise_multiply::APPLIED) { // Store modifier data layout in global memory - n_transforms x N x FactorSG x FactorWI @@ -247,6 +243,9 @@ __attribute__((always_inline)) inline void dimension_dft( priv[2 * idx + 1]); } } + if (conjugate_on_store == detail::complex_conjugate::APPLIED) { + conjugate_inplace(priv, fact_wi); + } global_data.log_dump_private("data in registers after computation:", priv, 2 * fact_wi); if (input_layout == detail::layout::BATCH_INTERLEAVED) { global_data.log_message_global(__func__, "storing transposed data from private to local memory"); diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index 8b67e55a..e9dbaf5e 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -25,6 +25,7 @@ #include +#include "portfft/common/bluestein.hpp" #include "portfft/common/global.hpp" #include "portfft/common/subgroup.hpp" #include "portfft/defines.hpp" @@ -86,70 +87,65 @@ void complex_transpose(const T* a, T* b, IdxGlobal lda, IdxGlobal ldb, IdxGlobal } /** - * Helper function to determine the increment of twiddle pointer between factors - * @param level Corresponding implementation for the previous factor - * @param factor_size length of the factor - * @return value to increment the pointer by + * + * @tparam TIn Input Type + * @tparam TOut Output Type + * @param src sycl::buffer / pointer containing the input data. In the case SPLIT_COMPLEX storage, it contains only the + * real part + * @param src_imag sycl::buffer / pointer containing the imaginary part of the input in the case where storage is + * SPLIT_COMPLEX + * @param dst sycl::buffer / pointer containing the output data. In the case SPLIT_COMPLEX storage, it contains only the + * real part + * @param dst_imag sycl::buffer / pointer containing the imaginary part of the output in the case where storage is + * SPLIT_COMPLEX + * @param num_elements number of elements to copy in each transform + * @param src_distance distance between two consecutive batches in the input + * @param dst_distance disance between two consecutive batches of the output + * @param num_copies number of batches to copy + * @param input_offset offset applied to the input + * @param output_offset offset applied to the output + * @param storage complex storage scheme: split_complex / complex_interleaved + * @param queue sycl queue associated with the commit + * @return sycl event */ -inline IdxGlobal increment_twiddle_offset(detail::level level, Idx factor_size) { - PORTFFT_LOG_FUNCTION_ENTRY(); - if (level == detail::level::SUBGROUP) { - return 2 * factor_size; - } - if (level == detail::level::WORKGROUP) { - Idx n = detail::factorize(factor_size); - Idx m = factor_size / n; - return 2 * (factor_size + m + n); +template +sycl::event copy_global2global(const TIn src, const TIn src_imag, TOut dst, TOut dst_imag, IdxGlobal num_elements, + IdxGlobal src_distance, IdxGlobal dst_distance, IdxGlobal num_copies, + IdxGlobal input_offset, IdxGlobal output_offset, complex_storage storage, + sycl::queue& queue) { + std::vector events; + auto copy_input_to_scratch_impl = [&](const TIn& input, TOut& output) -> sycl::event { + return queue.submit([&](sycl::handler& cgh) { + auto in_acc_or_usm = get_access(input, cgh); + auto out_acc_or_usm = get_access(output, cgh); + cgh.parallel_for(sycl::nd_range<1>({static_cast(num_copies * num_elements)}, + {static_cast(SubgroupSize)}), + [=](sycl::nd_item<1> it) { + copy(&in_acc_or_usm[0] + input_offset, &out_acc_or_usm[0] + output_offset, src_distance, + dst_distance, num_elements, num_copies, it); + }); + }); + }; + events.push_back(copy_input_to_scratch_impl(src, dst)); + if (storage == complex_storage::SPLIT_COMPLEX) { + events.push_back(copy_input_to_scratch_impl(src_imag, dst_imag)); } - return 0; + return queue.submit([&](sycl::handler& cgh) { + cgh.depends_on(events); + cgh.host_task([]() {}); + }); } template template struct committed_descriptor_impl::calculate_twiddles_struct::inner { - static Scalar* execute(committed_descriptor_impl& desc, dimension_struct& /*dimension_data*/, + static Scalar* execute(committed_descriptor_impl& desc, dimension_struct& dimension_data, std::vector& kernels) { + using idxglobal_vec_t = std::vector; PORTFFT_LOG_FUNCTION_ENTRY(); - std::vector factors_idx_global; - // Get factor sizes per level; - for (const auto& kernel_data : kernels) { - factors_idx_global.push_back(static_cast( - std::accumulate(kernel_data.factors.begin(), kernel_data.factors.end(), 1, std::multiplies()))); - } - - std::vector sub_batches; - // Get sub batches - for (std::size_t i = 0; i < factors_idx_global.size() - 1; i++) { - sub_batches.push_back(std::accumulate(factors_idx_global.begin() + static_cast(i + 1), - factors_idx_global.end(), IdxGlobal(1), std::multiplies())); - } - sub_batches.push_back(factors_idx_global.at(factors_idx_global.size() - 2)); - // calculate total memory required for twiddles; - IdxGlobal mem_required_for_twiddles = 0; - // First calculate mem required for twiddles between factors; - for (std::size_t i = 0; i < factors_idx_global.size() - 1; i++) { - mem_required_for_twiddles += 2 * factors_idx_global.at(i) * sub_batches.at(i); - } - // Now calculate mem required for twiddles per implementation - std::size_t counter = 0; - for (const auto& kernel_data : kernels) { - if (kernel_data.level == detail::level::SUBGROUP) { - mem_required_for_twiddles += 2 * factors_idx_global.at(counter); - } else if (kernel_data.level == detail::level::WORKGROUP) { - IdxGlobal factor_1 = detail::factorize(factors_idx_global.at(counter)); - IdxGlobal factor_2 = factors_idx_global.at(counter) / factor_1; - mem_required_for_twiddles += 2 * (factor_1 * factor_2) + 2 * (factor_1 + factor_2); - } - counter++; - } - std::vector host_memory(static_cast(mem_required_for_twiddles)); - std::vector scratch_space(static_cast(mem_required_for_twiddles)); - PORTFFT_LOG_TRACE("Allocating global memory for twiddles for workgroup implementation. Allocation size", - mem_required_for_twiddles); - Scalar* device_twiddles = - sycl::malloc_device(static_cast(mem_required_for_twiddles), desc.queue); - - // Helper Lambda to calculate twiddles + /** + * Helper Lambda to calculate twiddles + */ auto calculate_twiddles = [](IdxGlobal N, IdxGlobal M, IdxGlobal& offset, Scalar* ptr) { for (IdxGlobal i = 0; i < N; i++) { for (IdxGlobal j = 0; j < M; j++) { @@ -160,98 +156,178 @@ struct committed_descriptor_impl::calculate_twiddles_struct::inn } }; - IdxGlobal offset = 0; - // calculate twiddles to be multiplied between factors - for (std::size_t i = 0; i < factors_idx_global.size() - 1; i++) { - calculate_twiddles(sub_batches.at(i), factors_idx_global.at(i), offset, host_memory.data()); - } - // Now calculate per twiddles. - counter = 0; - for (const auto& kernel_data : kernels) { + auto calculate_level_specific_twiddles = [calculate_twiddles](Scalar* host_twiddles_ptr, Scalar* scratch_ptr, + const kernel_data_struct& kernel_data, + IdxGlobal& ptr_offset) { if (kernel_data.level == detail::level::SUBGROUP) { - for (Idx i = 0; i < kernel_data.factors.at(0); i++) { - for (Idx j = 0; j < kernel_data.factors.at(1); j++) { + for (Idx i = 0; i < kernel_data.factors.at(1); i++) { + for (Idx j = 0; j < kernel_data.factors.at(0); j++) { double theta = -2 * M_PI * static_cast(i * j) / - static_cast(kernel_data.factors.at(0) * kernel_data.factors.at(1)); + static_cast(kernel_data.factors.at(1) * kernel_data.factors.at(0)); auto twiddle = std::complex(static_cast(std::cos(theta)), static_cast(std::sin(theta))); - host_memory[static_cast(offset + static_cast(j * kernel_data.factors.at(0) + i))] = - twiddle.real(); - host_memory[static_cast( - offset + static_cast((j + kernel_data.factors.at(1)) * kernel_data.factors.at(0) + i))] = + host_twiddles_ptr[static_cast( + ptr_offset + static_cast(j * kernel_data.factors.at(1) + i))] = twiddle.real(); + host_twiddles_ptr[static_cast( + ptr_offset + static_cast((j + kernel_data.factors.at(0)) * kernel_data.factors.at(1) + i))] = twiddle.imag(); } } - offset += 2 * kernel_data.factors.at(0) * kernel_data.factors.at(1); + ptr_offset += static_cast(2 * kernel_data.length); } else if (kernel_data.level == detail::level::WORKGROUP) { Idx factor_n = kernel_data.factors.at(0) * kernel_data.factors.at(1); Idx factor_m = kernel_data.factors.at(2) * kernel_data.factors.at(3); calculate_twiddles(static_cast(kernel_data.factors.at(0)), - static_cast(kernel_data.factors.at(1)), offset, host_memory.data()); + static_cast(kernel_data.factors.at(1)), ptr_offset, host_twiddles_ptr); calculate_twiddles(static_cast(kernel_data.factors.at(2)), - static_cast(kernel_data.factors.at(3)), offset, host_memory.data()); + static_cast(kernel_data.factors.at(3)), ptr_offset, host_twiddles_ptr); // Calculate wg twiddles and transpose them - calculate_twiddles(static_cast(factor_n), static_cast(factor_m), offset, - host_memory.data()); + calculate_twiddles(static_cast(factor_n), static_cast(factor_m), ptr_offset, + host_twiddles_ptr); for (Idx j = 0; j < factor_n; j++) { - detail::complex_transpose(host_memory.data() + offset + 2 * j * factor_n, scratch_space.data(), factor_m, - factor_n, factor_n * factor_m); + detail::complex_transpose(host_twiddles_ptr + ptr_offset + 2 * j * factor_n, scratch_ptr, factor_m, factor_n, + factor_n * factor_m); + std::memcpy(host_twiddles_ptr + ptr_offset + 2 * j * factor_n, scratch_ptr, 2 * kernel_data.length); } } - counter++; - } + }; + + auto get_sub_batches_and_factors = [&dimension_data, &kernels]() -> std::tuple { + auto get_sub_batches_and_factors_impl = [&kernels](idxglobal_vec_t& factors, idxglobal_vec_t& sub_batches, + std::size_t num_factors, std::size_t offset) -> void { + for (std::size_t i = 0; i < num_factors; i++) { + factors.push_back(static_cast(kernels.at(offset + i).length)); + } + for (std::size_t i = 0; i < num_factors - 1; i++) { + sub_batches.push_back(std::accumulate(factors.begin() + static_cast(offset + i + 1), factors.end(), + IdxGlobal(1), std::multiplies())); + } + sub_batches.push_back(factors.at(factors.size() - 2)); + }; + idxglobal_vec_t factors; + idxglobal_vec_t sub_batches; + get_sub_batches_and_factors_impl(factors, sub_batches, + static_cast(dimension_data.num_forward_factors), 0); + if (dimension_data.is_prime) { + get_sub_batches_and_factors_impl(factors, sub_batches, + static_cast(dimension_data.num_backward_factors), + static_cast(dimension_data.num_forward_factors)); + } + return {factors, sub_batches}; + }; + + auto get_total_mem_for_twiddles = [&dimension_data, &kernels](const idxglobal_vec_t& factors, + const idxglobal_vec_t& sub_batches) -> std::size_t { + auto get_total_mem_for_twiddles_impl = [&kernels](const idxglobal_vec_t& factors, + const idxglobal_vec_t& sub_batches, std::size_t offset, + std::size_t num_factors) -> std::size_t { + IdxGlobal mem_required_for_twiddles = 0; + // account for memory required for store modifiers + for (std::size_t i = 0; i < num_factors - 1; i++) { + mem_required_for_twiddles += 2 * factors.at(offset + i) * sub_batches.at(offset + i); + } + // account for memory required for factor specific twiddles + for (std::size_t i = 0; i < num_factors; i++) { + const auto& kd_struct = kernels.at(offset + i); + if (kd_struct.level == detail::level::SUBGROUP) { + mem_required_for_twiddles += static_cast(2 * kd_struct.length); + } + if (kd_struct.level == detail::level::WORKGROUP) { + mem_required_for_twiddles += + static_cast(2 * kd_struct.length) + + static_cast( + 2 * std::accumulate(kd_struct.factors.begin(), kd_struct.factors.end(), 0, std::plus())); + } + } + return static_cast(mem_required_for_twiddles); + }; + std::size_t mem_required_for_twiddles = get_total_mem_for_twiddles_impl( + factors, sub_batches, 0, static_cast(dimension_data.num_forward_factors)); + if (dimension_data.is_prime) { + mem_required_for_twiddles += get_total_mem_for_twiddles_impl( + factors, sub_batches, static_cast(dimension_data.num_forward_factors), + static_cast(dimension_data.num_backward_factors)); + // load / store modifiers for bluestein + mem_required_for_twiddles += 4 * dimension_data.length; + } + return mem_required_for_twiddles; + }; - // Rearrage the twiddles between factors for optimal access patters in shared memory - // Also take this opportunity to populate local memory size, and batch size, and launch params and local memory - // usage Note, global impl only uses store modifiers - // TODO: there is a heap corruption in workitem's access of loaded modifiers, hence loading from global directly for - // now. - counter = 0; - for (auto& kernel_data : kernels) { - kernel_data.batch_size = sub_batches.at(counter); - kernel_data.length = static_cast(factors_idx_global.at(counter)); - if (kernel_data.level == detail::level::WORKITEM) { - // See comments in workitem_dispatcher for layout requirments. - Idx num_sgs_in_wg = PORTFFT_SGS_IN_WG; - if (counter < kernels.size() - 1) { - kernel_data.local_mem_required = static_cast(1); - } else { - kernel_data.local_mem_required = desc.num_scalars_in_local_mem( - detail::level::WORKITEM, static_cast(factors_idx_global.at(counter)), - kernel_data.used_sg_size, {static_cast(factors_idx_global.at(counter))}, num_sgs_in_wg, - layout::PACKED); + auto get_local_memory_usage = [&desc](detail::layout layout, const kernel_data_struct& kernel_data, + Idx& num_sgs_in_wg) -> std::size_t { + if (kernel_data.level == detail::level::WORKITEM && layout == detail::layout::BATCH_INTERLEAVED) { + return 0; + } + return desc.num_scalars_in_local_mem(kernel_data.level, kernel_data.length, kernel_data.used_sg_size, + kernel_data.factors, num_sgs_in_wg, layout); + }; + + auto calculate_twiddles_and_populate_metadata = [&dimension_data, &kernels, &desc, &calculate_twiddles, + &calculate_level_specific_twiddles, &get_local_memory_usage]( + Scalar* host_twiddles_ptr, const idxglobal_vec_t& factors, + const idxglobal_vec_t& sub_batches) -> void { + auto calculate_twiddles_and_populate_metadata_impl = + [&](Scalar* scratch_ptr, std::size_t num_factors, std::size_t kd_offset, IdxGlobal& ptr_offset) -> IdxGlobal { + IdxGlobal impl_twiddles_offset; + // First populate metadata + for (std::size_t i = 0; i < num_factors; i++) { + auto& kernel_data = kernels.at(kd_offset + i); + kernel_data.batch_size = sub_batches.at(kd_offset + i); + Idx num_sgs_in_wg = PORTFFT_SGS_IN_WG; + detail::layout layout = i == num_factors - 1 ? detail::layout::PACKED : detail::layout::BATCH_INTERLEAVED; + kernel_data.local_mem_required = get_local_memory_usage(layout, kernel_data, num_sgs_in_wg); + kernel_data.num_sgs_per_wg = num_sgs_in_wg; + const auto [global_range, local_range] = + get_launch_params(factors.at(kd_offset + i), sub_batches.at(kd_offset + i), kernel_data.level, + desc.n_compute_units, kernel_data.used_sg_size, num_sgs_in_wg); + kernel_data.global_range = global_range; + kernel_data.local_range = local_range; + } + + // Populate store modifiers + for (std::size_t i = 0; i < num_factors - 1; i++) { + calculate_twiddles(sub_batches.at(kd_offset + i), factors.at(kd_offset + i), ptr_offset, host_twiddles_ptr); } - auto [global_range, local_range] = - detail::get_launch_params(factors_idx_global.at(counter), sub_batches.at(counter), detail::level::WORKITEM, - desc.n_compute_units, kernel_data.used_sg_size, num_sgs_in_wg); - kernel_data.global_range = global_range; - kernel_data.local_range = local_range; - } else if (kernel_data.level == detail::level::SUBGROUP) { - Idx num_sgs_in_wg = PORTFFT_SGS_IN_WG; - // See comments in subgroup_dispatcher for layout requirements. - IdxGlobal factor_sg = detail::factorize_sg(factors_idx_global.at(counter), kernel_data.used_sg_size); - IdxGlobal factor_wi = factors_idx_global.at(counter) / factor_sg; - if (counter < kernels.size() - 1) { - kernel_data.local_mem_required = desc.num_scalars_in_local_mem( - detail::level::SUBGROUP, static_cast(factors_idx_global.at(counter)), - kernel_data.used_sg_size, {static_cast(factor_wi), static_cast(factor_sg)}, num_sgs_in_wg, - layout::BATCH_INTERLEAVED); - } else { - kernel_data.local_mem_required = desc.num_scalars_in_local_mem( - detail::level::SUBGROUP, static_cast(factors_idx_global.at(counter)), - kernel_data.used_sg_size, {static_cast(factor_wi), static_cast(factor_sg)}, num_sgs_in_wg, - layout::PACKED); + impl_twiddles_offset = ptr_offset; + + // Populate Implementation specific twiddles + for (std::size_t i = 0; i < num_factors; i++) { + const auto& kernel_data = kernels.at(kd_offset + i); + calculate_level_specific_twiddles(host_twiddles_ptr, scratch_ptr, kernel_data, ptr_offset); } - auto [global_range, local_range] = - detail::get_launch_params(factors_idx_global.at(counter), sub_batches.at(counter), detail::level::SUBGROUP, - desc.n_compute_units, kernel_data.used_sg_size, num_sgs_in_wg); - kernel_data.global_range = global_range; - kernel_data.local_range = local_range; + return impl_twiddles_offset; + }; + + std::vector scratch_space(2 * dimension_data.length); + IdxGlobal offset = 0; + dimension_data.forward_impl_twiddle_offset = calculate_twiddles_and_populate_metadata_impl( + host_twiddles_ptr, static_cast(dimension_data.num_forward_factors), 0, offset); + if (dimension_data.is_prime) { + dimension_data.backward_twiddles_offset = offset; + dimension_data.backward_impl_twiddle_offset = calculate_twiddles_and_populate_metadata_impl( + host_twiddles_ptr, static_cast(dimension_data.num_backward_factors), + static_cast(dimension_data.num_forward_factors), offset); + dimension_data.bluestein_modifiers_offset = offset; + detail::populate_bluestein_input_modifiers(host_twiddles_ptr + offset, dimension_data.committed_length, + dimension_data.length); + offset += static_cast(2 * dimension_data.length); + detail::get_fft_chirp_signal(host_twiddles_ptr + offset, dimension_data.committed_length, + dimension_data.length); } - counter++; + }; + + const auto [factors, sub_batches] = get_sub_batches_and_factors(); + std::size_t mem_required_for_twiddles = get_total_mem_for_twiddles(factors, sub_batches); + Scalar* device_twiddles_ptr = + sycl::aligned_alloc_device(alignof(sycl::vec), mem_required_for_twiddles, desc.queue); + if (!device_twiddles_ptr) { + throw internal_error("Could not allocate usm memory of size: ", mem_required_for_twiddles * sizeof(Scalar), + " bytes"); } - desc.queue.copy(host_memory.data(), device_twiddles, static_cast(mem_required_for_twiddles)).wait(); - return device_twiddles; + std::vector host_memory_twiddles(mem_required_for_twiddles); + calculate_twiddles_and_populate_metadata(host_memory_twiddles.data(), factors, sub_batches); + desc.queue.copy(host_memory_twiddles.data(), device_twiddles_ptr, mem_required_for_twiddles).wait_and_throw(); + return device_twiddles_ptr; } }; @@ -272,10 +348,10 @@ struct committed_descriptor_impl::set_spec_constants_struct::inn PORTFFT_LOG_TRACE("SpecConstFftSize:", length); in_bundle.template set_specialization_constant(length); } else if (level == detail::level::SUBGROUP) { - PORTFFT_LOG_TRACE("SubgroupFactorWISpecConst:", factors[1]); - in_bundle.template set_specialization_constant(factors[1]); - PORTFFT_LOG_TRACE("SubgroupFactorSGSpecConst:", factors[0]); - in_bundle.template set_specialization_constant(factors[0]); + PORTFFT_LOG_TRACE("SubgroupFactorWISpecConst:", factors[0]); + in_bundle.template set_specialization_constant(factors[0]); + PORTFFT_LOG_TRACE("SubgroupFactorSGSpecConst:", factors[1]); + in_bundle.template set_specialization_constant(factors[1]); } } }; @@ -302,97 +378,73 @@ struct committed_descriptor_impl::run_kernel_struct(kernels.at(0).twiddles_forward.get()); - const IdxGlobal* factors_and_scan = static_cast(dimension_data.factors_and_scan.get()); - std::size_t num_batches = desc.params.number_of_transforms; + std::size_t num_batches = static_cast(n_transforms); std::size_t max_batches_in_l2 = static_cast(dimension_data.num_batches_in_l2); std::size_t imag_offset = dimension_data.length * max_batches_in_l2; - IdxGlobal initial_impl_twiddle_offset = 0; - Idx num_factors = dimension_data.num_factors; - IdxGlobal committed_size = static_cast(desc.params.lengths[0]); - Idx num_transposes = num_factors - 1; - std::vector l2_events; + sycl::event event = desc.queue.submit([&](sycl::handler& cgh) { cgh.depends_on(dependencies); cgh.host_task([&]() {}); }); - for (std::size_t i = 0; i < static_cast(num_factors - 1); i++) { - initial_impl_twiddle_offset += 2 * kernels.at(i).batch_size * static_cast(kernels.at(i).length); - } + for (std::size_t i = 0; i < num_batches; i += max_batches_in_l2) { - PORTFFT_LOG_TRACE("Global implementation working on batches", i, "through", i + max_batches_in_l2, "out of", - num_batches); - IdxGlobal intermediate_twiddles_offset = 0; - IdxGlobal impl_twiddle_offset = initial_impl_twiddle_offset; - auto& kernel0 = kernels.at(0); - PORTFFT_LOG_TRACE("Dispatching the kernel for factor 0 of global implementation"); - l2_events = detail::compute_level( - kernel0, in, desc.scratch_ptr_1.get(), in_imag, desc.scratch_ptr_1.get() + imag_offset, twiddles_ptr, - factors_and_scan, intermediate_twiddles_offset, impl_twiddle_offset, - vec_size * static_cast(i) * committed_size + input_offset, committed_size, - static_cast(max_batches_in_l2), static_cast(num_batches), static_cast(i), 0, - dimension_data.num_factors, storage, {event}, desc.queue); - detail::dump_device(desc.queue, "after factor 0:", desc.scratch_ptr_1.get(), - desc.params.number_of_transforms * dimension_data.length * 2, l2_events); - intermediate_twiddles_offset += 2 * kernel0.batch_size * static_cast(kernel0.length); - impl_twiddle_offset += detail::increment_twiddle_offset(kernel0.level, static_cast(kernel0.length)); - for (std::size_t factor_num = 1; factor_num < static_cast(dimension_data.num_factors); - factor_num++) { - auto& current_kernel = kernels.at(factor_num); - PORTFFT_LOG_TRACE("Dispatching the kernel for factor", factor_num, "of global implementation"); - if (static_cast(factor_num) == dimension_data.num_factors - 1) { - PORTFFT_LOG_TRACE("This is the last kernel"); - } - l2_events = detail::compute_level( - current_kernel, desc.scratch_ptr_1.get(), desc.scratch_ptr_1.get(), desc.scratch_ptr_1.get() + imag_offset, - desc.scratch_ptr_1.get() + imag_offset, twiddles_ptr, factors_and_scan, intermediate_twiddles_offset, - impl_twiddle_offset, 0, committed_size, static_cast(max_batches_in_l2), - static_cast(num_batches), static_cast(i), static_cast(factor_num), - dimension_data.num_factors, storage, l2_events, desc.queue); - intermediate_twiddles_offset += 2 * current_kernel.batch_size * static_cast(current_kernel.length); - impl_twiddle_offset += - detail::increment_twiddle_offset(current_kernel.level, static_cast(current_kernel.length)); - detail::dump_device(desc.queue, "after factor:", desc.scratch_ptr_1.get(), - desc.params.number_of_transforms * dimension_data.length * 2, l2_events); - } - event = desc.queue.submit([&](sycl::handler& cgh) { - cgh.depends_on(l2_events); - cgh.host_task([&]() {}); - }); - for (Idx num_transpose = num_transposes - 1; num_transpose > 0; num_transpose--) { - PORTFFT_LOG_TRACE("Dispatching the transpose kernel", num_transpose); - event = detail::transpose_level( - dimension_data.transpose_kernels.at(static_cast(num_transpose)), desc.scratch_ptr_1.get(), - desc.scratch_ptr_2.get(), factors_and_scan, committed_size, static_cast(max_batches_in_l2), - n_transforms, static_cast(i), num_factors, 0, desc.queue, {event}, storage); - if (storage == complex_storage::SPLIT_COMPLEX) { - event = detail::transpose_level( - dimension_data.transpose_kernels.at(static_cast(num_transpose)), - desc.scratch_ptr_1.get() + imag_offset, desc.scratch_ptr_2.get() + imag_offset, factors_and_scan, - committed_size, static_cast(max_batches_in_l2), n_transforms, static_cast(i), num_factors, - 0, desc.queue, {event}, storage); - } - desc.scratch_ptr_1.swap(desc.scratch_ptr_2); - } - PORTFFT_LOG_TRACE("Dispatching the transpose kernel 0"); - event = detail::transpose_level( - dimension_data.transpose_kernels.at(0), desc.scratch_ptr_1.get(), out, factors_and_scan, committed_size, - static_cast(max_batches_in_l2), n_transforms, static_cast(i), num_factors, - vec_size * static_cast(i) * committed_size + output_offset, desc.queue, {event}, storage); - if (storage == complex_storage::SPLIT_COMPLEX) { - event = detail::transpose_level( - dimension_data.transpose_kernels.at(0), desc.scratch_ptr_1.get() + imag_offset, out_imag, factors_and_scan, - committed_size, static_cast(max_batches_in_l2), n_transforms, static_cast(i), num_factors, - vec_size * static_cast(i) * committed_size + output_offset, desc.queue, {event}, storage); + if (dimension_data.is_prime) { + IdxGlobal num_copies = static_cast( + i + max_batches_in_l2 < num_batches ? max_batches_in_l2 : num_batches - max_batches_in_l2); + // TODO: look into other library implementations to check whether is it possible at all to avoid this explicit + // copy. + copy_global2global( + in, in_imag, desc.scratch_ptr_1.get(), desc.scratch_ptr_1.get() + imag_offset, + static_cast(vec_size * dimension_data.committed_length), + static_cast(vec_size * dimension_data.committed_length), + static_cast(vec_size * dimension_data.length), + static_cast(vec_size * i * dimension_data.committed_length) + input_offset, 0, num_copies, + storage, desc.queue) + .wait(); + + detail::global_impl_driver( + static_cast(desc.scratch_ptr_1.get()), + static_cast(desc.scratch_ptr_1.get() + imag_offset), desc.scratch_ptr_2.get(), + desc.scratch_ptr_2.get() + imag_offset, desc, dimension_data, kernels, dimension_data.transpose_kernels, + dimension_data.num_forward_factors, 0, dimension_data.forward_impl_twiddle_offset, 0, i, + static_cast(num_batches), 0, 0, storage, detail::elementwise_multiply::APPLIED, + twiddles_ptr + dimension_data.bluestein_modifiers_offset + 2 * dimension_data.length) + .wait(); + std::swap(desc.scratch_ptr_1, desc.scratch_ptr_2); + detail::global_impl_driver( + static_cast(desc.scratch_ptr_1.get()), + static_cast(desc.scratch_ptr_1.get() + imag_offset), desc.scratch_ptr_2.get(), + desc.scratch_ptr_2.get() + imag_offset, desc, dimension_data, kernels, dimension_data.transpose_kernels, + dimension_data.num_backward_factors, dimension_data.backward_twiddles_offset, + dimension_data.backward_impl_twiddle_offset, std::size_t(dimension_data.num_forward_factors), i, + static_cast(num_batches), 0, 0, storage, detail::elementwise_multiply::NOT_APPLIED, + twiddles_ptr + dimension_data.bluestein_modifiers_offset); + + copy_global2global( + desc.scratch_ptr_2.get(), desc.scratch_ptr_2.get() + imag_offset, out, out_imag, + static_cast(vec_size * dimension_data.committed_length), + static_cast(vec_size * dimension_data.length), + static_cast(vec_size * dimension_data.committed_length), 0, + static_cast(vec_size * i * dimension_data.committed_length) + output_offset, num_copies, storage, + desc.queue) + .wait(); + } else { + event = detail::global_impl_driver( + in, in_imag, out, out_imag, desc, dimension_data, kernels, dimension_data.transpose_kernels, + dimension_data.num_forward_factors, 0, dimension_data.forward_impl_twiddle_offset, 0, i, + static_cast(num_batches), + static_cast(vec_size * i * dimension_data.length) + input_offset, + static_cast(vec_size * i * dimension_data.length) + output_offset, storage, + detail::elementwise_multiply::NOT_APPLIED, static_cast(nullptr)); } } return event; } }; - } // namespace detail } // namespace portfft diff --git a/src/portfft/dispatcher/subgroup_dispatcher.hpp b/src/portfft/dispatcher/subgroup_dispatcher.hpp index aebc8629..1a9a4347 100644 --- a/src/portfft/dispatcher/subgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/subgroup_dispatcher.hpp @@ -79,15 +79,12 @@ IdxGlobal get_global_size_subgroup(IdxGlobal n_transforms, Idx factor_sg, Idx su * @param twiddles pointer containing twiddles * @param load_modifier_data Pointer to the load modifier data in global Memory * @param store_modifier_data Pointer to the store modifier data in global Memory - * @param loc_load_modifier Pointer to load modifier data in local memory - * @param loc_store_modifier Pointer to store modifier data in local memory */ template PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag, T* output_imag, T* loc, T* loc_twiddles, IdxGlobal n_transforms, const T* twiddles, global_data_struct<1> global_data, sycl::kernel_handler& kh, - const T* load_modifier_data = nullptr, const T* store_modifier_data = nullptr, - T* loc_load_modifier = nullptr, T* loc_store_modifier = nullptr) { + const T* load_modifier_data = nullptr, const T* store_modifier_data = nullptr) { const complex_storage storage = kh.get_specialization_constant(); const detail::elementwise_multiply multiply_on_load = kh.get_specialization_constant(); @@ -101,6 +98,9 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag kh.get_specialization_constant(); const T scaling_factor = kh.get_specialization_constant()>(); + using vec2_t = sycl::vec; + vec2_t modifier_vec; + const Idx factor_wi = kh.get_specialization_constant(); const Idx factor_sg = kh.get_specialization_constant(); const IdxGlobal input_distance = kh.get_specialization_constant(); @@ -155,8 +155,6 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag constexpr Idx BankLinesPerPad = 1; auto loc_view = detail::padded_view(loc, BankLinesPerPad); - auto loc_load_modifier_view = detail::padded_view(loc_load_modifier, BankLinesPerPad); - auto loc_store_modifier_view = detail::padded_view(loc_store_modifier, BankLinesPerPad); global_data.log_message_global(__func__, "loading sg twiddles from global to local memory"); global2local(global_data, twiddles, loc_twiddles, n_reals_per_fft); @@ -190,19 +188,7 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag }(); Idx rounded_up_sub_batches = detail::round_up_to_multiple(num_batches_in_local_mem, n_ffts_per_sg); Idx local_imag_offset = factor_wi * factor_sg * max_num_batches_local_mem; - if (multiply_on_load == detail::elementwise_multiply::APPLIED) { - global_data.log_message_global(__func__, "loading load multipliers from global to local memory"); - global2local(global_data, load_modifier_data, loc_load_modifier_view, - n_reals_per_fft * num_batches_in_local_mem, - i * n_reals_per_fft); - } - // TODO: Replace this with Async DMA where the hardware supports it. - if (multiply_on_store == detail::elementwise_multiply::APPLIED) { - global_data.log_message_global(__func__, "loading store multipliers from global to local memory"); - global2local(global_data, store_modifier_data, loc_store_modifier_view, - n_reals_per_fft * num_batches_in_local_mem, - i * n_reals_per_fft); - } + global_data.log_message_global(__func__, "loading transposed data from global to local memory"); // load / store in a transposed manner if (storage == complex_storage::INTERLEAVED_COMPLEX) { @@ -250,50 +236,43 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag } global_data.log_dump_private("data loaded in registers:", priv, n_reals_per_wi); } + if (conjugate_on_load == detail::complex_conjugate::APPLIED) { + conjugate_inplace(priv, factor_wi); + } if (multiply_on_load == detail::elementwise_multiply::APPLIED) { - // Note: if using load modifier, this data need to be stored in the transposed fashion per batch to ensure - // low latency reads from shared memory, as this will result in much lesser bank conflicts. - // Tensor shape for load modifier in local memory = num_batches_in_local_mem x FactorWI x FactorSG - // TODO: change the above mentioned layout to the following tenshor shape: num_batches_in_local_mem x - // n_ffts_in_sg x FactorWI x FactorSG - global_data.log_message_global(__func__, "multiplying load modifier data"); if (working_inner) { PORTFFT_UNROLL for (Idx j = 0; j < factor_wi; j++) { - Idx base_offset = sub_batch * n_reals_per_fft + 2 * j * factor_sg + 2 * id_of_wi_in_fft; - multiply_complex(priv[2 * j], priv[2 * j + 1], loc_load_modifier_view[base_offset], - loc_load_modifier_view[base_offset + 1], priv[2 * j], priv[2 * j + 1]); + IdxGlobal idx = + static_cast(n_reals_per_fft) * (i + static_cast(sub_batch + id_of_fft_in_sg)) + + 2 * static_cast(id_of_wi_in_fft * factor_wi + j); + modifier_vec = *reinterpret_cast(&load_modifier_data[idx]); + multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_vec[0], modifier_vec[1], priv[2 * j], + priv[2 * j + 1]); } } } - if (conjugate_on_load == detail::complex_conjugate::APPLIED) { - conjugate_inplace(priv, factor_wi); - } sg_dft(priv, global_data.sg, factor_wi, factor_sg, loc_twiddles, wi_private_scratch); - if (conjugate_on_store == detail::complex_conjugate::APPLIED) { - conjugate_inplace(priv, factor_wi); - } if (working_inner) { global_data.log_dump_private("data in registers after computation:", priv, n_reals_per_wi); } if (multiply_on_store == detail::elementwise_multiply::APPLIED) { - // No need to store the store modifier data in a transposed fashion as data after sg_dft is already transposed - // Tensor Shape for store modifier is num_batches_in_local_memory x FactorSG x FactorWI global_data.log_message_global(__func__, "multiplying store modifier data"); if (working_inner) { PORTFFT_UNROLL for (Idx j = 0; j < factor_wi; j++) { - sycl::vec modifier_priv; - Idx base_offset = sub_batch * n_reals_per_fft + 2 * j * factor_sg + 2 * id_of_wi_in_fft; - // TODO: this leads to compilation error on AMD. Revert back to this once it is resolved - // modifier_priv.load(0, detail::get_local_multi_ptr(&loc_store_modifier_view[base_offset])); - modifier_priv[0] = loc_store_modifier_view[base_offset]; - modifier_priv[1] = loc_store_modifier_view[base_offset + 1]; - multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_priv[0], modifier_priv[1], priv[2 * j], + IdxGlobal idx = + static_cast(n_reals_per_fft) * (i + static_cast(sub_batch + id_of_fft_in_sg)) + + static_cast(2 * j * factor_sg + 2 * id_of_wi_in_fft); + modifier_vec = *reinterpret_cast(&store_modifier_data[idx]); + multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_vec[0], modifier_vec[1], priv[2 * j], priv[2 * j + 1]); } } } + if (conjugate_on_store == detail::complex_conjugate::APPLIED) { + conjugate_inplace(priv, factor_wi); + } if (apply_scale_factor == detail::apply_scale_factor::APPLIED) { PORTFFT_UNROLL for (Idx idx = 0; idx < factor_wi; idx++) { @@ -418,18 +397,7 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag static_cast(fft_size) * (i - static_cast(id_of_fft_in_sg)), local_imag_offset + subgroup_id * n_cplx_per_sg); } - if (multiply_on_load == detail::elementwise_multiply::APPLIED) { - global_data.log_message_global(__func__, "loading load modifier data"); - global2local( - global_data, load_modifier_data, loc_load_modifier_view, n_ffts_worked_on_by_sg * n_reals_per_fft, - n_reals_per_fft * (i - id_of_fft_in_sg), subgroup_id * n_reals_per_sg); - } - if (multiply_on_store == detail::elementwise_multiply::APPLIED) { - global_data.log_message_global(__func__, "loading store modifier data"); - global2local( - global_data, store_modifier_data, loc_store_modifier_view, n_ffts_worked_on_by_sg * n_reals_per_fft, - n_reals_per_fft * (i - id_of_fft_in_sg), subgroup_id * n_reals_per_sg); - } + sycl::group_barrier(global_data.sg); global_data.log_dump_local("data in local memory:", loc_view, n_reals_per_fft); if (working) { @@ -450,25 +418,23 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag global_data.log_dump_private("data loaded in registers:", priv, n_reals_per_wi); } sycl::group_barrier(global_data.sg); + if (conjugate_on_load == detail::complex_conjugate::APPLIED) { + conjugate_inplace(priv, factor_wi); + } if (multiply_on_load == detail::elementwise_multiply::APPLIED) { if (working) { global_data.log_message_global(__func__, "Multiplying load modifier before sg_dft"); PORTFFT_UNROLL for (Idx j = 0; j < factor_wi; j++) { - Idx base_offset = static_cast(global_data.sg.get_group_id()) * n_ffts_per_sg + - id_of_fft_in_sg * n_reals_per_fft + 2 * j * factor_sg + 2 * id_of_wi_in_fft; - multiply_complex(priv[2 * j], priv[2 * j + 1], loc_load_modifier_view[base_offset], - loc_load_modifier_view[base_offset + 1], priv[2 * j], priv[2 * j + 1]); + IdxGlobal idx = static_cast(n_reals_per_fft) * (i + IdxGlobal(id_of_fft_in_sg)) + + 2 * static_cast(id_of_wi_in_fft * factor_wi + j); + modifier_vec = *reinterpret_cast(&load_modifier_data[idx]); + multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_vec[0], modifier_vec[1], priv[2 * j], + priv[2 * j + 1]); } } } - if (conjugate_on_load == detail::complex_conjugate::APPLIED) { - conjugate_inplace(priv, factor_wi); - } sg_dft(priv, global_data.sg, factor_wi, factor_sg, loc_twiddles, wi_private_scratch); - if (conjugate_on_store == detail::complex_conjugate::APPLIED) { - conjugate_inplace(priv, factor_wi); - } if (working) { global_data.log_dump_private("data in registers after computation:", priv, n_reals_per_wi); } @@ -477,17 +443,17 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag global_data.log_message_global(__func__, "Multiplying store modifier before sg_dft"); PORTFFT_UNROLL for (Idx j = 0; j < factor_wi; j++) { - sycl::vec modifier_priv; - Idx base_offset = static_cast(global_data.it.get_sub_group().get_group_id()) * n_ffts_per_sg + - id_of_fft_in_sg * n_reals_per_fft + 2 * j * factor_sg + 2 * id_of_wi_in_fft; - // modifier_priv.load(0, detail::get_local_multi_ptr(&loc_store_modifier_view[base_offset])); - modifier_priv[0] = loc_store_modifier_view[base_offset]; - modifier_priv[1] = loc_store_modifier_view[base_offset + 1]; - multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_priv[0], modifier_priv[1], priv[2 * j], + IdxGlobal idx = static_cast(n_reals_per_fft) * (i + IdxGlobal(id_of_fft_in_sg)) + + static_cast(2 * j * factor_sg + 2 * id_of_wi_in_fft); + modifier_vec = *reinterpret_cast(&store_modifier_data[idx]); + multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_vec[0], modifier_vec[1], priv[2 * j], priv[2 * j + 1]); } } } + if (conjugate_on_store == detail::complex_conjugate::APPLIED) { + conjugate_inplace(priv, factor_wi); + } if (apply_scale_factor == detail::apply_scale_factor::APPLIED) { PORTFFT_UNROLL for (Idx j = 0; j < factor_wi; j++) { @@ -603,6 +569,10 @@ struct committed_descriptor_impl::calculate_twiddles_struct::inn kernel_data.length * 2); Scalar* res = sycl::aligned_alloc_device( alignof(sycl::vec), kernel_data.length * 2, desc.queue); + if (!res) { + throw internal_error("Could not allocate usm memory of size: ", kernel_data.length * 2 * sizeof(Scalar), + " bytes"); + } sycl::range<2> kernel_range({static_cast(factor_sg), static_cast(factor_wi)}); desc.queue.submit([&](sycl::handler& cgh) { PORTFFT_LOG_TRACE("Launching twiddle calculation kernel for subgroup implementation with global size", factor_sg, diff --git a/src/portfft/dispatcher/workgroup_dispatcher.hpp b/src/portfft/dispatcher/workgroup_dispatcher.hpp index dbbca454..be4ed660 100644 --- a/src/portfft/dispatcher/workgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/workgroup_dispatcher.hpp @@ -383,6 +383,10 @@ struct committed_descriptor_impl::calculate_twiddles_struct::inn Scalar* res = sycl::aligned_alloc_device(alignof(sycl::vec), static_cast(res_size), desc.queue); + if (!res) { + throw internal_error( + "Could not allocate usm memory of size: ", static_cast(res_size) * sizeof(Scalar), " bytes"); + } desc.queue.submit([&](sycl::handler& cgh) { PORTFFT_LOG_TRACE( "Launching twiddle calculation kernel for factor 1 of workgroup implementation with global size", factor_sg_n, diff --git a/src/portfft/dispatcher/workitem_dispatcher.hpp b/src/portfft/dispatcher/workitem_dispatcher.hpp index 28b6962b..844aca70 100644 --- a/src/portfft/dispatcher/workitem_dispatcher.hpp +++ b/src/portfft/dispatcher/workitem_dispatcher.hpp @@ -66,10 +66,11 @@ IdxGlobal get_global_size_workitem(IdxGlobal n_transforms, Idx subgroup_size, Id */ template PORTFFT_INLINE void apply_modifier(Idx num_elements, PrivT priv, const T* modifier_data, IdxGlobal offset) { + using vec2_t = sycl::vec; + vec2_t modifier_vec; PORTFFT_UNROLL for (Idx j = 0; j < num_elements; j++) { - sycl::vec modifier_vec; - modifier_vec.load(0, detail::get_global_multi_ptr(&modifier_data[offset + 2 * j])); + modifier_vec = *reinterpret_cast(&modifier_data[offset + 2 * j]); multiply_complex(priv[2 * j], priv[2 * j + 1], modifier_vec[0], modifier_vec[1], priv[2 * j], priv[2 * j + 1]); } } @@ -93,14 +94,11 @@ PORTFFT_INLINE void apply_modifier(Idx num_elements, PrivT priv, const T* modifi * @param kh kernel handler associated with the kernel launch * @param load_modifier_data Pointer to the load modifier data in global memory * @param store_modifier_data Pointer to the store modifier data in global memory - * @param loc_load_modifier Pointer to load modifier data in local memory - * @param loc_store_modifier Pointer to store modifier data in local memory */ template PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag, T* output_imag, T* loc, IdxGlobal n_transforms, global_data_struct<1> global_data, sycl::kernel_handler& kh, - const T* load_modifier_data = nullptr, const T* store_modifier_data = nullptr, - T* loc_load_modifier = nullptr, T* loc_store_modifier = nullptr) { + const T* load_modifier_data = nullptr, const T* store_modifier_data = nullptr) { complex_storage storage = kh.get_specialization_constant(); detail::elementwise_multiply multiply_on_load = kh.get_specialization_constant(); detail::elementwise_multiply multiply_on_store = kh.get_specialization_constant(); @@ -144,8 +142,6 @@ PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag Idx local_imag_offset = fft_size * SubgroupSize; constexpr Idx BankLinesPerPad = 1; auto loc_view = detail::padded_view(loc, BankLinesPerPad); - auto loc_load_modifier_view = detail::padded_view(loc_load_modifier, BankLinesPerPad); - auto loc_store_modifier_view = detail::padded_view(loc_store_modifier, BankLinesPerPad); const IdxGlobal transform_idx_begin = static_cast(global_data.it.get_global_id(0)); const IdxGlobal transform_idx_step = static_cast(global_data.it.get_global_range(0)); @@ -242,6 +238,9 @@ PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag copy_wi(global_data, local_imag_view, priv_imag_view, fft_size); } } + if (conjugate_on_load == detail::complex_conjugate::APPLIED) { + conjugate_inplace(priv, fft_size); + } global_data.log_dump_private("data loaded in registers:", priv, n_reals); if (multiply_on_load == detail::elementwise_multiply::APPLIED) { @@ -250,13 +249,7 @@ PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag global_data.log_message_global(__func__, "applying load modifier"); detail::apply_modifier(fft_size, priv, load_modifier_data, i * n_reals); } - if (conjugate_on_load == detail::complex_conjugate::APPLIED) { - conjugate_inplace(priv, fft_size); - } wi_dft<0>(priv, priv, fft_size, 1, 1, wi_private_scratch); - if (conjugate_on_store == detail::complex_conjugate::APPLIED) { - conjugate_inplace(priv, fft_size); - } global_data.log_dump_private("data in registers after computation:", priv, n_reals); if (multiply_on_store == detail::elementwise_multiply::APPLIED) { @@ -265,6 +258,9 @@ PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag global_data.log_message_global(__func__, "applying store modifier"); detail::apply_modifier(fft_size, priv, store_modifier_data, i * n_reals); } + if (conjugate_on_store == detail::complex_conjugate::APPLIED) { + conjugate_inplace(priv, fft_size); + } if (apply_scale_factor == detail::apply_scale_factor::APPLIED) { PORTFFT_UNROLL for (Idx idx = 0; idx < n_reals; idx += 2) { diff --git a/src/portfft/utils.hpp b/src/portfft/utils.hpp index db837e3e..151878ab 100644 --- a/src/portfft/utils.hpp +++ b/src/portfft/utils.hpp @@ -89,22 +89,27 @@ constexpr bool can_cast_safely(const InputType& x) { * The function should accept factor size and whether it would be have a BATCH_INTERLEAVED layout or not as an input, * and should return a boolean indicating whether or not the factor size can fit in any of the implementation. * @param transposed whether or not the factor will be computed in a BATCH_INTERLEAVED format - * @return + * @param encountered_prime_factor A flag to be set if a prime factor which cannot be dispatched to workitem + * implementation + * @return A factor of the committed size which can be dispatched to either workitem or subgroup implementation */ template -IdxGlobal factorize_input_impl(IdxGlobal factor_size, F&& check_and_select_target_level, bool transposed) { +IdxGlobal factorize_input_impl(IdxGlobal factor_size, F&& check_and_select_target_level, bool transposed, + bool& encountered_prime_factor) { PORTFFT_LOG_FUNCTION_ENTRY(); IdxGlobal fact_1 = factor_size; if (check_and_select_target_level(fact_1, transposed)) { return fact_1; } if ((detail::factorize(fact_1) == 1)) { - throw unsupported_configuration("Large prime sized factors are not supported at the moment"); + encountered_prime_factor = true; + return fact_1; } do { fact_1 = detail::factorize(fact_1); if (fact_1 == 1) { - throw internal_error("Factorization Failed !"); + encountered_prime_factor = true; + return fact_1; } } while (!check_and_select_target_level(fact_1)); return fact_1; @@ -118,17 +123,20 @@ IdxGlobal factorize_input_impl(IdxGlobal factor_size, F&& check_and_select_targe * implementations. The function should accept factor size and whether it would be have a BATCH_INTERLEAVED layout or * not as an input, and should return a boolean indicating whether or not the factor size can fit in any of the * implementation. + * @return whether or not a large prime was encountered during factorization */ template -void factorize_input(IdxGlobal input_size, F&& check_and_select_target_level) { +bool factorize_input(IdxGlobal input_size, F&& check_and_select_target_level) { PORTFFT_LOG_FUNCTION_ENTRY(); if (detail::factorize(input_size) == 1) { - throw unsupported_configuration("Large Prime sized FFTs are currently not supported"); + return true; } IdxGlobal temp = 1; + bool encountered_prime = false; while (input_size / temp != 1) { - temp *= factorize_input_impl(input_size / temp, check_and_select_target_level, true); + temp *= factorize_input_impl(input_size / temp, check_and_select_target_level, true, encountered_prime); } + return encountered_prime; } /** @@ -161,11 +169,15 @@ std::vector get_transpose_kernel_ids() { */ template inline std::shared_ptr make_shared(std::size_t size, sycl::queue& queue) { - return std::shared_ptr(sycl::malloc_device(size, queue), [captured_queue = queue](T* ptr) { - if (ptr != nullptr) { - sycl::free(ptr, captured_queue); - } - }); + T* ptr = sycl::malloc_device(size, queue); + if (ptr != nullptr) { + return std::shared_ptr(ptr, [captured_queue = queue](T* ptr) { + if (ptr != nullptr) { + sycl::free(ptr, captured_queue); + } + }); + } + throw internal_error("Could not allocate usm memory of size: ", size * sizeof(T), " bytes"); } /** @@ -245,6 +257,15 @@ detail::layout get_layout(const Descriptor& desc, direction dir) { return detail::layout::UNPACKED; } +/** + * return the padded length to be used for the Bluestein implementation + * @param committed_length Committed problem length which needs to be padded + * @return The padded length to be used for the Bluestein implementation + */ +inline IdxGlobal get_padded_length(double committed_length) { + return static_cast(std::pow(2, ceil(log(committed_length) / log(2.0)))); +} + } // namespace detail } // namespace portfft #endif diff --git a/test/unit_test/instantiate_fft_tests.hpp b/test/unit_test/instantiate_fft_tests.hpp index 94c74130..e5e973d4 100644 --- a/test/unit_test/instantiate_fft_tests.hpp +++ b/test/unit_test/instantiate_fft_tests.hpp @@ -156,6 +156,15 @@ INSTANTIATE_TEST_SUITE_P(WorkgroupOrGlobalRegressionTest, FFTTest, ::testing::Values(sizes_t{9800}, sizes_t{15360}, sizes_t{68640}))), test_params_print()); +// Test suite contains both Prime sized values(211, 523, 65537) as well as the sizes which have prime factors that we +// cannot handle (33012, 45232) +INSTANTIATE_TEST_SUITE_P(PrimeSizedTest, FFTTest, + ::testing::ConvertGenerator(::testing::Combine( + all_valid_global_placement_layouts, fwd_only, interleaved_storage, ::testing::Values(1, 3), + ::testing::Values(sizes_t{211}, sizes_t{523}, sizes_t{65537}, sizes_t{33012}, + sizes_t{45232}))), + test_params_print()); + // Backward FFT test suite INSTANTIATE_TEST_SUITE_P(BackwardTest, FFTTest, ::testing::ConvertGenerator(::testing::Combine(