diff --git a/src/portfft/committed_descriptor_impl.hpp b/src/portfft/committed_descriptor_impl.hpp index e7d66328..1b799d1d 100644 --- a/src/portfft/committed_descriptor_impl.hpp +++ b/src/portfft/committed_descriptor_impl.hpp @@ -138,34 +138,6 @@ detail::layout get_layout(const Descriptor& desc, direction dir) { return detail::layout::UNPACKED; } -/* -Compute functions in the `committed_descriptor_impl` call `dispatch_kernel` and `dispatch_kernel_helper`. These two -functions ensure the kernel is run with a supported subgroup size. Next `dispatch_kernel_helper` calls `run_kernel`. The -`run_kernel` member function picks appropriate implementation and calls the static `run_kernel of that implementation`. -The implementation specific `run_kernel` handles differences between forward and backward computations, casts the memory -(USM or buffers) from complex to scalars and launches the kernel. Each function described in this doc has only one -templated overload that handles both directions of transforms and buffer and USM memory. - -Device functions make no assumptions on the size of a work group or the number of workgroups in a kernel. These numbers -can be tuned for each device. - -Implementation-specific `run_kernel` function make the size of the FFT that is handled by the individual workitems -compile time constant. The one for subgroup implementation also calls `cross_sg_dispatcher` that makes the -cross-subgroup factor of FFT size compile time constant. They do that by using a switch on the FFT size for one -workitem, before calling `workitem_impl`, `subgroup_impl` or `workgroup_impl` . The `_impl` functions take the FFT size -for one workitem as a template parameter. Only the calls that are determined to fit into available registers (depending -on the value of PORTFFT_TARGET_REGS_PER_WI macro) are actually instantiated. - -The `_impl` functions iterate over the batch of problems, loading data for each first in -local memory then from there into private one. This is done in these two steps to avoid non-coalesced global memory -accesses. `workitem_impl` loads one problem per workitem, `subgroup_impl` loads one problem per subgroup and -`workgroup_impl` loads one problem per workgroup. After doing computations by the calls to `wi_dft` for workitem, -`sg_dft` for subgroup and `wg_dft` for workgroup, the data is written out, going through local memory again. - -The computational parts of the implementations are further documented in files with their implementations -`workitem.hpp`, `subgroup.hpp` and `workgroup.hpp`. -*/ - /** * A committed descriptor that contains everything that is needed to run FFT. * diff --git a/src/portfft/common/global.hpp b/src/portfft/common/global.hpp index 7798a789..d49ebe2c 100644 --- a/src/portfft/common/global.hpp +++ b/src/portfft/common/global.hpp @@ -45,7 +45,7 @@ namespace detail { /** * Gets the precomputed inclusive scan of the factors at a particular index. * - * @param inclusive_scan global memory pointer containing the inclusive scan of the factors + * @param inclusive_scan pointer to global memory containing the inclusive scan of the factors * @param num_factors Number of factors * @param level_num factor number * @return Outer batch product @@ -72,9 +72,9 @@ PORTFFT_INLINE inline IdxGlobal get_outer_batch_product(const IdxGlobal* inclusi * required m-dimensional loop into the single loop (dispatch level), and this function calculates the offset. * Precomputed inclusive scans are used to further reduce the number of calculations required. * - * @param factors global memory pointer containing factors of the input - * @param inner_batches global memory pointer containing the inner batch for each factor - * @param inclusive_scan global memory pointer containing the inclusive scan of the factors + * @param factors pointer to global memory containing factors of the input + * @param inner_batches pointer to global memory containing the inner batch for each factor + * @param inclusive_scan pointer to global memory containing the inclusive scan of the factors * @param num_factors Number of factors * @param iter_value Current iterator value of the flattened n-dimensional loop * @param outer_batch_product Inclusive Scan of factors at position level_num-1 @@ -122,14 +122,14 @@ PORTFFT_INLINE inline IdxGlobal get_outer_batch_offset(const IdxGlobal* factors, * @param output output pointer * @param input_imag input pointer for imaginary data * @param output_imag output pointer for imaginary data - * @param implementation_twiddles global twiddles pointer containing twiddles for the sub implementation + * @param implementation_twiddles pointer to global memory containing twiddles for the sub implementation * @param store_modifier store modifier data - * @param input_loc local memory for storing the input - * @param twiddles_loc local memory for storing the twiddles for sub-implementation - * @param store_modifier_loc local memory for store modifier data - * @param factors global memory pointer containing factors of the input - * @param inner_batches global memory pointer containing the inner batch for each factor - * @param inclusive_scan global memory pointer containing the inclusive scan of the factors + * @param input_loc pointer to local memory for storing the input + * @param twiddles_loc pointer to local memory for storing the twiddles for sub-implementation + * @param store_modifier_loc pointer to local memory for store modifier data + * @param factors pointer to global memory containing factors of the input + * @param inner_batches pointer to global memory containing the inner batch for each factor + * @param inclusive_scan pointer to global memory containing the inclusive scan of the factors * @param batch_size Batch size for the corresponding input * @param global_data global data * @param kh kernel handler @@ -187,10 +187,10 @@ PORTFFT_INLINE void dispatch_level(const Scalar* input, Scalar* output, const Sc * @param loc_for_store_modifier local memory for store modifier data * @param multipliers_between_factors twiddles to be multiplied between factors * @param impl_twiddles twiddles required for sub implementation - * @param factors global memory pointer containing factors of the input - * @param inner_batches global memory pointer containing the inner batch for each factor - * @param inclusive_scan global memory pointer containing the inclusive scan of the factors - * @param n_transforms batch size corresposding to the factor + * @param factors pointer to global memory containing factors of the input + * @param inner_batches pointer to global memory containing the inner batch for each factor + * @param inclusive_scan pointer to global memory containing the inclusive scan of the factors + * @param n_transforms batch size corresponding to the factor * @param input_batch_offset offset for the input pointer * @param launch_params launch configuration, the global and local range with which the kernel will get launched * @param cgh associated command group handler @@ -246,10 +246,10 @@ void launch_kernel(sycl::accessor& in * @param loc_for_store_modifier local memory for store modifier data * @param multipliers_between_factors twiddles to be multiplied between factors * @param impl_twiddles twiddles required for sub implementation - * @param factors global memory pointer containing factors of the input - * @param inner_batches global memory pointer containing the inner batch for each factor - * @param inclusive_scan global memory pointer containing the inclusive scan of the factors - * @param n_transforms batch size corresposding to the factor + * @param factors pointer to global memory containing factors of the input + * @param inner_batches pointer to global memory containing the inner batch for each factor + * @param inclusive_scan pointer to global memory containing the inclusive scan of the factors + * @param n_transforms batch size corresponding to the factor * @param input_batch_offset offset for the input pointer * @param launch_params launch configuration, the global and local range with which the kernel will get launched * @param cgh associated command group handler @@ -297,9 +297,9 @@ void launch_kernel(const Scalar* input, Scalar* output, const Scalar* input_imag * @param input input pointer * @param output output accessor * @param loc 2D local memory - * @param factors global memory pointer containing factors of the input - * @param inner_batches global memory pointer containing the inner batch for each factor - * @param inclusive_scan global memory pointer containing the inclusive scan of the factors + * @param factors pointer to global memory containing factors of the input + * @param inner_batches pointer to global memory containing the inner batch for each factor + * @param inclusive_scan pointer to global memory containing the inclusive scan of the factors * @param output_offset offset to output pointer * @param ldb leading dimension of the output * @param lda leading dimension of the input @@ -357,9 +357,9 @@ static void dispatch_transpose_kernel_impl(const Scalar* input, * @param input input pointer * @param output output pointer * @param loc 2D local memory - * @param factors global memory pointer containing factors of the input - * @param inner_batches global memory pointer containing the inner batch for each factor - * @param inclusive_scan global memory pointer containing the inclusive scan of the factors + * @param factors pointer to global memory containing factors of the input + * @param inner_batches pointer to global memory containing the inner batch for each factor + * @param inclusive_scan pointer to global memory containing the inclusive scan of the factors * @param output_offset offset to output pointer * @param ldb leading dimension of the output * @param lda leading dimension of the input @@ -418,7 +418,7 @@ static void dispatch_transpose_kernel_impl(const Scalar* input, Scalar* output, * @param kd_struct kernel data struct * @param input input pointer * @param output output usm/buffer - * @param factors_triple global memory pointer containing factors, inner batches corresponding per factor, and the + * @param factors_triple pointer to global memory containing factors, inner batches corresponding per factor, and the * inclusive scan of the factors * @param committed_size committed size of the FFT * @param num_batches_in_l2 number of batches in l2 @@ -481,8 +481,8 @@ sycl::event transpose_level(const typename committed_descriptor_impl get_launch_params(IdxGlobal fft_size, Idx /** * Transposes A into B, for complex inputs only - * @param a Input pointer a - * @param b Input pointer b - * @param lda leading dimension A - * @param ldb leading Dimension B + * @param a Input pointer + * @param b Output pointer + * @param lda leading dimension of `a` + * @param ldb leading dimension of `b` * @param num_elements Total number of complex values in the matrix */ template diff --git a/src/portfft/dispatcher/subgroup_dispatcher.hpp b/src/portfft/dispatcher/subgroup_dispatcher.hpp index 322abbd9..9a9e0c8d 100644 --- a/src/portfft/dispatcher/subgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/subgroup_dispatcher.hpp @@ -64,19 +64,18 @@ IdxGlobal get_global_size_subgroup(IdxGlobal n_transforms, Idx factor_sg, Idx su * @tparam LayoutOut Output Layout * @tparam SubgroupSize size of the subgroup * @tparam T type of the scalar used for computations - * @param input accessor or pointer to global memory containing input data. If complex storage (from + * @param input pointer to global memory containing input data. If complex storage (from * `SpecConstComplexStorage`) is split, this is just the real part of data. - * @param output accessor or pointer to global memory for output data. If complex storage (from + * @param output pointer to global memory for output data. If complex storage (from * `SpecConstComplexStorage`) is split, this is just the real part of data. - * @param input accessor or pointer to global memory containing imaginary part of the input data if complex storage + * @param input pointer to global memory containing imaginary part of the input data if complex storage * (from `SpecConstComplexStorage`) is split. Otherwise unused. - * @param output accessor or pointer to global memory containing imaginary part of the input data if complex storage + * @param output pointer to global memory containing imaginary part of the input data if complex storage * (from `SpecConstComplexStorage`) is split. Otherwise unused. - * @param loc local accessor. Must have enough space for 2*FactorWI*FactorSG*SubgroupSize + * @param loc pointer to local memory. Size requirement is determined by `num_scalars_in_local_mem_struct`. + * @param loc_twiddles pointer to local memory for twiddle factors. Must have enough space for `2 * FactorWI * FactorSG` * values - * @param loc_twiddles local accessor for twiddle factors. Must have enough space for 2*FactorWI*FactorSG - * values - * @param n_transforms number of FT transforms to do in one call + * @param n_transforms number of FFT transforms to do in one call * @param global_data global data for the kernel * @param kh kernel handler associated with the kernel launch * @param twiddles pointer containing twiddles diff --git a/src/portfft/dispatcher/workgroup_dispatcher.hpp b/src/portfft/dispatcher/workgroup_dispatcher.hpp index a0a65dd6..4ed83076 100644 --- a/src/portfft/dispatcher/workgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/workgroup_dispatcher.hpp @@ -81,15 +81,15 @@ IdxGlobal get_global_size_workgroup(IdxGlobal n_transforms, Idx subgroup_size, I * @tparam SubgroupSize size of the subgroup * @tparam T Scalar type * - * @param input accessor or pointer to global memory containing input data. If complex storage (from + * @param input pointer to global memory containing input data. If complex storage (from * `SpecConstComplexStorage`) is split, this is just the real part of data. - * @param output accessor or pointer to global memory for output data. If complex storage (from + * @param output pointer to global memory for output data. If complex storage (from * `SpecConstComplexStorage`) is split, this is just the real part of data. - * @param input_imag accessor or pointer to global memory containing imaginary part of the input data if complex storage + * @param input_imag pointer to global memory containing imaginary part of the input data if complex storage * (from `SpecConstComplexStorage`) is split. Otherwise unused. - * @param output_imag accessor or pointer to global memory containing imaginary part of the input data if complex + * @param output_imag pointer to global memory containing imaginary part of the input data if complex * storage (from `SpecConstComplexStorage`) is split. Otherwise unused. - * @param loc Pointer to local memory + * @param loc Pointer to local memory. Size requirement is determined by `num_scalars_in_local_mem_struct`. * @param loc_twiddles pointer to local allocation for subgroup level twiddles * @param n_transforms number of fft batches * @param global_data global data for the kernel diff --git a/src/portfft/dispatcher/workitem_dispatcher.hpp b/src/portfft/dispatcher/workitem_dispatcher.hpp index 614229a8..9a351749 100644 --- a/src/portfft/dispatcher/workitem_dispatcher.hpp +++ b/src/portfft/dispatcher/workitem_dispatcher.hpp @@ -81,16 +81,15 @@ PORTFFT_INLINE void apply_modifier(Idx num_elements, PrivT priv, const T* modifi * @tparam LayoutOut Output Layout * @tparam SubgroupSize size of the subgroup * @tparam T type of the scalar used for computations - * @param input accessor or pointer to global memory containing input data. If complex storage (from + * @param input pointer to global memory containing input data. If complex storage (from * `SpecConstComplexStorage`) is split, this is just the real part of data. - * @param output accessor or pointer to global memory for output data. If complex storage (from + * @param output pointer to global memory for output data. If complex storage (from * `SpecConstComplexStorage`) is split, this is just the real part of data. - * @param input accessor or pointer to global memory containing imaginary part of the input data if complex storage + * @param input pointer to global memory containing imaginary part of the input data if complex storage * (from `SpecConstComplexStorage`) is split. Otherwise unused. - * @param output accessor or pointer to global memory containing imaginary part of the input data if complex storage + * @param output pointer to global memory containing imaginary part of the input data if complex storage * (from `SpecConstComplexStorage`) is split. Otherwise unused. - * @param loc local memory pointer. Must have enough space for 2*fft_size*SubgroupSize - * values + * @param loc local memory pointer. Size requirement is determined by `num_scalars_in_local_mem_struct`. * @param n_transforms number of FT transforms to do in one call * @param global_data global data for the kernel * @param kh kernel handler associated with the kernel launch