@@ -37,7 +37,21 @@ KERNEL(stft_ref)(
37
37
const INPUT0_TYPE * restrict signal_for_this_frame = signal + batch * INPUT0_SIZE_X + frame_id * frame_step + start_offset ;
38
38
39
39
// Preload into shared mem:
40
- for (size_t i = get_local_linear_id (); i < window_size ; i += block_size ) {
40
+ for (size_t i = get_local_linear_id ()* 4 ; i < window_size ; i += block_size * 4 ) {
41
+ // NOTE: Vectorization by internal unrolling loop, in order to compiler to
42
+ // decide it if can use vectorized vectorized instructions,
43
+ // which may depend on data type, pointer alignment etc).
44
+ #pragma unroll
45
+ for (size_t j = 0 ; j < 4 ; ++ j ) {
46
+ const float signal_val = (float )signal_for_this_frame [i + j ];
47
+ const float window_val = (float )window [i + j ];
48
+ x_i_shared [i + j ] = signal_val * window_val ;
49
+ }
50
+ }
51
+
52
+ // Handle leftovers:
53
+ const size_t leftovers_start = window_size %(block_size * 4 );
54
+ for (size_t i = leftovers_start + get_local_linear_id (); i < window_size ; i += block_size * 4 ) {
41
55
const float signal_val = (float )signal_for_this_frame [i ];
42
56
const float window_val = (float )window [i ];
43
57
x_i_shared [i ] = signal_val * window_val ;
@@ -47,22 +61,22 @@ KERNEL(stft_ref)(
47
61
48
62
const size_t max_freq_for_this_block = min (freq_start + FREQ_PER_BLOCK , FREQS );
49
63
50
- // Currently each sub group calcs 4 freq_id at the same time
64
+ // Currently each sub group calcs 4 freq_id at the same time.
51
65
for (size_t freq_id = get_sub_group_id ()* FREQS_PER_THREAD + freq_start ; freq_id < max_freq_for_this_block ; freq_id += get_num_sub_groups ()* FREQS_PER_THREAD ) {
52
66
53
67
float4 freq_val_real = 0.0f ;
54
68
float4 freq_val_img = 0.0f ;
55
69
56
- // // dft_power = 2*PI*(k/N) from dft def.
70
+ // dft_power = 2*PI*(k/N) from dft def.
57
71
float4 dft_power = 2.0f * M_PI_F / (float )frame_size ;
58
72
dft_power .s0 *= (float )(freq_id + 0 );
59
73
dft_power .s1 *= (float )(freq_id + 1 );
60
74
dft_power .s2 *= (float )(freq_id + 2 );
61
75
dft_power .s3 *= (float )(freq_id + 3 );
62
76
63
- // sin cos bound(?) : Probably there is some external unit to calc sin cos
64
- // which is overloaded with commands(each thread issues 8 such instructions)
65
- // TODO: Implement fft.
77
+ // For bigger window_size kernel is sin cos bound: Probably there is some external
78
+ // unit to calc sin cos, which is overloaded with commands(each thread issues 8 such instructions).
79
+ // TODO: Implement fft for those cases .
66
80
for (int i = get_sub_group_local_id (); i < window_size ; i += get_sub_group_size ()) {
67
81
const float x_i = x_i_shared [i ];
68
82
0 commit comments