Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU]: STFT optimization(s) #28367

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions src/plugins/intel_gpu/src/kernel_selector/cl_kernels/stft_opt.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//


#define FREQS_PER_THREAD 4

KERNEL(stft_ref)(
OPTIONAL_SHAPE_INFO_ARG
const __global INPUT0_TYPE* restrict signal,
const __global INPUT1_TYPE* restrict window,
const __global INPUT2_TYPE* restrict frame_size_buff,
const __global INPUT3_TYPE* restrict frame_step_buff,
__global OUTPUT_TYPE* restrict output)
{
#if TRANSPOSE_FRAMES
const size_t FREQS = OUTPUT_FEATURE_NUM;
#else
const size_t FREQS = OUTPUT_SIZE_Y;
#endif

const size_t blocksPerFreq = (FREQS + FREQ_PER_BLOCK-1)/FREQ_PER_BLOCK;
const size_t batch = get_global_id(0);
const size_t frame_id = get_group_id(1)/blocksPerFreq;
const size_t freq_start = (get_group_id(1)%blocksPerFreq)*FREQ_PER_BLOCK;
const size_t frame_size = (size_t)frame_size_buff[0];
const size_t frame_step = (size_t)frame_step_buff[0];
const size_t window_size = INPUT1_SIZE_X;

__local float x_i_shared[SHARED_X_I_BUFFER_SIZE];

const size_t block_size = get_local_size(0)*get_local_size(1)*get_local_size(2);

// Handling case where window size is smaller than frame size.
const int start_offset = (frame_size - window_size) / 2;

const INPUT0_TYPE* restrict signal_for_this_frame = signal + batch*INPUT0_SIZE_X + frame_id*frame_step + start_offset;

// Preload into shared mem:
for(size_t i = get_local_linear_id()*4; i < window_size; i+= block_size*4) {
// NOTE: Vectorization by internal unrolling loop, in order to compiler to
// decide it if can use vectorized vectorized instructions,
// which may depend on data type, pointer alignment etc).
#pragma unroll
for(size_t j = 0; j < 4; ++j) {
const float signal_val = (float)signal_for_this_frame[i+j];
const float window_val = (float)window[i+j];
x_i_shared[i+j] = signal_val*window_val;
}
}

// Handle leftovers:
const size_t leftovers_start = window_size%(block_size*4);
for(size_t i = leftovers_start + get_local_linear_id(); i < window_size; i+= block_size*4) {
const float signal_val = (float)signal_for_this_frame[i];
const float window_val = (float)window[i];
x_i_shared[i] = signal_val*window_val;
}

barrier(CLK_LOCAL_MEM_FENCE);

const size_t max_freq_for_this_block = min(freq_start + FREQ_PER_BLOCK, FREQS);

// Currently each sub group calcs 4 freq_id at the same time.
for(size_t freq_id = get_sub_group_id()*FREQS_PER_THREAD + freq_start; freq_id < max_freq_for_this_block; freq_id += get_num_sub_groups()*FREQS_PER_THREAD) {

float4 freq_val_real = 0.0f;
float4 freq_val_img = 0.0f;

// dft_power = 2*PI*(k/N) from dft def.
float4 dft_power = 2.0f * M_PI_F / (float)frame_size;
dft_power.s0 *= (float)(freq_id + 0);
dft_power.s1 *= (float)(freq_id + 1);
dft_power.s2 *= (float)(freq_id + 2);
dft_power.s3 *= (float)(freq_id + 3);

// For bigger window_size kernel is sin cos bound: Probably there is some external
// unit to calc sin cos, which is overloaded with commands(each thread issues 8 such instructions).
// TODO: Implement fft for those cases.
for(int i = get_sub_group_local_id(); i < window_size; i+= get_sub_group_size()) {
const float x_i = x_i_shared[i];

const float4 real = native_cos(dft_power*(float)(i+start_offset))*x_i;
const float4 img = -native_sin(dft_power*(float)(i+start_offset))*x_i;

freq_val_real += real;
freq_val_img += img;
}

freq_val_real.s0 = sub_group_reduce_add(freq_val_real.s0);
freq_val_real.s1 = sub_group_reduce_add(freq_val_real.s1);
freq_val_real.s2 = sub_group_reduce_add(freq_val_real.s2);
freq_val_real.s3 = sub_group_reduce_add(freq_val_real.s3);

freq_val_img.s0 = sub_group_reduce_add(freq_val_img.s0);
freq_val_img.s1 = sub_group_reduce_add(freq_val_img.s1);
freq_val_img.s2 = sub_group_reduce_add(freq_val_img.s2);
freq_val_img.s3 = sub_group_reduce_add(freq_val_img.s3);

if((freq_id < FREQS) && (get_sub_group_local_id() < 2*min((size_t)FREQS_PER_THREAD, (FREQS - freq_id)))) {
#if TRANSPOSE_FRAMES
const int output_idx = OUTPUT_GET_INDEX(batch, freq_id + get_sub_group_local_id()/2, frame_id, get_sub_group_local_id() % 2);
#else
const int output_idx = OUTPUT_GET_INDEX(batch, frame_id, freq_id + get_sub_group_local_id()/2, get_sub_group_local_id() % 2);
#endif
if ( (get_sub_group_local_id() % 2) == 0)
praasz marked this conversation as resolved.
Show resolved Hide resolved
output[output_idx] = (OUTPUT_TYPE)freq_val_real[get_sub_group_local_id()/2];
else
output[output_idx] = (OUTPUT_TYPE)freq_val_img[get_sub_group_local_id()/2];
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ JitConstants STFTKernelBase::GetJitConstants(const STFT_params& params) const {
}

void STFTKernelBase::GetUpdateDispatchDataFunc(KernelData& kd) const {
kd.update_dispatch_data_func = [](const Params& params, KernelData& kd) {
kd.update_dispatch_data_func = [this](const Params& params, KernelData& kd) {
const auto& prim_params = static_cast<const STFT_params&>(params);
auto dispatchData = SetDefault(prim_params);
auto dispatchData = CalcLaunchConfig(prim_params);
OPENVINO_ASSERT(kd.kernels.size() == 1, "[GPU] Invalid kernels size for update dispatch data func");
kd.kernels[0].params.workGroups.global = dispatchData.gws;
kd.kernels[0].params.workGroups.local = dispatchData.lws;
kd.kernels[0].skip_execution = KernelData::SkipKernelExecution(prim_params);
};
}

STFTKernelBase::DispatchData STFTKernelBase::SetDefault(const STFT_params& params) {
CommonDispatchData STFTKernelBase::CalcLaunchConfig(const STFT_params& params) const {
CommonDispatchData dispatchData;
const auto inLayout = params.inputs.front().GetLayout();
const auto& output = params.outputs.front();
Expand Down Expand Up @@ -57,11 +57,13 @@ STFTKernelBase::DispatchData STFTKernelBase::SetDefault(const STFT_params& param
}

KernelsData STFTKernelBase::GetCommonKernelsData(const Params& params) const {
assert(params.GetType() == KernelType::STFT);
if (!Validate(params)) {
return {};
}

const auto& prim_params = static_cast<const STFT_params&>(params);

auto dispatchData = SetDefault(prim_params);
auto dispatchData = CalcLaunchConfig(prim_params);
KernelData k_data = KernelData::Default<STFT_params>(params);

auto cldnn_jit = GetJitConstants(prim_params);
Expand All @@ -87,4 +89,9 @@ KernelsData STFTKernelBase::GetCommonKernelsData(const Params& params) const {

return {k_data};
}

bool STFTKernelBase::Validate(const Params& p) const {
return p.GetType() == KernelType::STFT;
}

} // namespace kernel_selector
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ class STFTKernelBase : public KernelBaseOpenCL {
public:
using KernelBaseOpenCL::KernelBaseOpenCL;

using DispatchData = CommonDispatchData;

protected:
JitConstants GetJitConstants(const STFT_params& params) const;
static DispatchData SetDefault(const STFT_params& params);
virtual JitConstants GetJitConstants(const STFT_params& params) const;
virtual CommonDispatchData CalcLaunchConfig(const STFT_params& params) const;
KernelsData GetCommonKernelsData(const Params& params) const;
void GetUpdateDispatchDataFunc(KernelData& kd) const override;
bool Validate(const Params& p) const override;
};
} // namespace kernel_selector
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (C) 2018-2024 Intel Corporation
pkowalc1 marked this conversation as resolved.
Show resolved Hide resolved
// SPDX-License-Identifier: Apache-2.0
//

#include "stft_kernel_opt.h"

const size_t FREQ_PER_BLOCK = 256;
const size_t STATIC_MAX_X_I_BUFFER = 2048;
const size_t THREADS_PER_BLOCK = 256;
namespace kernel_selector {
ParamsKey STFTKernelOpt::GetSupportedKey() const {
ParamsKey k;

k.EnableInputDataType(Datatype::INT32);
k.EnableInputDataType(Datatype::INT64);
k.EnableInputDataType(Datatype::F32);
k.EnableInputDataType(Datatype::F16);

k.EnableOutputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::F16);

k.EnableInputLayout(DataLayout::bfyx);

k.EnableOutputLayout(DataLayout::bfyx);

k.EnableBatching();
k.EnableDifferentTypes();
k.EnableDynamicShapesSupport();
return k;
}

JitConstants STFTKernelOpt::GetJitConstants(const STFT_params& params) const {
JitConstants jit = STFTKernelBase::GetJitConstants(params);

jit.AddConstants({MakeJitConstant("FREQ_PER_BLOCK", FREQ_PER_BLOCK)});
jit.AddConstants({MakeJitConstant("STATIC_MAX_X_I_BUFFER", STATIC_MAX_X_I_BUFFER)});

const auto xiMaxBuffer = params.is_shape_agnostic ? "STATIC_MAX_X_I_BUFFER" : "INPUT1_SIZE_X";
jit.AddConstants({MakeJitConstant("SHARED_X_I_BUFFER_SIZE", xiMaxBuffer)});

return jit;
}

KernelsData STFTKernelOpt::GetKernelsData(const Params& params) const {
return GetCommonKernelsData(params);
}

KernelsPriority STFTKernelOpt::GetKernelsPriority(const Params& /*params*/) const {
return FORCE_PRIORITY_8;
}

CommonDispatchData STFTKernelOpt::CalcLaunchConfig(const STFT_params& params) const {
CommonDispatchData dispatchData;
const auto& output = params.outputs.front();

OPENVINO_ASSERT(output.Dimentions() == 4);
OPENVINO_ASSERT(output.X().v == 2);

const size_t freqSize = params.transpose_frames ? output.Feature().v : output.Y().v;
const size_t blocksPerFreq = (freqSize + FREQ_PER_BLOCK - 1) / FREQ_PER_BLOCK;

const size_t framesSize = params.transpose_frames ? output.Y().v : output.Feature().v;
const size_t batchSize = output.Batch().v;

dispatchData.lws = {1, THREADS_PER_BLOCK};
dispatchData.gws = {batchSize, framesSize * THREADS_PER_BLOCK * blocksPerFreq};

return dispatchData;
}

bool STFTKernelOpt::Validate(const Params& p) const {
if (STFTKernelBase::Validate(p) == false)
return false;

const auto& params = static_cast<const STFT_params&>(p);
const auto windowSize = params.inputs[1].LogicalSize();

if (params.is_shape_agnostic && windowSize > STATIC_MAX_X_I_BUFFER)
return false;

return true;
}

} // namespace kernel_selector
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (C) 2018-2024 Intel Corporation
pkowalc1 marked this conversation as resolved.
Show resolved Hide resolved
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "stft_kernel_base.h"

namespace kernel_selector {
class STFTKernelOpt : public STFTKernelBase {
public:
STFTKernelOpt() : STFTKernelBase("stft_opt") {}

JitConstants GetJitConstants(const STFT_params& params) const override;
KernelsData GetKernelsData(const Params& params) const override;
KernelsPriority GetKernelsPriority(const Params& params) const override;
CommonDispatchData CalcLaunchConfig(const STFT_params& params) const override;
ParamsKey GetSupportedKey() const override;
bool Validate(const Params& p) const override;
};
} // namespace kernel_selector
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

#include "stft_kernel_selector.h"

#include "stft_kernel_opt.h"
#include "stft_kernel_ref.h"

namespace kernel_selector {
STFT_kernel_selector::STFT_kernel_selector() {
Attach<STFTKernelRef>();
Attach<STFTKernelOpt>();
}

KernelsData STFT_kernel_selector::GetBestKernels(const Params& params) const {
Expand Down
Loading
Loading