Skip to content

Commit 6ca0800

Browse files
trivedivivekSS-JIA
andauthored
[ET-VK] Using uint16 for quantized linear tiling shader to reduce register pressure and improve performance. (#10509)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #10509 * #10508 This diff reduces int precision for certain variables in 8 bit quantized tiled linear op to reduce register pressure and improve performance. Differential Revision: [D73752090](https://our.internmc.facebook.com/intern/diff/D73752090/) --------- Co-authored-by: Sicheng Stephen Jia <[email protected]>
1 parent 9f5988e commit 6ca0800

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl

+11-9
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,14 @@ layout(push_constant) uniform restrict Block {
4040

4141
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4242

43+
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
44+
4345
void main() {
44-
const uint out_width_ntexels = divup4(out_sizes.x);
45-
const uint out_col = (gl_GlobalInvocationID.x % out_width_ntexels) << 2;
46-
const uint out_row = (gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS;
46+
const uint16_t out_width_ntexels = uint16_t(divup4(out_sizes.x));
47+
const uint16_t out_col = uint16_t((gl_GlobalInvocationID.x % out_width_ntexels) << 2);
48+
const uint16_t out_row = uint16_t((gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS);
4749

48-
if (out_row >= out_sizes.y) {
50+
if (out_row >= uint16_t(out_sizes.y)) {
4951
return;
5052
}
5153

@@ -54,29 +56,29 @@ void main() {
5456
VEC4_T c[TILE_ROWS];
5557

5658
$if SCALES_STORAGE == "buffer":
57-
const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]);
59+
const VEC4_T scales = VEC4_T(t_scales[int(out_col >> 2)]);
5860
$else:
59-
const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2(out_col >> 2, 0), 0));
61+
const VEC4_T scales = VEC4_T(texelFetch(t_scales, u16vec2(out_col >> 2, 0), 0));
6062

6163
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
6264
c[i] = VEC4_T(0.0);
6365
}
6466

65-
for (int pos = 0; pos < in_sizes.x; pos += 4) {
67+
for (uint16_t pos = uint16_t(0); pos < uint16_t(in_sizes.x); pos += uint16_t(4)) {
6668
// Preload weight tensor
6769
[[unroll]] for (int i = 0; i < 4; i++) {
6870
$if WEIGHT_STORAGE == "buffer":
6971
b[i] = t_weight[((pos + i) * out_sizes.x + out_col) >> 2];
7072
$else:
71-
b[i] = VEC4_T(texelFetch(t_weight, ivec2(out_col >> 2, pos + i), 0));
73+
b[i] = VEC4_T(texelFetch(t_weight, u16vec2(out_col >> 2, pos + i), 0));
7274
}
7375

7476
// Preload input tensor
7577
[[unroll]] for (int i = 0; i < TILE_ROWS; i++) {
7678
$if IN_STORAGE == "buffer":
7779
a[i] = t_in[((out_row + i) * in_sizes.x + pos) >> 2];
7880
$else:
79-
a[i] = VEC4_T(texelFetch(t_in, ivec3(pos >> 2, out_row + i, 0), 0));
81+
a[i] = VEC4_T(texelFetch(t_in, u16vec3(pos >> 2, out_row + i, 0), 0));
8082
}
8183

8284
// Accumulate output

0 commit comments

Comments
 (0)