Skip to content

Commit b12ae24

Browse files
author
morelos
committed
[ET-VK][Ops] dequantize_per_tensor.tensor variant
Pull Request resolved: #12209 # Context We need a tensor variant for dequantize/quantize operators since that is the expected output of choose_qparams. # Changes This extends the logic that currently exists to support a tensor variant for scales and zeros. ghstack-source-id: 294163235 @exported-using-ghexport Differential Revision: [D77746135](https://our.internmc.facebook.com/intern/diff/D77746135/)
1 parent 5b70c9f commit b12ae24

File tree

6 files changed

+434
-13
lines changed

6 files changed

+434
-13
lines changed

backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,14 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")}
2727
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")}
2828

2929
$if MODE == "per_tensor":
30+
$if SHAPE == "tensor":
31+
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
32+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
33+
3034
layout(push_constant) uniform restrict Block {
31-
float scale;
32-
int zero_point;
35+
$if SHAPE == "scalar":
36+
float scale;
37+
int zero_point;
3338
int quant_min;
3439
int quant_max;
3540
};
@@ -146,7 +151,10 @@ void dequantize_per_tensor() {
146151
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);
147152

148153
IN_T qvalue = t_in[in_bufi];
149-
OUT_T value = dequantize_val(qvalue, scale, zero_point);
154+
$if SHAPE == "scalar":
155+
OUT_T value = dequantize_val(qvalue, scale, zero_point);
156+
$if SHAPE == "tensor":
157+
OUT_T value = dequantize_val(qvalue, t_scale[0], t_zero_point[0]);
150158

151159
t_out[out_bufi] = value;
152160
}

backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ dequantize_buffer:
33
IN_DTYPE: int32
44
OUT_DTYPE: float
55
MODE: per_tensor
6+
SHAPE: tensor
67
generate_variant_forall:
78
IN_DTYPE:
89
- VALUE: uint8
@@ -15,6 +16,9 @@ dequantize_buffer:
1516
shader_variants:
1617
- NAME: dequantize_per_tensor_buffer
1718
MODE: per_tensor
19+
SHAPE: scalar
20+
- NAME: dequantize_per_tensor_tensor_buffer
21+
MODE: per_tensor
1822
- NAME: dequantize_per_token_buffer
1923
MODE: per_token
2024
- NAME: dequantize_per_channel_buffer

backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,14 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")}
3030
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")}
3131

3232
$if MODE == "per_tensor":
33+
$if SHAPE == "tensor":
34+
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
35+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
36+
3337
layout(push_constant) uniform restrict Block {
34-
float scale;
35-
int zero_point;
38+
$if SHAPE == "scalar":
39+
float scale;
40+
int zero_point;
3641
int quant_min;
3742
int quant_max;
3843
};
@@ -148,7 +153,11 @@ void dequantize_per_tensor() {
148153

149154
[[unroll]] for (int i = 0; i < 4; ++i) {
150155
IN_T qvalue = IN_T(intex[i]);
151-
OUT_T value = dequantize_val(qvalue, scale, zero_point);
156+
$if SHAPE == "scalar":
157+
OUT_T value = dequantize_val(qvalue, scale, zero_point);
158+
$if SHAPE == "tensor":
159+
OUT_T value = dequantize_val(qvalue, t_scale[0], t_zero_point[0]);
160+
152161
$if OUT_DTYPE == "double":
153162
outtex[i] = float(value);
154163
$else:

backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ dequantize_texture:
33
IN_DTYPE: int32
44
OUT_DTYPE: float
55
MODE: per_tensor
6+
SHAPE: tensor
67
generate_variant_forall:
78
IN_DTYPE:
89
- VALUE: uint8
@@ -15,6 +16,9 @@ dequantize_texture:
1516
shader_variants:
1617
- NAME: dequantize_per_tensor_texture3d
1718
MODE: per_tensor
19+
SHAPE: scalar
20+
- NAME: dequantize_per_tensor_tensor_texture3d
21+
MODE: per_tensor
1822
- NAME: dequantize_per_token_texture3d
1923
MODE: per_token
2024
- NAME: dequantize_per_channel_texture3d

backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,23 @@ void add_dequantize_per_tensor_node(
8080
const ValueRef& quant_min,
8181
const ValueRef& quant_max,
8282
const ValueRef& output) {
83+
const bool is_tensor_scale_zp =
84+
graph.val_is_tensor(scale) && graph.val_is_tensor(zero_point);
85+
8386
std::string kernel_name("dequantize_per_tensor");
87+
if (is_tensor_scale_zp) {
88+
kernel_name += "_tensor";
89+
}
8490
add_storage_type_suffix(kernel_name, graph.storage_type_of(input));
8591
add_dtype_suffix(kernel_name, graph.dtype_of(input));
8692
add_dtype_suffix(kernel_name, graph.dtype_of(output));
8793

88-
float scale_val = static_cast<float>(graph.get_double(scale));
89-
int zero_point_val = static_cast<int>(graph.get_int(zero_point));
94+
float scale_val = 1.0;
95+
int zero_point_val = 0;
96+
if (!is_tensor_scale_zp) {
97+
scale_val = static_cast<float>(graph.get_double(scale));
98+
zero_point_val = static_cast<int>(graph.get_int(zero_point));
99+
}
90100
int quant_min_val = static_cast<int>(graph.get_int(quant_min));
91101
int quant_max_val = static_cast<int>(graph.get_int(quant_max));
92102

@@ -100,15 +110,17 @@ void add_dequantize_per_tensor_node(
100110
graph.strides_ubo(input),
101111
graph.sizes_ubo(output),
102112
graph.strides_ubo(output)};
113+
} else {
114+
param_ubos = {
115+
graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)};
116+
}
117+
118+
if (is_tensor_scale_zp) {
103119
push_constants = {
104-
PushConstantDataInfo(&scale_val, sizeof(float)),
105-
PushConstantDataInfo(&zero_point_val, sizeof(int)),
106120
PushConstantDataInfo(&quant_min_val, sizeof(int)),
107121
PushConstantDataInfo(&quant_max_val, sizeof(int)),
108122
};
109123
} else {
110-
param_ubos = {
111-
graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)};
112124
push_constants = {
113125
PushConstantDataInfo(&scale_val, sizeof(float)),
114126
PushConstantDataInfo(&zero_point_val, sizeof(int)),
@@ -122,13 +134,20 @@ void add_dequantize_per_tensor_node(
122134
graph.hashed_layout_of(input),
123135
};
124136

137+
std::vector<ArgGroup> inputs_and_outputs = {
138+
{output, vkapi::kWrite}, {input, vkapi::kRead}};
139+
if (is_tensor_scale_zp) {
140+
inputs_and_outputs.emplace_back(
141+
ArgGroup{{scale, zero_point}, vkapi::kRead});
142+
}
143+
125144
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
126145
graph,
127146
VK_KERNEL_FROM_STR(kernel_name),
128147
default_pick_global_wg_size,
129148
default_pick_local_wg_size,
130149
// Inputs and Outputs
131-
{{output, vkapi::kWrite}, {input, vkapi::kRead}},
150+
inputs_and_outputs,
132151
// Shader param buffers
133152
param_ubos,
134153
// Push Constants
@@ -519,6 +538,9 @@ REGISTER_OPERATORS {
519538
VK_REGISTER_OP(
520539
quantized_decomposed.dequantize_per_tensor.default,
521540
dequantize_per_tensor_impl);
541+
VK_REGISTER_OP(
542+
quantized_decomposed.dequantize_per_tensor.tensor,
543+
dequantize_per_tensor_impl);
522544
VK_REGISTER_OP(
523545
quantized_decomposed.dequantize_per_token.default,
524546
dequantize_per_token_impl);

0 commit comments

Comments
 (0)