Skip to content

Commit

Permalink
formatq
Browse files Browse the repository at this point in the history
  • Loading branch information
t4c1 committed Jan 30, 2024
1 parent 727fa15 commit 273b2c9
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 38 deletions.
2 changes: 1 addition & 1 deletion src/portfft/committed_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace portfft {

template <typename Scalar, domain Domain>
class committed_descriptor : private committed_descriptor_impl<Scalar, Domain> {
public:
public:
/**
* Alias for `Scalar`.
*/
Expand Down
20 changes: 10 additions & 10 deletions src/portfft/committed_descriptor_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
#include "utils.hpp"

namespace portfft {

template <typename Scalar, domain Domain>
class committed_descriptor_impl;

Expand Down Expand Up @@ -142,8 +142,8 @@ template <typename Scalar, domain Domain>
struct descriptor;

/*
Compute functions in the `committed_descriptor_impl` call `dispatch_kernel` and `dispatch_kernel_helper`. These two functions
ensure the kernel is run with a supported subgroup size. Next `dispatch_kernel_helper` calls `run_kernel`. The
Compute functions in the `committed_descriptor_impl` call `dispatch_kernel` and `dispatch_kernel_helper`. These two
functions ensure the kernel is run with a supported subgroup size. Next `dispatch_kernel_helper` calls `run_kernel`. The
`run_kernel` member function picks appropriate implementation and calls the static `run_kernel of that implementation`.
The implementation specific `run_kernel` handles differences between forward and backward computations, casts the memory
(USM or buffers) from complex to scalars and launches the kernel. Each function described in this doc has only one
Expand Down Expand Up @@ -177,16 +177,16 @@ The computational parts of the implementations are further documented in files w
*/
template <typename Scalar, domain Domain>
class committed_descriptor_impl {

friend struct descriptor<Scalar, Domain>;
template <typename Scalar1, domain Domain1, detail::layout LayoutIn, detail::layout LayoutOut, Idx SubgroupSize,
typename TIn>
friend std::vector<sycl::event> detail::compute_level(
const typename committed_descriptor_impl<Scalar1, Domain1>::kernel_data_struct& kd_struct, TIn input, Scalar1* output,
TIn input_imag, Scalar1* output_imag, const Scalar1* twiddles_ptr, const IdxGlobal* factors_triple,
IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, IdxGlobal input_global_offset,
IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, IdxGlobal batch_start, Idx factor_id,
Idx total_factors, complex_storage storage, const std::vector<sycl::event>& dependencies, sycl::queue& queue);
const typename committed_descriptor_impl<Scalar1, Domain1>::kernel_data_struct& kd_struct, TIn input,
Scalar1* output, TIn input_imag, Scalar1* output_imag, const Scalar1* twiddles_ptr,
const IdxGlobal* factors_triple, IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset,
IdxGlobal input_global_offset, IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms,
IdxGlobal batch_start, Idx factor_id, Idx total_factors, complex_storage storage,
const std::vector<sycl::event>& dependencies, sycl::queue& queue);

template <typename Scalar1, domain Domain1, typename TOut>
friend sycl::event detail::transpose_level(
Expand Down Expand Up @@ -934,7 +934,7 @@ class committed_descriptor_impl {
}

public:
committed_descriptor_impl(const committed_descriptor_impl& desc) : params(desc.params) { //TODO params copied twice
committed_descriptor_impl(const committed_descriptor_impl& desc) : params(desc.params) { // TODO params copied twice
PORTFFT_LOG_FUNCTION_ENTRY();
create_copy(desc);
}
Expand Down
15 changes: 8 additions & 7 deletions src/portfft/common/global.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,11 +502,12 @@ sycl::event transpose_level(const typename committed_descriptor_impl<Scalar, Dom
template <typename Scalar, domain Domain, detail::layout LayoutIn, detail::layout LayoutOut, Idx SubgroupSize,
typename TIn>
std::vector<sycl::event> compute_level(
const typename committed_descriptor_impl<Scalar, Domain>::kernel_data_struct& kd_struct, const TIn input, Scalar* output,
const TIn input_imag, Scalar* output_imag, const Scalar* twiddles_ptr, const IdxGlobal* factors_triple,
IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, IdxGlobal input_global_offset,
IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, IdxGlobal batch_start, Idx factor_id,
Idx total_factors, complex_storage storage, const std::vector<sycl::event>& dependencies, sycl::queue& queue) {
const typename committed_descriptor_impl<Scalar, Domain>::kernel_data_struct& kd_struct, const TIn input,
Scalar* output, const TIn input_imag, Scalar* output_imag, const Scalar* twiddles_ptr,
const IdxGlobal* factors_triple, IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset,
IdxGlobal input_global_offset, IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms,
IdxGlobal batch_start, Idx factor_id, Idx total_factors, complex_storage storage,
const std::vector<sycl::event>& dependencies, sycl::queue& queue) {
PORTFFT_LOG_FUNCTION_ENTRY();
IdxGlobal local_range = kd_struct.local_range;
IdxGlobal global_range = kd_struct.global_range;
Expand Down Expand Up @@ -539,8 +540,8 @@ std::vector<sycl::event> compute_level(
const IdxGlobal* inclusive_scan = factors_triple + 2 * total_factors;
const Idx vec_size = storage == complex_storage::INTERLEAVED_COMPLEX ? 2 : 1;
std::vector<sycl::event> events;
PORTFFT_LOG_TRACE("Local mem requirement - input:", local_memory_for_input, "store modifiers",
local_mem_for_store_modifier, "twiddles", loc_mem_for_twiddles, "total",
PORTFFT_LOG_TRACE("Local mem requirement - input:", local_memory_for_input, "store modifiers",
local_mem_for_store_modifier, "twiddles", loc_mem_for_twiddles, "total",
local_memory_for_input + local_mem_for_store_modifier + loc_mem_for_twiddles);
for (Idx batch_in_l2 = 0; batch_in_l2 < num_batches_in_l2 && batch_in_l2 + batch_start < n_transforms;
batch_in_l2++) {
Expand Down
10 changes: 5 additions & 5 deletions src/portfft/dispatcher/global_dispatcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,8 @@ struct committed_descriptor_impl<Scalar, Domain>::set_spec_constants_struct::inn

template <typename Scalar, domain Domain>
template <detail::layout LayoutIn, typename Dummy>
struct committed_descriptor_impl<Scalar, Domain>::num_scalars_in_local_mem_struct::inner<detail::level::GLOBAL, LayoutIn,
Dummy> {
struct committed_descriptor_impl<Scalar, Domain>::num_scalars_in_local_mem_struct::inner<detail::level::GLOBAL,
LayoutIn, Dummy> {
static std::size_t execute(committed_descriptor_impl& /*desc*/, std::size_t /*length*/, Idx /*used_sg_size*/,
const std::vector<Idx>& /*factors*/, Idx& /*num_sgs_per_wg*/) {
PORTFFT_LOG_FUNCTION_ENTRY();
Expand All @@ -296,9 +296,9 @@ template <typename Scalar, domain Domain>
template <detail::layout LayoutIn, detail::layout LayoutOut, Idx SubgroupSize, typename TIn, typename TOut>
template <typename Dummy>
struct committed_descriptor_impl<Scalar, Domain>::run_kernel_struct<LayoutIn, LayoutOut, SubgroupSize, TIn,
TOut>::inner<detail::level::GLOBAL, Dummy> {
static sycl::event execute(committed_descriptor_impl& desc, const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag,
const std::vector<sycl::event>& dependencies, IdxGlobal n_transforms,
TOut>::inner<detail::level::GLOBAL, Dummy> {
static sycl::event execute(committed_descriptor_impl& desc, const TIn& in, TOut& out, const TIn& in_imag,
TOut& out_imag, const std::vector<sycl::event>& dependencies, IdxGlobal n_transforms,
IdxGlobal input_offset, IdxGlobal output_offset, dimension_struct& dimension_data,
direction compute_direction) {
PORTFFT_LOG_FUNCTION_ENTRY();
Expand Down
10 changes: 5 additions & 5 deletions src/portfft/dispatcher/subgroup_dispatcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,9 +618,9 @@ template <typename Scalar, domain Domain>
template <detail::layout LayoutIn, detail::layout LayoutOut, Idx SubgroupSize, typename TIn, typename TOut>
template <typename Dummy>
struct committed_descriptor_impl<Scalar, Domain>::run_kernel_struct<LayoutIn, LayoutOut, SubgroupSize, TIn,
TOut>::inner<detail::level::SUBGROUP, Dummy> {
static sycl::event execute(committed_descriptor_impl& desc, const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag,
const std::vector<sycl::event>& dependencies, IdxGlobal n_transforms,
TOut>::inner<detail::level::SUBGROUP, Dummy> {
static sycl::event execute(committed_descriptor_impl& desc, const TIn& in, TOut& out, const TIn& in_imag,
TOut& out_imag, const std::vector<sycl::event>& dependencies, IdxGlobal n_transforms,
IdxGlobal input_offset, IdxGlobal output_offset, dimension_struct& dimension_data,
direction compute_direction) {
PORTFFT_LOG_FUNCTION_ENTRY();
Expand Down Expand Up @@ -690,8 +690,8 @@ struct committed_descriptor_impl<Scalar, Domain>::set_spec_constants_struct::inn

template <typename Scalar, domain Domain>
template <detail::layout LayoutIn, typename Dummy>
struct committed_descriptor_impl<Scalar, Domain>::num_scalars_in_local_mem_struct::inner<detail::level::SUBGROUP, LayoutIn,
Dummy> {
struct committed_descriptor_impl<Scalar, Domain>::num_scalars_in_local_mem_struct::inner<detail::level::SUBGROUP,
LayoutIn, Dummy> {
static std::size_t execute(committed_descriptor_impl& desc, std::size_t length, Idx used_sg_size,
const std::vector<Idx>& factors, Idx& num_sgs_per_wg) {
PORTFFT_LOG_FUNCTION_ENTRY();
Expand Down
10 changes: 5 additions & 5 deletions src/portfft/dispatcher/workgroup_dispatcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,9 @@ template <typename Scalar, domain Domain>
template <detail::layout LayoutIn, detail::layout LayoutOut, Idx SubgroupSize, typename TIn, typename TOut>
template <typename Dummy>
struct committed_descriptor_impl<Scalar, Domain>::run_kernel_struct<LayoutIn, LayoutOut, SubgroupSize, TIn,
TOut>::inner<detail::level::WORKGROUP, Dummy> {
static sycl::event execute(committed_descriptor_impl& desc, const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag,
const std::vector<sycl::event>& dependencies, IdxGlobal n_transforms,
TOut>::inner<detail::level::WORKGROUP, Dummy> {
static sycl::event execute(committed_descriptor_impl& desc, const TIn& in, TOut& out, const TIn& in_imag,
TOut& out_imag, const std::vector<sycl::event>& dependencies, IdxGlobal n_transforms,
IdxGlobal input_offset, IdxGlobal output_offset, dimension_struct& dimension_data,
direction compute_direction) {
PORTFFT_LOG_FUNCTION_ENTRY();
Expand Down Expand Up @@ -356,8 +356,8 @@ struct committed_descriptor_impl<Scalar, Domain>::set_spec_constants_struct::inn

template <typename Scalar, domain Domain>
template <typename detail::layout LayoutIn, typename Dummy>
struct committed_descriptor_impl<Scalar, Domain>::num_scalars_in_local_mem_struct::inner<detail::level::WORKGROUP, LayoutIn,
Dummy> {
struct committed_descriptor_impl<Scalar, Domain>::num_scalars_in_local_mem_struct::inner<detail::level::WORKGROUP,
LayoutIn, Dummy> {
static std::size_t execute(committed_descriptor_impl& /*desc*/, std::size_t length, Idx used_sg_size,
const std::vector<Idx>& factors, Idx& /*num_sgs_per_wg*/) {
PORTFFT_LOG_FUNCTION_ENTRY();
Expand Down
10 changes: 5 additions & 5 deletions src/portfft/dispatcher/workitem_dispatcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,9 @@ template <typename Scalar, domain Domain>
template <detail::layout LayoutIn, detail::layout LayoutOut, Idx SubgroupSize, typename TIn, typename TOut>
template <typename Dummy>
struct committed_descriptor_impl<Scalar, Domain>::run_kernel_struct<LayoutIn, LayoutOut, SubgroupSize, TIn,
TOut>::inner<detail::level::WORKITEM, Dummy> {
static sycl::event execute(committed_descriptor_impl& desc, const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag,
const std::vector<sycl::event>& dependencies, IdxGlobal n_transforms,
TOut>::inner<detail::level::WORKITEM, Dummy> {
static sycl::event execute(committed_descriptor_impl& desc, const TIn& in, TOut& out, const TIn& in_imag,
TOut& out_imag, const std::vector<sycl::event>& dependencies, IdxGlobal n_transforms,
IdxGlobal input_offset, IdxGlobal output_offset, dimension_struct& dimension_data,
direction compute_direction) {
PORTFFT_LOG_FUNCTION_ENTRY();
Expand Down Expand Up @@ -346,8 +346,8 @@ struct committed_descriptor_impl<Scalar, Domain>::set_spec_constants_struct::inn

template <typename Scalar, domain Domain>
template <detail::layout LayoutIn, typename Dummy>
struct committed_descriptor_impl<Scalar, Domain>::num_scalars_in_local_mem_struct::inner<detail::level::WORKITEM, LayoutIn,
Dummy> {
struct committed_descriptor_impl<Scalar, Domain>::num_scalars_in_local_mem_struct::inner<detail::level::WORKITEM,
LayoutIn, Dummy> {
static std::size_t execute(committed_descriptor_impl& desc, std::size_t length, Idx used_sg_size,
const std::vector<Idx>& /*factors*/, Idx& num_sgs_per_wg) {
PORTFFT_LOG_FUNCTION_ENTRY();
Expand Down

0 comments on commit 273b2c9

Please sign in to comment.