@@ -40,12 +40,14 @@ layout(push_constant) uniform restrict Block {
40
40
41
41
layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
42
42
43
+ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
44
+
43
45
void main() {
44
- const uint out_width_ntexels = divup4(out_sizes.x);
45
- const uint out_col = ( gl_GlobalInvocationID.x % out_width_ntexels) << 2 ;
46
- const uint out_row = ( gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS;
46
+ const uint16_t out_width_ntexels = uint16_t( divup4(out_sizes.x) );
47
+ const uint16_t out_col = uint16_t(( gl_GlobalInvocationID.x % out_width_ntexels) << 2 ) ;
48
+ const uint16_t out_row = uint16_t(( gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS) ;
47
49
48
- if (out_row >= out_sizes.y) {
50
+ if (out_row >= uint16_t( out_sizes.y) ) {
49
51
return ;
50
52
}
51
53
@@ -54,29 +56,29 @@ void main() {
54
56
VEC4_T c[TILE_ROWS];
55
57
56
58
$if SCALES_STORAGE == "buffer ":
57
- const VEC4_T scales = VEC4_T(t_scales[out_col >> 2 ]);
59
+ const VEC4_T scales = VEC4_T(t_scales[int ( out_col >> 2 ) ]);
58
60
$else :
59
- const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2 (out_col >> 2 , 0 ), 0 ));
61
+ const VEC4_T scales = VEC4_T(texelFetch(t_scales, u16vec2 (out_col >> 2 , 0 ), 0 ));
60
62
61
63
[[unroll]] for (int i = 0 ; i < TILE_ROWS; ++ i) {
62
64
c[i] = VEC4_T(0.0 );
63
65
}
64
66
65
- for (int pos = 0 ; pos < in_sizes.x; pos += 4 ) {
67
+ for (uint16_t pos = uint16_t( 0 ) ; pos < uint16_t( in_sizes.x) ; pos += uint16_t( 4 ) ) {
66
68
// Preload weight tensor
67
69
[[unroll]] for (int i = 0 ; i < 4 ; i++ ) {
68
70
$if WEIGHT_STORAGE == "buffer ":
69
71
b[i] = t_weight[((pos + i) * out_sizes.x + out_col) >> 2 ];
70
72
$else :
71
- b[i] = VEC4_T(texelFetch(t_weight, ivec2 (out_col >> 2 , pos + i), 0 ));
73
+ b[i] = VEC4_T(texelFetch(t_weight, u16vec2 (out_col >> 2 , pos + i), 0 ));
72
74
}
73
75
74
76
// Preload input tensor
75
77
[[unroll]] for (int i = 0 ; i < TILE_ROWS; i++ ) {
76
78
$if IN_STORAGE == "buffer ":
77
79
a[i] = t_in[((out_row + i) * in_sizes.x + pos) >> 2 ];
78
80
$else :
79
- a[i] = VEC4_T(texelFetch(t_in, ivec3 (pos >> 2 , out_row + i, 0 ), 0 ));
81
+ a[i] = VEC4_T(texelFetch(t_in, u16vec3 (pos >> 2 , out_row + i, 0 ), 0 ));
80
82
}
81
83
82
84
// Accumulate output
0 commit comments